use crate::image_io::png_rgba_white_preprocess; use crate::image_processor::{convert_to_grayscale, resize_image}; use crate::model_loader::{ModelLoader, ModelSession, ModelType}; use anyhow::Context; use image::DynamicImage; use tract_onnx::prelude::tract_ndarray::s; use tract_onnx::prelude::{ DatumType, Graph, IntoTensor, RunnableModel, Tensor, TypedFact, TypedOp, tract_ndarray, tvec, }; use crate::base::ModelArgs; // 颜色过滤的自定义范围:(低值RGB, 高值RGB) pub type ColorRange = ((u8, u8, u8), (u8, u8, u8)); // 字符集范围类型 #[derive(Debug, Clone)] pub enum CharsetRange { All, // 所有字符 Digit, // 数字 Letter, // 字母 Alphanumeric, // 字母数字 Single(String), // 单字符串 Multiple(Vec), // 多个字符串 Range(char, char), // 字符范围 Custom(Vec), // 自定义字符列表 } #[derive(Debug, Clone)] pub struct PredictArgs{ /// 是否修复PNG格式问题 pub png_fix: bool, /// 是否返回概率信息 pub probability: bool, /// 颜色过滤:保留的颜色列表 pub color_filter_colors: Option>, /// 颜色过滤:自定义RGB范围 pub color_filter_custom_ranges: Option>, /// 字符集范围 pub charset_range: Option, } impl Default for PredictArgs { fn default() -> Self { Self { png_fix: false, probability: false, color_filter_colors: None, color_filter_custom_ranges: None, charset_range: None, } } } impl PredictArgs { pub fn new() -> Self { Self::default() } // Builder 模式方法 pub fn png_fix(mut self, enabled: bool) -> Self { self.png_fix = enabled; self } pub fn probability(mut self, enabled: bool) -> Self { self.probability = enabled; self } pub fn color_filter_colors(mut self, colors: Vec) -> Self { self.color_filter_colors = Some(colors); self } pub fn color_filter_custom_ranges(mut self, ranges: Vec) -> Self { self.color_filter_custom_ranges = Some(ranges); self } pub fn charset_range(mut self, range: CharsetRange) -> Self { self.charset_range = Some(range); self } // 便捷构造方法 pub fn quick() -> Self { Self::default() } pub fn with_probability() -> Self { Self::default().probability(true) } pub fn with_png_fix() -> Self { Self::default().png_fix(true) } } pub struct Ocr { session: RunnableModel, Graph>>, charset: Vec, } impl ModelSession for Ocr { fn predict(&self, image: &DynamicImage, png_fix: bool) -> Result { let tensor = self.preprocess_image(image, png_fix)?; // // let result = self.session.run(tvec!(tensor.into()))?; // // 3. 解析结果 // // let output = result[0].to_array_view::()?; let output = self.inference(tensor)?; let output2 = self.process_text_output(&output)?; Ok(Self::ctc_decode_indices(&output2)) // Ok("ocr result".to_string()) } fn get_model_type(&self) -> ModelType { ModelType::Ocr } } impl Ocr { pub fn new(model_path: String, charset: Vec) -> Result { let session = ModelLoader::load_model(&model_path)?.session; Ok(Self { session, charset }) } /// 对应 Python 的 _preprocess_image /// 负责:透明背景修复 -> 灰度化 -> 按比例 Resize -> 归一化 -> 4维张量转换 fn preprocess_image(&self, img: &DynamicImage, png_fix: bool) -> anyhow::Result { // A. 修复 PNG 透明背景 (内部逻辑你之前已实现) let _ = if png_fix && img.color().has_alpha() { png_rgba_white_preprocess(img) } else { img.clone() }; let h = 64u32; let w = (img.width() as f32 * (h as f32 / img.height() as f32)) as u32; let gray_img = convert_to_grayscale(img); let resized = resize_image(&gray_img, w, h); // resized.save("debug_preprocessed.png").unwrap(); // 1. 预处理:转灰度 -> Resize -> 归一化 // let resized = img.resize_exact(w, h, FilterType::Lanczos3).to_luma8(); // 使用 tract_ndarray 构造,避免版本冲突 let array = tract_ndarray::Array4::from_shape_fn((1, 1, h as usize, w as usize), |(_, _, y, x)| { let pixel = resized.get_pixel(x as u32, y as u32)[0] as f32; (pixel / 255.0 - 0.5) / 0.5 }); let tensor = Tensor::from(array); Ok(tensor) } /// 对应 Python 的 _inference fn inference(&self, tensor: Tensor) -> anyhow::Result { // tract 的 run 会返回一个 Vec,我们通常只需要第一个输出 // let result = self.session.run(tvec!(tensor.into()))?; let mut result = self .session .run(tvec!(tensor.into())) .context("执行模型推理失败")?; println!("模型输出原始数据: {:?}", result); Ok(result.remove(0).into_tensor()) } /// 核心解析逻辑:将模型输出的各种维度/类型的 Tensor 转为字符索引序列 fn process_text_output(&self, raw_tensor: &Tensor) -> anyhow::Result> { let shape = raw_tensor.shape(); println!("模型输出shape数据: {:?}", shape); let datum_type = raw_tensor.datum_type(); println!("模型输出datum_type数据: {:?}", datum_type); match raw_tensor.datum_type() { // 情况 1: huashi666 式模型,直接输出 i64 索引 (通常是模型内部做好了 Argmax) DatumType::I64 => { let view = raw_tensor.to_array_view::()?; Ok(view.iter().cloned().collect()) } // 情况 2: sml2h3 原版模型,输出 F32 概率矩阵 DatumType::F32 => { let view = raw_tensor.to_array_view::()?; let (steps, classes, data_view) = match shape.len() { 3 => { if shape[1] == 1 { // 形状: [Steps, 1, Classes] -> 你的原有逻辑 (shape[0], shape[2], view.into_dyn()) } else if shape[0] == 1 { // 形状: [1, Steps, Classes] -> 另一种常见导出格式 (shape[1], shape[2], view.into_dyn()) } else { // 默认取第一个 batch: [Batch, Steps, Classes] // 使用 slice 对应 Python 的 output[0, :, :] let sliced = view.slice(s![0, .., ..]); (shape[1], shape[2], sliced.into_dyn()) } } 2 => { // 形状: [Steps, Classes] -> 已经剥离了 Batch 维度 (shape[0], shape[1], view.into_dyn()) } _ => return Err(anyhow::anyhow!("不支持的输出维度: {:?}", shape)), }; let array_2d = data_view.to_shape((steps, classes))?; // // 对每一行执行 Argmax (寻找概率最大的字符索引) let indices = array_2d .outer_iter() .map(|row| { row.iter() .enumerate() .max_by(|(_, a), (_, b)| { a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal) }) .map(|(idx, _)| idx as i64) .unwrap_or(0) }) .collect(); Ok(indices) } _ => Err(anyhow::anyhow!( "不支持的模型输出数据类型: {:?}", raw_tensor.datum_type() )), } } fn ctc_decode_indices(predicted_indices: &[i64]) -> String { println!("indices模型输出原始数据: {:?}", predicted_indices); use crate::charset::CHARSET_BETA; // 对应 _ctc_decode_indices 的逻辑:去重、去 blank (0) let mut res = String::new(); let mut prev_idx: i64 = -1; for &idx in predicted_indices { // 1. 跳过连续重复的索引 // 2. 跳过 blank 字符 (假设索引 0 是 blank) if idx != prev_idx && idx != 0 { if let Ok(u_idx) = usize::try_from(idx) { if let Some(&char_str) = CHARSET_BETA.get(u_idx) { res.push_str(char_str); } } } prev_idx = idx; } println!("最终识别出的验证码是: {}", res); res } }