feat: 重构模型初始化逻辑

- 重构 DdddOcr。
- 新增 DdddOcrBuilder。
- 其他优化
This commit is contained in:
2026-05-05 22:18:12 +08:00
parent 1c366b7165
commit cfeb68ad04
8 changed files with 434 additions and 153 deletions

View File

@@ -1,169 +1,95 @@
pub mod base;
mod charset;
mod det_model;
mod image_io;
mod image_processor;
mod model;
mod model_loader;
mod ocr_model;
mod utils;
use crate::image_io::png_rgba_white_preprocess;
use crate::image_processor::{convert_to_grayscale, resize_image};
use anyhow::{Context, Result};
use image::{DynamicImage, imageops::FilterType};
use tract_onnx::prelude::*;
use anyhow::Result;
use image::DynamicImage;
// 关键点:直接使用 tract 重导出的 ndarray
use tract_onnx::prelude::tract_ndarray::s;
pub struct DdddOcr {
session: RunnableModel<TypedFact, Box<dyn TypedOp>, Graph<TypedFact, Box<dyn TypedOp>>>,
use crate::det_model::Det;
use crate::model_loader::ModelSession;
use crate::ocr_model::Ocr;
use crate::charset::get_default_charset;
pub enum ModeType {
/// 默认 OCR (使用内置路径)
Ocr {
path: String,
charset: Vec<String>,
},
Det {
path: String,
},
/// 自定义 OCR (路径由用户提供)
CustomOcr {
path: String,
charset: Vec<String>,
},
}
impl DdddOcr {
pub fn new<P>(model_path: P) -> Result<Self>
where
P: AsRef<std::path::Path>,
{
let session = onnx()
.model_for_path(model_path)
.with_context(|| "加载 ONNX 模型失败,请检查路径是否正确")?
.into_optimized()?
.into_runnable()?;
Ok(Self { session })
pub struct DdddOcrBuilder {
mode: ModeType,
}
impl DdddOcrBuilder {
pub fn new() -> Self {
Self {
mode: ModeType::Ocr {
path: "models/common.onnx".to_string(),
charset: get_default_charset(),
},
}
}
pub fn classification(&self, img: &DynamicImage) -> Result<String> {
let tensor = self.preprocess_image(img, false)?;
// let result = self.session.run(tvec!(tensor.into()))?;
// 3. 解析结果
// let output = result[0].to_array_view::<i64>()?;
let output = self.inference(tensor)?;
let output2 = self.process_text_output(&output)?;
Ok(Self::ctc_decode_indices(&output2))
/// 切换为检测模式
pub fn det(mut self) -> Self {
self.mode = ModeType::Det {
path: "models/common_det.onnx".to_string(),
};
self
}
/// 对应 Python 的 _preprocess_image
/// 负责:透明背景修复 -> 灰度化 -> 按比例 Resize -> 归一化 -> 4维张量转换
fn preprocess_image(&self, img: &DynamicImage, png_fix: bool) -> Result<Tensor> {
// A. 修复 PNG 透明背景 (内部逻辑你之前已实现)
let _ = if png_fix && img.color().has_alpha() {
png_rgba_white_preprocess(img)
} else {
img.clone()
/// 设置自定义 OCR 路径
pub fn custom_ocr(mut self, path: String, charset: Vec<String>) -> Self {
// 直接重写枚举,替换掉之前的 Ocr 或 Det
self.mode = ModeType::CustomOcr { path, charset };
self
}
/// 核心初始化逻辑
pub fn build(self) -> Result<DdddOcr> {
let session: Box<dyn ModelSession> = match self.mode {
ModeType::Ocr { path, charset } => Box::new(Ocr::new(path, charset)?),
ModeType::Det { path } => Box::new(Det::new(path)?),
ModeType::CustomOcr { path, charset } => Box::new(Ocr::new(path, charset)?),
};
let h = 64u32;
let w = (img.width() as f32 * (h as f32 / img.height() as f32)) as u32;
let gray_img = convert_to_grayscale(img);
let resized = resize_image(&gray_img, w, h);
// resized.save("debug_preprocessed.png").unwrap();
// 1. 预处理:转灰度 -> Resize -> 归一化
// let resized = img.resize_exact(w, h, FilterType::Lanczos3).to_luma8();
// 使用 tract_ndarray 构造,避免版本冲突
let array =
tract_ndarray::Array4::from_shape_fn((1, 1, h as usize, w as usize), |(_, _, y, x)| {
let pixel = resized.get_pixel(x as u32, y as u32)[0] as f32;
(pixel / 255.0 - 0.5) / 0.5
});
let tensor = Tensor::from(array);
Ok(tensor)
}
/// 对应 Python 的 _inference
fn inference(&self, tensor: Tensor) -> Result<Tensor> {
// tract 的 run 会返回一个 Vec<TValue>,我们通常只需要第一个输出
// let result = self.session.run(tvec!(tensor.into()))?;
let mut result = self
.session
.run(tvec!(tensor.into()))
.context("执行模型推理失败")?;
println!("模型输出原始数据: {:?}", result);
Ok(result.remove(0).into_tensor())
}
/// 核心解析逻辑:将模型输出的各种维度/类型的 Tensor 转为字符索引序列
fn process_text_output(&self, raw_tensor: &Tensor) -> Result<Vec<i64>> {
let shape = raw_tensor.shape();
println!("模型输出shape数据: {:?}", shape);
let datum_type = raw_tensor.datum_type();
println!("模型输出datum_type数据: {:?}", datum_type);
match raw_tensor.datum_type() {
// 情况 1: huashi666 式模型,直接输出 i64 索引 (通常是模型内部做好了 Argmax)
DatumType::I64 => {
let view = raw_tensor.to_array_view::<i64>()?;
Ok(view.iter().cloned().collect())
}
// 情况 2: sml2h3 原版模型,输出 F32 概率矩阵
DatumType::F32 => {
let view = raw_tensor.to_array_view::<f32>()?;
let (steps, classes, data_view) = match shape.len() {
3 => {
if shape[1] == 1 {
// 形状: [Steps, 1, Classes] -> 你的原有逻辑
(shape[0], shape[2], view.into_dyn())
} else if shape[0] == 1 {
// 形状: [1, Steps, Classes] -> 另一种常见导出格式
(shape[1], shape[2], view.into_dyn())
} else {
// 默认取第一个 batch: [Batch, Steps, Classes]
// 使用 slice 对应 Python 的 output[0, :, :]
let sliced = view.slice(s![0, .., ..]);
(shape[1], shape[2], sliced.into_dyn())
}
}
2 => {
// 形状: [Steps, Classes] -> 已经剥离了 Batch 维度
(shape[0], shape[1], view.into_dyn())
}
_ => return Err(anyhow::anyhow!("不支持的输出维度: {:?}", shape)),
};
let array_2d = data_view.to_shape((steps, classes))?;
//
// 对每一行执行 Argmax (寻找概率最大的字符索引)
let indices = array_2d
.outer_iter()
.map(|row| {
row.iter()
.enumerate()
.max_by(|(_, a), (_, b)| {
a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal)
})
.map(|(idx, _)| idx as i64)
.unwrap_or(0)
})
.collect();
Ok(indices)
}
_ => Err(anyhow::anyhow!(
"不支持的模型输出数据类型: {:?}",
raw_tensor.datum_type()
)),
}
}
fn ctc_decode_indices(predicted_indices: &[i64]) -> String {
println!("indices模型输出原始数据: {:?}", predicted_indices);
use crate::charset::CHARSET_BETA;
// 对应 _ctc_decode_indices 的逻辑:去重、去 blank (0)
let mut res = String::new();
let mut prev_idx: i64 = -1;
for &idx in predicted_indices {
// 1. 跳过连续重复的索引
// 2. 跳过 blank 字符 (假设索引 0 是 blank)
if idx != prev_idx && idx != 0 {
if let Ok(u_idx) = usize::try_from(idx) {
if let Some(&char_str) = CHARSET_BETA.get(u_idx) {
res.push_str(char_str);
}
}
}
prev_idx = idx;
}
println!("最终识别出的验证码是: {}", res);
res
Ok(DdddOcr { session })
}
}
pub struct DdddOcr {
session: Box<dyn ModelSession>,
}
impl DdddOcr {
pub fn classification(&self, img: &DynamicImage) -> Result<String> {
self.session.predict(img, false)
// let tensor = self.preprocess_image(img, false)?;
//
// // let result = self.session.run(tvec!(tensor.into()))?;
// // 3. 解析结果
// // let output = result[0].to_array_view::<i64>()?;
// let output = self.inference(tensor)?;
// let output2 = self.process_text_output(&output)?;
// Ok(Self::ctc_decode_indices(&output2))
}
}
#[cfg(test)]
mod tests {
@@ -179,4 +105,4 @@ mod tests {
// let result = dddd.ctc_decode_indices(&input);
// assert_eq!(result, "AABB");
}
}
}