Files
ddddocr-rs/src/lib.rs
CNWei 1a329ca273 refactor: 优化Det算法
- 优化 demo_postprocess,nms算法
- 新增 Slide 滑块识别
- 更新 Cargo.toml 依赖项
2026-05-07 18:00:39 +08:00

125 lines
3.5 KiB
Rust
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

pub mod base;
mod charset;
mod det_model;
mod image_io;
mod image_processor;
mod model;
mod model_loader;
mod ocr_model;
mod utils;
pub mod slide_model;
use anyhow::Result;
use image::DynamicImage;
use std::fmt::{Display, Formatter};
// 关键点:直接使用 tract 重导出的 ndarray
use crate::charset::get_default_charset;
use crate::det_model::Det;
use crate::model_loader::ModelSession;
use crate::ocr_model::Ocr;
pub enum ModelSpec {
/// 默认 OCR (使用内置路径)
OcrModel,
DetModel,
/// 自定义 OCR (路径由用户提供)
CustomOcrModel {
path: String,
charset: Vec<String>,
},
}
impl ModelSpec {
// 将默认路径定义为内部关联常量
const DEFAULT_OCR_PATH: &'static str = "models/common_sml2h3_f32.onnx";
const DEFAULT_DET_PATH: &'static str = "models/common_det.onnx";
}
pub enum Runtime {
Ocr(Ocr),
Det(Det),
}
impl Runtime {
// 统一获取描述的方法
pub fn desc(&self) -> String {
match self {
Runtime::Ocr(s) => s.desc(), // 调用 Ocr 结构体的方法
Runtime::Det(s) => s.desc(), // 调用 Det 结构体的方法
}
}
}
pub struct DdddOcrBuilder {
mode: ModelSpec,
}
impl DdddOcrBuilder {
pub fn new() -> Self {
Self {
mode: ModelSpec::OcrModel,
}
}
/// 切换为检测模式
pub fn det(mut self) -> Self {
self.mode = ModelSpec::DetModel;
self
}
/// 设置自定义 OCR 路径
pub fn custom_ocr(mut self, path: String, charset: Vec<String>) -> Self {
// 直接重写枚举,替换掉之前的 Ocr 或 Det
self.mode = ModelSpec::CustomOcrModel { path, charset };
self
}
/// 核心初始化逻辑
pub fn build(self) -> Result<DdddOcr> {
let runtime = match self.mode {
ModelSpec::OcrModel => Runtime::Ocr(Ocr::new(ModelSpec::DEFAULT_OCR_PATH.into(), get_default_charset())?),
ModelSpec::DetModel => Runtime::Det(Det::new(ModelSpec::DEFAULT_DET_PATH.into())?),
ModelSpec::CustomOcrModel { path, charset } => Runtime::Ocr(Ocr::new(path, charset)?),
};
Ok(DdddOcr { runtime })
}
}
pub struct DdddOcr {
runtime: Runtime,
}
impl Display for DdddOcr {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
write!(f, "DdddOcr(session: {})", self.runtime.desc())
}
}
impl DdddOcr {
pub fn classification(&self, img: &DynamicImage) -> Result<String> {
match &self.runtime {
Runtime::Ocr(s) => s.predict(img, false),
Runtime::Det(_) => Err(anyhow::anyhow!("当前模型是检测模型,无法执行 OCR")),
}
}
pub fn detection(&self, img: &[u8]) -> Result<Vec<Vec<i32>>> {
match &self.runtime {
Runtime::Det(s) => s.predict(img),
Runtime::Ocr(_) => Err(anyhow::anyhow!("当前模型是 OCR 模型,无法执行检测")),
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_ctc_decode_indices() {
// 模拟一个 DdddOcr 实例(如果 decode 不依赖 session可以设为相关函数
// 这里假设你的 decode_ctc 是公开或内部可访问的
let input = vec![1, 1, 0, 1, 2, 2, 0, 2];
// 逻辑:[1, 1] -> 1, [0] -> 跳过, [1] -> 1, [2, 2] -> 2, [0] -> 跳过, [2] -> 2
// 预期结果索引应该是 [1, 1, 2, 2] 对应的字符
// 具体的断言取决于你的 CHARSET_BETA
// let result = dddd.ctc_decode_indices(&input);
// assert_eq!(result, "AABB");
}
}