From cfeb68ad0447140534ef6db44dd9a3d903d6bcd7 Mon Sep 17 00:00:00 2001 From: CNWei Date: Tue, 5 May 2026 22:18:12 +0800 Subject: [PATCH] =?UTF-8?q?feat:=20=E9=87=8D=E6=9E=84=E6=A8=A1=E5=9E=8B?= =?UTF-8?q?=E5=88=9D=E5=A7=8B=E5=8C=96=E9=80=BB=E8=BE=91?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - 重构 DdddOcr。 - 新增 DdddOcrBuilder。 - 其他优化 --- examples/simple_usage.rs | 2 +- src/base.rs | 40 +++++++ src/charset.rs | 3 + src/det_model.rs | 24 ++++ src/lib.rs | 226 ++++++++++++----------------------- src/model_loader.rs | 40 +++++++ src/ocr_model.rs | 248 +++++++++++++++++++++++++++++++++++++++ tests/ocr_test.rs | 4 +- 8 files changed, 434 insertions(+), 153 deletions(-) create mode 100644 src/base.rs create mode 100644 src/det_model.rs create mode 100644 src/model_loader.rs create mode 100644 src/ocr_model.rs diff --git a/examples/simple_usage.rs b/examples/simple_usage.rs index c5031f8..decfde6 100644 --- a/examples/simple_usage.rs +++ b/examples/simple_usage.rs @@ -1,5 +1,5 @@ fn main() { - let ocr = ddddocr_rs::DdddOcr::new("model/common.onnx").unwrap(); + let ocr = ddddocr_rs::DdddOcrBuilder::new().build().unwrap(); let img = image::open("samples/code3.png").unwrap(); println!("Result: {}", ocr.classification(&img).unwrap()); } \ No newline at end of file diff --git a/src/base.rs b/src/base.rs new file mode 100644 index 0000000..e9c8b2c --- /dev/null +++ b/src/base.rs @@ -0,0 +1,40 @@ +pub trait ModelArgs { + // 获取模型路径 + fn model_path(&self) -> &str; + + // 获取字符集(由于 Det 没有,所以返回 Option) + fn charset(&self) -> Option<&str>; +} + +pub struct HasCharset { + pub charset: String, +} // 给 Ocr 和 Custom 用 +pub struct NoCharset; // 给 Det 用 + +pub struct Model { + pub path: String, + pub metadata: T, +} +// 针对有字符集的模型 (Ocr / Custom) +impl ModelArgs for Model { + fn model_path(&self) -> &str { + &self.path + } + fn charset(&self) -> Option<&str> { + Some(&self.metadata.charset) + } +} + +// 针对没有字符集的模型 (Det) +impl ModelArgs for Model { + fn model_path(&self) -> &str { + &self.path + } + fn charset(&self) -> Option<&str> { + None + } +} + +pub type OcrModel = Model; +pub type DetModel = Model; +pub type CustomModel = Model; // Ocr 和 Custom 逻辑一致,可以复用 \ No newline at end of file diff --git a/src/charset.rs b/src/charset.rs index 750e2fc..fc44cef 100644 --- a/src/charset.rs +++ b/src/charset.rs @@ -514,3 +514,6 @@ pub const CHARSET_BETA: &[&str] = &[ "谬", "溝", "言", "哽", "婿", "猿", "跗", "獴", "俜", "呙", "弗", "凿", "窭", "铌", "友", "唉", "怫", "荘", ]; +pub fn get_default_charset() -> Vec { + CHARSET_BETA.iter().map(|&s| s.to_string()).collect() +} \ No newline at end of file diff --git a/src/det_model.rs b/src/det_model.rs new file mode 100644 index 0000000..3dd7dbd --- /dev/null +++ b/src/det_model.rs @@ -0,0 +1,24 @@ +use image::DynamicImage; +use crate::model_loader::{ModelLoader, ModelSession, ModelType}; +use tract_onnx::prelude::{Graph, RunnableModel, TypedFact, TypedOp}; +use crate::ocr_model::Ocr; + +pub struct Det { + session: RunnableModel, Graph>>, +} +impl ModelSession for Det { + fn predict(&self, image: &DynamicImage, png_fix: bool) -> Result { + // OCR 识别逻辑 + CTC 解码 + Ok("ocr result".to_string()) + } + + fn get_model_type(&self) -> ModelType { + todo!() + } +} +impl Det { + pub fn new(model_path: String) -> Result { + let session = ModelLoader::load_model(&model_path)?.session; + Ok(Self { session }) + } +} diff --git a/src/lib.rs b/src/lib.rs index 0445875..87504ff 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,169 +1,95 @@ +pub mod base; mod charset; +mod det_model; mod image_io; mod image_processor; mod model; +mod model_loader; +mod ocr_model; mod utils; -use crate::image_io::png_rgba_white_preprocess; -use crate::image_processor::{convert_to_grayscale, resize_image}; -use anyhow::{Context, Result}; -use image::{DynamicImage, imageops::FilterType}; -use tract_onnx::prelude::*; +use anyhow::Result; +use image::DynamicImage; + // 关键点:直接使用 tract 重导出的 ndarray -use tract_onnx::prelude::tract_ndarray::s; -pub struct DdddOcr { - session: RunnableModel, Graph>>, +use crate::det_model::Det; +use crate::model_loader::ModelSession; +use crate::ocr_model::Ocr; +use crate::charset::get_default_charset; +pub enum ModeType { + /// 默认 OCR (使用内置路径) + Ocr { + path: String, + charset: Vec, + }, + Det { + path: String, + }, + /// 自定义 OCR (路径由用户提供) + CustomOcr { + path: String, + charset: Vec, + }, } -impl DdddOcr { - pub fn new

(model_path: P) -> Result - where - P: AsRef, - { - let session = onnx() - .model_for_path(model_path) - .with_context(|| "加载 ONNX 模型失败,请检查路径是否正确")? - .into_optimized()? - .into_runnable()?; - Ok(Self { session }) +pub struct DdddOcrBuilder { + mode: ModeType, +} + +impl DdddOcrBuilder { + pub fn new() -> Self { + Self { + mode: ModeType::Ocr { + path: "models/common.onnx".to_string(), + charset: get_default_charset(), + }, + } } - pub fn classification(&self, img: &DynamicImage) -> Result { - let tensor = self.preprocess_image(img, false)?; - - // let result = self.session.run(tvec!(tensor.into()))?; - // 3. 解析结果 - // let output = result[0].to_array_view::()?; - let output = self.inference(tensor)?; - let output2 = self.process_text_output(&output)?; - Ok(Self::ctc_decode_indices(&output2)) + /// 切换为检测模式 + pub fn det(mut self) -> Self { + self.mode = ModeType::Det { + path: "models/common_det.onnx".to_string(), + }; + self } - /// 对应 Python 的 _preprocess_image - /// 负责:透明背景修复 -> 灰度化 -> 按比例 Resize -> 归一化 -> 4维张量转换 - fn preprocess_image(&self, img: &DynamicImage, png_fix: bool) -> Result { - // A. 修复 PNG 透明背景 (内部逻辑你之前已实现) - let _ = if png_fix && img.color().has_alpha() { - png_rgba_white_preprocess(img) - } else { - img.clone() + + /// 设置自定义 OCR 路径 + pub fn custom_ocr(mut self, path: String, charset: Vec) -> Self { + // 直接重写枚举,替换掉之前的 Ocr 或 Det + self.mode = ModeType::CustomOcr { path, charset }; + self + } + + /// 核心初始化逻辑 + pub fn build(self) -> Result { + let session: Box = match self.mode { + ModeType::Ocr { path, charset } => Box::new(Ocr::new(path, charset)?), + ModeType::Det { path } => Box::new(Det::new(path)?), + ModeType::CustomOcr { path, charset } => Box::new(Ocr::new(path, charset)?), }; - let h = 64u32; - let w = (img.width() as f32 * (h as f32 / img.height() as f32)) as u32; - let gray_img = convert_to_grayscale(img); - let resized = resize_image(&gray_img, w, h); - // resized.save("debug_preprocessed.png").unwrap(); - // 1. 预处理:转灰度 -> Resize -> 归一化 - // let resized = img.resize_exact(w, h, FilterType::Lanczos3).to_luma8(); - - // 使用 tract_ndarray 构造,避免版本冲突 - let array = - tract_ndarray::Array4::from_shape_fn((1, 1, h as usize, w as usize), |(_, _, y, x)| { - let pixel = resized.get_pixel(x as u32, y as u32)[0] as f32; - (pixel / 255.0 - 0.5) / 0.5 - }); - - let tensor = Tensor::from(array); - - Ok(tensor) - } - /// 对应 Python 的 _inference - fn inference(&self, tensor: Tensor) -> Result { - // tract 的 run 会返回一个 Vec,我们通常只需要第一个输出 - // let result = self.session.run(tvec!(tensor.into()))?; - let mut result = self - .session - .run(tvec!(tensor.into())) - .context("执行模型推理失败")?; - println!("模型输出原始数据: {:?}", result); - Ok(result.remove(0).into_tensor()) - } - /// 核心解析逻辑:将模型输出的各种维度/类型的 Tensor 转为字符索引序列 - fn process_text_output(&self, raw_tensor: &Tensor) -> Result> { - let shape = raw_tensor.shape(); - println!("模型输出shape数据: {:?}", shape); - let datum_type = raw_tensor.datum_type(); - println!("模型输出datum_type数据: {:?}", datum_type); - - match raw_tensor.datum_type() { - // 情况 1: huashi666 式模型,直接输出 i64 索引 (通常是模型内部做好了 Argmax) - DatumType::I64 => { - let view = raw_tensor.to_array_view::()?; - Ok(view.iter().cloned().collect()) - } - - // 情况 2: sml2h3 原版模型,输出 F32 概率矩阵 - DatumType::F32 => { - let view = raw_tensor.to_array_view::()?; - let (steps, classes, data_view) = match shape.len() { - 3 => { - if shape[1] == 1 { - // 形状: [Steps, 1, Classes] -> 你的原有逻辑 - (shape[0], shape[2], view.into_dyn()) - } else if shape[0] == 1 { - // 形状: [1, Steps, Classes] -> 另一种常见导出格式 - (shape[1], shape[2], view.into_dyn()) - } else { - // 默认取第一个 batch: [Batch, Steps, Classes] - // 使用 slice 对应 Python 的 output[0, :, :] - let sliced = view.slice(s![0, .., ..]); - (shape[1], shape[2], sliced.into_dyn()) - } - } - 2 => { - // 形状: [Steps, Classes] -> 已经剥离了 Batch 维度 - (shape[0], shape[1], view.into_dyn()) - } - _ => return Err(anyhow::anyhow!("不支持的输出维度: {:?}", shape)), - }; - let array_2d = data_view.to_shape((steps, classes))?; - // - // 对每一行执行 Argmax (寻找概率最大的字符索引) - let indices = array_2d - .outer_iter() - .map(|row| { - row.iter() - .enumerate() - .max_by(|(_, a), (_, b)| { - a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal) - }) - .map(|(idx, _)| idx as i64) - .unwrap_or(0) - }) - .collect(); - Ok(indices) - } - _ => Err(anyhow::anyhow!( - "不支持的模型输出数据类型: {:?}", - raw_tensor.datum_type() - )), - } - } - fn ctc_decode_indices(predicted_indices: &[i64]) -> String { - println!("indices模型输出原始数据: {:?}", predicted_indices); - - use crate::charset::CHARSET_BETA; - // 对应 _ctc_decode_indices 的逻辑:去重、去 blank (0) - let mut res = String::new(); - let mut prev_idx: i64 = -1; - - for &idx in predicted_indices { - // 1. 跳过连续重复的索引 - // 2. 跳过 blank 字符 (假设索引 0 是 blank) - if idx != prev_idx && idx != 0 { - if let Ok(u_idx) = usize::try_from(idx) { - if let Some(&char_str) = CHARSET_BETA.get(u_idx) { - res.push_str(char_str); - } - } - } - prev_idx = idx; - } - println!("最终识别出的验证码是: {}", res); - res + Ok(DdddOcr { session }) } } +pub struct DdddOcr { + session: Box, +} +impl DdddOcr { + pub fn classification(&self, img: &DynamicImage) -> Result { + self.session.predict(img, false) + + // let tensor = self.preprocess_image(img, false)?; + // + // // let result = self.session.run(tvec!(tensor.into()))?; + // // 3. 解析结果 + // // let output = result[0].to_array_view::()?; + // let output = self.inference(tensor)?; + // let output2 = self.process_text_output(&output)?; + // Ok(Self::ctc_decode_indices(&output2)) + } +} #[cfg(test)] mod tests { @@ -179,4 +105,4 @@ mod tests { // let result = dddd.ctc_decode_indices(&input); // assert_eq!(result, "AABB"); } -} \ No newline at end of file +} diff --git a/src/model_loader.rs b/src/model_loader.rs new file mode 100644 index 0000000..f8f5c64 --- /dev/null +++ b/src/model_loader.rs @@ -0,0 +1,40 @@ +use anyhow::Context; +use image::DynamicImage; +use tract_onnx::onnx; +use tract_onnx::prelude::*; +// 关键点:直接使用 tract 重导出的 ndarray +use crate::image_io::png_rgba_white_preprocess; +use crate::image_processor::{convert_to_grayscale, resize_image}; +use std::collections::HashMap; +use tract_onnx::prelude::tract_ndarray::s; + +/// OCR 模型:包含路径和字符集 + +pub enum ModelType { + Ocr, + Det, + Custom, +} +// 定义统一的 trait +pub trait ModelSession { + fn predict(&self, image: &DynamicImage, png_fix: bool) -> Result; + fn get_model_type(&self) -> ModelType; +} + +pub struct ModelLoader { + pub session: RunnableModel, Graph>>, +} + +impl ModelLoader { + pub fn load_model

(model_path: P) -> anyhow::Result + where + P: AsRef, + { + let session = onnx() + .model_for_path(model_path) + .with_context(|| "加载 ONNX 模型失败,请检查路径是否正确")? + .into_optimized()? + .into_runnable()?; + Ok(Self { session }) + } +} diff --git a/src/ocr_model.rs b/src/ocr_model.rs new file mode 100644 index 0000000..38a8ac8 --- /dev/null +++ b/src/ocr_model.rs @@ -0,0 +1,248 @@ +use crate::image_io::png_rgba_white_preprocess; +use crate::image_processor::{convert_to_grayscale, resize_image}; +use crate::model_loader::{ModelLoader, ModelSession, ModelType}; +use anyhow::Context; +use image::DynamicImage; +use tract_onnx::prelude::tract_ndarray::s; +use tract_onnx::prelude::{ + DatumType, Graph, IntoTensor, RunnableModel, Tensor, TypedFact, TypedOp, tract_ndarray, tvec, +}; +use crate::base::ModelArgs; + + +// 颜色过滤的自定义范围:(低值RGB, 高值RGB) +pub type ColorRange = ((u8, u8, u8), (u8, u8, u8)); + +// 字符集范围类型 +#[derive(Debug, Clone)] +pub enum CharsetRange { + All, // 所有字符 + Digit, // 数字 + Letter, // 字母 + Alphanumeric, // 字母数字 + Single(String), // 单字符串 + Multiple(Vec), // 多个字符串 + Range(char, char), // 字符范围 + Custom(Vec), // 自定义字符列表 +} +#[derive(Debug, Clone)] +pub struct PredictArgs{ + /// 是否修复PNG格式问题 + pub png_fix: bool, + /// 是否返回概率信息 + pub probability: bool, + /// 颜色过滤:保留的颜色列表 + pub color_filter_colors: Option>, + /// 颜色过滤:自定义RGB范围 + pub color_filter_custom_ranges: Option>, + /// 字符集范围 + pub charset_range: Option, +} + +impl Default for PredictArgs { + fn default() -> Self { + Self { + png_fix: false, + probability: false, + color_filter_colors: None, + color_filter_custom_ranges: None, + charset_range: None, + } + } +} + +impl PredictArgs { + pub fn new() -> Self { + Self::default() + } + + // Builder 模式方法 + pub fn png_fix(mut self, enabled: bool) -> Self { + self.png_fix = enabled; + self + } + + pub fn probability(mut self, enabled: bool) -> Self { + self.probability = enabled; + self + } + + pub fn color_filter_colors(mut self, colors: Vec) -> Self { + self.color_filter_colors = Some(colors); + self + } + + pub fn color_filter_custom_ranges(mut self, ranges: Vec) -> Self { + self.color_filter_custom_ranges = Some(ranges); + self + } + + pub fn charset_range(mut self, range: CharsetRange) -> Self { + self.charset_range = Some(range); + self + } + + // 便捷构造方法 + pub fn quick() -> Self { + Self::default() + } + + pub fn with_probability() -> Self { + Self::default().probability(true) + } + + pub fn with_png_fix() -> Self { + Self::default().png_fix(true) + } +} +pub struct Ocr { + session: RunnableModel, Graph>>, + charset: Vec, +} +impl ModelSession for Ocr { + fn predict(&self, image: &DynamicImage, png_fix: bool) -> Result { + let tensor = self.preprocess_image(image, png_fix)?; + // + // let result = self.session.run(tvec!(tensor.into()))?; + // // 3. 解析结果 + // // let output = result[0].to_array_view::()?; + let output = self.inference(tensor)?; + let output2 = self.process_text_output(&output)?; + Ok(Self::ctc_decode_indices(&output2)) + // Ok("ocr result".to_string()) + } + + fn get_model_type(&self) -> ModelType { + ModelType::Ocr + } +} +impl Ocr { + pub fn new(model_path: String, charset: Vec) -> Result { + let session = ModelLoader::load_model(&model_path)?.session; + Ok(Self { session, charset }) + } + /// 对应 Python 的 _preprocess_image + /// 负责:透明背景修复 -> 灰度化 -> 按比例 Resize -> 归一化 -> 4维张量转换 + fn preprocess_image(&self, img: &DynamicImage, png_fix: bool) -> anyhow::Result { + // A. 修复 PNG 透明背景 (内部逻辑你之前已实现) + let _ = if png_fix && img.color().has_alpha() { + png_rgba_white_preprocess(img) + } else { + img.clone() + }; + + let h = 64u32; + let w = (img.width() as f32 * (h as f32 / img.height() as f32)) as u32; + let gray_img = convert_to_grayscale(img); + let resized = resize_image(&gray_img, w, h); + // resized.save("debug_preprocessed.png").unwrap(); + // 1. 预处理:转灰度 -> Resize -> 归一化 + // let resized = img.resize_exact(w, h, FilterType::Lanczos3).to_luma8(); + + // 使用 tract_ndarray 构造,避免版本冲突 + let array = + tract_ndarray::Array4::from_shape_fn((1, 1, h as usize, w as usize), |(_, _, y, x)| { + let pixel = resized.get_pixel(x as u32, y as u32)[0] as f32; + (pixel / 255.0 - 0.5) / 0.5 + }); + + let tensor = Tensor::from(array); + + Ok(tensor) + } + /// 对应 Python 的 _inference + fn inference(&self, tensor: Tensor) -> anyhow::Result { + // tract 的 run 会返回一个 Vec,我们通常只需要第一个输出 + // let result = self.session.run(tvec!(tensor.into()))?; + let mut result = self + .session + .run(tvec!(tensor.into())) + .context("执行模型推理失败")?; + println!("模型输出原始数据: {:?}", result); + Ok(result.remove(0).into_tensor()) + } + /// 核心解析逻辑:将模型输出的各种维度/类型的 Tensor 转为字符索引序列 + fn process_text_output(&self, raw_tensor: &Tensor) -> anyhow::Result> { + let shape = raw_tensor.shape(); + println!("模型输出shape数据: {:?}", shape); + let datum_type = raw_tensor.datum_type(); + println!("模型输出datum_type数据: {:?}", datum_type); + + match raw_tensor.datum_type() { + // 情况 1: huashi666 式模型,直接输出 i64 索引 (通常是模型内部做好了 Argmax) + DatumType::I64 => { + let view = raw_tensor.to_array_view::()?; + Ok(view.iter().cloned().collect()) + } + + // 情况 2: sml2h3 原版模型,输出 F32 概率矩阵 + DatumType::F32 => { + let view = raw_tensor.to_array_view::()?; + let (steps, classes, data_view) = match shape.len() { + 3 => { + if shape[1] == 1 { + // 形状: [Steps, 1, Classes] -> 你的原有逻辑 + (shape[0], shape[2], view.into_dyn()) + } else if shape[0] == 1 { + // 形状: [1, Steps, Classes] -> 另一种常见导出格式 + (shape[1], shape[2], view.into_dyn()) + } else { + // 默认取第一个 batch: [Batch, Steps, Classes] + // 使用 slice 对应 Python 的 output[0, :, :] + let sliced = view.slice(s![0, .., ..]); + (shape[1], shape[2], sliced.into_dyn()) + } + } + 2 => { + // 形状: [Steps, Classes] -> 已经剥离了 Batch 维度 + (shape[0], shape[1], view.into_dyn()) + } + _ => return Err(anyhow::anyhow!("不支持的输出维度: {:?}", shape)), + }; + let array_2d = data_view.to_shape((steps, classes))?; + // + // 对每一行执行 Argmax (寻找概率最大的字符索引) + let indices = array_2d + .outer_iter() + .map(|row| { + row.iter() + .enumerate() + .max_by(|(_, a), (_, b)| { + a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal) + }) + .map(|(idx, _)| idx as i64) + .unwrap_or(0) + }) + .collect(); + Ok(indices) + } + _ => Err(anyhow::anyhow!( + "不支持的模型输出数据类型: {:?}", + raw_tensor.datum_type() + )), + } + } + fn ctc_decode_indices(predicted_indices: &[i64]) -> String { + println!("indices模型输出原始数据: {:?}", predicted_indices); + + use crate::charset::CHARSET_BETA; + // 对应 _ctc_decode_indices 的逻辑:去重、去 blank (0) + let mut res = String::new(); + let mut prev_idx: i64 = -1; + + for &idx in predicted_indices { + // 1. 跳过连续重复的索引 + // 2. 跳过 blank 字符 (假设索引 0 是 blank) + if idx != prev_idx && idx != 0 { + if let Ok(u_idx) = usize::try_from(idx) { + if let Some(&char_str) = CHARSET_BETA.get(u_idx) { + res.push_str(char_str); + } + } + } + prev_idx = idx; + } + println!("最终识别出的验证码是: {}", res); + res + } +} diff --git a/tests/ocr_test.rs b/tests/ocr_test.rs index be75af2..776827d 100644 --- a/tests/ocr_test.rs +++ b/tests/ocr_test.rs @@ -1,9 +1,9 @@ -use ddddocr_rs::DdddOcr; // 假设你的包名是这个 +use ddddocr_rs::{DdddOcr, DdddOcrBuilder}; // 假设你的包名是这个 #[test] fn test_full_classification() { // 1. 初始化模型 - let ocr = DdddOcr::new("model/common.onnx").expect("模型加载失败"); + let ocr = DdddOcrBuilder::new().build().expect("模型加载失败"); // 2. 加载测试图片 let img = image::open("samples/code3.png").expect("测试图片不存在");