7 Commits

Author SHA1 Message Date
a51147c888 refactor: 优化 slide_model.rs
- 新增 cv2.rs 模拟 opencv
2026-05-09 17:52:34 +08:00
e8b365dced feat: 优化 image_io.rs 模块
- 新增 base64_to_image等工具函数。
2026-05-08 22:35:17 +08:00
f0db625bd1 refactor: 完成图像加载模块重构,对齐 ddddocr Python 原版 IO 逻辑 2026-05-08 17:59:42 +08:00
21bd1c93bf feat: 完成 Rust 滑块匹配算法,修复透明留白导致的坐标偏移
- 实现灰度与边缘两种匹配模式
- 对齐 OpenCV NCC 算法逻辑
- 优化图像灰度化与 Alpha 通道转换
- 提升坐标计算精度至像素级
2026-05-08 16:03:33 +08:00
1a329ca273 refactor: 优化Det算法
- 优化 demo_postprocess,nms算法
- 新增 Slide 滑块识别
- 更新 Cargo.toml 依赖项
2026-05-07 18:00:39 +08:00
8fcfa2096e refactor: 移除 OpenCV 依赖并实现纯 Rust 图像处理流水线
- 替换 opencv 为 image 库以简化交叉编译
- 修正 nms 逻辑中的 ArrayView 借用问题
- 增加 save_debug_image 方法用于可视化检测框
- 更新 Cargo.toml 依赖项
2026-05-06 17:37:38 +08:00
cfeb68ad04 feat: 重构模型初始化逻辑
- 重构 DdddOcr。
- 新增 DdddOcrBuilder。
- 其他优化
2026-05-05 22:18:12 +08:00
21 changed files with 1427 additions and 179 deletions

View File

@@ -9,3 +9,4 @@ tract-onnx = { version = "0.21.10" }
anyhow = "1.0.102" anyhow = "1.0.102"
image = "0.25.10" image = "0.25.10"
base64 = "0.22.1" base64 = "0.22.1"
imageproc = { version = "0.26.2", default-features = true }

View File

@@ -2,8 +2,43 @@
带带弟弟 OCR (ddddocr) 的 Rust 移植版。高性能、低占用,支持多种验证码识别与检测。 带带弟弟 OCR (ddddocr) 的 Rust 移植版。高性能、低占用,支持多种验证码识别与检测。
🧩 滑块识别算法核心知识点总结
本项目实现了两种核心匹配模式,其底层逻辑与 OpenCV 的对齐情况如下:
1. 匹配模式对比 (Match Modes)
|**模式**|**算法原理**|**适用场景**|**备注**|
|---|---|---|---|
|**边缘模式** (Edge-based)|基于 **Canny 边缘检测** 提取轮廓后再进行匹配。|**推荐方案**
。适用于绝大多数拼图滑块。|天然免疫拼图周边的透明/黑色留白干扰,坐标最精准。|
|**简单模式** (Simple/Gray)|直接基于 **灰度像素值** 进行归一化互相关计算。|适用于无明显边缘、靠颜色差异识别的场景。|对背景和透明边框敏感,可能存在重心偏移。|
2. 数学公式差异 (NCC vs. CCOEFF)
在简单模式下,本项目采用的是 归一化互相关 (NCC),对应 OpenCV 中的 TM_CCORR_NORMED。
逻辑对齐Rust 的 match_template 结果与 Python cv2.TM_CCORR_NORMED 完全一致。
关于偏移若拼图原始图片Target四周包含大量的透明留白
CCORR (本项目):会将留白视为图像的一部分,计算出的是整张图片框的中心。
CCOEFF (OpenCV 默认):会自动进行“均值中心化”,在一定程度上能削弱留白的影响。
最佳实践:若发现坐标有固定位移,建议优先切换至 边缘模式,或对滑块图进行 Bounding Box 裁剪 后再匹配。
3. 图像预处理一致性
为确保识别精度,本项目在 Rust 中完美复刻了 Python OpenCV 的预处理链路:
- **灰度化权重**:采用 OpenCV 标准感光公式 $0.299R + 0.587G + 0.114B$。
- **Alpha 处理**:在将 PNG 转为 RGB 时,自动将透明区域填充为黑色,确保与 PIL (Python Imaging Library) 行为一致。
- **坐标定义**:所有返回坐标均为匹配区域的 **几何中心点** $(x + w/2, y + h/2)$。
💡 开发者建议:
如果识别结果在 $X$ 轴上有大约 $10px$ 左右的固定误差,通常是因为滑块原图自带了透明边距(留白)。此时请确保
simple_target=false。该模式会通过 Canny 边缘检测 提取轮廓特征,能自动锁定拼图实体并忽略背景留白的像素干扰。
鸣谢 (Credits) 鸣谢 (Credits)
- 本项目是 [ddddocr](https://github.com/sml2h3/ddddocr) 的 Rust 移植版本,原作者为 sml2h3。衷心感谢原作者对 OCR 社区做出的杰出贡献。 - 本项目是 [ddddocr](https://github.com/sml2h3/ddddocr) 的 Rust 移植版本,原作者为 sml2h3。衷心感谢原作者对 OCR 社区做出的杰出贡献。

BIN
code3.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 5.6 KiB

View File

@@ -1,5 +1,5 @@
fn main() { fn main() {
let ocr = ddddocr_rs::DdddOcr::new("model/common.onnx").unwrap(); let ocr = ddddocr_rs::DdddOcrBuilder::new().build().unwrap();
let img = image::open("samples/code3.png").unwrap(); let img = image::open("samples/code3.png").unwrap();
println!("Result: {}", ocr.classification(&img).unwrap()); println!("Result: {}", ocr.classification(&img).unwrap());
} }

BIN
samples/det1.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 70 KiB

BIN
samples/det2.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 95 KiB

BIN
samples/det3.jpg Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 26 KiB

BIN
samples/hua.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 4.1 KiB

BIN
samples/huatu.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 94 KiB

BIN
samples/ken.jpg Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 8.3 KiB

BIN
samples/kenyuan.jpg Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 8.0 KiB

40
src/base.rs Normal file
View File

@@ -0,0 +1,40 @@
pub trait ModelArgs {
// 获取模型路径
fn model_path(&self) -> &str;
// 获取字符集(由于 Det 没有,所以返回 Option
fn charset(&self) -> Option<&str>;
}
pub struct HasCharset {
pub charset: String,
} // 给 Ocr 和 Custom 用
pub struct NoCharset; // 给 Det 用
pub struct Model<T> {
pub path: String,
pub metadata: T,
}
// 针对有字符集的模型 (Ocr / Custom)
impl ModelArgs for Model<HasCharset> {
fn model_path(&self) -> &str {
&self.path
}
fn charset(&self) -> Option<&str> {
Some(&self.metadata.charset)
}
}
// 针对没有字符集的模型 (Det)
impl ModelArgs for Model<NoCharset> {
fn model_path(&self) -> &str {
&self.path
}
fn charset(&self) -> Option<&str> {
None
}
}
pub type OcrModel = Model<HasCharset>;
pub type DetModel = Model<NoCharset>;
pub type CustomModel = Model<HasCharset>; // Ocr 和 Custom 逻辑一致,可以复用

View File

@@ -514,3 +514,6 @@ pub const CHARSET_BETA: &[&str] = &[
"", "", "", "", "婿", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "婿", "", "", "", "", "", "", "", "", "", "", "",
"", "", "", "",
]; ];
pub fn get_default_charset() -> Vec<String> {
CHARSET_BETA.iter().map(|&s| s.to_string()).collect()
}

57
src/cv2.rs Normal file
View File

@@ -0,0 +1,57 @@
use image::{ImageBuffer, Luma};
use tract_onnx::prelude::tract_ndarray::{azip, Array2, Array3, ArrayView2, ArrayView3};
/// 1. 计算两个数组的绝对差值 (对应 cv2.absdiff)
pub fn abs_diff(a: &ArrayView3<u8>, b: &ArrayView3<u8>) -> Array3<u8> {
// 利用 ndarray 的 map_collect生成差值的绝对值数组
// 或者直接使用 zip_mut_with 处理以减少内存分配
let mut diff = Array3::zeros(a.dim());
azip!((res in &mut diff, &va in a, &vb in b) {
*res = (va as i16 - vb as i16).abs() as u8;
});
diff
}
/// RGB 到灰度转换
pub fn rgb_to_gray(rgb: ArrayView3<u8>) -> Array2<u8> {
let (h, w, _) = rgb.dim();
Array2::from_shape_fn((h, w), |(y, x)| {
let r = rgb[[y, x, 0]] as f32;
let g = rgb[[y, x, 1]] as f32;
let b = rgb[[y, x, 2]] as f32;
// 完全忽略 a只按权重计算
(0.299 * r + 0.587 * g + 0.114 * b) as u8
})
}
/// 寻找匹配结果图中的最大值及其坐标 (模拟 cv2.minMaxLoc 的一部分)
pub fn min_max_loc(result_map: &ImageBuffer<Luma<f32>, Vec<f32>>) -> (f32, (u32, u32)) {
// 4. 找到最佳匹配位置 (对齐 cv2.minMaxLoc)
let mut max_val: f32 = -1.0;
let mut max_loc = (0, 0);
// 遍历匹配得分图
for (x, y, score) in result_map.enumerate_pixels() {
let s = score.0[0];
// 可以在此处加入你之前验证过的起始位过滤
// if x < 15 { continue; }
if s > max_val {
max_val = s;
max_loc = (x, y);
}
}
(max_val, max_loc)
}
pub fn ndarray_to_luma8(array: ArrayView2<u8>) -> ImageBuffer<Luma<u8>, Vec<u8>> {
let (height, width) = array.dim();
let mut buffer = ImageBuffer::new(width as u32, height as u32);
for y in 0..height {
for x in 0..width {
buffer.put_pixel(x as u32, y as u32, Luma([array[[y, x]]]));
}
}
buffer
}

263
src/det_model.rs Normal file
View File

@@ -0,0 +1,263 @@
use crate::model_loader::{ModelLoader, ModelSession, ModelType};
use anyhow::{Context, Result};
use image::{DynamicImage, GenericImageView, imageops::FilterType};
use tract_onnx::prelude::tract_ndarray::{Array2, Array3, Array4, Axis, prelude::*, s};
use tract_onnx::prelude::{Graph, RunnableModel, Tensor, TypedFact, TypedOp, tvec};
pub struct Det {
session: RunnableModel<TypedFact, Box<dyn TypedOp>, Graph<TypedFact, Box<dyn TypedOp>>>,
}
impl ModelSession for Det {
fn get_model_type(&self) -> ModelType {
todo!()
}
fn desc(&self) -> String {
"Detection Model 加载成功".to_string()
}
}
impl Det {
pub fn new(model_path: String) -> Result<Self, anyhow::Error> {
let session = ModelLoader::load_model(&model_path)?.session;
Ok(Self { session })
}
pub fn predict(&self, image_bytes: &[u8]) -> Result<Vec<Vec<i32>>> {
// Rust 中通常在调用层处理文件/PIL转换这里直接进入核心逻辑
self.get_bbox(image_bytes)
}
/// 2. preproc: 纯 Rust 实现 (替代 OpenCV)
fn preproc(&self, img: &DynamicImage, input_size: (u32, u32)) -> Result<(Tensor, f32)> {
let (target_h, target_w) = input_size;
let (img_w, img_h) = img.dimensions();
// 计算缩放比例 (Letterbox)
let r = (target_h as f32 / img_h as f32).min(target_w as f32 / img_w as f32);
let new_h = (img_h as f32 * r) as u32;
let new_w = (img_w as f32 * r) as u32;
// Resize 图像
let resized = img.resize_exact(new_w, new_h, FilterType::Triangle);
// 2. 关键:将 DynamicImage 显式转换为 RgbImage (Rgb<u8>)
let resized_rgb = resized.to_rgb8();
// 创建 114 灰度填充的背景
let mut base_img =
image::ImageBuffer::from_pixel(target_w, target_h, image::Rgb([114u8, 114, 114]));
// 将 resize 后的图像覆盖到左上角 (类似于原始代码中的 padded_img[:h, :w])
image::imageops::overlay(&mut base_img, &resized_rgb, 0, 0);
// 构造 NCHW Tensor
let mut array = Array4::<f32>::zeros((1, 3, target_h as usize, target_w as usize));
for (x, y, pixel) in base_img.enumerate_pixels() {
let x = x as usize;
let y = y as usize;
// 核心对标 Python 的 BGR 逻辑:
// pixel[0] 是 R, pixel[1] 是 G, pixel[2] 是 B
// 如果模型需要 BGR
// array[[0, 0, y as usize, x as usize]] = pixel[0] as f32;
// array[[0, 1, y as usize, x as usize]] = pixel[1] as f32;
// array[[0, 2, y as usize, x as usize]] = pixel[2] as f32;
array[[0, 0, y, x]] = pixel[2] as f32; // B
array[[0, 1, y, x]] = pixel[1] as f32; // G
array[[0, 2, y, x]] = pixel[0] as f32; // R
}
Ok((array.into(), r))
}
/// 3. demo_postprocess (逻辑与 Python 一致)
fn demo_postprocess(&self, mut outputs: Array3<f32>, img_size: (i32, i32)) -> Array3<f32> {
let strides = [8, 16, 32];
// 遍历每一个 Batch支持动态 Batch 推理)
for mut batch in outputs.axis_iter_mut(Axis(0)) {
let mut offset = 0;
for &stride in &strides {
// 计算当前特征图的尺寸
let h = img_size.0 / stride;
let w = img_size.1 / stride;
let f_stride = stride as f32;
for y in 0..h {
for x in 0..w {
// 计算当前格子在 25200 个锚点中的线性索引
let idx = offset + (y * w + x) as usize;
// 1. 还原中心点坐标 (cx, cy)
// 公式: (output + grid_offset) * stride
batch[[idx, 0]] = (batch[[idx, 0]] + x as f32) * f_stride;
batch[[idx, 1]] = (batch[[idx, 1]] + y as f32) * f_stride;
// 2. 还原宽高 (w, h)
// 公式: exp(output) * stride
batch[[idx, 2]] = batch[[idx, 2]].exp() * f_stride;
batch[[idx, 3]] = batch[[idx, 3]].exp() * f_stride;
}
}
// 移动到下一个步长的起始位置
offset += (h * w) as usize;
}
}
outputs
}
/// 4. nms
fn nms(&self, boxes: &Array2<f32>, scores: &Array1<f32>, nms_thr: f32) -> Vec<usize> {
let mut keep = Vec::new();
let x1 = boxes.column(0);
let y1 = boxes.column(1);
let x2 = boxes.column(2);
let y2 = boxes.column(3);
// 在每一项前加上 &,并确保括号内的计算顺序
// 注意ndarray 的 View 运算需要 &view1 - &view2
let areas = (&x2 - &x1 + 1.0) * (&y2 - &y1 + 1.0);
// 初始排序索引
let mut v: Vec<usize> = (0..scores.len()).collect();
v.sort_unstable_by(|&i, &j| {
scores[j]
.partial_cmp(&scores[i])
.unwrap_or(std::cmp::Ordering::Equal)
});
// 我们不使用 v.remove(0),而是直接通过索引池操作
let mut active_indices = v;
while !active_indices.is_empty() {
// 取出当前池子中得分最高的框(即第一个元素)
let i = active_indices[0];
keep.push(i);
// 如果池子里只剩一个了,直接结束
if active_indices.len() == 1 {
break;
}
// 5. 核心逻辑:使用 retain 一次性过滤掉:
// (a) 当前框自己 (idx == i)
// (b) 与当前框重叠度过高的框 (iou > nms_thr)
active_indices.retain(|&idx| {
// 如果是当前正在处理的框,不保留(因为它已经进入 keep 了)
if idx == i {
return false;
}
// 计算 IoU
let xx1 = x1[i].max(x1[idx]);
let yy1 = y1[i].max(y1[idx]);
let xx2 = x2[i].min(x2[idx]);
let yy2 = y2[i].min(y2[idx]);
let w = (xx2 - xx1 + 1.0).max(0.0);
let h = (yy2 - yy1 + 1.0).max(0.0);
let inter = w * h;
let iou = inter / (areas[i] + areas[idx] - inter);
// 只保留 IoU 小于阈值的框
iou <= nms_thr
});
}
keep
}
/// 5. multiclass_nms
pub fn multiclass_nms(
&self,
boxes: &Array2<f32>, // [25200, 4] -> xyxy 格式
scores: &Array2<f32>, // [25200, 80] -> 已经乘以 objectness 的得分
nms_thr: f32,
score_thr: f32,
) -> Vec<Vec<f32>> {
let mut candidates = Vec::new();
// 1. 筛选高分框 (单次遍历完成 Argmax 和 Threshold 过滤)
for i in 0..scores.nrows() {
let row = scores.row(i);
// 找到当前行(即当前锚点)得分最高的类别
let mut max_score = 0.0;
let mut cls_id = 0;
for (j, &s) in row.iter().enumerate() {
if s > max_score {
max_score = s;
cls_id = j;
}
}
// 仅保留超过阈值的候选框
if max_score > score_thr {
// 暂时存储索引和元数据,避免频繁创建大数组
candidates.push((i, max_score, cls_id));
}
}
if candidates.is_empty() {
return vec![];
}
// 2. 准备 NMS 输入
// 构造 NMS 需要的子集数组
let mut b_subset = Array2::<f32>::zeros((candidates.len(), 4));
let mut s_subset = Array1::<f32>::zeros(candidates.len());
for (new_idx, &(orig_idx, score, _)) in candidates.iter().enumerate() {
b_subset.row_mut(new_idx).assign(&boxes.row(orig_idx));
s_subset[new_idx] = score;
}
// 3. 执行 NMS (返回保留下来的子集索引)
let keep = self.nms(&b_subset, &s_subset, nms_thr);
// 4. 组装最终结果 [x1, y1, x2, y2, score, class_id]
keep.into_iter()
.map(|k_idx| {
let (orig_idx, score, cls_id) = candidates[k_idx];
let b = boxes.row(orig_idx);
vec![b[0], b[1], b[2], b[3], score, cls_id as f32]
})
.collect()
}
/// 6. get_bbox (完全解耦 OpenCV)
pub fn get_bbox(&self, image_bytes: &[u8]) -> Result<Vec<Vec<i32>>> {
// 使用 image crate 解码
let dynamic_img = image::load_from_memory(image_bytes).context("Failed to decode image")?;
let (orig_w, orig_h) = dynamic_img.dimensions();
let (input_tensor, ratio) = self.preproc(&dynamic_img, (416, 416))?;
// tract 推理
let outputs = self.session.run(tvec!(input_tensor.into()))?;
let output_array = outputs[0]
.to_array_view::<f32>()?
.to_owned()
.into_dimensionality::<Ix3>()?;
let predictions = self.demo_postprocess(output_array, (416, 416));
let pred = predictions.slice(s![0, .., ..]);
let boxes = pred.slice(s![.., 0..4]);
let scores = &pred.slice(s![.., 4..5]) * &pred.slice(s![.., 5..]);
let mut boxes_xyxy = Array2::<f32>::zeros(boxes.raw_dim());
for i in 0..boxes.nrows() {
boxes_xyxy[[i, 0]] = (boxes[[i, 0]] - boxes[[i, 2]] / 2.0) / ratio;
boxes_xyxy[[i, 1]] = (boxes[[i, 1]] - boxes[[i, 3]] / 2.0) / ratio;
boxes_xyxy[[i, 2]] = (boxes[[i, 0]] + boxes[[i, 2]] / 2.0) / ratio;
boxes_xyxy[[i, 3]] = (boxes[[i, 1]] + boxes[[i, 3]] / 2.0) / ratio;
}
let detections = self.multiclass_nms(&boxes_xyxy, &scores, 0.45, 0.1);
Ok(detections
.into_iter()
.map(|d| {
vec![
(d[0] as i32).max(0).min(orig_w as i32),
(d[1] as i32).max(0).min(orig_h as i32),
(d[2] as i32).max(0).min(orig_w as i32),
(d[3] as i32).max(0).min(orig_h as i32),
]
})
.collect())
}
}

View File

@@ -1,34 +1,125 @@
use anyhow::{Context, Result}; use anyhow::{Context, Result, anyhow, bail};
use base64::{Engine as _, engine::general_purpose}; use base64::{Engine as _, engine::general_purpose};
use image::{DynamicImage, GenericImageView, ImageBuffer, Rgb, RgbImage}; use image::{DynamicImage, GenericImageView, ImageBuffer, ImageFormat, Luma, Rgb, RgbImage, Rgba};
use std::fs;
use std::path::{Path, PathBuf}; use std::path::{Path, PathBuf};
use tract_onnx::prelude::tract_ndarray::Array3; use tract_onnx::prelude::tract_ndarray::{Array3, ArrayD, ArrayViewD};
#[derive(Debug)]
pub enum ColorMode {
RGB,
RGBA,
L,
}
/// 定义支持的输入类型枚举 /// 定义支持的输入类型枚举
pub enum ImageInput { pub enum ImageInput {
Bytes(Vec<u8>), Bytes(Vec<u8>),
Array(Array3<u8>), Array(ArrayD<u8>), // 对应 numpy 数组
Path(PathBuf), Path(PathBuf),
Base64(String), Base64(String),
DynamicImage(DynamicImage), DynamicImage(DynamicImage),
} }
/// 模拟 Python 的 load_image_from_input /// 模拟 Python 的 load_image_from_input
#[allow(dead_code)] #[allow(dead_code)]
pub fn load_image_from_input(input: ImageInput) -> Result<DynamicImage> { pub fn load_image_from_input(img_input: ImageInput) -> Result<DynamicImage> {
match input { match img_input {
ImageInput::DynamicImage(img) => Ok(img), // 2. 处理字节流 (Bytes)
_ => todo!("后续补充"), ImageInput::Bytes(bytes) => {
image::load_from_memory(&bytes).context("Failed to load image from bytes")
}
// 1. 已经是 DynamicImage
ImageInput::DynamicImage(i) => Ok(i),
// 5. 处理 ndarray (Numpy-like)
// 假设输入是 HWC 格式的 Array3<u8>
ImageInput::Array(a) => numpy_to_pil_image(a.view()),
// 4. 处理 Base64 字符串
ImageInput::Base64(b) => base64_to_image(&b),
// 3. 处理文件路径 (Path)
ImageInput::Path(p) => image::open(p).context("Failed to open image from path"),
}
}
fn base64_to_image(b64_str: &str) -> Result<DynamicImage> {
// 过滤掉可能存在的 base64 前缀,例如 "data:image/png;base64,"
let clean_b64 = if let Some(pos) = b64_str.find(",") {
&b64_str[pos + 1..]
} else {
&b64_str
};
let bytes = general_purpose::STANDARD
.decode(clean_b64.trim())
.map_err(|e| anyhow!("Base64 decode error: {}", e))?;
image::load_from_memory(&bytes).context("Failed to load image from decoded base64")
}
/// 读取图片文件并转换为 base64 编码字符串
/// 对应 Python 版 get_img_base64
pub fn get_img_base64<P: AsRef<Path>>(image_path: P) -> Result<String> {
// 1. 读取文件原始字节流
// 使用 AsRef<Path> 泛型可以让函数同时支持 String, &str, PathBuf 等类型
let image_data = fs::read(&image_path)
.with_context(|| format!("Failed to read image file: {:?}", image_path.as_ref()))?;
// 2. 进行 Base64 编码
// 使用 STANDARD 引擎对齐 Python 的 base64.b64encode
let b64_string = general_purpose::STANDARD.encode(image_data);
Ok(b64_string)
}
/// 封装数组转图像的逻辑,对齐 Python 版 _numpy_to_pil_image
fn numpy_to_pil_image(array: ArrayViewD<u8>) -> Result<DynamicImage> {
let shape = array.shape();
let dim = shape.len();
// 1. 确保数据在内存中是连续的 (C order / Standard Layout)
// 如果 arr 是经过切片或转置的,这一步会进行必要的内存拷贝
let standard = array.as_standard_layout();
let (raw_data, _offset) = standard.to_owned().into_raw_vec_and_offset();
match dim {
// 对应 Python: len(array.shape) == 2 (灰度图 H, W)
2 => {
let (h, w) = (shape[0], shape[1]);
ImageBuffer::<Luma<u8>, _>::from_raw(w as u32, h as u32, raw_data)
.map(DynamicImage::ImageLuma8)
.ok_or_else(|| anyhow!("Failed to create Luma image from 2D array"))
}
// 对应 Python: len(array.shape) == 3 (H, W, C)
3 => {
let (h, w, c) = (shape[0], shape[1], shape[2]);
match c {
// 对应 Python: array.shape[2] == 1 (单通道 H, W, 1)
1 => ImageBuffer::<Luma<u8>, _>::from_raw(w as u32, h as u32, raw_data)
.map(DynamicImage::ImageLuma8),
// 对应 Python: array.shape[2] == 3 (RGB H, W, 3)
3 => ImageBuffer::<Rgb<u8>, _>::from_raw(w as u32, h as u32, raw_data)
.map(DynamicImage::ImageRgb8),
// 对应 Python: array.shape[2] == 4 (RGBA H, W, 4)
4 => ImageBuffer::<Rgba<u8>, _>::from_raw(w as u32, h as u32, raw_data)
.map(DynamicImage::ImageRgba8),
_ => {
return Err(anyhow!("不支持的通道数: {}", c));
}
}
.ok_or_else(|| anyhow!("转换彩色图失败"))
}
_ => Err(anyhow!("不支持的数组维度: {},仅支持 2D 或 3D", dim)),
} }
} }
/// 对应 Python 的 png_rgba_black_preprocess /// 对应 Python 的 png_rgba_black_preprocess
/// 将带有透明通道的图片转换为白色背景的 RGB 图片 /// 将带有透明通道的图片转换为白色背景的 RGB 图片
#[allow(dead_code)]
pub fn png_rgba_white_preprocess(img: &DynamicImage) -> DynamicImage { pub fn png_rgba_white_preprocess(img: &DynamicImage) -> DynamicImage {
// 1. 检查是否包含透明通道,如果没有,直接克隆并返回 // 1. 检查是否包含透明通道,如果没有,直接克隆并返回
if !img.color().has_alpha() { if !img.color().has_alpha() {
return img.clone(); return DynamicImage::ImageRgb8(img.to_rgb8());
} }
let (width, height) = img.dimensions(); let (width, height) = img.dimensions();
@@ -41,22 +132,133 @@ pub fn png_rgba_white_preprocess(img: &DynamicImage) -> DynamicImage {
// 4. 遍历像素并手动进行 Alpha 混合 // 4. 遍历像素并手动进行 Alpha 混合
// 对应 Python 的 image.paste(img, ..., mask=img) // 对应 Python 的 image.paste(img, ..., mask=img)
for (x, y, pixel) in rgba_img.enumerate_pixels() { // 使用 enumerate_pixels_mut 同时获取坐标和背景像素的可变引用,减少查找开销
let alpha = pixel[3] as f32 / 255.0; for (x, y, bg_pixel) in background.enumerate_pixels_mut() {
// 安全性说明x, y 源自 background 尺寸,与 rgba_img 一致get_pixel 是安全的
let src_pixel = rgba_img.get_pixel(x, y);
let alpha_u8 = src_pixel[3];
if alpha >= 1.0 { match alpha_u8 {
// 完全不透明,直接覆盖 // 情况 A完全不透明,直接覆盖背景色
background.put_pixel(x, y, Rgb([pixel[0], pixel[1], pixel[2]])); 255 => {
} else if alpha > 0.0 { bg_pixel.0 = [src_pixel[0], src_pixel[1], src_pixel[2]];
// 半透明,执行 Alpha 混合公式: (src * alpha) + (dst * (1 - alpha)) }
let bg_pixel = background.get_pixel(x, y); // 情况 B完全透明保持背景色白色无需操作
let r = (pixel[0] as f32 * alpha + bg_pixel[0] as f32 * (1.0 - alpha)) as u8; 0 => {
let g = (pixel[1] as f32 * alpha + bg_pixel[1] as f32 * (1.0 - alpha)) as u8; continue;
let b = (pixel[2] as f32 * alpha + bg_pixel[2] as f32 * (1.0 - alpha)) as u8; }
background.put_pixel(x, y, Rgb([r, g, b])); // 情况 C半透明进行 Alpha 混合计算
_ => {
let alpha = alpha_u8 as f32 / 255.0;
let inv_alpha = 1.0 - alpha;
bg_pixel[0] = (src_pixel[0] as f32 * alpha + 255.0 * inv_alpha).round() as u8;
bg_pixel[1] = (src_pixel[1] as f32 * alpha + 255.0 * inv_alpha).round() as u8;
bg_pixel[2] = (src_pixel[2] as f32 * alpha + 255.0 * inv_alpha).round() as u8;
}
} }
// alpha == 0 的情况不需要处理,因为背景已经是白色了
} }
DynamicImage::ImageRgb8(background) DynamicImage::ImageRgb8(background)
} }
pub fn image_to_numpy(image: &DynamicImage, mode: ColorMode) -> Result<Array3<u8>> {
// 1. 模式转换 (对应 image.convert(target_mode)),此函数在时保留看后续优化是否需要替代image_to_ndarray
// Rust image 库通过 to_rgb8, to_luma8 等方法实现转换
let (width, height) = image.dimensions();
let (channels, raw) = match mode {
ColorMode::RGB => (3, image.to_rgb8().into_raw()),
ColorMode::L => (1, image.to_luma8().into_raw()),
ColorMode::RGBA => (4, image.to_rgba8().into_raw()),
};
Array3::from_shape_vec((height as usize, width as usize, channels), raw)
.map_err(|e| anyhow!("Failed to build ndarray: {}", e))
}
pub fn numpy_to_image(array: ArrayViewD<u8>, mode: ColorMode) -> Result<DynamicImage> {
let shape = array.shape();
// 1. 基础维度检查 (必须是 H, W, C 三维数组)
if shape.len() != 3 {
bail!("Expected a 3D array (H, W, C), but got {}D", shape.len());
}
let height = shape[0] as u32;
let width = shape[1] as u32;
let channels = shape[2];
// 2. 检查通道数是否与模式匹配
let expected_channels = match mode {
ColorMode::L => 1,
ColorMode::RGB => 3,
ColorMode::RGBA => 4,
};
if channels != expected_channels {
bail!(
"Mode {:?} expects {} channels, but array has {}",
mode,
expected_channels,
channels
);
}
// 确保数据连续性 (C-order)
let standard = array.as_standard_layout();
let (raw_data, _) = standard.to_owned().into_raw_vec_and_offset();
match mode {
ColorMode::L => ImageBuffer::<Luma<u8>, _>::from_raw(width, height, raw_data)
.map(DynamicImage::ImageLuma8),
ColorMode::RGB => ImageBuffer::<Rgb<u8>, _>::from_raw(width, height, raw_data)
.map(DynamicImage::ImageRgb8),
ColorMode::RGBA => ImageBuffer::<Rgba<u8>, _>::from_raw(width, height, raw_data)
.map(DynamicImage::ImageRgba8),
}
.ok_or_else(|| anyhow!("Failed to construct ImageBuffer. Buffer size might be incorrect."))
}
pub fn image_to_ndarray(img: &DynamicImage) -> Array3<u8> {
let (width, height) = img.dimensions();
// 1. 强制转为 RGB8 (丢弃 Alpha 通道,与 Python 的 target_mode='RGB' 对齐)
let rgb_img = img.to_rgb8();
// 2. 获取原始像素数据
let raw_data = rgb_img.into_raw();
// 3. 构造数组 (通道数改为 3)
Array3::from_shape_vec((height as usize, width as usize, 3), raw_data)
.expect("Failed to construct ndarray from image") // 建议显式报错,而不是返回全黑图
}
#[allow(dead_code)]
fn save_rust_result(result: &ImageBuffer<Luma<f32>, Vec<f32>>, filename: &str) {
let (width, height) = result.dimensions();
// 1. 寻找最值进行归一化
let mut max_val = f32::MIN;
let mut min_val = f32::MAX;
for p in result.pixels() {
if p.0[0] > max_val {
max_val = p.0[0];
}
if p.0[0] < min_val {
min_val = p.0[0];
}
}
// 2. 创建 8 位灰度图
let mut out_buf = ImageBuffer::new(width, height);
for y in 0..height {
for x in 0..width {
let val = result.get_pixel(x, y).0[0];
let normalized = if max_val > min_val {
((val - min_val) / (max_val - min_val) * 255.0) as u8
} else {
0u8
};
out_buf.put_pixel(x, y, Luma([normalized]));
}
}
// 3. 保存
DynamicImage::ImageLuma8(out_buf).save(filename).unwrap();
println!("Rust 结果热力图已保存至: {}", filename);
}

View File

@@ -1,169 +1,112 @@
pub mod base;
mod charset; mod charset;
mod det_model;
mod image_io; mod image_io;
mod image_processor; mod image_processor;
mod model; mod model;
mod model_loader;
mod ocr_model;
mod utils; mod utils;
pub mod slide_model;
mod cv2;
use anyhow::Result;
use image::DynamicImage;
use std::fmt::{Display, Formatter};
use crate::image_io::png_rgba_white_preprocess;
use crate::image_processor::{convert_to_grayscale, resize_image};
use anyhow::{Context, Result};
use image::{DynamicImage, imageops::FilterType};
use tract_onnx::prelude::*;
// 关键点:直接使用 tract 重导出的 ndarray // 关键点:直接使用 tract 重导出的 ndarray
use tract_onnx::prelude::tract_ndarray::s; use crate::charset::get_default_charset;
use crate::det_model::Det;
use crate::model_loader::ModelSession;
use crate::ocr_model::Ocr;
pub enum ModelSpec {
/// 默认 OCR (使用内置路径)
OcrModel,
DetModel,
/// 自定义 OCR (路径由用户提供)
CustomOcrModel {
path: String,
charset: Vec<String>,
},
}
impl ModelSpec {
// 将默认路径定义为内部关联常量
const DEFAULT_OCR_PATH: &'static str = "models/common_sml2h3_f32.onnx";
const DEFAULT_DET_PATH: &'static str = "models/common_det.onnx";
}
pub enum Runtime {
Ocr(Ocr),
Det(Det),
}
impl Runtime {
// 统一获取描述的方法
pub fn desc(&self) -> String {
match self {
Runtime::Ocr(s) => s.desc(), // 调用 Ocr 结构体的方法
Runtime::Det(s) => s.desc(), // 调用 Det 结构体的方法
}
}
}
pub struct DdddOcrBuilder {
mode: ModelSpec,
}
impl DdddOcrBuilder {
pub fn new() -> Self {
Self {
mode: ModelSpec::OcrModel,
}
}
/// 切换为检测模式
pub fn det(mut self) -> Self {
self.mode = ModelSpec::DetModel;
self
}
/// 设置自定义 OCR 路径
pub fn custom_ocr(mut self, path: String, charset: Vec<String>) -> Self {
// 直接重写枚举,替换掉之前的 Ocr 或 Det
self.mode = ModelSpec::CustomOcrModel { path, charset };
self
}
/// 核心初始化逻辑
pub fn build(self) -> Result<DdddOcr> {
let runtime = match self.mode {
ModelSpec::OcrModel => Runtime::Ocr(Ocr::new(ModelSpec::DEFAULT_OCR_PATH.into(), get_default_charset())?),
ModelSpec::DetModel => Runtime::Det(Det::new(ModelSpec::DEFAULT_DET_PATH.into())?),
ModelSpec::CustomOcrModel { path, charset } => Runtime::Ocr(Ocr::new(path, charset)?),
};
Ok(DdddOcr { runtime })
}
}
pub struct DdddOcr { pub struct DdddOcr {
session: RunnableModel<TypedFact, Box<dyn TypedOp>, Graph<TypedFact, Box<dyn TypedOp>>>, runtime: Runtime,
}
impl Display for DdddOcr {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
write!(f, "DdddOcr(session: {})", self.runtime.desc())
}
} }
impl DdddOcr { impl DdddOcr {
pub fn new<P>(model_path: P) -> Result<Self>
where
P: AsRef<std::path::Path>,
{
let session = onnx()
.model_for_path(model_path)
.with_context(|| "加载 ONNX 模型失败,请检查路径是否正确")?
.into_optimized()?
.into_runnable()?;
Ok(Self { session })
}
pub fn classification(&self, img: &DynamicImage) -> Result<String> { pub fn classification(&self, img: &DynamicImage) -> Result<String> {
let tensor = self.preprocess_image(img, false)?; match &self.runtime {
Runtime::Ocr(s) => s.predict(img, false),
// let result = self.session.run(tvec!(tensor.into()))?; Runtime::Det(_) => Err(anyhow::anyhow!("当前模型是检测模型,无法执行 OCR")),
// 3. 解析结果
// let output = result[0].to_array_view::<i64>()?;
let output = self.inference(tensor)?;
let output2 = self.process_text_output(&output)?;
Ok(Self::ctc_decode_indices(&output2))
}
/// 对应 Python 的 _preprocess_image
/// 负责:透明背景修复 -> 灰度化 -> 按比例 Resize -> 归一化 -> 4维张量转换
fn preprocess_image(&self, img: &DynamicImage, png_fix: bool) -> Result<Tensor> {
// A. 修复 PNG 透明背景 (内部逻辑你之前已实现)
let _ = if png_fix && img.color().has_alpha() {
png_rgba_white_preprocess(img)
} else {
img.clone()
};
let h = 64u32;
let w = (img.width() as f32 * (h as f32 / img.height() as f32)) as u32;
let gray_img = convert_to_grayscale(img);
let resized = resize_image(&gray_img, w, h);
// resized.save("debug_preprocessed.png").unwrap();
// 1. 预处理:转灰度 -> Resize -> 归一化
// let resized = img.resize_exact(w, h, FilterType::Lanczos3).to_luma8();
// 使用 tract_ndarray 构造,避免版本冲突
let array =
tract_ndarray::Array4::from_shape_fn((1, 1, h as usize, w as usize), |(_, _, y, x)| {
let pixel = resized.get_pixel(x as u32, y as u32)[0] as f32;
(pixel / 255.0 - 0.5) / 0.5
});
let tensor = Tensor::from(array);
Ok(tensor)
}
/// 对应 Python 的 _inference
fn inference(&self, tensor: Tensor) -> Result<Tensor> {
// tract 的 run 会返回一个 Vec<TValue>,我们通常只需要第一个输出
// let result = self.session.run(tvec!(tensor.into()))?;
let mut result = self
.session
.run(tvec!(tensor.into()))
.context("执行模型推理失败")?;
println!("模型输出原始数据: {:?}", result);
Ok(result.remove(0).into_tensor())
}
/// 核心解析逻辑:将模型输出的各种维度/类型的 Tensor 转为字符索引序列
fn process_text_output(&self, raw_tensor: &Tensor) -> Result<Vec<i64>> {
let shape = raw_tensor.shape();
println!("模型输出shape数据: {:?}", shape);
let datum_type = raw_tensor.datum_type();
println!("模型输出datum_type数据: {:?}", datum_type);
match raw_tensor.datum_type() {
// 情况 1: huashi666 式模型,直接输出 i64 索引 (通常是模型内部做好了 Argmax)
DatumType::I64 => {
let view = raw_tensor.to_array_view::<i64>()?;
Ok(view.iter().cloned().collect())
}
// 情况 2: sml2h3 原版模型,输出 F32 概率矩阵
DatumType::F32 => {
let view = raw_tensor.to_array_view::<f32>()?;
let (steps, classes, data_view) = match shape.len() {
3 => {
if shape[1] == 1 {
// 形状: [Steps, 1, Classes] -> 你的原有逻辑
(shape[0], shape[2], view.into_dyn())
} else if shape[0] == 1 {
// 形状: [1, Steps, Classes] -> 另一种常见导出格式
(shape[1], shape[2], view.into_dyn())
} else {
// 默认取第一个 batch: [Batch, Steps, Classes]
// 使用 slice 对应 Python 的 output[0, :, :]
let sliced = view.slice(s![0, .., ..]);
(shape[1], shape[2], sliced.into_dyn())
} }
} }
2 => { pub fn detection(&self, img: &[u8]) -> Result<Vec<Vec<i32>>> {
// 形状: [Steps, Classes] -> 已经剥离了 Batch 维度 match &self.runtime {
(shape[0], shape[1], view.into_dyn()) Runtime::Det(s) => s.predict(img),
} Runtime::Ocr(_) => Err(anyhow::anyhow!("当前模型是 OCR 模型,无法执行检测")),
_ => return Err(anyhow::anyhow!("不支持的输出维度: {:?}", shape)),
};
let array_2d = data_view.to_shape((steps, classes))?;
//
// 对每一行执行 Argmax (寻找概率最大的字符索引)
let indices = array_2d
.outer_iter()
.map(|row| {
row.iter()
.enumerate()
.max_by(|(_, a), (_, b)| {
a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal)
})
.map(|(idx, _)| idx as i64)
.unwrap_or(0)
})
.collect();
Ok(indices)
}
_ => Err(anyhow::anyhow!(
"不支持的模型输出数据类型: {:?}",
raw_tensor.datum_type()
)),
}
}
fn ctc_decode_indices(predicted_indices: &[i64]) -> String {
println!("indices模型输出原始数据: {:?}", predicted_indices);
use crate::charset::CHARSET_BETA;
// 对应 _ctc_decode_indices 的逻辑:去重、去 blank (0)
let mut res = String::new();
let mut prev_idx: i64 = -1;
for &idx in predicted_indices {
// 1. 跳过连续重复的索引
// 2. 跳过 blank 字符 (假设索引 0 是 blank)
if idx != prev_idx && idx != 0 {
if let Ok(u_idx) = usize::try_from(idx) {
if let Some(&char_str) = CHARSET_BETA.get(u_idx) {
res.push_str(char_str);
} }
} }
} }
prev_idx = idx;
}
println!("最终识别出的验证码是: {}", res);
res
}
}
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {

40
src/model_loader.rs Normal file
View File

@@ -0,0 +1,40 @@
use anyhow::Context;
use image::DynamicImage;
use tract_onnx::onnx;
use tract_onnx::prelude::*;
// 关键点:直接使用 tract 重导出的 ndarray
use crate::image_io::png_rgba_white_preprocess;
use crate::image_processor::{convert_to_grayscale, resize_image};
use std::collections::HashMap;
use tract_onnx::prelude::tract_ndarray::s;
/// OCR 模型:包含路径和字符集
pub enum ModelType {
Ocr,
Det,
Custom,
}
// 定义统一的 trait
pub trait ModelSession {
fn get_model_type(&self) -> ModelType;
fn desc(&self) -> String;
}
pub struct ModelLoader {
pub session: RunnableModel<TypedFact, Box<dyn TypedOp>, Graph<TypedFact, Box<dyn TypedOp>>>,
}
impl ModelLoader {
pub fn load_model<P>(model_path: P) -> anyhow::Result<Self>
where
P: AsRef<std::path::Path>,
{
let session = onnx()
.model_for_path(model_path)
.with_context(|| "加载 ONNX 模型失败,请检查路径是否正确")?
.into_optimized()?
.into_runnable()?;
Ok(Self { session })
}
}

251
src/ocr_model.rs Normal file
View File

@@ -0,0 +1,251 @@
use crate::base::ModelArgs;
use crate::image_io::png_rgba_white_preprocess;
use crate::image_processor::{convert_to_grayscale, resize_image};
use crate::model_loader::{ModelLoader, ModelSession, ModelType};
use anyhow::Context;
use image::DynamicImage;
use tract_onnx::prelude::tract_ndarray::s;
use tract_onnx::prelude::{
DatumType, Graph, IntoTensor, RunnableModel, Tensor, TypedFact, TypedOp, tract_ndarray, tvec,
};
// 颜色过滤的自定义范围:(低值RGB, 高值RGB)
pub type ColorRange = ((u8, u8, u8), (u8, u8, u8));
// 字符集范围类型
#[derive(Debug, Clone)]
pub enum CharsetRange {
All, // 所有字符
Digit, // 数字
Letter, // 字母
Alphanumeric, // 字母数字
Single(String), // 单字符串
Multiple(Vec<String>), // 多个字符串
Range(char, char), // 字符范围
Custom(Vec<char>), // 自定义字符列表
}
#[derive(Debug, Clone)]
pub struct PredictArgs {
/// 是否修复PNG格式问题
pub png_fix: bool,
/// 是否返回概率信息
pub probability: bool,
/// 颜色过滤:保留的颜色列表
pub color_filter_colors: Option<Vec<String>>,
/// 颜色过滤自定义RGB范围
pub color_filter_custom_ranges: Option<Vec<ColorRange>>,
/// 字符集范围
pub charset_range: Option<CharsetRange>,
}
impl Default for PredictArgs {
fn default() -> Self {
Self {
png_fix: false,
probability: false,
color_filter_colors: None,
color_filter_custom_ranges: None,
charset_range: None,
}
}
}
impl PredictArgs {
pub fn new() -> Self {
Self::default()
}
// Builder 模式方法
pub fn png_fix(mut self, enabled: bool) -> Self {
self.png_fix = enabled;
self
}
pub fn probability(mut self, enabled: bool) -> Self {
self.probability = enabled;
self
}
pub fn color_filter_colors(mut self, colors: Vec<String>) -> Self {
self.color_filter_colors = Some(colors);
self
}
pub fn color_filter_custom_ranges(mut self, ranges: Vec<ColorRange>) -> Self {
self.color_filter_custom_ranges = Some(ranges);
self
}
pub fn charset_range(mut self, range: CharsetRange) -> Self {
self.charset_range = Some(range);
self
}
// 便捷构造方法
pub fn quick() -> Self {
Self::default()
}
pub fn with_probability() -> Self {
Self::default().probability(true)
}
pub fn with_png_fix() -> Self {
Self::default().png_fix(true)
}
}
pub struct Ocr {
session: RunnableModel<TypedFact, Box<dyn TypedOp>, Graph<TypedFact, Box<dyn TypedOp>>>,
charset: Vec<String>,
}
impl ModelSession for Ocr {
fn get_model_type(&self) -> ModelType {
todo!()
}
fn desc(&self) -> String {
"Ocr Model 加载成功".to_string()
}
}
impl Ocr {
pub fn new(model_path: String, charset: Vec<String>) -> Result<Self, anyhow::Error> {
let session = ModelLoader::load_model(&model_path)?.session;
Ok(Self { session, charset })
}
pub fn predict(&self, image: &DynamicImage, png_fix: bool) -> Result<String, anyhow::Error> {
let tensor = self.preprocess_image(image, png_fix)?;
//
// let result = self.session.run(tvec!(tensor.into()))?;
// // 3. 解析结果
// // let output = result[0].to_array_view::<i64>()?;
let output = self.inference(tensor)?;
let output2 = self.process_text_output(&output)?;
Ok(self.ctc_decode_indices(&output2))
// Ok("ocr result".to_string())
}
/// 对应 Python 的 _preprocess_image
/// 负责:透明背景修复 -> 灰度化 -> 按比例 Resize -> 归一化 -> 4维张量转换
fn preprocess_image(&self, img: &DynamicImage, png_fix: bool) -> anyhow::Result<Tensor> {
// A. 修复 PNG 透明背景 (内部逻辑你之前已实现)
let _ = if png_fix && img.color().has_alpha() {
png_rgba_white_preprocess(img)
} else {
img.clone()
};
let h = 64u32;
let w = (img.width() as f32 * (h as f32 / img.height() as f32)) as u32;
let gray_img = convert_to_grayscale(img);
let resized = resize_image(&gray_img, w, h);
// resized.save("debug_preprocessed.png").unwrap();
// 1. 预处理:转灰度 -> Resize -> 归一化
// let resized = img.resize_exact(w, h, FilterType::Lanczos3).to_luma8();
// 使用 tract_ndarray 构造,避免版本冲突
let array =
tract_ndarray::Array4::from_shape_fn((1, 1, h as usize, w as usize), |(_, _, y, x)| {
let pixel = resized.get_pixel(x as u32, y as u32)[0] as f32;
(pixel / 255.0 - 0.5) / 0.5
});
let tensor = Tensor::from(array);
Ok(tensor)
}
/// 对应 Python 的 _inference
fn inference(&self, tensor: Tensor) -> anyhow::Result<Tensor> {
// tract 的 run 会返回一个 Vec<TValue>,我们通常只需要第一个输出
// let result = self.session.run(tvec!(tensor.into()))?;
let mut result = self
.session
.run(tvec!(tensor.into()))
.context("执行模型推理失败")?;
println!("模型输出原始数据: {:?}", result);
Ok(result.remove(0).into_tensor())
}
/// 核心解析逻辑:将模型输出的各种维度/类型的 Tensor 转为字符索引序列
fn process_text_output(&self, raw_tensor: &Tensor) -> anyhow::Result<Vec<i64>> {
let shape = raw_tensor.shape();
println!("模型输出shape数据: {:?}", shape);
let datum_type = raw_tensor.datum_type();
println!("模型输出datum_type数据: {:?}", datum_type);
match raw_tensor.datum_type() {
// 情况 1: huashi666 式模型,直接输出 i64 索引 (通常是模型内部做好了 Argmax)
DatumType::I64 => {
let view = raw_tensor.to_array_view::<i64>()?;
Ok(view.iter().cloned().collect())
}
// 情况 2: sml2h3 原版模型,输出 F32 概率矩阵
DatumType::F32 => {
let view = raw_tensor.to_array_view::<f32>()?;
let (steps, classes, data_view) = match shape.len() {
3 => {
if shape[1] == 1 {
// 形状: [Steps, 1, Classes] -> 你的原有逻辑
(shape[0], shape[2], view.into_dyn())
} else if shape[0] == 1 {
// 形状: [1, Steps, Classes] -> 另一种常见导出格式
(shape[1], shape[2], view.into_dyn())
} else {
// 默认取第一个 batch: [Batch, Steps, Classes]
// 使用 slice 对应 Python 的 output[0, :, :]
let sliced = view.slice(s![0, .., ..]);
(shape[1], shape[2], sliced.into_dyn())
}
}
2 => {
// 形状: [Steps, Classes] -> 已经剥离了 Batch 维度
(shape[0], shape[1], view.into_dyn())
}
_ => return Err(anyhow::anyhow!("不支持的输出维度: {:?}", shape)),
};
let array_2d = data_view.to_shape((steps, classes))?;
//
// 对每一行执行 Argmax (寻找概率最大的字符索引)
let indices = array_2d
.outer_iter()
.map(|row| {
row.iter()
.enumerate()
.max_by(|(_, a), (_, b)| {
a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal)
})
.map(|(idx, _)| idx as i64)
.unwrap_or(0)
})
.collect();
Ok(indices)
}
_ => Err(anyhow::anyhow!(
"不支持的模型输出数据类型: {:?}",
raw_tensor.datum_type()
)),
}
}
fn ctc_decode_indices(&self, predicted_indices: &[i64]) -> String {
println!("indices模型输出原始数据: {:?}", predicted_indices);
// 对应 _ctc_decode_indices 的逻辑:去重、去 blank (0)
let mut res = String::new();
let mut prev_idx: i64 = -1;
for &idx in predicted_indices {
// 1. 跳过连续重复的索引
// 2. 跳过 blank 字符 (假设索引 0 是 blank)
if idx != prev_idx && idx != 0 {
if let Ok(u_idx) = usize::try_from(idx) {
if let Some(char_str) = self.charset.get(u_idx) {
res.push_str(char_str);
} else {
// 保护逻辑:如果模型预测的索引超出了字符集范围
eprintln!("警告: 预测索引 {} 超出字符集范围", u_idx);
}
}
}
prev_idx = idx;
}
println!("最终识别出的验证码是: {}", res);
res
}
}

276
src/slide_model.rs Normal file
View File

@@ -0,0 +1,276 @@
use crate::cv2::{min_max_loc, rgb_to_gray, ndarray_to_luma8, abs_diff};
use crate::image_io::image_to_ndarray;
use anyhow::{Context, Result, anyhow};
use image::{DynamicImage, GenericImageView};
use image::{ImageBuffer, Luma};
use imageproc::distance_transform::Norm;
use imageproc::edges::canny;
use imageproc::morphology::{close, open};
use imageproc::region_labelling::{Connectivity, connected_components};
use imageproc::template_matching::{MatchTemplateMethod, match_template};
use std::cmp::{max, min};
use imageproc::contrast::{threshold, ThresholdType};
use tract_onnx::prelude::tract_ndarray::{Array2, Array3, ArrayView2, ArrayView3, Axis, s};
pub struct SlideResult {
pub target: [i32; 2],
pub target_x: i32,
pub target_y: i32,
pub confidence: f64,
}
pub struct Slide;
impl Slide {
pub fn new() -> Self {
Self
}
/// 对应 Python: slide_match
pub fn slide_match(
&self,
target_image: &DynamicImage,
background_image: &DynamicImage,
simple_target: bool,
) -> Result<SlideResult> {
let target_array = image_to_ndarray(target_image);
let background_array = image_to_ndarray(background_image);
self.perform_slide_match(target_array.view(), background_array.view(), simple_target)
.map_err(|e| anyhow!("滑块匹配失败: {}", e))
}
/// 对应 Python: slide_comparison
/// 用于比较带坑位的图片与原始背景图,定位差异点
pub fn slide_comparison(
&self,
target_image: &DynamicImage,
background_image: &DynamicImage,
) -> Result<SlideResult> {
// 1. 转换为 ndarray (HWC RGB)
let target_array = image_to_ndarray(target_image);
let background_array = image_to_ndarray(background_image);
// 2. 执行比较逻辑 (对应 _perform_slide_comparison)
self.perform_slide_comparison(target_array.view(), background_array.view())
.map_err(|e| anyhow!("滑块比较执行失败: {}", e))
}
/// 对应 Python: _perform_slide_comparison
pub fn perform_slide_comparison(
&self,
target: ArrayView3<u8>,
background: ArrayView3<u8>,
) -> Result<SlideResult> {
let (h, w, _) = target.dim();
// 1. 计算图像差异并灰度化 (对应 cv2.absdiff + cv2.cvtColor)
// 使用 OpenCV 标准权重公式0.299R + 0.587G + 0.114B
// let mut diff_buffer = ImageBuffer::new(w as u32, h as u32);
// for y in 0..h {
// for x in 0..w {
// let r_diff = (target[[y, x, 0]] as i16 - background[[y, x, 0]] as i16).abs() as f32;
// let g_diff = (target[[y, x, 1]] as i16 - background[[y, x, 1]] as i16).abs() as f32;
// let b_diff = (target[[y, x, 2]] as i16 - background[[y, x, 2]] as i16).abs() as f32;
//
// let gray_diff = (0.299 * r_diff + 0.587 * g_diff + 0.114 * b_diff) as u8;
// diff_buffer.put_pixel(x as u32, y as u32, Luma([gray_diff]));
// }
// }
// 1. 计算差异数组 (复用 cv2::absdiff)
let diff_array = abs_diff(&target, &background);
// 2. 转换为灰度数组 (复用你的 cv2::rgb_to_gray)
let gray_array = rgb_to_gray(diff_array.view());
// 3. 转为 ImageBuffer 以使用 imageproc 的高级功能
let gray_buffer = ndarray_to_luma8(gray_array.view());
// 2. 二值化 (对应 cv2.threshold(..., 30, 255, cv2.THRESH_BINARY))
// let mut binary = ImageBuffer::new(w as u32, h as u32);
// for (x, y, pixel) in diff_buffer.enumerate_pixels() {
// let val = if pixel.0[0] > 30 { 255u8 } else { 0u8 };
// binary.put_pixel(x, y, Luma([val]));
// }
let binary = threshold(&gray_buffer, 30, ThresholdType::Binary);
// 3. 形态学操作去噪 (对应 cv2.morphologyEx)
// 闭运算 (Close): 先膨胀后腐蚀,用于填补缺口内的细小黑色空洞
// 开运算 (Open): 先腐蚀后膨胀,用于消除背景中的白色噪点点
let norm = Norm::LInf; // 对应 3x3 的矩形内核
let radius = 1u8; // 1 表示 3x3 的范围2 表示 5x5 的范围
let closed = close(&binary, norm, radius);
let cleaned = open(&closed, norm, radius);
// 4. 寻找最大连通区域 (对应 findContours + max area)
// connected_components 会给每个独立的白色区域打上不同的标签 (ID)
let background_label = Luma([0u8]);
let labelled = connected_components(&cleaned, Connectivity::Eight, background_label);
// 统计每个标签出现的频率(即面积)
let mut max_label = 0;
let mut max_area = 0;
let mut areas = std::collections::HashMap::new();
for pixel in labelled.pixels() {
let label = pixel.0[0];
if label == 0 {
continue;
} // 跳过背景
let count = areas.entry(label).or_insert(0);
*count += 1;
if *count > max_area {
max_area = *count;
max_label = label;
}
}
if max_label == 0 {
return Ok(SlideResult {
target: [0, 0],
target_x: 0,
target_y: 0,
confidence: 0.0,
});
}
// 5. 计算最大区域的边界框 (对应 cv2.boundingRect)
let mut min_x = w as u32;
let mut max_x = 0;
let mut min_y = h as u32;
let mut max_y = 0;
for (x, y, pixel) in labelled.enumerate_pixels() {
if pixel.0[0] == max_label {
min_x = min(min_x, x);
max_x = max(max_x, x);
min_y = min(min_y, y);
max_y = max(max_y, y);
}
}
// 6. 计算中心点
let rect_w = max_x - min_x;
let rect_h = max_y - min_y;
let center_x = (min_x + rect_w / 2) as i32;
let center_y = (min_y + rect_h / 2) as i32;
Ok(SlideResult {
target: [center_x, center_y],
target_x: center_x,
target_y: center_y,
confidence: 1.0, // Comparison 模式下通常认为找到即为 1.0
})
}
/// 对应 Python: _perform_slide_match
// 在 SlideEngine 中修改此入口进行测试
fn perform_slide_match(
&self,
target: ArrayView3<u8>,
background: ArrayView3<u8>,
simple_target: bool, // 增加这个参数
) -> Result<SlideResult> {
// 1. 统一灰度化
let target_gray = rgb_to_gray(target);
let background_gray = rgb_to_gray(background);
if simple_target {
// 2a. 简单模式:直接在灰度图上匹配
self.simple_template_match(target_gray.view(), background_gray.view())
} else {
// 2b. 复杂模式:先提取边缘,再匹配
self.edge_based_match(target_gray.view(), background_gray.view())
}
}
/// 对应 Python: _simple_template_match
/// 使用 SAD (Sum of Absolute Differences) 算法
/// 核心模板匹配SAD + 有效像素过滤
fn simple_template_match(
&self,
target: ArrayView2<u8>,
background: ArrayView2<u8>,
) -> Result<SlideResult> {
// 1. 将 ndarray 转换为 imageproc 需要的 ImageBuffer (无拷贝或轻量转换)
// let (bh, bw) = background.dim();
// 转换逻辑 (假设你已经有方法转回 ImageBuffer)
let t_buf = ndarray_to_luma8(target);
let b_buf = ndarray_to_luma8(background);
// t_buf.save("debug_rust_target.png").unwrap();
// 2. 调用 imageproc 的 NCC 算法 (等价于 cv2.TM_CCOEFF_NORMED)
// 模板匹配 (完全对齐 cv2.matchTemplate(..., cv2.TM_CCOEFF_NORMED))
let result = match_template(
&b_buf,
&t_buf,
MatchTemplateMethod::CrossCorrelationNormalized,
);
// save_rust_result(&result, "debug_rust_target2.png");
// 3. 寻找最大值 (等价于 cv2.minMaxLoc)
let (max_val, max_loc) = min_max_loc(&result);
// 4. 计算中心点 (与 Python 逻辑完全一致)
let (th, tw) = target.dim();
let center_x = max_loc.0 as i32 + (tw as i32 / 2);
let center_y = max_loc.1 as i32 + (th as i32 / 2);
// println!("Rust Target Width (tw): {}", tw);
// println!("Rust Best Max Loc X: {}", max_loc.0);
// println!("Rust Final Center X: {}", center_x);
Ok(SlideResult {
target: [center_x, center_y],
target_x: center_x,
target_y: center_y,
confidence: max_val as f64,
})
}
/// 对应 Python: _edge_based_match
/// 基于边缘检测的滑块匹配 (对齐 Python _edge_based_match)
pub fn edge_based_match(
&self,
target: ArrayView2<u8>,
background: ArrayView2<u8>,
) -> Result<SlideResult> {
// 1. 将 ndarray 转换为 ImageBuffer
// 注意Canny 和 match_template 需要 ImageBuffer 格式
let t_buf = ndarray_to_luma8(target);
let b_buf = ndarray_to_luma8(background);
// 2. 边缘检测 (完全对齐 cv2.Canny(50, 150))
// 这步会生成黑底白线的二值化边缘图
let target_edges = canny(&t_buf, 50.0, 150.0);
let background_edges = canny(&b_buf, 50.0, 150.0);
// target_edges.save("debug_target_edges.png").ok();
// background_edges.save("debug_bg_edges.png").ok();
// 3. 模板匹配 (完全对齐 cv2.matchTemplate(..., cv2.TM_CCOEFF_NORMED))
// 在边缘图上计算归一化互相关系数
let result = match_template(
&background_edges,
&target_edges,
MatchTemplateMethod::CrossCorrelationNormalized,
);
// 4. 找到最佳匹配位置 (对齐 cv2.minMaxLoc)
let (max_val, max_loc) = min_max_loc(&result);
// 5. 计算中心位置 (对齐 Python 逻辑)
// target_w, target_h 来自输入数组的维度
let (th, tw) = target.dim();
let center_x = max_loc.0 as i32 + (tw as i32 / 2);
let center_y = max_loc.1 as i32 + (th as i32 / 2);
// 打印调试信息,方便与 Python 对比
// println!("Edge Match: max_val: {}, max_loc: {:?}", max_val, max_loc);
println!("-Rust Target Width (tw): {}", tw);
println!("-Rust Best Max Loc X: {}", max_loc.0);
println!("-Rust Final Center X: {}", center_x);
Ok(SlideResult {
target: [center_x, center_y],
target_x: center_x,
target_y: center_y,
confidence: max_val as f64,
})
}
}

View File

@@ -1,12 +1,63 @@
use ddddocr_rs::DdddOcr; // 假设你的包名是这个 use std::fs;
use std::path::Path;
use image::Rgb;
use ddddocr_rs::{DdddOcr, DdddOcrBuilder}; // 假设你的包名是这个
use ddddocr_rs::slide_model::Slide;
fn load_image<P: AsRef<Path>>(path: P) -> anyhow::Result<image::DynamicImage> {
// 1. 先将泛型转为具体的 &Path 引用
let path_ref = path.as_ref();
// 2. 调用 open 时传入引用image::open 支持 AsRef<Path>
image::open(path_ref)
.map_err(|e| {
// 3. 此时 path_ref 依然有效,可以安全地在闭包中使用
anyhow::anyhow!("无法加载图片 {:?}: {}", path_ref, e)
})
}
/// 将检测结果绘制在图像上并保存
fn save_debug_image( image_bytes: &[u8], bboxes: &Vec<Vec<i32>>, output_path: &str) -> anyhow::Result<()> {
let dynamic_img = image::load_from_memory(image_bytes)?;
let mut img = dynamic_img.to_rgb8();
let (width, height) = img.dimensions();
let red = Rgb([255u8, 0, 0]);
for bbox in bboxes {
// 基础边界检查
let x1 = bbox[0].max(0).min(width as i32 - 1) as u32;
let y1 = bbox[1].max(0).min(height as i32 - 1) as u32;
let x2 = bbox[2].max(0).min(width as i32 - 1) as u32;
let y2 = bbox[3].max(0).min(height as i32 - 1) as u32;
// 绘制横向线条
for x in x1..=x2 {
img.put_pixel(x, y1, red);
img.put_pixel(x, y2, red);
// 如果要加粗,多画一行
if y1 + 1 < height { img.put_pixel(x, y1 + 1, red); }
if y2.saturating_sub(1) > 0 { img.put_pixel(x, y2 - 1, red); }
}
// 绘制纵向线条
for y in y1..=y2 {
img.put_pixel(x1, y, red);
img.put_pixel(x2, y, red);
// 如果要加粗,多画一列
if x1 + 1 < width { img.put_pixel(x1 + 1, y, red); }
if x2.saturating_sub(1) > 0 { img.put_pixel(x2 - 1, y, red); }
}
}
img.save(output_path)?;
Ok(())
}
#[test] #[test]
fn test_full_classification() { fn test_full_classification() {
// 1. 初始化模型 // 1. 初始化模型
let ocr = DdddOcr::new("model/common.onnx").expect("模型加载失败"); let ocr = DdddOcrBuilder::new().build().expect("模型加载失败");
// 2. 加载测试图片 // 2. 加载测试图片
let img = image::open("samples/code3.png").expect("测试图片不存在"); let img = image::open("samples/code2.png").expect("测试图片不存在");
// 3. 执行识别 // 3. 执行识别
let result = ocr.classification(&img).expect("识别过程出错"); let result = ocr.classification(&img).expect("识别过程出错");
@@ -14,3 +65,89 @@ fn test_full_classification() {
println!("识别结果: {}", result); println!("识别结果: {}", result);
assert!(!result.is_empty()); assert!(!result.is_empty());
} }
#[test]
fn test_det_load()->anyhow::Result<()>{
let det = DdddOcrBuilder::new().det().build()?;
let image_path = "samples/det1.png";
let image_bytes = fs::read(image_path)
.map_err(|e| anyhow::anyhow!("无法读取图片 {}: {}", image_path, e))?;
println!("图片读取成功,字节大小: {}", image_bytes.len());
let bboxes =det.detection(&image_bytes)?;
println!(":?{}",det);
println!("检测到的目标数量: {}", bboxes.len());
if bboxes.is_empty() {
println!("未检测到任何目标。");
} else {
save_debug_image(&image_bytes, &bboxes, "samples/result.jpg")?;
for (i, bbox) in bboxes.iter().enumerate() {
println!("目标 [{}]: x1={}, y1={}, x2={}, y2={}", i, bbox[0], bbox[1], bbox[2], bbox[3]);
}
}
Ok(())
}
#[test]
fn test_real_slide_match() {
let engine = Slide::new();
// 1. 加载你准备好的测试图
// 假设图片放在项目根目录下的 assets 文件夹
let target_img = load_image("samples/hua.png")
.expect("请确保 samples/hua.png 存在");
let bg_img = load_image("samples/huatu.png")
.expect("请确保 samples/huatu.png 存在");
// 2. 执行匹配
// 如果是那种带有明显阴影边缘的复杂滑块,建议 simple_target 传 false
let start = std::time::Instant::now();
let result = engine.slide_match(&target_img, &bg_img, true)
.expect("Slide match 执行失败");
let duration = start.elapsed();
// 3. 打印结果
println!("-------------------------------------------");
println!("滑块匹配测试结果:");
println!("检测坐标: [x: {}, y: {}]", result.target_x, result.target_y);
println!("置信度: {:.4}", result.confidence);
println!("耗时: {:?}", duration);
println!("-------------------------------------------");
// 验证基本逻辑:坐标不应为 0 (除非匹配失败)
assert_eq!(result.target_x, 237);
assert_eq!(result.target_y, 77);
assert!(result.confidence > 0.0);
}
#[test]
fn test_real_slide_comparison() {
let engine = Slide::new();
// 1. 加载你准备好的测试图
// 假设图片放在项目根目录下的 assets 文件夹
let target_img = load_image("samples/ken.jpg")
.expect("请确保 samples/ken.jpg 存在");
let bg_img = load_image("samples/kenyuan.jpg")
.expect("请确保 samples/kenyuan.jpg 存在");
// 2. 执行匹配
// 如果是那种带有明显阴影边缘的复杂滑块,建议 simple_target 传 false
let start = std::time::Instant::now();
let result = engine.slide_comparison(&target_img, &bg_img)
.expect("Slide match 执行失败");
let duration = start.elapsed();
// 3. 打印结果
println!("-------------------------------------------");
println!("滑块匹配测试结果:");
println!("检测坐标: [x: {}, y: {}]", result.target_x, result.target_y);
println!("置信度: {:.4}", result.confidence);
println!("耗时: {:?}", duration);
println!("-------------------------------------------");
// 验证基本逻辑:坐标不应为 0 (除非匹配失败)
assert_eq!(result.target_x, 171);
assert_eq!(result.target_y, 90);
assert!(result.confidence > 0.0);
}