2 Commits

Author SHA1 Message Date
1c366b7165 feat: 重构 CTC 解码逻辑
- 重构 ctc_decode 为关联函数并优化内存分配。
- 增加 单元测试和集成测试
2026-05-01 21:54:33 +08:00
642fed5d9f feat: 实现 DdddOcr 核心推理流水线与图像预处理
- 封装 `preprocess_image` 方法,实现 PNG 透明背景修复、灰度化、比例缩放及 NCHW 张量转换。
- 提取 `inference` 逻辑,支持通过 tract-onnx 执行模型推理。
- 实现 `extract_indices` 解析输出张量,支持 I64 索引直接读取与 F32 概率矩阵的 Argmax 处理。
- 完善 `decode_ctc` 解码算法,支持标准 CTC 贪婪搜索与字符集映射。
- 重构 `classification` 主入口,将预处理、推理、解析、解码逻辑解耦,提升代码可维护性。
2026-04-30 17:54:08 +08:00
10 changed files with 302 additions and 65 deletions

8
.idea/.gitignore generated vendored Normal file
View File

@@ -0,0 +1,8 @@
# 默认忽略的文件
/shelf/
/workspace.xml
# 基于编辑器的 HTTP 客户端请求
/httpRequests/
# Datasource local storage ignored files
/dataSources/
/dataSources.local.xml

View File

@@ -5,6 +5,7 @@ edition = "2024"
license = "MIT OR Apache-2.0" license = "MIT OR Apache-2.0"
[dependencies] [dependencies]
tract-onnx = { version = "0.21.1" } tract-onnx = { version = "0.21.10" }
anyhow = "1.0.102" anyhow = "1.0.102"
image = "0.25.10" image = "0.25.10"
base64 = "0.22.1"

5
examples/simple_usage.rs Normal file
View File

@@ -0,0 +1,5 @@
fn main() {
let ocr = ddddocr_rs::DdddOcr::new("model/common.onnx").unwrap();
let img = image::open("samples/code3.png").unwrap();
println!("Result: {}", ocr.classification(&img).unwrap());
}

62
src/image_io.rs Normal file
View File

@@ -0,0 +1,62 @@
use anyhow::{Context, Result};
use base64::{Engine as _, engine::general_purpose};
use image::{DynamicImage, GenericImageView, ImageBuffer, Rgb, RgbImage};
use std::path::{Path, PathBuf};
use tract_onnx::prelude::tract_ndarray::Array3;
/// 定义支持的输入类型枚举
pub enum ImageInput {
Bytes(Vec<u8>),
Array(Array3<u8>),
Path(PathBuf),
Base64(String),
DynamicImage(DynamicImage),
}
/// 模拟 Python 的 load_image_from_input
#[allow(dead_code)]
pub fn load_image_from_input(input: ImageInput) -> Result<DynamicImage> {
match input {
ImageInput::DynamicImage(img) => Ok(img),
_ => todo!("后续补充"),
}
}
/// 对应 Python 的 png_rgba_black_preprocess
/// 将带有透明通道的图片转换为白色背景的 RGB 图片
#[allow(dead_code)]
pub fn png_rgba_white_preprocess(img: &DynamicImage) -> DynamicImage {
// 1. 检查是否包含透明通道,如果没有,直接克隆并返回
if !img.color().has_alpha() {
return img.clone();
}
let (width, height) = img.dimensions();
// 2. 创建一个新的 RGB 图像缓冲,默认填充为白色 (255, 255, 255)
let mut background = ImageBuffer::from_pixel(width, height, Rgb([255u8, 255u8, 255u8]));
// 3. 获取原图的 RGBA 视图
let rgba_img = img.to_rgba8();
// 4. 遍历像素并手动进行 Alpha 混合
// 对应 Python 的 image.paste(img, ..., mask=img)
for (x, y, pixel) in rgba_img.enumerate_pixels() {
let alpha = pixel[3] as f32 / 255.0;
if alpha >= 1.0 {
// 完全不透明,直接覆盖
background.put_pixel(x, y, Rgb([pixel[0], pixel[1], pixel[2]]));
} else if alpha > 0.0 {
// 半透明,执行 Alpha 混合公式: (src * alpha) + (dst * (1 - alpha))
let bg_pixel = background.get_pixel(x, y);
let r = (pixel[0] as f32 * alpha + bg_pixel[0] as f32 * (1.0 - alpha)) as u8;
let g = (pixel[1] as f32 * alpha + bg_pixel[1] as f32 * (1.0 - alpha)) as u8;
let b = (pixel[2] as f32 * alpha + bg_pixel[2] as f32 * (1.0 - alpha)) as u8;
background.put_pixel(x, y, Rgb([r, g, b]));
}
// alpha == 0 的情况不需要处理,因为背景已经是白色了
}
DynamicImage::ImageRgb8(background)
}

27
src/image_processor.rs Normal file
View File

@@ -0,0 +1,27 @@
use image::{DynamicImage, GrayImage, imageops::FilterType};
use anyhow::Result;
/// 对应 Python 的 convert_to_grayscale
/// 将图像转换为灰度图 (L模式)
pub fn convert_to_grayscale(image: &DynamicImage) -> GrayImage {
// Rust image 库的 to_luma8 会根据标准的亮度公式进行转换
image.to_luma8()
}
/// 对应 Python 的 resize_image
/// 调整图像尺寸。当前版本仅实现 keep_aspect_ratio=false
pub fn resize_image(
image: &GrayImage,
target_width: u32,
target_height: u32,
// resample 参数我们直接使用 FilterTypeLanczos3 是最接近 Python LANCZOS 的
) -> GrayImage {
// 使用 resize 算法进行精确缩放
image::imageops::resize(
image,
target_width,
target_height,
FilterType::Lanczos3
)
}

182
src/lib.rs Normal file
View File

@@ -0,0 +1,182 @@
mod charset;
mod image_io;
mod image_processor;
mod model;
mod utils;
use crate::image_io::png_rgba_white_preprocess;
use crate::image_processor::{convert_to_grayscale, resize_image};
use anyhow::{Context, Result};
use image::{DynamicImage, imageops::FilterType};
use tract_onnx::prelude::*;
// 关键点:直接使用 tract 重导出的 ndarray
use tract_onnx::prelude::tract_ndarray::s;
pub struct DdddOcr {
session: RunnableModel<TypedFact, Box<dyn TypedOp>, Graph<TypedFact, Box<dyn TypedOp>>>,
}
impl DdddOcr {
pub fn new<P>(model_path: P) -> Result<Self>
where
P: AsRef<std::path::Path>,
{
let session = onnx()
.model_for_path(model_path)
.with_context(|| "加载 ONNX 模型失败,请检查路径是否正确")?
.into_optimized()?
.into_runnable()?;
Ok(Self { session })
}
pub fn classification(&self, img: &DynamicImage) -> Result<String> {
let tensor = self.preprocess_image(img, false)?;
// let result = self.session.run(tvec!(tensor.into()))?;
// 3. 解析结果
// let output = result[0].to_array_view::<i64>()?;
let output = self.inference(tensor)?;
let output2 = self.process_text_output(&output)?;
Ok(Self::ctc_decode_indices(&output2))
}
/// 对应 Python 的 _preprocess_image
/// 负责:透明背景修复 -> 灰度化 -> 按比例 Resize -> 归一化 -> 4维张量转换
fn preprocess_image(&self, img: &DynamicImage, png_fix: bool) -> Result<Tensor> {
// A. 修复 PNG 透明背景 (内部逻辑你之前已实现)
let _ = if png_fix && img.color().has_alpha() {
png_rgba_white_preprocess(img)
} else {
img.clone()
};
let h = 64u32;
let w = (img.width() as f32 * (h as f32 / img.height() as f32)) as u32;
let gray_img = convert_to_grayscale(img);
let resized = resize_image(&gray_img, w, h);
// resized.save("debug_preprocessed.png").unwrap();
// 1. 预处理:转灰度 -> Resize -> 归一化
// let resized = img.resize_exact(w, h, FilterType::Lanczos3).to_luma8();
// 使用 tract_ndarray 构造,避免版本冲突
let array =
tract_ndarray::Array4::from_shape_fn((1, 1, h as usize, w as usize), |(_, _, y, x)| {
let pixel = resized.get_pixel(x as u32, y as u32)[0] as f32;
(pixel / 255.0 - 0.5) / 0.5
});
let tensor = Tensor::from(array);
Ok(tensor)
}
/// 对应 Python 的 _inference
fn inference(&self, tensor: Tensor) -> Result<Tensor> {
// tract 的 run 会返回一个 Vec<TValue>,我们通常只需要第一个输出
// let result = self.session.run(tvec!(tensor.into()))?;
let mut result = self
.session
.run(tvec!(tensor.into()))
.context("执行模型推理失败")?;
println!("模型输出原始数据: {:?}", result);
Ok(result.remove(0).into_tensor())
}
/// 核心解析逻辑:将模型输出的各种维度/类型的 Tensor 转为字符索引序列
fn process_text_output(&self, raw_tensor: &Tensor) -> Result<Vec<i64>> {
let shape = raw_tensor.shape();
println!("模型输出shape数据: {:?}", shape);
let datum_type = raw_tensor.datum_type();
println!("模型输出datum_type数据: {:?}", datum_type);
match raw_tensor.datum_type() {
// 情况 1: huashi666 式模型,直接输出 i64 索引 (通常是模型内部做好了 Argmax)
DatumType::I64 => {
let view = raw_tensor.to_array_view::<i64>()?;
Ok(view.iter().cloned().collect())
}
// 情况 2: sml2h3 原版模型,输出 F32 概率矩阵
DatumType::F32 => {
let view = raw_tensor.to_array_view::<f32>()?;
let (steps, classes, data_view) = match shape.len() {
3 => {
if shape[1] == 1 {
// 形状: [Steps, 1, Classes] -> 你的原有逻辑
(shape[0], shape[2], view.into_dyn())
} else if shape[0] == 1 {
// 形状: [1, Steps, Classes] -> 另一种常见导出格式
(shape[1], shape[2], view.into_dyn())
} else {
// 默认取第一个 batch: [Batch, Steps, Classes]
// 使用 slice 对应 Python 的 output[0, :, :]
let sliced = view.slice(s![0, .., ..]);
(shape[1], shape[2], sliced.into_dyn())
}
}
2 => {
// 形状: [Steps, Classes] -> 已经剥离了 Batch 维度
(shape[0], shape[1], view.into_dyn())
}
_ => return Err(anyhow::anyhow!("不支持的输出维度: {:?}", shape)),
};
let array_2d = data_view.to_shape((steps, classes))?;
//
// 对每一行执行 Argmax (寻找概率最大的字符索引)
let indices = array_2d
.outer_iter()
.map(|row| {
row.iter()
.enumerate()
.max_by(|(_, a), (_, b)| {
a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal)
})
.map(|(idx, _)| idx as i64)
.unwrap_or(0)
})
.collect();
Ok(indices)
}
_ => Err(anyhow::anyhow!(
"不支持的模型输出数据类型: {:?}",
raw_tensor.datum_type()
)),
}
}
fn ctc_decode_indices(predicted_indices: &[i64]) -> String {
println!("indices模型输出原始数据: {:?}", predicted_indices);
use crate::charset::CHARSET_BETA;
// 对应 _ctc_decode_indices 的逻辑:去重、去 blank (0)
let mut res = String::new();
let mut prev_idx: i64 = -1;
for &idx in predicted_indices {
// 1. 跳过连续重复的索引
// 2. 跳过 blank 字符 (假设索引 0 是 blank)
if idx != prev_idx && idx != 0 {
if let Ok(u_idx) = usize::try_from(idx) {
if let Some(&char_str) = CHARSET_BETA.get(u_idx) {
res.push_str(char_str);
}
}
}
prev_idx = idx;
}
println!("最终识别出的验证码是: {}", res);
res
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_ctc_decode_indices() {
// 模拟一个 DdddOcr 实例(如果 decode 不依赖 session可以设为相关函数
// 这里假设你的 decode_ctc 是公开或内部可访问的
let input = vec![1, 1, 0, 1, 2, 2, 0, 2];
// 逻辑:[1, 1] -> 1, [0] -> 跳过, [1] -> 1, [2, 2] -> 2, [0] -> 跳过, [2] -> 2
// 预期结果索引应该是 [1, 1, 2, 2] 对应的字符
// 具体的断言取决于你的 CHARSET_BETA
// let result = dddd.ctc_decode_indices(&input);
// assert_eq!(result, "AABB");
}
}

View File

@@ -1,64 +0,0 @@
mod charset;
use anyhow::Result;
use charset::CHARSET_BETA;
use image::{imageops::FilterType, open};
use tract_onnx::prelude::*;
// 编译时读取字典文件
fn main() -> Result<()> {
// 1. 加载并优化模型 (假设模型文件在根目录)
let model = onnx()
.model_for_path("model/common.onnx")? // 这里替换成你提取的 ddddocr 模型名
.into_optimized()?
.into_runnable()?;
// 2. 加载并处理图片 (需要缩放到模型要求的尺寸,例如 64x30)
let img = open("samples/code3.png")?;
let h = 64u32;
let w = (img.width() as f32 * (h as f32 / img.height() as f32)) as u32;
// 1. 缩放并转灰度
let resized = img.resize_exact(w, h, FilterType::Lanczos3).to_luma8();
let array =
tract_ndarray::Array4::from_shape_fn((1, 1, h as usize, w as usize), |(_, _, y, x)| {
let pixel = resized.get_pixel(x as u32, y as u32)[0] as f32;
(pixel / 255.0 - 0.5) / 0.5
});
let tensor = Tensor::from(array);
// 4. 运行推理
let result = model.run(tvec!(tensor.into()))?;
// 注意:这里需要根据 ddddocr 的要求将图片转为 Tensor
// 简化逻辑:
// let tensor: Tensor = tract_ndarray::Array4::<f32>::zeros((1, 1, 30, 64)).into();
// 3. 运行推理
// let result = model.run(tvec!(tensor.into()))?;
println!("模型输出原始数据: {:?}", result);
let output = result[0].to_array_view::<i64>()?;
let indices: Vec<i64> = output.iter().cloned().collect();
// 2. 将视图转为切片并调用函数
let code = decode_ctc(&indices);
println!("indices模型输出原始数据: {:?}", indices);
println!("最终识别出的验证码是: {}", code);
Ok(())
}
fn decode_ctc(indices: &[i64]) -> String {
let mut res = String::new();
let mut last_idx: i64 = -1;
for &idx in indices {
// idx == 0 通常是 CTC 的 blank 占位符
if idx != 0 && idx != last_idx {
if let Some(&char_str) = CHARSET_BETA.get(idx as usize) {
res.push_str(char_str);
}
}
last_idx = idx;
}
res
}

0
src/model.rs Normal file
View File

0
src/utils.rs Normal file
View File

16
tests/ocr_test.rs Normal file
View File

@@ -0,0 +1,16 @@
use ddddocr_rs::DdddOcr; // 假设你的包名是这个
#[test]
fn test_full_classification() {
// 1. 初始化模型
let ocr = DdddOcr::new("model/common.onnx").expect("模型加载失败");
// 2. 加载测试图片
let img = image::open("samples/code3.png").expect("测试图片不存在");
// 3. 执行识别
let result = ocr.classification(&img).expect("识别过程出错");
println!("识别结果: {}", result);
assert!(!result.is_empty());
}