From 642fed5d9f73d5a4779bd27c58af447f447e8cda Mon Sep 17 00:00:00 2001 From: CNWei Date: Thu, 30 Apr 2026 17:54:08 +0800 Subject: [PATCH] =?UTF-8?q?feat:=20=E5=AE=9E=E7=8E=B0=20DdddOcr=20?= =?UTF-8?q?=E6=A0=B8=E5=BF=83=E6=8E=A8=E7=90=86=E6=B5=81=E6=B0=B4=E7=BA=BF?= =?UTF-8?q?=E4=B8=8E=E5=9B=BE=E5=83=8F=E9=A2=84=E5=A4=84=E7=90=86?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - 封装 `preprocess_image` 方法,实现 PNG 透明背景修复、灰度化、比例缩放及 NCHW 张量转换。 - 提取 `inference` 逻辑,支持通过 tract-onnx 执行模型推理。 - 实现 `extract_indices` 解析输出张量,支持 I64 索引直接读取与 F32 概率矩阵的 Argmax 处理。 - 完善 `decode_ctc` 解码算法,支持标准 CTC 贪婪搜索与字符集映射。 - 重构 `classification` 主入口,将预处理、推理、解析、解码逻辑解耦,提升代码可维护性。 --- .idea/.gitignore | 8 +++ Cargo.toml | 3 +- src/image_io.rs | 62 ++++++++++++++++++ src/image_processor.rs | 27 ++++++++ src/lib.rs | 145 +++++++++++++++++++++++++++++++++++++++++ src/main.rs | 56 +++++++++++++--- src/model.rs | 0 src/utils.rs | 0 8 files changed, 292 insertions(+), 9 deletions(-) create mode 100644 .idea/.gitignore create mode 100644 src/image_io.rs create mode 100644 src/image_processor.rs create mode 100644 src/lib.rs create mode 100644 src/model.rs create mode 100644 src/utils.rs diff --git a/.idea/.gitignore b/.idea/.gitignore new file mode 100644 index 0000000..35410ca --- /dev/null +++ b/.idea/.gitignore @@ -0,0 +1,8 @@ +# 默认忽略的文件 +/shelf/ +/workspace.xml +# 基于编辑器的 HTTP 客户端请求 +/httpRequests/ +# Datasource local storage ignored files +/dataSources/ +/dataSources.local.xml diff --git a/Cargo.toml b/Cargo.toml index 055e59b..53883ce 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -5,6 +5,7 @@ edition = "2024" license = "MIT OR Apache-2.0" [dependencies] -tract-onnx = { version = "0.21.1" } +tract-onnx = { version = "0.21.10" } anyhow = "1.0.102" image = "0.25.10" +base64 = "0.22.1" diff --git a/src/image_io.rs b/src/image_io.rs new file mode 100644 index 0000000..3a69b7a --- /dev/null +++ b/src/image_io.rs @@ -0,0 +1,62 @@ +use anyhow::{Context, Result}; +use base64::{Engine as _, engine::general_purpose}; +use image::{DynamicImage, GenericImageView, ImageBuffer, Rgb, RgbImage}; +use std::path::{Path, PathBuf}; +use tract_onnx::prelude::tract_ndarray::Array3; + +/// 定义支持的输入类型枚举 +pub enum ImageInput { + Bytes(Vec), + Array(Array3), + Path(PathBuf), + Base64(String), + DynamicImage(DynamicImage), +} + +/// 模拟 Python 的 load_image_from_input +#[allow(dead_code)] +pub fn load_image_from_input(input: ImageInput) -> Result { + match input { + ImageInput::DynamicImage(img) => Ok(img), + _ => todo!("后续补充"), + } +} + +/// 对应 Python 的 png_rgba_black_preprocess +/// 将带有透明通道的图片转换为白色背景的 RGB 图片 +#[allow(dead_code)] +pub fn png_rgba_white_preprocess(img: &DynamicImage) -> DynamicImage { + // 1. 检查是否包含透明通道,如果没有,直接克隆并返回 + if !img.color().has_alpha() { + return img.clone(); + } + + let (width, height) = img.dimensions(); + + // 2. 创建一个新的 RGB 图像缓冲,默认填充为白色 (255, 255, 255) + let mut background = ImageBuffer::from_pixel(width, height, Rgb([255u8, 255u8, 255u8])); + + // 3. 获取原图的 RGBA 视图 + let rgba_img = img.to_rgba8(); + + // 4. 遍历像素并手动进行 Alpha 混合 + // 对应 Python 的 image.paste(img, ..., mask=img) + for (x, y, pixel) in rgba_img.enumerate_pixels() { + let alpha = pixel[3] as f32 / 255.0; + + if alpha >= 1.0 { + // 完全不透明,直接覆盖 + background.put_pixel(x, y, Rgb([pixel[0], pixel[1], pixel[2]])); + } else if alpha > 0.0 { + // 半透明,执行 Alpha 混合公式: (src * alpha) + (dst * (1 - alpha)) + let bg_pixel = background.get_pixel(x, y); + let r = (pixel[0] as f32 * alpha + bg_pixel[0] as f32 * (1.0 - alpha)) as u8; + let g = (pixel[1] as f32 * alpha + bg_pixel[1] as f32 * (1.0 - alpha)) as u8; + let b = (pixel[2] as f32 * alpha + bg_pixel[2] as f32 * (1.0 - alpha)) as u8; + background.put_pixel(x, y, Rgb([r, g, b])); + } + // alpha == 0 的情况不需要处理,因为背景已经是白色了 + } + + DynamicImage::ImageRgb8(background) +} diff --git a/src/image_processor.rs b/src/image_processor.rs new file mode 100644 index 0000000..b500e7b --- /dev/null +++ b/src/image_processor.rs @@ -0,0 +1,27 @@ +use image::{DynamicImage, GrayImage, imageops::FilterType}; +use anyhow::Result; + +/// 对应 Python 的 convert_to_grayscale +/// 将图像转换为灰度图 (L模式) +pub fn convert_to_grayscale(image: &DynamicImage) -> GrayImage { + // Rust image 库的 to_luma8 会根据标准的亮度公式进行转换 + image.to_luma8() +} + +/// 对应 Python 的 resize_image +/// 调整图像尺寸。当前版本仅实现 keep_aspect_ratio=false +pub fn resize_image( + image: &GrayImage, + target_width: u32, + target_height: u32, + // resample 参数我们直接使用 FilterType,Lanczos3 是最接近 Python LANCZOS 的 +) -> GrayImage { + // 使用 resize 算法进行精确缩放 + image::imageops::resize( + image, + target_width, + target_height, + FilterType::Lanczos3 + ) +} + diff --git a/src/lib.rs b/src/lib.rs new file mode 100644 index 0000000..de28ea2 --- /dev/null +++ b/src/lib.rs @@ -0,0 +1,145 @@ +mod model; +mod utils; + +mod charset; +mod image_io; +mod image_processor; + +use anyhow::{Context, Result}; +use image::{DynamicImage, imageops::FilterType}; +use tract_onnx::prelude::*; +// 关键点:直接使用 tract 重导出的 ndarray +use crate::image_io::png_rgba_white_preprocess; +use crate::image_processor::{convert_to_grayscale, resize_image}; +use tract_onnx::prelude::tract_itertools::Itertools; + +pub struct DdddOcr { + session: RunnableModel, Graph>>, +} + +impl DdddOcr { + pub fn new

(model_path: P) -> Result + where + P: AsRef, + { + let session = onnx() + .model_for_path(model_path) + .with_context(|| "加载 ONNX 模型失败,请检查路径是否正确")? + .into_optimized()? + .into_runnable()?; + Ok(Self { session }) + } + + pub fn classification(&self, img: &DynamicImage) -> Result { + let tensor = self.preprocess_image(img, false)?; + + // let result = self.session.run(tvec!(tensor.into()))?; + // 3. 解析结果 + // let output = result[0].to_array_view::()?; + let output = self.inference(tensor)?; + let output2 = self.extract_indices(&output)?; + Ok(self.decode_ctc(&output2)) + } + /// 对应 Python 的 _preprocess_image + /// 负责:透明背景修复 -> 灰度化 -> 按比例 Resize -> 归一化 -> 4维张量转换 + fn preprocess_image(&self, img: &DynamicImage, png_fix: bool) -> Result { + // A. 修复 PNG 透明背景 (内部逻辑你之前已实现) + let processed_img = if png_fix && img.color().has_alpha() { + png_rgba_white_preprocess(img) + } else { + img.clone() + }; + + let h = 64u32; + let w = (img.width() as f32 * (h as f32 / img.height() as f32)) as u32; + let gray_img = convert_to_grayscale(img); + let resized = resize_image(&gray_img, w, h); + // 1. 预处理:转灰度 -> Resize -> 归一化 + // let resized = img.resize_exact(w, h, FilterType::Lanczos3).to_luma8(); + + // 使用 tract_ndarray 构造,避免版本冲突 + let array = + tract_ndarray::Array4::from_shape_fn((1, 1, h as usize, w as usize), |(_, _, y, x)| { + let pixel = resized.get_pixel(x as u32, y as u32)[0] as f32; + (pixel / 255.0 - 0.5) / 0.5 + }); + + let tensor = Tensor::from(array); + + Ok(tensor) + } + /// 对应 Python 的 _inference + fn inference(&self, tensor: Tensor) -> Result { + // tract 的 run 会返回一个 Vec,我们通常只需要第一个输出 + // let result = self.session.run(tvec!(tensor.into()))?; + let mut result = self + .session + .run(tvec!(tensor.into())) + .context("执行模型推理失败")?; + + Ok(result.remove(0).into_tensor()) + } + /// 核心解析逻辑:将模型输出的各种维度/类型的 Tensor 转为字符索引序列 + fn extract_indices(&self, raw_tensor: &Tensor) -> Result> { + let shape = raw_tensor.shape(); + + match raw_tensor.datum_type() { + // 情况 1: huashi666 式模型,直接输出 i64 索引 (通常是模型内部做好了 Argmax) + DatumType::I64 => { + let view = raw_tensor.to_array_view::()?; + Ok(view.iter().cloned().collect()) + } + + // 情况 2: sml2h3 原版模型,输出 F32 概率矩阵 + DatumType::F32 => { + let view = raw_tensor.to_array_view::()?; + + // 处理典型的 CTC 输出形状 [TimeSteps, Batch:1, Classes] + if shape.len() == 3 { + let steps = shape[0]; + let classes = shape[2]; + + // 将一维视图重新整理为二维 [steps, classes] + let array_2d = view.to_shape((steps, classes))?; + + // 对每一行执行 Argmax (寻找概率最大的字符索引) + let indices = array_2d + .outer_iter() + .map(|row| { + row.iter() + .enumerate() + .max_by(|(_, a), (_, b)| { + a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal) + }) + .map(|(idx, _)| idx as i64) + .unwrap_or(0) + }) + .collect(); + Ok(indices) + } else { + Err(anyhow::anyhow!("不支持的 F32 输出形状: {:?}", shape)) + } + } + _ => Err(anyhow::anyhow!( + "不支持的模型输出数据类型: {:?}", + raw_tensor.datum_type() + )), + } + } + fn decode_ctc(&self, indices: &[i64]) -> String { + use crate::charset::CHARSET_BETA; + let mut res = String::new(); + let mut last_idx: i64 = -1; + + for &idx in indices { + // ddddocr 的 blank 通常是 0 + if idx != 0 && idx != last_idx { + if let Some(&char_str) = CHARSET_BETA.get(idx as usize) { + res.push_str(char_str); + } + } + last_idx = idx; + } + res + } +} diff --git a/src/main.rs b/src/main.rs index 0afda9b..41e0f1a 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,6 +1,6 @@ mod charset; -use anyhow::Result; +use anyhow::{anyhow, Result}; use charset::CHARSET_BETA; use image::{imageops::FilterType, open}; use tract_onnx::prelude::*; @@ -8,7 +8,7 @@ use tract_onnx::prelude::*; fn main() -> Result<()> { // 1. 加载并优化模型 (假设模型文件在根目录) let model = onnx() - .model_for_path("model/common.onnx")? // 这里替换成你提取的 ddddocr 模型名 + .model_for_path("model/common_huashi666_i64.onnx")? // 这里替换成你提取的 ddddocr 模型名 .into_optimized()? .into_runnable()?; @@ -33,20 +33,60 @@ fn main() -> Result<()> { // 注意:这里需要根据 ddddocr 的要求将图片转为 Tensor // 简化逻辑: // let tensor: Tensor = tract_ndarray::Array4::::zeros((1, 1, 30, 64)).into(); - + let raw_tensor = &result[0]; // 3. 运行推理 // let result = model.run(tvec!(tensor.into()))?; println!("模型输出原始数据: {:?}", result); - let output = result[0].to_array_view::()?; - let indices: Vec = output.iter().cloned().collect(); + let shape = result[0].shape(); + println!("模型输出shape数据: {:?}", shape); + let datum_type = result[0].datum_type(); + println!("模型输出datum_type数据: {:?}", datum_type); + + let predicted_indices: Vec = match raw_tensor.datum_type() { + // 情况 1: huashi666 式模型,直接输出 i64 索引 + DatumType::I64 => { + raw_tensor.to_array_view::()?.iter().cloned().collect() + } + // 情况 2: sml2h3 原版模型,输出 F32 概率 + DatumType::F32 => { + let view = raw_tensor.to_array_view::()?; + + // 模仿 Python 的维度判断逻辑 + if shape.len() == 3 { + // 假设形状是 [21, 1, 8210] + let steps = shape[0]; + let classes = shape[2]; + let array_2d = view.to_shape(( + (steps, classes), + tract_onnx::prelude::tract_ndarray::Order::RowMajor + ))?; + + array_2d.outer_iter() + .map(|row| { + row.iter().enumerate() + .max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap()) + .map(|(idx, _)| idx as i64).unwrap() + }) + .collect() + } else { + // 其他形状处理... + vec![] + } + } + _ => return Err(anyhow!("不支持的输出类型")), + }; + + // let output = result[0].to_array_view::()?; + // println!("模型输出原始数据2: {:?}", output); + // let indices: Vec = output.iter().cloned().collect(); // 2. 将视图转为切片并调用函数 - let code = decode_ctc(&indices); - println!("indices模型输出原始数据: {:?}", indices); + let code = decode_ctc(&predicted_indices); + println!("indices模型输出原始数据: {:?}", predicted_indices); println!("最终识别出的验证码是: {}", code); Ok(()) } - +// common_huashi666_i64 fn decode_ctc(indices: &[i64]) -> String { let mut res = String::new(); let mut last_idx: i64 = -1; diff --git a/src/model.rs b/src/model.rs new file mode 100644 index 0000000..e69de29 diff --git a/src/utils.rs b/src/utils.rs new file mode 100644 index 0000000..e69de29