2 Commits

Author SHA1 Message Date
0923d92150 feat: 优化 slide.rs 2026-05-11 22:54:05 +08:00
0df9022411 feat: 优化 项目目录结构 2026-05-10 20:52:42 +08:00
14 changed files with 156 additions and 134 deletions

View File

@@ -1,14 +1,7 @@
pub mod base;
mod charset;
mod det_model;
mod image_io;
mod image_processor;
mod model;
mod model_loader;
mod ocr_model;
mod utils;
pub mod slide_model;
mod cv2;
pub mod models;
pub mod utils;
use anyhow::Result;
use image::DynamicImage;
@@ -16,9 +9,9 @@ use std::fmt::{Display, Formatter};
// 关键点:直接使用 tract 重导出的 ndarray
use crate::charset::get_default_charset;
use crate::det_model::Det;
use crate::model_loader::ModelSession;
use crate::ocr_model::Ocr;
use models::det::Det;
use models::loader::ModelSession;
use models::ocr::Ocr;
pub enum ModelSpec {
/// 默认 OCR (使用内置路径)
OcrModel,
@@ -31,7 +24,7 @@ pub enum ModelSpec {
}
impl ModelSpec {
// 将默认路径定义为内部关联常量
const DEFAULT_OCR_PATH: &'static str = "models/common_sml2h3_f32.onnx";
const DEFAULT_OCR_PATH: &'static str = "models/common.onnx";
const DEFAULT_DET_PATH: &'static str = "models/common_det.onnx";
}
pub enum Runtime {
@@ -74,7 +67,10 @@ impl DdddOcrBuilder {
/// 核心初始化逻辑
pub fn build(self) -> Result<DdddOcr> {
let runtime = match self.mode {
ModelSpec::OcrModel => Runtime::Ocr(Ocr::new(ModelSpec::DEFAULT_OCR_PATH.into(), get_default_charset())?),
ModelSpec::OcrModel => Runtime::Ocr(Ocr::new(
ModelSpec::DEFAULT_OCR_PATH.into(),
get_default_charset(),
)?),
ModelSpec::DetModel => Runtime::Det(Det::new(ModelSpec::DEFAULT_DET_PATH.into())?),
ModelSpec::CustomOcrModel { path, charset } => Runtime::Ocr(Ocr::new(path, charset)?),
};
@@ -110,7 +106,6 @@ impl DdddOcr {
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_ctc_decode_indices() {
// 模拟一个 DdddOcr 实例(如果 decode 不依赖 session可以设为相关函数

View File

View File

@@ -1,4 +1,4 @@
use crate::model_loader::{ModelLoader, ModelSession, ModelType};
use crate::models::loader::{ModelLoader, ModelSession, ModelType};
use anyhow::{Context, Result};
use image::{DynamicImage, GenericImageView, imageops::FilterType};
use tract_onnx::prelude::tract_ndarray::{Array2, Array3, Array4, Axis, prelude::*, s};
@@ -219,8 +219,8 @@ impl Det {
}
/// 6. get_bbox (完全解耦 OpenCV)
pub fn get_bbox(&self, image_bytes: &[u8]) -> Result<Vec<Vec<i32>>> {
// 使用 image crate 解码
let dynamic_img = image::load_from_memory(image_bytes).context("Failed to decode image")?;
// 使用 utils crate 解码
let dynamic_img = image::load_from_memory(image_bytes).context("Failed to decode utils")?;
let (orig_w, orig_h) = dynamic_img.dimensions();
let (input_tensor, ratio) = self.preproc(&dynamic_img, (416, 416))?;

View File

@@ -3,8 +3,8 @@ use image::DynamicImage;
use tract_onnx::onnx;
use tract_onnx::prelude::*;
// 关键点:直接使用 tract 重导出的 ndarray
use crate::image_io::png_rgba_white_preprocess;
use crate::image_processor::{convert_to_grayscale, resize_image};
use crate::utils::image_io::png_rgba_white_preprocess;
use crate::utils::image_processor::{convert_to_grayscale, resize_image};
use std::collections::HashMap;
use tract_onnx::prelude::tract_ndarray::s;

5
src/models/mod.rs Normal file
View File

@@ -0,0 +1,5 @@
pub mod base;
pub mod loader;
pub mod ocr;
pub mod det;
pub mod slide;

View File

@@ -1,7 +1,7 @@
use crate::base::ModelArgs;
use crate::image_io::png_rgba_white_preprocess;
use crate::image_processor::{convert_to_grayscale, resize_image};
use crate::model_loader::{ModelLoader, ModelSession, ModelType};
use crate::models::base::ModelArgs;
use crate::utils::image_io::png_rgba_white_preprocess;
use crate::utils::image_processor::{convert_to_grayscale, resize_image};
use crate::models::loader::{ModelLoader, ModelSession, ModelType};
use anyhow::Context;
use image::DynamicImage;
use tract_onnx::prelude::tract_ndarray::s;

View File

@@ -1,15 +1,16 @@
use crate::cv2::{min_max_loc, rgb_to_gray, ndarray_to_luma8, abs_diff};
use crate::image_io::image_to_ndarray;
use crate::utils::cv_ops;
use crate::utils::cv_ops::{abs_diff, min_max_loc, ndarray_to_luma8, rgb_to_gray};
use crate::utils::image_io::image_to_ndarray;
use anyhow::{Context, Result, anyhow};
use image::{DynamicImage, GenericImageView};
use image::{ImageBuffer, Luma};
use imageproc::contrast::{ThresholdType, threshold};
use imageproc::distance_transform::Norm;
use imageproc::edges::canny;
use imageproc::morphology::{close, open};
use imageproc::region_labelling::{Connectivity, connected_components};
use imageproc::template_matching::{MatchTemplateMethod, match_template};
use std::cmp::{max, min};
use imageproc::contrast::{threshold, ThresholdType};
use tract_onnx::prelude::tract_ndarray::{Array2, Array3, ArrayView2, ArrayView3, Axis, s};
pub struct SlideResult {
@@ -78,17 +79,12 @@ impl Slide {
// 1. 计算差异数组 (复用 cv2::absdiff)
let diff_array = abs_diff(&target, &background);
// 2. 转换为灰度数组 (复用你的 cv2::rgb_to_gray)
// 2. 转换为灰度数组 (复用你的 cv2.cvtColor)
let gray_array = rgb_to_gray(diff_array.view());
// 3. 转为 ImageBuffer 以使用 imageproc 的高级功能
let gray_buffer = ndarray_to_luma8(gray_array.view());
// 2. 二值化 (对应 cv2.threshold(..., 30, 255, cv2.THRESH_BINARY))
// let mut binary = ImageBuffer::new(w as u32, h as u32);
// for (x, y, pixel) in diff_buffer.enumerate_pixels() {
// let val = if pixel.0[0] > 30 { 255u8 } else { 0u8 };
// binary.put_pixel(x, y, Luma([val]));
// }
let binary = threshold(&gray_buffer, 30, ThresholdType::Binary);
// 3. 形态学操作去噪 (对应 cv2.morphologyEx)
// 闭运算 (Close): 先膨胀后腐蚀,用于填补缺口内的细小黑色空洞
@@ -98,65 +94,32 @@ impl Slide {
let closed = close(&binary, norm, radius);
let cleaned = open(&closed, norm, radius);
// 4. 寻找最大连通区域 (对应 findContours + max area)
// connected_components 会给每个独立的白色区域打上不同的标签 (ID)
let background_label = Luma([0u8]);
let labelled = connected_components(&cleaned, Connectivity::Eight, background_label);
// 统计每个标签出现的频率(即面积)
let mut max_label = 0;
let mut max_area = 0;
let mut areas = std::collections::HashMap::new();
// // 统计每个标签出现的频率(即面积)
// 4. 寻找最大连通区域 (对应 findContours + max area)
if let Some(max_label) = cv_ops::find_contours_and_max(&labelled) {
// 5. 计算最大区域的边界框 (对应 cv2.boundingRect)
let (x, y, w, h) = cv_ops::bounding_rect(&labelled, max_label);
// 6. 计算中心点 (调用之前封装的 calculate_center)
let (center_x, center_y) = cv_ops::calculate_center((x, y), w as usize, h as usize);
for pixel in labelled.pixels() {
let label = pixel.0[0];
if label == 0 {
continue;
} // 跳过背景
let count = areas.entry(label).or_insert(0);
*count += 1;
if *count > max_area {
max_area = *count;
max_label = label;
}
}
if max_label == 0 {
return Ok(SlideResult {
Ok(SlideResult {
target: [center_x, center_y],
target_x: center_x,
target_y: center_y,
confidence: 1.0, // Comparison 模式下通常认为找到即为 1.0
})
} else {
Ok(SlideResult {
target: [0, 0],
target_x: 0,
target_y: 0,
confidence: 0.0,
});
})
}
// 5. 计算最大区域的边界框 (对应 cv2.boundingRect)
let mut min_x = w as u32;
let mut max_x = 0;
let mut min_y = h as u32;
let mut max_y = 0;
for (x, y, pixel) in labelled.enumerate_pixels() {
if pixel.0[0] == max_label {
min_x = min(min_x, x);
max_x = max(max_x, x);
min_y = min(min_y, y);
max_y = max(max_y, y);
}
}
// 6. 计算中心点
let rect_w = max_x - min_x;
let rect_h = max_y - min_y;
let center_x = (min_x + rect_w / 2) as i32;
let center_y = (min_y + rect_h / 2) as i32;
Ok(SlideResult {
target: [center_x, center_y],
target_x: center_x,
target_y: center_y,
confidence: 1.0, // Comparison 模式下通常认为找到即为 1.0
})
}
/// 对应 Python: _perform_slide_match
@@ -210,8 +173,8 @@ impl Slide {
// 4. 计算中心点 (与 Python 逻辑完全一致)
let (th, tw) = target.dim();
let center_x = max_loc.0 as i32 + (tw as i32 / 2);
let center_y = max_loc.1 as i32 + (th as i32 / 2);
let (center_x, center_y) = cv_ops::calculate_center(max_loc, tw as usize, th as usize);
// println!("Rust Target Width (tw): {}", tw);
// println!("Rust Best Max Loc X: {}", max_loc.0);
// println!("Rust Final Center X: {}", center_x);
@@ -256,8 +219,7 @@ impl Slide {
// 5. 计算中心位置 (对齐 Python 逻辑)
// target_w, target_h 来自输入数组的维度
let (th, tw) = target.dim();
let center_x = max_loc.0 as i32 + (tw as i32 / 2);
let center_y = max_loc.1 as i32 + (th as i32 / 2);
let (center_x, center_y) = cv_ops::calculate_center(max_loc, tw as usize, th as usize);
// 打印调试信息,方便与 Python 对比
// println!("Edge Match: max_val: {}, max_loc: {:?}", max_val, max_loc);
@@ -271,6 +233,4 @@ impl Slide {
confidence: max_val as f64,
})
}
}

View File

View File

@@ -1,3 +1,4 @@
use std::cmp::{max, min};
use image::{ImageBuffer, Luma};
use tract_onnx::prelude::tract_ndarray::{azip, Array2, Array3, ArrayView2, ArrayView3};
@@ -45,6 +46,55 @@ pub fn min_max_loc(result_map: &ImageBuffer<Luma<f32>, Vec<f32>>) -> (f32, (u32,
}
(max_val, max_loc)
}
/// 1. 模拟 findContours 并获取最大面积区域的 Label
/// 返回 Option<u32>,如果找不到任何区域则返回 None
pub fn find_contours_and_max(labelled: &ImageBuffer<Luma<u32>, Vec<u32>>) -> Option<u32> {
// 统计每个标签出现的频率(即面积)
let mut max_label = 0;
let mut max_area = 0;
let mut areas = std::collections::HashMap::new();
for pixel in labelled.pixels() {
let label = pixel.0[0];
if label == 0 {
continue;
} // 跳过背景
let count = areas.entry(label).or_insert(0);
*count += 1;
if *count > max_area {
max_area = *count;
max_label = label;
}
}
if max_label == 0 { None } else { Some(max_label) }
}
pub fn bounding_rect(labelled: &ImageBuffer<Luma<u32>, Vec<u32>>,max_label: u32) -> (u32, u32, u32, u32) {
// 5. 计算最大区域的边界框 (对应 cv2.boundingRect)
let mut min_x = labelled.width();
let mut max_x = 0;
let mut min_y = labelled.height();
let mut max_y = 0;
for (x, y, pixel) in labelled.enumerate_pixels() {
if pixel.0[0] == max_label {
min_x = min(min_x, x);
max_x = max(max_x, x);
min_y = min(min_y, y);
max_y = max(max_y, y);
}
}
let w = max_x - min_x;
let h = max_y - min_y;
(min_x, min_y, w, h)
}
pub fn calculate_center(max_loc: (u32, u32), tw: usize, th: usize) -> (i32, i32) {
let center_x = max_loc.0 as i32 + (tw as i32 / 2);
let center_y = max_loc.1 as i32 + (th as i32 / 2);
(center_x, center_y)
}
pub fn ndarray_to_luma8(array: ArrayView2<u8>) -> ImageBuffer<Luma<u8>, Vec<u8>> {
let (height, width) = array.dim();
let mut buffer = ImageBuffer::new(width as u32, height as u32);

View File

@@ -24,7 +24,7 @@ pub fn load_image_from_input(img_input: ImageInput) -> Result<DynamicImage> {
match img_input {
// 2. 处理字节流 (Bytes)
ImageInput::Bytes(bytes) => {
image::load_from_memory(&bytes).context("Failed to load image from bytes")
image::load_from_memory(&bytes).context("Failed to load utils from bytes")
}
// 1. 已经是 DynamicImage
ImageInput::DynamicImage(i) => Ok(i),
@@ -34,11 +34,11 @@ pub fn load_image_from_input(img_input: ImageInput) -> Result<DynamicImage> {
// 4. 处理 Base64 字符串
ImageInput::Base64(b) => base64_to_image(&b),
// 3. 处理文件路径 (Path)
ImageInput::Path(p) => image::open(p).context("Failed to open image from path"),
ImageInput::Path(p) => image::open(p).context("Failed to open utils from path"),
}
}
fn base64_to_image(b64_str: &str) -> Result<DynamicImage> {
// 过滤掉可能存在的 base64 前缀,例如 "data:image/png;base64,"
// 过滤掉可能存在的 base64 前缀,例如 "data:utils/png;base64,"
let clean_b64 = if let Some(pos) = b64_str.find(",") {
&b64_str[pos + 1..]
} else {
@@ -49,7 +49,7 @@ fn base64_to_image(b64_str: &str) -> Result<DynamicImage> {
.decode(clean_b64.trim())
.map_err(|e| anyhow!("Base64 decode error: {}", e))?;
image::load_from_memory(&bytes).context("Failed to load image from decoded base64")
image::load_from_memory(&bytes).context("Failed to load utils from decoded base64")
}
/// 读取图片文件并转换为 base64 编码字符串
@@ -58,7 +58,7 @@ pub fn get_img_base64<P: AsRef<Path>>(image_path: P) -> Result<String> {
// 1. 读取文件原始字节流
// 使用 AsRef<Path> 泛型可以让函数同时支持 String, &str, PathBuf 等类型
let image_data = fs::read(&image_path)
.with_context(|| format!("Failed to read image file: {:?}", image_path.as_ref()))?;
.with_context(|| format!("Failed to read utils file: {:?}", image_path.as_ref()))?;
// 2. 进行 Base64 编码
// 使用 STANDARD 引擎对齐 Python 的 base64.b64encode
@@ -83,7 +83,7 @@ fn numpy_to_pil_image(array: ArrayViewD<u8>) -> Result<DynamicImage> {
let (h, w) = (shape[0], shape[1]);
ImageBuffer::<Luma<u8>, _>::from_raw(w as u32, h as u32, raw_data)
.map(DynamicImage::ImageLuma8)
.ok_or_else(|| anyhow!("Failed to create Luma image from 2D array"))
.ok_or_else(|| anyhow!("Failed to create Luma utils from 2D array"))
}
// 对应 Python: len(array.shape) == 3 (H, W, C)
@@ -131,7 +131,7 @@ pub fn png_rgba_white_preprocess(img: &DynamicImage) -> DynamicImage {
let rgba_img = img.to_rgba8();
// 4. 遍历像素并手动进行 Alpha 混合
// 对应 Python 的 image.paste(img, ..., mask=img)
// 对应 Python 的 utils.paste(img, ..., mask=img)
// 使用 enumerate_pixels_mut 同时获取坐标和背景像素的可变引用,减少查找开销
for (x, y, bg_pixel) in background.enumerate_pixels_mut() {
// 安全性说明x, y 源自 background 尺寸,与 rgba_img 一致get_pixel 是安全的
@@ -162,8 +162,8 @@ pub fn png_rgba_white_preprocess(img: &DynamicImage) -> DynamicImage {
DynamicImage::ImageRgb8(background)
}
pub fn image_to_numpy(image: &DynamicImage, mode: ColorMode) -> Result<Array3<u8>> {
// 1. 模式转换 (对应 image.convert(target_mode)),此函数在时保留看后续优化是否需要替代image_to_ndarray
// Rust image 库通过 to_rgb8, to_luma8 等方法实现转换
// 1. 模式转换 (对应 utils.convert(target_mode)),此函数在时保留看后续优化是否需要替代image_to_ndarray
// Rust utils 库通过 to_rgb8, to_luma8 等方法实现转换
let (width, height) = image.dimensions();
let (channels, raw) = match mode {
@@ -225,7 +225,7 @@ pub fn image_to_ndarray(img: &DynamicImage) -> Array3<u8> {
// 3. 构造数组 (通道数改为 3)
Array3::from_shape_vec((height as usize, width as usize, 3), raw_data)
.expect("Failed to construct ndarray from image") // 建议显式报错,而不是返回全黑图
.expect("Failed to construct ndarray from utils") // 建议显式报错,而不是返回全黑图
}
#[allow(dead_code)]

View File

@@ -4,7 +4,7 @@ use anyhow::Result;
/// 对应 Python 的 convert_to_grayscale
/// 将图像转换为灰度图 (L模式)
pub fn convert_to_grayscale(image: &DynamicImage) -> GrayImage {
// Rust image 库的 to_luma8 会根据标准的亮度公式进行转换
// Rust utils 库的 to_luma8 会根据标准的亮度公式进行转换
image.to_luma8()
}

3
src/utils/mod.rs Normal file
View File

@@ -0,0 +1,3 @@
pub mod image_io;
pub mod image_processor;
pub mod cv_ops;

View File

@@ -1,23 +1,24 @@
use ddddocr_rs::models::slide::Slide;
use ddddocr_rs::{DdddOcr, DdddOcrBuilder}; // 假设你的包名是这个
use image::Rgb;
use std::fs;
use std::path::Path;
use image::Rgb;
use ddddocr_rs::{DdddOcr, DdddOcrBuilder}; // 假设你的包名是这个
use ddddocr_rs::slide_model::Slide;
fn load_image<P: AsRef<Path>>(path: P) -> anyhow::Result<image::DynamicImage> {
// 1. 先将泛型转为具体的 &Path 引用
let path_ref = path.as_ref();
// 2. 调用 open 时传入引用(image::open 支持 AsRef<Path>
image::open(path_ref)
.map_err(|e| {
// 3. 此时 path_ref 依然有效,可以安全地在闭包中使用
anyhow::anyhow!("无法加载图片 {:?}: {}", path_ref, e)
})
// 2. 调用 open 时传入引用(utils::open 支持 AsRef<Path>
image::open(path_ref).map_err(|e| {
// 3. 此时 path_ref 依然有效,可以安全地在闭包中使用
anyhow::anyhow!("无法加载图片 {:?}: {}", path_ref, e)
})
}
/// 将检测结果绘制在图像上并保存
fn save_debug_image( image_bytes: &[u8], bboxes: &Vec<Vec<i32>>, output_path: &str) -> anyhow::Result<()> {
fn save_debug_image(
image_bytes: &[u8],
bboxes: &Vec<Vec<i32>>,
output_path: &str,
) -> anyhow::Result<()> {
let dynamic_img = image::load_from_memory(image_bytes)?;
let mut img = dynamic_img.to_rgb8();
let (width, height) = img.dimensions();
@@ -35,16 +36,24 @@ fn save_debug_image( image_bytes: &[u8], bboxes: &Vec<Vec<i32>>, output_path: &s
img.put_pixel(x, y1, red);
img.put_pixel(x, y2, red);
// 如果要加粗,多画一行
if y1 + 1 < height { img.put_pixel(x, y1 + 1, red); }
if y2.saturating_sub(1) > 0 { img.put_pixel(x, y2 - 1, red); }
if y1 + 1 < height {
img.put_pixel(x, y1 + 1, red);
}
if y2.saturating_sub(1) > 0 {
img.put_pixel(x, y2 - 1, red);
}
}
// 绘制纵向线条
for y in y1..=y2 {
img.put_pixel(x1, y, red);
img.put_pixel(x2, y, red);
// 如果要加粗,多画一列
if x1 + 1 < width { img.put_pixel(x1 + 1, y, red); }
if x2.saturating_sub(1) > 0 { img.put_pixel(x2 - 1, y, red); }
if x1 + 1 < width {
img.put_pixel(x1 + 1, y, red);
}
if x2.saturating_sub(1) > 0 {
img.put_pixel(x2 - 1, y, red);
}
}
}
@@ -66,43 +75,44 @@ fn test_full_classification() {
assert!(!result.is_empty());
}
#[test]
fn test_det_load()->anyhow::Result<()>{
fn test_det_load() -> anyhow::Result<()> {
let det = DdddOcrBuilder::new().det().build()?;
let image_path = "samples/det1.png";
let image_bytes = fs::read(image_path)
.map_err(|e| anyhow::anyhow!("无法读取图片 {}: {}", image_path, e))?;
let image_bytes =
fs::read(image_path).map_err(|e| anyhow::anyhow!("无法读取图片 {}: {}", image_path, e))?;
println!("图片读取成功,字节大小: {}", image_bytes.len());
let bboxes =det.detection(&image_bytes)?;
println!(":?{}",det);
let bboxes = det.detection(&image_bytes)?;
println!(":?{}", det);
println!("检测到的目标数量: {}", bboxes.len());
if bboxes.is_empty() {
println!("未检测到任何目标。");
} else {
save_debug_image(&image_bytes, &bboxes, "samples/result.jpg")?;
for (i, bbox) in bboxes.iter().enumerate() {
println!("目标 [{}]: x1={}, y1={}, x2={}, y2={}", i, bbox[0], bbox[1], bbox[2], bbox[3]);
println!(
"目标 [{}]: x1={}, y1={}, x2={}, y2={}",
i, bbox[0], bbox[1], bbox[2], bbox[3]
);
}
}
Ok(())
}
#[test]
fn test_real_slide_match() {
let engine = Slide::new();
// 1. 加载你准备好的测试图
// 假设图片放在项目根目录下的 assets 文件夹
let target_img = load_image("samples/hua.png")
.expect("请确保 samples/hua.png 存在");
let bg_img = load_image("samples/huatu.png")
.expect("请确保 samples/huatu.png 存在");
let target_img = load_image("samples/hua.png").expect("请确保 samples/hua.png 存在");
let bg_img = load_image("samples/huatu.png").expect("请确保 samples/huatu.png 存在");
// 2. 执行匹配
// 如果是那种带有明显阴影边缘的复杂滑块,建议 simple_target 传 false
let start = std::time::Instant::now();
let result = engine.slide_match(&target_img, &bg_img, true)
let result = engine
.slide_match(&target_img, &bg_img, false)
.expect("Slide match 执行失败");
let duration = start.elapsed();
@@ -126,15 +136,14 @@ fn test_real_slide_comparison() {
// 1. 加载你准备好的测试图
// 假设图片放在项目根目录下的 assets 文件夹
let target_img = load_image("samples/ken.jpg")
.expect("请确保 samples/ken.jpg 存在");
let bg_img = load_image("samples/kenyuan.jpg")
.expect("请确保 samples/kenyuan.jpg 存在");
let target_img = load_image("samples/ken.jpg").expect("请确保 samples/ken.jpg 存在");
let bg_img = load_image("samples/kenyuan.jpg").expect("请确保 samples/kenyuan.jpg 存在");
// 2. 执行匹配
// 如果是那种带有明显阴影边缘的复杂滑块,建议 simple_target 传 false
let start = std::time::Instant::now();
let result = engine.slide_comparison(&target_img, &bg_img)
let result = engine
.slide_comparison(&target_img, &bg_img)
.expect("Slide match 执行失败");
let duration = start.elapsed();