refactor: 完成图像加载模块重构,对齐 ddddocr Python 原版 IO 逻辑
This commit is contained in:
@@ -1,24 +1,103 @@
|
|||||||
use anyhow::{Context, Result};
|
use anyhow::{Context, Result, anyhow};
|
||||||
use base64::{Engine as _, engine::general_purpose};
|
use base64::{Engine as _, engine::general_purpose};
|
||||||
use image::{DynamicImage, GenericImageView, ImageBuffer, Luma, Rgb, RgbImage};
|
use image::{DynamicImage, GenericImageView, ImageBuffer, ImageFormat, Luma, Rgb, RgbImage, Rgba};
|
||||||
|
use std::fs;
|
||||||
use std::path::{Path, PathBuf};
|
use std::path::{Path, PathBuf};
|
||||||
use tract_onnx::prelude::tract_ndarray::Array3;
|
use tract_onnx::prelude::tract_ndarray::{Array3, ArrayD, ArrayViewD};
|
||||||
|
|
||||||
/// 定义支持的输入类型枚举
|
/// 定义支持的输入类型枚举
|
||||||
pub enum ImageInput {
|
pub enum ImageInput {
|
||||||
Bytes(Vec<u8>),
|
Bytes(Vec<u8>),
|
||||||
Array(Array3<u8>),
|
Array(ArrayD<u8>), // 对应 numpy 数组
|
||||||
Path(PathBuf),
|
Path(PathBuf),
|
||||||
Base64(String),
|
Base64(String),
|
||||||
DynamicImage(DynamicImage),
|
DynamicImage(DynamicImage),
|
||||||
}
|
}
|
||||||
|
|
||||||
|
impl From<&str> for ImageInput {
|
||||||
|
fn from(s: &str) -> Self {
|
||||||
|
if Path::new(s).exists() {
|
||||||
|
ImageInput::Path(s.into())
|
||||||
|
} else {
|
||||||
|
ImageInput::Base64(s.to_string())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
/// 模拟 Python 的 load_image_from_input
|
/// 模拟 Python 的 load_image_from_input
|
||||||
#[allow(dead_code)]
|
#[allow(dead_code)]
|
||||||
pub fn load_image_from_input(input: ImageInput) -> Result<DynamicImage> {
|
pub fn load_image_from_input(img_input: ImageInput) -> Result<DynamicImage> {
|
||||||
match input {
|
match img_input {
|
||||||
|
// 2. 处理字节流 (Bytes)
|
||||||
|
ImageInput::Bytes(bytes) => {
|
||||||
|
image::load_from_memory(&bytes).context("Failed to load image from bytes")
|
||||||
|
}
|
||||||
|
// 1. 已经是 DynamicImage
|
||||||
ImageInput::DynamicImage(img) => Ok(img),
|
ImageInput::DynamicImage(img) => Ok(img),
|
||||||
_ => todo!("后续补充"),
|
// 5. 处理 ndarray (Numpy-like)
|
||||||
|
// 假设输入是 HWC 格式的 Array3<u8>
|
||||||
|
ImageInput::Array(arr) => numpy_to_pil_image(arr.view()),
|
||||||
|
// 4. 处理 Base64 字符串
|
||||||
|
ImageInput::Base64(b64_str) => base64_to_image(&b64_str),
|
||||||
|
// 3. 处理文件路径 (Path)
|
||||||
|
ImageInput::Path(path) => image::open(path).context("Failed to open image from path"),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
fn base64_to_image(img_base64: &str) -> Result<DynamicImage> {
|
||||||
|
// 过滤掉可能存在的 base64 前缀,例如 "data:image/png;base64,"
|
||||||
|
let clean_b64 = if let Some(pos) = img_base64.find(",") {
|
||||||
|
&img_base64[pos + 1..]
|
||||||
|
} else {
|
||||||
|
&img_base64
|
||||||
|
};
|
||||||
|
|
||||||
|
let bytes = general_purpose::STANDARD
|
||||||
|
.decode(clean_b64.trim())
|
||||||
|
.map_err(|e| anyhow!("Base64 decode error: {}", e))?;
|
||||||
|
|
||||||
|
image::load_from_memory(&bytes).context("Failed to load image from decoded base64")
|
||||||
|
}
|
||||||
|
/// 封装数组转图像的逻辑,对齐 Python 版 _numpy_to_pil_image
|
||||||
|
fn numpy_to_pil_image(array: ArrayViewD<u8>) -> Result<DynamicImage> {
|
||||||
|
let shape = array.shape();
|
||||||
|
let dim = shape.len();
|
||||||
|
|
||||||
|
// 1. 确保数据在内存中是连续的 (C order / Standard Layout)
|
||||||
|
// 如果 arr 是经过切片或转置的,这一步会进行必要的内存拷贝
|
||||||
|
let standard = array.as_standard_layout();
|
||||||
|
let (raw_data, _offset) = standard.to_owned().into_raw_vec_and_offset();
|
||||||
|
|
||||||
|
match dim {
|
||||||
|
// 对应 Python: len(array.shape) == 2 (灰度图 H, W)
|
||||||
|
2 => {
|
||||||
|
let (h, w) = (shape[0], shape[1]);
|
||||||
|
ImageBuffer::<Luma<u8>, _>::from_raw(w as u32, h as u32, raw_data)
|
||||||
|
.map(DynamicImage::ImageLuma8)
|
||||||
|
.ok_or_else(|| anyhow!("Failed to create Luma image from 2D array"))
|
||||||
|
}
|
||||||
|
|
||||||
|
// 对应 Python: len(array.shape) == 3 (H, W, C)
|
||||||
|
3 => {
|
||||||
|
let (h, w, c) = (shape[0], shape[1], shape[2]);
|
||||||
|
match c {
|
||||||
|
// 对应 Python: array.shape[2] == 1 (单通道 H, W, 1)
|
||||||
|
1 => ImageBuffer::<Luma<u8>, _>::from_raw(w as u32, h as u32, raw_data)
|
||||||
|
.map(DynamicImage::ImageLuma8),
|
||||||
|
|
||||||
|
// 对应 Python: array.shape[2] == 3 (RGB H, W, 3)
|
||||||
|
3 => ImageBuffer::<Rgb<u8>, _>::from_raw(w as u32, h as u32, raw_data)
|
||||||
|
.map(DynamicImage::ImageRgb8),
|
||||||
|
|
||||||
|
// 对应 Python: array.shape[2] == 4 (RGBA H, W, 4)
|
||||||
|
4 => ImageBuffer::<Rgba<u8>, _>::from_raw(w as u32, h as u32, raw_data)
|
||||||
|
.map(DynamicImage::ImageRgba8),
|
||||||
|
|
||||||
|
_ => {
|
||||||
|
return Err(anyhow!("不支持的通道数: {}", c));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
.ok_or_else(|| anyhow!("转换彩色图失败"))
|
||||||
|
}
|
||||||
|
|
||||||
|
_ => Err(anyhow!("不支持的数组维度: {},仅支持 2D 或 3D", dim)),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user