126 lines
3.5 KiB
Rust
126 lines
3.5 KiB
Rust
pub mod base;
|
||
mod charset;
|
||
mod det_model;
|
||
mod image_io;
|
||
mod image_processor;
|
||
mod model;
|
||
mod model_loader;
|
||
mod ocr_model;
|
||
mod utils;
|
||
pub mod slide_model;
|
||
mod cv2;
|
||
|
||
use anyhow::Result;
|
||
use image::DynamicImage;
|
||
use std::fmt::{Display, Formatter};
|
||
|
||
// 关键点:直接使用 tract 重导出的 ndarray
|
||
use crate::charset::get_default_charset;
|
||
use crate::det_model::Det;
|
||
use crate::model_loader::ModelSession;
|
||
use crate::ocr_model::Ocr;
|
||
pub enum ModelSpec {
|
||
/// 默认 OCR (使用内置路径)
|
||
OcrModel,
|
||
DetModel,
|
||
/// 自定义 OCR (路径由用户提供)
|
||
CustomOcrModel {
|
||
path: String,
|
||
charset: Vec<String>,
|
||
},
|
||
}
|
||
impl ModelSpec {
|
||
// 将默认路径定义为内部关联常量
|
||
const DEFAULT_OCR_PATH: &'static str = "models/common_sml2h3_f32.onnx";
|
||
const DEFAULT_DET_PATH: &'static str = "models/common_det.onnx";
|
||
}
|
||
pub enum Runtime {
|
||
Ocr(Ocr),
|
||
Det(Det),
|
||
}
|
||
impl Runtime {
|
||
// 统一获取描述的方法
|
||
pub fn desc(&self) -> String {
|
||
match self {
|
||
Runtime::Ocr(s) => s.desc(), // 调用 Ocr 结构体的方法
|
||
Runtime::Det(s) => s.desc(), // 调用 Det 结构体的方法
|
||
}
|
||
}
|
||
}
|
||
pub struct DdddOcrBuilder {
|
||
mode: ModelSpec,
|
||
}
|
||
|
||
impl DdddOcrBuilder {
|
||
pub fn new() -> Self {
|
||
Self {
|
||
mode: ModelSpec::OcrModel,
|
||
}
|
||
}
|
||
|
||
/// 切换为检测模式
|
||
pub fn det(mut self) -> Self {
|
||
self.mode = ModelSpec::DetModel;
|
||
self
|
||
}
|
||
|
||
/// 设置自定义 OCR 路径
|
||
pub fn custom_ocr(mut self, path: String, charset: Vec<String>) -> Self {
|
||
// 直接重写枚举,替换掉之前的 Ocr 或 Det
|
||
self.mode = ModelSpec::CustomOcrModel { path, charset };
|
||
self
|
||
}
|
||
|
||
/// 核心初始化逻辑
|
||
pub fn build(self) -> Result<DdddOcr> {
|
||
let runtime = match self.mode {
|
||
ModelSpec::OcrModel => Runtime::Ocr(Ocr::new(ModelSpec::DEFAULT_OCR_PATH.into(), get_default_charset())?),
|
||
ModelSpec::DetModel => Runtime::Det(Det::new(ModelSpec::DEFAULT_DET_PATH.into())?),
|
||
ModelSpec::CustomOcrModel { path, charset } => Runtime::Ocr(Ocr::new(path, charset)?),
|
||
};
|
||
|
||
Ok(DdddOcr { runtime })
|
||
}
|
||
}
|
||
|
||
pub struct DdddOcr {
|
||
runtime: Runtime,
|
||
}
|
||
|
||
impl Display for DdddOcr {
|
||
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
|
||
write!(f, "DdddOcr(session: {})", self.runtime.desc())
|
||
}
|
||
}
|
||
|
||
impl DdddOcr {
|
||
pub fn classification(&self, img: &DynamicImage) -> Result<String> {
|
||
match &self.runtime {
|
||
Runtime::Ocr(s) => s.predict(img, false),
|
||
Runtime::Det(_) => Err(anyhow::anyhow!("当前模型是检测模型,无法执行 OCR")),
|
||
}
|
||
}
|
||
pub fn detection(&self, img: &[u8]) -> Result<Vec<Vec<i32>>> {
|
||
match &self.runtime {
|
||
Runtime::Det(s) => s.predict(img),
|
||
Runtime::Ocr(_) => Err(anyhow::anyhow!("当前模型是 OCR 模型,无法执行检测")),
|
||
}
|
||
}
|
||
}
|
||
|
||
#[cfg(test)]
|
||
mod tests {
|
||
use super::*;
|
||
#[test]
|
||
fn test_ctc_decode_indices() {
|
||
// 模拟一个 DdddOcr 实例(如果 decode 不依赖 session,可以设为相关函数)
|
||
// 这里假设你的 decode_ctc 是公开或内部可访问的
|
||
let input = vec![1, 1, 0, 1, 2, 2, 0, 2];
|
||
// 逻辑:[1, 1] -> 1, [0] -> 跳过, [1] -> 1, [2, 2] -> 2, [0] -> 跳过, [2] -> 2
|
||
// 预期结果索引应该是 [1, 1, 2, 2] 对应的字符
|
||
// 具体的断言取决于你的 CHARSET_BETA
|
||
// let result = dddd.ctc_decode_indices(&input);
|
||
// assert_eq!(result, "AABB");
|
||
}
|
||
}
|