Compare commits
9 Commits
feature-v0
...
feature-v0
| Author | SHA1 | Date | |
|---|---|---|---|
| 0923d92150 | |||
| 0df9022411 | |||
| a51147c888 | |||
| e8b365dced | |||
| f0db625bd1 | |||
| 21bd1c93bf | |||
| 1a329ca273 | |||
| 8fcfa2096e | |||
| cfeb68ad04 |
@@ -9,3 +9,4 @@ tract-onnx = { version = "0.21.10" }
|
|||||||
anyhow = "1.0.102"
|
anyhow = "1.0.102"
|
||||||
image = "0.25.10"
|
image = "0.25.10"
|
||||||
base64 = "0.22.1"
|
base64 = "0.22.1"
|
||||||
|
imageproc = { version = "0.26.2", default-features = true }
|
||||||
35
README.md
@@ -2,8 +2,43 @@
|
|||||||
|
|
||||||
带带弟弟 OCR (ddddocr) 的 Rust 移植版。高性能、低占用,支持多种验证码识别与检测。
|
带带弟弟 OCR (ddddocr) 的 Rust 移植版。高性能、低占用,支持多种验证码识别与检测。
|
||||||
|
|
||||||
|
🧩 滑块识别算法核心知识点总结
|
||||||
|
本项目实现了两种核心匹配模式,其底层逻辑与 OpenCV 的对齐情况如下:
|
||||||
|
|
||||||
|
1. 匹配模式对比 (Match Modes)
|
||||||
|
|**模式**|**算法原理**|**适用场景**|**备注**|
|
||||||
|
|---|---|---|---|
|
||||||
|
|**边缘模式** (Edge-based)|基于 **Canny 边缘检测** 提取轮廓后再进行匹配。|**推荐方案**
|
||||||
|
。适用于绝大多数拼图滑块。|天然免疫拼图周边的透明/黑色留白干扰,坐标最精准。|
|
||||||
|
|**简单模式** (Simple/Gray)|直接基于 **灰度像素值** 进行归一化互相关计算。|适用于无明显边缘、靠颜色差异识别的场景。|对背景和透明边框敏感,可能存在重心偏移。|
|
||||||
|
|
||||||
|
2. 数学公式差异 (NCC vs. CCOEFF)
|
||||||
|
在简单模式下,本项目采用的是 归一化互相关 (NCC),对应 OpenCV 中的 TM_CCORR_NORMED。
|
||||||
|
|
||||||
|
逻辑对齐:Rust 的 match_template 结果与 Python cv2.TM_CCORR_NORMED 完全一致。
|
||||||
|
|
||||||
|
关于偏移:若拼图原始图片(Target)四周包含大量的透明留白:
|
||||||
|
|
||||||
|
CCORR (本项目):会将留白视为图像的一部分,计算出的是整张图片框的中心。
|
||||||
|
|
||||||
|
CCOEFF (OpenCV 默认):会自动进行“均值中心化”,在一定程度上能削弱留白的影响。
|
||||||
|
|
||||||
|
最佳实践:若发现坐标有固定位移,建议优先切换至 边缘模式,或对滑块图进行 Bounding Box 裁剪 后再匹配。
|
||||||
|
|
||||||
|
3. 图像预处理一致性
|
||||||
|
|
||||||
|
为确保识别精度,本项目在 Rust 中完美复刻了 Python OpenCV 的预处理链路:
|
||||||
|
|
||||||
|
- **灰度化权重**:采用 OpenCV 标准感光公式 $0.299R + 0.587G + 0.114B$。
|
||||||
|
|
||||||
|
- **Alpha 处理**:在将 PNG 转为 RGB 时,自动将透明区域填充为黑色,确保与 PIL (Python Imaging Library) 行为一致。
|
||||||
|
|
||||||
|
- **坐标定义**:所有返回坐标均为匹配区域的 **几何中心点** $(x + w/2, y + h/2)$。
|
||||||
|
|
||||||
|
💡 开发者建议:
|
||||||
|
|
||||||
|
如果识别结果在 $X$ 轴上有大约 $10px$ 左右的固定误差,通常是因为滑块原图自带了透明边距(留白)。此时请确保
|
||||||
|
simple_target=false。该模式会通过 Canny 边缘检测 提取轮廓特征,能自动锁定拼图实体并忽略背景留白的像素干扰。
|
||||||
鸣谢 (Credits)
|
鸣谢 (Credits)
|
||||||
|
|
||||||
- 本项目是 [ddddocr](https://github.com/sml2h3/ddddocr) 的 Rust 移植版本,原作者为 sml2h3。衷心感谢原作者对 OCR 社区做出的杰出贡献。
|
- 本项目是 [ddddocr](https://github.com/sml2h3/ddddocr) 的 Rust 移植版本,原作者为 sml2h3。衷心感谢原作者对 OCR 社区做出的杰出贡献。
|
||||||
|
|||||||
@@ -1,5 +1,5 @@
|
|||||||
fn main() {
|
fn main() {
|
||||||
let ocr = ddddocr_rs::DdddOcr::new("model/common.onnx").unwrap();
|
let ocr = ddddocr_rs::DdddOcrBuilder::new().build().unwrap();
|
||||||
let img = image::open("samples/code3.png").unwrap();
|
let img = image::open("samples/code3.png").unwrap();
|
||||||
println!("Result: {}", ocr.classification(&img).unwrap());
|
println!("Result: {}", ocr.classification(&img).unwrap());
|
||||||
}
|
}
|
||||||
BIN
samples/det1.png
Normal file
|
After Width: | Height: | Size: 70 KiB |
BIN
samples/det2.png
Normal file
|
After Width: | Height: | Size: 95 KiB |
BIN
samples/det3.jpg
Normal file
|
After Width: | Height: | Size: 26 KiB |
BIN
samples/hua.png
Normal file
|
After Width: | Height: | Size: 4.1 KiB |
BIN
samples/huatu.png
Normal file
|
After Width: | Height: | Size: 94 KiB |
BIN
samples/ken.jpg
Normal file
|
After Width: | Height: | Size: 8.3 KiB |
BIN
samples/kenyuan.jpg
Normal file
|
After Width: | Height: | Size: 8.0 KiB |
@@ -514,3 +514,6 @@ pub const CHARSET_BETA: &[&str] = &[
|
|||||||
"谬", "溝", "言", "哽", "婿", "猿", "跗", "獴", "俜", "呙", "弗", "凿", "窭", "铌", "友", "唉",
|
"谬", "溝", "言", "哽", "婿", "猿", "跗", "獴", "俜", "呙", "弗", "凿", "窭", "铌", "友", "唉",
|
||||||
"怫", "荘",
|
"怫", "荘",
|
||||||
];
|
];
|
||||||
|
pub fn get_default_charset() -> Vec<String> {
|
||||||
|
CHARSET_BETA.iter().map(|&s| s.to_string()).collect()
|
||||||
|
}
|
||||||
@@ -1,62 +0,0 @@
|
|||||||
use anyhow::{Context, Result};
|
|
||||||
use base64::{Engine as _, engine::general_purpose};
|
|
||||||
use image::{DynamicImage, GenericImageView, ImageBuffer, Rgb, RgbImage};
|
|
||||||
use std::path::{Path, PathBuf};
|
|
||||||
use tract_onnx::prelude::tract_ndarray::Array3;
|
|
||||||
|
|
||||||
/// 定义支持的输入类型枚举
|
|
||||||
pub enum ImageInput {
|
|
||||||
Bytes(Vec<u8>),
|
|
||||||
Array(Array3<u8>),
|
|
||||||
Path(PathBuf),
|
|
||||||
Base64(String),
|
|
||||||
DynamicImage(DynamicImage),
|
|
||||||
}
|
|
||||||
|
|
||||||
/// 模拟 Python 的 load_image_from_input
|
|
||||||
#[allow(dead_code)]
|
|
||||||
pub fn load_image_from_input(input: ImageInput) -> Result<DynamicImage> {
|
|
||||||
match input {
|
|
||||||
ImageInput::DynamicImage(img) => Ok(img),
|
|
||||||
_ => todo!("后续补充"),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/// 对应 Python 的 png_rgba_black_preprocess
|
|
||||||
/// 将带有透明通道的图片转换为白色背景的 RGB 图片
|
|
||||||
#[allow(dead_code)]
|
|
||||||
pub fn png_rgba_white_preprocess(img: &DynamicImage) -> DynamicImage {
|
|
||||||
// 1. 检查是否包含透明通道,如果没有,直接克隆并返回
|
|
||||||
if !img.color().has_alpha() {
|
|
||||||
return img.clone();
|
|
||||||
}
|
|
||||||
|
|
||||||
let (width, height) = img.dimensions();
|
|
||||||
|
|
||||||
// 2. 创建一个新的 RGB 图像缓冲,默认填充为白色 (255, 255, 255)
|
|
||||||
let mut background = ImageBuffer::from_pixel(width, height, Rgb([255u8, 255u8, 255u8]));
|
|
||||||
|
|
||||||
// 3. 获取原图的 RGBA 视图
|
|
||||||
let rgba_img = img.to_rgba8();
|
|
||||||
|
|
||||||
// 4. 遍历像素并手动进行 Alpha 混合
|
|
||||||
// 对应 Python 的 image.paste(img, ..., mask=img)
|
|
||||||
for (x, y, pixel) in rgba_img.enumerate_pixels() {
|
|
||||||
let alpha = pixel[3] as f32 / 255.0;
|
|
||||||
|
|
||||||
if alpha >= 1.0 {
|
|
||||||
// 完全不透明,直接覆盖
|
|
||||||
background.put_pixel(x, y, Rgb([pixel[0], pixel[1], pixel[2]]));
|
|
||||||
} else if alpha > 0.0 {
|
|
||||||
// 半透明,执行 Alpha 混合公式: (src * alpha) + (dst * (1 - alpha))
|
|
||||||
let bg_pixel = background.get_pixel(x, y);
|
|
||||||
let r = (pixel[0] as f32 * alpha + bg_pixel[0] as f32 * (1.0 - alpha)) as u8;
|
|
||||||
let g = (pixel[1] as f32 * alpha + bg_pixel[1] as f32 * (1.0 - alpha)) as u8;
|
|
||||||
let b = (pixel[2] as f32 * alpha + bg_pixel[2] as f32 * (1.0 - alpha)) as u8;
|
|
||||||
background.put_pixel(x, y, Rgb([r, g, b]));
|
|
||||||
}
|
|
||||||
// alpha == 0 的情况不需要处理,因为背景已经是白色了
|
|
||||||
}
|
|
||||||
|
|
||||||
DynamicImage::ImageRgb8(background)
|
|
||||||
}
|
|
||||||
246
src/lib.rs
@@ -1,173 +1,111 @@
|
|||||||
mod charset;
|
mod charset;
|
||||||
mod image_io;
|
|
||||||
mod image_processor;
|
|
||||||
mod model;
|
|
||||||
mod utils;
|
|
||||||
|
|
||||||
use crate::image_io::png_rgba_white_preprocess;
|
pub mod models;
|
||||||
use crate::image_processor::{convert_to_grayscale, resize_image};
|
pub mod utils;
|
||||||
use anyhow::{Context, Result};
|
|
||||||
use image::{DynamicImage, imageops::FilterType};
|
use anyhow::Result;
|
||||||
use tract_onnx::prelude::*;
|
use image::DynamicImage;
|
||||||
|
use std::fmt::{Display, Formatter};
|
||||||
|
|
||||||
// 关键点:直接使用 tract 重导出的 ndarray
|
// 关键点:直接使用 tract 重导出的 ndarray
|
||||||
use tract_onnx::prelude::tract_ndarray::s;
|
use crate::charset::get_default_charset;
|
||||||
|
use models::det::Det;
|
||||||
|
use models::loader::ModelSession;
|
||||||
|
use models::ocr::Ocr;
|
||||||
|
pub enum ModelSpec {
|
||||||
|
/// 默认 OCR (使用内置路径)
|
||||||
|
OcrModel,
|
||||||
|
DetModel,
|
||||||
|
/// 自定义 OCR (路径由用户提供)
|
||||||
|
CustomOcrModel {
|
||||||
|
path: String,
|
||||||
|
charset: Vec<String>,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
impl ModelSpec {
|
||||||
|
// 将默认路径定义为内部关联常量
|
||||||
|
const DEFAULT_OCR_PATH: &'static str = "models/common.onnx";
|
||||||
|
const DEFAULT_DET_PATH: &'static str = "models/common_det.onnx";
|
||||||
|
}
|
||||||
|
pub enum Runtime {
|
||||||
|
Ocr(Ocr),
|
||||||
|
Det(Det),
|
||||||
|
}
|
||||||
|
impl Runtime {
|
||||||
|
// 统一获取描述的方法
|
||||||
|
pub fn desc(&self) -> String {
|
||||||
|
match self {
|
||||||
|
Runtime::Ocr(s) => s.desc(), // 调用 Ocr 结构体的方法
|
||||||
|
Runtime::Det(s) => s.desc(), // 调用 Det 结构体的方法
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
pub struct DdddOcrBuilder {
|
||||||
|
mode: ModelSpec,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl DdddOcrBuilder {
|
||||||
|
pub fn new() -> Self {
|
||||||
|
Self {
|
||||||
|
mode: ModelSpec::OcrModel,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// 切换为检测模式
|
||||||
|
pub fn det(mut self) -> Self {
|
||||||
|
self.mode = ModelSpec::DetModel;
|
||||||
|
self
|
||||||
|
}
|
||||||
|
|
||||||
|
/// 设置自定义 OCR 路径
|
||||||
|
pub fn custom_ocr(mut self, path: String, charset: Vec<String>) -> Self {
|
||||||
|
// 直接重写枚举,替换掉之前的 Ocr 或 Det
|
||||||
|
self.mode = ModelSpec::CustomOcrModel { path, charset };
|
||||||
|
self
|
||||||
|
}
|
||||||
|
|
||||||
|
/// 核心初始化逻辑
|
||||||
|
pub fn build(self) -> Result<DdddOcr> {
|
||||||
|
let runtime = match self.mode {
|
||||||
|
ModelSpec::OcrModel => Runtime::Ocr(Ocr::new(
|
||||||
|
ModelSpec::DEFAULT_OCR_PATH.into(),
|
||||||
|
get_default_charset(),
|
||||||
|
)?),
|
||||||
|
ModelSpec::DetModel => Runtime::Det(Det::new(ModelSpec::DEFAULT_DET_PATH.into())?),
|
||||||
|
ModelSpec::CustomOcrModel { path, charset } => Runtime::Ocr(Ocr::new(path, charset)?),
|
||||||
|
};
|
||||||
|
|
||||||
|
Ok(DdddOcr { runtime })
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
pub struct DdddOcr {
|
pub struct DdddOcr {
|
||||||
session: RunnableModel<TypedFact, Box<dyn TypedOp>, Graph<TypedFact, Box<dyn TypedOp>>>,
|
runtime: Runtime,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Display for DdddOcr {
|
||||||
|
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
|
||||||
|
write!(f, "DdddOcr(session: {})", self.runtime.desc())
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl DdddOcr {
|
impl DdddOcr {
|
||||||
pub fn new<P>(model_path: P) -> Result<Self>
|
|
||||||
where
|
|
||||||
P: AsRef<std::path::Path>,
|
|
||||||
{
|
|
||||||
let session = onnx()
|
|
||||||
.model_for_path(model_path)
|
|
||||||
.with_context(|| "加载 ONNX 模型失败,请检查路径是否正确")?
|
|
||||||
.into_optimized()?
|
|
||||||
.into_runnable()?;
|
|
||||||
Ok(Self { session })
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn classification(&self, img: &DynamicImage) -> Result<String> {
|
pub fn classification(&self, img: &DynamicImage) -> Result<String> {
|
||||||
let tensor = self.preprocess_image(img, false)?;
|
match &self.runtime {
|
||||||
|
Runtime::Ocr(s) => s.predict(img, false),
|
||||||
// let result = self.session.run(tvec!(tensor.into()))?;
|
Runtime::Det(_) => Err(anyhow::anyhow!("当前模型是检测模型,无法执行 OCR")),
|
||||||
// 3. 解析结果
|
|
||||||
// let output = result[0].to_array_view::<i64>()?;
|
|
||||||
let output = self.inference(tensor)?;
|
|
||||||
let output2 = self.process_text_output(&output)?;
|
|
||||||
Ok(Self::ctc_decode_indices(&output2))
|
|
||||||
}
|
|
||||||
/// 对应 Python 的 _preprocess_image
|
|
||||||
/// 负责:透明背景修复 -> 灰度化 -> 按比例 Resize -> 归一化 -> 4维张量转换
|
|
||||||
fn preprocess_image(&self, img: &DynamicImage, png_fix: bool) -> Result<Tensor> {
|
|
||||||
// A. 修复 PNG 透明背景 (内部逻辑你之前已实现)
|
|
||||||
let _ = if png_fix && img.color().has_alpha() {
|
|
||||||
png_rgba_white_preprocess(img)
|
|
||||||
} else {
|
|
||||||
img.clone()
|
|
||||||
};
|
|
||||||
|
|
||||||
let h = 64u32;
|
|
||||||
let w = (img.width() as f32 * (h as f32 / img.height() as f32)) as u32;
|
|
||||||
let gray_img = convert_to_grayscale(img);
|
|
||||||
let resized = resize_image(&gray_img, w, h);
|
|
||||||
// resized.save("debug_preprocessed.png").unwrap();
|
|
||||||
// 1. 预处理:转灰度 -> Resize -> 归一化
|
|
||||||
// let resized = img.resize_exact(w, h, FilterType::Lanczos3).to_luma8();
|
|
||||||
|
|
||||||
// 使用 tract_ndarray 构造,避免版本冲突
|
|
||||||
let array =
|
|
||||||
tract_ndarray::Array4::from_shape_fn((1, 1, h as usize, w as usize), |(_, _, y, x)| {
|
|
||||||
let pixel = resized.get_pixel(x as u32, y as u32)[0] as f32;
|
|
||||||
(pixel / 255.0 - 0.5) / 0.5
|
|
||||||
});
|
|
||||||
|
|
||||||
let tensor = Tensor::from(array);
|
|
||||||
|
|
||||||
Ok(tensor)
|
|
||||||
}
|
|
||||||
/// 对应 Python 的 _inference
|
|
||||||
fn inference(&self, tensor: Tensor) -> Result<Tensor> {
|
|
||||||
// tract 的 run 会返回一个 Vec<TValue>,我们通常只需要第一个输出
|
|
||||||
// let result = self.session.run(tvec!(tensor.into()))?;
|
|
||||||
let mut result = self
|
|
||||||
.session
|
|
||||||
.run(tvec!(tensor.into()))
|
|
||||||
.context("执行模型推理失败")?;
|
|
||||||
println!("模型输出原始数据: {:?}", result);
|
|
||||||
Ok(result.remove(0).into_tensor())
|
|
||||||
}
|
|
||||||
/// 核心解析逻辑:将模型输出的各种维度/类型的 Tensor 转为字符索引序列
|
|
||||||
fn process_text_output(&self, raw_tensor: &Tensor) -> Result<Vec<i64>> {
|
|
||||||
let shape = raw_tensor.shape();
|
|
||||||
println!("模型输出shape数据: {:?}", shape);
|
|
||||||
let datum_type = raw_tensor.datum_type();
|
|
||||||
println!("模型输出datum_type数据: {:?}", datum_type);
|
|
||||||
|
|
||||||
match raw_tensor.datum_type() {
|
|
||||||
// 情况 1: huashi666 式模型,直接输出 i64 索引 (通常是模型内部做好了 Argmax)
|
|
||||||
DatumType::I64 => {
|
|
||||||
let view = raw_tensor.to_array_view::<i64>()?;
|
|
||||||
Ok(view.iter().cloned().collect())
|
|
||||||
}
|
|
||||||
|
|
||||||
// 情况 2: sml2h3 原版模型,输出 F32 概率矩阵
|
|
||||||
DatumType::F32 => {
|
|
||||||
let view = raw_tensor.to_array_view::<f32>()?;
|
|
||||||
let (steps, classes, data_view) = match shape.len() {
|
|
||||||
3 => {
|
|
||||||
if shape[1] == 1 {
|
|
||||||
// 形状: [Steps, 1, Classes] -> 你的原有逻辑
|
|
||||||
(shape[0], shape[2], view.into_dyn())
|
|
||||||
} else if shape[0] == 1 {
|
|
||||||
// 形状: [1, Steps, Classes] -> 另一种常见导出格式
|
|
||||||
(shape[1], shape[2], view.into_dyn())
|
|
||||||
} else {
|
|
||||||
// 默认取第一个 batch: [Batch, Steps, Classes]
|
|
||||||
// 使用 slice 对应 Python 的 output[0, :, :]
|
|
||||||
let sliced = view.slice(s![0, .., ..]);
|
|
||||||
(shape[1], shape[2], sliced.into_dyn())
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
2 => {
|
pub fn detection(&self, img: &[u8]) -> Result<Vec<Vec<i32>>> {
|
||||||
// 形状: [Steps, Classes] -> 已经剥离了 Batch 维度
|
match &self.runtime {
|
||||||
(shape[0], shape[1], view.into_dyn())
|
Runtime::Det(s) => s.predict(img),
|
||||||
|
Runtime::Ocr(_) => Err(anyhow::anyhow!("当前模型是 OCR 模型,无法执行检测")),
|
||||||
}
|
}
|
||||||
_ => return Err(anyhow::anyhow!("不支持的输出维度: {:?}", shape)),
|
|
||||||
};
|
|
||||||
let array_2d = data_view.to_shape((steps, classes))?;
|
|
||||||
//
|
|
||||||
// 对每一行执行 Argmax (寻找概率最大的字符索引)
|
|
||||||
let indices = array_2d
|
|
||||||
.outer_iter()
|
|
||||||
.map(|row| {
|
|
||||||
row.iter()
|
|
||||||
.enumerate()
|
|
||||||
.max_by(|(_, a), (_, b)| {
|
|
||||||
a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal)
|
|
||||||
})
|
|
||||||
.map(|(idx, _)| idx as i64)
|
|
||||||
.unwrap_or(0)
|
|
||||||
})
|
|
||||||
.collect();
|
|
||||||
Ok(indices)
|
|
||||||
}
|
|
||||||
_ => Err(anyhow::anyhow!(
|
|
||||||
"不支持的模型输出数据类型: {:?}",
|
|
||||||
raw_tensor.datum_type()
|
|
||||||
)),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
fn ctc_decode_indices(predicted_indices: &[i64]) -> String {
|
|
||||||
println!("indices模型输出原始数据: {:?}", predicted_indices);
|
|
||||||
|
|
||||||
use crate::charset::CHARSET_BETA;
|
|
||||||
// 对应 _ctc_decode_indices 的逻辑:去重、去 blank (0)
|
|
||||||
let mut res = String::new();
|
|
||||||
let mut prev_idx: i64 = -1;
|
|
||||||
|
|
||||||
for &idx in predicted_indices {
|
|
||||||
// 1. 跳过连续重复的索引
|
|
||||||
// 2. 跳过 blank 字符 (假设索引 0 是 blank)
|
|
||||||
if idx != prev_idx && idx != 0 {
|
|
||||||
if let Ok(u_idx) = usize::try_from(idx) {
|
|
||||||
if let Some(&char_str) = CHARSET_BETA.get(u_idx) {
|
|
||||||
res.push_str(char_str);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
prev_idx = idx;
|
|
||||||
}
|
|
||||||
println!("最终识别出的验证码是: {}", res);
|
|
||||||
res
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
mod tests {
|
mod tests {
|
||||||
use super::*;
|
|
||||||
#[test]
|
#[test]
|
||||||
fn test_ctc_decode_indices() {
|
fn test_ctc_decode_indices() {
|
||||||
// 模拟一个 DdddOcr 实例(如果 decode 不依赖 session,可以设为相关函数)
|
// 模拟一个 DdddOcr 实例(如果 decode 不依赖 session,可以设为相关函数)
|
||||||
|
|||||||
40
src/models/base.rs
Normal file
@@ -0,0 +1,40 @@
|
|||||||
|
pub trait ModelArgs {
|
||||||
|
// 获取模型路径
|
||||||
|
fn model_path(&self) -> &str;
|
||||||
|
|
||||||
|
// 获取字符集(由于 Det 没有,所以返回 Option)
|
||||||
|
fn charset(&self) -> Option<&str>;
|
||||||
|
}
|
||||||
|
|
||||||
|
pub struct HasCharset {
|
||||||
|
pub charset: String,
|
||||||
|
} // 给 Ocr 和 Custom 用
|
||||||
|
pub struct NoCharset; // 给 Det 用
|
||||||
|
|
||||||
|
pub struct Model<T> {
|
||||||
|
pub path: String,
|
||||||
|
pub metadata: T,
|
||||||
|
}
|
||||||
|
// 针对有字符集的模型 (Ocr / Custom)
|
||||||
|
impl ModelArgs for Model<HasCharset> {
|
||||||
|
fn model_path(&self) -> &str {
|
||||||
|
&self.path
|
||||||
|
}
|
||||||
|
fn charset(&self) -> Option<&str> {
|
||||||
|
Some(&self.metadata.charset)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// 针对没有字符集的模型 (Det)
|
||||||
|
impl ModelArgs for Model<NoCharset> {
|
||||||
|
fn model_path(&self) -> &str {
|
||||||
|
&self.path
|
||||||
|
}
|
||||||
|
fn charset(&self) -> Option<&str> {
|
||||||
|
None
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub type OcrModel = Model<HasCharset>;
|
||||||
|
pub type DetModel = Model<NoCharset>;
|
||||||
|
pub type CustomModel = Model<HasCharset>; // Ocr 和 Custom 逻辑一致,可以复用
|
||||||
263
src/models/det.rs
Normal file
@@ -0,0 +1,263 @@
|
|||||||
|
use crate::models::loader::{ModelLoader, ModelSession, ModelType};
|
||||||
|
use anyhow::{Context, Result};
|
||||||
|
use image::{DynamicImage, GenericImageView, imageops::FilterType};
|
||||||
|
use tract_onnx::prelude::tract_ndarray::{Array2, Array3, Array4, Axis, prelude::*, s};
|
||||||
|
use tract_onnx::prelude::{Graph, RunnableModel, Tensor, TypedFact, TypedOp, tvec};
|
||||||
|
|
||||||
|
pub struct Det {
|
||||||
|
session: RunnableModel<TypedFact, Box<dyn TypedOp>, Graph<TypedFact, Box<dyn TypedOp>>>,
|
||||||
|
}
|
||||||
|
impl ModelSession for Det {
|
||||||
|
fn get_model_type(&self) -> ModelType {
|
||||||
|
todo!()
|
||||||
|
}
|
||||||
|
fn desc(&self) -> String {
|
||||||
|
"Detection Model 加载成功".to_string()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
impl Det {
|
||||||
|
pub fn new(model_path: String) -> Result<Self, anyhow::Error> {
|
||||||
|
let session = ModelLoader::load_model(&model_path)?.session;
|
||||||
|
Ok(Self { session })
|
||||||
|
}
|
||||||
|
pub fn predict(&self, image_bytes: &[u8]) -> Result<Vec<Vec<i32>>> {
|
||||||
|
// Rust 中通常在调用层处理文件/PIL转换,这里直接进入核心逻辑
|
||||||
|
self.get_bbox(image_bytes)
|
||||||
|
}
|
||||||
|
/// 2. preproc: 纯 Rust 实现 (替代 OpenCV)
|
||||||
|
fn preproc(&self, img: &DynamicImage, input_size: (u32, u32)) -> Result<(Tensor, f32)> {
|
||||||
|
let (target_h, target_w) = input_size;
|
||||||
|
let (img_w, img_h) = img.dimensions();
|
||||||
|
|
||||||
|
// 计算缩放比例 (Letterbox)
|
||||||
|
let r = (target_h as f32 / img_h as f32).min(target_w as f32 / img_w as f32);
|
||||||
|
let new_h = (img_h as f32 * r) as u32;
|
||||||
|
let new_w = (img_w as f32 * r) as u32;
|
||||||
|
|
||||||
|
// Resize 图像
|
||||||
|
let resized = img.resize_exact(new_w, new_h, FilterType::Triangle);
|
||||||
|
// 2. 关键:将 DynamicImage 显式转换为 RgbImage (Rgb<u8>)
|
||||||
|
let resized_rgb = resized.to_rgb8();
|
||||||
|
// 创建 114 灰度填充的背景
|
||||||
|
let mut base_img =
|
||||||
|
image::ImageBuffer::from_pixel(target_w, target_h, image::Rgb([114u8, 114, 114]));
|
||||||
|
|
||||||
|
// 将 resize 后的图像覆盖到左上角 (类似于原始代码中的 padded_img[:h, :w])
|
||||||
|
image::imageops::overlay(&mut base_img, &resized_rgb, 0, 0);
|
||||||
|
|
||||||
|
// 构造 NCHW Tensor
|
||||||
|
let mut array = Array4::<f32>::zeros((1, 3, target_h as usize, target_w as usize));
|
||||||
|
for (x, y, pixel) in base_img.enumerate_pixels() {
|
||||||
|
let x = x as usize;
|
||||||
|
let y = y as usize;
|
||||||
|
// 核心对标 Python 的 BGR 逻辑:
|
||||||
|
// pixel[0] 是 R, pixel[1] 是 G, pixel[2] 是 B
|
||||||
|
// 如果模型需要 BGR:
|
||||||
|
// array[[0, 0, y as usize, x as usize]] = pixel[0] as f32;
|
||||||
|
// array[[0, 1, y as usize, x as usize]] = pixel[1] as f32;
|
||||||
|
// array[[0, 2, y as usize, x as usize]] = pixel[2] as f32;
|
||||||
|
array[[0, 0, y, x]] = pixel[2] as f32; // B
|
||||||
|
array[[0, 1, y, x]] = pixel[1] as f32; // G
|
||||||
|
array[[0, 2, y, x]] = pixel[0] as f32; // R
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok((array.into(), r))
|
||||||
|
}
|
||||||
|
|
||||||
|
/// 3. demo_postprocess (逻辑与 Python 一致)
|
||||||
|
fn demo_postprocess(&self, mut outputs: Array3<f32>, img_size: (i32, i32)) -> Array3<f32> {
|
||||||
|
let strides = [8, 16, 32];
|
||||||
|
|
||||||
|
// 遍历每一个 Batch(支持动态 Batch 推理)
|
||||||
|
for mut batch in outputs.axis_iter_mut(Axis(0)) {
|
||||||
|
let mut offset = 0;
|
||||||
|
|
||||||
|
for &stride in &strides {
|
||||||
|
// 计算当前特征图的尺寸
|
||||||
|
let h = img_size.0 / stride;
|
||||||
|
let w = img_size.1 / stride;
|
||||||
|
let f_stride = stride as f32;
|
||||||
|
|
||||||
|
for y in 0..h {
|
||||||
|
for x in 0..w {
|
||||||
|
// 计算当前格子在 25200 个锚点中的线性索引
|
||||||
|
let idx = offset + (y * w + x) as usize;
|
||||||
|
// 1. 还原中心点坐标 (cx, cy)
|
||||||
|
// 公式: (output + grid_offset) * stride
|
||||||
|
batch[[idx, 0]] = (batch[[idx, 0]] + x as f32) * f_stride;
|
||||||
|
batch[[idx, 1]] = (batch[[idx, 1]] + y as f32) * f_stride;
|
||||||
|
|
||||||
|
// 2. 还原宽高 (w, h)
|
||||||
|
// 公式: exp(output) * stride
|
||||||
|
batch[[idx, 2]] = batch[[idx, 2]].exp() * f_stride;
|
||||||
|
batch[[idx, 3]] = batch[[idx, 3]].exp() * f_stride;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// 移动到下一个步长的起始位置
|
||||||
|
offset += (h * w) as usize;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
outputs
|
||||||
|
}
|
||||||
|
|
||||||
|
/// 4. nms
|
||||||
|
fn nms(&self, boxes: &Array2<f32>, scores: &Array1<f32>, nms_thr: f32) -> Vec<usize> {
|
||||||
|
let mut keep = Vec::new();
|
||||||
|
let x1 = boxes.column(0);
|
||||||
|
let y1 = boxes.column(1);
|
||||||
|
let x2 = boxes.column(2);
|
||||||
|
let y2 = boxes.column(3);
|
||||||
|
// 在每一项前加上 &,并确保括号内的计算顺序
|
||||||
|
// 注意:ndarray 的 View 运算需要 &view1 - &view2
|
||||||
|
let areas = (&x2 - &x1 + 1.0) * (&y2 - &y1 + 1.0);
|
||||||
|
|
||||||
|
// 初始排序索引
|
||||||
|
let mut v: Vec<usize> = (0..scores.len()).collect();
|
||||||
|
v.sort_unstable_by(|&i, &j| {
|
||||||
|
scores[j]
|
||||||
|
.partial_cmp(&scores[i])
|
||||||
|
.unwrap_or(std::cmp::Ordering::Equal)
|
||||||
|
});
|
||||||
|
// 我们不使用 v.remove(0),而是直接通过索引池操作
|
||||||
|
let mut active_indices = v;
|
||||||
|
|
||||||
|
while !active_indices.is_empty() {
|
||||||
|
// 取出当前池子中得分最高的框(即第一个元素)
|
||||||
|
let i = active_indices[0];
|
||||||
|
keep.push(i);
|
||||||
|
|
||||||
|
// 如果池子里只剩一个了,直接结束
|
||||||
|
if active_indices.len() == 1 {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
|
||||||
|
// 5. 核心逻辑:使用 retain 一次性过滤掉:
|
||||||
|
// (a) 当前框自己 (idx == i)
|
||||||
|
// (b) 与当前框重叠度过高的框 (iou > nms_thr)
|
||||||
|
active_indices.retain(|&idx| {
|
||||||
|
// 如果是当前正在处理的框,不保留(因为它已经进入 keep 了)
|
||||||
|
if idx == i {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
// 计算 IoU
|
||||||
|
let xx1 = x1[i].max(x1[idx]);
|
||||||
|
let yy1 = y1[i].max(y1[idx]);
|
||||||
|
let xx2 = x2[i].min(x2[idx]);
|
||||||
|
let yy2 = y2[i].min(y2[idx]);
|
||||||
|
|
||||||
|
let w = (xx2 - xx1 + 1.0).max(0.0);
|
||||||
|
let h = (yy2 - yy1 + 1.0).max(0.0);
|
||||||
|
let inter = w * h;
|
||||||
|
|
||||||
|
let iou = inter / (areas[i] + areas[idx] - inter);
|
||||||
|
|
||||||
|
// 只保留 IoU 小于阈值的框
|
||||||
|
iou <= nms_thr
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
keep
|
||||||
|
}
|
||||||
|
|
||||||
|
/// 5. multiclass_nms
|
||||||
|
pub fn multiclass_nms(
|
||||||
|
&self,
|
||||||
|
boxes: &Array2<f32>, // [25200, 4] -> xyxy 格式
|
||||||
|
scores: &Array2<f32>, // [25200, 80] -> 已经乘以 objectness 的得分
|
||||||
|
nms_thr: f32,
|
||||||
|
score_thr: f32,
|
||||||
|
) -> Vec<Vec<f32>> {
|
||||||
|
let mut candidates = Vec::new();
|
||||||
|
|
||||||
|
// 1. 筛选高分框 (单次遍历完成 Argmax 和 Threshold 过滤)
|
||||||
|
for i in 0..scores.nrows() {
|
||||||
|
let row = scores.row(i);
|
||||||
|
|
||||||
|
// 找到当前行(即当前锚点)得分最高的类别
|
||||||
|
let mut max_score = 0.0;
|
||||||
|
let mut cls_id = 0;
|
||||||
|
for (j, &s) in row.iter().enumerate() {
|
||||||
|
if s > max_score {
|
||||||
|
max_score = s;
|
||||||
|
cls_id = j;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// 仅保留超过阈值的候选框
|
||||||
|
if max_score > score_thr {
|
||||||
|
// 暂时存储索引和元数据,避免频繁创建大数组
|
||||||
|
candidates.push((i, max_score, cls_id));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if candidates.is_empty() {
|
||||||
|
return vec![];
|
||||||
|
}
|
||||||
|
|
||||||
|
// 2. 准备 NMS 输入
|
||||||
|
// 构造 NMS 需要的子集数组
|
||||||
|
let mut b_subset = Array2::<f32>::zeros((candidates.len(), 4));
|
||||||
|
let mut s_subset = Array1::<f32>::zeros(candidates.len());
|
||||||
|
|
||||||
|
for (new_idx, &(orig_idx, score, _)) in candidates.iter().enumerate() {
|
||||||
|
b_subset.row_mut(new_idx).assign(&boxes.row(orig_idx));
|
||||||
|
s_subset[new_idx] = score;
|
||||||
|
}
|
||||||
|
|
||||||
|
// 3. 执行 NMS (返回保留下来的子集索引)
|
||||||
|
let keep = self.nms(&b_subset, &s_subset, nms_thr);
|
||||||
|
|
||||||
|
// 4. 组装最终结果 [x1, y1, x2, y2, score, class_id]
|
||||||
|
keep.into_iter()
|
||||||
|
.map(|k_idx| {
|
||||||
|
let (orig_idx, score, cls_id) = candidates[k_idx];
|
||||||
|
let b = boxes.row(orig_idx);
|
||||||
|
vec![b[0], b[1], b[2], b[3], score, cls_id as f32]
|
||||||
|
})
|
||||||
|
.collect()
|
||||||
|
}
|
||||||
|
/// 6. get_bbox (完全解耦 OpenCV)
|
||||||
|
pub fn get_bbox(&self, image_bytes: &[u8]) -> Result<Vec<Vec<i32>>> {
|
||||||
|
// 使用 utils crate 解码
|
||||||
|
let dynamic_img = image::load_from_memory(image_bytes).context("Failed to decode utils")?;
|
||||||
|
let (orig_w, orig_h) = dynamic_img.dimensions();
|
||||||
|
|
||||||
|
let (input_tensor, ratio) = self.preproc(&dynamic_img, (416, 416))?;
|
||||||
|
|
||||||
|
// tract 推理
|
||||||
|
let outputs = self.session.run(tvec!(input_tensor.into()))?;
|
||||||
|
let output_array = outputs[0]
|
||||||
|
.to_array_view::<f32>()?
|
||||||
|
.to_owned()
|
||||||
|
.into_dimensionality::<Ix3>()?;
|
||||||
|
|
||||||
|
let predictions = self.demo_postprocess(output_array, (416, 416));
|
||||||
|
let pred = predictions.slice(s![0, .., ..]);
|
||||||
|
|
||||||
|
let boxes = pred.slice(s![.., 0..4]);
|
||||||
|
let scores = &pred.slice(s![.., 4..5]) * &pred.slice(s![.., 5..]);
|
||||||
|
|
||||||
|
let mut boxes_xyxy = Array2::<f32>::zeros(boxes.raw_dim());
|
||||||
|
for i in 0..boxes.nrows() {
|
||||||
|
boxes_xyxy[[i, 0]] = (boxes[[i, 0]] - boxes[[i, 2]] / 2.0) / ratio;
|
||||||
|
boxes_xyxy[[i, 1]] = (boxes[[i, 1]] - boxes[[i, 3]] / 2.0) / ratio;
|
||||||
|
boxes_xyxy[[i, 2]] = (boxes[[i, 0]] + boxes[[i, 2]] / 2.0) / ratio;
|
||||||
|
boxes_xyxy[[i, 3]] = (boxes[[i, 1]] + boxes[[i, 3]] / 2.0) / ratio;
|
||||||
|
}
|
||||||
|
|
||||||
|
let detections = self.multiclass_nms(&boxes_xyxy, &scores, 0.45, 0.1);
|
||||||
|
|
||||||
|
Ok(detections
|
||||||
|
.into_iter()
|
||||||
|
.map(|d| {
|
||||||
|
vec![
|
||||||
|
(d[0] as i32).max(0).min(orig_w as i32),
|
||||||
|
(d[1] as i32).max(0).min(orig_h as i32),
|
||||||
|
(d[2] as i32).max(0).min(orig_w as i32),
|
||||||
|
(d[3] as i32).max(0).min(orig_h as i32),
|
||||||
|
]
|
||||||
|
})
|
||||||
|
.collect())
|
||||||
|
}
|
||||||
|
}
|
||||||
40
src/models/loader.rs
Normal file
@@ -0,0 +1,40 @@
|
|||||||
|
use anyhow::Context;
|
||||||
|
use image::DynamicImage;
|
||||||
|
use tract_onnx::onnx;
|
||||||
|
use tract_onnx::prelude::*;
|
||||||
|
// 关键点:直接使用 tract 重导出的 ndarray
|
||||||
|
use crate::utils::image_io::png_rgba_white_preprocess;
|
||||||
|
use crate::utils::image_processor::{convert_to_grayscale, resize_image};
|
||||||
|
use std::collections::HashMap;
|
||||||
|
use tract_onnx::prelude::tract_ndarray::s;
|
||||||
|
|
||||||
|
/// OCR 模型:包含路径和字符集
|
||||||
|
|
||||||
|
pub enum ModelType {
|
||||||
|
Ocr,
|
||||||
|
Det,
|
||||||
|
Custom,
|
||||||
|
}
|
||||||
|
// 定义统一的 trait
|
||||||
|
pub trait ModelSession {
|
||||||
|
fn get_model_type(&self) -> ModelType;
|
||||||
|
fn desc(&self) -> String;
|
||||||
|
}
|
||||||
|
|
||||||
|
pub struct ModelLoader {
|
||||||
|
pub session: RunnableModel<TypedFact, Box<dyn TypedOp>, Graph<TypedFact, Box<dyn TypedOp>>>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl ModelLoader {
|
||||||
|
pub fn load_model<P>(model_path: P) -> anyhow::Result<Self>
|
||||||
|
where
|
||||||
|
P: AsRef<std::path::Path>,
|
||||||
|
{
|
||||||
|
let session = onnx()
|
||||||
|
.model_for_path(model_path)
|
||||||
|
.with_context(|| "加载 ONNX 模型失败,请检查路径是否正确")?
|
||||||
|
.into_optimized()?
|
||||||
|
.into_runnable()?;
|
||||||
|
Ok(Self { session })
|
||||||
|
}
|
||||||
|
}
|
||||||
5
src/models/mod.rs
Normal file
@@ -0,0 +1,5 @@
|
|||||||
|
pub mod base;
|
||||||
|
pub mod loader;
|
||||||
|
pub mod ocr;
|
||||||
|
pub mod det;
|
||||||
|
pub mod slide;
|
||||||
251
src/models/ocr.rs
Normal file
@@ -0,0 +1,251 @@
|
|||||||
|
use crate::models::base::ModelArgs;
|
||||||
|
use crate::utils::image_io::png_rgba_white_preprocess;
|
||||||
|
use crate::utils::image_processor::{convert_to_grayscale, resize_image};
|
||||||
|
use crate::models::loader::{ModelLoader, ModelSession, ModelType};
|
||||||
|
use anyhow::Context;
|
||||||
|
use image::DynamicImage;
|
||||||
|
use tract_onnx::prelude::tract_ndarray::s;
|
||||||
|
use tract_onnx::prelude::{
|
||||||
|
DatumType, Graph, IntoTensor, RunnableModel, Tensor, TypedFact, TypedOp, tract_ndarray, tvec,
|
||||||
|
};
|
||||||
|
|
||||||
|
// 颜色过滤的自定义范围:(低值RGB, 高值RGB)
|
||||||
|
pub type ColorRange = ((u8, u8, u8), (u8, u8, u8));
|
||||||
|
|
||||||
|
// 字符集范围类型
|
||||||
|
#[derive(Debug, Clone)]
|
||||||
|
pub enum CharsetRange {
|
||||||
|
All, // 所有字符
|
||||||
|
Digit, // 数字
|
||||||
|
Letter, // 字母
|
||||||
|
Alphanumeric, // 字母数字
|
||||||
|
Single(String), // 单字符串
|
||||||
|
Multiple(Vec<String>), // 多个字符串
|
||||||
|
Range(char, char), // 字符范围
|
||||||
|
Custom(Vec<char>), // 自定义字符列表
|
||||||
|
}
|
||||||
|
#[derive(Debug, Clone)]
|
||||||
|
pub struct PredictArgs {
|
||||||
|
/// 是否修复PNG格式问题
|
||||||
|
pub png_fix: bool,
|
||||||
|
/// 是否返回概率信息
|
||||||
|
pub probability: bool,
|
||||||
|
/// 颜色过滤:保留的颜色列表
|
||||||
|
pub color_filter_colors: Option<Vec<String>>,
|
||||||
|
/// 颜色过滤:自定义RGB范围
|
||||||
|
pub color_filter_custom_ranges: Option<Vec<ColorRange>>,
|
||||||
|
/// 字符集范围
|
||||||
|
pub charset_range: Option<CharsetRange>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Default for PredictArgs {
|
||||||
|
fn default() -> Self {
|
||||||
|
Self {
|
||||||
|
png_fix: false,
|
||||||
|
probability: false,
|
||||||
|
color_filter_colors: None,
|
||||||
|
color_filter_custom_ranges: None,
|
||||||
|
charset_range: None,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl PredictArgs {
|
||||||
|
pub fn new() -> Self {
|
||||||
|
Self::default()
|
||||||
|
}
|
||||||
|
|
||||||
|
// Builder 模式方法
|
||||||
|
pub fn png_fix(mut self, enabled: bool) -> Self {
|
||||||
|
self.png_fix = enabled;
|
||||||
|
self
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn probability(mut self, enabled: bool) -> Self {
|
||||||
|
self.probability = enabled;
|
||||||
|
self
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn color_filter_colors(mut self, colors: Vec<String>) -> Self {
|
||||||
|
self.color_filter_colors = Some(colors);
|
||||||
|
self
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn color_filter_custom_ranges(mut self, ranges: Vec<ColorRange>) -> Self {
|
||||||
|
self.color_filter_custom_ranges = Some(ranges);
|
||||||
|
self
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn charset_range(mut self, range: CharsetRange) -> Self {
|
||||||
|
self.charset_range = Some(range);
|
||||||
|
self
|
||||||
|
}
|
||||||
|
|
||||||
|
// 便捷构造方法
|
||||||
|
pub fn quick() -> Self {
|
||||||
|
Self::default()
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn with_probability() -> Self {
|
||||||
|
Self::default().probability(true)
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn with_png_fix() -> Self {
|
||||||
|
Self::default().png_fix(true)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
pub struct Ocr {
|
||||||
|
session: RunnableModel<TypedFact, Box<dyn TypedOp>, Graph<TypedFact, Box<dyn TypedOp>>>,
|
||||||
|
charset: Vec<String>,
|
||||||
|
}
|
||||||
|
impl ModelSession for Ocr {
|
||||||
|
fn get_model_type(&self) -> ModelType {
|
||||||
|
todo!()
|
||||||
|
}
|
||||||
|
fn desc(&self) -> String {
|
||||||
|
"Ocr Model 加载成功".to_string()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
impl Ocr {
|
||||||
|
pub fn new(model_path: String, charset: Vec<String>) -> Result<Self, anyhow::Error> {
|
||||||
|
let session = ModelLoader::load_model(&model_path)?.session;
|
||||||
|
Ok(Self { session, charset })
|
||||||
|
}
|
||||||
|
pub fn predict(&self, image: &DynamicImage, png_fix: bool) -> Result<String, anyhow::Error> {
|
||||||
|
let tensor = self.preprocess_image(image, png_fix)?;
|
||||||
|
//
|
||||||
|
// let result = self.session.run(tvec!(tensor.into()))?;
|
||||||
|
// // 3. 解析结果
|
||||||
|
// // let output = result[0].to_array_view::<i64>()?;
|
||||||
|
let output = self.inference(tensor)?;
|
||||||
|
let output2 = self.process_text_output(&output)?;
|
||||||
|
Ok(self.ctc_decode_indices(&output2))
|
||||||
|
// Ok("ocr result".to_string())
|
||||||
|
}
|
||||||
|
/// 对应 Python 的 _preprocess_image
|
||||||
|
/// 负责:透明背景修复 -> 灰度化 -> 按比例 Resize -> 归一化 -> 4维张量转换
|
||||||
|
fn preprocess_image(&self, img: &DynamicImage, png_fix: bool) -> anyhow::Result<Tensor> {
|
||||||
|
// A. 修复 PNG 透明背景 (内部逻辑你之前已实现)
|
||||||
|
let _ = if png_fix && img.color().has_alpha() {
|
||||||
|
png_rgba_white_preprocess(img)
|
||||||
|
} else {
|
||||||
|
img.clone()
|
||||||
|
};
|
||||||
|
|
||||||
|
let h = 64u32;
|
||||||
|
let w = (img.width() as f32 * (h as f32 / img.height() as f32)) as u32;
|
||||||
|
let gray_img = convert_to_grayscale(img);
|
||||||
|
let resized = resize_image(&gray_img, w, h);
|
||||||
|
// resized.save("debug_preprocessed.png").unwrap();
|
||||||
|
// 1. 预处理:转灰度 -> Resize -> 归一化
|
||||||
|
// let resized = img.resize_exact(w, h, FilterType::Lanczos3).to_luma8();
|
||||||
|
|
||||||
|
// 使用 tract_ndarray 构造,避免版本冲突
|
||||||
|
let array =
|
||||||
|
tract_ndarray::Array4::from_shape_fn((1, 1, h as usize, w as usize), |(_, _, y, x)| {
|
||||||
|
let pixel = resized.get_pixel(x as u32, y as u32)[0] as f32;
|
||||||
|
(pixel / 255.0 - 0.5) / 0.5
|
||||||
|
});
|
||||||
|
|
||||||
|
let tensor = Tensor::from(array);
|
||||||
|
|
||||||
|
Ok(tensor)
|
||||||
|
}
|
||||||
|
/// 对应 Python 的 _inference
|
||||||
|
fn inference(&self, tensor: Tensor) -> anyhow::Result<Tensor> {
|
||||||
|
// tract 的 run 会返回一个 Vec<TValue>,我们通常只需要第一个输出
|
||||||
|
// let result = self.session.run(tvec!(tensor.into()))?;
|
||||||
|
let mut result = self
|
||||||
|
.session
|
||||||
|
.run(tvec!(tensor.into()))
|
||||||
|
.context("执行模型推理失败")?;
|
||||||
|
println!("模型输出原始数据: {:?}", result);
|
||||||
|
Ok(result.remove(0).into_tensor())
|
||||||
|
}
|
||||||
|
/// 核心解析逻辑:将模型输出的各种维度/类型的 Tensor 转为字符索引序列
|
||||||
|
fn process_text_output(&self, raw_tensor: &Tensor) -> anyhow::Result<Vec<i64>> {
|
||||||
|
let shape = raw_tensor.shape();
|
||||||
|
println!("模型输出shape数据: {:?}", shape);
|
||||||
|
let datum_type = raw_tensor.datum_type();
|
||||||
|
println!("模型输出datum_type数据: {:?}", datum_type);
|
||||||
|
|
||||||
|
match raw_tensor.datum_type() {
|
||||||
|
// 情况 1: huashi666 式模型,直接输出 i64 索引 (通常是模型内部做好了 Argmax)
|
||||||
|
DatumType::I64 => {
|
||||||
|
let view = raw_tensor.to_array_view::<i64>()?;
|
||||||
|
Ok(view.iter().cloned().collect())
|
||||||
|
}
|
||||||
|
|
||||||
|
// 情况 2: sml2h3 原版模型,输出 F32 概率矩阵
|
||||||
|
DatumType::F32 => {
|
||||||
|
let view = raw_tensor.to_array_view::<f32>()?;
|
||||||
|
let (steps, classes, data_view) = match shape.len() {
|
||||||
|
3 => {
|
||||||
|
if shape[1] == 1 {
|
||||||
|
// 形状: [Steps, 1, Classes] -> 你的原有逻辑
|
||||||
|
(shape[0], shape[2], view.into_dyn())
|
||||||
|
} else if shape[0] == 1 {
|
||||||
|
// 形状: [1, Steps, Classes] -> 另一种常见导出格式
|
||||||
|
(shape[1], shape[2], view.into_dyn())
|
||||||
|
} else {
|
||||||
|
// 默认取第一个 batch: [Batch, Steps, Classes]
|
||||||
|
// 使用 slice 对应 Python 的 output[0, :, :]
|
||||||
|
let sliced = view.slice(s![0, .., ..]);
|
||||||
|
(shape[1], shape[2], sliced.into_dyn())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
2 => {
|
||||||
|
// 形状: [Steps, Classes] -> 已经剥离了 Batch 维度
|
||||||
|
(shape[0], shape[1], view.into_dyn())
|
||||||
|
}
|
||||||
|
_ => return Err(anyhow::anyhow!("不支持的输出维度: {:?}", shape)),
|
||||||
|
};
|
||||||
|
let array_2d = data_view.to_shape((steps, classes))?;
|
||||||
|
//
|
||||||
|
// 对每一行执行 Argmax (寻找概率最大的字符索引)
|
||||||
|
let indices = array_2d
|
||||||
|
.outer_iter()
|
||||||
|
.map(|row| {
|
||||||
|
row.iter()
|
||||||
|
.enumerate()
|
||||||
|
.max_by(|(_, a), (_, b)| {
|
||||||
|
a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal)
|
||||||
|
})
|
||||||
|
.map(|(idx, _)| idx as i64)
|
||||||
|
.unwrap_or(0)
|
||||||
|
})
|
||||||
|
.collect();
|
||||||
|
Ok(indices)
|
||||||
|
}
|
||||||
|
_ => Err(anyhow::anyhow!(
|
||||||
|
"不支持的模型输出数据类型: {:?}",
|
||||||
|
raw_tensor.datum_type()
|
||||||
|
)),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
fn ctc_decode_indices(&self, predicted_indices: &[i64]) -> String {
|
||||||
|
println!("indices模型输出原始数据: {:?}", predicted_indices);
|
||||||
|
|
||||||
|
// 对应 _ctc_decode_indices 的逻辑:去重、去 blank (0)
|
||||||
|
let mut res = String::new();
|
||||||
|
let mut prev_idx: i64 = -1;
|
||||||
|
|
||||||
|
for &idx in predicted_indices {
|
||||||
|
// 1. 跳过连续重复的索引
|
||||||
|
// 2. 跳过 blank 字符 (假设索引 0 是 blank)
|
||||||
|
if idx != prev_idx && idx != 0 {
|
||||||
|
if let Ok(u_idx) = usize::try_from(idx) {
|
||||||
|
if let Some(char_str) = self.charset.get(u_idx) {
|
||||||
|
res.push_str(char_str);
|
||||||
|
} else {
|
||||||
|
// 保护逻辑:如果模型预测的索引超出了字符集范围
|
||||||
|
eprintln!("警告: 预测索引 {} 超出字符集范围", u_idx);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
prev_idx = idx;
|
||||||
|
}
|
||||||
|
println!("最终识别出的验证码是: {}", res);
|
||||||
|
res
|
||||||
|
}
|
||||||
|
}
|
||||||
236
src/models/slide.rs
Normal file
@@ -0,0 +1,236 @@
|
|||||||
|
use crate::utils::cv_ops;
|
||||||
|
use crate::utils::cv_ops::{abs_diff, min_max_loc, ndarray_to_luma8, rgb_to_gray};
|
||||||
|
use crate::utils::image_io::image_to_ndarray;
|
||||||
|
use anyhow::{Context, Result, anyhow};
|
||||||
|
use image::{DynamicImage, GenericImageView};
|
||||||
|
use image::{ImageBuffer, Luma};
|
||||||
|
use imageproc::contrast::{ThresholdType, threshold};
|
||||||
|
use imageproc::distance_transform::Norm;
|
||||||
|
use imageproc::edges::canny;
|
||||||
|
use imageproc::morphology::{close, open};
|
||||||
|
use imageproc::region_labelling::{Connectivity, connected_components};
|
||||||
|
use imageproc::template_matching::{MatchTemplateMethod, match_template};
|
||||||
|
use std::cmp::{max, min};
|
||||||
|
use tract_onnx::prelude::tract_ndarray::{Array2, Array3, ArrayView2, ArrayView3, Axis, s};
|
||||||
|
|
||||||
|
pub struct SlideResult {
|
||||||
|
pub target: [i32; 2],
|
||||||
|
pub target_x: i32,
|
||||||
|
pub target_y: i32,
|
||||||
|
pub confidence: f64,
|
||||||
|
}
|
||||||
|
|
||||||
|
pub struct Slide;
|
||||||
|
|
||||||
|
impl Slide {
|
||||||
|
pub fn new() -> Self {
|
||||||
|
Self
|
||||||
|
}
|
||||||
|
|
||||||
|
/// 对应 Python: slide_match
|
||||||
|
pub fn slide_match(
|
||||||
|
&self,
|
||||||
|
target_image: &DynamicImage,
|
||||||
|
background_image: &DynamicImage,
|
||||||
|
simple_target: bool,
|
||||||
|
) -> Result<SlideResult> {
|
||||||
|
let target_array = image_to_ndarray(target_image);
|
||||||
|
let background_array = image_to_ndarray(background_image);
|
||||||
|
|
||||||
|
self.perform_slide_match(target_array.view(), background_array.view(), simple_target)
|
||||||
|
.map_err(|e| anyhow!("滑块匹配失败: {}", e))
|
||||||
|
}
|
||||||
|
/// 对应 Python: slide_comparison
|
||||||
|
/// 用于比较带坑位的图片与原始背景图,定位差异点
|
||||||
|
pub fn slide_comparison(
|
||||||
|
&self,
|
||||||
|
target_image: &DynamicImage,
|
||||||
|
background_image: &DynamicImage,
|
||||||
|
) -> Result<SlideResult> {
|
||||||
|
// 1. 转换为 ndarray (HWC RGB)
|
||||||
|
let target_array = image_to_ndarray(target_image);
|
||||||
|
let background_array = image_to_ndarray(background_image);
|
||||||
|
|
||||||
|
// 2. 执行比较逻辑 (对应 _perform_slide_comparison)
|
||||||
|
self.perform_slide_comparison(target_array.view(), background_array.view())
|
||||||
|
.map_err(|e| anyhow!("滑块比较执行失败: {}", e))
|
||||||
|
}
|
||||||
|
/// 对应 Python: _perform_slide_comparison
|
||||||
|
pub fn perform_slide_comparison(
|
||||||
|
&self,
|
||||||
|
target: ArrayView3<u8>,
|
||||||
|
background: ArrayView3<u8>,
|
||||||
|
) -> Result<SlideResult> {
|
||||||
|
let (h, w, _) = target.dim();
|
||||||
|
|
||||||
|
// 1. 计算图像差异并灰度化 (对应 cv2.absdiff + cv2.cvtColor)
|
||||||
|
// 使用 OpenCV 标准权重公式:0.299R + 0.587G + 0.114B
|
||||||
|
// let mut diff_buffer = ImageBuffer::new(w as u32, h as u32);
|
||||||
|
// for y in 0..h {
|
||||||
|
// for x in 0..w {
|
||||||
|
// let r_diff = (target[[y, x, 0]] as i16 - background[[y, x, 0]] as i16).abs() as f32;
|
||||||
|
// let g_diff = (target[[y, x, 1]] as i16 - background[[y, x, 1]] as i16).abs() as f32;
|
||||||
|
// let b_diff = (target[[y, x, 2]] as i16 - background[[y, x, 2]] as i16).abs() as f32;
|
||||||
|
//
|
||||||
|
// let gray_diff = (0.299 * r_diff + 0.587 * g_diff + 0.114 * b_diff) as u8;
|
||||||
|
// diff_buffer.put_pixel(x as u32, y as u32, Luma([gray_diff]));
|
||||||
|
// }
|
||||||
|
// }
|
||||||
|
// 1. 计算差异数组 (复用 cv2::absdiff)
|
||||||
|
let diff_array = abs_diff(&target, &background);
|
||||||
|
|
||||||
|
// 2. 转换为灰度数组 (复用你的 cv2.cvtColor)
|
||||||
|
let gray_array = rgb_to_gray(diff_array.view());
|
||||||
|
// 3. 转为 ImageBuffer 以使用 imageproc 的高级功能
|
||||||
|
let gray_buffer = ndarray_to_luma8(gray_array.view());
|
||||||
|
|
||||||
|
// 2. 二值化 (对应 cv2.threshold(..., 30, 255, cv2.THRESH_BINARY))
|
||||||
|
let binary = threshold(&gray_buffer, 30, ThresholdType::Binary);
|
||||||
|
// 3. 形态学操作去噪 (对应 cv2.morphologyEx)
|
||||||
|
// 闭运算 (Close): 先膨胀后腐蚀,用于填补缺口内的细小黑色空洞
|
||||||
|
// 开运算 (Open): 先腐蚀后膨胀,用于消除背景中的白色噪点点
|
||||||
|
let norm = Norm::LInf; // 对应 3x3 的矩形内核
|
||||||
|
let radius = 1u8; // 1 表示 3x3 的范围,2 表示 5x5 的范围
|
||||||
|
let closed = close(&binary, norm, radius);
|
||||||
|
let cleaned = open(&closed, norm, radius);
|
||||||
|
|
||||||
|
// connected_components 会给每个独立的白色区域打上不同的标签 (ID)
|
||||||
|
let background_label = Luma([0u8]);
|
||||||
|
let labelled = connected_components(&cleaned, Connectivity::Eight, background_label);
|
||||||
|
|
||||||
|
// // 统计每个标签出现的频率(即面积)
|
||||||
|
// 4. 寻找最大连通区域 (对应 findContours + max area)
|
||||||
|
if let Some(max_label) = cv_ops::find_contours_and_max(&labelled) {
|
||||||
|
// 5. 计算最大区域的边界框 (对应 cv2.boundingRect)
|
||||||
|
let (x, y, w, h) = cv_ops::bounding_rect(&labelled, max_label);
|
||||||
|
// 6. 计算中心点 (调用之前封装的 calculate_center)
|
||||||
|
let (center_x, center_y) = cv_ops::calculate_center((x, y), w as usize, h as usize);
|
||||||
|
|
||||||
|
Ok(SlideResult {
|
||||||
|
target: [center_x, center_y],
|
||||||
|
target_x: center_x,
|
||||||
|
target_y: center_y,
|
||||||
|
confidence: 1.0, // Comparison 模式下通常认为找到即为 1.0
|
||||||
|
})
|
||||||
|
} else {
|
||||||
|
Ok(SlideResult {
|
||||||
|
target: [0, 0],
|
||||||
|
target_x: 0,
|
||||||
|
target_y: 0,
|
||||||
|
confidence: 0.0,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// 对应 Python: _perform_slide_match
|
||||||
|
// 在 SlideEngine 中修改此入口进行测试
|
||||||
|
fn perform_slide_match(
|
||||||
|
&self,
|
||||||
|
target: ArrayView3<u8>,
|
||||||
|
background: ArrayView3<u8>,
|
||||||
|
simple_target: bool, // 增加这个参数
|
||||||
|
) -> Result<SlideResult> {
|
||||||
|
// 1. 统一灰度化
|
||||||
|
let target_gray = rgb_to_gray(target);
|
||||||
|
let background_gray = rgb_to_gray(background);
|
||||||
|
|
||||||
|
if simple_target {
|
||||||
|
// 2a. 简单模式:直接在灰度图上匹配
|
||||||
|
self.simple_template_match(target_gray.view(), background_gray.view())
|
||||||
|
} else {
|
||||||
|
// 2b. 复杂模式:先提取边缘,再匹配
|
||||||
|
|
||||||
|
self.edge_based_match(target_gray.view(), background_gray.view())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
/// 对应 Python: _simple_template_match
|
||||||
|
/// 使用 SAD (Sum of Absolute Differences) 算法
|
||||||
|
/// 核心模板匹配:SAD + 有效像素过滤
|
||||||
|
fn simple_template_match(
|
||||||
|
&self,
|
||||||
|
target: ArrayView2<u8>,
|
||||||
|
background: ArrayView2<u8>,
|
||||||
|
) -> Result<SlideResult> {
|
||||||
|
// 1. 将 ndarray 转换为 imageproc 需要的 ImageBuffer (无拷贝或轻量转换)
|
||||||
|
|
||||||
|
// let (bh, bw) = background.dim();
|
||||||
|
|
||||||
|
// 转换逻辑 (假设你已经有方法转回 ImageBuffer)
|
||||||
|
let t_buf = ndarray_to_luma8(target);
|
||||||
|
let b_buf = ndarray_to_luma8(background);
|
||||||
|
// t_buf.save("debug_rust_target.png").unwrap();
|
||||||
|
|
||||||
|
// 2. 调用 imageproc 的 NCC 算法 (等价于 cv2.TM_CCOEFF_NORMED)
|
||||||
|
// 模板匹配 (完全对齐 cv2.matchTemplate(..., cv2.TM_CCOEFF_NORMED))
|
||||||
|
let result = match_template(
|
||||||
|
&b_buf,
|
||||||
|
&t_buf,
|
||||||
|
MatchTemplateMethod::CrossCorrelationNormalized,
|
||||||
|
);
|
||||||
|
// save_rust_result(&result, "debug_rust_target2.png");
|
||||||
|
// 3. 寻找最大值 (等价于 cv2.minMaxLoc)
|
||||||
|
let (max_val, max_loc) = min_max_loc(&result);
|
||||||
|
|
||||||
|
// 4. 计算中心点 (与 Python 逻辑完全一致)
|
||||||
|
let (th, tw) = target.dim();
|
||||||
|
|
||||||
|
let (center_x, center_y) = cv_ops::calculate_center(max_loc, tw as usize, th as usize);
|
||||||
|
// println!("Rust Target Width (tw): {}", tw);
|
||||||
|
// println!("Rust Best Max Loc X: {}", max_loc.0);
|
||||||
|
// println!("Rust Final Center X: {}", center_x);
|
||||||
|
Ok(SlideResult {
|
||||||
|
target: [center_x, center_y],
|
||||||
|
target_x: center_x,
|
||||||
|
target_y: center_y,
|
||||||
|
confidence: max_val as f64,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
/// 对应 Python: _edge_based_match
|
||||||
|
/// 基于边缘检测的滑块匹配 (对齐 Python _edge_based_match)
|
||||||
|
pub fn edge_based_match(
|
||||||
|
&self,
|
||||||
|
target: ArrayView2<u8>,
|
||||||
|
background: ArrayView2<u8>,
|
||||||
|
) -> Result<SlideResult> {
|
||||||
|
// 1. 将 ndarray 转换为 ImageBuffer
|
||||||
|
// 注意:Canny 和 match_template 需要 ImageBuffer 格式
|
||||||
|
let t_buf = ndarray_to_luma8(target);
|
||||||
|
let b_buf = ndarray_to_luma8(background);
|
||||||
|
|
||||||
|
// 2. 边缘检测 (完全对齐 cv2.Canny(50, 150))
|
||||||
|
// 这步会生成黑底白线的二值化边缘图
|
||||||
|
let target_edges = canny(&t_buf, 50.0, 150.0);
|
||||||
|
let background_edges = canny(&b_buf, 50.0, 150.0);
|
||||||
|
|
||||||
|
// target_edges.save("debug_target_edges.png").ok();
|
||||||
|
// background_edges.save("debug_bg_edges.png").ok();
|
||||||
|
|
||||||
|
// 3. 模板匹配 (完全对齐 cv2.matchTemplate(..., cv2.TM_CCOEFF_NORMED))
|
||||||
|
// 在边缘图上计算归一化互相关系数
|
||||||
|
let result = match_template(
|
||||||
|
&background_edges,
|
||||||
|
&target_edges,
|
||||||
|
MatchTemplateMethod::CrossCorrelationNormalized,
|
||||||
|
);
|
||||||
|
|
||||||
|
// 4. 找到最佳匹配位置 (对齐 cv2.minMaxLoc)
|
||||||
|
let (max_val, max_loc) = min_max_loc(&result);
|
||||||
|
// 5. 计算中心位置 (对齐 Python 逻辑)
|
||||||
|
// target_w, target_h 来自输入数组的维度
|
||||||
|
let (th, tw) = target.dim();
|
||||||
|
let (center_x, center_y) = cv_ops::calculate_center(max_loc, tw as usize, th as usize);
|
||||||
|
|
||||||
|
// 打印调试信息,方便与 Python 对比
|
||||||
|
// println!("Edge Match: max_val: {}, max_loc: {:?}", max_val, max_loc);
|
||||||
|
println!("-Rust Target Width (tw): {}", tw);
|
||||||
|
println!("-Rust Best Max Loc X: {}", max_loc.0);
|
||||||
|
println!("-Rust Final Center X: {}", center_x);
|
||||||
|
Ok(SlideResult {
|
||||||
|
target: [center_x, center_y],
|
||||||
|
target_x: center_x,
|
||||||
|
target_y: center_y,
|
||||||
|
confidence: max_val as f64,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
107
src/utils/cv_ops.rs
Normal file
@@ -0,0 +1,107 @@
|
|||||||
|
use std::cmp::{max, min};
|
||||||
|
use image::{ImageBuffer, Luma};
|
||||||
|
use tract_onnx::prelude::tract_ndarray::{azip, Array2, Array3, ArrayView2, ArrayView3};
|
||||||
|
|
||||||
|
/// 1. 计算两个数组的绝对差值 (对应 cv2.absdiff)
|
||||||
|
pub fn abs_diff(a: &ArrayView3<u8>, b: &ArrayView3<u8>) -> Array3<u8> {
|
||||||
|
// 利用 ndarray 的 map_collect,生成差值的绝对值数组
|
||||||
|
// 或者直接使用 zip_mut_with 处理以减少内存分配
|
||||||
|
let mut diff = Array3::zeros(a.dim());
|
||||||
|
azip!((res in &mut diff, &va in a, &vb in b) {
|
||||||
|
*res = (va as i16 - vb as i16).abs() as u8;
|
||||||
|
});
|
||||||
|
diff
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
/// RGB 到灰度转换
|
||||||
|
pub fn rgb_to_gray(rgb: ArrayView3<u8>) -> Array2<u8> {
|
||||||
|
let (h, w, _) = rgb.dim();
|
||||||
|
Array2::from_shape_fn((h, w), |(y, x)| {
|
||||||
|
let r = rgb[[y, x, 0]] as f32;
|
||||||
|
let g = rgb[[y, x, 1]] as f32;
|
||||||
|
let b = rgb[[y, x, 2]] as f32;
|
||||||
|
// 完全忽略 a,只按权重计算
|
||||||
|
(0.299 * r + 0.587 * g + 0.114 * b) as u8
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
/// 寻找匹配结果图中的最大值及其坐标 (模拟 cv2.minMaxLoc 的一部分)
|
||||||
|
pub fn min_max_loc(result_map: &ImageBuffer<Luma<f32>, Vec<f32>>) -> (f32, (u32, u32)) {
|
||||||
|
// 4. 找到最佳匹配位置 (对齐 cv2.minMaxLoc)
|
||||||
|
let mut max_val: f32 = -1.0;
|
||||||
|
let mut max_loc = (0, 0);
|
||||||
|
|
||||||
|
// 遍历匹配得分图
|
||||||
|
for (x, y, score) in result_map.enumerate_pixels() {
|
||||||
|
let s = score.0[0];
|
||||||
|
|
||||||
|
// 可以在此处加入你之前验证过的起始位过滤
|
||||||
|
// if x < 15 { continue; }
|
||||||
|
|
||||||
|
if s > max_val {
|
||||||
|
max_val = s;
|
||||||
|
max_loc = (x, y);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
(max_val, max_loc)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// 1. 模拟 findContours 并获取最大面积区域的 Label
|
||||||
|
/// 返回 Option<u32>,如果找不到任何区域则返回 None
|
||||||
|
pub fn find_contours_and_max(labelled: &ImageBuffer<Luma<u32>, Vec<u32>>) -> Option<u32> {
|
||||||
|
// 统计每个标签出现的频率(即面积)
|
||||||
|
let mut max_label = 0;
|
||||||
|
let mut max_area = 0;
|
||||||
|
let mut areas = std::collections::HashMap::new();
|
||||||
|
|
||||||
|
for pixel in labelled.pixels() {
|
||||||
|
let label = pixel.0[0];
|
||||||
|
if label == 0 {
|
||||||
|
continue;
|
||||||
|
} // 跳过背景
|
||||||
|
let count = areas.entry(label).or_insert(0);
|
||||||
|
*count += 1;
|
||||||
|
if *count > max_area {
|
||||||
|
max_area = *count;
|
||||||
|
max_label = label;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if max_label == 0 { None } else { Some(max_label) }
|
||||||
|
}
|
||||||
|
pub fn bounding_rect(labelled: &ImageBuffer<Luma<u32>, Vec<u32>>,max_label: u32) -> (u32, u32, u32, u32) {
|
||||||
|
// 5. 计算最大区域的边界框 (对应 cv2.boundingRect)
|
||||||
|
let mut min_x = labelled.width();
|
||||||
|
let mut max_x = 0;
|
||||||
|
let mut min_y = labelled.height();
|
||||||
|
let mut max_y = 0;
|
||||||
|
|
||||||
|
for (x, y, pixel) in labelled.enumerate_pixels() {
|
||||||
|
if pixel.0[0] == max_label {
|
||||||
|
min_x = min(min_x, x);
|
||||||
|
max_x = max(max_x, x);
|
||||||
|
min_y = min(min_y, y);
|
||||||
|
max_y = max(max_y, y);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
let w = max_x - min_x;
|
||||||
|
let h = max_y - min_y;
|
||||||
|
(min_x, min_y, w, h)
|
||||||
|
}
|
||||||
|
pub fn calculate_center(max_loc: (u32, u32), tw: usize, th: usize) -> (i32, i32) {
|
||||||
|
let center_x = max_loc.0 as i32 + (tw as i32 / 2);
|
||||||
|
let center_y = max_loc.1 as i32 + (th as i32 / 2);
|
||||||
|
(center_x, center_y)
|
||||||
|
}
|
||||||
|
pub fn ndarray_to_luma8(array: ArrayView2<u8>) -> ImageBuffer<Luma<u8>, Vec<u8>> {
|
||||||
|
let (height, width) = array.dim();
|
||||||
|
let mut buffer = ImageBuffer::new(width as u32, height as u32);
|
||||||
|
for y in 0..height {
|
||||||
|
for x in 0..width {
|
||||||
|
buffer.put_pixel(x as u32, y as u32, Luma([array[[y, x]]]));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
buffer
|
||||||
|
}
|
||||||
264
src/utils/image_io.rs
Normal file
@@ -0,0 +1,264 @@
|
|||||||
|
use anyhow::{Context, Result, anyhow, bail};
|
||||||
|
use base64::{Engine as _, engine::general_purpose};
|
||||||
|
use image::{DynamicImage, GenericImageView, ImageBuffer, ImageFormat, Luma, Rgb, RgbImage, Rgba};
|
||||||
|
use std::fs;
|
||||||
|
use std::path::{Path, PathBuf};
|
||||||
|
use tract_onnx::prelude::tract_ndarray::{Array3, ArrayD, ArrayViewD};
|
||||||
|
#[derive(Debug)]
|
||||||
|
pub enum ColorMode {
|
||||||
|
RGB,
|
||||||
|
RGBA,
|
||||||
|
L,
|
||||||
|
}
|
||||||
|
/// 定义支持的输入类型枚举
|
||||||
|
pub enum ImageInput {
|
||||||
|
Bytes(Vec<u8>),
|
||||||
|
Array(ArrayD<u8>), // 对应 numpy 数组
|
||||||
|
Path(PathBuf),
|
||||||
|
Base64(String),
|
||||||
|
DynamicImage(DynamicImage),
|
||||||
|
}
|
||||||
|
/// 模拟 Python 的 load_image_from_input
|
||||||
|
#[allow(dead_code)]
|
||||||
|
pub fn load_image_from_input(img_input: ImageInput) -> Result<DynamicImage> {
|
||||||
|
match img_input {
|
||||||
|
// 2. 处理字节流 (Bytes)
|
||||||
|
ImageInput::Bytes(bytes) => {
|
||||||
|
image::load_from_memory(&bytes).context("Failed to load utils from bytes")
|
||||||
|
}
|
||||||
|
// 1. 已经是 DynamicImage
|
||||||
|
ImageInput::DynamicImage(i) => Ok(i),
|
||||||
|
// 5. 处理 ndarray (Numpy-like)
|
||||||
|
// 假设输入是 HWC 格式的 Array3<u8>
|
||||||
|
ImageInput::Array(a) => numpy_to_pil_image(a.view()),
|
||||||
|
// 4. 处理 Base64 字符串
|
||||||
|
ImageInput::Base64(b) => base64_to_image(&b),
|
||||||
|
// 3. 处理文件路径 (Path)
|
||||||
|
ImageInput::Path(p) => image::open(p).context("Failed to open utils from path"),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
fn base64_to_image(b64_str: &str) -> Result<DynamicImage> {
|
||||||
|
// 过滤掉可能存在的 base64 前缀,例如 "data:utils/png;base64,"
|
||||||
|
let clean_b64 = if let Some(pos) = b64_str.find(",") {
|
||||||
|
&b64_str[pos + 1..]
|
||||||
|
} else {
|
||||||
|
&b64_str
|
||||||
|
};
|
||||||
|
|
||||||
|
let bytes = general_purpose::STANDARD
|
||||||
|
.decode(clean_b64.trim())
|
||||||
|
.map_err(|e| anyhow!("Base64 decode error: {}", e))?;
|
||||||
|
|
||||||
|
image::load_from_memory(&bytes).context("Failed to load utils from decoded base64")
|
||||||
|
}
|
||||||
|
|
||||||
|
/// 读取图片文件并转换为 base64 编码字符串
|
||||||
|
/// 对应 Python 版 get_img_base64
|
||||||
|
pub fn get_img_base64<P: AsRef<Path>>(image_path: P) -> Result<String> {
|
||||||
|
// 1. 读取文件原始字节流
|
||||||
|
// 使用 AsRef<Path> 泛型可以让函数同时支持 String, &str, PathBuf 等类型
|
||||||
|
let image_data = fs::read(&image_path)
|
||||||
|
.with_context(|| format!("Failed to read utils file: {:?}", image_path.as_ref()))?;
|
||||||
|
|
||||||
|
// 2. 进行 Base64 编码
|
||||||
|
// 使用 STANDARD 引擎对齐 Python 的 base64.b64encode
|
||||||
|
let b64_string = general_purpose::STANDARD.encode(image_data);
|
||||||
|
|
||||||
|
Ok(b64_string)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// 封装数组转图像的逻辑,对齐 Python 版 _numpy_to_pil_image
|
||||||
|
fn numpy_to_pil_image(array: ArrayViewD<u8>) -> Result<DynamicImage> {
|
||||||
|
let shape = array.shape();
|
||||||
|
let dim = shape.len();
|
||||||
|
|
||||||
|
// 1. 确保数据在内存中是连续的 (C order / Standard Layout)
|
||||||
|
// 如果 arr 是经过切片或转置的,这一步会进行必要的内存拷贝
|
||||||
|
let standard = array.as_standard_layout();
|
||||||
|
let (raw_data, _offset) = standard.to_owned().into_raw_vec_and_offset();
|
||||||
|
|
||||||
|
match dim {
|
||||||
|
// 对应 Python: len(array.shape) == 2 (灰度图 H, W)
|
||||||
|
2 => {
|
||||||
|
let (h, w) = (shape[0], shape[1]);
|
||||||
|
ImageBuffer::<Luma<u8>, _>::from_raw(w as u32, h as u32, raw_data)
|
||||||
|
.map(DynamicImage::ImageLuma8)
|
||||||
|
.ok_or_else(|| anyhow!("Failed to create Luma utils from 2D array"))
|
||||||
|
}
|
||||||
|
|
||||||
|
// 对应 Python: len(array.shape) == 3 (H, W, C)
|
||||||
|
3 => {
|
||||||
|
let (h, w, c) = (shape[0], shape[1], shape[2]);
|
||||||
|
match c {
|
||||||
|
// 对应 Python: array.shape[2] == 1 (单通道 H, W, 1)
|
||||||
|
1 => ImageBuffer::<Luma<u8>, _>::from_raw(w as u32, h as u32, raw_data)
|
||||||
|
.map(DynamicImage::ImageLuma8),
|
||||||
|
|
||||||
|
// 对应 Python: array.shape[2] == 3 (RGB H, W, 3)
|
||||||
|
3 => ImageBuffer::<Rgb<u8>, _>::from_raw(w as u32, h as u32, raw_data)
|
||||||
|
.map(DynamicImage::ImageRgb8),
|
||||||
|
|
||||||
|
// 对应 Python: array.shape[2] == 4 (RGBA H, W, 4)
|
||||||
|
4 => ImageBuffer::<Rgba<u8>, _>::from_raw(w as u32, h as u32, raw_data)
|
||||||
|
.map(DynamicImage::ImageRgba8),
|
||||||
|
|
||||||
|
_ => {
|
||||||
|
return Err(anyhow!("不支持的通道数: {}", c));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
.ok_or_else(|| anyhow!("转换彩色图失败"))
|
||||||
|
}
|
||||||
|
|
||||||
|
_ => Err(anyhow!("不支持的数组维度: {},仅支持 2D 或 3D", dim)),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// 对应 Python 的 png_rgba_black_preprocess
|
||||||
|
/// 将带有透明通道的图片转换为白色背景的 RGB 图片
|
||||||
|
|
||||||
|
pub fn png_rgba_white_preprocess(img: &DynamicImage) -> DynamicImage {
|
||||||
|
// 1. 检查是否包含透明通道,如果没有,直接克隆并返回
|
||||||
|
if !img.color().has_alpha() {
|
||||||
|
return DynamicImage::ImageRgb8(img.to_rgb8());
|
||||||
|
}
|
||||||
|
|
||||||
|
let (width, height) = img.dimensions();
|
||||||
|
|
||||||
|
// 2. 创建一个新的 RGB 图像缓冲,默认填充为白色 (255, 255, 255)
|
||||||
|
let mut background = ImageBuffer::from_pixel(width, height, Rgb([255u8, 255u8, 255u8]));
|
||||||
|
|
||||||
|
// 3. 获取原图的 RGBA 视图
|
||||||
|
let rgba_img = img.to_rgba8();
|
||||||
|
|
||||||
|
// 4. 遍历像素并手动进行 Alpha 混合
|
||||||
|
// 对应 Python 的 utils.paste(img, ..., mask=img)
|
||||||
|
// 使用 enumerate_pixels_mut 同时获取坐标和背景像素的可变引用,减少查找开销
|
||||||
|
for (x, y, bg_pixel) in background.enumerate_pixels_mut() {
|
||||||
|
// 安全性说明:x, y 源自 background 尺寸,与 rgba_img 一致,get_pixel 是安全的
|
||||||
|
let src_pixel = rgba_img.get_pixel(x, y);
|
||||||
|
let alpha_u8 = src_pixel[3];
|
||||||
|
|
||||||
|
match alpha_u8 {
|
||||||
|
// 情况 A:完全不透明,直接覆盖背景色
|
||||||
|
255 => {
|
||||||
|
bg_pixel.0 = [src_pixel[0], src_pixel[1], src_pixel[2]];
|
||||||
|
}
|
||||||
|
// 情况 B:完全透明,保持背景色(白色),无需操作
|
||||||
|
0 => {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
// 情况 C:半透明,进行 Alpha 混合计算
|
||||||
|
_ => {
|
||||||
|
let alpha = alpha_u8 as f32 / 255.0;
|
||||||
|
let inv_alpha = 1.0 - alpha;
|
||||||
|
|
||||||
|
bg_pixel[0] = (src_pixel[0] as f32 * alpha + 255.0 * inv_alpha).round() as u8;
|
||||||
|
bg_pixel[1] = (src_pixel[1] as f32 * alpha + 255.0 * inv_alpha).round() as u8;
|
||||||
|
bg_pixel[2] = (src_pixel[2] as f32 * alpha + 255.0 * inv_alpha).round() as u8;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
DynamicImage::ImageRgb8(background)
|
||||||
|
}
|
||||||
|
pub fn image_to_numpy(image: &DynamicImage, mode: ColorMode) -> Result<Array3<u8>> {
|
||||||
|
// 1. 模式转换 (对应 utils.convert(target_mode)),此函数在时保留看后续优化是否需要替代image_to_ndarray
|
||||||
|
// Rust utils 库通过 to_rgb8, to_luma8 等方法实现转换
|
||||||
|
let (width, height) = image.dimensions();
|
||||||
|
|
||||||
|
let (channels, raw) = match mode {
|
||||||
|
ColorMode::RGB => (3, image.to_rgb8().into_raw()),
|
||||||
|
ColorMode::L => (1, image.to_luma8().into_raw()),
|
||||||
|
ColorMode::RGBA => (4, image.to_rgba8().into_raw()),
|
||||||
|
};
|
||||||
|
|
||||||
|
Array3::from_shape_vec((height as usize, width as usize, channels), raw)
|
||||||
|
.map_err(|e| anyhow!("Failed to build ndarray: {}", e))
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn numpy_to_image(array: ArrayViewD<u8>, mode: ColorMode) -> Result<DynamicImage> {
|
||||||
|
let shape = array.shape();
|
||||||
|
// 1. 基础维度检查 (必须是 H, W, C 三维数组)
|
||||||
|
if shape.len() != 3 {
|
||||||
|
bail!("Expected a 3D array (H, W, C), but got {}D", shape.len());
|
||||||
|
}
|
||||||
|
|
||||||
|
let height = shape[0] as u32;
|
||||||
|
let width = shape[1] as u32;
|
||||||
|
let channels = shape[2];
|
||||||
|
// 2. 检查通道数是否与模式匹配
|
||||||
|
let expected_channels = match mode {
|
||||||
|
ColorMode::L => 1,
|
||||||
|
ColorMode::RGB => 3,
|
||||||
|
ColorMode::RGBA => 4,
|
||||||
|
};
|
||||||
|
if channels != expected_channels {
|
||||||
|
bail!(
|
||||||
|
"Mode {:?} expects {} channels, but array has {}",
|
||||||
|
mode,
|
||||||
|
expected_channels,
|
||||||
|
channels
|
||||||
|
);
|
||||||
|
}
|
||||||
|
// 确保数据连续性 (C-order)
|
||||||
|
let standard = array.as_standard_layout();
|
||||||
|
let (raw_data, _) = standard.to_owned().into_raw_vec_and_offset();
|
||||||
|
|
||||||
|
match mode {
|
||||||
|
ColorMode::L => ImageBuffer::<Luma<u8>, _>::from_raw(width, height, raw_data)
|
||||||
|
.map(DynamicImage::ImageLuma8),
|
||||||
|
ColorMode::RGB => ImageBuffer::<Rgb<u8>, _>::from_raw(width, height, raw_data)
|
||||||
|
.map(DynamicImage::ImageRgb8),
|
||||||
|
ColorMode::RGBA => ImageBuffer::<Rgba<u8>, _>::from_raw(width, height, raw_data)
|
||||||
|
.map(DynamicImage::ImageRgba8),
|
||||||
|
}
|
||||||
|
.ok_or_else(|| anyhow!("Failed to construct ImageBuffer. Buffer size might be incorrect."))
|
||||||
|
}
|
||||||
|
pub fn image_to_ndarray(img: &DynamicImage) -> Array3<u8> {
|
||||||
|
let (width, height) = img.dimensions();
|
||||||
|
|
||||||
|
// 1. 强制转为 RGB8 (丢弃 Alpha 通道,与 Python 的 target_mode='RGB' 对齐)
|
||||||
|
let rgb_img = img.to_rgb8();
|
||||||
|
|
||||||
|
// 2. 获取原始像素数据
|
||||||
|
let raw_data = rgb_img.into_raw();
|
||||||
|
|
||||||
|
// 3. 构造数组 (通道数改为 3)
|
||||||
|
Array3::from_shape_vec((height as usize, width as usize, 3), raw_data)
|
||||||
|
.expect("Failed to construct ndarray from utils") // 建议显式报错,而不是返回全黑图
|
||||||
|
}
|
||||||
|
|
||||||
|
#[allow(dead_code)]
|
||||||
|
fn save_rust_result(result: &ImageBuffer<Luma<f32>, Vec<f32>>, filename: &str) {
|
||||||
|
let (width, height) = result.dimensions();
|
||||||
|
|
||||||
|
// 1. 寻找最值进行归一化
|
||||||
|
let mut max_val = f32::MIN;
|
||||||
|
let mut min_val = f32::MAX;
|
||||||
|
for p in result.pixels() {
|
||||||
|
if p.0[0] > max_val {
|
||||||
|
max_val = p.0[0];
|
||||||
|
}
|
||||||
|
if p.0[0] < min_val {
|
||||||
|
min_val = p.0[0];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// 2. 创建 8 位灰度图
|
||||||
|
let mut out_buf = ImageBuffer::new(width, height);
|
||||||
|
for y in 0..height {
|
||||||
|
for x in 0..width {
|
||||||
|
let val = result.get_pixel(x, y).0[0];
|
||||||
|
let normalized = if max_val > min_val {
|
||||||
|
((val - min_val) / (max_val - min_val) * 255.0) as u8
|
||||||
|
} else {
|
||||||
|
0u8
|
||||||
|
};
|
||||||
|
out_buf.put_pixel(x, y, Luma([normalized]));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// 3. 保存
|
||||||
|
DynamicImage::ImageLuma8(out_buf).save(filename).unwrap();
|
||||||
|
println!("Rust 结果热力图已保存至: {}", filename);
|
||||||
|
}
|
||||||
@@ -4,7 +4,7 @@ use anyhow::Result;
|
|||||||
/// 对应 Python 的 convert_to_grayscale
|
/// 对应 Python 的 convert_to_grayscale
|
||||||
/// 将图像转换为灰度图 (L模式)
|
/// 将图像转换为灰度图 (L模式)
|
||||||
pub fn convert_to_grayscale(image: &DynamicImage) -> GrayImage {
|
pub fn convert_to_grayscale(image: &DynamicImage) -> GrayImage {
|
||||||
// Rust image 库的 to_luma8 会根据标准的亮度公式进行转换
|
// Rust utils 库的 to_luma8 会根据标准的亮度公式进行转换
|
||||||
image.to_luma8()
|
image.to_luma8()
|
||||||
}
|
}
|
||||||
|
|
||||||
3
src/utils/mod.rs
Normal file
@@ -0,0 +1,3 @@
|
|||||||
|
pub mod image_io;
|
||||||
|
pub mod image_processor;
|
||||||
|
pub mod cv_ops;
|
||||||
@@ -1,12 +1,72 @@
|
|||||||
use ddddocr_rs::DdddOcr; // 假设你的包名是这个
|
use ddddocr_rs::models::slide::Slide;
|
||||||
|
use ddddocr_rs::{DdddOcr, DdddOcrBuilder}; // 假设你的包名是这个
|
||||||
|
use image::Rgb;
|
||||||
|
use std::fs;
|
||||||
|
use std::path::Path;
|
||||||
|
fn load_image<P: AsRef<Path>>(path: P) -> anyhow::Result<image::DynamicImage> {
|
||||||
|
// 1. 先将泛型转为具体的 &Path 引用
|
||||||
|
let path_ref = path.as_ref();
|
||||||
|
|
||||||
|
// 2. 调用 open 时传入引用(utils::open 支持 AsRef<Path>)
|
||||||
|
image::open(path_ref).map_err(|e| {
|
||||||
|
// 3. 此时 path_ref 依然有效,可以安全地在闭包中使用
|
||||||
|
anyhow::anyhow!("无法加载图片 {:?}: {}", path_ref, e)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
/// 将检测结果绘制在图像上并保存
|
||||||
|
fn save_debug_image(
|
||||||
|
image_bytes: &[u8],
|
||||||
|
bboxes: &Vec<Vec<i32>>,
|
||||||
|
output_path: &str,
|
||||||
|
) -> anyhow::Result<()> {
|
||||||
|
let dynamic_img = image::load_from_memory(image_bytes)?;
|
||||||
|
let mut img = dynamic_img.to_rgb8();
|
||||||
|
let (width, height) = img.dimensions();
|
||||||
|
let red = Rgb([255u8, 0, 0]);
|
||||||
|
|
||||||
|
for bbox in bboxes {
|
||||||
|
// 基础边界检查
|
||||||
|
let x1 = bbox[0].max(0).min(width as i32 - 1) as u32;
|
||||||
|
let y1 = bbox[1].max(0).min(height as i32 - 1) as u32;
|
||||||
|
let x2 = bbox[2].max(0).min(width as i32 - 1) as u32;
|
||||||
|
let y2 = bbox[3].max(0).min(height as i32 - 1) as u32;
|
||||||
|
|
||||||
|
// 绘制横向线条
|
||||||
|
for x in x1..=x2 {
|
||||||
|
img.put_pixel(x, y1, red);
|
||||||
|
img.put_pixel(x, y2, red);
|
||||||
|
// 如果要加粗,多画一行
|
||||||
|
if y1 + 1 < height {
|
||||||
|
img.put_pixel(x, y1 + 1, red);
|
||||||
|
}
|
||||||
|
if y2.saturating_sub(1) > 0 {
|
||||||
|
img.put_pixel(x, y2 - 1, red);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// 绘制纵向线条
|
||||||
|
for y in y1..=y2 {
|
||||||
|
img.put_pixel(x1, y, red);
|
||||||
|
img.put_pixel(x2, y, red);
|
||||||
|
// 如果要加粗,多画一列
|
||||||
|
if x1 + 1 < width {
|
||||||
|
img.put_pixel(x1 + 1, y, red);
|
||||||
|
}
|
||||||
|
if x2.saturating_sub(1) > 0 {
|
||||||
|
img.put_pixel(x2 - 1, y, red);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
img.save(output_path)?;
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
#[test]
|
#[test]
|
||||||
fn test_full_classification() {
|
fn test_full_classification() {
|
||||||
// 1. 初始化模型
|
// 1. 初始化模型
|
||||||
let ocr = DdddOcr::new("model/common.onnx").expect("模型加载失败");
|
let ocr = DdddOcrBuilder::new().build().expect("模型加载失败");
|
||||||
|
|
||||||
// 2. 加载测试图片
|
// 2. 加载测试图片
|
||||||
let img = image::open("samples/code3.png").expect("测试图片不存在");
|
let img = image::open("samples/code2.png").expect("测试图片不存在");
|
||||||
|
|
||||||
// 3. 执行识别
|
// 3. 执行识别
|
||||||
let result = ocr.classification(&img).expect("识别过程出错");
|
let result = ocr.classification(&img).expect("识别过程出错");
|
||||||
@@ -14,3 +74,89 @@ fn test_full_classification() {
|
|||||||
println!("识别结果: {}", result);
|
println!("识别结果: {}", result);
|
||||||
assert!(!result.is_empty());
|
assert!(!result.is_empty());
|
||||||
}
|
}
|
||||||
|
#[test]
|
||||||
|
fn test_det_load() -> anyhow::Result<()> {
|
||||||
|
let det = DdddOcrBuilder::new().det().build()?;
|
||||||
|
let image_path = "samples/det1.png";
|
||||||
|
let image_bytes =
|
||||||
|
fs::read(image_path).map_err(|e| anyhow::anyhow!("无法读取图片 {}: {}", image_path, e))?;
|
||||||
|
|
||||||
|
println!("图片读取成功,字节大小: {}", image_bytes.len());
|
||||||
|
let bboxes = det.detection(&image_bytes)?;
|
||||||
|
println!(":?{}", det);
|
||||||
|
println!("检测到的目标数量: {}", bboxes.len());
|
||||||
|
if bboxes.is_empty() {
|
||||||
|
println!("未检测到任何目标。");
|
||||||
|
} else {
|
||||||
|
save_debug_image(&image_bytes, &bboxes, "samples/result.jpg")?;
|
||||||
|
for (i, bbox) in bboxes.iter().enumerate() {
|
||||||
|
println!(
|
||||||
|
"目标 [{}]: x1={}, y1={}, x2={}, y2={}",
|
||||||
|
i, bbox[0], bbox[1], bbox[2], bbox[3]
|
||||||
|
);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_real_slide_match() {
|
||||||
|
let engine = Slide::new();
|
||||||
|
|
||||||
|
// 1. 加载你准备好的测试图
|
||||||
|
// 假设图片放在项目根目录下的 assets 文件夹
|
||||||
|
let target_img = load_image("samples/hua.png").expect("请确保 samples/hua.png 存在");
|
||||||
|
let bg_img = load_image("samples/huatu.png").expect("请确保 samples/huatu.png 存在");
|
||||||
|
|
||||||
|
// 2. 执行匹配
|
||||||
|
// 如果是那种带有明显阴影边缘的复杂滑块,建议 simple_target 传 false
|
||||||
|
let start = std::time::Instant::now();
|
||||||
|
let result = engine
|
||||||
|
.slide_match(&target_img, &bg_img, false)
|
||||||
|
.expect("Slide match 执行失败");
|
||||||
|
let duration = start.elapsed();
|
||||||
|
|
||||||
|
// 3. 打印结果
|
||||||
|
println!("-------------------------------------------");
|
||||||
|
println!("滑块匹配测试结果:");
|
||||||
|
println!("检测坐标: [x: {}, y: {}]", result.target_x, result.target_y);
|
||||||
|
println!("置信度: {:.4}", result.confidence);
|
||||||
|
println!("耗时: {:?}", duration);
|
||||||
|
println!("-------------------------------------------");
|
||||||
|
|
||||||
|
// 验证基本逻辑:坐标不应为 0 (除非匹配失败)
|
||||||
|
assert_eq!(result.target_x, 237);
|
||||||
|
assert_eq!(result.target_y, 77);
|
||||||
|
assert!(result.confidence > 0.0);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_real_slide_comparison() {
|
||||||
|
let engine = Slide::new();
|
||||||
|
|
||||||
|
// 1. 加载你准备好的测试图
|
||||||
|
// 假设图片放在项目根目录下的 assets 文件夹
|
||||||
|
let target_img = load_image("samples/ken.jpg").expect("请确保 samples/ken.jpg 存在");
|
||||||
|
let bg_img = load_image("samples/kenyuan.jpg").expect("请确保 samples/kenyuan.jpg 存在");
|
||||||
|
|
||||||
|
// 2. 执行匹配
|
||||||
|
// 如果是那种带有明显阴影边缘的复杂滑块,建议 simple_target 传 false
|
||||||
|
let start = std::time::Instant::now();
|
||||||
|
let result = engine
|
||||||
|
.slide_comparison(&target_img, &bg_img)
|
||||||
|
.expect("Slide match 执行失败");
|
||||||
|
let duration = start.elapsed();
|
||||||
|
|
||||||
|
// 3. 打印结果
|
||||||
|
println!("-------------------------------------------");
|
||||||
|
println!("滑块匹配测试结果:");
|
||||||
|
println!("检测坐标: [x: {}, y: {}]", result.target_x, result.target_y);
|
||||||
|
println!("置信度: {:.4}", result.confidence);
|
||||||
|
println!("耗时: {:?}", duration);
|
||||||
|
println!("-------------------------------------------");
|
||||||
|
|
||||||
|
// 验证基本逻辑:坐标不应为 0 (除非匹配失败)
|
||||||
|
assert_eq!(result.target_x, 171);
|
||||||
|
assert_eq!(result.target_y, 90);
|
||||||
|
assert!(result.confidence > 0.0);
|
||||||
|
}
|
||||||
|
|||||||