From 0df90224115fa1b188fc71c71ee8bd63eac71e2f Mon Sep 17 00:00:00 2001 From: CNWei Date: Sun, 10 May 2026 20:52:42 +0800 Subject: [PATCH] =?UTF-8?q?feat:=20=E4=BC=98=E5=8C=96=20=E9=A1=B9=E7=9B=AE?= =?UTF-8?q?=E7=9B=AE=E5=BD=95=E7=BB=93=E6=9E=84?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/lib.rs | 27 +++++++++-------------- src/model.rs | 0 src/{ => models}/base.rs | 0 src/{det_model.rs => models/det.rs} | 6 ++--- src/{model_loader.rs => models/loader.rs} | 4 ++-- src/models/mod.rs | 5 +++++ src/{ocr_model.rs => models/ocr.rs} | 8 +++---- src/{slide_model.rs => models/slide.rs} | 4 ++-- src/utils.rs | 0 src/{cv2.rs => utils/cv_ops.rs} | 0 src/{ => utils}/image_io.rs | 20 ++++++++--------- src/{ => utils}/image_processor.rs | 2 +- src/utils/mod.rs | 3 +++ tests/ocr_test.rs | 6 ++--- 14 files changed, 44 insertions(+), 41 deletions(-) delete mode 100644 src/model.rs rename src/{ => models}/base.rs (100%) rename src/{det_model.rs => models/det.rs} (98%) rename src/{model_loader.rs => models/loader.rs} (88%) create mode 100644 src/models/mod.rs rename src/{ocr_model.rs => models/ocr.rs} (97%) rename src/{slide_model.rs => models/slide.rs} (98%) delete mode 100644 src/utils.rs rename src/{cv2.rs => utils/cv_ops.rs} (100%) rename src/{ => utils}/image_io.rs (93%) rename src/{ => utils}/image_processor.rs (92%) create mode 100644 src/utils/mod.rs diff --git a/src/lib.rs b/src/lib.rs index 9b69d0e..391554a 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,14 +1,7 @@ -pub mod base; mod charset; -mod det_model; -mod image_io; -mod image_processor; -mod model; -mod model_loader; -mod ocr_model; -mod utils; -pub mod slide_model; -mod cv2; + +pub mod models; +pub mod utils; use anyhow::Result; use image::DynamicImage; @@ -16,9 +9,9 @@ use std::fmt::{Display, Formatter}; // 关键点:直接使用 tract 重导出的 ndarray use crate::charset::get_default_charset; -use crate::det_model::Det; -use crate::model_loader::ModelSession; -use crate::ocr_model::Ocr; +use models::det::Det; +use models::loader::ModelSession; +use models::ocr::Ocr; pub enum ModelSpec { /// 默认 OCR (使用内置路径) OcrModel, @@ -31,7 +24,7 @@ pub enum ModelSpec { } impl ModelSpec { // 将默认路径定义为内部关联常量 - const DEFAULT_OCR_PATH: &'static str = "models/common_sml2h3_f32.onnx"; + const DEFAULT_OCR_PATH: &'static str = "models/common.onnx"; const DEFAULT_DET_PATH: &'static str = "models/common_det.onnx"; } pub enum Runtime { @@ -74,7 +67,10 @@ impl DdddOcrBuilder { /// 核心初始化逻辑 pub fn build(self) -> Result { let runtime = match self.mode { - ModelSpec::OcrModel => Runtime::Ocr(Ocr::new(ModelSpec::DEFAULT_OCR_PATH.into(), get_default_charset())?), + ModelSpec::OcrModel => Runtime::Ocr(Ocr::new( + ModelSpec::DEFAULT_OCR_PATH.into(), + get_default_charset(), + )?), ModelSpec::DetModel => Runtime::Det(Det::new(ModelSpec::DEFAULT_DET_PATH.into())?), ModelSpec::CustomOcrModel { path, charset } => Runtime::Ocr(Ocr::new(path, charset)?), }; @@ -110,7 +106,6 @@ impl DdddOcr { #[cfg(test)] mod tests { - use super::*; #[test] fn test_ctc_decode_indices() { // 模拟一个 DdddOcr 实例(如果 decode 不依赖 session,可以设为相关函数) diff --git a/src/model.rs b/src/model.rs deleted file mode 100644 index e69de29..0000000 diff --git a/src/base.rs b/src/models/base.rs similarity index 100% rename from src/base.rs rename to src/models/base.rs diff --git a/src/det_model.rs b/src/models/det.rs similarity index 98% rename from src/det_model.rs rename to src/models/det.rs index ff1a831..43a9cdc 100644 --- a/src/det_model.rs +++ b/src/models/det.rs @@ -1,4 +1,4 @@ -use crate::model_loader::{ModelLoader, ModelSession, ModelType}; +use crate::models::loader::{ModelLoader, ModelSession, ModelType}; use anyhow::{Context, Result}; use image::{DynamicImage, GenericImageView, imageops::FilterType}; use tract_onnx::prelude::tract_ndarray::{Array2, Array3, Array4, Axis, prelude::*, s}; @@ -219,8 +219,8 @@ impl Det { } /// 6. get_bbox (完全解耦 OpenCV) pub fn get_bbox(&self, image_bytes: &[u8]) -> Result>> { - // 使用 image crate 解码 - let dynamic_img = image::load_from_memory(image_bytes).context("Failed to decode image")?; + // 使用 utils crate 解码 + let dynamic_img = image::load_from_memory(image_bytes).context("Failed to decode utils")?; let (orig_w, orig_h) = dynamic_img.dimensions(); let (input_tensor, ratio) = self.preproc(&dynamic_img, (416, 416))?; diff --git a/src/model_loader.rs b/src/models/loader.rs similarity index 88% rename from src/model_loader.rs rename to src/models/loader.rs index e912c9d..ae1f1a8 100644 --- a/src/model_loader.rs +++ b/src/models/loader.rs @@ -3,8 +3,8 @@ use image::DynamicImage; use tract_onnx::onnx; use tract_onnx::prelude::*; // 关键点:直接使用 tract 重导出的 ndarray -use crate::image_io::png_rgba_white_preprocess; -use crate::image_processor::{convert_to_grayscale, resize_image}; +use crate::utils::image_io::png_rgba_white_preprocess; +use crate::utils::image_processor::{convert_to_grayscale, resize_image}; use std::collections::HashMap; use tract_onnx::prelude::tract_ndarray::s; diff --git a/src/models/mod.rs b/src/models/mod.rs new file mode 100644 index 0000000..981358c --- /dev/null +++ b/src/models/mod.rs @@ -0,0 +1,5 @@ +pub mod base; +pub mod loader; +pub mod ocr; +pub mod det; +pub mod slide; \ No newline at end of file diff --git a/src/ocr_model.rs b/src/models/ocr.rs similarity index 97% rename from src/ocr_model.rs rename to src/models/ocr.rs index 9f993dc..e9c39d2 100644 --- a/src/ocr_model.rs +++ b/src/models/ocr.rs @@ -1,7 +1,7 @@ -use crate::base::ModelArgs; -use crate::image_io::png_rgba_white_preprocess; -use crate::image_processor::{convert_to_grayscale, resize_image}; -use crate::model_loader::{ModelLoader, ModelSession, ModelType}; +use crate::models::base::ModelArgs; +use crate::utils::image_io::png_rgba_white_preprocess; +use crate::utils::image_processor::{convert_to_grayscale, resize_image}; +use crate::models::loader::{ModelLoader, ModelSession, ModelType}; use anyhow::Context; use image::DynamicImage; use tract_onnx::prelude::tract_ndarray::s; diff --git a/src/slide_model.rs b/src/models/slide.rs similarity index 98% rename from src/slide_model.rs rename to src/models/slide.rs index 344c8b2..9afc854 100644 --- a/src/slide_model.rs +++ b/src/models/slide.rs @@ -1,5 +1,5 @@ -use crate::cv2::{min_max_loc, rgb_to_gray, ndarray_to_luma8, abs_diff}; -use crate::image_io::image_to_ndarray; +use crate::utils::cv_ops::{min_max_loc, rgb_to_gray, ndarray_to_luma8, abs_diff}; +use crate::utils::image_io::image_to_ndarray; use anyhow::{Context, Result, anyhow}; use image::{DynamicImage, GenericImageView}; use image::{ImageBuffer, Luma}; diff --git a/src/utils.rs b/src/utils.rs deleted file mode 100644 index e69de29..0000000 diff --git a/src/cv2.rs b/src/utils/cv_ops.rs similarity index 100% rename from src/cv2.rs rename to src/utils/cv_ops.rs diff --git a/src/image_io.rs b/src/utils/image_io.rs similarity index 93% rename from src/image_io.rs rename to src/utils/image_io.rs index 178771a..c784500 100644 --- a/src/image_io.rs +++ b/src/utils/image_io.rs @@ -24,7 +24,7 @@ pub fn load_image_from_input(img_input: ImageInput) -> Result { match img_input { // 2. 处理字节流 (Bytes) ImageInput::Bytes(bytes) => { - image::load_from_memory(&bytes).context("Failed to load image from bytes") + image::load_from_memory(&bytes).context("Failed to load utils from bytes") } // 1. 已经是 DynamicImage ImageInput::DynamicImage(i) => Ok(i), @@ -34,11 +34,11 @@ pub fn load_image_from_input(img_input: ImageInput) -> Result { // 4. 处理 Base64 字符串 ImageInput::Base64(b) => base64_to_image(&b), // 3. 处理文件路径 (Path) - ImageInput::Path(p) => image::open(p).context("Failed to open image from path"), + ImageInput::Path(p) => image::open(p).context("Failed to open utils from path"), } } fn base64_to_image(b64_str: &str) -> Result { - // 过滤掉可能存在的 base64 前缀,例如 "data:image/png;base64," + // 过滤掉可能存在的 base64 前缀,例如 "data:utils/png;base64," let clean_b64 = if let Some(pos) = b64_str.find(",") { &b64_str[pos + 1..] } else { @@ -49,7 +49,7 @@ fn base64_to_image(b64_str: &str) -> Result { .decode(clean_b64.trim()) .map_err(|e| anyhow!("Base64 decode error: {}", e))?; - image::load_from_memory(&bytes).context("Failed to load image from decoded base64") + image::load_from_memory(&bytes).context("Failed to load utils from decoded base64") } /// 读取图片文件并转换为 base64 编码字符串 @@ -58,7 +58,7 @@ pub fn get_img_base64>(image_path: P) -> Result { // 1. 读取文件原始字节流 // 使用 AsRef 泛型可以让函数同时支持 String, &str, PathBuf 等类型 let image_data = fs::read(&image_path) - .with_context(|| format!("Failed to read image file: {:?}", image_path.as_ref()))?; + .with_context(|| format!("Failed to read utils file: {:?}", image_path.as_ref()))?; // 2. 进行 Base64 编码 // 使用 STANDARD 引擎对齐 Python 的 base64.b64encode @@ -83,7 +83,7 @@ fn numpy_to_pil_image(array: ArrayViewD) -> Result { let (h, w) = (shape[0], shape[1]); ImageBuffer::, _>::from_raw(w as u32, h as u32, raw_data) .map(DynamicImage::ImageLuma8) - .ok_or_else(|| anyhow!("Failed to create Luma image from 2D array")) + .ok_or_else(|| anyhow!("Failed to create Luma utils from 2D array")) } // 对应 Python: len(array.shape) == 3 (H, W, C) @@ -131,7 +131,7 @@ pub fn png_rgba_white_preprocess(img: &DynamicImage) -> DynamicImage { let rgba_img = img.to_rgba8(); // 4. 遍历像素并手动进行 Alpha 混合 - // 对应 Python 的 image.paste(img, ..., mask=img) + // 对应 Python 的 utils.paste(img, ..., mask=img) // 使用 enumerate_pixels_mut 同时获取坐标和背景像素的可变引用,减少查找开销 for (x, y, bg_pixel) in background.enumerate_pixels_mut() { // 安全性说明:x, y 源自 background 尺寸,与 rgba_img 一致,get_pixel 是安全的 @@ -162,8 +162,8 @@ pub fn png_rgba_white_preprocess(img: &DynamicImage) -> DynamicImage { DynamicImage::ImageRgb8(background) } pub fn image_to_numpy(image: &DynamicImage, mode: ColorMode) -> Result> { - // 1. 模式转换 (对应 image.convert(target_mode)),此函数在时保留看后续优化是否需要替代image_to_ndarray - // Rust image 库通过 to_rgb8, to_luma8 等方法实现转换 + // 1. 模式转换 (对应 utils.convert(target_mode)),此函数在时保留看后续优化是否需要替代image_to_ndarray + // Rust utils 库通过 to_rgb8, to_luma8 等方法实现转换 let (width, height) = image.dimensions(); let (channels, raw) = match mode { @@ -225,7 +225,7 @@ pub fn image_to_ndarray(img: &DynamicImage) -> Array3 { // 3. 构造数组 (通道数改为 3) Array3::from_shape_vec((height as usize, width as usize, 3), raw_data) - .expect("Failed to construct ndarray from image") // 建议显式报错,而不是返回全黑图 + .expect("Failed to construct ndarray from utils") // 建议显式报错,而不是返回全黑图 } #[allow(dead_code)] diff --git a/src/image_processor.rs b/src/utils/image_processor.rs similarity index 92% rename from src/image_processor.rs rename to src/utils/image_processor.rs index b500e7b..cd39fae 100644 --- a/src/image_processor.rs +++ b/src/utils/image_processor.rs @@ -4,7 +4,7 @@ use anyhow::Result; /// 对应 Python 的 convert_to_grayscale /// 将图像转换为灰度图 (L模式) pub fn convert_to_grayscale(image: &DynamicImage) -> GrayImage { - // Rust image 库的 to_luma8 会根据标准的亮度公式进行转换 + // Rust utils 库的 to_luma8 会根据标准的亮度公式进行转换 image.to_luma8() } diff --git a/src/utils/mod.rs b/src/utils/mod.rs new file mode 100644 index 0000000..0d04ccf --- /dev/null +++ b/src/utils/mod.rs @@ -0,0 +1,3 @@ +pub mod image_io; +pub mod image_processor; +pub mod cv_ops; \ No newline at end of file diff --git a/tests/ocr_test.rs b/tests/ocr_test.rs index 0d64d5e..9e61d48 100644 --- a/tests/ocr_test.rs +++ b/tests/ocr_test.rs @@ -2,12 +2,12 @@ use std::fs; use std::path::Path; use image::Rgb; use ddddocr_rs::{DdddOcr, DdddOcrBuilder}; // 假设你的包名是这个 -use ddddocr_rs::slide_model::Slide; +use ddddocr_rs::models::slide::Slide; fn load_image>(path: P) -> anyhow::Result { // 1. 先将泛型转为具体的 &Path 引用 let path_ref = path.as_ref(); - // 2. 调用 open 时传入引用(image::open 支持 AsRef) + // 2. 调用 open 时传入引用(utils::open 支持 AsRef) image::open(path_ref) .map_err(|e| { // 3. 此时 path_ref 依然有效,可以安全地在闭包中使用 @@ -102,7 +102,7 @@ fn test_real_slide_match() { // 2. 执行匹配 // 如果是那种带有明显阴影边缘的复杂滑块,建议 simple_target 传 false let start = std::time::Instant::now(); - let result = engine.slide_match(&target_img, &bg_img, true) + let result = engine.slide_match(&target_img, &bg_img, false) .expect("Slide match 执行失败"); let duration = start.elapsed();