feat: 优化 项目目录结构
This commit is contained in:
40
src/models/loader.rs
Normal file
40
src/models/loader.rs
Normal file
@@ -0,0 +1,40 @@
|
||||
use anyhow::Context;
|
||||
use image::DynamicImage;
|
||||
use tract_onnx::onnx;
|
||||
use tract_onnx::prelude::*;
|
||||
// 关键点:直接使用 tract 重导出的 ndarray
|
||||
use crate::utils::image_io::png_rgba_white_preprocess;
|
||||
use crate::utils::image_processor::{convert_to_grayscale, resize_image};
|
||||
use std::collections::HashMap;
|
||||
use tract_onnx::prelude::tract_ndarray::s;
|
||||
|
||||
/// OCR 模型:包含路径和字符集
|
||||
|
||||
pub enum ModelType {
|
||||
Ocr,
|
||||
Det,
|
||||
Custom,
|
||||
}
|
||||
// 定义统一的 trait
|
||||
pub trait ModelSession {
|
||||
fn get_model_type(&self) -> ModelType;
|
||||
fn desc(&self) -> String;
|
||||
}
|
||||
|
||||
pub struct ModelLoader {
|
||||
pub session: RunnableModel<TypedFact, Box<dyn TypedOp>, Graph<TypedFact, Box<dyn TypedOp>>>,
|
||||
}
|
||||
|
||||
impl ModelLoader {
|
||||
pub fn load_model<P>(model_path: P) -> anyhow::Result<Self>
|
||||
where
|
||||
P: AsRef<std::path::Path>,
|
||||
{
|
||||
let session = onnx()
|
||||
.model_for_path(model_path)
|
||||
.with_context(|| "加载 ONNX 模型失败,请检查路径是否正确")?
|
||||
.into_optimized()?
|
||||
.into_runnable()?;
|
||||
Ok(Self { session })
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user