feat: 实现 DdddOcr 核心推理流水线与图像预处理
- 封装 `preprocess_image` 方法,实现 PNG 透明背景修复、灰度化、比例缩放及 NCHW 张量转换。 - 提取 `inference` 逻辑,支持通过 tract-onnx 执行模型推理。 - 实现 `extract_indices` 解析输出张量,支持 I64 索引直接读取与 F32 概率矩阵的 Argmax 处理。 - 完善 `decode_ctc` 解码算法,支持标准 CTC 贪婪搜索与字符集映射。 - 重构 `classification` 主入口,将预处理、推理、解析、解码逻辑解耦,提升代码可维护性。
This commit is contained in:
8
.idea/.gitignore
generated
vendored
Normal file
8
.idea/.gitignore
generated
vendored
Normal file
@@ -0,0 +1,8 @@
|
|||||||
|
# 默认忽略的文件
|
||||||
|
/shelf/
|
||||||
|
/workspace.xml
|
||||||
|
# 基于编辑器的 HTTP 客户端请求
|
||||||
|
/httpRequests/
|
||||||
|
# Datasource local storage ignored files
|
||||||
|
/dataSources/
|
||||||
|
/dataSources.local.xml
|
||||||
@@ -5,6 +5,7 @@ edition = "2024"
|
|||||||
license = "MIT OR Apache-2.0"
|
license = "MIT OR Apache-2.0"
|
||||||
|
|
||||||
[dependencies]
|
[dependencies]
|
||||||
tract-onnx = { version = "0.21.1" }
|
tract-onnx = { version = "0.21.10" }
|
||||||
anyhow = "1.0.102"
|
anyhow = "1.0.102"
|
||||||
image = "0.25.10"
|
image = "0.25.10"
|
||||||
|
base64 = "0.22.1"
|
||||||
|
|||||||
62
src/image_io.rs
Normal file
62
src/image_io.rs
Normal file
@@ -0,0 +1,62 @@
|
|||||||
|
use anyhow::{Context, Result};
|
||||||
|
use base64::{Engine as _, engine::general_purpose};
|
||||||
|
use image::{DynamicImage, GenericImageView, ImageBuffer, Rgb, RgbImage};
|
||||||
|
use std::path::{Path, PathBuf};
|
||||||
|
use tract_onnx::prelude::tract_ndarray::Array3;
|
||||||
|
|
||||||
|
/// 定义支持的输入类型枚举
|
||||||
|
pub enum ImageInput {
|
||||||
|
Bytes(Vec<u8>),
|
||||||
|
Array(Array3<u8>),
|
||||||
|
Path(PathBuf),
|
||||||
|
Base64(String),
|
||||||
|
DynamicImage(DynamicImage),
|
||||||
|
}
|
||||||
|
|
||||||
|
/// 模拟 Python 的 load_image_from_input
|
||||||
|
#[allow(dead_code)]
|
||||||
|
pub fn load_image_from_input(input: ImageInput) -> Result<DynamicImage> {
|
||||||
|
match input {
|
||||||
|
ImageInput::DynamicImage(img) => Ok(img),
|
||||||
|
_ => todo!("后续补充"),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// 对应 Python 的 png_rgba_black_preprocess
|
||||||
|
/// 将带有透明通道的图片转换为白色背景的 RGB 图片
|
||||||
|
#[allow(dead_code)]
|
||||||
|
pub fn png_rgba_white_preprocess(img: &DynamicImage) -> DynamicImage {
|
||||||
|
// 1. 检查是否包含透明通道,如果没有,直接克隆并返回
|
||||||
|
if !img.color().has_alpha() {
|
||||||
|
return img.clone();
|
||||||
|
}
|
||||||
|
|
||||||
|
let (width, height) = img.dimensions();
|
||||||
|
|
||||||
|
// 2. 创建一个新的 RGB 图像缓冲,默认填充为白色 (255, 255, 255)
|
||||||
|
let mut background = ImageBuffer::from_pixel(width, height, Rgb([255u8, 255u8, 255u8]));
|
||||||
|
|
||||||
|
// 3. 获取原图的 RGBA 视图
|
||||||
|
let rgba_img = img.to_rgba8();
|
||||||
|
|
||||||
|
// 4. 遍历像素并手动进行 Alpha 混合
|
||||||
|
// 对应 Python 的 image.paste(img, ..., mask=img)
|
||||||
|
for (x, y, pixel) in rgba_img.enumerate_pixels() {
|
||||||
|
let alpha = pixel[3] as f32 / 255.0;
|
||||||
|
|
||||||
|
if alpha >= 1.0 {
|
||||||
|
// 完全不透明,直接覆盖
|
||||||
|
background.put_pixel(x, y, Rgb([pixel[0], pixel[1], pixel[2]]));
|
||||||
|
} else if alpha > 0.0 {
|
||||||
|
// 半透明,执行 Alpha 混合公式: (src * alpha) + (dst * (1 - alpha))
|
||||||
|
let bg_pixel = background.get_pixel(x, y);
|
||||||
|
let r = (pixel[0] as f32 * alpha + bg_pixel[0] as f32 * (1.0 - alpha)) as u8;
|
||||||
|
let g = (pixel[1] as f32 * alpha + bg_pixel[1] as f32 * (1.0 - alpha)) as u8;
|
||||||
|
let b = (pixel[2] as f32 * alpha + bg_pixel[2] as f32 * (1.0 - alpha)) as u8;
|
||||||
|
background.put_pixel(x, y, Rgb([r, g, b]));
|
||||||
|
}
|
||||||
|
// alpha == 0 的情况不需要处理,因为背景已经是白色了
|
||||||
|
}
|
||||||
|
|
||||||
|
DynamicImage::ImageRgb8(background)
|
||||||
|
}
|
||||||
27
src/image_processor.rs
Normal file
27
src/image_processor.rs
Normal file
@@ -0,0 +1,27 @@
|
|||||||
|
use image::{DynamicImage, GrayImage, imageops::FilterType};
|
||||||
|
use anyhow::Result;
|
||||||
|
|
||||||
|
/// 对应 Python 的 convert_to_grayscale
|
||||||
|
/// 将图像转换为灰度图 (L模式)
|
||||||
|
pub fn convert_to_grayscale(image: &DynamicImage) -> GrayImage {
|
||||||
|
// Rust image 库的 to_luma8 会根据标准的亮度公式进行转换
|
||||||
|
image.to_luma8()
|
||||||
|
}
|
||||||
|
|
||||||
|
/// 对应 Python 的 resize_image
|
||||||
|
/// 调整图像尺寸。当前版本仅实现 keep_aspect_ratio=false
|
||||||
|
pub fn resize_image(
|
||||||
|
image: &GrayImage,
|
||||||
|
target_width: u32,
|
||||||
|
target_height: u32,
|
||||||
|
// resample 参数我们直接使用 FilterType,Lanczos3 是最接近 Python LANCZOS 的
|
||||||
|
) -> GrayImage {
|
||||||
|
// 使用 resize 算法进行精确缩放
|
||||||
|
image::imageops::resize(
|
||||||
|
image,
|
||||||
|
target_width,
|
||||||
|
target_height,
|
||||||
|
FilterType::Lanczos3
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
145
src/lib.rs
Normal file
145
src/lib.rs
Normal file
@@ -0,0 +1,145 @@
|
|||||||
|
mod model;
|
||||||
|
mod utils;
|
||||||
|
|
||||||
|
mod charset;
|
||||||
|
mod image_io;
|
||||||
|
mod image_processor;
|
||||||
|
|
||||||
|
use anyhow::{Context, Result};
|
||||||
|
use image::{DynamicImage, imageops::FilterType};
|
||||||
|
use tract_onnx::prelude::*;
|
||||||
|
// 关键点:直接使用 tract 重导出的 ndarray
|
||||||
|
use crate::image_io::png_rgba_white_preprocess;
|
||||||
|
use crate::image_processor::{convert_to_grayscale, resize_image};
|
||||||
|
use tract_onnx::prelude::tract_itertools::Itertools;
|
||||||
|
|
||||||
|
pub struct DdddOcr {
|
||||||
|
session: RunnableModel<TypedFact, Box<dyn TypedOp>, Graph<TypedFact, Box<dyn TypedOp>>>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl DdddOcr {
|
||||||
|
pub fn new<P>(model_path: P) -> Result<Self>
|
||||||
|
where
|
||||||
|
P: AsRef<std::path::Path>,
|
||||||
|
{
|
||||||
|
let session = onnx()
|
||||||
|
.model_for_path(model_path)
|
||||||
|
.with_context(|| "加载 ONNX 模型失败,请检查路径是否正确")?
|
||||||
|
.into_optimized()?
|
||||||
|
.into_runnable()?;
|
||||||
|
Ok(Self { session })
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn classification(&self, img: &DynamicImage) -> Result<String> {
|
||||||
|
let tensor = self.preprocess_image(img, false)?;
|
||||||
|
|
||||||
|
// let result = self.session.run(tvec!(tensor.into()))?;
|
||||||
|
// 3. 解析结果
|
||||||
|
// let output = result[0].to_array_view::<i64>()?;
|
||||||
|
let output = self.inference(tensor)?;
|
||||||
|
let output2 = self.extract_indices(&output)?;
|
||||||
|
Ok(self.decode_ctc(&output2))
|
||||||
|
}
|
||||||
|
/// 对应 Python 的 _preprocess_image
|
||||||
|
/// 负责:透明背景修复 -> 灰度化 -> 按比例 Resize -> 归一化 -> 4维张量转换
|
||||||
|
fn preprocess_image(&self, img: &DynamicImage, png_fix: bool) -> Result<Tensor> {
|
||||||
|
// A. 修复 PNG 透明背景 (内部逻辑你之前已实现)
|
||||||
|
let processed_img = if png_fix && img.color().has_alpha() {
|
||||||
|
png_rgba_white_preprocess(img)
|
||||||
|
} else {
|
||||||
|
img.clone()
|
||||||
|
};
|
||||||
|
|
||||||
|
let h = 64u32;
|
||||||
|
let w = (img.width() as f32 * (h as f32 / img.height() as f32)) as u32;
|
||||||
|
let gray_img = convert_to_grayscale(img);
|
||||||
|
let resized = resize_image(&gray_img, w, h);
|
||||||
|
// 1. 预处理:转灰度 -> Resize -> 归一化
|
||||||
|
// let resized = img.resize_exact(w, h, FilterType::Lanczos3).to_luma8();
|
||||||
|
|
||||||
|
// 使用 tract_ndarray 构造,避免版本冲突
|
||||||
|
let array =
|
||||||
|
tract_ndarray::Array4::from_shape_fn((1, 1, h as usize, w as usize), |(_, _, y, x)| {
|
||||||
|
let pixel = resized.get_pixel(x as u32, y as u32)[0] as f32;
|
||||||
|
(pixel / 255.0 - 0.5) / 0.5
|
||||||
|
});
|
||||||
|
|
||||||
|
let tensor = Tensor::from(array);
|
||||||
|
|
||||||
|
Ok(tensor)
|
||||||
|
}
|
||||||
|
/// 对应 Python 的 _inference
|
||||||
|
fn inference(&self, tensor: Tensor) -> Result<Tensor> {
|
||||||
|
// tract 的 run 会返回一个 Vec<TValue>,我们通常只需要第一个输出
|
||||||
|
// let result = self.session.run(tvec!(tensor.into()))?;
|
||||||
|
let mut result = self
|
||||||
|
.session
|
||||||
|
.run(tvec!(tensor.into()))
|
||||||
|
.context("执行模型推理失败")?;
|
||||||
|
|
||||||
|
Ok(result.remove(0).into_tensor())
|
||||||
|
}
|
||||||
|
/// 核心解析逻辑:将模型输出的各种维度/类型的 Tensor 转为字符索引序列
|
||||||
|
fn extract_indices(&self, raw_tensor: &Tensor) -> Result<Vec<i64>> {
|
||||||
|
let shape = raw_tensor.shape();
|
||||||
|
|
||||||
|
match raw_tensor.datum_type() {
|
||||||
|
// 情况 1: huashi666 式模型,直接输出 i64 索引 (通常是模型内部做好了 Argmax)
|
||||||
|
DatumType::I64 => {
|
||||||
|
let view = raw_tensor.to_array_view::<i64>()?;
|
||||||
|
Ok(view.iter().cloned().collect())
|
||||||
|
}
|
||||||
|
|
||||||
|
// 情况 2: sml2h3 原版模型,输出 F32 概率矩阵
|
||||||
|
DatumType::F32 => {
|
||||||
|
let view = raw_tensor.to_array_view::<f32>()?;
|
||||||
|
|
||||||
|
// 处理典型的 CTC 输出形状 [TimeSteps, Batch:1, Classes]
|
||||||
|
if shape.len() == 3 {
|
||||||
|
let steps = shape[0];
|
||||||
|
let classes = shape[2];
|
||||||
|
|
||||||
|
// 将一维视图重新整理为二维 [steps, classes]
|
||||||
|
let array_2d = view.to_shape((steps, classes))?;
|
||||||
|
|
||||||
|
// 对每一行执行 Argmax (寻找概率最大的字符索引)
|
||||||
|
let indices = array_2d
|
||||||
|
.outer_iter()
|
||||||
|
.map(|row| {
|
||||||
|
row.iter()
|
||||||
|
.enumerate()
|
||||||
|
.max_by(|(_, a), (_, b)| {
|
||||||
|
a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal)
|
||||||
|
})
|
||||||
|
.map(|(idx, _)| idx as i64)
|
||||||
|
.unwrap_or(0)
|
||||||
|
})
|
||||||
|
.collect();
|
||||||
|
Ok(indices)
|
||||||
|
} else {
|
||||||
|
Err(anyhow::anyhow!("不支持的 F32 输出形状: {:?}", shape))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
_ => Err(anyhow::anyhow!(
|
||||||
|
"不支持的模型输出数据类型: {:?}",
|
||||||
|
raw_tensor.datum_type()
|
||||||
|
)),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
fn decode_ctc(&self, indices: &[i64]) -> String {
|
||||||
|
use crate::charset::CHARSET_BETA;
|
||||||
|
let mut res = String::new();
|
||||||
|
let mut last_idx: i64 = -1;
|
||||||
|
|
||||||
|
for &idx in indices {
|
||||||
|
// ddddocr 的 blank 通常是 0
|
||||||
|
if idx != 0 && idx != last_idx {
|
||||||
|
if let Some(&char_str) = CHARSET_BETA.get(idx as usize) {
|
||||||
|
res.push_str(char_str);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
last_idx = idx;
|
||||||
|
}
|
||||||
|
res
|
||||||
|
}
|
||||||
|
}
|
||||||
56
src/main.rs
56
src/main.rs
@@ -1,6 +1,6 @@
|
|||||||
mod charset;
|
mod charset;
|
||||||
|
|
||||||
use anyhow::Result;
|
use anyhow::{anyhow, Result};
|
||||||
use charset::CHARSET_BETA;
|
use charset::CHARSET_BETA;
|
||||||
use image::{imageops::FilterType, open};
|
use image::{imageops::FilterType, open};
|
||||||
use tract_onnx::prelude::*;
|
use tract_onnx::prelude::*;
|
||||||
@@ -8,7 +8,7 @@ use tract_onnx::prelude::*;
|
|||||||
fn main() -> Result<()> {
|
fn main() -> Result<()> {
|
||||||
// 1. 加载并优化模型 (假设模型文件在根目录)
|
// 1. 加载并优化模型 (假设模型文件在根目录)
|
||||||
let model = onnx()
|
let model = onnx()
|
||||||
.model_for_path("model/common.onnx")? // 这里替换成你提取的 ddddocr 模型名
|
.model_for_path("model/common_huashi666_i64.onnx")? // 这里替换成你提取的 ddddocr 模型名
|
||||||
.into_optimized()?
|
.into_optimized()?
|
||||||
.into_runnable()?;
|
.into_runnable()?;
|
||||||
|
|
||||||
@@ -33,20 +33,60 @@ fn main() -> Result<()> {
|
|||||||
// 注意:这里需要根据 ddddocr 的要求将图片转为 Tensor
|
// 注意:这里需要根据 ddddocr 的要求将图片转为 Tensor
|
||||||
// 简化逻辑:
|
// 简化逻辑:
|
||||||
// let tensor: Tensor = tract_ndarray::Array4::<f32>::zeros((1, 1, 30, 64)).into();
|
// let tensor: Tensor = tract_ndarray::Array4::<f32>::zeros((1, 1, 30, 64)).into();
|
||||||
|
let raw_tensor = &result[0];
|
||||||
// 3. 运行推理
|
// 3. 运行推理
|
||||||
// let result = model.run(tvec!(tensor.into()))?;
|
// let result = model.run(tvec!(tensor.into()))?;
|
||||||
println!("模型输出原始数据: {:?}", result);
|
println!("模型输出原始数据: {:?}", result);
|
||||||
let output = result[0].to_array_view::<i64>()?;
|
let shape = result[0].shape();
|
||||||
let indices: Vec<i64> = output.iter().cloned().collect();
|
println!("模型输出shape数据: {:?}", shape);
|
||||||
|
let datum_type = result[0].datum_type();
|
||||||
|
println!("模型输出datum_type数据: {:?}", datum_type);
|
||||||
|
|
||||||
|
let predicted_indices: Vec<i64> = match raw_tensor.datum_type() {
|
||||||
|
// 情况 1: huashi666 式模型,直接输出 i64 索引
|
||||||
|
DatumType::I64 => {
|
||||||
|
raw_tensor.to_array_view::<i64>()?.iter().cloned().collect()
|
||||||
|
}
|
||||||
|
// 情况 2: sml2h3 原版模型,输出 F32 概率
|
||||||
|
DatumType::F32 => {
|
||||||
|
let view = raw_tensor.to_array_view::<f32>()?;
|
||||||
|
|
||||||
|
// 模仿 Python 的维度判断逻辑
|
||||||
|
if shape.len() == 3 {
|
||||||
|
// 假设形状是 [21, 1, 8210]
|
||||||
|
let steps = shape[0];
|
||||||
|
let classes = shape[2];
|
||||||
|
let array_2d = view.to_shape((
|
||||||
|
(steps, classes),
|
||||||
|
tract_onnx::prelude::tract_ndarray::Order::RowMajor
|
||||||
|
))?;
|
||||||
|
|
||||||
|
array_2d.outer_iter()
|
||||||
|
.map(|row| {
|
||||||
|
row.iter().enumerate()
|
||||||
|
.max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap())
|
||||||
|
.map(|(idx, _)| idx as i64).unwrap()
|
||||||
|
})
|
||||||
|
.collect()
|
||||||
|
} else {
|
||||||
|
// 其他形状处理...
|
||||||
|
vec![]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
_ => return Err(anyhow!("不支持的输出类型")),
|
||||||
|
};
|
||||||
|
|
||||||
|
// let output = result[0].to_array_view::<i64>()?;
|
||||||
|
// println!("模型输出原始数据2: {:?}", output);
|
||||||
|
// let indices: Vec<i64> = output.iter().cloned().collect();
|
||||||
|
|
||||||
// 2. 将视图转为切片并调用函数
|
// 2. 将视图转为切片并调用函数
|
||||||
let code = decode_ctc(&indices);
|
let code = decode_ctc(&predicted_indices);
|
||||||
println!("indices模型输出原始数据: {:?}", indices);
|
println!("indices模型输出原始数据: {:?}", predicted_indices);
|
||||||
println!("最终识别出的验证码是: {}", code);
|
println!("最终识别出的验证码是: {}", code);
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
// common_huashi666_i64
|
||||||
fn decode_ctc(indices: &[i64]) -> String {
|
fn decode_ctc(indices: &[i64]) -> String {
|
||||||
let mut res = String::new();
|
let mut res = String::new();
|
||||||
let mut last_idx: i64 = -1;
|
let mut last_idx: i64 = -1;
|
||||||
|
|||||||
0
src/model.rs
Normal file
0
src/model.rs
Normal file
0
src/utils.rs
Normal file
0
src/utils.rs
Normal file
Reference in New Issue
Block a user