feat: 完成 Rust 滑块匹配算法,修复透明留白导致的坐标偏移
- 实现灰度与边缘两种匹配模式 - 对齐 OpenCV NCC 算法逻辑 - 优化图像灰度化与 Alpha 通道转换 - 提升坐标计算精度至像素级
This commit is contained in:
34
README.md
34
README.md
@@ -2,8 +2,42 @@
|
||||
|
||||
带带弟弟 OCR (ddddocr) 的 Rust 移植版。高性能、低占用,支持多种验证码识别与检测。
|
||||
|
||||
🧩 滑块识别算法核心知识点总结
|
||||
本项目实现了两种核心匹配模式,其底层逻辑与 OpenCV 的对齐情况如下:
|
||||
|
||||
1. 匹配模式对比 (Match Modes)
|
||||
|**模式**|**算法原理**|**适用场景**|**备注**|
|
||||
|---|---|---|---|
|
||||
|**边缘模式** (Edge-based)|基于 **Canny 边缘检测** 提取轮廓后再进行匹配。|**推荐方案**
|
||||
。适用于绝大多数拼图滑块。|天然免疫拼图周边的透明/黑色留白干扰,坐标最精准。|
|
||||
|**简单模式** (Simple/Gray)|直接基于 **灰度像素值** 进行归一化互相关计算。|适用于无明显边缘、靠颜色差异识别的场景。|对背景和透明边框敏感,可能存在重心偏移。|
|
||||
|
||||
2. 数学公式差异 (NCC vs. CCOEFF)
|
||||
在简单模式下,本项目采用的是 归一化互相关 (NCC),对应 OpenCV 中的 TM_CCORR_NORMED。
|
||||
|
||||
逻辑对齐:Rust 的 match_template 结果与 Python cv2.TM_CCORR_NORMED 完全一致。
|
||||
|
||||
关于偏移:若拼图原始图片(Target)四周包含大量的透明留白:
|
||||
|
||||
CCORR (本项目):会将留白视为图像的一部分,计算出的是整张图片框的中心。
|
||||
|
||||
CCOEFF (OpenCV 默认):会自动进行“均值中心化”,在一定程度上能削弱留白的影响。
|
||||
|
||||
最佳实践:若发现坐标有固定位移,建议优先切换至 边缘模式,或对滑块图进行 Bounding Box 裁剪 后再匹配。
|
||||
|
||||
3. 图像预处理一致性
|
||||
|
||||
为确保识别精度,本项目在 Rust 中完美复刻了 Python OpenCV 的预处理链路:
|
||||
|
||||
- **灰度化权重**:采用 OpenCV 标准感光公式 $0.299R + 0.587G + 0.114B$。
|
||||
|
||||
- **Alpha 处理**:在将 PNG 转为 RGB 时,自动将透明区域填充为黑色,确保与 PIL (Python Imaging Library) 行为一致。
|
||||
|
||||
- **坐标定义**:所有返回坐标均为匹配区域的 **几何中心点** $(x + w/2, y + h/2)$。
|
||||
|
||||
💡 开发者建议:
|
||||
|
||||
如果识别结果在 $X$ 轴上有大约 $10px$ 左右的固定误差,通常是因为滑块原图自带了透明边距(留白)。此时请确保 simple_target=false。该模式会通过 Canny 边缘检测 提取轮廓特征,能自动锁定拼图实体并忽略背景留白的像素干扰。
|
||||
鸣谢 (Credits)
|
||||
|
||||
- 本项目是 [ddddocr](https://github.com/sml2h3/ddddocr) 的 Rust 移植版本,原作者为 sml2h3。衷心感谢原作者对 OCR 社区做出的杰出贡献。
|
||||
|
||||
BIN
samples/ken.jpg
Normal file
BIN
samples/ken.jpg
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 8.3 KiB |
BIN
samples/kenyuan.jpg
Normal file
BIN
samples/kenyuan.jpg
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 8.0 KiB |
13
src/cv2.rs
Normal file
13
src/cv2.rs
Normal file
@@ -0,0 +1,13 @@
|
||||
use tract_onnx::prelude::tract_ndarray::{Array2, ArrayView3};
|
||||
|
||||
/// RGB 到灰度转换
|
||||
pub fn rgb_to_gray(rgb: ArrayView3<u8>) -> Array2<u8> {
|
||||
let (h, w, _) = rgb.dim();
|
||||
Array2::from_shape_fn((h, w), |(y, x)| {
|
||||
let r = rgb[[y, x, 0]] as f32;
|
||||
let g = rgb[[y, x, 1]] as f32;
|
||||
let b = rgb[[y, x, 2]] as f32;
|
||||
// 完全忽略 a,只按权重计算
|
||||
(0.299 * r + 0.587 * g + 0.114 * b) as u8
|
||||
})
|
||||
}
|
||||
@@ -1,6 +1,6 @@
|
||||
use anyhow::{Context, Result};
|
||||
use base64::{Engine as _, engine::general_purpose};
|
||||
use image::{DynamicImage, GenericImageView, ImageBuffer, Rgb, RgbImage};
|
||||
use image::{DynamicImage, GenericImageView, ImageBuffer, Luma, Rgb, RgbImage};
|
||||
use std::path::{Path, PathBuf};
|
||||
use tract_onnx::prelude::tract_ndarray::Array3;
|
||||
|
||||
@@ -60,3 +60,51 @@ pub fn png_rgba_white_preprocess(img: &DynamicImage) -> DynamicImage {
|
||||
|
||||
DynamicImage::ImageRgb8(background)
|
||||
}
|
||||
|
||||
pub fn image_to_ndarray(img: &DynamicImage) -> Array3<u8> {
|
||||
let (width, height) = img.dimensions();
|
||||
|
||||
// 1. 强制转为 RGB8 (丢弃 Alpha 通道,与 Python 的 target_mode='RGB' 对齐)
|
||||
let rgb_img = img.to_rgb8();
|
||||
|
||||
// 2. 获取原始像素数据
|
||||
let raw_data = rgb_img.into_raw();
|
||||
|
||||
// 3. 构造数组 (通道数改为 3)
|
||||
Array3::from_shape_vec((height as usize, width as usize, 3), raw_data)
|
||||
.expect("Failed to construct ndarray from image") // 建议显式报错,而不是返回全黑图
|
||||
}
|
||||
#[allow(dead_code)]
|
||||
fn save_rust_result(result: &ImageBuffer<Luma<f32>, Vec<f32>>, filename: &str) {
|
||||
let (width, height) = result.dimensions();
|
||||
|
||||
// 1. 寻找最值进行归一化
|
||||
let mut max_val = f32::MIN;
|
||||
let mut min_val = f32::MAX;
|
||||
for p in result.pixels() {
|
||||
if p.0[0] > max_val {
|
||||
max_val = p.0[0];
|
||||
}
|
||||
if p.0[0] < min_val {
|
||||
min_val = p.0[0];
|
||||
}
|
||||
}
|
||||
|
||||
// 2. 创建 8 位灰度图
|
||||
let mut out_buf = ImageBuffer::new(width, height);
|
||||
for y in 0..height {
|
||||
for x in 0..width {
|
||||
let val = result.get_pixel(x, y).0[0];
|
||||
let normalized = if max_val > min_val {
|
||||
((val - min_val) / (max_val - min_val) * 255.0) as u8
|
||||
} else {
|
||||
0u8
|
||||
};
|
||||
out_buf.put_pixel(x, y, Luma([normalized]));
|
||||
}
|
||||
}
|
||||
|
||||
// 3. 保存
|
||||
DynamicImage::ImageLuma8(out_buf).save(filename).unwrap();
|
||||
println!("Rust 结果热力图已保存至: {}", filename);
|
||||
}
|
||||
|
||||
@@ -8,6 +8,7 @@ mod model_loader;
|
||||
mod ocr_model;
|
||||
mod utils;
|
||||
pub mod slide_model;
|
||||
mod cv2;
|
||||
|
||||
use anyhow::Result;
|
||||
use image::DynamicImage;
|
||||
|
||||
@@ -2,11 +2,20 @@ use anyhow::{Context, Result, anyhow};
|
||||
use image::{DynamicImage, GenericImageView};
|
||||
use tract_onnx::prelude::tract_ndarray::{Array2, Array3, ArrayView2, ArrayView3, Axis, s};
|
||||
use imageproc::template_matching::{match_template, MatchTemplateMethod};
|
||||
use image::{ImageBuffer, Luma};
|
||||
use crate::image_io::image_to_ndarray;
|
||||
use crate::cv2::rgb_to_gray;
|
||||
use imageproc::edges::canny;
|
||||
use imageproc::distance_transform::Norm;
|
||||
use imageproc::morphology::{close, open};
|
||||
use imageproc::region_labelling::{connected_components, Connectivity};
|
||||
use std::cmp::{max, min};
|
||||
|
||||
pub struct SlideResult {
|
||||
pub target: [i32; 2],
|
||||
pub target_x: i32,
|
||||
pub target_y: i32,
|
||||
pub confidence: f32,
|
||||
pub confidence: f64,
|
||||
}
|
||||
|
||||
pub struct Slide;
|
||||
@@ -21,12 +30,12 @@ impl Slide {
|
||||
&self,
|
||||
target_pil: &DynamicImage,
|
||||
background_pil: &DynamicImage,
|
||||
_simple_target: bool,
|
||||
simple_target: bool,
|
||||
) -> Result<SlideResult> {
|
||||
let target_array = self.image_to_ndarray(target_pil);
|
||||
let background_array = self.image_to_ndarray(background_pil);
|
||||
let target_array = image_to_ndarray(target_pil);
|
||||
let background_array = image_to_ndarray(background_pil);
|
||||
|
||||
self.perform_slide_match(target_array.view(), background_array.view())
|
||||
self.perform_slide_match(target_array.view(), background_array.view(),simple_target)
|
||||
.map_err(|e| anyhow!("滑块匹配失败: {}", e))
|
||||
}
|
||||
/// 对应 Python: slide_comparison
|
||||
@@ -37,15 +46,15 @@ impl Slide {
|
||||
background_pil: &DynamicImage,
|
||||
) -> Result<SlideResult> {
|
||||
// 1. 转换为 ndarray (HWC RGB)
|
||||
let target_array = self.image_to_ndarray(target_pil);
|
||||
let background_array = self.image_to_ndarray(background_pil);
|
||||
let target_array = image_to_ndarray(target_pil);
|
||||
let background_array = image_to_ndarray(background_pil);
|
||||
|
||||
// 2. 执行比较逻辑 (对应 _perform_slide_comparison)
|
||||
self.perform_slide_comparison(target_array.view(), background_array.view())
|
||||
.map_err(|e| anyhow!("滑块比较执行失败: {}", e))
|
||||
}
|
||||
/// 对应 Python: _perform_slide_comparison
|
||||
fn perform_slide_comparison(
|
||||
pub fn perform_slide_comparison(
|
||||
&self,
|
||||
target: ArrayView3<u8>,
|
||||
background: ArrayView3<u8>,
|
||||
@@ -53,118 +62,108 @@ impl Slide {
|
||||
let (h, w, _) = target.dim();
|
||||
|
||||
// 1. 计算图像差异并灰度化 (对应 cv2.absdiff + cv2.cvtColor)
|
||||
let mut diff_gray = Array2::<u8>::zeros((h, w));
|
||||
// 使用 OpenCV 标准权重公式:0.299R + 0.587G + 0.114B
|
||||
let mut diff_buffer = ImageBuffer::new(w as u32, h as u32);
|
||||
for y in 0..h {
|
||||
for x in 0..w {
|
||||
let r_diff = (target[[y, x, 0]] as i16 - background[[y, x, 0]] as i16).abs();
|
||||
let g_diff = (target[[y, x, 1]] as i16 - background[[y, x, 1]] as i16).abs();
|
||||
let b_diff = (target[[y, x, 2]] as i16 - background[[y, x, 2]] as i16).abs();
|
||||
let r_diff = (target[[y, x, 0]] as i16 - background[[y, x, 0]] as i16).abs() as f32;
|
||||
let g_diff = (target[[y, x, 1]] as i16 - background[[y, x, 1]] as i16).abs() as f32;
|
||||
let b_diff = (target[[y, x, 2]] as i16 - background[[y, x, 2]] as i16).abs() as f32;
|
||||
|
||||
// 取三通道差异的平均值作为灰度差异
|
||||
diff_gray[[y, x]] = ((r_diff + g_diff + b_diff) / 3) as u8;
|
||||
let gray_diff = (0.299 * r_diff + 0.587 * g_diff + 0.114 * b_diff) as u8;
|
||||
diff_buffer.put_pixel(x as u32, y as u32, Luma([gray_diff]));
|
||||
}
|
||||
}
|
||||
|
||||
// 2. 二值化 (对应 cv2.threshold(diff_gray, 30, 255, cv2.THRESH_BINARY))
|
||||
let binary = diff_gray.mapv(|x| if x > 30 { 255u8 } else { 0u8 });
|
||||
// 2. 二值化 (对应 cv2.threshold(..., 30, 255, cv2.THRESH_BINARY))
|
||||
let mut binary = ImageBuffer::new(w as u32, h as u32);
|
||||
for (x, y, pixel) in diff_buffer.enumerate_pixels() {
|
||||
let val = if pixel.0[0] > 30 { 255u8 } else { 0u8 };
|
||||
binary.put_pixel(x, y, Luma([val]));
|
||||
}
|
||||
|
||||
// 3. 形态学去噪 (由于不引入 imageproc,我们通过简单的“中值滤波”或“区域平滑”模拟)
|
||||
// 在滑块场景中,若差异明显,直接寻找最大包围盒通常已经足够准确
|
||||
let binary_cleaned = self.simple_denoise(binary.view());
|
||||
// 3. 形态学操作去噪 (对应 cv2.morphologyEx)
|
||||
// 闭运算 (Close): 先膨胀后腐蚀,用于填补缺口内的细小黑色空洞
|
||||
// 开运算 (Open): 先腐蚀后膨胀,用于消除背景中的白色噪点点
|
||||
let norm = Norm::LInf; // 对应 3x3 的矩形内核
|
||||
let radius = 1u8; // 1 表示 3x3 的范围,2 表示 5x5 的范围
|
||||
let closed = close(&binary, norm, radius);
|
||||
let cleaned = open(&closed, norm, radius);
|
||||
|
||||
// 4. 寻找最大变动区域 (对应 findContours + max contour + boundingRect)
|
||||
self.find_largest_component_center(binary_cleaned.view())
|
||||
}
|
||||
/// 辅助:简单的去噪逻辑(模拟形态学操作)
|
||||
/// 检查像素周围,如果孤立点过多则抹除
|
||||
fn simple_denoise(&self, binary: ArrayView2<u8>) -> Array2<u8> {
|
||||
let (h, w) = binary.dim();
|
||||
let mut output = binary.to_owned();
|
||||
// 简单实现:如果一个点周围没有足够多的邻居,则认为是噪点(类似腐蚀)
|
||||
for y in 1..h - 1 {
|
||||
for x in 1..w - 1 {
|
||||
if binary[[y, x]] == 255 {
|
||||
let mut neighbors = 0;
|
||||
for ny in y - 1..=y + 1 {
|
||||
for nx in x - 1..=x + 1 {
|
||||
if binary[[ny, nx]] == 255 {
|
||||
neighbors += 1;
|
||||
}
|
||||
}
|
||||
}
|
||||
if neighbors < 3 {
|
||||
output[[y, x]] = 0;
|
||||
}
|
||||
}
|
||||
// 4. 寻找最大连通区域 (对应 findContours + max area)
|
||||
// connected_components 会给每个独立的白色区域打上不同的标签 (ID)
|
||||
let background_label = Luma([0u8]);
|
||||
let labelled = connected_components(&cleaned, Connectivity::Eight, background_label);
|
||||
|
||||
// 统计每个标签出现的频率(即面积)
|
||||
let mut max_label = 0;
|
||||
let mut max_area = 0;
|
||||
let mut areas = std::collections::HashMap::new();
|
||||
|
||||
for pixel in labelled.pixels() {
|
||||
let label = pixel.0[0];
|
||||
if label == 0 { continue; } // 跳过背景
|
||||
let count = areas.entry(label).or_insert(0);
|
||||
*count += 1;
|
||||
if *count > max_area {
|
||||
max_area = *count;
|
||||
max_label = label;
|
||||
}
|
||||
}
|
||||
output
|
||||
}
|
||||
|
||||
/// 辅助:寻找二值图中“最大块”的中心点
|
||||
fn find_largest_component_center(&self, binary: ArrayView2<u8>) -> Result<SlideResult> {
|
||||
let (h, w) = binary.dim();
|
||||
let mut min_x = w;
|
||||
if max_label == 0 {
|
||||
return Ok(SlideResult { target: [0, 0], target_x: 0, target_y: 0, confidence: 0.0 });
|
||||
}
|
||||
|
||||
// 5. 计算最大区域的边界框 (对应 cv2.boundingRect)
|
||||
let mut min_x = w as u32;
|
||||
let mut max_x = 0;
|
||||
let mut min_y = h;
|
||||
let mut min_y = h as u32;
|
||||
let mut max_y = 0;
|
||||
let mut found = false;
|
||||
|
||||
// 遍历寻找所有白色像素的边界
|
||||
for ((y, x), &val) in binary.indexed_iter() {
|
||||
if val == 255 {
|
||||
if x < min_x {
|
||||
min_x = x;
|
||||
}
|
||||
if x > max_x {
|
||||
max_x = x;
|
||||
}
|
||||
if y < min_y {
|
||||
min_y = y;
|
||||
}
|
||||
if y > max_y {
|
||||
max_y = y;
|
||||
}
|
||||
found = true;
|
||||
for (x, y, pixel) in labelled.enumerate_pixels() {
|
||||
if pixel.0[0] == max_label {
|
||||
min_x = min(min_x, x);
|
||||
max_x = max(max_x, x);
|
||||
min_y = min(min_y, y);
|
||||
max_y = max(max_y, y);
|
||||
}
|
||||
}
|
||||
|
||||
if !found {
|
||||
return Ok(SlideResult {
|
||||
target: [0, 0],
|
||||
target_x: 0,
|
||||
target_y: 0,
|
||||
confidence: 0.0,
|
||||
});
|
||||
}
|
||||
|
||||
let center_x = ((min_x + max_x) / 2) as i32;
|
||||
let center_y = ((min_y + max_y) / 2) as i32;
|
||||
// 6. 计算中心点
|
||||
let rect_w = max_x - min_x;
|
||||
let rect_h = max_y - min_y;
|
||||
let center_x = (min_x + rect_w / 2) as i32;
|
||||
let center_y = (min_y + rect_h / 2) as i32;
|
||||
|
||||
Ok(SlideResult {
|
||||
target: [center_x, center_y],
|
||||
target_x: center_x,
|
||||
target_y: center_y,
|
||||
confidence: 1.0,
|
||||
confidence: 1.0, // Comparison 模式下通常认为找到即为 1.0
|
||||
})
|
||||
}
|
||||
|
||||
/// 对应 Python: _perform_slide_match
|
||||
// 在 SlideEngine 中修改此入口进行测试
|
||||
fn perform_slide_match(
|
||||
&self,
|
||||
target: ArrayView3<u8>,
|
||||
background: ArrayView3<u8>,
|
||||
simple_target: bool, // 增加这个参数
|
||||
) -> Result<SlideResult> {
|
||||
// 1. 转换为灰度
|
||||
let target_gray = self.rgb_to_gray(target);
|
||||
let background_gray = self.rgb_to_gray(background);
|
||||
// 1. 统一灰度化
|
||||
let target_gray = rgb_to_gray(target);
|
||||
let background_gray = rgb_to_gray(background);
|
||||
|
||||
// 2. 提取边缘 (Sobel)
|
||||
let target_edges = self.sobel_edge_detection(target_gray.view());
|
||||
let background_edges = self.sobel_edge_detection(background_gray.view());
|
||||
if simple_target {
|
||||
// 2a. 简单模式:直接在灰度图上匹配
|
||||
self.simple_template_match(target_gray.view(), background_gray.view())
|
||||
} else {
|
||||
// 2b. 复杂模式:先提取边缘,再匹配
|
||||
|
||||
// 3. 在边缘图上进行匹配 (这是对齐 Python [237, 77] 的关键)
|
||||
self.simple_template_match(target_edges.view(), background_edges.view())
|
||||
self.edge_based_match(target_gray.view(), background_gray.view())
|
||||
}
|
||||
}
|
||||
/// 对应 Python: _simple_template_match
|
||||
/// 使用 SAD (Sum of Absolute Differences) 算法
|
||||
@@ -174,201 +173,118 @@ impl Slide {
|
||||
target: ArrayView2<u8>,
|
||||
background: ArrayView2<u8>,
|
||||
) -> Result<SlideResult> {
|
||||
// 1. 将 ndarray 转换为 imageproc 需要的 ImageBuffer (无拷贝或轻量转换)
|
||||
let (th, tw) = target.dim();
|
||||
let (bh, bw) = background.dim();
|
||||
|
||||
let mut min_sad = i64::MAX;
|
||||
let mut best_x = 0;
|
||||
let mut best_y = 0;
|
||||
// 转换逻辑 (假设你已经有方法转回 ImageBuffer)
|
||||
let t_buf = self.ndarray_to_luma8(target);
|
||||
let b_buf = self.ndarray_to_luma8(background);
|
||||
t_buf.save("debug_rust_target.png").unwrap();
|
||||
|
||||
// 1. 寻找滑块真正的“内容边界”(排除透明边距干扰)
|
||||
let mut content_left = tw;
|
||||
let mut content_right = 0;
|
||||
for r in 0..th {
|
||||
for c in 0..tw {
|
||||
if target[[r, c]] > 50 { // 假设边缘值大于50是有效内容
|
||||
if c < content_left { content_left = c; }
|
||||
if c > content_right { content_right = c; }
|
||||
}
|
||||
}
|
||||
}
|
||||
let content_width = if content_right > content_left { content_right - content_left } else { tw };
|
||||
|
||||
// 2. 遍历搜索
|
||||
// 技巧:y 从 10 开始,避开背景图最顶部的导航栏阴影干扰
|
||||
for y in 10..=(bh - th) {
|
||||
for x in 0..=(bw - tw) {
|
||||
let window = background.slice(s![y..y + th, x..x + tw]);
|
||||
let mut current_sad: i64 = 0;
|
||||
let mut count: i64 = 0;
|
||||
// 2. 调用 imageproc 的 NCC 算法 (等价于 cv2.TM_CCOEFF_NORMED)
|
||||
let result = match_template(&b_buf, &t_buf, MatchTemplateMethod::CrossCorrelationNormalized);
|
||||
// save_rust_result(&result, "debug_rust_target2.png");
|
||||
// 3. 寻找最大值 (等价于 cv2.minMaxLoc)
|
||||
let mut max_val: f32 = -1.0;
|
||||
let mut max_loc = (0, 0);
|
||||
|
||||
for r in 0..th {
|
||||
for c in 0..tw {
|
||||
let t_val = target[[r, c]];
|
||||
if t_val > 50 {
|
||||
let b_val = window[[r, c]];
|
||||
current_sad += (t_val as i16 - b_val as i16).abs() as i64;
|
||||
count += 1;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if count > 0 {
|
||||
// 惩罚项:如果 Y 坐标太靠上,给它一个额外的权重负担(防止误判 Y=0)
|
||||
let penalty = if y < 20 { 1000 } else { 0 };
|
||||
let score = (current_sad * 100 / count) + penalty;
|
||||
|
||||
if score < min_sad {
|
||||
min_sad = score;
|
||||
best_x = x;
|
||||
best_y = y;
|
||||
}
|
||||
}
|
||||
for (x, y, score) in result.enumerate_pixels() {
|
||||
let s = score.0[0];
|
||||
// 这里的 x, y 是左上角坐标
|
||||
if s > max_val {
|
||||
max_val = s;
|
||||
max_loc = (x, y);
|
||||
}
|
||||
}
|
||||
|
||||
// 3. 坐标转换:对齐 Python 的中心点逻辑
|
||||
// Python 237 = Rust 214 + (滑块有效宽度 46 / 2)
|
||||
let res_x = (best_x + (tw / 2)) as i32;
|
||||
let res_y = (best_y + (th / 2)) as i32;
|
||||
|
||||
// 4. 计算中心点 (与 Python 逻辑完全一致)
|
||||
let center_x = max_loc.0 as i32 + (tw as i32 / 2);
|
||||
let center_y = max_loc.1 as i32 + (th as i32 / 2);
|
||||
// println!("Rust Target Width (tw): {}", tw);
|
||||
// println!("Rust Best Max Loc X: {}", max_loc.0);
|
||||
// println!("Rust Final Center X: {}", center_x);
|
||||
Ok(SlideResult {
|
||||
target: [res_x, res_y],
|
||||
target_x: res_x,
|
||||
target_y: res_y,
|
||||
confidence: 0.98,
|
||||
target: [center_x, center_y],
|
||||
target_x: center_x,
|
||||
target_y: center_y,
|
||||
confidence: max_val as f64,
|
||||
})
|
||||
}
|
||||
|
||||
fn ndarray_to_luma8(&self, array: ArrayView2<u8>) -> ImageBuffer<Luma<u8>, Vec<u8>> {
|
||||
let (height, width) = array.dim();
|
||||
let mut buffer = ImageBuffer::new(width as u32, height as u32);
|
||||
for y in 0..height {
|
||||
for x in 0..width {
|
||||
buffer.put_pixel(x as u32, y as u32, Luma([array[[y, x]]]));
|
||||
}
|
||||
}
|
||||
buffer
|
||||
}
|
||||
/// 对应 Python: _edge_based_match
|
||||
fn edge_based_match(
|
||||
/// 基于边缘检测的滑块匹配 (对齐 Python _edge_based_match)
|
||||
pub fn edge_based_match(
|
||||
&self,
|
||||
target: ArrayView2<u8>,
|
||||
background: ArrayView2<u8>,
|
||||
) -> Result<SlideResult> {
|
||||
// 1. 提取边缘(只保留轮廓)
|
||||
let target_edges = self.sobel_edge_detection(target);
|
||||
println!("target_edges:{}", target_edges);
|
||||
let background_edges = self.sobel_edge_detection(background);
|
||||
// 1. 将 ndarray 转换为 ImageBuffer
|
||||
// 注意:Canny 和 match_template 需要 ImageBuffer 格式
|
||||
let t_buf = self.ndarray_to_luma8(target);
|
||||
let b_buf = self.ndarray_to_luma8(background);
|
||||
|
||||
// 2. 在边缘图上进行匹配(边缘图背景是黑的,线条是白的,SAD 会极其精准)
|
||||
// 注意:这里调用我们改进后的 simple_template_match
|
||||
self.simple_template_match(target_edges.view(), background_edges.view())
|
||||
}
|
||||
/// 模拟 image_to_numpy: DynamicImage -> Array3<u8> (HWC)
|
||||
fn image_to_ndarray(&self, img: &DynamicImage) -> Array3<u8> {
|
||||
let (width, height) = img.dimensions();
|
||||
let rgba_img = img.to_rgba8();
|
||||
let raw_data = rgba_img.into_raw();
|
||||
Array3::from_shape_vec((height as usize, width as usize, 4), raw_data)
|
||||
.unwrap_or_else(|_| Array3::zeros((height as usize, width as usize, 4)))
|
||||
}
|
||||
fn image_to_ndarray_with_mask(&self, img: &DynamicImage) -> (Array2<u8>, Array2<u8>) {
|
||||
let (width, height) = img.dimensions();
|
||||
let rgba_img = img.to_rgba8();
|
||||
// 2. 边缘检测 (完全对齐 cv2.Canny(50, 150))
|
||||
// 这步会生成黑底白线的二值化边缘图
|
||||
let target_edges = canny(&t_buf, 50.0, 150.0);
|
||||
let background_edges = canny(&b_buf, 50.0, 150.0);
|
||||
|
||||
let mut gray = Array2::zeros((height as usize, width as usize));
|
||||
let mut mask = Array2::zeros((height as usize, width as usize));
|
||||
// target_edges.save("debug_target_edges.png").ok();
|
||||
// background_edges.save("debug_bg_edges.png").ok();
|
||||
|
||||
for (x, y, pixel) in rgba_img.enumerate_pixels() {
|
||||
// 简单的灰度转换
|
||||
let g = (0.299 * pixel[0] as f32 + 0.587 * pixel[1] as f32 + 0.114 * pixel[2] as f32) as u8;
|
||||
gray[[y as usize, x as usize]] = g;
|
||||
// 只有不透明度大于 0 的才作为有效匹配区域
|
||||
mask[[y as usize, x as usize]] = if pixel[3] > 0 { 1 } else { 0 };
|
||||
}
|
||||
(gray, mask)
|
||||
}
|
||||
/// RGB 到灰度转换
|
||||
fn rgb_to_gray(&self, rgba: ArrayView3<u8>) -> Array2<u8> {
|
||||
let (h, w, _) = rgba.dim();
|
||||
Array2::from_shape_fn((h, w), |(y, x)| {
|
||||
let r = rgba[[y, x, 0]] as f32;
|
||||
let g = rgba[[y, x, 1]] as f32;
|
||||
let b = rgba[[y, x, 2]] as f32;
|
||||
let a = rgba[[y, x, 3]] as f32;
|
||||
|
||||
// 如果 Alpha 是 0,强制背景为黑色
|
||||
if a < 128.0 {
|
||||
0
|
||||
} else {
|
||||
(0.299 * r + 0.587 * g + 0.114 * b) as u8
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
/// 简单的 Sobel 边缘检测实现
|
||||
fn sobel_edge_detection(&self, input: ArrayView2<u8>) -> Array2<u8> {
|
||||
let (h, w) = input.dim();
|
||||
let mut output = Array2::zeros((h, w));
|
||||
for y in 1..h - 1 {
|
||||
for x in 1..w - 1 {
|
||||
let gx = (input[[y - 1, x + 1]] as i32 + 2 * input[[y, x + 1]] as i32 + input[[y + 1, x + 1]] as i32)
|
||||
- (input[[y - 1, x - 1]] as i32 + 2 * input[[y, x - 1]] as i32 + input[[y + 1, x - 1]] as i32);
|
||||
let gy = (input[[y + 1, x - 1]] as i32 + 2 * input[[y + 1, x]] as i32 + input[[y + 1, x + 1]] as i32)
|
||||
- (input[[y - 1, x - 1]] as i32 + 2 * input[[y - 1, x]] as i32 + input[[y - 1, x + 1]] as i32);
|
||||
|
||||
let mag = ((gx.pow(2) + gy.pow(2)) as f32).sqrt();
|
||||
// 强化边缘:稍微提高对比度
|
||||
output[[y, x]] = (mag.min(255.0)) as u8;
|
||||
}
|
||||
}
|
||||
output
|
||||
}
|
||||
fn calculate_confidence(&self, sad: i64, area: usize) -> f32 {
|
||||
let avg_error = sad as f32 / area as f32;
|
||||
(1.0 - (avg_error / 255.0)).max(0.0)
|
||||
}
|
||||
pub fn slide_match_v2(
|
||||
&self,
|
||||
target_pil: &DynamicImage, // 你的滑块图
|
||||
background_pil: &DynamicImage, // 你的背景图
|
||||
) -> Result<SlideResult> {
|
||||
|
||||
// 1. 转换为灰度图 (Luma8)
|
||||
let t_gray = target_pil.to_luma8();
|
||||
let b_gray = background_pil.to_luma8();
|
||||
|
||||
// 2. 使用 CrossCorrelationNormed (NCC 算法)
|
||||
// 这种算法对亮度不敏感,专门对付有干扰、带阴影的“蜜蜂图”
|
||||
// 3. 模板匹配 (完全对齐 cv2.matchTemplate(..., cv2.TM_CCOEFF_NORMED))
|
||||
// 在边缘图上计算归一化互相关系数
|
||||
let result_map = match_template(
|
||||
&b_gray,
|
||||
&t_gray,
|
||||
&background_edges,
|
||||
&target_edges,
|
||||
MatchTemplateMethod::CrossCorrelationNormalized
|
||||
);
|
||||
|
||||
let (tw, th) = target_pil.dimensions();
|
||||
let mut best_score = -1.0;
|
||||
let mut best_x = 0;
|
||||
let mut best_y = 0;
|
||||
// 4. 找到最佳匹配位置 (对齐 cv2.minMaxLoc)
|
||||
let mut max_val: f32 = -1.0;
|
||||
let mut max_loc = (0, 0);
|
||||
|
||||
// 3. 智能过滤:解决 X=23 的干扰问题
|
||||
// 遍历匹配得分图
|
||||
for (x, y, score) in result_map.enumerate_pixels() {
|
||||
let score_val = score.0[0];
|
||||
let s = score.0[0];
|
||||
|
||||
// 核心逻辑:跳过起始干扰区域。
|
||||
// 通常滑块移动距离不会小于 20 像素。
|
||||
// 如果那个 X=23 是干扰项,跳过它就能找到右边真正的坑位。
|
||||
if x < 20 {
|
||||
continue;
|
||||
}
|
||||
// 可以在此处加入你之前验证过的起始位过滤
|
||||
// if x < 15 { continue; }
|
||||
|
||||
if score_val > best_score {
|
||||
best_score = score_val;
|
||||
best_x = x;
|
||||
best_y = y;
|
||||
if s > max_val {
|
||||
max_val = s;
|
||||
max_loc = (x, y);
|
||||
}
|
||||
}
|
||||
|
||||
// 4. 坐标对齐 (对齐 Python ddddocr 的中心点返回习惯)
|
||||
// Python 237 = 我们的左边缘 214 + (滑块宽度 46 / 2)
|
||||
let res_x = (best_x + tw / 2) as i32;
|
||||
let res_y = (best_y + th / 2) as i32;
|
||||
// 5. 计算中心位置 (对齐 Python 逻辑)
|
||||
// target_w, target_h 来自输入数组的维度
|
||||
let (th, tw) = target.dim();
|
||||
let center_x = max_loc.0 as i32 + (tw as i32 / 2);
|
||||
let center_y = max_loc.1 as i32 + (th as i32 / 2);
|
||||
|
||||
// 打印调试信息,方便与 Python 对比
|
||||
// println!("Edge Match: max_val: {}, max_loc: {:?}", max_val, max_loc);
|
||||
println!("-Rust Target Width (tw): {}", tw);
|
||||
println!("-Rust Best Max Loc X: {}", max_loc.0);
|
||||
println!("-Rust Final Center X: {}", center_x);
|
||||
Ok(SlideResult {
|
||||
target: [res_x, res_y],
|
||||
target_x: res_x,
|
||||
target_y: res_y,
|
||||
confidence: best_score as f64 as f32,
|
||||
target: [center_x, center_y],
|
||||
target_x: center_x,
|
||||
target_y: center_y,
|
||||
confidence: max_val as f64,
|
||||
})
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
@@ -119,3 +119,35 @@ fn test_real_slide_match() {
|
||||
assert_eq!(result.target_y, 77);
|
||||
assert!(result.confidence > 0.0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_real_slide_comparison() {
|
||||
let engine = Slide::new();
|
||||
|
||||
// 1. 加载你准备好的测试图
|
||||
// 假设图片放在项目根目录下的 assets 文件夹
|
||||
let target_img = load_image("samples/ken.jpg")
|
||||
.expect("请确保 samples/ken.jpg 存在");
|
||||
let bg_img = load_image("samples/kenyuan.jpg")
|
||||
.expect("请确保 samples/kenyuan.jpg 存在");
|
||||
|
||||
// 2. 执行匹配
|
||||
// 如果是那种带有明显阴影边缘的复杂滑块,建议 simple_target 传 false
|
||||
let start = std::time::Instant::now();
|
||||
let result = engine.slide_comparison(&target_img, &bg_img)
|
||||
.expect("Slide match 执行失败");
|
||||
let duration = start.elapsed();
|
||||
|
||||
// 3. 打印结果
|
||||
println!("-------------------------------------------");
|
||||
println!("滑块匹配测试结果:");
|
||||
println!("检测坐标: [x: {}, y: {}]", result.target_x, result.target_y);
|
||||
println!("置信度: {:.4}", result.confidence);
|
||||
println!("耗时: {:?}", duration);
|
||||
println!("-------------------------------------------");
|
||||
|
||||
// 验证基本逻辑:坐标不应为 0 (除非匹配失败)
|
||||
assert_eq!(result.target_x, 171);
|
||||
assert_eq!(result.target_y, 91);
|
||||
assert!(result.confidence > 0.0);
|
||||
}
|
||||
Reference in New Issue
Block a user