249 lines
9.1 KiB
Rust
249 lines
9.1 KiB
Rust
use crate::image_io::png_rgba_white_preprocess;
|
||
use crate::image_processor::{convert_to_grayscale, resize_image};
|
||
use crate::model_loader::{ModelLoader, ModelSession, ModelType};
|
||
use anyhow::Context;
|
||
use image::DynamicImage;
|
||
use tract_onnx::prelude::tract_ndarray::s;
|
||
use tract_onnx::prelude::{
|
||
DatumType, Graph, IntoTensor, RunnableModel, Tensor, TypedFact, TypedOp, tract_ndarray, tvec,
|
||
};
|
||
use crate::base::ModelArgs;
|
||
|
||
|
||
// 颜色过滤的自定义范围:(低值RGB, 高值RGB)
|
||
pub type ColorRange = ((u8, u8, u8), (u8, u8, u8));
|
||
|
||
// 字符集范围类型
|
||
#[derive(Debug, Clone)]
|
||
pub enum CharsetRange {
|
||
All, // 所有字符
|
||
Digit, // 数字
|
||
Letter, // 字母
|
||
Alphanumeric, // 字母数字
|
||
Single(String), // 单字符串
|
||
Multiple(Vec<String>), // 多个字符串
|
||
Range(char, char), // 字符范围
|
||
Custom(Vec<char>), // 自定义字符列表
|
||
}
|
||
#[derive(Debug, Clone)]
|
||
pub struct PredictArgs{
|
||
/// 是否修复PNG格式问题
|
||
pub png_fix: bool,
|
||
/// 是否返回概率信息
|
||
pub probability: bool,
|
||
/// 颜色过滤:保留的颜色列表
|
||
pub color_filter_colors: Option<Vec<String>>,
|
||
/// 颜色过滤:自定义RGB范围
|
||
pub color_filter_custom_ranges: Option<Vec<ColorRange>>,
|
||
/// 字符集范围
|
||
pub charset_range: Option<CharsetRange>,
|
||
}
|
||
|
||
impl Default for PredictArgs {
|
||
fn default() -> Self {
|
||
Self {
|
||
png_fix: false,
|
||
probability: false,
|
||
color_filter_colors: None,
|
||
color_filter_custom_ranges: None,
|
||
charset_range: None,
|
||
}
|
||
}
|
||
}
|
||
|
||
impl PredictArgs {
|
||
pub fn new() -> Self {
|
||
Self::default()
|
||
}
|
||
|
||
// Builder 模式方法
|
||
pub fn png_fix(mut self, enabled: bool) -> Self {
|
||
self.png_fix = enabled;
|
||
self
|
||
}
|
||
|
||
pub fn probability(mut self, enabled: bool) -> Self {
|
||
self.probability = enabled;
|
||
self
|
||
}
|
||
|
||
pub fn color_filter_colors(mut self, colors: Vec<String>) -> Self {
|
||
self.color_filter_colors = Some(colors);
|
||
self
|
||
}
|
||
|
||
pub fn color_filter_custom_ranges(mut self, ranges: Vec<ColorRange>) -> Self {
|
||
self.color_filter_custom_ranges = Some(ranges);
|
||
self
|
||
}
|
||
|
||
pub fn charset_range(mut self, range: CharsetRange) -> Self {
|
||
self.charset_range = Some(range);
|
||
self
|
||
}
|
||
|
||
// 便捷构造方法
|
||
pub fn quick() -> Self {
|
||
Self::default()
|
||
}
|
||
|
||
pub fn with_probability() -> Self {
|
||
Self::default().probability(true)
|
||
}
|
||
|
||
pub fn with_png_fix() -> Self {
|
||
Self::default().png_fix(true)
|
||
}
|
||
}
|
||
pub struct Ocr {
|
||
session: RunnableModel<TypedFact, Box<dyn TypedOp>, Graph<TypedFact, Box<dyn TypedOp>>>,
|
||
charset: Vec<String>,
|
||
}
|
||
impl ModelSession for Ocr {
|
||
fn predict(&self, image: &DynamicImage, png_fix: bool) -> Result<String, anyhow::Error> {
|
||
let tensor = self.preprocess_image(image, png_fix)?;
|
||
//
|
||
// let result = self.session.run(tvec!(tensor.into()))?;
|
||
// // 3. 解析结果
|
||
// // let output = result[0].to_array_view::<i64>()?;
|
||
let output = self.inference(tensor)?;
|
||
let output2 = self.process_text_output(&output)?;
|
||
Ok(Self::ctc_decode_indices(&output2))
|
||
// Ok("ocr result".to_string())
|
||
}
|
||
|
||
fn get_model_type(&self) -> ModelType {
|
||
ModelType::Ocr
|
||
}
|
||
}
|
||
impl Ocr {
|
||
pub fn new(model_path: String, charset: Vec<String>) -> Result<Self, anyhow::Error> {
|
||
let session = ModelLoader::load_model(&model_path)?.session;
|
||
Ok(Self { session, charset })
|
||
}
|
||
/// 对应 Python 的 _preprocess_image
|
||
/// 负责:透明背景修复 -> 灰度化 -> 按比例 Resize -> 归一化 -> 4维张量转换
|
||
fn preprocess_image(&self, img: &DynamicImage, png_fix: bool) -> anyhow::Result<Tensor> {
|
||
// A. 修复 PNG 透明背景 (内部逻辑你之前已实现)
|
||
let _ = if png_fix && img.color().has_alpha() {
|
||
png_rgba_white_preprocess(img)
|
||
} else {
|
||
img.clone()
|
||
};
|
||
|
||
let h = 64u32;
|
||
let w = (img.width() as f32 * (h as f32 / img.height() as f32)) as u32;
|
||
let gray_img = convert_to_grayscale(img);
|
||
let resized = resize_image(&gray_img, w, h);
|
||
// resized.save("debug_preprocessed.png").unwrap();
|
||
// 1. 预处理:转灰度 -> Resize -> 归一化
|
||
// let resized = img.resize_exact(w, h, FilterType::Lanczos3).to_luma8();
|
||
|
||
// 使用 tract_ndarray 构造,避免版本冲突
|
||
let array =
|
||
tract_ndarray::Array4::from_shape_fn((1, 1, h as usize, w as usize), |(_, _, y, x)| {
|
||
let pixel = resized.get_pixel(x as u32, y as u32)[0] as f32;
|
||
(pixel / 255.0 - 0.5) / 0.5
|
||
});
|
||
|
||
let tensor = Tensor::from(array);
|
||
|
||
Ok(tensor)
|
||
}
|
||
/// 对应 Python 的 _inference
|
||
fn inference(&self, tensor: Tensor) -> anyhow::Result<Tensor> {
|
||
// tract 的 run 会返回一个 Vec<TValue>,我们通常只需要第一个输出
|
||
// let result = self.session.run(tvec!(tensor.into()))?;
|
||
let mut result = self
|
||
.session
|
||
.run(tvec!(tensor.into()))
|
||
.context("执行模型推理失败")?;
|
||
println!("模型输出原始数据: {:?}", result);
|
||
Ok(result.remove(0).into_tensor())
|
||
}
|
||
/// 核心解析逻辑:将模型输出的各种维度/类型的 Tensor 转为字符索引序列
|
||
fn process_text_output(&self, raw_tensor: &Tensor) -> anyhow::Result<Vec<i64>> {
|
||
let shape = raw_tensor.shape();
|
||
println!("模型输出shape数据: {:?}", shape);
|
||
let datum_type = raw_tensor.datum_type();
|
||
println!("模型输出datum_type数据: {:?}", datum_type);
|
||
|
||
match raw_tensor.datum_type() {
|
||
// 情况 1: huashi666 式模型,直接输出 i64 索引 (通常是模型内部做好了 Argmax)
|
||
DatumType::I64 => {
|
||
let view = raw_tensor.to_array_view::<i64>()?;
|
||
Ok(view.iter().cloned().collect())
|
||
}
|
||
|
||
// 情况 2: sml2h3 原版模型,输出 F32 概率矩阵
|
||
DatumType::F32 => {
|
||
let view = raw_tensor.to_array_view::<f32>()?;
|
||
let (steps, classes, data_view) = match shape.len() {
|
||
3 => {
|
||
if shape[1] == 1 {
|
||
// 形状: [Steps, 1, Classes] -> 你的原有逻辑
|
||
(shape[0], shape[2], view.into_dyn())
|
||
} else if shape[0] == 1 {
|
||
// 形状: [1, Steps, Classes] -> 另一种常见导出格式
|
||
(shape[1], shape[2], view.into_dyn())
|
||
} else {
|
||
// 默认取第一个 batch: [Batch, Steps, Classes]
|
||
// 使用 slice 对应 Python 的 output[0, :, :]
|
||
let sliced = view.slice(s![0, .., ..]);
|
||
(shape[1], shape[2], sliced.into_dyn())
|
||
}
|
||
}
|
||
2 => {
|
||
// 形状: [Steps, Classes] -> 已经剥离了 Batch 维度
|
||
(shape[0], shape[1], view.into_dyn())
|
||
}
|
||
_ => return Err(anyhow::anyhow!("不支持的输出维度: {:?}", shape)),
|
||
};
|
||
let array_2d = data_view.to_shape((steps, classes))?;
|
||
//
|
||
// 对每一行执行 Argmax (寻找概率最大的字符索引)
|
||
let indices = array_2d
|
||
.outer_iter()
|
||
.map(|row| {
|
||
row.iter()
|
||
.enumerate()
|
||
.max_by(|(_, a), (_, b)| {
|
||
a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal)
|
||
})
|
||
.map(|(idx, _)| idx as i64)
|
||
.unwrap_or(0)
|
||
})
|
||
.collect();
|
||
Ok(indices)
|
||
}
|
||
_ => Err(anyhow::anyhow!(
|
||
"不支持的模型输出数据类型: {:?}",
|
||
raw_tensor.datum_type()
|
||
)),
|
||
}
|
||
}
|
||
fn ctc_decode_indices(predicted_indices: &[i64]) -> String {
|
||
println!("indices模型输出原始数据: {:?}", predicted_indices);
|
||
|
||
use crate::charset::CHARSET_BETA;
|
||
// 对应 _ctc_decode_indices 的逻辑:去重、去 blank (0)
|
||
let mut res = String::new();
|
||
let mut prev_idx: i64 = -1;
|
||
|
||
for &idx in predicted_indices {
|
||
// 1. 跳过连续重复的索引
|
||
// 2. 跳过 blank 字符 (假设索引 0 是 blank)
|
||
if idx != prev_idx && idx != 0 {
|
||
if let Ok(u_idx) = usize::try_from(idx) {
|
||
if let Some(&char_str) = CHARSET_BETA.get(u_idx) {
|
||
res.push_str(char_str);
|
||
}
|
||
}
|
||
}
|
||
prev_idx = idx;
|
||
}
|
||
println!("最终识别出的验证码是: {}", res);
|
||
res
|
||
}
|
||
}
|