feat: 重构 CTC 解码逻辑
- 重构 ctc_decode 为关联函数并优化内存分配。 - 增加 单元测试和集成测试
This commit is contained in:
5
examples/simple_usage.rs
Normal file
5
examples/simple_usage.rs
Normal file
@@ -0,0 +1,5 @@
|
|||||||
|
fn main() {
|
||||||
|
let ocr = ddddocr_rs::DdddOcr::new("model/common.onnx").unwrap();
|
||||||
|
let img = image::open("samples/code3.png").unwrap();
|
||||||
|
println!("Result: {}", ocr.classification(&img).unwrap());
|
||||||
|
}
|
||||||
133
src/lib.rs
133
src/lib.rs
@@ -1,18 +1,16 @@
|
|||||||
mod model;
|
|
||||||
mod utils;
|
|
||||||
|
|
||||||
mod charset;
|
mod charset;
|
||||||
mod image_io;
|
mod image_io;
|
||||||
mod image_processor;
|
mod image_processor;
|
||||||
|
mod model;
|
||||||
|
mod utils;
|
||||||
|
|
||||||
|
use crate::image_io::png_rgba_white_preprocess;
|
||||||
|
use crate::image_processor::{convert_to_grayscale, resize_image};
|
||||||
use anyhow::{Context, Result};
|
use anyhow::{Context, Result};
|
||||||
use image::{DynamicImage, imageops::FilterType};
|
use image::{DynamicImage, imageops::FilterType};
|
||||||
use tract_onnx::prelude::*;
|
use tract_onnx::prelude::*;
|
||||||
// 关键点:直接使用 tract 重导出的 ndarray
|
// 关键点:直接使用 tract 重导出的 ndarray
|
||||||
use crate::image_io::png_rgba_white_preprocess;
|
use tract_onnx::prelude::tract_ndarray::s;
|
||||||
use crate::image_processor::{convert_to_grayscale, resize_image};
|
|
||||||
use tract_onnx::prelude::tract_itertools::Itertools;
|
|
||||||
|
|
||||||
pub struct DdddOcr {
|
pub struct DdddOcr {
|
||||||
session: RunnableModel<TypedFact, Box<dyn TypedOp>, Graph<TypedFact, Box<dyn TypedOp>>>,
|
session: RunnableModel<TypedFact, Box<dyn TypedOp>, Graph<TypedFact, Box<dyn TypedOp>>>,
|
||||||
}
|
}
|
||||||
@@ -37,14 +35,14 @@ impl DdddOcr {
|
|||||||
// 3. 解析结果
|
// 3. 解析结果
|
||||||
// let output = result[0].to_array_view::<i64>()?;
|
// let output = result[0].to_array_view::<i64>()?;
|
||||||
let output = self.inference(tensor)?;
|
let output = self.inference(tensor)?;
|
||||||
let output2 = self.extract_indices(&output)?;
|
let output2 = self.process_text_output(&output)?;
|
||||||
Ok(self.decode_ctc(&output2))
|
Ok(Self::ctc_decode_indices(&output2))
|
||||||
}
|
}
|
||||||
/// 对应 Python 的 _preprocess_image
|
/// 对应 Python 的 _preprocess_image
|
||||||
/// 负责:透明背景修复 -> 灰度化 -> 按比例 Resize -> 归一化 -> 4维张量转换
|
/// 负责:透明背景修复 -> 灰度化 -> 按比例 Resize -> 归一化 -> 4维张量转换
|
||||||
fn preprocess_image(&self, img: &DynamicImage, png_fix: bool) -> Result<Tensor> {
|
fn preprocess_image(&self, img: &DynamicImage, png_fix: bool) -> Result<Tensor> {
|
||||||
// A. 修复 PNG 透明背景 (内部逻辑你之前已实现)
|
// A. 修复 PNG 透明背景 (内部逻辑你之前已实现)
|
||||||
let processed_img = if png_fix && img.color().has_alpha() {
|
let _ = if png_fix && img.color().has_alpha() {
|
||||||
png_rgba_white_preprocess(img)
|
png_rgba_white_preprocess(img)
|
||||||
} else {
|
} else {
|
||||||
img.clone()
|
img.clone()
|
||||||
@@ -54,6 +52,7 @@ impl DdddOcr {
|
|||||||
let w = (img.width() as f32 * (h as f32 / img.height() as f32)) as u32;
|
let w = (img.width() as f32 * (h as f32 / img.height() as f32)) as u32;
|
||||||
let gray_img = convert_to_grayscale(img);
|
let gray_img = convert_to_grayscale(img);
|
||||||
let resized = resize_image(&gray_img, w, h);
|
let resized = resize_image(&gray_img, w, h);
|
||||||
|
// resized.save("debug_preprocessed.png").unwrap();
|
||||||
// 1. 预处理:转灰度 -> Resize -> 归一化
|
// 1. 预处理:转灰度 -> Resize -> 归一化
|
||||||
// let resized = img.resize_exact(w, h, FilterType::Lanczos3).to_luma8();
|
// let resized = img.resize_exact(w, h, FilterType::Lanczos3).to_luma8();
|
||||||
|
|
||||||
@@ -76,12 +75,15 @@ impl DdddOcr {
|
|||||||
.session
|
.session
|
||||||
.run(tvec!(tensor.into()))
|
.run(tvec!(tensor.into()))
|
||||||
.context("执行模型推理失败")?;
|
.context("执行模型推理失败")?;
|
||||||
|
println!("模型输出原始数据: {:?}", result);
|
||||||
Ok(result.remove(0).into_tensor())
|
Ok(result.remove(0).into_tensor())
|
||||||
}
|
}
|
||||||
/// 核心解析逻辑:将模型输出的各种维度/类型的 Tensor 转为字符索引序列
|
/// 核心解析逻辑:将模型输出的各种维度/类型的 Tensor 转为字符索引序列
|
||||||
fn extract_indices(&self, raw_tensor: &Tensor) -> Result<Vec<i64>> {
|
fn process_text_output(&self, raw_tensor: &Tensor) -> Result<Vec<i64>> {
|
||||||
let shape = raw_tensor.shape();
|
let shape = raw_tensor.shape();
|
||||||
|
println!("模型输出shape数据: {:?}", shape);
|
||||||
|
let datum_type = raw_tensor.datum_type();
|
||||||
|
println!("模型输出datum_type数据: {:?}", datum_type);
|
||||||
|
|
||||||
match raw_tensor.datum_type() {
|
match raw_tensor.datum_type() {
|
||||||
// 情况 1: huashi666 式模型,直接输出 i64 索引 (通常是模型内部做好了 Argmax)
|
// 情况 1: huashi666 式模型,直接输出 i64 索引 (通常是模型内部做好了 Argmax)
|
||||||
@@ -93,32 +95,43 @@ impl DdddOcr {
|
|||||||
// 情况 2: sml2h3 原版模型,输出 F32 概率矩阵
|
// 情况 2: sml2h3 原版模型,输出 F32 概率矩阵
|
||||||
DatumType::F32 => {
|
DatumType::F32 => {
|
||||||
let view = raw_tensor.to_array_view::<f32>()?;
|
let view = raw_tensor.to_array_view::<f32>()?;
|
||||||
|
let (steps, classes, data_view) = match shape.len() {
|
||||||
// 处理典型的 CTC 输出形状 [TimeSteps, Batch:1, Classes]
|
3 => {
|
||||||
if shape.len() == 3 {
|
if shape[1] == 1 {
|
||||||
let steps = shape[0];
|
// 形状: [Steps, 1, Classes] -> 你的原有逻辑
|
||||||
let classes = shape[2];
|
(shape[0], shape[2], view.into_dyn())
|
||||||
|
} else if shape[0] == 1 {
|
||||||
// 将一维视图重新整理为二维 [steps, classes]
|
// 形状: [1, Steps, Classes] -> 另一种常见导出格式
|
||||||
let array_2d = view.to_shape((steps, classes))?;
|
(shape[1], shape[2], view.into_dyn())
|
||||||
|
} else {
|
||||||
// 对每一行执行 Argmax (寻找概率最大的字符索引)
|
// 默认取第一个 batch: [Batch, Steps, Classes]
|
||||||
let indices = array_2d
|
// 使用 slice 对应 Python 的 output[0, :, :]
|
||||||
.outer_iter()
|
let sliced = view.slice(s![0, .., ..]);
|
||||||
.map(|row| {
|
(shape[1], shape[2], sliced.into_dyn())
|
||||||
row.iter()
|
}
|
||||||
.enumerate()
|
}
|
||||||
.max_by(|(_, a), (_, b)| {
|
2 => {
|
||||||
a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal)
|
// 形状: [Steps, Classes] -> 已经剥离了 Batch 维度
|
||||||
})
|
(shape[0], shape[1], view.into_dyn())
|
||||||
.map(|(idx, _)| idx as i64)
|
}
|
||||||
.unwrap_or(0)
|
_ => return Err(anyhow::anyhow!("不支持的输出维度: {:?}", shape)),
|
||||||
})
|
};
|
||||||
.collect();
|
let array_2d = data_view.to_shape((steps, classes))?;
|
||||||
Ok(indices)
|
//
|
||||||
} else {
|
// 对每一行执行 Argmax (寻找概率最大的字符索引)
|
||||||
Err(anyhow::anyhow!("不支持的 F32 输出形状: {:?}", shape))
|
let indices = array_2d
|
||||||
}
|
.outer_iter()
|
||||||
|
.map(|row| {
|
||||||
|
row.iter()
|
||||||
|
.enumerate()
|
||||||
|
.max_by(|(_, a), (_, b)| {
|
||||||
|
a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal)
|
||||||
|
})
|
||||||
|
.map(|(idx, _)| idx as i64)
|
||||||
|
.unwrap_or(0)
|
||||||
|
})
|
||||||
|
.collect();
|
||||||
|
Ok(indices)
|
||||||
}
|
}
|
||||||
_ => Err(anyhow::anyhow!(
|
_ => Err(anyhow::anyhow!(
|
||||||
"不支持的模型输出数据类型: {:?}",
|
"不支持的模型输出数据类型: {:?}",
|
||||||
@@ -126,20 +139,44 @@ impl DdddOcr {
|
|||||||
)),
|
)),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
fn decode_ctc(&self, indices: &[i64]) -> String {
|
fn ctc_decode_indices(predicted_indices: &[i64]) -> String {
|
||||||
use crate::charset::CHARSET_BETA;
|
println!("indices模型输出原始数据: {:?}", predicted_indices);
|
||||||
let mut res = String::new();
|
|
||||||
let mut last_idx: i64 = -1;
|
|
||||||
|
|
||||||
for &idx in indices {
|
use crate::charset::CHARSET_BETA;
|
||||||
// ddddocr 的 blank 通常是 0
|
// 对应 _ctc_decode_indices 的逻辑:去重、去 blank (0)
|
||||||
if idx != 0 && idx != last_idx {
|
let mut res = String::new();
|
||||||
if let Some(&char_str) = CHARSET_BETA.get(idx as usize) {
|
let mut prev_idx: i64 = -1;
|
||||||
res.push_str(char_str);
|
|
||||||
|
for &idx in predicted_indices {
|
||||||
|
// 1. 跳过连续重复的索引
|
||||||
|
// 2. 跳过 blank 字符 (假设索引 0 是 blank)
|
||||||
|
if idx != prev_idx && idx != 0 {
|
||||||
|
if let Ok(u_idx) = usize::try_from(idx) {
|
||||||
|
if let Some(&char_str) = CHARSET_BETA.get(u_idx) {
|
||||||
|
res.push_str(char_str);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
last_idx = idx;
|
prev_idx = idx;
|
||||||
}
|
}
|
||||||
|
println!("最终识别出的验证码是: {}", res);
|
||||||
res
|
res
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
mod tests {
|
||||||
|
use super::*;
|
||||||
|
#[test]
|
||||||
|
fn test_ctc_decode_indices() {
|
||||||
|
// 模拟一个 DdddOcr 实例(如果 decode 不依赖 session,可以设为相关函数)
|
||||||
|
// 这里假设你的 decode_ctc 是公开或内部可访问的
|
||||||
|
let input = vec![1, 1, 0, 1, 2, 2, 0, 2];
|
||||||
|
// 逻辑:[1, 1] -> 1, [0] -> 跳过, [1] -> 1, [2, 2] -> 2, [0] -> 跳过, [2] -> 2
|
||||||
|
// 预期结果索引应该是 [1, 1, 2, 2] 对应的字符
|
||||||
|
// 具体的断言取决于你的 CHARSET_BETA
|
||||||
|
// let result = dddd.ctc_decode_indices(&input);
|
||||||
|
// assert_eq!(result, "AABB");
|
||||||
|
}
|
||||||
|
}
|
||||||
104
src/main.rs
104
src/main.rs
@@ -1,104 +0,0 @@
|
|||||||
mod charset;
|
|
||||||
|
|
||||||
use anyhow::{anyhow, Result};
|
|
||||||
use charset::CHARSET_BETA;
|
|
||||||
use image::{imageops::FilterType, open};
|
|
||||||
use tract_onnx::prelude::*;
|
|
||||||
// 编译时读取字典文件
|
|
||||||
fn main() -> Result<()> {
|
|
||||||
// 1. 加载并优化模型 (假设模型文件在根目录)
|
|
||||||
let model = onnx()
|
|
||||||
.model_for_path("model/common_huashi666_i64.onnx")? // 这里替换成你提取的 ddddocr 模型名
|
|
||||||
.into_optimized()?
|
|
||||||
.into_runnable()?;
|
|
||||||
|
|
||||||
// 2. 加载并处理图片 (需要缩放到模型要求的尺寸,例如 64x30)
|
|
||||||
let img = open("samples/code3.png")?;
|
|
||||||
|
|
||||||
let h = 64u32;
|
|
||||||
let w = (img.width() as f32 * (h as f32 / img.height() as f32)) as u32;
|
|
||||||
|
|
||||||
// 1. 缩放并转灰度
|
|
||||||
let resized = img.resize_exact(w, h, FilterType::Lanczos3).to_luma8();
|
|
||||||
let array =
|
|
||||||
tract_ndarray::Array4::from_shape_fn((1, 1, h as usize, w as usize), |(_, _, y, x)| {
|
|
||||||
let pixel = resized.get_pixel(x as u32, y as u32)[0] as f32;
|
|
||||||
(pixel / 255.0 - 0.5) / 0.5
|
|
||||||
});
|
|
||||||
|
|
||||||
let tensor = Tensor::from(array);
|
|
||||||
|
|
||||||
// 4. 运行推理
|
|
||||||
let result = model.run(tvec!(tensor.into()))?;
|
|
||||||
// 注意:这里需要根据 ddddocr 的要求将图片转为 Tensor
|
|
||||||
// 简化逻辑:
|
|
||||||
// let tensor: Tensor = tract_ndarray::Array4::<f32>::zeros((1, 1, 30, 64)).into();
|
|
||||||
let raw_tensor = &result[0];
|
|
||||||
// 3. 运行推理
|
|
||||||
// let result = model.run(tvec!(tensor.into()))?;
|
|
||||||
println!("模型输出原始数据: {:?}", result);
|
|
||||||
let shape = result[0].shape();
|
|
||||||
println!("模型输出shape数据: {:?}", shape);
|
|
||||||
let datum_type = result[0].datum_type();
|
|
||||||
println!("模型输出datum_type数据: {:?}", datum_type);
|
|
||||||
|
|
||||||
let predicted_indices: Vec<i64> = match raw_tensor.datum_type() {
|
|
||||||
// 情况 1: huashi666 式模型,直接输出 i64 索引
|
|
||||||
DatumType::I64 => {
|
|
||||||
raw_tensor.to_array_view::<i64>()?.iter().cloned().collect()
|
|
||||||
}
|
|
||||||
// 情况 2: sml2h3 原版模型,输出 F32 概率
|
|
||||||
DatumType::F32 => {
|
|
||||||
let view = raw_tensor.to_array_view::<f32>()?;
|
|
||||||
|
|
||||||
// 模仿 Python 的维度判断逻辑
|
|
||||||
if shape.len() == 3 {
|
|
||||||
// 假设形状是 [21, 1, 8210]
|
|
||||||
let steps = shape[0];
|
|
||||||
let classes = shape[2];
|
|
||||||
let array_2d = view.to_shape((
|
|
||||||
(steps, classes),
|
|
||||||
tract_onnx::prelude::tract_ndarray::Order::RowMajor
|
|
||||||
))?;
|
|
||||||
|
|
||||||
array_2d.outer_iter()
|
|
||||||
.map(|row| {
|
|
||||||
row.iter().enumerate()
|
|
||||||
.max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap())
|
|
||||||
.map(|(idx, _)| idx as i64).unwrap()
|
|
||||||
})
|
|
||||||
.collect()
|
|
||||||
} else {
|
|
||||||
// 其他形状处理...
|
|
||||||
vec![]
|
|
||||||
}
|
|
||||||
}
|
|
||||||
_ => return Err(anyhow!("不支持的输出类型")),
|
|
||||||
};
|
|
||||||
|
|
||||||
// let output = result[0].to_array_view::<i64>()?;
|
|
||||||
// println!("模型输出原始数据2: {:?}", output);
|
|
||||||
// let indices: Vec<i64> = output.iter().cloned().collect();
|
|
||||||
|
|
||||||
// 2. 将视图转为切片并调用函数
|
|
||||||
let code = decode_ctc(&predicted_indices);
|
|
||||||
println!("indices模型输出原始数据: {:?}", predicted_indices);
|
|
||||||
println!("最终识别出的验证码是: {}", code);
|
|
||||||
Ok(())
|
|
||||||
}
|
|
||||||
// common_huashi666_i64
|
|
||||||
fn decode_ctc(indices: &[i64]) -> String {
|
|
||||||
let mut res = String::new();
|
|
||||||
let mut last_idx: i64 = -1;
|
|
||||||
|
|
||||||
for &idx in indices {
|
|
||||||
// idx == 0 通常是 CTC 的 blank 占位符
|
|
||||||
if idx != 0 && idx != last_idx {
|
|
||||||
if let Some(&char_str) = CHARSET_BETA.get(idx as usize) {
|
|
||||||
res.push_str(char_str);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
last_idx = idx;
|
|
||||||
}
|
|
||||||
res
|
|
||||||
}
|
|
||||||
16
tests/ocr_test.rs
Normal file
16
tests/ocr_test.rs
Normal file
@@ -0,0 +1,16 @@
|
|||||||
|
use ddddocr_rs::DdddOcr; // 假设你的包名是这个
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_full_classification() {
|
||||||
|
// 1. 初始化模型
|
||||||
|
let ocr = DdddOcr::new("model/common.onnx").expect("模型加载失败");
|
||||||
|
|
||||||
|
// 2. 加载测试图片
|
||||||
|
let img = image::open("samples/code3.png").expect("测试图片不存在");
|
||||||
|
|
||||||
|
// 3. 执行识别
|
||||||
|
let result = ocr.classification(&img).expect("识别过程出错");
|
||||||
|
|
||||||
|
println!("识别结果: {}", result);
|
||||||
|
assert!(!result.is_empty());
|
||||||
|
}
|
||||||
Reference in New Issue
Block a user