refactor: 移除 OpenCV 依赖并实现纯 Rust 图像处理流水线

- 替换 opencv 为 image 库以简化交叉编译
- 修正 nms 逻辑中的 ArrayView 借用问题
- 增加 save_debug_image 方法用于可视化检测框
- 更新 Cargo.toml 依赖项
This commit is contained in:
2026-05-06 17:37:38 +08:00
parent cfeb68ad04
commit 8fcfa2096e
8 changed files with 338 additions and 71 deletions

BIN
code3.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 5.6 KiB

BIN
samples/det1.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 70 KiB

BIN
samples/det2.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 95 KiB

View File

@@ -1,24 +1,213 @@
use image::DynamicImage;
use crate::model_loader::{ModelLoader, ModelSession, ModelType}; use crate::model_loader::{ModelLoader, ModelSession, ModelType};
use tract_onnx::prelude::{Graph, RunnableModel, TypedFact, TypedOp}; use anyhow::{Context, Result};
use crate::ocr_model::Ocr; use image::{DynamicImage, GenericImageView, imageops::FilterType};
use tract_onnx::prelude::tract_ndarray::{Array2, Array3, Array4, prelude::*};
use tract_onnx::prelude::{Graph, RunnableModel, Tensor, TypedFact, TypedOp, tvec};
use image::{GenericImage,RgbImage,Rgb, Rgba};
use imageproc::drawing::draw_hollow_rect_mut;
use imageproc::rect::Rect;
use std::path::Path;
pub struct Det { pub struct Det {
session: RunnableModel<TypedFact, Box<dyn TypedOp>, Graph<TypedFact, Box<dyn TypedOp>>>, session: RunnableModel<TypedFact, Box<dyn TypedOp>, Graph<TypedFact, Box<dyn TypedOp>>>,
} }
impl ModelSession for Det { impl ModelSession for Det {
fn predict(&self, image: &DynamicImage, png_fix: bool) -> Result<String, anyhow::Error> {
// OCR 识别逻辑 + CTC 解码
Ok("ocr result".to_string())
}
fn get_model_type(&self) -> ModelType { fn get_model_type(&self) -> ModelType {
todo!() todo!()
} }
fn desc(&self) -> String {
"Detection Model 加载成功".to_string()
}
} }
impl Det { impl Det {
pub fn new(model_path: String) -> Result<Self, anyhow::Error> { pub fn new(model_path: String) -> Result<Self, anyhow::Error> {
let session = ModelLoader::load_model(&model_path)?.session; let session = ModelLoader::load_model(&model_path)?.session;
Ok(Self { session }) Ok(Self { session })
} }
pub fn predict(&self, image_bytes: &[u8]) -> Result<Vec<Vec<i32>>> {
// Rust 中通常在调用层处理文件/PIL转换这里直接进入核心逻辑
self.get_bbox(image_bytes)
}
/// 2. preproc: 纯 Rust 实现 (替代 OpenCV)
fn preproc(&self, img: &DynamicImage, input_size: (u32, u32)) -> Result<(Tensor, f32)> {
let (target_h, target_w) = input_size;
let (img_w, img_h) = img.dimensions();
// 计算缩放比例 (Letterbox)
let r = (target_h as f32 / img_h as f32).min(target_w as f32 / img_w as f32);
let new_h = (img_h as f32 * r) as u32;
let new_w = (img_w as f32 * r) as u32;
// Resize 图像
let resized = img.resize_exact(new_w, new_h, FilterType::Triangle);
// 2. 关键:将 DynamicImage 显式转换为 RgbImage (Rgb<u8>)
let resized_rgb = resized.to_rgb8();
// 创建 114 灰度填充的背景
let mut base_img =
image::ImageBuffer::from_pixel(target_w, target_h, image::Rgb([114u8, 114, 114]));
// 将 resize 后的图像覆盖到左上角 (类似于原始代码中的 padded_img[:h, :w])
image::imageops::overlay(&mut base_img, &resized_rgb, 0, 0);
// 构造 NCHW Tensor
let mut array = Array4::<f32>::zeros((1, 3, target_h as usize, target_w as usize));
for (x, y, pixel) in base_img.enumerate_pixels() {
// RGB 顺序归一化 (根据模型需求,若需 BGR 则调换索引)
array[[0, 0, y as usize, x as usize]] = pixel[0] as f32;
array[[0, 1, y as usize, x as usize]] = pixel[1] as f32;
array[[0, 2, y as usize, x as usize]] = pixel[2] as f32;
}
Ok((array.into(), r))
}
/// 3. demo_postprocess (逻辑与 Python 一致)
fn demo_postprocess(&self, mut outputs: Array3<f32>, img_size: (i32, i32)) -> Array3<f32> {
let strides = [8, 16, 32];
let mut offset = 0;
for stride in strides {
let h = img_size.0 / stride;
let w = img_size.1 / stride;
for y in 0..h {
for x in 0..w {
let idx = offset + (y * w + x) as usize;
// cx, cy 还原
outputs[[0, idx, 0]] = (outputs[[0, idx, 0]] + x as f32) * stride as f32;
outputs[[0, idx, 1]] = (outputs[[0, idx, 1]] + y as f32) * stride as f32;
// w, h 还原
outputs[[0, idx, 2]] = outputs[[0, idx, 2]].exp() * stride as f32;
outputs[[0, idx, 3]] = outputs[[0, idx, 3]].exp() * stride as f32;
}
}
offset += (h * w) as usize;
}
outputs
}
/// 4. nms
fn nms(&self, boxes: &Array2<f32>, scores: &Array1<f32>, nms_thr: f32) -> Vec<usize> {
let mut keep = Vec::new();
let x1 = boxes.column(0);
let y1 = boxes.column(1);
let x2 = boxes.column(2);
let y2 = boxes.column(3);
// 在每一项前加上 &,并确保括号内的计算顺序
// 注意ndarray 的 View 运算需要 &view1 - &view2
let areas = (&x2 - &x1 + 1.0) * (&y2 - &y1 + 1.0);
let mut v: Vec<usize> = (0..scores.len()).collect();
v.sort_by(|&i, &j| {
scores[j]
.partial_cmp(&scores[i])
.unwrap_or(std::cmp::Ordering::Equal)
});
while let Some(i) = v.first().cloned() {
keep.push(i);
if v.len() == 1 {
break;
}
v.remove(0);
v.retain(|&idx| {
let xx1 = x1[i].max(x1[idx]);
let yy1 = y1[i].max(y1[idx]);
let xx2 = x2[i].min(x2[idx]);
let yy2 = y2[i].min(y2[idx]);
let w = (xx2 - xx1 + 1.0).max(0.0);
let h = (yy2 - yy1 + 1.0).max(0.0);
let inter = w * h;
let iou = inter / (areas[i] + areas[idx] - inter);
iou <= nms_thr
});
}
keep
}
/// 5. multiclass_nms
fn multiclass_nms(
&self,
boxes: &Array2<f32>,
scores: &Array2<f32>,
nms_thr: f32,
score_thr: f32,
) -> Vec<Vec<f32>> {
let mut result = Vec::new();
for i in 0..scores.nrows() {
let row = scores.row(i);
let (cls_id, &score) = row
.iter()
.enumerate()
.max_by(|a, b| a.1.partial_cmp(b.1).unwrap())
.unwrap();
if score > score_thr {
let mut det = boxes.row(i).to_vec();
det.push(score);
det.push(cls_id as f32);
result.push(det);
}
}
if result.is_empty() {
return vec![];
}
let b_subset = Array2::from_shape_vec(
(result.len(), 4),
result.iter().flat_map(|r| r[0..4].to_vec()).collect(),
)
.unwrap();
let s_subset = Array1::from_vec(result.iter().map(|r| r[4]).collect());
let keep = self.nms(&b_subset, &s_subset, nms_thr);
keep.into_iter().map(|idx| result[idx].clone()).collect()
}
/// 6. get_bbox (完全解耦 OpenCV)
pub fn get_bbox(&self, image_bytes: &[u8]) -> Result<Vec<Vec<i32>>> {
// 使用 image crate 解码
let dynamic_img = image::load_from_memory(image_bytes).context("Failed to decode image")?;
let (orig_w, orig_h) = dynamic_img.dimensions();
let (input_tensor, ratio) = self.preproc(&dynamic_img, (416, 416))?;
// tract 推理
let outputs = self.session.run(tvec!(input_tensor.into()))?;
let output_array = outputs[0]
.to_array_view::<f32>()?
.to_owned()
.into_dimensionality::<Ix3>()?;
let predictions = self.demo_postprocess(output_array, (416, 416));
let pred = predictions.slice(s![0, .., ..]);
let boxes = pred.slice(s![.., 0..4]);
let scores = &pred.slice(s![.., 4..5]) * &pred.slice(s![.., 5..]);
let mut boxes_xyxy = Array2::<f32>::zeros(boxes.raw_dim());
for i in 0..boxes.nrows() {
boxes_xyxy[[i, 0]] = (boxes[[i, 0]] - boxes[[i, 2]] / 2.0) / ratio;
boxes_xyxy[[i, 1]] = (boxes[[i, 1]] - boxes[[i, 3]] / 2.0) / ratio;
boxes_xyxy[[i, 2]] = (boxes[[i, 0]] + boxes[[i, 2]] / 2.0) / ratio;
boxes_xyxy[[i, 3]] = (boxes[[i, 1]] + boxes[[i, 3]] / 2.0) / ratio;
}
let dets = self.multiclass_nms(&boxes_xyxy, &scores, 0.45, 0.1);
Ok(dets
.into_iter()
.map(|d| {
vec![
(d[0] as i32).max(0).min(orig_w as i32),
(d[1] as i32).max(0).min(orig_h as i32),
(d[2] as i32).max(0).min(orig_w as i32),
(d[3] as i32).max(0).min(orig_h as i32),
]
})
.collect())
}
} }

View File

@@ -10,84 +10,99 @@ mod utils;
use anyhow::Result; use anyhow::Result;
use image::DynamicImage; use image::DynamicImage;
use std::fmt::{Display, Formatter};
// 关键点:直接使用 tract 重导出的 ndarray // 关键点:直接使用 tract 重导出的 ndarray
use crate::charset::get_default_charset;
use crate::det_model::Det; use crate::det_model::Det;
use crate::model_loader::ModelSession; use crate::model_loader::ModelSession;
use crate::ocr_model::Ocr; use crate::ocr_model::Ocr;
use crate::charset::get_default_charset; pub enum ModelSpec {
pub enum ModeType {
/// 默认 OCR (使用内置路径) /// 默认 OCR (使用内置路径)
Ocr { OcrModel,
path: String, DetModel,
charset: Vec<String>,
},
Det {
path: String,
},
/// 自定义 OCR (路径由用户提供) /// 自定义 OCR (路径由用户提供)
CustomOcr { CustomOcrModel {
path: String, path: String,
charset: Vec<String>, charset: Vec<String>,
}, },
} }
impl ModelSpec {
// 将默认路径定义为内部关联常量
const DEFAULT_OCR_PATH: &'static str = "models/common_sml2h3_f32.onnx";
const DEFAULT_DET_PATH: &'static str = "models/common_det.onnx";
}
pub enum Runtime {
Ocr(Ocr),
Det(Det),
}
impl Runtime {
// 统一获取描述的方法
pub fn desc(&self) -> String {
match self {
Runtime::Ocr(s) => s.desc(), // 调用 Ocr 结构体的方法
Runtime::Det(s) => s.desc(), // 调用 Det 结构体的方法
}
}
}
pub struct DdddOcrBuilder { pub struct DdddOcrBuilder {
mode: ModeType, mode: ModelSpec,
} }
impl DdddOcrBuilder { impl DdddOcrBuilder {
pub fn new() -> Self { pub fn new() -> Self {
Self { Self {
mode: ModeType::Ocr { mode: ModelSpec::OcrModel,
path: "models/common.onnx".to_string(),
charset: get_default_charset(),
},
} }
} }
/// 切换为检测模式 /// 切换为检测模式
pub fn det(mut self) -> Self { pub fn det(mut self) -> Self {
self.mode = ModeType::Det { self.mode = ModelSpec::DetModel;
path: "models/common_det.onnx".to_string(),
};
self self
} }
/// 设置自定义 OCR 路径 /// 设置自定义 OCR 路径
pub fn custom_ocr(mut self, path: String, charset: Vec<String>) -> Self { pub fn custom_ocr(mut self, path: String, charset: Vec<String>) -> Self {
// 直接重写枚举,替换掉之前的 Ocr 或 Det // 直接重写枚举,替换掉之前的 Ocr 或 Det
self.mode = ModeType::CustomOcr { path, charset }; self.mode = ModelSpec::CustomOcrModel { path, charset };
self self
} }
/// 核心初始化逻辑 /// 核心初始化逻辑
pub fn build(self) -> Result<DdddOcr> { pub fn build(self) -> Result<DdddOcr> {
let session: Box<dyn ModelSession> = match self.mode { let runtime = match self.mode {
ModeType::Ocr { path, charset } => Box::new(Ocr::new(path, charset)?), ModelSpec::OcrModel => Runtime::Ocr(Ocr::new(ModelSpec::DEFAULT_OCR_PATH.into(), get_default_charset())?),
ModeType::Det { path } => Box::new(Det::new(path)?), ModelSpec::DetModel => Runtime::Det(Det::new(ModelSpec::DEFAULT_DET_PATH.into())?),
ModeType::CustomOcr { path, charset } => Box::new(Ocr::new(path, charset)?), ModelSpec::CustomOcrModel { path, charset } => Runtime::Ocr(Ocr::new(path, charset)?),
}; };
Ok(DdddOcr { session }) Ok(DdddOcr { runtime })
} }
} }
pub struct DdddOcr { pub struct DdddOcr {
session: Box<dyn ModelSession>, runtime: Runtime,
} }
impl Display for DdddOcr {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
write!(f, "DdddOcr(session: {})", self.runtime.desc())
}
}
impl DdddOcr { impl DdddOcr {
pub fn classification(&self, img: &DynamicImage) -> Result<String> { pub fn classification(&self, img: &DynamicImage) -> Result<String> {
self.session.predict(img, false) match &self.runtime {
Runtime::Ocr(s) => s.predict(img, false),
// let tensor = self.preprocess_image(img, false)?; Runtime::Det(_) => Err(anyhow::anyhow!("当前模型是检测模型,无法执行 OCR")),
// }
// // let result = self.session.run(tvec!(tensor.into()))?; }
// // 3. 解析结果 pub fn detection(&self, img: &[u8]) -> Result<Vec<Vec<i32>>> {
// // let output = result[0].to_array_view::<i64>()?; match &self.runtime {
// let output = self.inference(tensor)?; Runtime::Det(s) => s.predict(img),
// let output2 = self.process_text_output(&output)?; Runtime::Ocr(_) => Err(anyhow::anyhow!("当前模型是 OCR 模型,无法执行检测")),
// Ok(Self::ctc_decode_indices(&output2)) }
} }
} }

View File

@@ -17,8 +17,8 @@ pub enum ModelType {
} }
// 定义统一的 trait // 定义统一的 trait
pub trait ModelSession { pub trait ModelSession {
fn predict(&self, image: &DynamicImage, png_fix: bool) -> Result<String, anyhow::Error>;
fn get_model_type(&self) -> ModelType; fn get_model_type(&self) -> ModelType;
fn desc(&self) -> String;
} }
pub struct ModelLoader { pub struct ModelLoader {

View File

@@ -1,3 +1,4 @@
use crate::base::ModelArgs;
use crate::image_io::png_rgba_white_preprocess; use crate::image_io::png_rgba_white_preprocess;
use crate::image_processor::{convert_to_grayscale, resize_image}; use crate::image_processor::{convert_to_grayscale, resize_image};
use crate::model_loader::{ModelLoader, ModelSession, ModelType}; use crate::model_loader::{ModelLoader, ModelSession, ModelType};
@@ -7,8 +8,6 @@ use tract_onnx::prelude::tract_ndarray::s;
use tract_onnx::prelude::{ use tract_onnx::prelude::{
DatumType, Graph, IntoTensor, RunnableModel, Tensor, TypedFact, TypedOp, tract_ndarray, tvec, DatumType, Graph, IntoTensor, RunnableModel, Tensor, TypedFact, TypedOp, tract_ndarray, tvec,
}; };
use crate::base::ModelArgs;
// 颜色过滤的自定义范围:(低值RGB, 高值RGB) // 颜色过滤的自定义范围:(低值RGB, 高值RGB)
pub type ColorRange = ((u8, u8, u8), (u8, u8, u8)); pub type ColorRange = ((u8, u8, u8), (u8, u8, u8));
@@ -16,17 +15,17 @@ pub type ColorRange = ((u8, u8, u8), (u8, u8, u8));
// 字符集范围类型 // 字符集范围类型
#[derive(Debug, Clone)] #[derive(Debug, Clone)]
pub enum CharsetRange { pub enum CharsetRange {
All, // 所有字符 All, // 所有字符
Digit, // 数字 Digit, // 数字
Letter, // 字母 Letter, // 字母
Alphanumeric, // 字母数字 Alphanumeric, // 字母数字
Single(String), // 单字符串 Single(String), // 单字符串
Multiple(Vec<String>), // 多个字符串 Multiple(Vec<String>), // 多个字符串
Range(char, char), // 字符范围 Range(char, char), // 字符范围
Custom(Vec<char>), // 自定义字符列表 Custom(Vec<char>), // 自定义字符列表
} }
#[derive(Debug, Clone)] #[derive(Debug, Clone)]
pub struct PredictArgs{ pub struct PredictArgs {
/// 是否修复PNG格式问题 /// 是否修复PNG格式问题
pub png_fix: bool, pub png_fix: bool,
/// 是否返回概率信息 /// 是否返回概率信息
@@ -100,7 +99,19 @@ pub struct Ocr {
charset: Vec<String>, charset: Vec<String>,
} }
impl ModelSession for Ocr { impl ModelSession for Ocr {
fn predict(&self, image: &DynamicImage, png_fix: bool) -> Result<String, anyhow::Error> { fn get_model_type(&self) -> ModelType {
todo!()
}
fn desc(&self) -> String {
"Ocr Model 加载成功".to_string()
}
}
impl Ocr {
pub fn new(model_path: String, charset: Vec<String>) -> Result<Self, anyhow::Error> {
let session = ModelLoader::load_model(&model_path)?.session;
Ok(Self { session, charset })
}
pub fn predict(&self, image: &DynamicImage, png_fix: bool) -> Result<String, anyhow::Error> {
let tensor = self.preprocess_image(image, png_fix)?; let tensor = self.preprocess_image(image, png_fix)?;
// //
// let result = self.session.run(tvec!(tensor.into()))?; // let result = self.session.run(tvec!(tensor.into()))?;
@@ -108,19 +119,9 @@ impl ModelSession for Ocr {
// // let output = result[0].to_array_view::<i64>()?; // // let output = result[0].to_array_view::<i64>()?;
let output = self.inference(tensor)?; let output = self.inference(tensor)?;
let output2 = self.process_text_output(&output)?; let output2 = self.process_text_output(&output)?;
Ok(Self::ctc_decode_indices(&output2)) Ok(self.ctc_decode_indices(&output2))
// Ok("ocr result".to_string()) // Ok("ocr result".to_string())
} }
fn get_model_type(&self) -> ModelType {
ModelType::Ocr
}
}
impl Ocr {
pub fn new(model_path: String, charset: Vec<String>) -> Result<Self, anyhow::Error> {
let session = ModelLoader::load_model(&model_path)?.session;
Ok(Self { session, charset })
}
/// 对应 Python 的 _preprocess_image /// 对应 Python 的 _preprocess_image
/// 负责:透明背景修复 -> 灰度化 -> 按比例 Resize -> 归一化 -> 4维张量转换 /// 负责:透明背景修复 -> 灰度化 -> 按比例 Resize -> 归一化 -> 4维张量转换
fn preprocess_image(&self, img: &DynamicImage, png_fix: bool) -> anyhow::Result<Tensor> { fn preprocess_image(&self, img: &DynamicImage, png_fix: bool) -> anyhow::Result<Tensor> {
@@ -222,10 +223,9 @@ impl Ocr {
)), )),
} }
} }
fn ctc_decode_indices(predicted_indices: &[i64]) -> String { fn ctc_decode_indices(&self, predicted_indices: &[i64]) -> String {
println!("indices模型输出原始数据: {:?}", predicted_indices); println!("indices模型输出原始数据: {:?}", predicted_indices);
use crate::charset::CHARSET_BETA;
// 对应 _ctc_decode_indices 的逻辑:去重、去 blank (0) // 对应 _ctc_decode_indices 的逻辑:去重、去 blank (0)
let mut res = String::new(); let mut res = String::new();
let mut prev_idx: i64 = -1; let mut prev_idx: i64 = -1;
@@ -235,8 +235,11 @@ impl Ocr {
// 2. 跳过 blank 字符 (假设索引 0 是 blank) // 2. 跳过 blank 字符 (假设索引 0 是 blank)
if idx != prev_idx && idx != 0 { if idx != prev_idx && idx != 0 {
if let Ok(u_idx) = usize::try_from(idx) { if let Ok(u_idx) = usize::try_from(idx) {
if let Some(&char_str) = CHARSET_BETA.get(u_idx) { if let Some(char_str) = self.charset.get(u_idx) {
res.push_str(char_str); res.push_str(char_str);
} else {
// 保护逻辑:如果模型预测的索引超出了字符集范围
eprintln!("警告: 预测索引 {} 超出字符集范围", u_idx);
} }
} }
} }

View File

@@ -1,5 +1,44 @@
use std::fs;
use image::Rgb;
use ddddocr_rs::{DdddOcr, DdddOcrBuilder}; // 假设你的包名是这个 use ddddocr_rs::{DdddOcr, DdddOcrBuilder}; // 假设你的包名是这个
/// 将检测结果绘制在图像上并保存
fn save_debug_image( image_bytes: &[u8], bboxes: &Vec<Vec<i32>>, output_path: &str) -> anyhow::Result<()> {
let dynamic_img = image::load_from_memory(image_bytes)?;
let mut img = dynamic_img.to_rgb8();
let (width, height) = img.dimensions();
let red = Rgb([255u8, 0, 0]);
for bbox in bboxes {
// 基础边界检查
let x1 = bbox[0].max(0).min(width as i32 - 1) as u32;
let y1 = bbox[1].max(0).min(height as i32 - 1) as u32;
let x2 = bbox[2].max(0).min(width as i32 - 1) as u32;
let y2 = bbox[3].max(0).min(height as i32 - 1) as u32;
// 绘制横向线条
for x in x1..=x2 {
img.put_pixel(x, y1, red);
img.put_pixel(x, y2, red);
// 如果要加粗,多画一行
if y1 + 1 < height { img.put_pixel(x, y1 + 1, red); }
if y2.saturating_sub(1) > 0 { img.put_pixel(x, y2 - 1, red); }
}
// 绘制纵向线条
for y in y1..=y2 {
img.put_pixel(x1, y, red);
img.put_pixel(x2, y, red);
// 如果要加粗,多画一列
if x1 + 1 < width { img.put_pixel(x1 + 1, y, red); }
if x2.saturating_sub(1) > 0 { img.put_pixel(x2 - 1, y, red); }
}
}
img.save(output_path)?;
Ok(())
}
#[test] #[test]
fn test_full_classification() { fn test_full_classification() {
// 1. 初始化模型 // 1. 初始化模型
@@ -13,4 +52,25 @@ fn test_full_classification() {
println!("识别结果: {}", result); println!("识别结果: {}", result);
assert!(!result.is_empty()); assert!(!result.is_empty());
}
#[test]
fn test_det_load()->anyhow::Result<()>{
let det = DdddOcrBuilder::new().det().build()?;
let image_path = "samples/det1.png";
let image_bytes = fs::read(image_path)
.map_err(|e| anyhow::anyhow!("无法读取图片 {}: {}", image_path, e))?;
println!("图片读取成功,字节大小: {}", image_bytes.len());
let bboxes =det.detection(&image_bytes)?;
println!(":?{}",det);
println!("检测到的目标数量: {}", bboxes.len());
if bboxes.is_empty() {
println!("未检测到任何目标。");
} else {
save_debug_image(&image_bytes, &bboxes, "samples/result.jpg")?;
for (i, bbox) in bboxes.iter().enumerate() {
println!("目标 [{}]: x1={}, y1={}, x2={}, y2={}", i, bbox[0], bbox[1], bbox[2], bbox[3]);
}
}
Ok(())
} }