Files
ddddocr-rs/src/ocr_model.rs
CNWei 8fcfa2096e refactor: 移除 OpenCV 依赖并实现纯 Rust 图像处理流水线
- 替换 opencv 为 image 库以简化交叉编译
- 修正 nms 逻辑中的 ArrayView 借用问题
- 增加 save_debug_image 方法用于可视化检测框
- 更新 Cargo.toml 依赖项
2026-05-06 17:37:38 +08:00

252 lines
9.3 KiB
Rust
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

use crate::base::ModelArgs;
use crate::image_io::png_rgba_white_preprocess;
use crate::image_processor::{convert_to_grayscale, resize_image};
use crate::model_loader::{ModelLoader, ModelSession, ModelType};
use anyhow::Context;
use image::DynamicImage;
use tract_onnx::prelude::tract_ndarray::s;
use tract_onnx::prelude::{
DatumType, Graph, IntoTensor, RunnableModel, Tensor, TypedFact, TypedOp, tract_ndarray, tvec,
};
// 颜色过滤的自定义范围:(低值RGB, 高值RGB)
pub type ColorRange = ((u8, u8, u8), (u8, u8, u8));
// 字符集范围类型
#[derive(Debug, Clone)]
pub enum CharsetRange {
All, // 所有字符
Digit, // 数字
Letter, // 字母
Alphanumeric, // 字母数字
Single(String), // 单字符串
Multiple(Vec<String>), // 多个字符串
Range(char, char), // 字符范围
Custom(Vec<char>), // 自定义字符列表
}
#[derive(Debug, Clone)]
pub struct PredictArgs {
/// 是否修复PNG格式问题
pub png_fix: bool,
/// 是否返回概率信息
pub probability: bool,
/// 颜色过滤:保留的颜色列表
pub color_filter_colors: Option<Vec<String>>,
/// 颜色过滤自定义RGB范围
pub color_filter_custom_ranges: Option<Vec<ColorRange>>,
/// 字符集范围
pub charset_range: Option<CharsetRange>,
}
impl Default for PredictArgs {
fn default() -> Self {
Self {
png_fix: false,
probability: false,
color_filter_colors: None,
color_filter_custom_ranges: None,
charset_range: None,
}
}
}
impl PredictArgs {
pub fn new() -> Self {
Self::default()
}
// Builder 模式方法
pub fn png_fix(mut self, enabled: bool) -> Self {
self.png_fix = enabled;
self
}
pub fn probability(mut self, enabled: bool) -> Self {
self.probability = enabled;
self
}
pub fn color_filter_colors(mut self, colors: Vec<String>) -> Self {
self.color_filter_colors = Some(colors);
self
}
pub fn color_filter_custom_ranges(mut self, ranges: Vec<ColorRange>) -> Self {
self.color_filter_custom_ranges = Some(ranges);
self
}
pub fn charset_range(mut self, range: CharsetRange) -> Self {
self.charset_range = Some(range);
self
}
// 便捷构造方法
pub fn quick() -> Self {
Self::default()
}
pub fn with_probability() -> Self {
Self::default().probability(true)
}
pub fn with_png_fix() -> Self {
Self::default().png_fix(true)
}
}
pub struct Ocr {
session: RunnableModel<TypedFact, Box<dyn TypedOp>, Graph<TypedFact, Box<dyn TypedOp>>>,
charset: Vec<String>,
}
impl ModelSession for Ocr {
fn get_model_type(&self) -> ModelType {
todo!()
}
fn desc(&self) -> String {
"Ocr Model 加载成功".to_string()
}
}
impl Ocr {
pub fn new(model_path: String, charset: Vec<String>) -> Result<Self, anyhow::Error> {
let session = ModelLoader::load_model(&model_path)?.session;
Ok(Self { session, charset })
}
pub fn predict(&self, image: &DynamicImage, png_fix: bool) -> Result<String, anyhow::Error> {
let tensor = self.preprocess_image(image, png_fix)?;
//
// let result = self.session.run(tvec!(tensor.into()))?;
// // 3. 解析结果
// // let output = result[0].to_array_view::<i64>()?;
let output = self.inference(tensor)?;
let output2 = self.process_text_output(&output)?;
Ok(self.ctc_decode_indices(&output2))
// Ok("ocr result".to_string())
}
/// 对应 Python 的 _preprocess_image
/// 负责:透明背景修复 -> 灰度化 -> 按比例 Resize -> 归一化 -> 4维张量转换
fn preprocess_image(&self, img: &DynamicImage, png_fix: bool) -> anyhow::Result<Tensor> {
// A. 修复 PNG 透明背景 (内部逻辑你之前已实现)
let _ = if png_fix && img.color().has_alpha() {
png_rgba_white_preprocess(img)
} else {
img.clone()
};
let h = 64u32;
let w = (img.width() as f32 * (h as f32 / img.height() as f32)) as u32;
let gray_img = convert_to_grayscale(img);
let resized = resize_image(&gray_img, w, h);
// resized.save("debug_preprocessed.png").unwrap();
// 1. 预处理:转灰度 -> Resize -> 归一化
// let resized = img.resize_exact(w, h, FilterType::Lanczos3).to_luma8();
// 使用 tract_ndarray 构造,避免版本冲突
let array =
tract_ndarray::Array4::from_shape_fn((1, 1, h as usize, w as usize), |(_, _, y, x)| {
let pixel = resized.get_pixel(x as u32, y as u32)[0] as f32;
(pixel / 255.0 - 0.5) / 0.5
});
let tensor = Tensor::from(array);
Ok(tensor)
}
/// 对应 Python 的 _inference
fn inference(&self, tensor: Tensor) -> anyhow::Result<Tensor> {
// tract 的 run 会返回一个 Vec<TValue>,我们通常只需要第一个输出
// let result = self.session.run(tvec!(tensor.into()))?;
let mut result = self
.session
.run(tvec!(tensor.into()))
.context("执行模型推理失败")?;
println!("模型输出原始数据: {:?}", result);
Ok(result.remove(0).into_tensor())
}
/// 核心解析逻辑:将模型输出的各种维度/类型的 Tensor 转为字符索引序列
fn process_text_output(&self, raw_tensor: &Tensor) -> anyhow::Result<Vec<i64>> {
let shape = raw_tensor.shape();
println!("模型输出shape数据: {:?}", shape);
let datum_type = raw_tensor.datum_type();
println!("模型输出datum_type数据: {:?}", datum_type);
match raw_tensor.datum_type() {
// 情况 1: huashi666 式模型,直接输出 i64 索引 (通常是模型内部做好了 Argmax)
DatumType::I64 => {
let view = raw_tensor.to_array_view::<i64>()?;
Ok(view.iter().cloned().collect())
}
// 情况 2: sml2h3 原版模型,输出 F32 概率矩阵
DatumType::F32 => {
let view = raw_tensor.to_array_view::<f32>()?;
let (steps, classes, data_view) = match shape.len() {
3 => {
if shape[1] == 1 {
// 形状: [Steps, 1, Classes] -> 你的原有逻辑
(shape[0], shape[2], view.into_dyn())
} else if shape[0] == 1 {
// 形状: [1, Steps, Classes] -> 另一种常见导出格式
(shape[1], shape[2], view.into_dyn())
} else {
// 默认取第一个 batch: [Batch, Steps, Classes]
// 使用 slice 对应 Python 的 output[0, :, :]
let sliced = view.slice(s![0, .., ..]);
(shape[1], shape[2], sliced.into_dyn())
}
}
2 => {
// 形状: [Steps, Classes] -> 已经剥离了 Batch 维度
(shape[0], shape[1], view.into_dyn())
}
_ => return Err(anyhow::anyhow!("不支持的输出维度: {:?}", shape)),
};
let array_2d = data_view.to_shape((steps, classes))?;
//
// 对每一行执行 Argmax (寻找概率最大的字符索引)
let indices = array_2d
.outer_iter()
.map(|row| {
row.iter()
.enumerate()
.max_by(|(_, a), (_, b)| {
a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal)
})
.map(|(idx, _)| idx as i64)
.unwrap_or(0)
})
.collect();
Ok(indices)
}
_ => Err(anyhow::anyhow!(
"不支持的模型输出数据类型: {:?}",
raw_tensor.datum_type()
)),
}
}
fn ctc_decode_indices(&self, predicted_indices: &[i64]) -> String {
println!("indices模型输出原始数据: {:?}", predicted_indices);
// 对应 _ctc_decode_indices 的逻辑:去重、去 blank (0)
let mut res = String::new();
let mut prev_idx: i64 = -1;
for &idx in predicted_indices {
// 1. 跳过连续重复的索引
// 2. 跳过 blank 字符 (假设索引 0 是 blank)
if idx != prev_idx && idx != 0 {
if let Ok(u_idx) = usize::try_from(idx) {
if let Some(char_str) = self.charset.get(u_idx) {
res.push_str(char_str);
} else {
// 保护逻辑:如果模型预测的索引超出了字符集范围
eprintln!("警告: 预测索引 {} 超出字符集范围", u_idx);
}
}
}
prev_idx = idx;
}
println!("最终识别出的验证码是: {}", res);
res
}
}