41 lines
1.1 KiB
Rust
41 lines
1.1 KiB
Rust
use anyhow::Context;
|
|
use image::DynamicImage;
|
|
use tract_onnx::onnx;
|
|
use tract_onnx::prelude::*;
|
|
// 关键点:直接使用 tract 重导出的 ndarray
|
|
use crate::utils::image_io::png_rgba_white_preprocess;
|
|
use crate::utils::image_processor::{convert_to_grayscale, resize_image};
|
|
use std::collections::HashMap;
|
|
use tract_onnx::prelude::tract_ndarray::s;
|
|
|
|
/// OCR 模型:包含路径和字符集
|
|
|
|
pub enum ModelType {
|
|
Ocr,
|
|
Det,
|
|
Custom,
|
|
}
|
|
// 定义统一的 trait
|
|
pub trait ModelSession {
|
|
fn get_model_type(&self) -> ModelType;
|
|
fn desc(&self) -> String;
|
|
}
|
|
|
|
pub struct ModelLoader {
|
|
pub session: RunnableModel<TypedFact, Box<dyn TypedOp>, Graph<TypedFact, Box<dyn TypedOp>>>,
|
|
}
|
|
|
|
impl ModelLoader {
|
|
pub fn load_model<P>(model_path: P) -> anyhow::Result<Self>
|
|
where
|
|
P: AsRef<std::path::Path>,
|
|
{
|
|
let session = onnx()
|
|
.model_for_path(model_path)
|
|
.with_context(|| "加载 ONNX 模型失败,请检查路径是否正确")?
|
|
.into_optimized()?
|
|
.into_runnable()?;
|
|
Ok(Self { session })
|
|
}
|
|
}
|