refactor: 移除 OpenCV 依赖并实现纯 Rust 图像处理流水线
- 替换 opencv 为 image 库以简化交叉编译 - 修正 nms 逻辑中的 ArrayView 借用问题 - 增加 save_debug_image 方法用于可视化检测框 - 更新 Cargo.toml 依赖项
This commit is contained in:
87
src/lib.rs
87
src/lib.rs
@@ -10,84 +10,99 @@ mod utils;
|
||||
|
||||
use anyhow::Result;
|
||||
use image::DynamicImage;
|
||||
use std::fmt::{Display, Formatter};
|
||||
|
||||
// 关键点:直接使用 tract 重导出的 ndarray
|
||||
use crate::charset::get_default_charset;
|
||||
use crate::det_model::Det;
|
||||
use crate::model_loader::ModelSession;
|
||||
use crate::ocr_model::Ocr;
|
||||
use crate::charset::get_default_charset;
|
||||
pub enum ModeType {
|
||||
pub enum ModelSpec {
|
||||
/// 默认 OCR (使用内置路径)
|
||||
Ocr {
|
||||
path: String,
|
||||
charset: Vec<String>,
|
||||
},
|
||||
Det {
|
||||
path: String,
|
||||
},
|
||||
OcrModel,
|
||||
DetModel,
|
||||
/// 自定义 OCR (路径由用户提供)
|
||||
CustomOcr {
|
||||
CustomOcrModel {
|
||||
path: String,
|
||||
charset: Vec<String>,
|
||||
},
|
||||
}
|
||||
|
||||
impl ModelSpec {
|
||||
// 将默认路径定义为内部关联常量
|
||||
const DEFAULT_OCR_PATH: &'static str = "models/common_sml2h3_f32.onnx";
|
||||
const DEFAULT_DET_PATH: &'static str = "models/common_det.onnx";
|
||||
}
|
||||
pub enum Runtime {
|
||||
Ocr(Ocr),
|
||||
Det(Det),
|
||||
}
|
||||
impl Runtime {
|
||||
// 统一获取描述的方法
|
||||
pub fn desc(&self) -> String {
|
||||
match self {
|
||||
Runtime::Ocr(s) => s.desc(), // 调用 Ocr 结构体的方法
|
||||
Runtime::Det(s) => s.desc(), // 调用 Det 结构体的方法
|
||||
}
|
||||
}
|
||||
}
|
||||
pub struct DdddOcrBuilder {
|
||||
mode: ModeType,
|
||||
mode: ModelSpec,
|
||||
}
|
||||
|
||||
impl DdddOcrBuilder {
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
mode: ModeType::Ocr {
|
||||
path: "models/common.onnx".to_string(),
|
||||
charset: get_default_charset(),
|
||||
},
|
||||
mode: ModelSpec::OcrModel,
|
||||
}
|
||||
}
|
||||
|
||||
/// 切换为检测模式
|
||||
pub fn det(mut self) -> Self {
|
||||
self.mode = ModeType::Det {
|
||||
path: "models/common_det.onnx".to_string(),
|
||||
};
|
||||
self.mode = ModelSpec::DetModel;
|
||||
self
|
||||
}
|
||||
|
||||
/// 设置自定义 OCR 路径
|
||||
pub fn custom_ocr(mut self, path: String, charset: Vec<String>) -> Self {
|
||||
// 直接重写枚举,替换掉之前的 Ocr 或 Det
|
||||
self.mode = ModeType::CustomOcr { path, charset };
|
||||
self.mode = ModelSpec::CustomOcrModel { path, charset };
|
||||
self
|
||||
}
|
||||
|
||||
/// 核心初始化逻辑
|
||||
pub fn build(self) -> Result<DdddOcr> {
|
||||
let session: Box<dyn ModelSession> = match self.mode {
|
||||
ModeType::Ocr { path, charset } => Box::new(Ocr::new(path, charset)?),
|
||||
ModeType::Det { path } => Box::new(Det::new(path)?),
|
||||
ModeType::CustomOcr { path, charset } => Box::new(Ocr::new(path, charset)?),
|
||||
let runtime = match self.mode {
|
||||
ModelSpec::OcrModel => Runtime::Ocr(Ocr::new(ModelSpec::DEFAULT_OCR_PATH.into(), get_default_charset())?),
|
||||
ModelSpec::DetModel => Runtime::Det(Det::new(ModelSpec::DEFAULT_DET_PATH.into())?),
|
||||
ModelSpec::CustomOcrModel { path, charset } => Runtime::Ocr(Ocr::new(path, charset)?),
|
||||
};
|
||||
|
||||
Ok(DdddOcr { session })
|
||||
Ok(DdddOcr { runtime })
|
||||
}
|
||||
}
|
||||
|
||||
pub struct DdddOcr {
|
||||
session: Box<dyn ModelSession>,
|
||||
runtime: Runtime,
|
||||
}
|
||||
|
||||
impl Display for DdddOcr {
|
||||
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
|
||||
write!(f, "DdddOcr(session: {})", self.runtime.desc())
|
||||
}
|
||||
}
|
||||
|
||||
impl DdddOcr {
|
||||
pub fn classification(&self, img: &DynamicImage) -> Result<String> {
|
||||
self.session.predict(img, false)
|
||||
|
||||
// let tensor = self.preprocess_image(img, false)?;
|
||||
//
|
||||
// // let result = self.session.run(tvec!(tensor.into()))?;
|
||||
// // 3. 解析结果
|
||||
// // let output = result[0].to_array_view::<i64>()?;
|
||||
// let output = self.inference(tensor)?;
|
||||
// let output2 = self.process_text_output(&output)?;
|
||||
// Ok(Self::ctc_decode_indices(&output2))
|
||||
match &self.runtime {
|
||||
Runtime::Ocr(s) => s.predict(img, false),
|
||||
Runtime::Det(_) => Err(anyhow::anyhow!("当前模型是检测模型,无法执行 OCR")),
|
||||
}
|
||||
}
|
||||
pub fn detection(&self, img: &[u8]) -> Result<Vec<Vec<i32>>> {
|
||||
match &self.runtime {
|
||||
Runtime::Det(s) => s.predict(img),
|
||||
Runtime::Ocr(_) => Err(anyhow::anyhow!("当前模型是 OCR 模型,无法执行检测")),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user