feat: 重构模型初始化逻辑
- 重构 DdddOcr。 - 新增 DdddOcrBuilder。 - 其他优化
This commit is contained in:
@@ -1,5 +1,5 @@
|
||||
fn main() {
|
||||
let ocr = ddddocr_rs::DdddOcr::new("model/common.onnx").unwrap();
|
||||
let ocr = ddddocr_rs::DdddOcrBuilder::new().build().unwrap();
|
||||
let img = image::open("samples/code3.png").unwrap();
|
||||
println!("Result: {}", ocr.classification(&img).unwrap());
|
||||
}
|
||||
40
src/base.rs
Normal file
40
src/base.rs
Normal file
@@ -0,0 +1,40 @@
|
||||
pub trait ModelArgs {
|
||||
// 获取模型路径
|
||||
fn model_path(&self) -> &str;
|
||||
|
||||
// 获取字符集(由于 Det 没有,所以返回 Option)
|
||||
fn charset(&self) -> Option<&str>;
|
||||
}
|
||||
|
||||
pub struct HasCharset {
|
||||
pub charset: String,
|
||||
} // 给 Ocr 和 Custom 用
|
||||
pub struct NoCharset; // 给 Det 用
|
||||
|
||||
pub struct Model<T> {
|
||||
pub path: String,
|
||||
pub metadata: T,
|
||||
}
|
||||
// 针对有字符集的模型 (Ocr / Custom)
|
||||
impl ModelArgs for Model<HasCharset> {
|
||||
fn model_path(&self) -> &str {
|
||||
&self.path
|
||||
}
|
||||
fn charset(&self) -> Option<&str> {
|
||||
Some(&self.metadata.charset)
|
||||
}
|
||||
}
|
||||
|
||||
// 针对没有字符集的模型 (Det)
|
||||
impl ModelArgs for Model<NoCharset> {
|
||||
fn model_path(&self) -> &str {
|
||||
&self.path
|
||||
}
|
||||
fn charset(&self) -> Option<&str> {
|
||||
None
|
||||
}
|
||||
}
|
||||
|
||||
pub type OcrModel = Model<HasCharset>;
|
||||
pub type DetModel = Model<NoCharset>;
|
||||
pub type CustomModel = Model<HasCharset>; // Ocr 和 Custom 逻辑一致,可以复用
|
||||
@@ -514,3 +514,6 @@ pub const CHARSET_BETA: &[&str] = &[
|
||||
"谬", "溝", "言", "哽", "婿", "猿", "跗", "獴", "俜", "呙", "弗", "凿", "窭", "铌", "友", "唉",
|
||||
"怫", "荘",
|
||||
];
|
||||
pub fn get_default_charset() -> Vec<String> {
|
||||
CHARSET_BETA.iter().map(|&s| s.to_string()).collect()
|
||||
}
|
||||
24
src/det_model.rs
Normal file
24
src/det_model.rs
Normal file
@@ -0,0 +1,24 @@
|
||||
use image::DynamicImage;
|
||||
use crate::model_loader::{ModelLoader, ModelSession, ModelType};
|
||||
use tract_onnx::prelude::{Graph, RunnableModel, TypedFact, TypedOp};
|
||||
use crate::ocr_model::Ocr;
|
||||
|
||||
pub struct Det {
|
||||
session: RunnableModel<TypedFact, Box<dyn TypedOp>, Graph<TypedFact, Box<dyn TypedOp>>>,
|
||||
}
|
||||
impl ModelSession for Det {
|
||||
fn predict(&self, image: &DynamicImage, png_fix: bool) -> Result<String, anyhow::Error> {
|
||||
// OCR 识别逻辑 + CTC 解码
|
||||
Ok("ocr result".to_string())
|
||||
}
|
||||
|
||||
fn get_model_type(&self) -> ModelType {
|
||||
todo!()
|
||||
}
|
||||
}
|
||||
impl Det {
|
||||
pub fn new(model_path: String) -> Result<Self, anyhow::Error> {
|
||||
let session = ModelLoader::load_model(&model_path)?.session;
|
||||
Ok(Self { session })
|
||||
}
|
||||
}
|
||||
224
src/lib.rs
224
src/lib.rs
@@ -1,169 +1,95 @@
|
||||
pub mod base;
|
||||
mod charset;
|
||||
mod det_model;
|
||||
mod image_io;
|
||||
mod image_processor;
|
||||
mod model;
|
||||
mod model_loader;
|
||||
mod ocr_model;
|
||||
mod utils;
|
||||
|
||||
use crate::image_io::png_rgba_white_preprocess;
|
||||
use crate::image_processor::{convert_to_grayscale, resize_image};
|
||||
use anyhow::{Context, Result};
|
||||
use image::{DynamicImage, imageops::FilterType};
|
||||
use tract_onnx::prelude::*;
|
||||
use anyhow::Result;
|
||||
use image::DynamicImage;
|
||||
|
||||
// 关键点:直接使用 tract 重导出的 ndarray
|
||||
use tract_onnx::prelude::tract_ndarray::s;
|
||||
pub struct DdddOcr {
|
||||
session: RunnableModel<TypedFact, Box<dyn TypedOp>, Graph<TypedFact, Box<dyn TypedOp>>>,
|
||||
use crate::det_model::Det;
|
||||
use crate::model_loader::ModelSession;
|
||||
use crate::ocr_model::Ocr;
|
||||
use crate::charset::get_default_charset;
|
||||
pub enum ModeType {
|
||||
/// 默认 OCR (使用内置路径)
|
||||
Ocr {
|
||||
path: String,
|
||||
charset: Vec<String>,
|
||||
},
|
||||
Det {
|
||||
path: String,
|
||||
},
|
||||
/// 自定义 OCR (路径由用户提供)
|
||||
CustomOcr {
|
||||
path: String,
|
||||
charset: Vec<String>,
|
||||
},
|
||||
}
|
||||
|
||||
impl DdddOcr {
|
||||
pub fn new<P>(model_path: P) -> Result<Self>
|
||||
where
|
||||
P: AsRef<std::path::Path>,
|
||||
{
|
||||
let session = onnx()
|
||||
.model_for_path(model_path)
|
||||
.with_context(|| "加载 ONNX 模型失败,请检查路径是否正确")?
|
||||
.into_optimized()?
|
||||
.into_runnable()?;
|
||||
Ok(Self { session })
|
||||
pub struct DdddOcrBuilder {
|
||||
mode: ModeType,
|
||||
}
|
||||
|
||||
impl DdddOcrBuilder {
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
mode: ModeType::Ocr {
|
||||
path: "models/common.onnx".to_string(),
|
||||
charset: get_default_charset(),
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
pub fn classification(&self, img: &DynamicImage) -> Result<String> {
|
||||
let tensor = self.preprocess_image(img, false)?;
|
||||
|
||||
// let result = self.session.run(tvec!(tensor.into()))?;
|
||||
// 3. 解析结果
|
||||
// let output = result[0].to_array_view::<i64>()?;
|
||||
let output = self.inference(tensor)?;
|
||||
let output2 = self.process_text_output(&output)?;
|
||||
Ok(Self::ctc_decode_indices(&output2))
|
||||
/// 切换为检测模式
|
||||
pub fn det(mut self) -> Self {
|
||||
self.mode = ModeType::Det {
|
||||
path: "models/common_det.onnx".to_string(),
|
||||
};
|
||||
self
|
||||
}
|
||||
/// 对应 Python 的 _preprocess_image
|
||||
/// 负责:透明背景修复 -> 灰度化 -> 按比例 Resize -> 归一化 -> 4维张量转换
|
||||
fn preprocess_image(&self, img: &DynamicImage, png_fix: bool) -> Result<Tensor> {
|
||||
// A. 修复 PNG 透明背景 (内部逻辑你之前已实现)
|
||||
let _ = if png_fix && img.color().has_alpha() {
|
||||
png_rgba_white_preprocess(img)
|
||||
} else {
|
||||
img.clone()
|
||||
|
||||
/// 设置自定义 OCR 路径
|
||||
pub fn custom_ocr(mut self, path: String, charset: Vec<String>) -> Self {
|
||||
// 直接重写枚举,替换掉之前的 Ocr 或 Det
|
||||
self.mode = ModeType::CustomOcr { path, charset };
|
||||
self
|
||||
}
|
||||
|
||||
/// 核心初始化逻辑
|
||||
pub fn build(self) -> Result<DdddOcr> {
|
||||
let session: Box<dyn ModelSession> = match self.mode {
|
||||
ModeType::Ocr { path, charset } => Box::new(Ocr::new(path, charset)?),
|
||||
ModeType::Det { path } => Box::new(Det::new(path)?),
|
||||
ModeType::CustomOcr { path, charset } => Box::new(Ocr::new(path, charset)?),
|
||||
};
|
||||
|
||||
let h = 64u32;
|
||||
let w = (img.width() as f32 * (h as f32 / img.height() as f32)) as u32;
|
||||
let gray_img = convert_to_grayscale(img);
|
||||
let resized = resize_image(&gray_img, w, h);
|
||||
// resized.save("debug_preprocessed.png").unwrap();
|
||||
// 1. 预处理:转灰度 -> Resize -> 归一化
|
||||
// let resized = img.resize_exact(w, h, FilterType::Lanczos3).to_luma8();
|
||||
|
||||
// 使用 tract_ndarray 构造,避免版本冲突
|
||||
let array =
|
||||
tract_ndarray::Array4::from_shape_fn((1, 1, h as usize, w as usize), |(_, _, y, x)| {
|
||||
let pixel = resized.get_pixel(x as u32, y as u32)[0] as f32;
|
||||
(pixel / 255.0 - 0.5) / 0.5
|
||||
});
|
||||
|
||||
let tensor = Tensor::from(array);
|
||||
|
||||
Ok(tensor)
|
||||
}
|
||||
/// 对应 Python 的 _inference
|
||||
fn inference(&self, tensor: Tensor) -> Result<Tensor> {
|
||||
// tract 的 run 会返回一个 Vec<TValue>,我们通常只需要第一个输出
|
||||
// let result = self.session.run(tvec!(tensor.into()))?;
|
||||
let mut result = self
|
||||
.session
|
||||
.run(tvec!(tensor.into()))
|
||||
.context("执行模型推理失败")?;
|
||||
println!("模型输出原始数据: {:?}", result);
|
||||
Ok(result.remove(0).into_tensor())
|
||||
}
|
||||
/// 核心解析逻辑:将模型输出的各种维度/类型的 Tensor 转为字符索引序列
|
||||
fn process_text_output(&self, raw_tensor: &Tensor) -> Result<Vec<i64>> {
|
||||
let shape = raw_tensor.shape();
|
||||
println!("模型输出shape数据: {:?}", shape);
|
||||
let datum_type = raw_tensor.datum_type();
|
||||
println!("模型输出datum_type数据: {:?}", datum_type);
|
||||
|
||||
match raw_tensor.datum_type() {
|
||||
// 情况 1: huashi666 式模型,直接输出 i64 索引 (通常是模型内部做好了 Argmax)
|
||||
DatumType::I64 => {
|
||||
let view = raw_tensor.to_array_view::<i64>()?;
|
||||
Ok(view.iter().cloned().collect())
|
||||
}
|
||||
|
||||
// 情况 2: sml2h3 原版模型,输出 F32 概率矩阵
|
||||
DatumType::F32 => {
|
||||
let view = raw_tensor.to_array_view::<f32>()?;
|
||||
let (steps, classes, data_view) = match shape.len() {
|
||||
3 => {
|
||||
if shape[1] == 1 {
|
||||
// 形状: [Steps, 1, Classes] -> 你的原有逻辑
|
||||
(shape[0], shape[2], view.into_dyn())
|
||||
} else if shape[0] == 1 {
|
||||
// 形状: [1, Steps, Classes] -> 另一种常见导出格式
|
||||
(shape[1], shape[2], view.into_dyn())
|
||||
} else {
|
||||
// 默认取第一个 batch: [Batch, Steps, Classes]
|
||||
// 使用 slice 对应 Python 的 output[0, :, :]
|
||||
let sliced = view.slice(s![0, .., ..]);
|
||||
(shape[1], shape[2], sliced.into_dyn())
|
||||
}
|
||||
}
|
||||
2 => {
|
||||
// 形状: [Steps, Classes] -> 已经剥离了 Batch 维度
|
||||
(shape[0], shape[1], view.into_dyn())
|
||||
}
|
||||
_ => return Err(anyhow::anyhow!("不支持的输出维度: {:?}", shape)),
|
||||
};
|
||||
let array_2d = data_view.to_shape((steps, classes))?;
|
||||
//
|
||||
// 对每一行执行 Argmax (寻找概率最大的字符索引)
|
||||
let indices = array_2d
|
||||
.outer_iter()
|
||||
.map(|row| {
|
||||
row.iter()
|
||||
.enumerate()
|
||||
.max_by(|(_, a), (_, b)| {
|
||||
a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal)
|
||||
})
|
||||
.map(|(idx, _)| idx as i64)
|
||||
.unwrap_or(0)
|
||||
})
|
||||
.collect();
|
||||
Ok(indices)
|
||||
}
|
||||
_ => Err(anyhow::anyhow!(
|
||||
"不支持的模型输出数据类型: {:?}",
|
||||
raw_tensor.datum_type()
|
||||
)),
|
||||
}
|
||||
}
|
||||
fn ctc_decode_indices(predicted_indices: &[i64]) -> String {
|
||||
println!("indices模型输出原始数据: {:?}", predicted_indices);
|
||||
|
||||
use crate::charset::CHARSET_BETA;
|
||||
// 对应 _ctc_decode_indices 的逻辑:去重、去 blank (0)
|
||||
let mut res = String::new();
|
||||
let mut prev_idx: i64 = -1;
|
||||
|
||||
for &idx in predicted_indices {
|
||||
// 1. 跳过连续重复的索引
|
||||
// 2. 跳过 blank 字符 (假设索引 0 是 blank)
|
||||
if idx != prev_idx && idx != 0 {
|
||||
if let Ok(u_idx) = usize::try_from(idx) {
|
||||
if let Some(&char_str) = CHARSET_BETA.get(u_idx) {
|
||||
res.push_str(char_str);
|
||||
}
|
||||
}
|
||||
}
|
||||
prev_idx = idx;
|
||||
}
|
||||
println!("最终识别出的验证码是: {}", res);
|
||||
res
|
||||
Ok(DdddOcr { session })
|
||||
}
|
||||
}
|
||||
|
||||
pub struct DdddOcr {
|
||||
session: Box<dyn ModelSession>,
|
||||
}
|
||||
impl DdddOcr {
|
||||
pub fn classification(&self, img: &DynamicImage) -> Result<String> {
|
||||
self.session.predict(img, false)
|
||||
|
||||
// let tensor = self.preprocess_image(img, false)?;
|
||||
//
|
||||
// // let result = self.session.run(tvec!(tensor.into()))?;
|
||||
// // 3. 解析结果
|
||||
// // let output = result[0].to_array_view::<i64>()?;
|
||||
// let output = self.inference(tensor)?;
|
||||
// let output2 = self.process_text_output(&output)?;
|
||||
// Ok(Self::ctc_decode_indices(&output2))
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
|
||||
40
src/model_loader.rs
Normal file
40
src/model_loader.rs
Normal file
@@ -0,0 +1,40 @@
|
||||
use anyhow::Context;
|
||||
use image::DynamicImage;
|
||||
use tract_onnx::onnx;
|
||||
use tract_onnx::prelude::*;
|
||||
// 关键点:直接使用 tract 重导出的 ndarray
|
||||
use crate::image_io::png_rgba_white_preprocess;
|
||||
use crate::image_processor::{convert_to_grayscale, resize_image};
|
||||
use std::collections::HashMap;
|
||||
use tract_onnx::prelude::tract_ndarray::s;
|
||||
|
||||
/// OCR 模型:包含路径和字符集
|
||||
|
||||
pub enum ModelType {
|
||||
Ocr,
|
||||
Det,
|
||||
Custom,
|
||||
}
|
||||
// 定义统一的 trait
|
||||
pub trait ModelSession {
|
||||
fn predict(&self, image: &DynamicImage, png_fix: bool) -> Result<String, anyhow::Error>;
|
||||
fn get_model_type(&self) -> ModelType;
|
||||
}
|
||||
|
||||
pub struct ModelLoader {
|
||||
pub session: RunnableModel<TypedFact, Box<dyn TypedOp>, Graph<TypedFact, Box<dyn TypedOp>>>,
|
||||
}
|
||||
|
||||
impl ModelLoader {
|
||||
pub fn load_model<P>(model_path: P) -> anyhow::Result<Self>
|
||||
where
|
||||
P: AsRef<std::path::Path>,
|
||||
{
|
||||
let session = onnx()
|
||||
.model_for_path(model_path)
|
||||
.with_context(|| "加载 ONNX 模型失败,请检查路径是否正确")?
|
||||
.into_optimized()?
|
||||
.into_runnable()?;
|
||||
Ok(Self { session })
|
||||
}
|
||||
}
|
||||
248
src/ocr_model.rs
Normal file
248
src/ocr_model.rs
Normal file
@@ -0,0 +1,248 @@
|
||||
use crate::image_io::png_rgba_white_preprocess;
|
||||
use crate::image_processor::{convert_to_grayscale, resize_image};
|
||||
use crate::model_loader::{ModelLoader, ModelSession, ModelType};
|
||||
use anyhow::Context;
|
||||
use image::DynamicImage;
|
||||
use tract_onnx::prelude::tract_ndarray::s;
|
||||
use tract_onnx::prelude::{
|
||||
DatumType, Graph, IntoTensor, RunnableModel, Tensor, TypedFact, TypedOp, tract_ndarray, tvec,
|
||||
};
|
||||
use crate::base::ModelArgs;
|
||||
|
||||
|
||||
// 颜色过滤的自定义范围:(低值RGB, 高值RGB)
|
||||
pub type ColorRange = ((u8, u8, u8), (u8, u8, u8));
|
||||
|
||||
// 字符集范围类型
|
||||
#[derive(Debug, Clone)]
|
||||
pub enum CharsetRange {
|
||||
All, // 所有字符
|
||||
Digit, // 数字
|
||||
Letter, // 字母
|
||||
Alphanumeric, // 字母数字
|
||||
Single(String), // 单字符串
|
||||
Multiple(Vec<String>), // 多个字符串
|
||||
Range(char, char), // 字符范围
|
||||
Custom(Vec<char>), // 自定义字符列表
|
||||
}
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct PredictArgs{
|
||||
/// 是否修复PNG格式问题
|
||||
pub png_fix: bool,
|
||||
/// 是否返回概率信息
|
||||
pub probability: bool,
|
||||
/// 颜色过滤:保留的颜色列表
|
||||
pub color_filter_colors: Option<Vec<String>>,
|
||||
/// 颜色过滤:自定义RGB范围
|
||||
pub color_filter_custom_ranges: Option<Vec<ColorRange>>,
|
||||
/// 字符集范围
|
||||
pub charset_range: Option<CharsetRange>,
|
||||
}
|
||||
|
||||
impl Default for PredictArgs {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
png_fix: false,
|
||||
probability: false,
|
||||
color_filter_colors: None,
|
||||
color_filter_custom_ranges: None,
|
||||
charset_range: None,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl PredictArgs {
|
||||
pub fn new() -> Self {
|
||||
Self::default()
|
||||
}
|
||||
|
||||
// Builder 模式方法
|
||||
pub fn png_fix(mut self, enabled: bool) -> Self {
|
||||
self.png_fix = enabled;
|
||||
self
|
||||
}
|
||||
|
||||
pub fn probability(mut self, enabled: bool) -> Self {
|
||||
self.probability = enabled;
|
||||
self
|
||||
}
|
||||
|
||||
pub fn color_filter_colors(mut self, colors: Vec<String>) -> Self {
|
||||
self.color_filter_colors = Some(colors);
|
||||
self
|
||||
}
|
||||
|
||||
pub fn color_filter_custom_ranges(mut self, ranges: Vec<ColorRange>) -> Self {
|
||||
self.color_filter_custom_ranges = Some(ranges);
|
||||
self
|
||||
}
|
||||
|
||||
pub fn charset_range(mut self, range: CharsetRange) -> Self {
|
||||
self.charset_range = Some(range);
|
||||
self
|
||||
}
|
||||
|
||||
// 便捷构造方法
|
||||
pub fn quick() -> Self {
|
||||
Self::default()
|
||||
}
|
||||
|
||||
pub fn with_probability() -> Self {
|
||||
Self::default().probability(true)
|
||||
}
|
||||
|
||||
pub fn with_png_fix() -> Self {
|
||||
Self::default().png_fix(true)
|
||||
}
|
||||
}
|
||||
pub struct Ocr {
|
||||
session: RunnableModel<TypedFact, Box<dyn TypedOp>, Graph<TypedFact, Box<dyn TypedOp>>>,
|
||||
charset: Vec<String>,
|
||||
}
|
||||
impl ModelSession for Ocr {
|
||||
fn predict(&self, image: &DynamicImage, png_fix: bool) -> Result<String, anyhow::Error> {
|
||||
let tensor = self.preprocess_image(image, png_fix)?;
|
||||
//
|
||||
// let result = self.session.run(tvec!(tensor.into()))?;
|
||||
// // 3. 解析结果
|
||||
// // let output = result[0].to_array_view::<i64>()?;
|
||||
let output = self.inference(tensor)?;
|
||||
let output2 = self.process_text_output(&output)?;
|
||||
Ok(Self::ctc_decode_indices(&output2))
|
||||
// Ok("ocr result".to_string())
|
||||
}
|
||||
|
||||
fn get_model_type(&self) -> ModelType {
|
||||
ModelType::Ocr
|
||||
}
|
||||
}
|
||||
impl Ocr {
|
||||
pub fn new(model_path: String, charset: Vec<String>) -> Result<Self, anyhow::Error> {
|
||||
let session = ModelLoader::load_model(&model_path)?.session;
|
||||
Ok(Self { session, charset })
|
||||
}
|
||||
/// 对应 Python 的 _preprocess_image
|
||||
/// 负责:透明背景修复 -> 灰度化 -> 按比例 Resize -> 归一化 -> 4维张量转换
|
||||
fn preprocess_image(&self, img: &DynamicImage, png_fix: bool) -> anyhow::Result<Tensor> {
|
||||
// A. 修复 PNG 透明背景 (内部逻辑你之前已实现)
|
||||
let _ = if png_fix && img.color().has_alpha() {
|
||||
png_rgba_white_preprocess(img)
|
||||
} else {
|
||||
img.clone()
|
||||
};
|
||||
|
||||
let h = 64u32;
|
||||
let w = (img.width() as f32 * (h as f32 / img.height() as f32)) as u32;
|
||||
let gray_img = convert_to_grayscale(img);
|
||||
let resized = resize_image(&gray_img, w, h);
|
||||
// resized.save("debug_preprocessed.png").unwrap();
|
||||
// 1. 预处理:转灰度 -> Resize -> 归一化
|
||||
// let resized = img.resize_exact(w, h, FilterType::Lanczos3).to_luma8();
|
||||
|
||||
// 使用 tract_ndarray 构造,避免版本冲突
|
||||
let array =
|
||||
tract_ndarray::Array4::from_shape_fn((1, 1, h as usize, w as usize), |(_, _, y, x)| {
|
||||
let pixel = resized.get_pixel(x as u32, y as u32)[0] as f32;
|
||||
(pixel / 255.0 - 0.5) / 0.5
|
||||
});
|
||||
|
||||
let tensor = Tensor::from(array);
|
||||
|
||||
Ok(tensor)
|
||||
}
|
||||
/// 对应 Python 的 _inference
|
||||
fn inference(&self, tensor: Tensor) -> anyhow::Result<Tensor> {
|
||||
// tract 的 run 会返回一个 Vec<TValue>,我们通常只需要第一个输出
|
||||
// let result = self.session.run(tvec!(tensor.into()))?;
|
||||
let mut result = self
|
||||
.session
|
||||
.run(tvec!(tensor.into()))
|
||||
.context("执行模型推理失败")?;
|
||||
println!("模型输出原始数据: {:?}", result);
|
||||
Ok(result.remove(0).into_tensor())
|
||||
}
|
||||
/// 核心解析逻辑:将模型输出的各种维度/类型的 Tensor 转为字符索引序列
|
||||
fn process_text_output(&self, raw_tensor: &Tensor) -> anyhow::Result<Vec<i64>> {
|
||||
let shape = raw_tensor.shape();
|
||||
println!("模型输出shape数据: {:?}", shape);
|
||||
let datum_type = raw_tensor.datum_type();
|
||||
println!("模型输出datum_type数据: {:?}", datum_type);
|
||||
|
||||
match raw_tensor.datum_type() {
|
||||
// 情况 1: huashi666 式模型,直接输出 i64 索引 (通常是模型内部做好了 Argmax)
|
||||
DatumType::I64 => {
|
||||
let view = raw_tensor.to_array_view::<i64>()?;
|
||||
Ok(view.iter().cloned().collect())
|
||||
}
|
||||
|
||||
// 情况 2: sml2h3 原版模型,输出 F32 概率矩阵
|
||||
DatumType::F32 => {
|
||||
let view = raw_tensor.to_array_view::<f32>()?;
|
||||
let (steps, classes, data_view) = match shape.len() {
|
||||
3 => {
|
||||
if shape[1] == 1 {
|
||||
// 形状: [Steps, 1, Classes] -> 你的原有逻辑
|
||||
(shape[0], shape[2], view.into_dyn())
|
||||
} else if shape[0] == 1 {
|
||||
// 形状: [1, Steps, Classes] -> 另一种常见导出格式
|
||||
(shape[1], shape[2], view.into_dyn())
|
||||
} else {
|
||||
// 默认取第一个 batch: [Batch, Steps, Classes]
|
||||
// 使用 slice 对应 Python 的 output[0, :, :]
|
||||
let sliced = view.slice(s![0, .., ..]);
|
||||
(shape[1], shape[2], sliced.into_dyn())
|
||||
}
|
||||
}
|
||||
2 => {
|
||||
// 形状: [Steps, Classes] -> 已经剥离了 Batch 维度
|
||||
(shape[0], shape[1], view.into_dyn())
|
||||
}
|
||||
_ => return Err(anyhow::anyhow!("不支持的输出维度: {:?}", shape)),
|
||||
};
|
||||
let array_2d = data_view.to_shape((steps, classes))?;
|
||||
//
|
||||
// 对每一行执行 Argmax (寻找概率最大的字符索引)
|
||||
let indices = array_2d
|
||||
.outer_iter()
|
||||
.map(|row| {
|
||||
row.iter()
|
||||
.enumerate()
|
||||
.max_by(|(_, a), (_, b)| {
|
||||
a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal)
|
||||
})
|
||||
.map(|(idx, _)| idx as i64)
|
||||
.unwrap_or(0)
|
||||
})
|
||||
.collect();
|
||||
Ok(indices)
|
||||
}
|
||||
_ => Err(anyhow::anyhow!(
|
||||
"不支持的模型输出数据类型: {:?}",
|
||||
raw_tensor.datum_type()
|
||||
)),
|
||||
}
|
||||
}
|
||||
fn ctc_decode_indices(predicted_indices: &[i64]) -> String {
|
||||
println!("indices模型输出原始数据: {:?}", predicted_indices);
|
||||
|
||||
use crate::charset::CHARSET_BETA;
|
||||
// 对应 _ctc_decode_indices 的逻辑:去重、去 blank (0)
|
||||
let mut res = String::new();
|
||||
let mut prev_idx: i64 = -1;
|
||||
|
||||
for &idx in predicted_indices {
|
||||
// 1. 跳过连续重复的索引
|
||||
// 2. 跳过 blank 字符 (假设索引 0 是 blank)
|
||||
if idx != prev_idx && idx != 0 {
|
||||
if let Ok(u_idx) = usize::try_from(idx) {
|
||||
if let Some(&char_str) = CHARSET_BETA.get(u_idx) {
|
||||
res.push_str(char_str);
|
||||
}
|
||||
}
|
||||
}
|
||||
prev_idx = idx;
|
||||
}
|
||||
println!("最终识别出的验证码是: {}", res);
|
||||
res
|
||||
}
|
||||
}
|
||||
@@ -1,9 +1,9 @@
|
||||
use ddddocr_rs::DdddOcr; // 假设你的包名是这个
|
||||
use ddddocr_rs::{DdddOcr, DdddOcrBuilder}; // 假设你的包名是这个
|
||||
|
||||
#[test]
|
||||
fn test_full_classification() {
|
||||
// 1. 初始化模型
|
||||
let ocr = DdddOcr::new("model/common.onnx").expect("模型加载失败");
|
||||
let ocr = DdddOcrBuilder::new().build().expect("模型加载失败");
|
||||
|
||||
// 2. 加载测试图片
|
||||
let img = image::open("samples/code3.png").expect("测试图片不存在");
|
||||
|
||||
Reference in New Issue
Block a user