pub mod base; mod charset; mod det_model; mod image_io; mod image_processor; mod model; mod model_loader; mod ocr_model; mod utils; pub mod slide_model; mod cv2; use anyhow::Result; use image::DynamicImage; use std::fmt::{Display, Formatter}; // 关键点:直接使用 tract 重导出的 ndarray use crate::charset::get_default_charset; use crate::det_model::Det; use crate::model_loader::ModelSession; use crate::ocr_model::Ocr; pub enum ModelSpec { /// 默认 OCR (使用内置路径) OcrModel, DetModel, /// 自定义 OCR (路径由用户提供) CustomOcrModel { path: String, charset: Vec, }, } impl ModelSpec { // 将默认路径定义为内部关联常量 const DEFAULT_OCR_PATH: &'static str = "models/common_sml2h3_f32.onnx"; const DEFAULT_DET_PATH: &'static str = "models/common_det.onnx"; } pub enum Runtime { Ocr(Ocr), Det(Det), } impl Runtime { // 统一获取描述的方法 pub fn desc(&self) -> String { match self { Runtime::Ocr(s) => s.desc(), // 调用 Ocr 结构体的方法 Runtime::Det(s) => s.desc(), // 调用 Det 结构体的方法 } } } pub struct DdddOcrBuilder { mode: ModelSpec, } impl DdddOcrBuilder { pub fn new() -> Self { Self { mode: ModelSpec::OcrModel, } } /// 切换为检测模式 pub fn det(mut self) -> Self { self.mode = ModelSpec::DetModel; self } /// 设置自定义 OCR 路径 pub fn custom_ocr(mut self, path: String, charset: Vec) -> Self { // 直接重写枚举,替换掉之前的 Ocr 或 Det self.mode = ModelSpec::CustomOcrModel { path, charset }; self } /// 核心初始化逻辑 pub fn build(self) -> Result { let runtime = match self.mode { ModelSpec::OcrModel => Runtime::Ocr(Ocr::new(ModelSpec::DEFAULT_OCR_PATH.into(), get_default_charset())?), ModelSpec::DetModel => Runtime::Det(Det::new(ModelSpec::DEFAULT_DET_PATH.into())?), ModelSpec::CustomOcrModel { path, charset } => Runtime::Ocr(Ocr::new(path, charset)?), }; Ok(DdddOcr { runtime }) } } pub struct DdddOcr { runtime: Runtime, } impl Display for DdddOcr { fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { write!(f, "DdddOcr(session: {})", self.runtime.desc()) } } impl DdddOcr { pub fn classification(&self, img: &DynamicImage) -> Result { match &self.runtime { Runtime::Ocr(s) => s.predict(img, false), Runtime::Det(_) => Err(anyhow::anyhow!("当前模型是检测模型,无法执行 OCR")), } } pub fn detection(&self, img: &[u8]) -> Result>> { match &self.runtime { Runtime::Det(s) => s.predict(img), Runtime::Ocr(_) => Err(anyhow::anyhow!("当前模型是 OCR 模型,无法执行检测")), } } } #[cfg(test)] mod tests { use super::*; #[test] fn test_ctc_decode_indices() { // 模拟一个 DdddOcr 实例(如果 decode 不依赖 session,可以设为相关函数) // 这里假设你的 decode_ctc 是公开或内部可访问的 let input = vec![1, 1, 0, 1, 2, 2, 0, 2]; // 逻辑:[1, 1] -> 1, [0] -> 跳过, [1] -> 1, [2, 2] -> 2, [0] -> 跳过, [2] -> 2 // 预期结果索引应该是 [1, 1, 2, 2] 对应的字符 // 具体的断言取决于你的 CHARSET_BETA // let result = dddd.ctc_decode_indices(&input); // assert_eq!(result, "AABB"); } }