use anyhow::{Context, Result, anyhow, bail}; use base64::{Engine as _, engine::general_purpose}; use image::{DynamicImage, GenericImageView, ImageBuffer, ImageFormat, Luma, Rgb, RgbImage, Rgba}; use std::fs; use std::path::{Path, PathBuf}; use tract_onnx::prelude::tract_ndarray::{Array3, ArrayD, ArrayViewD}; #[derive(Debug)] pub enum ColorMode { RGB, RGBA, L, } /// 定义支持的输入类型枚举 pub enum ImageInput { Bytes(Vec), Array(ArrayD), // 对应 numpy 数组 Path(PathBuf), Base64(String), DynamicImage(DynamicImage), } /// 模拟 Python 的 load_image_from_input #[allow(dead_code)] pub fn load_image_from_input(img_input: ImageInput) -> Result { match img_input { // 2. 处理字节流 (Bytes) ImageInput::Bytes(bytes) => { image::load_from_memory(&bytes).context("Failed to load image from bytes") } // 1. 已经是 DynamicImage ImageInput::DynamicImage(i) => Ok(i), // 5. 处理 ndarray (Numpy-like) // 假设输入是 HWC 格式的 Array3 ImageInput::Array(a) => numpy_to_pil_image(a.view()), // 4. 处理 Base64 字符串 ImageInput::Base64(b) => base64_to_image(&b), // 3. 处理文件路径 (Path) ImageInput::Path(p) => image::open(p).context("Failed to open image from path"), } } fn base64_to_image(b64_str: &str) -> Result { // 过滤掉可能存在的 base64 前缀,例如 "data:image/png;base64," let clean_b64 = if let Some(pos) = b64_str.find(",") { &b64_str[pos + 1..] } else { &b64_str }; let bytes = general_purpose::STANDARD .decode(clean_b64.trim()) .map_err(|e| anyhow!("Base64 decode error: {}", e))?; image::load_from_memory(&bytes).context("Failed to load image from decoded base64") } /// 读取图片文件并转换为 base64 编码字符串 /// 对应 Python 版 get_img_base64 pub fn get_img_base64>(image_path: P) -> Result { // 1. 读取文件原始字节流 // 使用 AsRef 泛型可以让函数同时支持 String, &str, PathBuf 等类型 let image_data = fs::read(&image_path) .with_context(|| format!("Failed to read image file: {:?}", image_path.as_ref()))?; // 2. 进行 Base64 编码 // 使用 STANDARD 引擎对齐 Python 的 base64.b64encode let b64_string = general_purpose::STANDARD.encode(image_data); Ok(b64_string) } /// 封装数组转图像的逻辑,对齐 Python 版 _numpy_to_pil_image fn numpy_to_pil_image(array: ArrayViewD) -> Result { let shape = array.shape(); let dim = shape.len(); // 1. 确保数据在内存中是连续的 (C order / Standard Layout) // 如果 arr 是经过切片或转置的,这一步会进行必要的内存拷贝 let standard = array.as_standard_layout(); let (raw_data, _offset) = standard.to_owned().into_raw_vec_and_offset(); match dim { // 对应 Python: len(array.shape) == 2 (灰度图 H, W) 2 => { let (h, w) = (shape[0], shape[1]); ImageBuffer::, _>::from_raw(w as u32, h as u32, raw_data) .map(DynamicImage::ImageLuma8) .ok_or_else(|| anyhow!("Failed to create Luma image from 2D array")) } // 对应 Python: len(array.shape) == 3 (H, W, C) 3 => { let (h, w, c) = (shape[0], shape[1], shape[2]); match c { // 对应 Python: array.shape[2] == 1 (单通道 H, W, 1) 1 => ImageBuffer::, _>::from_raw(w as u32, h as u32, raw_data) .map(DynamicImage::ImageLuma8), // 对应 Python: array.shape[2] == 3 (RGB H, W, 3) 3 => ImageBuffer::, _>::from_raw(w as u32, h as u32, raw_data) .map(DynamicImage::ImageRgb8), // 对应 Python: array.shape[2] == 4 (RGBA H, W, 4) 4 => ImageBuffer::, _>::from_raw(w as u32, h as u32, raw_data) .map(DynamicImage::ImageRgba8), _ => { return Err(anyhow!("不支持的通道数: {}", c)); } } .ok_or_else(|| anyhow!("转换彩色图失败")) } _ => Err(anyhow!("不支持的数组维度: {},仅支持 2D 或 3D", dim)), } } /// 对应 Python 的 png_rgba_black_preprocess /// 将带有透明通道的图片转换为白色背景的 RGB 图片 pub fn png_rgba_white_preprocess(img: &DynamicImage) -> DynamicImage { // 1. 检查是否包含透明通道,如果没有,直接克隆并返回 if !img.color().has_alpha() { return DynamicImage::ImageRgb8(img.to_rgb8()); } let (width, height) = img.dimensions(); // 2. 创建一个新的 RGB 图像缓冲,默认填充为白色 (255, 255, 255) let mut background = ImageBuffer::from_pixel(width, height, Rgb([255u8, 255u8, 255u8])); // 3. 获取原图的 RGBA 视图 let rgba_img = img.to_rgba8(); // 4. 遍历像素并手动进行 Alpha 混合 // 对应 Python 的 image.paste(img, ..., mask=img) // 使用 enumerate_pixels_mut 同时获取坐标和背景像素的可变引用,减少查找开销 for (x, y, bg_pixel) in background.enumerate_pixels_mut() { // 安全性说明:x, y 源自 background 尺寸,与 rgba_img 一致,get_pixel 是安全的 let src_pixel = rgba_img.get_pixel(x, y); let alpha_u8 = src_pixel[3]; match alpha_u8 { // 情况 A:完全不透明,直接覆盖背景色 255 => { bg_pixel.0 = [src_pixel[0], src_pixel[1], src_pixel[2]]; } // 情况 B:完全透明,保持背景色(白色),无需操作 0 => { continue; } // 情况 C:半透明,进行 Alpha 混合计算 _ => { let alpha = alpha_u8 as f32 / 255.0; let inv_alpha = 1.0 - alpha; bg_pixel[0] = (src_pixel[0] as f32 * alpha + 255.0 * inv_alpha).round() as u8; bg_pixel[1] = (src_pixel[1] as f32 * alpha + 255.0 * inv_alpha).round() as u8; bg_pixel[2] = (src_pixel[2] as f32 * alpha + 255.0 * inv_alpha).round() as u8; } } } DynamicImage::ImageRgb8(background) } pub fn image_to_numpy(image: &DynamicImage, mode: ColorMode) -> Result> { // 1. 模式转换 (对应 image.convert(target_mode)),此函数在时保留看后续优化是否需要替代image_to_ndarray // Rust image 库通过 to_rgb8, to_luma8 等方法实现转换 let (width, height) = image.dimensions(); let (channels, raw) = match mode { ColorMode::RGB => (3, image.to_rgb8().into_raw()), ColorMode::L => (1, image.to_luma8().into_raw()), ColorMode::RGBA => (4, image.to_rgba8().into_raw()), }; Array3::from_shape_vec((height as usize, width as usize, channels), raw) .map_err(|e| anyhow!("Failed to build ndarray: {}", e)) } pub fn numpy_to_image(array: ArrayViewD, mode: ColorMode) -> Result { let shape = array.shape(); // 1. 基础维度检查 (必须是 H, W, C 三维数组) if shape.len() != 3 { bail!("Expected a 3D array (H, W, C), but got {}D", shape.len()); } let height = shape[0] as u32; let width = shape[1] as u32; let channels = shape[2]; // 2. 检查通道数是否与模式匹配 let expected_channels = match mode { ColorMode::L => 1, ColorMode::RGB => 3, ColorMode::RGBA => 4, }; if channels != expected_channels { bail!( "Mode {:?} expects {} channels, but array has {}", mode, expected_channels, channels ); } // 确保数据连续性 (C-order) let standard = array.as_standard_layout(); let (raw_data, _) = standard.to_owned().into_raw_vec_and_offset(); match mode { ColorMode::L => ImageBuffer::, _>::from_raw(width, height, raw_data) .map(DynamicImage::ImageLuma8), ColorMode::RGB => ImageBuffer::, _>::from_raw(width, height, raw_data) .map(DynamicImage::ImageRgb8), ColorMode::RGBA => ImageBuffer::, _>::from_raw(width, height, raw_data) .map(DynamicImage::ImageRgba8), } .ok_or_else(|| anyhow!("Failed to construct ImageBuffer. Buffer size might be incorrect.")) } pub fn image_to_ndarray(img: &DynamicImage) -> Array3 { let (width, height) = img.dimensions(); // 1. 强制转为 RGB8 (丢弃 Alpha 通道,与 Python 的 target_mode='RGB' 对齐) let rgb_img = img.to_rgb8(); // 2. 获取原始像素数据 let raw_data = rgb_img.into_raw(); // 3. 构造数组 (通道数改为 3) Array3::from_shape_vec((height as usize, width as usize, 3), raw_data) .expect("Failed to construct ndarray from image") // 建议显式报错,而不是返回全黑图 } #[allow(dead_code)] fn save_rust_result(result: &ImageBuffer, Vec>, filename: &str) { let (width, height) = result.dimensions(); // 1. 寻找最值进行归一化 let mut max_val = f32::MIN; let mut min_val = f32::MAX; for p in result.pixels() { if p.0[0] > max_val { max_val = p.0[0]; } if p.0[0] < min_val { min_val = p.0[0]; } } // 2. 创建 8 位灰度图 let mut out_buf = ImageBuffer::new(width, height); for y in 0..height { for x in 0..width { let val = result.get_pixel(x, y).0[0]; let normalized = if max_val > min_val { ((val - min_val) / (max_val - min_val) * 255.0) as u8 } else { 0u8 }; out_buf.put_pixel(x, y, Luma([normalized])); } } // 3. 保存 DynamicImage::ImageLuma8(out_buf).save(filename).unwrap(); println!("Rust 结果热力图已保存至: {}", filename); }