refactor: 完成图像加载模块重构,对齐 ddddocr Python 原版 IO 逻辑

This commit is contained in:
2026-05-08 17:59:42 +08:00
parent 21bd1c93bf
commit f0db625bd1

View File

@@ -1,24 +1,103 @@
use anyhow::{Context, Result};
use anyhow::{Context, Result, anyhow};
use base64::{Engine as _, engine::general_purpose};
use image::{DynamicImage, GenericImageView, ImageBuffer, Luma, Rgb, RgbImage};
use image::{DynamicImage, GenericImageView, ImageBuffer, ImageFormat, Luma, Rgb, RgbImage, Rgba};
use std::fs;
use std::path::{Path, PathBuf};
use tract_onnx::prelude::tract_ndarray::Array3;
use tract_onnx::prelude::tract_ndarray::{Array3, ArrayD, ArrayViewD};
/// 定义支持的输入类型枚举
pub enum ImageInput {
Bytes(Vec<u8>),
Array(Array3<u8>),
Array(ArrayD<u8>), // 对应 numpy 数组
Path(PathBuf),
Base64(String),
DynamicImage(DynamicImage),
}
impl From<&str> for ImageInput {
fn from(s: &str) -> Self {
if Path::new(s).exists() {
ImageInput::Path(s.into())
} else {
ImageInput::Base64(s.to_string())
}
}
}
/// 模拟 Python 的 load_image_from_input
#[allow(dead_code)]
pub fn load_image_from_input(input: ImageInput) -> Result<DynamicImage> {
match input {
pub fn load_image_from_input(img_input: ImageInput) -> Result<DynamicImage> {
match img_input {
// 2. 处理字节流 (Bytes)
ImageInput::Bytes(bytes) => {
image::load_from_memory(&bytes).context("Failed to load image from bytes")
}
// 1. 已经是 DynamicImage
ImageInput::DynamicImage(img) => Ok(img),
_ => todo!("后续补充"),
// 5. 处理 ndarray (Numpy-like)
// 假设输入是 HWC 格式的 Array3<u8>
ImageInput::Array(arr) => numpy_to_pil_image(arr.view()),
// 4. 处理 Base64 字符串
ImageInput::Base64(b64_str) => base64_to_image(&b64_str),
// 3. 处理文件路径 (Path)
ImageInput::Path(path) => image::open(path).context("Failed to open image from path"),
}
}
fn base64_to_image(img_base64: &str) -> Result<DynamicImage> {
// 过滤掉可能存在的 base64 前缀,例如 "data:image/png;base64,"
let clean_b64 = if let Some(pos) = img_base64.find(",") {
&img_base64[pos + 1..]
} else {
&img_base64
};
let bytes = general_purpose::STANDARD
.decode(clean_b64.trim())
.map_err(|e| anyhow!("Base64 decode error: {}", e))?;
image::load_from_memory(&bytes).context("Failed to load image from decoded base64")
}
/// 封装数组转图像的逻辑,对齐 Python 版 _numpy_to_pil_image
fn numpy_to_pil_image(array: ArrayViewD<u8>) -> Result<DynamicImage> {
let shape = array.shape();
let dim = shape.len();
// 1. 确保数据在内存中是连续的 (C order / Standard Layout)
// 如果 arr 是经过切片或转置的,这一步会进行必要的内存拷贝
let standard = array.as_standard_layout();
let (raw_data, _offset) = standard.to_owned().into_raw_vec_and_offset();
match dim {
// 对应 Python: len(array.shape) == 2 (灰度图 H, W)
2 => {
let (h, w) = (shape[0], shape[1]);
ImageBuffer::<Luma<u8>, _>::from_raw(w as u32, h as u32, raw_data)
.map(DynamicImage::ImageLuma8)
.ok_or_else(|| anyhow!("Failed to create Luma image from 2D array"))
}
// 对应 Python: len(array.shape) == 3 (H, W, C)
3 => {
let (h, w, c) = (shape[0], shape[1], shape[2]);
match c {
// 对应 Python: array.shape[2] == 1 (单通道 H, W, 1)
1 => ImageBuffer::<Luma<u8>, _>::from_raw(w as u32, h as u32, raw_data)
.map(DynamicImage::ImageLuma8),
// 对应 Python: array.shape[2] == 3 (RGB H, W, 3)
3 => ImageBuffer::<Rgb<u8>, _>::from_raw(w as u32, h as u32, raw_data)
.map(DynamicImage::ImageRgb8),
// 对应 Python: array.shape[2] == 4 (RGBA H, W, 4)
4 => ImageBuffer::<Rgba<u8>, _>::from_raw(w as u32, h as u32, raw_data)
.map(DynamicImage::ImageRgba8),
_ => {
return Err(anyhow!("不支持的通道数: {}", c));
}
}
.ok_or_else(|| anyhow!("转换彩色图失败"))
}
_ => Err(anyhow!("不支持的数组维度: {},仅支持 2D 或 3D", dim)),
}
}