feat: 实现 DdddOcr 核心推理流水线与图像预处理
- 封装 `preprocess_image` 方法,实现 PNG 透明背景修复、灰度化、比例缩放及 NCHW 张量转换。 - 提取 `inference` 逻辑,支持通过 tract-onnx 执行模型推理。 - 实现 `extract_indices` 解析输出张量,支持 I64 索引直接读取与 F32 概率矩阵的 Argmax 处理。 - 完善 `decode_ctc` 解码算法,支持标准 CTC 贪婪搜索与字符集映射。 - 重构 `classification` 主入口,将预处理、推理、解析、解码逻辑解耦,提升代码可维护性。
This commit is contained in:
145
src/lib.rs
Normal file
145
src/lib.rs
Normal file
@@ -0,0 +1,145 @@
|
||||
mod model;
|
||||
mod utils;
|
||||
|
||||
mod charset;
|
||||
mod image_io;
|
||||
mod image_processor;
|
||||
|
||||
use anyhow::{Context, Result};
|
||||
use image::{DynamicImage, imageops::FilterType};
|
||||
use tract_onnx::prelude::*;
|
||||
// 关键点:直接使用 tract 重导出的 ndarray
|
||||
use crate::image_io::png_rgba_white_preprocess;
|
||||
use crate::image_processor::{convert_to_grayscale, resize_image};
|
||||
use tract_onnx::prelude::tract_itertools::Itertools;
|
||||
|
||||
pub struct DdddOcr {
|
||||
session: RunnableModel<TypedFact, Box<dyn TypedOp>, Graph<TypedFact, Box<dyn TypedOp>>>,
|
||||
}
|
||||
|
||||
impl DdddOcr {
|
||||
pub fn new<P>(model_path: P) -> Result<Self>
|
||||
where
|
||||
P: AsRef<std::path::Path>,
|
||||
{
|
||||
let session = onnx()
|
||||
.model_for_path(model_path)
|
||||
.with_context(|| "加载 ONNX 模型失败,请检查路径是否正确")?
|
||||
.into_optimized()?
|
||||
.into_runnable()?;
|
||||
Ok(Self { session })
|
||||
}
|
||||
|
||||
pub fn classification(&self, img: &DynamicImage) -> Result<String> {
|
||||
let tensor = self.preprocess_image(img, false)?;
|
||||
|
||||
// let result = self.session.run(tvec!(tensor.into()))?;
|
||||
// 3. 解析结果
|
||||
// let output = result[0].to_array_view::<i64>()?;
|
||||
let output = self.inference(tensor)?;
|
||||
let output2 = self.extract_indices(&output)?;
|
||||
Ok(self.decode_ctc(&output2))
|
||||
}
|
||||
/// 对应 Python 的 _preprocess_image
|
||||
/// 负责:透明背景修复 -> 灰度化 -> 按比例 Resize -> 归一化 -> 4维张量转换
|
||||
fn preprocess_image(&self, img: &DynamicImage, png_fix: bool) -> Result<Tensor> {
|
||||
// A. 修复 PNG 透明背景 (内部逻辑你之前已实现)
|
||||
let processed_img = if png_fix && img.color().has_alpha() {
|
||||
png_rgba_white_preprocess(img)
|
||||
} else {
|
||||
img.clone()
|
||||
};
|
||||
|
||||
let h = 64u32;
|
||||
let w = (img.width() as f32 * (h as f32 / img.height() as f32)) as u32;
|
||||
let gray_img = convert_to_grayscale(img);
|
||||
let resized = resize_image(&gray_img, w, h);
|
||||
// 1. 预处理:转灰度 -> Resize -> 归一化
|
||||
// let resized = img.resize_exact(w, h, FilterType::Lanczos3).to_luma8();
|
||||
|
||||
// 使用 tract_ndarray 构造,避免版本冲突
|
||||
let array =
|
||||
tract_ndarray::Array4::from_shape_fn((1, 1, h as usize, w as usize), |(_, _, y, x)| {
|
||||
let pixel = resized.get_pixel(x as u32, y as u32)[0] as f32;
|
||||
(pixel / 255.0 - 0.5) / 0.5
|
||||
});
|
||||
|
||||
let tensor = Tensor::from(array);
|
||||
|
||||
Ok(tensor)
|
||||
}
|
||||
/// 对应 Python 的 _inference
|
||||
fn inference(&self, tensor: Tensor) -> Result<Tensor> {
|
||||
// tract 的 run 会返回一个 Vec<TValue>,我们通常只需要第一个输出
|
||||
// let result = self.session.run(tvec!(tensor.into()))?;
|
||||
let mut result = self
|
||||
.session
|
||||
.run(tvec!(tensor.into()))
|
||||
.context("执行模型推理失败")?;
|
||||
|
||||
Ok(result.remove(0).into_tensor())
|
||||
}
|
||||
/// 核心解析逻辑:将模型输出的各种维度/类型的 Tensor 转为字符索引序列
|
||||
fn extract_indices(&self, raw_tensor: &Tensor) -> Result<Vec<i64>> {
|
||||
let shape = raw_tensor.shape();
|
||||
|
||||
match raw_tensor.datum_type() {
|
||||
// 情况 1: huashi666 式模型,直接输出 i64 索引 (通常是模型内部做好了 Argmax)
|
||||
DatumType::I64 => {
|
||||
let view = raw_tensor.to_array_view::<i64>()?;
|
||||
Ok(view.iter().cloned().collect())
|
||||
}
|
||||
|
||||
// 情况 2: sml2h3 原版模型,输出 F32 概率矩阵
|
||||
DatumType::F32 => {
|
||||
let view = raw_tensor.to_array_view::<f32>()?;
|
||||
|
||||
// 处理典型的 CTC 输出形状 [TimeSteps, Batch:1, Classes]
|
||||
if shape.len() == 3 {
|
||||
let steps = shape[0];
|
||||
let classes = shape[2];
|
||||
|
||||
// 将一维视图重新整理为二维 [steps, classes]
|
||||
let array_2d = view.to_shape((steps, classes))?;
|
||||
|
||||
// 对每一行执行 Argmax (寻找概率最大的字符索引)
|
||||
let indices = array_2d
|
||||
.outer_iter()
|
||||
.map(|row| {
|
||||
row.iter()
|
||||
.enumerate()
|
||||
.max_by(|(_, a), (_, b)| {
|
||||
a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal)
|
||||
})
|
||||
.map(|(idx, _)| idx as i64)
|
||||
.unwrap_or(0)
|
||||
})
|
||||
.collect();
|
||||
Ok(indices)
|
||||
} else {
|
||||
Err(anyhow::anyhow!("不支持的 F32 输出形状: {:?}", shape))
|
||||
}
|
||||
}
|
||||
_ => Err(anyhow::anyhow!(
|
||||
"不支持的模型输出数据类型: {:?}",
|
||||
raw_tensor.datum_type()
|
||||
)),
|
||||
}
|
||||
}
|
||||
fn decode_ctc(&self, indices: &[i64]) -> String {
|
||||
use crate::charset::CHARSET_BETA;
|
||||
let mut res = String::new();
|
||||
let mut last_idx: i64 = -1;
|
||||
|
||||
for &idx in indices {
|
||||
// ddddocr 的 blank 通常是 0
|
||||
if idx != 0 && idx != last_idx {
|
||||
if let Some(&char_str) = CHARSET_BETA.get(idx as usize) {
|
||||
res.push_str(char_str);
|
||||
}
|
||||
}
|
||||
last_idx = idx;
|
||||
}
|
||||
res
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user