diff --git a/src/models/slide.rs b/src/models/slide.rs index 9afc854..f4f91b9 100644 --- a/src/models/slide.rs +++ b/src/models/slide.rs @@ -1,15 +1,16 @@ -use crate::utils::cv_ops::{min_max_loc, rgb_to_gray, ndarray_to_luma8, abs_diff}; +use crate::utils::cv_ops; +use crate::utils::cv_ops::{abs_diff, min_max_loc, ndarray_to_luma8, rgb_to_gray}; use crate::utils::image_io::image_to_ndarray; use anyhow::{Context, Result, anyhow}; use image::{DynamicImage, GenericImageView}; use image::{ImageBuffer, Luma}; +use imageproc::contrast::{ThresholdType, threshold}; use imageproc::distance_transform::Norm; use imageproc::edges::canny; use imageproc::morphology::{close, open}; use imageproc::region_labelling::{Connectivity, connected_components}; use imageproc::template_matching::{MatchTemplateMethod, match_template}; use std::cmp::{max, min}; -use imageproc::contrast::{threshold, ThresholdType}; use tract_onnx::prelude::tract_ndarray::{Array2, Array3, ArrayView2, ArrayView3, Axis, s}; pub struct SlideResult { @@ -78,17 +79,12 @@ impl Slide { // 1. 计算差异数组 (复用 cv2::absdiff) let diff_array = abs_diff(&target, &background); - // 2. 转换为灰度数组 (复用你的 cv2::rgb_to_gray) + // 2. 转换为灰度数组 (复用你的 cv2.cvtColor) let gray_array = rgb_to_gray(diff_array.view()); // 3. 转为 ImageBuffer 以使用 imageproc 的高级功能 let gray_buffer = ndarray_to_luma8(gray_array.view()); // 2. 二值化 (对应 cv2.threshold(..., 30, 255, cv2.THRESH_BINARY)) - // let mut binary = ImageBuffer::new(w as u32, h as u32); - // for (x, y, pixel) in diff_buffer.enumerate_pixels() { - // let val = if pixel.0[0] > 30 { 255u8 } else { 0u8 }; - // binary.put_pixel(x, y, Luma([val])); - // } let binary = threshold(&gray_buffer, 30, ThresholdType::Binary); // 3. 形态学操作去噪 (对应 cv2.morphologyEx) // 闭运算 (Close): 先膨胀后腐蚀,用于填补缺口内的细小黑色空洞 @@ -98,65 +94,32 @@ impl Slide { let closed = close(&binary, norm, radius); let cleaned = open(&closed, norm, radius); - // 4. 寻找最大连通区域 (对应 findContours + max area) // connected_components 会给每个独立的白色区域打上不同的标签 (ID) let background_label = Luma([0u8]); let labelled = connected_components(&cleaned, Connectivity::Eight, background_label); - // 统计每个标签出现的频率(即面积) - let mut max_label = 0; - let mut max_area = 0; - let mut areas = std::collections::HashMap::new(); + // // 统计每个标签出现的频率(即面积) + // 4. 寻找最大连通区域 (对应 findContours + max area) + if let Some(max_label) = cv_ops::find_contours_and_max(&labelled) { + // 5. 计算最大区域的边界框 (对应 cv2.boundingRect) + let (x, y, w, h) = cv_ops::bounding_rect(&labelled, max_label); + // 6. 计算中心点 (调用之前封装的 calculate_center) + let (center_x, center_y) = cv_ops::calculate_center((x, y), w as usize, h as usize); - for pixel in labelled.pixels() { - let label = pixel.0[0]; - if label == 0 { - continue; - } // 跳过背景 - let count = areas.entry(label).or_insert(0); - *count += 1; - if *count > max_area { - max_area = *count; - max_label = label; - } - } - - if max_label == 0 { - return Ok(SlideResult { + Ok(SlideResult { + target: [center_x, center_y], + target_x: center_x, + target_y: center_y, + confidence: 1.0, // Comparison 模式下通常认为找到即为 1.0 + }) + } else { + Ok(SlideResult { target: [0, 0], target_x: 0, target_y: 0, confidence: 0.0, - }); + }) } - - // 5. 计算最大区域的边界框 (对应 cv2.boundingRect) - let mut min_x = w as u32; - let mut max_x = 0; - let mut min_y = h as u32; - let mut max_y = 0; - - for (x, y, pixel) in labelled.enumerate_pixels() { - if pixel.0[0] == max_label { - min_x = min(min_x, x); - max_x = max(max_x, x); - min_y = min(min_y, y); - max_y = max(max_y, y); - } - } - - // 6. 计算中心点 - let rect_w = max_x - min_x; - let rect_h = max_y - min_y; - let center_x = (min_x + rect_w / 2) as i32; - let center_y = (min_y + rect_h / 2) as i32; - - Ok(SlideResult { - target: [center_x, center_y], - target_x: center_x, - target_y: center_y, - confidence: 1.0, // Comparison 模式下通常认为找到即为 1.0 - }) } /// 对应 Python: _perform_slide_match @@ -210,8 +173,8 @@ impl Slide { // 4. 计算中心点 (与 Python 逻辑完全一致) let (th, tw) = target.dim(); - let center_x = max_loc.0 as i32 + (tw as i32 / 2); - let center_y = max_loc.1 as i32 + (th as i32 / 2); + + let (center_x, center_y) = cv_ops::calculate_center(max_loc, tw as usize, th as usize); // println!("Rust Target Width (tw): {}", tw); // println!("Rust Best Max Loc X: {}", max_loc.0); // println!("Rust Final Center X: {}", center_x); @@ -256,8 +219,7 @@ impl Slide { // 5. 计算中心位置 (对齐 Python 逻辑) // target_w, target_h 来自输入数组的维度 let (th, tw) = target.dim(); - let center_x = max_loc.0 as i32 + (tw as i32 / 2); - let center_y = max_loc.1 as i32 + (th as i32 / 2); + let (center_x, center_y) = cv_ops::calculate_center(max_loc, tw as usize, th as usize); // 打印调试信息,方便与 Python 对比 // println!("Edge Match: max_val: {}, max_loc: {:?}", max_val, max_loc); @@ -271,6 +233,4 @@ impl Slide { confidence: max_val as f64, }) } - - } diff --git a/src/utils/cv_ops.rs b/src/utils/cv_ops.rs index 15aa3d9..7e22c55 100644 --- a/src/utils/cv_ops.rs +++ b/src/utils/cv_ops.rs @@ -1,3 +1,4 @@ +use std::cmp::{max, min}; use image::{ImageBuffer, Luma}; use tract_onnx::prelude::tract_ndarray::{azip, Array2, Array3, ArrayView2, ArrayView3}; @@ -45,6 +46,55 @@ pub fn min_max_loc(result_map: &ImageBuffer, Vec>) -> (f32, (u32, } (max_val, max_loc) } + +/// 1. 模拟 findContours 并获取最大面积区域的 Label +/// 返回 Option,如果找不到任何区域则返回 None +pub fn find_contours_and_max(labelled: &ImageBuffer, Vec>) -> Option { + // 统计每个标签出现的频率(即面积) + let mut max_label = 0; + let mut max_area = 0; + let mut areas = std::collections::HashMap::new(); + + for pixel in labelled.pixels() { + let label = pixel.0[0]; + if label == 0 { + continue; + } // 跳过背景 + let count = areas.entry(label).or_insert(0); + *count += 1; + if *count > max_area { + max_area = *count; + max_label = label; + } + } + if max_label == 0 { None } else { Some(max_label) } +} +pub fn bounding_rect(labelled: &ImageBuffer, Vec>,max_label: u32) -> (u32, u32, u32, u32) { + // 5. 计算最大区域的边界框 (对应 cv2.boundingRect) + let mut min_x = labelled.width(); + let mut max_x = 0; + let mut min_y = labelled.height(); + let mut max_y = 0; + + for (x, y, pixel) in labelled.enumerate_pixels() { + if pixel.0[0] == max_label { + min_x = min(min_x, x); + max_x = max(max_x, x); + min_y = min(min_y, y); + max_y = max(max_y, y); + } + } + + + let w = max_x - min_x; + let h = max_y - min_y; + (min_x, min_y, w, h) +} +pub fn calculate_center(max_loc: (u32, u32), tw: usize, th: usize) -> (i32, i32) { + let center_x = max_loc.0 as i32 + (tw as i32 / 2); + let center_y = max_loc.1 as i32 + (th as i32 / 2); + (center_x, center_y) +} pub fn ndarray_to_luma8(array: ArrayView2) -> ImageBuffer, Vec> { let (height, width) = array.dim(); let mut buffer = ImageBuffer::new(width as u32, height as u32); diff --git a/tests/ocr_test.rs b/tests/ocr_test.rs index 9e61d48..928eee6 100644 --- a/tests/ocr_test.rs +++ b/tests/ocr_test.rs @@ -1,23 +1,24 @@ +use ddddocr_rs::models::slide::Slide; +use ddddocr_rs::{DdddOcr, DdddOcrBuilder}; // 假设你的包名是这个 +use image::Rgb; use std::fs; use std::path::Path; -use image::Rgb; -use ddddocr_rs::{DdddOcr, DdddOcrBuilder}; // 假设你的包名是这个 -use ddddocr_rs::models::slide::Slide; fn load_image>(path: P) -> anyhow::Result { // 1. 先将泛型转为具体的 &Path 引用 let path_ref = path.as_ref(); // 2. 调用 open 时传入引用(utils::open 支持 AsRef) - image::open(path_ref) - .map_err(|e| { - // 3. 此时 path_ref 依然有效,可以安全地在闭包中使用 - anyhow::anyhow!("无法加载图片 {:?}: {}", path_ref, e) - }) + image::open(path_ref).map_err(|e| { + // 3. 此时 path_ref 依然有效,可以安全地在闭包中使用 + anyhow::anyhow!("无法加载图片 {:?}: {}", path_ref, e) + }) } /// 将检测结果绘制在图像上并保存 -fn save_debug_image( image_bytes: &[u8], bboxes: &Vec>, output_path: &str) -> anyhow::Result<()> { - - +fn save_debug_image( + image_bytes: &[u8], + bboxes: &Vec>, + output_path: &str, +) -> anyhow::Result<()> { let dynamic_img = image::load_from_memory(image_bytes)?; let mut img = dynamic_img.to_rgb8(); let (width, height) = img.dimensions(); @@ -35,16 +36,24 @@ fn save_debug_image( image_bytes: &[u8], bboxes: &Vec>, output_path: &s img.put_pixel(x, y1, red); img.put_pixel(x, y2, red); // 如果要加粗,多画一行 - if y1 + 1 < height { img.put_pixel(x, y1 + 1, red); } - if y2.saturating_sub(1) > 0 { img.put_pixel(x, y2 - 1, red); } + if y1 + 1 < height { + img.put_pixel(x, y1 + 1, red); + } + if y2.saturating_sub(1) > 0 { + img.put_pixel(x, y2 - 1, red); + } } // 绘制纵向线条 for y in y1..=y2 { img.put_pixel(x1, y, red); img.put_pixel(x2, y, red); // 如果要加粗,多画一列 - if x1 + 1 < width { img.put_pixel(x1 + 1, y, red); } - if x2.saturating_sub(1) > 0 { img.put_pixel(x2 - 1, y, red); } + if x1 + 1 < width { + img.put_pixel(x1 + 1, y, red); + } + if x2.saturating_sub(1) > 0 { + img.put_pixel(x2 - 1, y, red); + } } } @@ -66,43 +75,44 @@ fn test_full_classification() { assert!(!result.is_empty()); } #[test] -fn test_det_load()->anyhow::Result<()>{ +fn test_det_load() -> anyhow::Result<()> { let det = DdddOcrBuilder::new().det().build()?; let image_path = "samples/det1.png"; - let image_bytes = fs::read(image_path) - .map_err(|e| anyhow::anyhow!("无法读取图片 {}: {}", image_path, e))?; + let image_bytes = + fs::read(image_path).map_err(|e| anyhow::anyhow!("无法读取图片 {}: {}", image_path, e))?; println!("图片读取成功,字节大小: {}", image_bytes.len()); - let bboxes =det.detection(&image_bytes)?; - println!(":?{}",det); + let bboxes = det.detection(&image_bytes)?; + println!(":?{}", det); println!("检测到的目标数量: {}", bboxes.len()); if bboxes.is_empty() { println!("未检测到任何目标。"); } else { save_debug_image(&image_bytes, &bboxes, "samples/result.jpg")?; for (i, bbox) in bboxes.iter().enumerate() { - println!("目标 [{}]: x1={}, y1={}, x2={}, y2={}", i, bbox[0], bbox[1], bbox[2], bbox[3]); + println!( + "目标 [{}]: x1={}, y1={}, x2={}, y2={}", + i, bbox[0], bbox[1], bbox[2], bbox[3] + ); } } Ok(()) } - #[test] fn test_real_slide_match() { let engine = Slide::new(); // 1. 加载你准备好的测试图 // 假设图片放在项目根目录下的 assets 文件夹 - let target_img = load_image("samples/hua.png") - .expect("请确保 samples/hua.png 存在"); - let bg_img = load_image("samples/huatu.png") - .expect("请确保 samples/huatu.png 存在"); + let target_img = load_image("samples/hua.png").expect("请确保 samples/hua.png 存在"); + let bg_img = load_image("samples/huatu.png").expect("请确保 samples/huatu.png 存在"); // 2. 执行匹配 // 如果是那种带有明显阴影边缘的复杂滑块,建议 simple_target 传 false let start = std::time::Instant::now(); - let result = engine.slide_match(&target_img, &bg_img, false) + let result = engine + .slide_match(&target_img, &bg_img, false) .expect("Slide match 执行失败"); let duration = start.elapsed(); @@ -126,15 +136,14 @@ fn test_real_slide_comparison() { // 1. 加载你准备好的测试图 // 假设图片放在项目根目录下的 assets 文件夹 - let target_img = load_image("samples/ken.jpg") - .expect("请确保 samples/ken.jpg 存在"); - let bg_img = load_image("samples/kenyuan.jpg") - .expect("请确保 samples/kenyuan.jpg 存在"); + let target_img = load_image("samples/ken.jpg").expect("请确保 samples/ken.jpg 存在"); + let bg_img = load_image("samples/kenyuan.jpg").expect("请确保 samples/kenyuan.jpg 存在"); // 2. 执行匹配 // 如果是那种带有明显阴影边缘的复杂滑块,建议 simple_target 传 false let start = std::time::Instant::now(); - let result = engine.slide_comparison(&target_img, &bg_img) + let result = engine + .slide_comparison(&target_img, &bg_img) .expect("Slide match 执行失败"); let duration = start.elapsed(); @@ -150,4 +159,4 @@ fn test_real_slide_comparison() { assert_eq!(result.target_x, 171); assert_eq!(result.target_y, 90); assert!(result.confidence > 0.0); -} \ No newline at end of file +}