diff --git a/code3.png b/code3.png new file mode 100644 index 0000000..ac30b0b Binary files /dev/null and b/code3.png differ diff --git a/samples/det1.png b/samples/det1.png new file mode 100644 index 0000000..dd4cb46 Binary files /dev/null and b/samples/det1.png differ diff --git a/samples/det2.png b/samples/det2.png new file mode 100644 index 0000000..affd615 Binary files /dev/null and b/samples/det2.png differ diff --git a/src/det_model.rs b/src/det_model.rs index 3dd7dbd..79e81f4 100644 --- a/src/det_model.rs +++ b/src/det_model.rs @@ -1,24 +1,213 @@ -use image::DynamicImage; use crate::model_loader::{ModelLoader, ModelSession, ModelType}; -use tract_onnx::prelude::{Graph, RunnableModel, TypedFact, TypedOp}; -use crate::ocr_model::Ocr; +use anyhow::{Context, Result}; +use image::{DynamicImage, GenericImageView, imageops::FilterType}; +use tract_onnx::prelude::tract_ndarray::{Array2, Array3, Array4, prelude::*}; +use tract_onnx::prelude::{Graph, RunnableModel, Tensor, TypedFact, TypedOp, tvec}; + +use image::{GenericImage,RgbImage,Rgb, Rgba}; +use imageproc::drawing::draw_hollow_rect_mut; +use imageproc::rect::Rect; +use std::path::Path; pub struct Det { session: RunnableModel, Graph>>, } impl ModelSession for Det { - fn predict(&self, image: &DynamicImage, png_fix: bool) -> Result { - // OCR 识别逻辑 + CTC 解码 - Ok("ocr result".to_string()) - } - fn get_model_type(&self) -> ModelType { todo!() } + fn desc(&self) -> String { + "Detection Model 加载成功".to_string() + } } impl Det { pub fn new(model_path: String) -> Result { let session = ModelLoader::load_model(&model_path)?.session; Ok(Self { session }) } + pub fn predict(&self, image_bytes: &[u8]) -> Result>> { + // Rust 中通常在调用层处理文件/PIL转换,这里直接进入核心逻辑 + self.get_bbox(image_bytes) + } + /// 2. preproc: 纯 Rust 实现 (替代 OpenCV) + fn preproc(&self, img: &DynamicImage, input_size: (u32, u32)) -> Result<(Tensor, f32)> { + let (target_h, target_w) = input_size; + let (img_w, img_h) = img.dimensions(); + + // 计算缩放比例 (Letterbox) + let r = (target_h as f32 / img_h as f32).min(target_w as f32 / img_w as f32); + let new_h = (img_h as f32 * r) as u32; + let new_w = (img_w as f32 * r) as u32; + + // Resize 图像 + let resized = img.resize_exact(new_w, new_h, FilterType::Triangle); + // 2. 关键:将 DynamicImage 显式转换为 RgbImage (Rgb) + let resized_rgb = resized.to_rgb8(); + // 创建 114 灰度填充的背景 + let mut base_img = + image::ImageBuffer::from_pixel(target_w, target_h, image::Rgb([114u8, 114, 114])); + + // 将 resize 后的图像覆盖到左上角 (类似于原始代码中的 padded_img[:h, :w]) + image::imageops::overlay(&mut base_img, &resized_rgb, 0, 0); + + // 构造 NCHW Tensor + let mut array = Array4::::zeros((1, 3, target_h as usize, target_w as usize)); + for (x, y, pixel) in base_img.enumerate_pixels() { + // RGB 顺序归一化 (根据模型需求,若需 BGR 则调换索引) + array[[0, 0, y as usize, x as usize]] = pixel[0] as f32; + array[[0, 1, y as usize, x as usize]] = pixel[1] as f32; + array[[0, 2, y as usize, x as usize]] = pixel[2] as f32; + } + + Ok((array.into(), r)) + } + + /// 3. demo_postprocess (逻辑与 Python 一致) + fn demo_postprocess(&self, mut outputs: Array3, img_size: (i32, i32)) -> Array3 { + let strides = [8, 16, 32]; + let mut offset = 0; + + for stride in strides { + let h = img_size.0 / stride; + let w = img_size.1 / stride; + for y in 0..h { + for x in 0..w { + let idx = offset + (y * w + x) as usize; + // cx, cy 还原 + outputs[[0, idx, 0]] = (outputs[[0, idx, 0]] + x as f32) * stride as f32; + outputs[[0, idx, 1]] = (outputs[[0, idx, 1]] + y as f32) * stride as f32; + // w, h 还原 + outputs[[0, idx, 2]] = outputs[[0, idx, 2]].exp() * stride as f32; + outputs[[0, idx, 3]] = outputs[[0, idx, 3]].exp() * stride as f32; + } + } + offset += (h * w) as usize; + } + outputs + } + + /// 4. nms + fn nms(&self, boxes: &Array2, scores: &Array1, nms_thr: f32) -> Vec { + let mut keep = Vec::new(); + let x1 = boxes.column(0); + let y1 = boxes.column(1); + let x2 = boxes.column(2); + let y2 = boxes.column(3); + // 在每一项前加上 &,并确保括号内的计算顺序 + // 注意:ndarray 的 View 运算需要 &view1 - &view2 + let areas = (&x2 - &x1 + 1.0) * (&y2 - &y1 + 1.0); + + let mut v: Vec = (0..scores.len()).collect(); + v.sort_by(|&i, &j| { + scores[j] + .partial_cmp(&scores[i]) + .unwrap_or(std::cmp::Ordering::Equal) + }); + + while let Some(i) = v.first().cloned() { + keep.push(i); + if v.len() == 1 { + break; + } + v.remove(0); + v.retain(|&idx| { + let xx1 = x1[i].max(x1[idx]); + let yy1 = y1[i].max(y1[idx]); + let xx2 = x2[i].min(x2[idx]); + let yy2 = y2[i].min(y2[idx]); + let w = (xx2 - xx1 + 1.0).max(0.0); + let h = (yy2 - yy1 + 1.0).max(0.0); + let inter = w * h; + let iou = inter / (areas[i] + areas[idx] - inter); + iou <= nms_thr + }); + } + keep + } + + /// 5. multiclass_nms + fn multiclass_nms( + &self, + boxes: &Array2, + scores: &Array2, + nms_thr: f32, + score_thr: f32, + ) -> Vec> { + let mut result = Vec::new(); + for i in 0..scores.nrows() { + let row = scores.row(i); + let (cls_id, &score) = row + .iter() + .enumerate() + .max_by(|a, b| a.1.partial_cmp(b.1).unwrap()) + .unwrap(); + + if score > score_thr { + let mut det = boxes.row(i).to_vec(); + det.push(score); + det.push(cls_id as f32); + result.push(det); + } + } + + if result.is_empty() { + return vec![]; + } + + let b_subset = Array2::from_shape_vec( + (result.len(), 4), + result.iter().flat_map(|r| r[0..4].to_vec()).collect(), + ) + .unwrap(); + let s_subset = Array1::from_vec(result.iter().map(|r| r[4]).collect()); + + let keep = self.nms(&b_subset, &s_subset, nms_thr); + keep.into_iter().map(|idx| result[idx].clone()).collect() + } + + /// 6. get_bbox (完全解耦 OpenCV) + pub fn get_bbox(&self, image_bytes: &[u8]) -> Result>> { + // 使用 image crate 解码 + let dynamic_img = image::load_from_memory(image_bytes).context("Failed to decode image")?; + let (orig_w, orig_h) = dynamic_img.dimensions(); + + let (input_tensor, ratio) = self.preproc(&dynamic_img, (416, 416))?; + + // tract 推理 + let outputs = self.session.run(tvec!(input_tensor.into()))?; + let output_array = outputs[0] + .to_array_view::()? + .to_owned() + .into_dimensionality::()?; + + let predictions = self.demo_postprocess(output_array, (416, 416)); + let pred = predictions.slice(s![0, .., ..]); + + let boxes = pred.slice(s![.., 0..4]); + let scores = &pred.slice(s![.., 4..5]) * &pred.slice(s![.., 5..]); + + let mut boxes_xyxy = Array2::::zeros(boxes.raw_dim()); + for i in 0..boxes.nrows() { + boxes_xyxy[[i, 0]] = (boxes[[i, 0]] - boxes[[i, 2]] / 2.0) / ratio; + boxes_xyxy[[i, 1]] = (boxes[[i, 1]] - boxes[[i, 3]] / 2.0) / ratio; + boxes_xyxy[[i, 2]] = (boxes[[i, 0]] + boxes[[i, 2]] / 2.0) / ratio; + boxes_xyxy[[i, 3]] = (boxes[[i, 1]] + boxes[[i, 3]] / 2.0) / ratio; + } + + let dets = self.multiclass_nms(&boxes_xyxy, &scores, 0.45, 0.1); + + Ok(dets + .into_iter() + .map(|d| { + vec![ + (d[0] as i32).max(0).min(orig_w as i32), + (d[1] as i32).max(0).min(orig_h as i32), + (d[2] as i32).max(0).min(orig_w as i32), + (d[3] as i32).max(0).min(orig_h as i32), + ] + }) + .collect()) + } + + } diff --git a/src/lib.rs b/src/lib.rs index 87504ff..f1be9ef 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -10,84 +10,99 @@ mod utils; use anyhow::Result; use image::DynamicImage; +use std::fmt::{Display, Formatter}; // 关键点:直接使用 tract 重导出的 ndarray +use crate::charset::get_default_charset; use crate::det_model::Det; use crate::model_loader::ModelSession; use crate::ocr_model::Ocr; -use crate::charset::get_default_charset; -pub enum ModeType { +pub enum ModelSpec { /// 默认 OCR (使用内置路径) - Ocr { - path: String, - charset: Vec, - }, - Det { - path: String, - }, + OcrModel, + DetModel, /// 自定义 OCR (路径由用户提供) - CustomOcr { + CustomOcrModel { path: String, charset: Vec, }, } - +impl ModelSpec { + // 将默认路径定义为内部关联常量 + const DEFAULT_OCR_PATH: &'static str = "models/common_sml2h3_f32.onnx"; + const DEFAULT_DET_PATH: &'static str = "models/common_det.onnx"; +} +pub enum Runtime { + Ocr(Ocr), + Det(Det), +} +impl Runtime { + // 统一获取描述的方法 + pub fn desc(&self) -> String { + match self { + Runtime::Ocr(s) => s.desc(), // 调用 Ocr 结构体的方法 + Runtime::Det(s) => s.desc(), // 调用 Det 结构体的方法 + } + } +} pub struct DdddOcrBuilder { - mode: ModeType, + mode: ModelSpec, } impl DdddOcrBuilder { pub fn new() -> Self { Self { - mode: ModeType::Ocr { - path: "models/common.onnx".to_string(), - charset: get_default_charset(), - }, + mode: ModelSpec::OcrModel, } } /// 切换为检测模式 pub fn det(mut self) -> Self { - self.mode = ModeType::Det { - path: "models/common_det.onnx".to_string(), - }; + self.mode = ModelSpec::DetModel; self } /// 设置自定义 OCR 路径 pub fn custom_ocr(mut self, path: String, charset: Vec) -> Self { // 直接重写枚举,替换掉之前的 Ocr 或 Det - self.mode = ModeType::CustomOcr { path, charset }; + self.mode = ModelSpec::CustomOcrModel { path, charset }; self } /// 核心初始化逻辑 pub fn build(self) -> Result { - let session: Box = match self.mode { - ModeType::Ocr { path, charset } => Box::new(Ocr::new(path, charset)?), - ModeType::Det { path } => Box::new(Det::new(path)?), - ModeType::CustomOcr { path, charset } => Box::new(Ocr::new(path, charset)?), + let runtime = match self.mode { + ModelSpec::OcrModel => Runtime::Ocr(Ocr::new(ModelSpec::DEFAULT_OCR_PATH.into(), get_default_charset())?), + ModelSpec::DetModel => Runtime::Det(Det::new(ModelSpec::DEFAULT_DET_PATH.into())?), + ModelSpec::CustomOcrModel { path, charset } => Runtime::Ocr(Ocr::new(path, charset)?), }; - Ok(DdddOcr { session }) + Ok(DdddOcr { runtime }) } } pub struct DdddOcr { - session: Box, + runtime: Runtime, } + +impl Display for DdddOcr { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + write!(f, "DdddOcr(session: {})", self.runtime.desc()) + } +} + impl DdddOcr { pub fn classification(&self, img: &DynamicImage) -> Result { - self.session.predict(img, false) - - // let tensor = self.preprocess_image(img, false)?; - // - // // let result = self.session.run(tvec!(tensor.into()))?; - // // 3. 解析结果 - // // let output = result[0].to_array_view::()?; - // let output = self.inference(tensor)?; - // let output2 = self.process_text_output(&output)?; - // Ok(Self::ctc_decode_indices(&output2)) + match &self.runtime { + Runtime::Ocr(s) => s.predict(img, false), + Runtime::Det(_) => Err(anyhow::anyhow!("当前模型是检测模型,无法执行 OCR")), + } + } + pub fn detection(&self, img: &[u8]) -> Result>> { + match &self.runtime { + Runtime::Det(s) => s.predict(img), + Runtime::Ocr(_) => Err(anyhow::anyhow!("当前模型是 OCR 模型,无法执行检测")), + } } } diff --git a/src/model_loader.rs b/src/model_loader.rs index f8f5c64..e912c9d 100644 --- a/src/model_loader.rs +++ b/src/model_loader.rs @@ -17,8 +17,8 @@ pub enum ModelType { } // 定义统一的 trait pub trait ModelSession { - fn predict(&self, image: &DynamicImage, png_fix: bool) -> Result; fn get_model_type(&self) -> ModelType; + fn desc(&self) -> String; } pub struct ModelLoader { diff --git a/src/ocr_model.rs b/src/ocr_model.rs index 38a8ac8..9f993dc 100644 --- a/src/ocr_model.rs +++ b/src/ocr_model.rs @@ -1,3 +1,4 @@ +use crate::base::ModelArgs; use crate::image_io::png_rgba_white_preprocess; use crate::image_processor::{convert_to_grayscale, resize_image}; use crate::model_loader::{ModelLoader, ModelSession, ModelType}; @@ -7,8 +8,6 @@ use tract_onnx::prelude::tract_ndarray::s; use tract_onnx::prelude::{ DatumType, Graph, IntoTensor, RunnableModel, Tensor, TypedFact, TypedOp, tract_ndarray, tvec, }; -use crate::base::ModelArgs; - // 颜色过滤的自定义范围:(低值RGB, 高值RGB) pub type ColorRange = ((u8, u8, u8), (u8, u8, u8)); @@ -16,17 +15,17 @@ pub type ColorRange = ((u8, u8, u8), (u8, u8, u8)); // 字符集范围类型 #[derive(Debug, Clone)] pub enum CharsetRange { - All, // 所有字符 - Digit, // 数字 - Letter, // 字母 - Alphanumeric, // 字母数字 - Single(String), // 单字符串 - Multiple(Vec), // 多个字符串 - Range(char, char), // 字符范围 - Custom(Vec), // 自定义字符列表 + All, // 所有字符 + Digit, // 数字 + Letter, // 字母 + Alphanumeric, // 字母数字 + Single(String), // 单字符串 + Multiple(Vec), // 多个字符串 + Range(char, char), // 字符范围 + Custom(Vec), // 自定义字符列表 } #[derive(Debug, Clone)] -pub struct PredictArgs{ +pub struct PredictArgs { /// 是否修复PNG格式问题 pub png_fix: bool, /// 是否返回概率信息 @@ -100,7 +99,19 @@ pub struct Ocr { charset: Vec, } impl ModelSession for Ocr { - fn predict(&self, image: &DynamicImage, png_fix: bool) -> Result { + fn get_model_type(&self) -> ModelType { + todo!() + } + fn desc(&self) -> String { + "Ocr Model 加载成功".to_string() + } +} +impl Ocr { + pub fn new(model_path: String, charset: Vec) -> Result { + let session = ModelLoader::load_model(&model_path)?.session; + Ok(Self { session, charset }) + } + pub fn predict(&self, image: &DynamicImage, png_fix: bool) -> Result { let tensor = self.preprocess_image(image, png_fix)?; // // let result = self.session.run(tvec!(tensor.into()))?; @@ -108,19 +119,9 @@ impl ModelSession for Ocr { // // let output = result[0].to_array_view::()?; let output = self.inference(tensor)?; let output2 = self.process_text_output(&output)?; - Ok(Self::ctc_decode_indices(&output2)) + Ok(self.ctc_decode_indices(&output2)) // Ok("ocr result".to_string()) } - - fn get_model_type(&self) -> ModelType { - ModelType::Ocr - } -} -impl Ocr { - pub fn new(model_path: String, charset: Vec) -> Result { - let session = ModelLoader::load_model(&model_path)?.session; - Ok(Self { session, charset }) - } /// 对应 Python 的 _preprocess_image /// 负责:透明背景修复 -> 灰度化 -> 按比例 Resize -> 归一化 -> 4维张量转换 fn preprocess_image(&self, img: &DynamicImage, png_fix: bool) -> anyhow::Result { @@ -222,10 +223,9 @@ impl Ocr { )), } } - fn ctc_decode_indices(predicted_indices: &[i64]) -> String { + fn ctc_decode_indices(&self, predicted_indices: &[i64]) -> String { println!("indices模型输出原始数据: {:?}", predicted_indices); - use crate::charset::CHARSET_BETA; // 对应 _ctc_decode_indices 的逻辑:去重、去 blank (0) let mut res = String::new(); let mut prev_idx: i64 = -1; @@ -235,8 +235,11 @@ impl Ocr { // 2. 跳过 blank 字符 (假设索引 0 是 blank) if idx != prev_idx && idx != 0 { if let Ok(u_idx) = usize::try_from(idx) { - if let Some(&char_str) = CHARSET_BETA.get(u_idx) { + if let Some(char_str) = self.charset.get(u_idx) { res.push_str(char_str); + } else { + // 保护逻辑:如果模型预测的索引超出了字符集范围 + eprintln!("警告: 预测索引 {} 超出字符集范围", u_idx); } } } diff --git a/tests/ocr_test.rs b/tests/ocr_test.rs index 776827d..4cfccfc 100644 --- a/tests/ocr_test.rs +++ b/tests/ocr_test.rs @@ -1,5 +1,44 @@ +use std::fs; +use image::Rgb; use ddddocr_rs::{DdddOcr, DdddOcrBuilder}; // 假设你的包名是这个 +/// 将检测结果绘制在图像上并保存 +fn save_debug_image( image_bytes: &[u8], bboxes: &Vec>, output_path: &str) -> anyhow::Result<()> { + + + let dynamic_img = image::load_from_memory(image_bytes)?; + let mut img = dynamic_img.to_rgb8(); + let (width, height) = img.dimensions(); + let red = Rgb([255u8, 0, 0]); + + for bbox in bboxes { + // 基础边界检查 + let x1 = bbox[0].max(0).min(width as i32 - 1) as u32; + let y1 = bbox[1].max(0).min(height as i32 - 1) as u32; + let x2 = bbox[2].max(0).min(width as i32 - 1) as u32; + let y2 = bbox[3].max(0).min(height as i32 - 1) as u32; + + // 绘制横向线条 + for x in x1..=x2 { + img.put_pixel(x, y1, red); + img.put_pixel(x, y2, red); + // 如果要加粗,多画一行 + if y1 + 1 < height { img.put_pixel(x, y1 + 1, red); } + if y2.saturating_sub(1) > 0 { img.put_pixel(x, y2 - 1, red); } + } + // 绘制纵向线条 + for y in y1..=y2 { + img.put_pixel(x1, y, red); + img.put_pixel(x2, y, red); + // 如果要加粗,多画一列 + if x1 + 1 < width { img.put_pixel(x1 + 1, y, red); } + if x2.saturating_sub(1) > 0 { img.put_pixel(x2 - 1, y, red); } + } + } + + img.save(output_path)?; + Ok(()) +} #[test] fn test_full_classification() { // 1. 初始化模型 @@ -13,4 +52,25 @@ fn test_full_classification() { println!("识别结果: {}", result); assert!(!result.is_empty()); +} +#[test] +fn test_det_load()->anyhow::Result<()>{ + let det = DdddOcrBuilder::new().det().build()?; + let image_path = "samples/det1.png"; + let image_bytes = fs::read(image_path) + .map_err(|e| anyhow::anyhow!("无法读取图片 {}: {}", image_path, e))?; + + println!("图片读取成功,字节大小: {}", image_bytes.len()); + let bboxes =det.detection(&image_bytes)?; + println!(":?{}",det); + println!("检测到的目标数量: {}", bboxes.len()); + if bboxes.is_empty() { + println!("未检测到任何目标。"); + } else { + save_debug_image(&image_bytes, &bboxes, "samples/result.jpg")?; + for (i, bbox) in bboxes.iter().enumerate() { + println!("目标 [{}]: x1={}, y1={}, x2={}, y2={}", i, bbox[0], bbox[1], bbox[2], bbox[3]); + } + } + Ok(()) } \ No newline at end of file