Files
ddddocr-rs/src/slide_model.rs
CNWei a51147c888 refactor: 优化 slide_model.rs
- 新增 cv2.rs 模拟 opencv
2026-05-09 17:52:34 +08:00

277 lines
11 KiB
Rust
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

use crate::cv2::{min_max_loc, rgb_to_gray, ndarray_to_luma8, abs_diff};
use crate::image_io::image_to_ndarray;
use anyhow::{Context, Result, anyhow};
use image::{DynamicImage, GenericImageView};
use image::{ImageBuffer, Luma};
use imageproc::distance_transform::Norm;
use imageproc::edges::canny;
use imageproc::morphology::{close, open};
use imageproc::region_labelling::{Connectivity, connected_components};
use imageproc::template_matching::{MatchTemplateMethod, match_template};
use std::cmp::{max, min};
use imageproc::contrast::{threshold, ThresholdType};
use tract_onnx::prelude::tract_ndarray::{Array2, Array3, ArrayView2, ArrayView3, Axis, s};
pub struct SlideResult {
pub target: [i32; 2],
pub target_x: i32,
pub target_y: i32,
pub confidence: f64,
}
pub struct Slide;
impl Slide {
pub fn new() -> Self {
Self
}
/// 对应 Python: slide_match
pub fn slide_match(
&self,
target_image: &DynamicImage,
background_image: &DynamicImage,
simple_target: bool,
) -> Result<SlideResult> {
let target_array = image_to_ndarray(target_image);
let background_array = image_to_ndarray(background_image);
self.perform_slide_match(target_array.view(), background_array.view(), simple_target)
.map_err(|e| anyhow!("滑块匹配失败: {}", e))
}
/// 对应 Python: slide_comparison
/// 用于比较带坑位的图片与原始背景图,定位差异点
pub fn slide_comparison(
&self,
target_image: &DynamicImage,
background_image: &DynamicImage,
) -> Result<SlideResult> {
// 1. 转换为 ndarray (HWC RGB)
let target_array = image_to_ndarray(target_image);
let background_array = image_to_ndarray(background_image);
// 2. 执行比较逻辑 (对应 _perform_slide_comparison)
self.perform_slide_comparison(target_array.view(), background_array.view())
.map_err(|e| anyhow!("滑块比较执行失败: {}", e))
}
/// 对应 Python: _perform_slide_comparison
pub fn perform_slide_comparison(
&self,
target: ArrayView3<u8>,
background: ArrayView3<u8>,
) -> Result<SlideResult> {
let (h, w, _) = target.dim();
// 1. 计算图像差异并灰度化 (对应 cv2.absdiff + cv2.cvtColor)
// 使用 OpenCV 标准权重公式0.299R + 0.587G + 0.114B
// let mut diff_buffer = ImageBuffer::new(w as u32, h as u32);
// for y in 0..h {
// for x in 0..w {
// let r_diff = (target[[y, x, 0]] as i16 - background[[y, x, 0]] as i16).abs() as f32;
// let g_diff = (target[[y, x, 1]] as i16 - background[[y, x, 1]] as i16).abs() as f32;
// let b_diff = (target[[y, x, 2]] as i16 - background[[y, x, 2]] as i16).abs() as f32;
//
// let gray_diff = (0.299 * r_diff + 0.587 * g_diff + 0.114 * b_diff) as u8;
// diff_buffer.put_pixel(x as u32, y as u32, Luma([gray_diff]));
// }
// }
// 1. 计算差异数组 (复用 cv2::absdiff)
let diff_array = abs_diff(&target, &background);
// 2. 转换为灰度数组 (复用你的 cv2::rgb_to_gray)
let gray_array = rgb_to_gray(diff_array.view());
// 3. 转为 ImageBuffer 以使用 imageproc 的高级功能
let gray_buffer = ndarray_to_luma8(gray_array.view());
// 2. 二值化 (对应 cv2.threshold(..., 30, 255, cv2.THRESH_BINARY))
// let mut binary = ImageBuffer::new(w as u32, h as u32);
// for (x, y, pixel) in diff_buffer.enumerate_pixels() {
// let val = if pixel.0[0] > 30 { 255u8 } else { 0u8 };
// binary.put_pixel(x, y, Luma([val]));
// }
let binary = threshold(&gray_buffer, 30, ThresholdType::Binary);
// 3. 形态学操作去噪 (对应 cv2.morphologyEx)
// 闭运算 (Close): 先膨胀后腐蚀,用于填补缺口内的细小黑色空洞
// 开运算 (Open): 先腐蚀后膨胀,用于消除背景中的白色噪点点
let norm = Norm::LInf; // 对应 3x3 的矩形内核
let radius = 1u8; // 1 表示 3x3 的范围2 表示 5x5 的范围
let closed = close(&binary, norm, radius);
let cleaned = open(&closed, norm, radius);
// 4. 寻找最大连通区域 (对应 findContours + max area)
// connected_components 会给每个独立的白色区域打上不同的标签 (ID)
let background_label = Luma([0u8]);
let labelled = connected_components(&cleaned, Connectivity::Eight, background_label);
// 统计每个标签出现的频率(即面积)
let mut max_label = 0;
let mut max_area = 0;
let mut areas = std::collections::HashMap::new();
for pixel in labelled.pixels() {
let label = pixel.0[0];
if label == 0 {
continue;
} // 跳过背景
let count = areas.entry(label).or_insert(0);
*count += 1;
if *count > max_area {
max_area = *count;
max_label = label;
}
}
if max_label == 0 {
return Ok(SlideResult {
target: [0, 0],
target_x: 0,
target_y: 0,
confidence: 0.0,
});
}
// 5. 计算最大区域的边界框 (对应 cv2.boundingRect)
let mut min_x = w as u32;
let mut max_x = 0;
let mut min_y = h as u32;
let mut max_y = 0;
for (x, y, pixel) in labelled.enumerate_pixels() {
if pixel.0[0] == max_label {
min_x = min(min_x, x);
max_x = max(max_x, x);
min_y = min(min_y, y);
max_y = max(max_y, y);
}
}
// 6. 计算中心点
let rect_w = max_x - min_x;
let rect_h = max_y - min_y;
let center_x = (min_x + rect_w / 2) as i32;
let center_y = (min_y + rect_h / 2) as i32;
Ok(SlideResult {
target: [center_x, center_y],
target_x: center_x,
target_y: center_y,
confidence: 1.0, // Comparison 模式下通常认为找到即为 1.0
})
}
/// 对应 Python: _perform_slide_match
// 在 SlideEngine 中修改此入口进行测试
fn perform_slide_match(
&self,
target: ArrayView3<u8>,
background: ArrayView3<u8>,
simple_target: bool, // 增加这个参数
) -> Result<SlideResult> {
// 1. 统一灰度化
let target_gray = rgb_to_gray(target);
let background_gray = rgb_to_gray(background);
if simple_target {
// 2a. 简单模式:直接在灰度图上匹配
self.simple_template_match(target_gray.view(), background_gray.view())
} else {
// 2b. 复杂模式:先提取边缘,再匹配
self.edge_based_match(target_gray.view(), background_gray.view())
}
}
/// 对应 Python: _simple_template_match
/// 使用 SAD (Sum of Absolute Differences) 算法
/// 核心模板匹配SAD + 有效像素过滤
fn simple_template_match(
&self,
target: ArrayView2<u8>,
background: ArrayView2<u8>,
) -> Result<SlideResult> {
// 1. 将 ndarray 转换为 imageproc 需要的 ImageBuffer (无拷贝或轻量转换)
// let (bh, bw) = background.dim();
// 转换逻辑 (假设你已经有方法转回 ImageBuffer)
let t_buf = ndarray_to_luma8(target);
let b_buf = ndarray_to_luma8(background);
// t_buf.save("debug_rust_target.png").unwrap();
// 2. 调用 imageproc 的 NCC 算法 (等价于 cv2.TM_CCOEFF_NORMED)
// 模板匹配 (完全对齐 cv2.matchTemplate(..., cv2.TM_CCOEFF_NORMED))
let result = match_template(
&b_buf,
&t_buf,
MatchTemplateMethod::CrossCorrelationNormalized,
);
// save_rust_result(&result, "debug_rust_target2.png");
// 3. 寻找最大值 (等价于 cv2.minMaxLoc)
let (max_val, max_loc) = min_max_loc(&result);
// 4. 计算中心点 (与 Python 逻辑完全一致)
let (th, tw) = target.dim();
let center_x = max_loc.0 as i32 + (tw as i32 / 2);
let center_y = max_loc.1 as i32 + (th as i32 / 2);
// println!("Rust Target Width (tw): {}", tw);
// println!("Rust Best Max Loc X: {}", max_loc.0);
// println!("Rust Final Center X: {}", center_x);
Ok(SlideResult {
target: [center_x, center_y],
target_x: center_x,
target_y: center_y,
confidence: max_val as f64,
})
}
/// 对应 Python: _edge_based_match
/// 基于边缘检测的滑块匹配 (对齐 Python _edge_based_match)
pub fn edge_based_match(
&self,
target: ArrayView2<u8>,
background: ArrayView2<u8>,
) -> Result<SlideResult> {
// 1. 将 ndarray 转换为 ImageBuffer
// 注意Canny 和 match_template 需要 ImageBuffer 格式
let t_buf = ndarray_to_luma8(target);
let b_buf = ndarray_to_luma8(background);
// 2. 边缘检测 (完全对齐 cv2.Canny(50, 150))
// 这步会生成黑底白线的二值化边缘图
let target_edges = canny(&t_buf, 50.0, 150.0);
let background_edges = canny(&b_buf, 50.0, 150.0);
// target_edges.save("debug_target_edges.png").ok();
// background_edges.save("debug_bg_edges.png").ok();
// 3. 模板匹配 (完全对齐 cv2.matchTemplate(..., cv2.TM_CCOEFF_NORMED))
// 在边缘图上计算归一化互相关系数
let result = match_template(
&background_edges,
&target_edges,
MatchTemplateMethod::CrossCorrelationNormalized,
);
// 4. 找到最佳匹配位置 (对齐 cv2.minMaxLoc)
let (max_val, max_loc) = min_max_loc(&result);
// 5. 计算中心位置 (对齐 Python 逻辑)
// target_w, target_h 来自输入数组的维度
let (th, tw) = target.dim();
let center_x = max_loc.0 as i32 + (tw as i32 / 2);
let center_y = max_loc.1 as i32 + (th as i32 / 2);
// 打印调试信息,方便与 Python 对比
// println!("Edge Match: max_val: {}, max_loc: {:?}", max_val, max_loc);
println!("-Rust Target Width (tw): {}", tw);
println!("-Rust Best Max Loc X: {}", max_loc.0);
println!("-Rust Final Center X: {}", center_x);
Ok(SlideResult {
target: [center_x, center_y],
target_x: center_x,
target_y: center_y,
confidence: max_val as f64,
})
}
}