diff --git a/examples/simple_usage.rs b/examples/simple_usage.rs new file mode 100644 index 0000000..c5031f8 --- /dev/null +++ b/examples/simple_usage.rs @@ -0,0 +1,5 @@ +fn main() { + let ocr = ddddocr_rs::DdddOcr::new("model/common.onnx").unwrap(); + let img = image::open("samples/code3.png").unwrap(); + println!("Result: {}", ocr.classification(&img).unwrap()); +} \ No newline at end of file diff --git a/src/lib.rs b/src/lib.rs index de28ea2..0445875 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,18 +1,16 @@ -mod model; -mod utils; - mod charset; mod image_io; mod image_processor; +mod model; +mod utils; +use crate::image_io::png_rgba_white_preprocess; +use crate::image_processor::{convert_to_grayscale, resize_image}; use anyhow::{Context, Result}; use image::{DynamicImage, imageops::FilterType}; use tract_onnx::prelude::*; // 关键点:直接使用 tract 重导出的 ndarray -use crate::image_io::png_rgba_white_preprocess; -use crate::image_processor::{convert_to_grayscale, resize_image}; -use tract_onnx::prelude::tract_itertools::Itertools; - +use tract_onnx::prelude::tract_ndarray::s; pub struct DdddOcr { session: RunnableModel, Graph>>, } @@ -37,14 +35,14 @@ impl DdddOcr { // 3. 解析结果 // let output = result[0].to_array_view::()?; let output = self.inference(tensor)?; - let output2 = self.extract_indices(&output)?; - Ok(self.decode_ctc(&output2)) + let output2 = self.process_text_output(&output)?; + Ok(Self::ctc_decode_indices(&output2)) } /// 对应 Python 的 _preprocess_image /// 负责:透明背景修复 -> 灰度化 -> 按比例 Resize -> 归一化 -> 4维张量转换 fn preprocess_image(&self, img: &DynamicImage, png_fix: bool) -> Result { // A. 修复 PNG 透明背景 (内部逻辑你之前已实现) - let processed_img = if png_fix && img.color().has_alpha() { + let _ = if png_fix && img.color().has_alpha() { png_rgba_white_preprocess(img) } else { img.clone() @@ -54,6 +52,7 @@ impl DdddOcr { let w = (img.width() as f32 * (h as f32 / img.height() as f32)) as u32; let gray_img = convert_to_grayscale(img); let resized = resize_image(&gray_img, w, h); + // resized.save("debug_preprocessed.png").unwrap(); // 1. 预处理:转灰度 -> Resize -> 归一化 // let resized = img.resize_exact(w, h, FilterType::Lanczos3).to_luma8(); @@ -76,12 +75,15 @@ impl DdddOcr { .session .run(tvec!(tensor.into())) .context("执行模型推理失败")?; - + println!("模型输出原始数据: {:?}", result); Ok(result.remove(0).into_tensor()) } /// 核心解析逻辑:将模型输出的各种维度/类型的 Tensor 转为字符索引序列 - fn extract_indices(&self, raw_tensor: &Tensor) -> Result> { + fn process_text_output(&self, raw_tensor: &Tensor) -> Result> { let shape = raw_tensor.shape(); + println!("模型输出shape数据: {:?}", shape); + let datum_type = raw_tensor.datum_type(); + println!("模型输出datum_type数据: {:?}", datum_type); match raw_tensor.datum_type() { // 情况 1: huashi666 式模型,直接输出 i64 索引 (通常是模型内部做好了 Argmax) @@ -93,32 +95,43 @@ impl DdddOcr { // 情况 2: sml2h3 原版模型,输出 F32 概率矩阵 DatumType::F32 => { let view = raw_tensor.to_array_view::()?; - - // 处理典型的 CTC 输出形状 [TimeSteps, Batch:1, Classes] - if shape.len() == 3 { - let steps = shape[0]; - let classes = shape[2]; - - // 将一维视图重新整理为二维 [steps, classes] - let array_2d = view.to_shape((steps, classes))?; - - // 对每一行执行 Argmax (寻找概率最大的字符索引) - let indices = array_2d - .outer_iter() - .map(|row| { - row.iter() - .enumerate() - .max_by(|(_, a), (_, b)| { - a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal) - }) - .map(|(idx, _)| idx as i64) - .unwrap_or(0) - }) - .collect(); - Ok(indices) - } else { - Err(anyhow::anyhow!("不支持的 F32 输出形状: {:?}", shape)) - } + let (steps, classes, data_view) = match shape.len() { + 3 => { + if shape[1] == 1 { + // 形状: [Steps, 1, Classes] -> 你的原有逻辑 + (shape[0], shape[2], view.into_dyn()) + } else if shape[0] == 1 { + // 形状: [1, Steps, Classes] -> 另一种常见导出格式 + (shape[1], shape[2], view.into_dyn()) + } else { + // 默认取第一个 batch: [Batch, Steps, Classes] + // 使用 slice 对应 Python 的 output[0, :, :] + let sliced = view.slice(s![0, .., ..]); + (shape[1], shape[2], sliced.into_dyn()) + } + } + 2 => { + // 形状: [Steps, Classes] -> 已经剥离了 Batch 维度 + (shape[0], shape[1], view.into_dyn()) + } + _ => return Err(anyhow::anyhow!("不支持的输出维度: {:?}", shape)), + }; + let array_2d = data_view.to_shape((steps, classes))?; + // + // 对每一行执行 Argmax (寻找概率最大的字符索引) + let indices = array_2d + .outer_iter() + .map(|row| { + row.iter() + .enumerate() + .max_by(|(_, a), (_, b)| { + a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal) + }) + .map(|(idx, _)| idx as i64) + .unwrap_or(0) + }) + .collect(); + Ok(indices) } _ => Err(anyhow::anyhow!( "不支持的模型输出数据类型: {:?}", @@ -126,20 +139,44 @@ impl DdddOcr { )), } } - fn decode_ctc(&self, indices: &[i64]) -> String { - use crate::charset::CHARSET_BETA; - let mut res = String::new(); - let mut last_idx: i64 = -1; + fn ctc_decode_indices(predicted_indices: &[i64]) -> String { + println!("indices模型输出原始数据: {:?}", predicted_indices); - for &idx in indices { - // ddddocr 的 blank 通常是 0 - if idx != 0 && idx != last_idx { - if let Some(&char_str) = CHARSET_BETA.get(idx as usize) { - res.push_str(char_str); + use crate::charset::CHARSET_BETA; + // 对应 _ctc_decode_indices 的逻辑:去重、去 blank (0) + let mut res = String::new(); + let mut prev_idx: i64 = -1; + + for &idx in predicted_indices { + // 1. 跳过连续重复的索引 + // 2. 跳过 blank 字符 (假设索引 0 是 blank) + if idx != prev_idx && idx != 0 { + if let Ok(u_idx) = usize::try_from(idx) { + if let Some(&char_str) = CHARSET_BETA.get(u_idx) { + res.push_str(char_str); + } } } - last_idx = idx; + prev_idx = idx; } + println!("最终识别出的验证码是: {}", res); res } } + + +#[cfg(test)] +mod tests { + use super::*; + #[test] + fn test_ctc_decode_indices() { + // 模拟一个 DdddOcr 实例(如果 decode 不依赖 session,可以设为相关函数) + // 这里假设你的 decode_ctc 是公开或内部可访问的 + let input = vec![1, 1, 0, 1, 2, 2, 0, 2]; + // 逻辑:[1, 1] -> 1, [0] -> 跳过, [1] -> 1, [2, 2] -> 2, [0] -> 跳过, [2] -> 2 + // 预期结果索引应该是 [1, 1, 2, 2] 对应的字符 + // 具体的断言取决于你的 CHARSET_BETA + // let result = dddd.ctc_decode_indices(&input); + // assert_eq!(result, "AABB"); + } +} \ No newline at end of file diff --git a/src/main.rs b/src/main.rs deleted file mode 100644 index 41e0f1a..0000000 --- a/src/main.rs +++ /dev/null @@ -1,104 +0,0 @@ -mod charset; - -use anyhow::{anyhow, Result}; -use charset::CHARSET_BETA; -use image::{imageops::FilterType, open}; -use tract_onnx::prelude::*; -// 编译时读取字典文件 -fn main() -> Result<()> { - // 1. 加载并优化模型 (假设模型文件在根目录) - let model = onnx() - .model_for_path("model/common_huashi666_i64.onnx")? // 这里替换成你提取的 ddddocr 模型名 - .into_optimized()? - .into_runnable()?; - - // 2. 加载并处理图片 (需要缩放到模型要求的尺寸,例如 64x30) - let img = open("samples/code3.png")?; - - let h = 64u32; - let w = (img.width() as f32 * (h as f32 / img.height() as f32)) as u32; - - // 1. 缩放并转灰度 - let resized = img.resize_exact(w, h, FilterType::Lanczos3).to_luma8(); - let array = - tract_ndarray::Array4::from_shape_fn((1, 1, h as usize, w as usize), |(_, _, y, x)| { - let pixel = resized.get_pixel(x as u32, y as u32)[0] as f32; - (pixel / 255.0 - 0.5) / 0.5 - }); - - let tensor = Tensor::from(array); - - // 4. 运行推理 - let result = model.run(tvec!(tensor.into()))?; - // 注意:这里需要根据 ddddocr 的要求将图片转为 Tensor - // 简化逻辑: - // let tensor: Tensor = tract_ndarray::Array4::::zeros((1, 1, 30, 64)).into(); - let raw_tensor = &result[0]; - // 3. 运行推理 - // let result = model.run(tvec!(tensor.into()))?; - println!("模型输出原始数据: {:?}", result); - let shape = result[0].shape(); - println!("模型输出shape数据: {:?}", shape); - let datum_type = result[0].datum_type(); - println!("模型输出datum_type数据: {:?}", datum_type); - - let predicted_indices: Vec = match raw_tensor.datum_type() { - // 情况 1: huashi666 式模型,直接输出 i64 索引 - DatumType::I64 => { - raw_tensor.to_array_view::()?.iter().cloned().collect() - } - // 情况 2: sml2h3 原版模型,输出 F32 概率 - DatumType::F32 => { - let view = raw_tensor.to_array_view::()?; - - // 模仿 Python 的维度判断逻辑 - if shape.len() == 3 { - // 假设形状是 [21, 1, 8210] - let steps = shape[0]; - let classes = shape[2]; - let array_2d = view.to_shape(( - (steps, classes), - tract_onnx::prelude::tract_ndarray::Order::RowMajor - ))?; - - array_2d.outer_iter() - .map(|row| { - row.iter().enumerate() - .max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap()) - .map(|(idx, _)| idx as i64).unwrap() - }) - .collect() - } else { - // 其他形状处理... - vec![] - } - } - _ => return Err(anyhow!("不支持的输出类型")), - }; - - // let output = result[0].to_array_view::()?; - // println!("模型输出原始数据2: {:?}", output); - // let indices: Vec = output.iter().cloned().collect(); - - // 2. 将视图转为切片并调用函数 - let code = decode_ctc(&predicted_indices); - println!("indices模型输出原始数据: {:?}", predicted_indices); - println!("最终识别出的验证码是: {}", code); - Ok(()) -} -// common_huashi666_i64 -fn decode_ctc(indices: &[i64]) -> String { - let mut res = String::new(); - let mut last_idx: i64 = -1; - - for &idx in indices { - // idx == 0 通常是 CTC 的 blank 占位符 - if idx != 0 && idx != last_idx { - if let Some(&char_str) = CHARSET_BETA.get(idx as usize) { - res.push_str(char_str); - } - } - last_idx = idx; - } - res -} diff --git a/tests/ocr_test.rs b/tests/ocr_test.rs new file mode 100644 index 0000000..be75af2 --- /dev/null +++ b/tests/ocr_test.rs @@ -0,0 +1,16 @@ +use ddddocr_rs::DdddOcr; // 假设你的包名是这个 + +#[test] +fn test_full_classification() { + // 1. 初始化模型 + let ocr = DdddOcr::new("model/common.onnx").expect("模型加载失败"); + + // 2. 加载测试图片 + let img = image::open("samples/code3.png").expect("测试图片不存在"); + + // 3. 执行识别 + let result = ocr.classification(&img).expect("识别过程出错"); + + println!("识别结果: {}", result); + assert!(!result.is_empty()); +} \ No newline at end of file