265 lines
9.9 KiB
Rust
265 lines
9.9 KiB
Rust
use anyhow::{Context, Result, anyhow, bail};
|
||
use base64::{Engine as _, engine::general_purpose};
|
||
use image::{DynamicImage, GenericImageView, ImageBuffer, ImageFormat, Luma, Rgb, RgbImage, Rgba};
|
||
use std::fs;
|
||
use std::path::{Path, PathBuf};
|
||
use tract_onnx::prelude::tract_ndarray::{Array3, ArrayD, ArrayViewD};
|
||
#[derive(Debug)]
|
||
pub enum ColorMode {
|
||
RGB,
|
||
RGBA,
|
||
L,
|
||
}
|
||
/// 定义支持的输入类型枚举
|
||
pub enum ImageInput {
|
||
Bytes(Vec<u8>),
|
||
Array(ArrayD<u8>), // 对应 numpy 数组
|
||
Path(PathBuf),
|
||
Base64(String),
|
||
DynamicImage(DynamicImage),
|
||
}
|
||
/// 模拟 Python 的 load_image_from_input
|
||
#[allow(dead_code)]
|
||
pub fn load_image_from_input(img_input: ImageInput) -> Result<DynamicImage> {
|
||
match img_input {
|
||
// 2. 处理字节流 (Bytes)
|
||
ImageInput::Bytes(bytes) => {
|
||
image::load_from_memory(&bytes).context("Failed to load image from bytes")
|
||
}
|
||
// 1. 已经是 DynamicImage
|
||
ImageInput::DynamicImage(i) => Ok(i),
|
||
// 5. 处理 ndarray (Numpy-like)
|
||
// 假设输入是 HWC 格式的 Array3<u8>
|
||
ImageInput::Array(a) => numpy_to_pil_image(a.view()),
|
||
// 4. 处理 Base64 字符串
|
||
ImageInput::Base64(b) => base64_to_image(&b),
|
||
// 3. 处理文件路径 (Path)
|
||
ImageInput::Path(p) => image::open(p).context("Failed to open image from path"),
|
||
}
|
||
}
|
||
fn base64_to_image(b64_str: &str) -> Result<DynamicImage> {
|
||
// 过滤掉可能存在的 base64 前缀,例如 "data:image/png;base64,"
|
||
let clean_b64 = if let Some(pos) = b64_str.find(",") {
|
||
&b64_str[pos + 1..]
|
||
} else {
|
||
&b64_str
|
||
};
|
||
|
||
let bytes = general_purpose::STANDARD
|
||
.decode(clean_b64.trim())
|
||
.map_err(|e| anyhow!("Base64 decode error: {}", e))?;
|
||
|
||
image::load_from_memory(&bytes).context("Failed to load image from decoded base64")
|
||
}
|
||
|
||
/// 读取图片文件并转换为 base64 编码字符串
|
||
/// 对应 Python 版 get_img_base64
|
||
pub fn get_img_base64<P: AsRef<Path>>(image_path: P) -> Result<String> {
|
||
// 1. 读取文件原始字节流
|
||
// 使用 AsRef<Path> 泛型可以让函数同时支持 String, &str, PathBuf 等类型
|
||
let image_data = fs::read(&image_path)
|
||
.with_context(|| format!("Failed to read image file: {:?}", image_path.as_ref()))?;
|
||
|
||
// 2. 进行 Base64 编码
|
||
// 使用 STANDARD 引擎对齐 Python 的 base64.b64encode
|
||
let b64_string = general_purpose::STANDARD.encode(image_data);
|
||
|
||
Ok(b64_string)
|
||
}
|
||
|
||
/// 封装数组转图像的逻辑,对齐 Python 版 _numpy_to_pil_image
|
||
fn numpy_to_pil_image(array: ArrayViewD<u8>) -> Result<DynamicImage> {
|
||
let shape = array.shape();
|
||
let dim = shape.len();
|
||
|
||
// 1. 确保数据在内存中是连续的 (C order / Standard Layout)
|
||
// 如果 arr 是经过切片或转置的,这一步会进行必要的内存拷贝
|
||
let standard = array.as_standard_layout();
|
||
let (raw_data, _offset) = standard.to_owned().into_raw_vec_and_offset();
|
||
|
||
match dim {
|
||
// 对应 Python: len(array.shape) == 2 (灰度图 H, W)
|
||
2 => {
|
||
let (h, w) = (shape[0], shape[1]);
|
||
ImageBuffer::<Luma<u8>, _>::from_raw(w as u32, h as u32, raw_data)
|
||
.map(DynamicImage::ImageLuma8)
|
||
.ok_or_else(|| anyhow!("Failed to create Luma image from 2D array"))
|
||
}
|
||
|
||
// 对应 Python: len(array.shape) == 3 (H, W, C)
|
||
3 => {
|
||
let (h, w, c) = (shape[0], shape[1], shape[2]);
|
||
match c {
|
||
// 对应 Python: array.shape[2] == 1 (单通道 H, W, 1)
|
||
1 => ImageBuffer::<Luma<u8>, _>::from_raw(w as u32, h as u32, raw_data)
|
||
.map(DynamicImage::ImageLuma8),
|
||
|
||
// 对应 Python: array.shape[2] == 3 (RGB H, W, 3)
|
||
3 => ImageBuffer::<Rgb<u8>, _>::from_raw(w as u32, h as u32, raw_data)
|
||
.map(DynamicImage::ImageRgb8),
|
||
|
||
// 对应 Python: array.shape[2] == 4 (RGBA H, W, 4)
|
||
4 => ImageBuffer::<Rgba<u8>, _>::from_raw(w as u32, h as u32, raw_data)
|
||
.map(DynamicImage::ImageRgba8),
|
||
|
||
_ => {
|
||
return Err(anyhow!("不支持的通道数: {}", c));
|
||
}
|
||
}
|
||
.ok_or_else(|| anyhow!("转换彩色图失败"))
|
||
}
|
||
|
||
_ => Err(anyhow!("不支持的数组维度: {},仅支持 2D 或 3D", dim)),
|
||
}
|
||
}
|
||
|
||
/// 对应 Python 的 png_rgba_black_preprocess
|
||
/// 将带有透明通道的图片转换为白色背景的 RGB 图片
|
||
|
||
pub fn png_rgba_white_preprocess(img: &DynamicImage) -> DynamicImage {
|
||
// 1. 检查是否包含透明通道,如果没有,直接克隆并返回
|
||
if !img.color().has_alpha() {
|
||
return DynamicImage::ImageRgb8(img.to_rgb8());
|
||
}
|
||
|
||
let (width, height) = img.dimensions();
|
||
|
||
// 2. 创建一个新的 RGB 图像缓冲,默认填充为白色 (255, 255, 255)
|
||
let mut background = ImageBuffer::from_pixel(width, height, Rgb([255u8, 255u8, 255u8]));
|
||
|
||
// 3. 获取原图的 RGBA 视图
|
||
let rgba_img = img.to_rgba8();
|
||
|
||
// 4. 遍历像素并手动进行 Alpha 混合
|
||
// 对应 Python 的 image.paste(img, ..., mask=img)
|
||
// 使用 enumerate_pixels_mut 同时获取坐标和背景像素的可变引用,减少查找开销
|
||
for (x, y, bg_pixel) in background.enumerate_pixels_mut() {
|
||
// 安全性说明:x, y 源自 background 尺寸,与 rgba_img 一致,get_pixel 是安全的
|
||
let src_pixel = rgba_img.get_pixel(x, y);
|
||
let alpha_u8 = src_pixel[3];
|
||
|
||
match alpha_u8 {
|
||
// 情况 A:完全不透明,直接覆盖背景色
|
||
255 => {
|
||
bg_pixel.0 = [src_pixel[0], src_pixel[1], src_pixel[2]];
|
||
}
|
||
// 情况 B:完全透明,保持背景色(白色),无需操作
|
||
0 => {
|
||
continue;
|
||
}
|
||
// 情况 C:半透明,进行 Alpha 混合计算
|
||
_ => {
|
||
let alpha = alpha_u8 as f32 / 255.0;
|
||
let inv_alpha = 1.0 - alpha;
|
||
|
||
bg_pixel[0] = (src_pixel[0] as f32 * alpha + 255.0 * inv_alpha).round() as u8;
|
||
bg_pixel[1] = (src_pixel[1] as f32 * alpha + 255.0 * inv_alpha).round() as u8;
|
||
bg_pixel[2] = (src_pixel[2] as f32 * alpha + 255.0 * inv_alpha).round() as u8;
|
||
}
|
||
}
|
||
}
|
||
|
||
DynamicImage::ImageRgb8(background)
|
||
}
|
||
pub fn image_to_numpy(image: &DynamicImage, mode: ColorMode) -> Result<Array3<u8>> {
|
||
// 1. 模式转换 (对应 image.convert(target_mode)),此函数在时保留看后续优化是否需要替代image_to_ndarray
|
||
// Rust image 库通过 to_rgb8, to_luma8 等方法实现转换
|
||
let (width, height) = image.dimensions();
|
||
|
||
let (channels, raw) = match mode {
|
||
ColorMode::RGB => (3, image.to_rgb8().into_raw()),
|
||
ColorMode::L => (1, image.to_luma8().into_raw()),
|
||
ColorMode::RGBA => (4, image.to_rgba8().into_raw()),
|
||
};
|
||
|
||
Array3::from_shape_vec((height as usize, width as usize, channels), raw)
|
||
.map_err(|e| anyhow!("Failed to build ndarray: {}", e))
|
||
}
|
||
|
||
pub fn numpy_to_image(array: ArrayViewD<u8>, mode: ColorMode) -> Result<DynamicImage> {
|
||
let shape = array.shape();
|
||
// 1. 基础维度检查 (必须是 H, W, C 三维数组)
|
||
if shape.len() != 3 {
|
||
bail!("Expected a 3D array (H, W, C), but got {}D", shape.len());
|
||
}
|
||
|
||
let height = shape[0] as u32;
|
||
let width = shape[1] as u32;
|
||
let channels = shape[2];
|
||
// 2. 检查通道数是否与模式匹配
|
||
let expected_channels = match mode {
|
||
ColorMode::L => 1,
|
||
ColorMode::RGB => 3,
|
||
ColorMode::RGBA => 4,
|
||
};
|
||
if channels != expected_channels {
|
||
bail!(
|
||
"Mode {:?} expects {} channels, but array has {}",
|
||
mode,
|
||
expected_channels,
|
||
channels
|
||
);
|
||
}
|
||
// 确保数据连续性 (C-order)
|
||
let standard = array.as_standard_layout();
|
||
let (raw_data, _) = standard.to_owned().into_raw_vec_and_offset();
|
||
|
||
match mode {
|
||
ColorMode::L => ImageBuffer::<Luma<u8>, _>::from_raw(width, height, raw_data)
|
||
.map(DynamicImage::ImageLuma8),
|
||
ColorMode::RGB => ImageBuffer::<Rgb<u8>, _>::from_raw(width, height, raw_data)
|
||
.map(DynamicImage::ImageRgb8),
|
||
ColorMode::RGBA => ImageBuffer::<Rgba<u8>, _>::from_raw(width, height, raw_data)
|
||
.map(DynamicImage::ImageRgba8),
|
||
}
|
||
.ok_or_else(|| anyhow!("Failed to construct ImageBuffer. Buffer size might be incorrect."))
|
||
}
|
||
pub fn image_to_ndarray(img: &DynamicImage) -> Array3<u8> {
|
||
let (width, height) = img.dimensions();
|
||
|
||
// 1. 强制转为 RGB8 (丢弃 Alpha 通道,与 Python 的 target_mode='RGB' 对齐)
|
||
let rgb_img = img.to_rgb8();
|
||
|
||
// 2. 获取原始像素数据
|
||
let raw_data = rgb_img.into_raw();
|
||
|
||
// 3. 构造数组 (通道数改为 3)
|
||
Array3::from_shape_vec((height as usize, width as usize, 3), raw_data)
|
||
.expect("Failed to construct ndarray from image") // 建议显式报错,而不是返回全黑图
|
||
}
|
||
|
||
#[allow(dead_code)]
|
||
fn save_rust_result(result: &ImageBuffer<Luma<f32>, Vec<f32>>, filename: &str) {
|
||
let (width, height) = result.dimensions();
|
||
|
||
// 1. 寻找最值进行归一化
|
||
let mut max_val = f32::MIN;
|
||
let mut min_val = f32::MAX;
|
||
for p in result.pixels() {
|
||
if p.0[0] > max_val {
|
||
max_val = p.0[0];
|
||
}
|
||
if p.0[0] < min_val {
|
||
min_val = p.0[0];
|
||
}
|
||
}
|
||
|
||
// 2. 创建 8 位灰度图
|
||
let mut out_buf = ImageBuffer::new(width, height);
|
||
for y in 0..height {
|
||
for x in 0..width {
|
||
let val = result.get_pixel(x, y).0[0];
|
||
let normalized = if max_val > min_val {
|
||
((val - min_val) / (max_val - min_val) * 255.0) as u8
|
||
} else {
|
||
0u8
|
||
};
|
||
out_buf.put_pixel(x, y, Luma([normalized]));
|
||
}
|
||
}
|
||
|
||
// 3. 保存
|
||
DynamicImage::ImageLuma8(out_buf).save(filename).unwrap();
|
||
println!("Rust 结果热力图已保存至: {}", filename);
|
||
}
|