diff --git a/src/image_io.rs b/src/image_io.rs index e3e1fd8..1373286 100644 --- a/src/image_io.rs +++ b/src/image_io.rs @@ -1,24 +1,103 @@ -use anyhow::{Context, Result}; +use anyhow::{Context, Result, anyhow}; use base64::{Engine as _, engine::general_purpose}; -use image::{DynamicImage, GenericImageView, ImageBuffer, Luma, Rgb, RgbImage}; +use image::{DynamicImage, GenericImageView, ImageBuffer, ImageFormat, Luma, Rgb, RgbImage, Rgba}; +use std::fs; use std::path::{Path, PathBuf}; -use tract_onnx::prelude::tract_ndarray::Array3; - +use tract_onnx::prelude::tract_ndarray::{Array3, ArrayD, ArrayViewD}; /// 定义支持的输入类型枚举 pub enum ImageInput { Bytes(Vec), - Array(Array3), + Array(ArrayD), // 对应 numpy 数组 Path(PathBuf), Base64(String), DynamicImage(DynamicImage), } +impl From<&str> for ImageInput { + fn from(s: &str) -> Self { + if Path::new(s).exists() { + ImageInput::Path(s.into()) + } else { + ImageInput::Base64(s.to_string()) + } + } +} /// 模拟 Python 的 load_image_from_input #[allow(dead_code)] -pub fn load_image_from_input(input: ImageInput) -> Result { - match input { +pub fn load_image_from_input(img_input: ImageInput) -> Result { + match img_input { + // 2. 处理字节流 (Bytes) + ImageInput::Bytes(bytes) => { + image::load_from_memory(&bytes).context("Failed to load image from bytes") + } + // 1. 已经是 DynamicImage ImageInput::DynamicImage(img) => Ok(img), - _ => todo!("后续补充"), + // 5. 处理 ndarray (Numpy-like) + // 假设输入是 HWC 格式的 Array3 + ImageInput::Array(arr) => numpy_to_pil_image(arr.view()), + // 4. 处理 Base64 字符串 + ImageInput::Base64(b64_str) => base64_to_image(&b64_str), + // 3. 处理文件路径 (Path) + ImageInput::Path(path) => image::open(path).context("Failed to open image from path"), + } +} +fn base64_to_image(img_base64: &str) -> Result { + // 过滤掉可能存在的 base64 前缀,例如 "data:image/png;base64," + let clean_b64 = if let Some(pos) = img_base64.find(",") { + &img_base64[pos + 1..] + } else { + &img_base64 + }; + + let bytes = general_purpose::STANDARD + .decode(clean_b64.trim()) + .map_err(|e| anyhow!("Base64 decode error: {}", e))?; + + image::load_from_memory(&bytes).context("Failed to load image from decoded base64") +} +/// 封装数组转图像的逻辑,对齐 Python 版 _numpy_to_pil_image +fn numpy_to_pil_image(array: ArrayViewD) -> Result { + let shape = array.shape(); + let dim = shape.len(); + + // 1. 确保数据在内存中是连续的 (C order / Standard Layout) + // 如果 arr 是经过切片或转置的,这一步会进行必要的内存拷贝 + let standard = array.as_standard_layout(); + let (raw_data, _offset) = standard.to_owned().into_raw_vec_and_offset(); + + match dim { + // 对应 Python: len(array.shape) == 2 (灰度图 H, W) + 2 => { + let (h, w) = (shape[0], shape[1]); + ImageBuffer::, _>::from_raw(w as u32, h as u32, raw_data) + .map(DynamicImage::ImageLuma8) + .ok_or_else(|| anyhow!("Failed to create Luma image from 2D array")) + } + + // 对应 Python: len(array.shape) == 3 (H, W, C) + 3 => { + let (h, w, c) = (shape[0], shape[1], shape[2]); + match c { + // 对应 Python: array.shape[2] == 1 (单通道 H, W, 1) + 1 => ImageBuffer::, _>::from_raw(w as u32, h as u32, raw_data) + .map(DynamicImage::ImageLuma8), + + // 对应 Python: array.shape[2] == 3 (RGB H, W, 3) + 3 => ImageBuffer::, _>::from_raw(w as u32, h as u32, raw_data) + .map(DynamicImage::ImageRgb8), + + // 对应 Python: array.shape[2] == 4 (RGBA H, W, 4) + 4 => ImageBuffer::, _>::from_raw(w as u32, h as u32, raw_data) + .map(DynamicImage::ImageRgba8), + + _ => { + return Err(anyhow!("不支持的通道数: {}", c)); + } + } + .ok_or_else(|| anyhow!("转换彩色图失败")) + } + + _ => Err(anyhow!("不支持的数组维度: {},仅支持 2D 或 3D", dim)), } }