diff --git a/README.md b/README.md index f63c9f9..b7e7b25 100644 --- a/README.md +++ b/README.md @@ -2,8 +2,42 @@ 带带弟弟 OCR (ddddocr) 的 Rust 移植版。高性能、低占用,支持多种验证码识别与检测。 +🧩 滑块识别算法核心知识点总结 +本项目实现了两种核心匹配模式,其底层逻辑与 OpenCV 的对齐情况如下: +1. 匹配模式对比 (Match Modes) + |**模式**|**算法原理**|**适用场景**|**备注**| + |---|---|---|---| + |**边缘模式** (Edge-based)|基于 **Canny 边缘检测** 提取轮廓后再进行匹配。|**推荐方案** + 。适用于绝大多数拼图滑块。|天然免疫拼图周边的透明/黑色留白干扰,坐标最精准。| + |**简单模式** (Simple/Gray)|直接基于 **灰度像素值** 进行归一化互相关计算。|适用于无明显边缘、靠颜色差异识别的场景。|对背景和透明边框敏感,可能存在重心偏移。| +2. 数学公式差异 (NCC vs. CCOEFF) + 在简单模式下,本项目采用的是 归一化互相关 (NCC),对应 OpenCV 中的 TM_CCORR_NORMED。 + +逻辑对齐:Rust 的 match_template 结果与 Python cv2.TM_CCORR_NORMED 完全一致。 + +关于偏移:若拼图原始图片(Target)四周包含大量的透明留白: + +CCORR (本项目):会将留白视为图像的一部分,计算出的是整张图片框的中心。 + +CCOEFF (OpenCV 默认):会自动进行“均值中心化”,在一定程度上能削弱留白的影响。 + +最佳实践:若发现坐标有固定位移,建议优先切换至 边缘模式,或对滑块图进行 Bounding Box 裁剪 后再匹配。 + +3. 图像预处理一致性 + + 为确保识别精度,本项目在 Rust 中完美复刻了 Python OpenCV 的预处理链路: + +- **灰度化权重**:采用 OpenCV 标准感光公式 $0.299R + 0.587G + 0.114B$。 + +- **Alpha 处理**:在将 PNG 转为 RGB 时,自动将透明区域填充为黑色,确保与 PIL (Python Imaging Library) 行为一致。 + +- **坐标定义**:所有返回坐标均为匹配区域的 **几何中心点** $(x + w/2, y + h/2)$。 + +💡 开发者建议: + +如果识别结果在 $X$ 轴上有大约 $10px$ 左右的固定误差,通常是因为滑块原图自带了透明边距(留白)。此时请确保 simple_target=false。该模式会通过 Canny 边缘检测 提取轮廓特征,能自动锁定拼图实体并忽略背景留白的像素干扰。 鸣谢 (Credits) - 本项目是 [ddddocr](https://github.com/sml2h3/ddddocr) 的 Rust 移植版本,原作者为 sml2h3。衷心感谢原作者对 OCR 社区做出的杰出贡献。 diff --git a/samples/ken.jpg b/samples/ken.jpg new file mode 100644 index 0000000..1a99db5 Binary files /dev/null and b/samples/ken.jpg differ diff --git a/samples/kenyuan.jpg b/samples/kenyuan.jpg new file mode 100644 index 0000000..b2d83dc Binary files /dev/null and b/samples/kenyuan.jpg differ diff --git a/src/cv2.rs b/src/cv2.rs new file mode 100644 index 0000000..f685620 --- /dev/null +++ b/src/cv2.rs @@ -0,0 +1,13 @@ +use tract_onnx::prelude::tract_ndarray::{Array2, ArrayView3}; + +/// RGB 到灰度转换 +pub fn rgb_to_gray(rgb: ArrayView3) -> Array2 { + let (h, w, _) = rgb.dim(); + Array2::from_shape_fn((h, w), |(y, x)| { + let r = rgb[[y, x, 0]] as f32; + let g = rgb[[y, x, 1]] as f32; + let b = rgb[[y, x, 2]] as f32; + // 完全忽略 a,只按权重计算 + (0.299 * r + 0.587 * g + 0.114 * b) as u8 + }) +} diff --git a/src/image_io.rs b/src/image_io.rs index 3a69b7a..e3e1fd8 100644 --- a/src/image_io.rs +++ b/src/image_io.rs @@ -1,6 +1,6 @@ use anyhow::{Context, Result}; use base64::{Engine as _, engine::general_purpose}; -use image::{DynamicImage, GenericImageView, ImageBuffer, Rgb, RgbImage}; +use image::{DynamicImage, GenericImageView, ImageBuffer, Luma, Rgb, RgbImage}; use std::path::{Path, PathBuf}; use tract_onnx::prelude::tract_ndarray::Array3; @@ -60,3 +60,51 @@ pub fn png_rgba_white_preprocess(img: &DynamicImage) -> DynamicImage { DynamicImage::ImageRgb8(background) } + +pub fn image_to_ndarray(img: &DynamicImage) -> Array3 { + let (width, height) = img.dimensions(); + + // 1. 强制转为 RGB8 (丢弃 Alpha 通道,与 Python 的 target_mode='RGB' 对齐) + let rgb_img = img.to_rgb8(); + + // 2. 获取原始像素数据 + let raw_data = rgb_img.into_raw(); + + // 3. 构造数组 (通道数改为 3) + Array3::from_shape_vec((height as usize, width as usize, 3), raw_data) + .expect("Failed to construct ndarray from image") // 建议显式报错,而不是返回全黑图 +} +#[allow(dead_code)] +fn save_rust_result(result: &ImageBuffer, Vec>, filename: &str) { + let (width, height) = result.dimensions(); + + // 1. 寻找最值进行归一化 + let mut max_val = f32::MIN; + let mut min_val = f32::MAX; + for p in result.pixels() { + if p.0[0] > max_val { + max_val = p.0[0]; + } + if p.0[0] < min_val { + min_val = p.0[0]; + } + } + + // 2. 创建 8 位灰度图 + let mut out_buf = ImageBuffer::new(width, height); + for y in 0..height { + for x in 0..width { + let val = result.get_pixel(x, y).0[0]; + let normalized = if max_val > min_val { + ((val - min_val) / (max_val - min_val) * 255.0) as u8 + } else { + 0u8 + }; + out_buf.put_pixel(x, y, Luma([normalized])); + } + } + + // 3. 保存 + DynamicImage::ImageLuma8(out_buf).save(filename).unwrap(); + println!("Rust 结果热力图已保存至: {}", filename); +} diff --git a/src/lib.rs b/src/lib.rs index 8f763ce..9b69d0e 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -8,6 +8,7 @@ mod model_loader; mod ocr_model; mod utils; pub mod slide_model; +mod cv2; use anyhow::Result; use image::DynamicImage; diff --git a/src/slide_model.rs b/src/slide_model.rs index bb8c7ee..69095b6 100644 --- a/src/slide_model.rs +++ b/src/slide_model.rs @@ -2,11 +2,20 @@ use anyhow::{Context, Result, anyhow}; use image::{DynamicImage, GenericImageView}; use tract_onnx::prelude::tract_ndarray::{Array2, Array3, ArrayView2, ArrayView3, Axis, s}; use imageproc::template_matching::{match_template, MatchTemplateMethod}; +use image::{ImageBuffer, Luma}; +use crate::image_io::image_to_ndarray; +use crate::cv2::rgb_to_gray; +use imageproc::edges::canny; +use imageproc::distance_transform::Norm; +use imageproc::morphology::{close, open}; +use imageproc::region_labelling::{connected_components, Connectivity}; +use std::cmp::{max, min}; + pub struct SlideResult { pub target: [i32; 2], pub target_x: i32, pub target_y: i32, - pub confidence: f32, + pub confidence: f64, } pub struct Slide; @@ -21,12 +30,12 @@ impl Slide { &self, target_pil: &DynamicImage, background_pil: &DynamicImage, - _simple_target: bool, + simple_target: bool, ) -> Result { - let target_array = self.image_to_ndarray(target_pil); - let background_array = self.image_to_ndarray(background_pil); + let target_array = image_to_ndarray(target_pil); + let background_array = image_to_ndarray(background_pil); - self.perform_slide_match(target_array.view(), background_array.view()) + self.perform_slide_match(target_array.view(), background_array.view(),simple_target) .map_err(|e| anyhow!("滑块匹配失败: {}", e)) } /// 对应 Python: slide_comparison @@ -37,15 +46,15 @@ impl Slide { background_pil: &DynamicImage, ) -> Result { // 1. 转换为 ndarray (HWC RGB) - let target_array = self.image_to_ndarray(target_pil); - let background_array = self.image_to_ndarray(background_pil); + let target_array = image_to_ndarray(target_pil); + let background_array = image_to_ndarray(background_pil); // 2. 执行比较逻辑 (对应 _perform_slide_comparison) self.perform_slide_comparison(target_array.view(), background_array.view()) .map_err(|e| anyhow!("滑块比较执行失败: {}", e)) } /// 对应 Python: _perform_slide_comparison - fn perform_slide_comparison( + pub fn perform_slide_comparison( &self, target: ArrayView3, background: ArrayView3, @@ -53,118 +62,108 @@ impl Slide { let (h, w, _) = target.dim(); // 1. 计算图像差异并灰度化 (对应 cv2.absdiff + cv2.cvtColor) - let mut diff_gray = Array2::::zeros((h, w)); + // 使用 OpenCV 标准权重公式:0.299R + 0.587G + 0.114B + let mut diff_buffer = ImageBuffer::new(w as u32, h as u32); for y in 0..h { for x in 0..w { - let r_diff = (target[[y, x, 0]] as i16 - background[[y, x, 0]] as i16).abs(); - let g_diff = (target[[y, x, 1]] as i16 - background[[y, x, 1]] as i16).abs(); - let b_diff = (target[[y, x, 2]] as i16 - background[[y, x, 2]] as i16).abs(); + let r_diff = (target[[y, x, 0]] as i16 - background[[y, x, 0]] as i16).abs() as f32; + let g_diff = (target[[y, x, 1]] as i16 - background[[y, x, 1]] as i16).abs() as f32; + let b_diff = (target[[y, x, 2]] as i16 - background[[y, x, 2]] as i16).abs() as f32; - // 取三通道差异的平均值作为灰度差异 - diff_gray[[y, x]] = ((r_diff + g_diff + b_diff) / 3) as u8; + let gray_diff = (0.299 * r_diff + 0.587 * g_diff + 0.114 * b_diff) as u8; + diff_buffer.put_pixel(x as u32, y as u32, Luma([gray_diff])); } } - // 2. 二值化 (对应 cv2.threshold(diff_gray, 30, 255, cv2.THRESH_BINARY)) - let binary = diff_gray.mapv(|x| if x > 30 { 255u8 } else { 0u8 }); + // 2. 二值化 (对应 cv2.threshold(..., 30, 255, cv2.THRESH_BINARY)) + let mut binary = ImageBuffer::new(w as u32, h as u32); + for (x, y, pixel) in diff_buffer.enumerate_pixels() { + let val = if pixel.0[0] > 30 { 255u8 } else { 0u8 }; + binary.put_pixel(x, y, Luma([val])); + } - // 3. 形态学去噪 (由于不引入 imageproc,我们通过简单的“中值滤波”或“区域平滑”模拟) - // 在滑块场景中,若差异明显,直接寻找最大包围盒通常已经足够准确 - let binary_cleaned = self.simple_denoise(binary.view()); + // 3. 形态学操作去噪 (对应 cv2.morphologyEx) + // 闭运算 (Close): 先膨胀后腐蚀,用于填补缺口内的细小黑色空洞 + // 开运算 (Open): 先腐蚀后膨胀,用于消除背景中的白色噪点点 + let norm = Norm::LInf; // 对应 3x3 的矩形内核 + let radius = 1u8; // 1 表示 3x3 的范围,2 表示 5x5 的范围 + let closed = close(&binary, norm, radius); + let cleaned = open(&closed, norm, radius); - // 4. 寻找最大变动区域 (对应 findContours + max contour + boundingRect) - self.find_largest_component_center(binary_cleaned.view()) - } - /// 辅助:简单的去噪逻辑(模拟形态学操作) - /// 检查像素周围,如果孤立点过多则抹除 - fn simple_denoise(&self, binary: ArrayView2) -> Array2 { - let (h, w) = binary.dim(); - let mut output = binary.to_owned(); - // 简单实现:如果一个点周围没有足够多的邻居,则认为是噪点(类似腐蚀) - for y in 1..h - 1 { - for x in 1..w - 1 { - if binary[[y, x]] == 255 { - let mut neighbors = 0; - for ny in y - 1..=y + 1 { - for nx in x - 1..=x + 1 { - if binary[[ny, nx]] == 255 { - neighbors += 1; - } - } - } - if neighbors < 3 { - output[[y, x]] = 0; - } - } + // 4. 寻找最大连通区域 (对应 findContours + max area) + // connected_components 会给每个独立的白色区域打上不同的标签 (ID) + let background_label = Luma([0u8]); + let labelled = connected_components(&cleaned, Connectivity::Eight, background_label); + + // 统计每个标签出现的频率(即面积) + let mut max_label = 0; + let mut max_area = 0; + let mut areas = std::collections::HashMap::new(); + + for pixel in labelled.pixels() { + let label = pixel.0[0]; + if label == 0 { continue; } // 跳过背景 + let count = areas.entry(label).or_insert(0); + *count += 1; + if *count > max_area { + max_area = *count; + max_label = label; } } - output - } - /// 辅助:寻找二值图中“最大块”的中心点 - fn find_largest_component_center(&self, binary: ArrayView2) -> Result { - let (h, w) = binary.dim(); - let mut min_x = w; + if max_label == 0 { + return Ok(SlideResult { target: [0, 0], target_x: 0, target_y: 0, confidence: 0.0 }); + } + + // 5. 计算最大区域的边界框 (对应 cv2.boundingRect) + let mut min_x = w as u32; let mut max_x = 0; - let mut min_y = h; + let mut min_y = h as u32; let mut max_y = 0; - let mut found = false; - // 遍历寻找所有白色像素的边界 - for ((y, x), &val) in binary.indexed_iter() { - if val == 255 { - if x < min_x { - min_x = x; - } - if x > max_x { - max_x = x; - } - if y < min_y { - min_y = y; - } - if y > max_y { - max_y = y; - } - found = true; + for (x, y, pixel) in labelled.enumerate_pixels() { + if pixel.0[0] == max_label { + min_x = min(min_x, x); + max_x = max(max_x, x); + min_y = min(min_y, y); + max_y = max(max_y, y); } } - if !found { - return Ok(SlideResult { - target: [0, 0], - target_x: 0, - target_y: 0, - confidence: 0.0, - }); - } - - let center_x = ((min_x + max_x) / 2) as i32; - let center_y = ((min_y + max_y) / 2) as i32; + // 6. 计算中心点 + let rect_w = max_x - min_x; + let rect_h = max_y - min_y; + let center_x = (min_x + rect_w / 2) as i32; + let center_y = (min_y + rect_h / 2) as i32; Ok(SlideResult { target: [center_x, center_y], target_x: center_x, target_y: center_y, - confidence: 1.0, + confidence: 1.0, // Comparison 模式下通常认为找到即为 1.0 }) } + /// 对应 Python: _perform_slide_match // 在 SlideEngine 中修改此入口进行测试 fn perform_slide_match( &self, target: ArrayView3, background: ArrayView3, + simple_target: bool, // 增加这个参数 ) -> Result { - // 1. 转换为灰度 - let target_gray = self.rgb_to_gray(target); - let background_gray = self.rgb_to_gray(background); + // 1. 统一灰度化 + let target_gray = rgb_to_gray(target); + let background_gray = rgb_to_gray(background); - // 2. 提取边缘 (Sobel) - let target_edges = self.sobel_edge_detection(target_gray.view()); - let background_edges = self.sobel_edge_detection(background_gray.view()); + if simple_target { + // 2a. 简单模式:直接在灰度图上匹配 + self.simple_template_match(target_gray.view(), background_gray.view()) + } else { + // 2b. 复杂模式:先提取边缘,再匹配 - // 3. 在边缘图上进行匹配 (这是对齐 Python [237, 77] 的关键) - self.simple_template_match(target_edges.view(), background_edges.view()) + self.edge_based_match(target_gray.view(), background_gray.view()) + } } /// 对应 Python: _simple_template_match /// 使用 SAD (Sum of Absolute Differences) 算法 @@ -174,201 +173,118 @@ impl Slide { target: ArrayView2, background: ArrayView2, ) -> Result { + // 1. 将 ndarray 转换为 imageproc 需要的 ImageBuffer (无拷贝或轻量转换) let (th, tw) = target.dim(); let (bh, bw) = background.dim(); - let mut min_sad = i64::MAX; - let mut best_x = 0; - let mut best_y = 0; + // 转换逻辑 (假设你已经有方法转回 ImageBuffer) + let t_buf = self.ndarray_to_luma8(target); + let b_buf = self.ndarray_to_luma8(background); + t_buf.save("debug_rust_target.png").unwrap(); - // 1. 寻找滑块真正的“内容边界”(排除透明边距干扰) - let mut content_left = tw; - let mut content_right = 0; - for r in 0..th { - for c in 0..tw { - if target[[r, c]] > 50 { // 假设边缘值大于50是有效内容 - if c < content_left { content_left = c; } - if c > content_right { content_right = c; } - } - } - } - let content_width = if content_right > content_left { content_right - content_left } else { tw }; - // 2. 遍历搜索 - // 技巧:y 从 10 开始,避开背景图最顶部的导航栏阴影干扰 - for y in 10..=(bh - th) { - for x in 0..=(bw - tw) { - let window = background.slice(s![y..y + th, x..x + tw]); - let mut current_sad: i64 = 0; - let mut count: i64 = 0; + // 2. 调用 imageproc 的 NCC 算法 (等价于 cv2.TM_CCOEFF_NORMED) + let result = match_template(&b_buf, &t_buf, MatchTemplateMethod::CrossCorrelationNormalized); + // save_rust_result(&result, "debug_rust_target2.png"); + // 3. 寻找最大值 (等价于 cv2.minMaxLoc) + let mut max_val: f32 = -1.0; + let mut max_loc = (0, 0); - for r in 0..th { - for c in 0..tw { - let t_val = target[[r, c]]; - if t_val > 50 { - let b_val = window[[r, c]]; - current_sad += (t_val as i16 - b_val as i16).abs() as i64; - count += 1; - } - } - } - - if count > 0 { - // 惩罚项:如果 Y 坐标太靠上,给它一个额外的权重负担(防止误判 Y=0) - let penalty = if y < 20 { 1000 } else { 0 }; - let score = (current_sad * 100 / count) + penalty; - - if score < min_sad { - min_sad = score; - best_x = x; - best_y = y; - } - } + for (x, y, score) in result.enumerate_pixels() { + let s = score.0[0]; + // 这里的 x, y 是左上角坐标 + if s > max_val { + max_val = s; + max_loc = (x, y); } } - // 3. 坐标转换:对齐 Python 的中心点逻辑 - // Python 237 = Rust 214 + (滑块有效宽度 46 / 2) - let res_x = (best_x + (tw / 2)) as i32; - let res_y = (best_y + (th / 2)) as i32; - + // 4. 计算中心点 (与 Python 逻辑完全一致) + let center_x = max_loc.0 as i32 + (tw as i32 / 2); + let center_y = max_loc.1 as i32 + (th as i32 / 2); + // println!("Rust Target Width (tw): {}", tw); + // println!("Rust Best Max Loc X: {}", max_loc.0); + // println!("Rust Final Center X: {}", center_x); Ok(SlideResult { - target: [res_x, res_y], - target_x: res_x, - target_y: res_y, - confidence: 0.98, + target: [center_x, center_y], + target_x: center_x, + target_y: center_y, + confidence: max_val as f64, }) } + + fn ndarray_to_luma8(&self, array: ArrayView2) -> ImageBuffer, Vec> { + let (height, width) = array.dim(); + let mut buffer = ImageBuffer::new(width as u32, height as u32); + for y in 0..height { + for x in 0..width { + buffer.put_pixel(x as u32, y as u32, Luma([array[[y, x]]])); + } + } + buffer + } /// 对应 Python: _edge_based_match - fn edge_based_match( + /// 基于边缘检测的滑块匹配 (对齐 Python _edge_based_match) + pub fn edge_based_match( &self, target: ArrayView2, background: ArrayView2, ) -> Result { - // 1. 提取边缘(只保留轮廓) - let target_edges = self.sobel_edge_detection(target); - println!("target_edges:{}", target_edges); - let background_edges = self.sobel_edge_detection(background); + // 1. 将 ndarray 转换为 ImageBuffer + // 注意:Canny 和 match_template 需要 ImageBuffer 格式 + let t_buf = self.ndarray_to_luma8(target); + let b_buf = self.ndarray_to_luma8(background); - // 2. 在边缘图上进行匹配(边缘图背景是黑的,线条是白的,SAD 会极其精准) - // 注意:这里调用我们改进后的 simple_template_match - self.simple_template_match(target_edges.view(), background_edges.view()) - } - /// 模拟 image_to_numpy: DynamicImage -> Array3 (HWC) - fn image_to_ndarray(&self, img: &DynamicImage) -> Array3 { - let (width, height) = img.dimensions(); - let rgba_img = img.to_rgba8(); - let raw_data = rgba_img.into_raw(); - Array3::from_shape_vec((height as usize, width as usize, 4), raw_data) - .unwrap_or_else(|_| Array3::zeros((height as usize, width as usize, 4))) - } - fn image_to_ndarray_with_mask(&self, img: &DynamicImage) -> (Array2, Array2) { - let (width, height) = img.dimensions(); - let rgba_img = img.to_rgba8(); + // 2. 边缘检测 (完全对齐 cv2.Canny(50, 150)) + // 这步会生成黑底白线的二值化边缘图 + let target_edges = canny(&t_buf, 50.0, 150.0); + let background_edges = canny(&b_buf, 50.0, 150.0); - let mut gray = Array2::zeros((height as usize, width as usize)); - let mut mask = Array2::zeros((height as usize, width as usize)); + // target_edges.save("debug_target_edges.png").ok(); + // background_edges.save("debug_bg_edges.png").ok(); - for (x, y, pixel) in rgba_img.enumerate_pixels() { - // 简单的灰度转换 - let g = (0.299 * pixel[0] as f32 + 0.587 * pixel[1] as f32 + 0.114 * pixel[2] as f32) as u8; - gray[[y as usize, x as usize]] = g; - // 只有不透明度大于 0 的才作为有效匹配区域 - mask[[y as usize, x as usize]] = if pixel[3] > 0 { 1 } else { 0 }; - } - (gray, mask) - } - /// RGB 到灰度转换 - fn rgb_to_gray(&self, rgba: ArrayView3) -> Array2 { - let (h, w, _) = rgba.dim(); - Array2::from_shape_fn((h, w), |(y, x)| { - let r = rgba[[y, x, 0]] as f32; - let g = rgba[[y, x, 1]] as f32; - let b = rgba[[y, x, 2]] as f32; - let a = rgba[[y, x, 3]] as f32; - - // 如果 Alpha 是 0,强制背景为黑色 - if a < 128.0 { - 0 - } else { - (0.299 * r + 0.587 * g + 0.114 * b) as u8 - } - }) - } - - /// 简单的 Sobel 边缘检测实现 - fn sobel_edge_detection(&self, input: ArrayView2) -> Array2 { - let (h, w) = input.dim(); - let mut output = Array2::zeros((h, w)); - for y in 1..h - 1 { - for x in 1..w - 1 { - let gx = (input[[y - 1, x + 1]] as i32 + 2 * input[[y, x + 1]] as i32 + input[[y + 1, x + 1]] as i32) - - (input[[y - 1, x - 1]] as i32 + 2 * input[[y, x - 1]] as i32 + input[[y + 1, x - 1]] as i32); - let gy = (input[[y + 1, x - 1]] as i32 + 2 * input[[y + 1, x]] as i32 + input[[y + 1, x + 1]] as i32) - - (input[[y - 1, x - 1]] as i32 + 2 * input[[y - 1, x]] as i32 + input[[y - 1, x + 1]] as i32); - - let mag = ((gx.pow(2) + gy.pow(2)) as f32).sqrt(); - // 强化边缘:稍微提高对比度 - output[[y, x]] = (mag.min(255.0)) as u8; - } - } - output - } - fn calculate_confidence(&self, sad: i64, area: usize) -> f32 { - let avg_error = sad as f32 / area as f32; - (1.0 - (avg_error / 255.0)).max(0.0) - } - pub fn slide_match_v2( - &self, - target_pil: &DynamicImage, // 你的滑块图 - background_pil: &DynamicImage, // 你的背景图 - ) -> Result { - - // 1. 转换为灰度图 (Luma8) - let t_gray = target_pil.to_luma8(); - let b_gray = background_pil.to_luma8(); - - // 2. 使用 CrossCorrelationNormed (NCC 算法) - // 这种算法对亮度不敏感,专门对付有干扰、带阴影的“蜜蜂图” + // 3. 模板匹配 (完全对齐 cv2.matchTemplate(..., cv2.TM_CCOEFF_NORMED)) + // 在边缘图上计算归一化互相关系数 let result_map = match_template( - &b_gray, - &t_gray, + &background_edges, + &target_edges, MatchTemplateMethod::CrossCorrelationNormalized ); - let (tw, th) = target_pil.dimensions(); - let mut best_score = -1.0; - let mut best_x = 0; - let mut best_y = 0; + // 4. 找到最佳匹配位置 (对齐 cv2.minMaxLoc) + let mut max_val: f32 = -1.0; + let mut max_loc = (0, 0); - // 3. 智能过滤:解决 X=23 的干扰问题 + // 遍历匹配得分图 for (x, y, score) in result_map.enumerate_pixels() { - let score_val = score.0[0]; + let s = score.0[0]; - // 核心逻辑:跳过起始干扰区域。 - // 通常滑块移动距离不会小于 20 像素。 - // 如果那个 X=23 是干扰项,跳过它就能找到右边真正的坑位。 - if x < 20 { - continue; - } + // 可以在此处加入你之前验证过的起始位过滤 + // if x < 15 { continue; } - if score_val > best_score { - best_score = score_val; - best_x = x; - best_y = y; + if s > max_val { + max_val = s; + max_loc = (x, y); } } - // 4. 坐标对齐 (对齐 Python ddddocr 的中心点返回习惯) - // Python 237 = 我们的左边缘 214 + (滑块宽度 46 / 2) - let res_x = (best_x + tw / 2) as i32; - let res_y = (best_y + th / 2) as i32; + // 5. 计算中心位置 (对齐 Python 逻辑) + // target_w, target_h 来自输入数组的维度 + let (th, tw) = target.dim(); + let center_x = max_loc.0 as i32 + (tw as i32 / 2); + let center_y = max_loc.1 as i32 + (th as i32 / 2); + // 打印调试信息,方便与 Python 对比 + // println!("Edge Match: max_val: {}, max_loc: {:?}", max_val, max_loc); + println!("-Rust Target Width (tw): {}", tw); + println!("-Rust Best Max Loc X: {}", max_loc.0); + println!("-Rust Final Center X: {}", center_x); Ok(SlideResult { - target: [res_x, res_y], - target_x: res_x, - target_y: res_y, - confidence: best_score as f64 as f32, + target: [center_x, center_y], + target_x: center_x, + target_y: center_y, + confidence: max_val as f64, }) } + } diff --git a/tests/ocr_test.rs b/tests/ocr_test.rs index 81a1790..a915f22 100644 --- a/tests/ocr_test.rs +++ b/tests/ocr_test.rs @@ -118,4 +118,36 @@ fn test_real_slide_match() { assert_eq!(result.target_x, 237); assert_eq!(result.target_y, 77); assert!(result.confidence > 0.0); +} + +#[test] +fn test_real_slide_comparison() { + let engine = Slide::new(); + + // 1. 加载你准备好的测试图 + // 假设图片放在项目根目录下的 assets 文件夹 + let target_img = load_image("samples/ken.jpg") + .expect("请确保 samples/ken.jpg 存在"); + let bg_img = load_image("samples/kenyuan.jpg") + .expect("请确保 samples/kenyuan.jpg 存在"); + + // 2. 执行匹配 + // 如果是那种带有明显阴影边缘的复杂滑块,建议 simple_target 传 false + let start = std::time::Instant::now(); + let result = engine.slide_comparison(&target_img, &bg_img) + .expect("Slide match 执行失败"); + let duration = start.elapsed(); + + // 3. 打印结果 + println!("-------------------------------------------"); + println!("滑块匹配测试结果:"); + println!("检测坐标: [x: {}, y: {}]", result.target_x, result.target_y); + println!("置信度: {:.4}", result.confidence); + println!("耗时: {:?}", duration); + println!("-------------------------------------------"); + + // 验证基本逻辑:坐标不应为 0 (除非匹配失败) + assert_eq!(result.target_x, 171); + assert_eq!(result.target_y, 91); + assert!(result.confidence > 0.0); } \ No newline at end of file