use axum::{ body::Body, extract::{Query, State}, http::{HeaderMap, Method, Request, StatusCode}, response::{IntoResponse, Response}, }; use std::collections::HashMap; use std::sync::{Arc, RwLock}; use tokio_util::io::ReaderStream; use crate::models::Payload; use crate::router::MockRouter; /// 共享的应用状态,router 现在由 RwLock 保护以支持热重载 pub struct AppState { pub router: RwLock, } /// 提取请求的 Content-Type(去掉参数部分,如 boundary) fn extract_content_type(headers: &HeaderMap) -> Option { headers .get(axum::http::header::CONTENT_TYPE) .and_then(|v| v.to_str().ok()) .map(|s| s.split(';').next().unwrap_or(s).trim().to_lowercase()) } /// 根据 Content-Type 解析 Body(始终以请求的 Content-Type 为准) fn parse_body(content_type: Option<&str>, bytes: &[u8]) -> Payload { if bytes.is_empty() { return Payload::None; } match content_type { Some(ct) if ct.contains("application/json") => { serde_json::from_slice(bytes) .map(Payload::Json) .unwrap_or_else(|_| { // JSON 解析失败,降级为文本 Payload::Text(String::from_utf8_lossy(bytes).to_string()) }) } Some(ct) if ct.contains("xml") => { Payload::Xml(String::from_utf8_lossy(bytes).to_string()) } Some(ct) if ct.contains("form-urlencoded") => { Payload::Form(parse_urlencoded(bytes)) } Some(ct) if ct.contains("multipart/form-data") => { Payload::Multipart(extract_multipart_data(bytes)) } _ => { Payload::Text(String::from_utf8_lossy(bytes).to_string()) } } } /// 解析 urlencoded 格式 fn parse_urlencoded(bytes: &[u8]) -> HashMap { let body = String::from_utf8_lossy(bytes); let mut map = HashMap::new(); for pair in body.split('&') { if let Some((key, value)) = pair.split_once('=') { let decoded_key = urlencoding_decode(key); let decoded_value = urlencoding_decode(value); map.insert(decoded_key, decoded_value); } } map } /// URL 解码(简单实现) fn urlencoding_decode(s: &str) -> String { let mut result = String::new(); let mut chars = s.chars().peekable(); while let Some(c) = chars.next() { if c == '+' { result.push(' '); } else if c == '%' { let hex: String = chars.by_ref().take(2).collect(); if let Ok(byte) = u8::from_str_radix(&hex, 16) { result.push(byte as char); } else { result.push('%'); result.push_str(&hex); } } else { result.push(c); } } result } /// 从 multipart body 中提取键值对 fn extract_multipart_data(bytes: &[u8]) -> HashMap { let body = String::from_utf8_lossy(bytes); let mut map = HashMap::new(); // 分割 boundary let lines: Vec<&str> = body.lines().collect(); let mut current_name: Option = None; let mut current_value = String::new(); let mut in_value = false; for line in &lines { // 检测 Content-Disposition 行,提取 name if line.contains("Content-Disposition") && line.contains("name=") { // 保存上一个字段的值 if let Some(name) = current_name.take() { map.insert(name, current_value.trim().to_string()); current_value.clear(); } // 提取 name 属性 if let Some(start) = line.find("name=\"") { let start = start + 6; if let Some(end) = line[start..].find('"') { current_name = Some(line[start..start + end].to_string()); in_value = false; } } } else if line.starts_with("Content-Type") { // 跳过 Content-Type 行 continue; } else if line.is_empty() { // 空行后面是值 in_value = true; } else if in_value { // 收集值内容 if !current_value.is_empty() { current_value.push('\n'); } current_value.push_str(line); } } // 保存最后一个字段 if let Some(name) = current_name { map.insert(name, current_value.trim().to_string()); } map } /// 全局统一请求处理函数 pub async fn mock_handler( State(state): State>, // State 必须是第一个或靠前的参数 method: Method, headers: HeaderMap, Query(params): Query>, req: Request, // Request 必须是最后一个参数 ) -> impl IntoResponse { // 1. 提取 path 和 method let path = req.uri().path().to_string(); let method_str = method.as_str().to_string(); // 2. 提取请求的 Content-Type let req_content_type = extract_content_type(&headers); // 3. 读取请求 body let body_bytes = match axum::body::to_bytes(req.into_body(), 10 * 1024 * 1024).await { Ok(bytes) => bytes, Err(_) => { return Response::builder() .status(StatusCode::BAD_REQUEST) .body(Body::from("Read body error")) .unwrap(); } }; // 4. 根据【请求的 Content-Type】解析 body let parsed_body = parse_body(req_content_type.as_deref(), &body_bytes); // 5. 将 Axum HeaderMap 转换为简单的 HashMap let mut req_headers = HashMap::new(); for (name, value) in headers.iter() { if let Ok(v) = value.to_str() { req_headers.insert(name.as_str().to_string(), v.to_string()); } } // 6. 执行匹配逻辑:先获取读锁 (Read Lock) let maybe_rule = { let router = state.router.read().expect("Failed to acquire read lock"); router.match_rule(&method_str, &path, ¶ms, &req_headers, &parsed_body).cloned() // 此处使用 .cloned() 以便尽早释放读锁,避免阻塞热重载写锁 }; if let Some(rule) = maybe_rule { // 7. 处理模拟延迟 if let Some(ref settings) = rule.settings { if let Some(delay) = settings.delay_ms { tokio::time::sleep(std::time::Duration::from_millis(delay)).await; } } // 8. 构建响应 let status = StatusCode::from_u16(rule.response.status).unwrap_or(StatusCode::OK); let mut response_builder = Response::builder().status(status); if let Some(ref h) = rule.response.headers { for (k, v) in h { response_builder = response_builder.header(k, v); } } // 9. Smart Body 逻辑 if let Some(file_path) = rule.response.get_file_path() { match tokio::fs::File::open(file_path).await { Ok(file) => { let stream = ReaderStream::new(file); let body = Body::from_stream(stream); response_builder.body(body).unwrap() } Err(_) => Response::builder() .status(StatusCode::INTERNAL_SERVER_ERROR) .body(Body::from(format!( "Mock Error: File not found at {}", file_path ))) .unwrap(), } } else { // 内联模式:直接返回字符串内容 response_builder .body(Body::from(rule.response.body.clone())) .unwrap() } } else { // 匹配失败返回 404 Response::builder() .status(StatusCode::NOT_FOUND) .body(Body::from("No mock rule matched this request")) .unwrap() } }