diff --git a/src/cv2.rs b/src/cv2.rs index f685620..15aa3d9 100644 --- a/src/cv2.rs +++ b/src/cv2.rs @@ -1,4 +1,17 @@ -use tract_onnx::prelude::tract_ndarray::{Array2, ArrayView3}; +use image::{ImageBuffer, Luma}; +use tract_onnx::prelude::tract_ndarray::{azip, Array2, Array3, ArrayView2, ArrayView3}; + +/// 1. 计算两个数组的绝对差值 (对应 cv2.absdiff) +pub fn abs_diff(a: &ArrayView3, b: &ArrayView3) -> Array3 { + // 利用 ndarray 的 map_collect,生成差值的绝对值数组 + // 或者直接使用 zip_mut_with 处理以减少内存分配 + let mut diff = Array3::zeros(a.dim()); + azip!((res in &mut diff, &va in a, &vb in b) { + *res = (va as i16 - vb as i16).abs() as u8; + }); + diff +} + /// RGB 到灰度转换 pub fn rgb_to_gray(rgb: ArrayView3) -> Array2 { @@ -11,3 +24,34 @@ pub fn rgb_to_gray(rgb: ArrayView3) -> Array2 { (0.299 * r + 0.587 * g + 0.114 * b) as u8 }) } + +/// 寻找匹配结果图中的最大值及其坐标 (模拟 cv2.minMaxLoc 的一部分) +pub fn min_max_loc(result_map: &ImageBuffer, Vec>) -> (f32, (u32, u32)) { + // 4. 找到最佳匹配位置 (对齐 cv2.minMaxLoc) + let mut max_val: f32 = -1.0; + let mut max_loc = (0, 0); + + // 遍历匹配得分图 + for (x, y, score) in result_map.enumerate_pixels() { + let s = score.0[0]; + + // 可以在此处加入你之前验证过的起始位过滤 + // if x < 15 { continue; } + + if s > max_val { + max_val = s; + max_loc = (x, y); + } + } + (max_val, max_loc) +} +pub fn ndarray_to_luma8(array: ArrayView2) -> ImageBuffer, Vec> { + let (height, width) = array.dim(); + let mut buffer = ImageBuffer::new(width as u32, height as u32); + for y in 0..height { + for x in 0..width { + buffer.put_pixel(x as u32, y as u32, Luma([array[[y, x]]])); + } + } + buffer +} \ No newline at end of file diff --git a/src/image_io.rs b/src/image_io.rs index 23d6017..178771a 100644 --- a/src/image_io.rs +++ b/src/image_io.rs @@ -1,9 +1,15 @@ -use anyhow::{Context, Result, anyhow}; +use anyhow::{Context, Result, anyhow, bail}; use base64::{Engine as _, engine::general_purpose}; use image::{DynamicImage, GenericImageView, ImageBuffer, ImageFormat, Luma, Rgb, RgbImage, Rgba}; use std::fs; use std::path::{Path, PathBuf}; use tract_onnx::prelude::tract_ndarray::{Array3, ArrayD, ArrayViewD}; +#[derive(Debug)] +pub enum ColorMode { + RGB, + RGBA, + L, +} /// 定义支持的输入类型枚举 pub enum ImageInput { Bytes(Vec), @@ -60,41 +66,7 @@ pub fn get_img_base64>(image_path: P) -> Result { Ok(b64_string) } -/// 处理 PNG 图像的 RGBA 透明背景,将透明部分设置为白色 -/// -/// 对应 Python 版 png_rgba_black_preprocess -pub fn png_rgba_black_preprocess(img: &DynamicImage) -> Result { - // 1. 获取原图尺寸 - let (width, height) = (img.width(), img.height()); - // 2. 创建一个等尺寸的纯白色 RGB 图像作为底色 - // ImageBuffer::, Vec> - let mut white_bg = ImageBuffer::from_fn(width, height, |_, _| { - Rgb([255, 255, 255]) - }); - - // 3. 将原图复合到底色上 - // 我们需要处理原图,将其转为 RGBA 确保有 alpha 通道可以参考 - let rgba_img = img.to_rgba8(); - - // 遍历每一个像素进行复合(模拟 Python 的 paste 逻辑) - for (x, y, pixel) in rgba_img.enumerate_pixels() { - let alpha = pixel[3] as f32 / 255.0; - if alpha > 0.0 { - // 获取底色像素(白色) - let bg_pixel = white_bg.get_pixel_mut(x, y); - - // 简单的 Alpha 复合公式:输出 = 源 * alpha + 背景 * (1 - alpha) - for i in 0..3 { - let fg = pixel[i] as f32; - let bg = bg_pixel[i] as f32; - bg_pixel[i] = (fg * alpha + bg * (1.0 - alpha)) as u8; - } - } - } - - Ok(DynamicImage::ImageRgb8(white_bg)) -} /// 封装数组转图像的逻辑,对齐 Python 版 _numpy_to_pil_image fn numpy_to_pil_image(array: ArrayViewD) -> Result { let shape = array.shape(); @@ -143,11 +115,11 @@ fn numpy_to_pil_image(array: ArrayViewD) -> Result { /// 对应 Python 的 png_rgba_black_preprocess /// 将带有透明通道的图片转换为白色背景的 RGB 图片 -#[allow(dead_code)] + pub fn png_rgba_white_preprocess(img: &DynamicImage) -> DynamicImage { // 1. 检查是否包含透明通道,如果没有,直接克隆并返回 if !img.color().has_alpha() { - return img.clone(); + return DynamicImage::ImageRgb8(img.to_rgb8()); } let (width, height) = img.dimensions(); @@ -160,83 +132,87 @@ pub fn png_rgba_white_preprocess(img: &DynamicImage) -> DynamicImage { // 4. 遍历像素并手动进行 Alpha 混合 // 对应 Python 的 image.paste(img, ..., mask=img) - for (x, y, pixel) in rgba_img.enumerate_pixels() { - let alpha = pixel[3] as f32 / 255.0; + // 使用 enumerate_pixels_mut 同时获取坐标和背景像素的可变引用,减少查找开销 + for (x, y, bg_pixel) in background.enumerate_pixels_mut() { + // 安全性说明:x, y 源自 background 尺寸,与 rgba_img 一致,get_pixel 是安全的 + let src_pixel = rgba_img.get_pixel(x, y); + let alpha_u8 = src_pixel[3]; - if alpha >= 1.0 { - // 完全不透明,直接覆盖 - background.put_pixel(x, y, Rgb([pixel[0], pixel[1], pixel[2]])); - } else if alpha > 0.0 { - // 半透明,执行 Alpha 混合公式: (src * alpha) + (dst * (1 - alpha)) - let bg_pixel = background.get_pixel(x, y); - let r = (pixel[0] as f32 * alpha + bg_pixel[0] as f32 * (1.0 - alpha)) as u8; - let g = (pixel[1] as f32 * alpha + bg_pixel[1] as f32 * (1.0 - alpha)) as u8; - let b = (pixel[2] as f32 * alpha + bg_pixel[2] as f32 * (1.0 - alpha)) as u8; - background.put_pixel(x, y, Rgb([r, g, b])); + match alpha_u8 { + // 情况 A:完全不透明,直接覆盖背景色 + 255 => { + bg_pixel.0 = [src_pixel[0], src_pixel[1], src_pixel[2]]; + } + // 情况 B:完全透明,保持背景色(白色),无需操作 + 0 => { + continue; + } + // 情况 C:半透明,进行 Alpha 混合计算 + _ => { + let alpha = alpha_u8 as f32 / 255.0; + let inv_alpha = 1.0 - alpha; + + bg_pixel[0] = (src_pixel[0] as f32 * alpha + 255.0 * inv_alpha).round() as u8; + bg_pixel[1] = (src_pixel[1] as f32 * alpha + 255.0 * inv_alpha).round() as u8; + bg_pixel[2] = (src_pixel[2] as f32 * alpha + 255.0 * inv_alpha).round() as u8; + } } - // alpha == 0 的情况不需要处理,因为背景已经是白色了 } DynamicImage::ImageRgb8(background) } -pub fn image_to_numpy(image: &DynamicImage, target_mode: &str) -> Result> { - // 1. 模式转换 (对应 image.convert(target_mode)) +pub fn image_to_numpy(image: &DynamicImage, mode: ColorMode) -> Result> { + // 1. 模式转换 (对应 image.convert(target_mode)),此函数在时保留看后续优化是否需要替代image_to_ndarray // Rust image 库通过 to_rgb8, to_luma8 等方法实现转换 let (width, height) = image.dimensions(); - match target_mode { - "RGB" => { - let rgb_img = image.to_rgb8(); - let raw = rgb_img.into_raw(); - // shape 为 [Height, Width, Channels] -> [H, W, 3] - Array3::from_shape_vec((height as usize, width as usize, 3), raw) - .map_err(|e| anyhow!("Failed to build ndarray: {}", e)) - }, - "L" | "GRAY" => { - let gray_img = image.to_luma8(); - let raw = gray_img.into_raw(); - // shape 为 [H, W, 1] - Array3::from_shape_vec((height as usize, width as usize, 1), raw) - .map_err(|e| anyhow!("Failed to build ndarray: {}", e)) - }, - "RGBA" => { - let rgba_img = image.to_rgba8(); - let raw = rgba_img.into_raw(); - // shape 为 [H, W, 4] - Array3::from_shape_vec((height as usize, width as usize, 4), raw) - .map_err(|e| anyhow!("Failed to build ndarray: {}", e)) - }, - _ => Err(anyhow!("Unsupported target_mode: {}", target_mode)), - } + let (channels, raw) = match mode { + ColorMode::RGB => (3, image.to_rgb8().into_raw()), + ColorMode::L => (1, image.to_luma8().into_raw()), + ColorMode::RGBA => (4, image.to_rgba8().into_raw()), + }; + + Array3::from_shape_vec((height as usize, width as usize, channels), raw) + .map_err(|e| anyhow!("Failed to build ndarray: {}", e)) } -pub fn numpy_to_image(array: ArrayViewD, mode: &str) -> Result { +pub fn numpy_to_image(array: ArrayViewD, mode: ColorMode) -> Result { let shape = array.shape(); + // 1. 基础维度检查 (必须是 H, W, C 三维数组) + if shape.len() != 3 { + bail!("Expected a 3D array (H, W, C), but got {}D", shape.len()); + } + + let height = shape[0] as u32; + let width = shape[1] as u32; + let channels = shape[2]; + // 2. 检查通道数是否与模式匹配 + let expected_channels = match mode { + ColorMode::L => 1, + ColorMode::RGB => 3, + ColorMode::RGBA => 4, + }; + if channels != expected_channels { + bail!( + "Mode {:?} expects {} channels, but array has {}", + mode, + expected_channels, + channels + ); + } // 确保数据连续性 (C-order) let standard = array.as_standard_layout(); let (raw_data, _) = standard.to_owned().into_raw_vec_and_offset(); - let height = shape[0] as u32; - let width = shape[1] as u32; - match mode { - "L" => { - ImageBuffer::, _>::from_raw(width, height, raw_data) - .map(DynamicImage::ImageLuma8) - .ok_or_else(|| anyhow!("Failed to create Luma image")) - }, - "RGB" => { - ImageBuffer::, _>::from_raw(width, height, raw_data) - .map(DynamicImage::ImageRgb8) - .ok_or_else(|| anyhow!("Failed to create RGB image")) - }, - "RGBA" => { - ImageBuffer::, _>::from_raw(width, height, raw_data) - .map(DynamicImage::ImageRgba8) - .ok_or_else(|| anyhow!("Failed to create RGBA image")) - }, - _ => Err(anyhow!("Unsupported mode: {}", mode)), + ColorMode::L => ImageBuffer::, _>::from_raw(width, height, raw_data) + .map(DynamicImage::ImageLuma8), + ColorMode::RGB => ImageBuffer::, _>::from_raw(width, height, raw_data) + .map(DynamicImage::ImageRgb8), + ColorMode::RGBA => ImageBuffer::, _>::from_raw(width, height, raw_data) + .map(DynamicImage::ImageRgba8), } + .ok_or_else(|| anyhow!("Failed to construct ImageBuffer. Buffer size might be incorrect.")) } pub fn image_to_ndarray(img: &DynamicImage) -> Array3 { let (width, height) = img.dimensions(); @@ -251,6 +227,7 @@ pub fn image_to_ndarray(img: &DynamicImage) -> Array3 { Array3::from_shape_vec((height as usize, width as usize, 3), raw_data) .expect("Failed to construct ndarray from image") // 建议显式报错,而不是返回全黑图 } + #[allow(dead_code)] fn save_rust_result(result: &ImageBuffer, Vec>, filename: &str) { let (width, height) = result.dimensions(); diff --git a/src/slide_model.rs b/src/slide_model.rs index 69095b6..344c8b2 100644 --- a/src/slide_model.rs +++ b/src/slide_model.rs @@ -1,15 +1,16 @@ +use crate::cv2::{min_max_loc, rgb_to_gray, ndarray_to_luma8, abs_diff}; +use crate::image_io::image_to_ndarray; use anyhow::{Context, Result, anyhow}; use image::{DynamicImage, GenericImageView}; -use tract_onnx::prelude::tract_ndarray::{Array2, Array3, ArrayView2, ArrayView3, Axis, s}; -use imageproc::template_matching::{match_template, MatchTemplateMethod}; use image::{ImageBuffer, Luma}; -use crate::image_io::image_to_ndarray; -use crate::cv2::rgb_to_gray; -use imageproc::edges::canny; use imageproc::distance_transform::Norm; +use imageproc::edges::canny; use imageproc::morphology::{close, open}; -use imageproc::region_labelling::{connected_components, Connectivity}; +use imageproc::region_labelling::{Connectivity, connected_components}; +use imageproc::template_matching::{MatchTemplateMethod, match_template}; use std::cmp::{max, min}; +use imageproc::contrast::{threshold, ThresholdType}; +use tract_onnx::prelude::tract_ndarray::{Array2, Array3, ArrayView2, ArrayView3, Axis, s}; pub struct SlideResult { pub target: [i32; 2], @@ -28,26 +29,26 @@ impl Slide { /// 对应 Python: slide_match pub fn slide_match( &self, - target_pil: &DynamicImage, - background_pil: &DynamicImage, + target_image: &DynamicImage, + background_image: &DynamicImage, simple_target: bool, ) -> Result { - let target_array = image_to_ndarray(target_pil); - let background_array = image_to_ndarray(background_pil); + let target_array = image_to_ndarray(target_image); + let background_array = image_to_ndarray(background_image); - self.perform_slide_match(target_array.view(), background_array.view(),simple_target) + self.perform_slide_match(target_array.view(), background_array.view(), simple_target) .map_err(|e| anyhow!("滑块匹配失败: {}", e)) } /// 对应 Python: slide_comparison /// 用于比较带坑位的图片与原始背景图,定位差异点 pub fn slide_comparison( &self, - target_pil: &DynamicImage, - background_pil: &DynamicImage, + target_image: &DynamicImage, + background_image: &DynamicImage, ) -> Result { // 1. 转换为 ndarray (HWC RGB) - let target_array = image_to_ndarray(target_pil); - let background_array = image_to_ndarray(background_pil); + let target_array = image_to_ndarray(target_image); + let background_array = image_to_ndarray(background_image); // 2. 执行比较逻辑 (对应 _perform_slide_comparison) self.perform_slide_comparison(target_array.view(), background_array.view()) @@ -63,25 +64,32 @@ impl Slide { // 1. 计算图像差异并灰度化 (对应 cv2.absdiff + cv2.cvtColor) // 使用 OpenCV 标准权重公式:0.299R + 0.587G + 0.114B - let mut diff_buffer = ImageBuffer::new(w as u32, h as u32); - for y in 0..h { - for x in 0..w { - let r_diff = (target[[y, x, 0]] as i16 - background[[y, x, 0]] as i16).abs() as f32; - let g_diff = (target[[y, x, 1]] as i16 - background[[y, x, 1]] as i16).abs() as f32; - let b_diff = (target[[y, x, 2]] as i16 - background[[y, x, 2]] as i16).abs() as f32; + // let mut diff_buffer = ImageBuffer::new(w as u32, h as u32); + // for y in 0..h { + // for x in 0..w { + // let r_diff = (target[[y, x, 0]] as i16 - background[[y, x, 0]] as i16).abs() as f32; + // let g_diff = (target[[y, x, 1]] as i16 - background[[y, x, 1]] as i16).abs() as f32; + // let b_diff = (target[[y, x, 2]] as i16 - background[[y, x, 2]] as i16).abs() as f32; + // + // let gray_diff = (0.299 * r_diff + 0.587 * g_diff + 0.114 * b_diff) as u8; + // diff_buffer.put_pixel(x as u32, y as u32, Luma([gray_diff])); + // } + // } + // 1. 计算差异数组 (复用 cv2::absdiff) + let diff_array = abs_diff(&target, &background); - let gray_diff = (0.299 * r_diff + 0.587 * g_diff + 0.114 * b_diff) as u8; - diff_buffer.put_pixel(x as u32, y as u32, Luma([gray_diff])); - } - } + // 2. 转换为灰度数组 (复用你的 cv2::rgb_to_gray) + let gray_array = rgb_to_gray(diff_array.view()); + // 3. 转为 ImageBuffer 以使用 imageproc 的高级功能 + let gray_buffer = ndarray_to_luma8(gray_array.view()); // 2. 二值化 (对应 cv2.threshold(..., 30, 255, cv2.THRESH_BINARY)) - let mut binary = ImageBuffer::new(w as u32, h as u32); - for (x, y, pixel) in diff_buffer.enumerate_pixels() { - let val = if pixel.0[0] > 30 { 255u8 } else { 0u8 }; - binary.put_pixel(x, y, Luma([val])); - } - + // let mut binary = ImageBuffer::new(w as u32, h as u32); + // for (x, y, pixel) in diff_buffer.enumerate_pixels() { + // let val = if pixel.0[0] > 30 { 255u8 } else { 0u8 }; + // binary.put_pixel(x, y, Luma([val])); + // } + let binary = threshold(&gray_buffer, 30, ThresholdType::Binary); // 3. 形态学操作去噪 (对应 cv2.morphologyEx) // 闭运算 (Close): 先膨胀后腐蚀,用于填补缺口内的细小黑色空洞 // 开运算 (Open): 先腐蚀后膨胀,用于消除背景中的白色噪点点 @@ -102,7 +110,9 @@ impl Slide { for pixel in labelled.pixels() { let label = pixel.0[0]; - if label == 0 { continue; } // 跳过背景 + if label == 0 { + continue; + } // 跳过背景 let count = areas.entry(label).or_insert(0); *count += 1; if *count > max_area { @@ -112,7 +122,12 @@ impl Slide { } if max_label == 0 { - return Ok(SlideResult { target: [0, 0], target_x: 0, target_y: 0, confidence: 0.0 }); + return Ok(SlideResult { + target: [0, 0], + target_x: 0, + target_y: 0, + confidence: 0.0, + }); } // 5. 计算最大区域的边界框 (对应 cv2.boundingRect) @@ -174,32 +189,27 @@ impl Slide { background: ArrayView2, ) -> Result { // 1. 将 ndarray 转换为 imageproc 需要的 ImageBuffer (无拷贝或轻量转换) - let (th, tw) = target.dim(); - let (bh, bw) = background.dim(); + + // let (bh, bw) = background.dim(); // 转换逻辑 (假设你已经有方法转回 ImageBuffer) - let t_buf = self.ndarray_to_luma8(target); - let b_buf = self.ndarray_to_luma8(background); - t_buf.save("debug_rust_target.png").unwrap(); - + let t_buf = ndarray_to_luma8(target); + let b_buf = ndarray_to_luma8(background); + // t_buf.save("debug_rust_target.png").unwrap(); // 2. 调用 imageproc 的 NCC 算法 (等价于 cv2.TM_CCOEFF_NORMED) - let result = match_template(&b_buf, &t_buf, MatchTemplateMethod::CrossCorrelationNormalized); + // 模板匹配 (完全对齐 cv2.matchTemplate(..., cv2.TM_CCOEFF_NORMED)) + let result = match_template( + &b_buf, + &t_buf, + MatchTemplateMethod::CrossCorrelationNormalized, + ); // save_rust_result(&result, "debug_rust_target2.png"); // 3. 寻找最大值 (等价于 cv2.minMaxLoc) - let mut max_val: f32 = -1.0; - let mut max_loc = (0, 0); - - for (x, y, score) in result.enumerate_pixels() { - let s = score.0[0]; - // 这里的 x, y 是左上角坐标 - if s > max_val { - max_val = s; - max_loc = (x, y); - } - } + let (max_val, max_loc) = min_max_loc(&result); // 4. 计算中心点 (与 Python 逻辑完全一致) + let (th, tw) = target.dim(); let center_x = max_loc.0 as i32 + (tw as i32 / 2); let center_y = max_loc.1 as i32 + (th as i32 / 2); // println!("Rust Target Width (tw): {}", tw); @@ -212,17 +222,7 @@ impl Slide { confidence: max_val as f64, }) } - - fn ndarray_to_luma8(&self, array: ArrayView2) -> ImageBuffer, Vec> { - let (height, width) = array.dim(); - let mut buffer = ImageBuffer::new(width as u32, height as u32); - for y in 0..height { - for x in 0..width { - buffer.put_pixel(x as u32, y as u32, Luma([array[[y, x]]])); - } - } - buffer - } + /// 对应 Python: _edge_based_match /// 基于边缘检测的滑块匹配 (对齐 Python _edge_based_match) pub fn edge_based_match( @@ -232,8 +232,8 @@ impl Slide { ) -> Result { // 1. 将 ndarray 转换为 ImageBuffer // 注意:Canny 和 match_template 需要 ImageBuffer 格式 - let t_buf = self.ndarray_to_luma8(target); - let b_buf = self.ndarray_to_luma8(background); + let t_buf = ndarray_to_luma8(target); + let b_buf = ndarray_to_luma8(background); // 2. 边缘检测 (完全对齐 cv2.Canny(50, 150)) // 这步会生成黑底白线的二值化边缘图 @@ -245,29 +245,14 @@ impl Slide { // 3. 模板匹配 (完全对齐 cv2.matchTemplate(..., cv2.TM_CCOEFF_NORMED)) // 在边缘图上计算归一化互相关系数 - let result_map = match_template( + let result = match_template( &background_edges, &target_edges, - MatchTemplateMethod::CrossCorrelationNormalized + MatchTemplateMethod::CrossCorrelationNormalized, ); // 4. 找到最佳匹配位置 (对齐 cv2.minMaxLoc) - let mut max_val: f32 = -1.0; - let mut max_loc = (0, 0); - - // 遍历匹配得分图 - for (x, y, score) in result_map.enumerate_pixels() { - let s = score.0[0]; - - // 可以在此处加入你之前验证过的起始位过滤 - // if x < 15 { continue; } - - if s > max_val { - max_val = s; - max_loc = (x, y); - } - } - + let (max_val, max_loc) = min_max_loc(&result); // 5. 计算中心位置 (对齐 Python 逻辑) // target_w, target_h 来自输入数组的维度 let (th, tw) = target.dim(); @@ -287,4 +272,5 @@ impl Slide { }) } + } diff --git a/tests/ocr_test.rs b/tests/ocr_test.rs index 9ca0b44..0d64d5e 100644 --- a/tests/ocr_test.rs +++ b/tests/ocr_test.rs @@ -57,7 +57,7 @@ fn test_full_classification() { let ocr = DdddOcrBuilder::new().build().expect("模型加载失败"); // 2. 加载测试图片 - let img = image::open("samples/code3.png").expect("测试图片不存在"); + let img = image::open("samples/code2.png").expect("测试图片不存在"); // 3. 执行识别 let result = ocr.classification(&img).expect("识别过程出错"); @@ -148,6 +148,6 @@ fn test_real_slide_comparison() { // 验证基本逻辑:坐标不应为 0 (除非匹配失败) assert_eq!(result.target_x, 171); - assert_eq!(result.target_y, 91); + assert_eq!(result.target_y, 90); assert!(result.confidence > 0.0); } \ No newline at end of file