291 lines
11 KiB
Rust
291 lines
11 KiB
Rust
use anyhow::{Context, Result, anyhow};
|
||
use image::{DynamicImage, GenericImageView};
|
||
use tract_onnx::prelude::tract_ndarray::{Array2, Array3, ArrayView2, ArrayView3, Axis, s};
|
||
use imageproc::template_matching::{match_template, MatchTemplateMethod};
|
||
use image::{ImageBuffer, Luma};
|
||
use crate::image_io::image_to_ndarray;
|
||
use crate::cv2::rgb_to_gray;
|
||
use imageproc::edges::canny;
|
||
use imageproc::distance_transform::Norm;
|
||
use imageproc::morphology::{close, open};
|
||
use imageproc::region_labelling::{connected_components, Connectivity};
|
||
use std::cmp::{max, min};
|
||
|
||
pub struct SlideResult {
|
||
pub target: [i32; 2],
|
||
pub target_x: i32,
|
||
pub target_y: i32,
|
||
pub confidence: f64,
|
||
}
|
||
|
||
pub struct Slide;
|
||
|
||
impl Slide {
|
||
pub fn new() -> Self {
|
||
Self
|
||
}
|
||
|
||
/// 对应 Python: slide_match
|
||
pub fn slide_match(
|
||
&self,
|
||
target_pil: &DynamicImage,
|
||
background_pil: &DynamicImage,
|
||
simple_target: bool,
|
||
) -> Result<SlideResult> {
|
||
let target_array = image_to_ndarray(target_pil);
|
||
let background_array = image_to_ndarray(background_pil);
|
||
|
||
self.perform_slide_match(target_array.view(), background_array.view(),simple_target)
|
||
.map_err(|e| anyhow!("滑块匹配失败: {}", e))
|
||
}
|
||
/// 对应 Python: slide_comparison
|
||
/// 用于比较带坑位的图片与原始背景图,定位差异点
|
||
pub fn slide_comparison(
|
||
&self,
|
||
target_pil: &DynamicImage,
|
||
background_pil: &DynamicImage,
|
||
) -> Result<SlideResult> {
|
||
// 1. 转换为 ndarray (HWC RGB)
|
||
let target_array = image_to_ndarray(target_pil);
|
||
let background_array = image_to_ndarray(background_pil);
|
||
|
||
// 2. 执行比较逻辑 (对应 _perform_slide_comparison)
|
||
self.perform_slide_comparison(target_array.view(), background_array.view())
|
||
.map_err(|e| anyhow!("滑块比较执行失败: {}", e))
|
||
}
|
||
/// 对应 Python: _perform_slide_comparison
|
||
pub fn perform_slide_comparison(
|
||
&self,
|
||
target: ArrayView3<u8>,
|
||
background: ArrayView3<u8>,
|
||
) -> Result<SlideResult> {
|
||
let (h, w, _) = target.dim();
|
||
|
||
// 1. 计算图像差异并灰度化 (对应 cv2.absdiff + cv2.cvtColor)
|
||
// 使用 OpenCV 标准权重公式:0.299R + 0.587G + 0.114B
|
||
let mut diff_buffer = ImageBuffer::new(w as u32, h as u32);
|
||
for y in 0..h {
|
||
for x in 0..w {
|
||
let r_diff = (target[[y, x, 0]] as i16 - background[[y, x, 0]] as i16).abs() as f32;
|
||
let g_diff = (target[[y, x, 1]] as i16 - background[[y, x, 1]] as i16).abs() as f32;
|
||
let b_diff = (target[[y, x, 2]] as i16 - background[[y, x, 2]] as i16).abs() as f32;
|
||
|
||
let gray_diff = (0.299 * r_diff + 0.587 * g_diff + 0.114 * b_diff) as u8;
|
||
diff_buffer.put_pixel(x as u32, y as u32, Luma([gray_diff]));
|
||
}
|
||
}
|
||
|
||
// 2. 二值化 (对应 cv2.threshold(..., 30, 255, cv2.THRESH_BINARY))
|
||
let mut binary = ImageBuffer::new(w as u32, h as u32);
|
||
for (x, y, pixel) in diff_buffer.enumerate_pixels() {
|
||
let val = if pixel.0[0] > 30 { 255u8 } else { 0u8 };
|
||
binary.put_pixel(x, y, Luma([val]));
|
||
}
|
||
|
||
// 3. 形态学操作去噪 (对应 cv2.morphologyEx)
|
||
// 闭运算 (Close): 先膨胀后腐蚀,用于填补缺口内的细小黑色空洞
|
||
// 开运算 (Open): 先腐蚀后膨胀,用于消除背景中的白色噪点点
|
||
let norm = Norm::LInf; // 对应 3x3 的矩形内核
|
||
let radius = 1u8; // 1 表示 3x3 的范围,2 表示 5x5 的范围
|
||
let closed = close(&binary, norm, radius);
|
||
let cleaned = open(&closed, norm, radius);
|
||
|
||
// 4. 寻找最大连通区域 (对应 findContours + max area)
|
||
// connected_components 会给每个独立的白色区域打上不同的标签 (ID)
|
||
let background_label = Luma([0u8]);
|
||
let labelled = connected_components(&cleaned, Connectivity::Eight, background_label);
|
||
|
||
// 统计每个标签出现的频率(即面积)
|
||
let mut max_label = 0;
|
||
let mut max_area = 0;
|
||
let mut areas = std::collections::HashMap::new();
|
||
|
||
for pixel in labelled.pixels() {
|
||
let label = pixel.0[0];
|
||
if label == 0 { continue; } // 跳过背景
|
||
let count = areas.entry(label).or_insert(0);
|
||
*count += 1;
|
||
if *count > max_area {
|
||
max_area = *count;
|
||
max_label = label;
|
||
}
|
||
}
|
||
|
||
if max_label == 0 {
|
||
return Ok(SlideResult { target: [0, 0], target_x: 0, target_y: 0, confidence: 0.0 });
|
||
}
|
||
|
||
// 5. 计算最大区域的边界框 (对应 cv2.boundingRect)
|
||
let mut min_x = w as u32;
|
||
let mut max_x = 0;
|
||
let mut min_y = h as u32;
|
||
let mut max_y = 0;
|
||
|
||
for (x, y, pixel) in labelled.enumerate_pixels() {
|
||
if pixel.0[0] == max_label {
|
||
min_x = min(min_x, x);
|
||
max_x = max(max_x, x);
|
||
min_y = min(min_y, y);
|
||
max_y = max(max_y, y);
|
||
}
|
||
}
|
||
|
||
// 6. 计算中心点
|
||
let rect_w = max_x - min_x;
|
||
let rect_h = max_y - min_y;
|
||
let center_x = (min_x + rect_w / 2) as i32;
|
||
let center_y = (min_y + rect_h / 2) as i32;
|
||
|
||
Ok(SlideResult {
|
||
target: [center_x, center_y],
|
||
target_x: center_x,
|
||
target_y: center_y,
|
||
confidence: 1.0, // Comparison 模式下通常认为找到即为 1.0
|
||
})
|
||
}
|
||
|
||
/// 对应 Python: _perform_slide_match
|
||
// 在 SlideEngine 中修改此入口进行测试
|
||
fn perform_slide_match(
|
||
&self,
|
||
target: ArrayView3<u8>,
|
||
background: ArrayView3<u8>,
|
||
simple_target: bool, // 增加这个参数
|
||
) -> Result<SlideResult> {
|
||
// 1. 统一灰度化
|
||
let target_gray = rgb_to_gray(target);
|
||
let background_gray = rgb_to_gray(background);
|
||
|
||
if simple_target {
|
||
// 2a. 简单模式:直接在灰度图上匹配
|
||
self.simple_template_match(target_gray.view(), background_gray.view())
|
||
} else {
|
||
// 2b. 复杂模式:先提取边缘,再匹配
|
||
|
||
self.edge_based_match(target_gray.view(), background_gray.view())
|
||
}
|
||
}
|
||
/// 对应 Python: _simple_template_match
|
||
/// 使用 SAD (Sum of Absolute Differences) 算法
|
||
/// 核心模板匹配:SAD + 有效像素过滤
|
||
fn simple_template_match(
|
||
&self,
|
||
target: ArrayView2<u8>,
|
||
background: ArrayView2<u8>,
|
||
) -> Result<SlideResult> {
|
||
// 1. 将 ndarray 转换为 imageproc 需要的 ImageBuffer (无拷贝或轻量转换)
|
||
let (th, tw) = target.dim();
|
||
let (bh, bw) = background.dim();
|
||
|
||
// 转换逻辑 (假设你已经有方法转回 ImageBuffer)
|
||
let t_buf = self.ndarray_to_luma8(target);
|
||
let b_buf = self.ndarray_to_luma8(background);
|
||
t_buf.save("debug_rust_target.png").unwrap();
|
||
|
||
|
||
// 2. 调用 imageproc 的 NCC 算法 (等价于 cv2.TM_CCOEFF_NORMED)
|
||
let result = match_template(&b_buf, &t_buf, MatchTemplateMethod::CrossCorrelationNormalized);
|
||
// save_rust_result(&result, "debug_rust_target2.png");
|
||
// 3. 寻找最大值 (等价于 cv2.minMaxLoc)
|
||
let mut max_val: f32 = -1.0;
|
||
let mut max_loc = (0, 0);
|
||
|
||
for (x, y, score) in result.enumerate_pixels() {
|
||
let s = score.0[0];
|
||
// 这里的 x, y 是左上角坐标
|
||
if s > max_val {
|
||
max_val = s;
|
||
max_loc = (x, y);
|
||
}
|
||
}
|
||
|
||
// 4. 计算中心点 (与 Python 逻辑完全一致)
|
||
let center_x = max_loc.0 as i32 + (tw as i32 / 2);
|
||
let center_y = max_loc.1 as i32 + (th as i32 / 2);
|
||
// println!("Rust Target Width (tw): {}", tw);
|
||
// println!("Rust Best Max Loc X: {}", max_loc.0);
|
||
// println!("Rust Final Center X: {}", center_x);
|
||
Ok(SlideResult {
|
||
target: [center_x, center_y],
|
||
target_x: center_x,
|
||
target_y: center_y,
|
||
confidence: max_val as f64,
|
||
})
|
||
}
|
||
|
||
fn ndarray_to_luma8(&self, array: ArrayView2<u8>) -> ImageBuffer<Luma<u8>, Vec<u8>> {
|
||
let (height, width) = array.dim();
|
||
let mut buffer = ImageBuffer::new(width as u32, height as u32);
|
||
for y in 0..height {
|
||
for x in 0..width {
|
||
buffer.put_pixel(x as u32, y as u32, Luma([array[[y, x]]]));
|
||
}
|
||
}
|
||
buffer
|
||
}
|
||
/// 对应 Python: _edge_based_match
|
||
/// 基于边缘检测的滑块匹配 (对齐 Python _edge_based_match)
|
||
pub fn edge_based_match(
|
||
&self,
|
||
target: ArrayView2<u8>,
|
||
background: ArrayView2<u8>,
|
||
) -> Result<SlideResult> {
|
||
// 1. 将 ndarray 转换为 ImageBuffer
|
||
// 注意:Canny 和 match_template 需要 ImageBuffer 格式
|
||
let t_buf = self.ndarray_to_luma8(target);
|
||
let b_buf = self.ndarray_to_luma8(background);
|
||
|
||
// 2. 边缘检测 (完全对齐 cv2.Canny(50, 150))
|
||
// 这步会生成黑底白线的二值化边缘图
|
||
let target_edges = canny(&t_buf, 50.0, 150.0);
|
||
let background_edges = canny(&b_buf, 50.0, 150.0);
|
||
|
||
// target_edges.save("debug_target_edges.png").ok();
|
||
// background_edges.save("debug_bg_edges.png").ok();
|
||
|
||
// 3. 模板匹配 (完全对齐 cv2.matchTemplate(..., cv2.TM_CCOEFF_NORMED))
|
||
// 在边缘图上计算归一化互相关系数
|
||
let result_map = match_template(
|
||
&background_edges,
|
||
&target_edges,
|
||
MatchTemplateMethod::CrossCorrelationNormalized
|
||
);
|
||
|
||
// 4. 找到最佳匹配位置 (对齐 cv2.minMaxLoc)
|
||
let mut max_val: f32 = -1.0;
|
||
let mut max_loc = (0, 0);
|
||
|
||
// 遍历匹配得分图
|
||
for (x, y, score) in result_map.enumerate_pixels() {
|
||
let s = score.0[0];
|
||
|
||
// 可以在此处加入你之前验证过的起始位过滤
|
||
// if x < 15 { continue; }
|
||
|
||
if s > max_val {
|
||
max_val = s;
|
||
max_loc = (x, y);
|
||
}
|
||
}
|
||
|
||
// 5. 计算中心位置 (对齐 Python 逻辑)
|
||
// target_w, target_h 来自输入数组的维度
|
||
let (th, tw) = target.dim();
|
||
let center_x = max_loc.0 as i32 + (tw as i32 / 2);
|
||
let center_y = max_loc.1 as i32 + (th as i32 / 2);
|
||
|
||
// 打印调试信息,方便与 Python 对比
|
||
// println!("Edge Match: max_val: {}, max_loc: {:?}", max_val, max_loc);
|
||
println!("-Rust Target Width (tw): {}", tw);
|
||
println!("-Rust Best Max Loc X: {}", max_loc.0);
|
||
println!("-Rust Final Center X: {}", center_x);
|
||
Ok(SlideResult {
|
||
target: [center_x, center_y],
|
||
target_x: center_x,
|
||
target_y: center_y,
|
||
confidence: max_val as f64,
|
||
})
|
||
}
|
||
|
||
}
|