Axum 穿透多层代理:一行中间件锁定真实 IP,伪造头秒拒

Photos provided by Unsplash OR Pexels

Axum 多层代理处理与安全增强实战

在现实世界的生产环境中,Web 请求很少只经过单一代理。典型的架构可能包含:客户端 → CDN → 负载均衡器 → 应用网关 → 你的 Axum 应用。本章将深入探讨如何安全、高效地处理这种多层代理场景。

一、多层代理的挑战与安全模型

1.1 多层代理链的典型结构

客户端 (203.0.113.195)

CDN/边缘节点 (198.51.100.1)

负载均衡器 (10.0.1.100)

应用网关 (10.0.2.50)

你的Axum应用 (10.0.3.10)

对应的 X-Forwarded-For 头部:

X-Forwarded-For: 203.0.113.195, 198.51.100.1, 10.0.1.100

1.2 安全威胁模型

在多代理环境中,我们需要防范以下威胁:

  1. 头部注入:攻击者通过第一个代理注入恶意头部
  2. 代理链欺骗:不可信代理篡改整个代理链
  3. IP 伪造:即使经过多层代理,最终代理仍可能被欺骗

二、完整的多层代理处理器实现

2.1 基础结构定义

创建 src/advanced.rs

use std::net::{IpAddr, SocketAddr};
use std::str::FromStr;
use std::collections::HashSet;
use ipnetwork::IpNetwork;
use crate::config::TrustedProxiesConfig;
use thiserror::Error;
use serde::{Deserialize, Serialize};
use tracing::{debug, warn, trace};

/// 代理处理错误类型
#[derive(Error, Debug)]
pub enum ProxyError {
    #[error("无效的 IP 地址格式:{0}")]
    InvalidIpFormat(String),

    #[error("代理链验证失败:{0}")]
    ChainValidationFailed(String),

    #[error("头部解析错误:{0}")]
    HeaderParseError(String),
}

/// 代理链验证模式
#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq)]
pub enum ValidationMode {
    /// 宽松模式:只要最后一个代理可信,就接受整个链
    Lenient,
    /// 严格模式:要求链中所有代理都可信
    Strict,
    /// 跳数验证:从右向左找到第一个不可信代理
    HopByHop,
}

/// RFC 7239 Forwarded 头部解析器
/// 格式:Forwarded: for=192.0.2.60;proto=http;by=203.0.113.43
#[derive(Debug, Clone)]
pub struct ForwardedHeader {
    /// 客户端地址
    pub for_client: Option<IpAddr>,
    /// 代理标识
    pub by_proxy: Option<IpAddr>,
    /// 协议
    pub proto: Option<String>,
    /// 主机
    pub host: Option<String>,
}

/// 多层代理处理器配置
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct MultiProxyConfig {
    /// 代理链验证模式
    pub validation_mode: ValidationMode,
    /// 是否启用 RFC 7239 Forwarded 头部支持
    pub enable_rfc7239: bool,
    /// 最大代理跳数限制
    pub max_proxy_hops: usize,
    /// 允许的私有网络 CIDR(用于内部代理验证)
    pub allowed_private_nets: Vec<IpNetwork>,
    /// 是否启用代理链连续性检查
    pub enable_chain_continuity_check: bool,
}

impl Default for MultiProxyConfig {
    fn default() -> Self {
        Self {
            validation_mode: ValidationMode::HopByHop,
            enable_rfc7239: true,
            max_proxy_hops: 10,
            allowed_private_nets: vec![
                IpNetwork::from_str("10.0.0.0/8").unwrap(),
                IpNetwork::from_str("172.16.0.0/12").unwrap(),
                IpNetwork::from_str("192.168.0.0/16").unwrap(),
                IpNetwork::from_str("fd00::/8").unwrap(),
            ],
            enable_chain_continuity_check: true,
        }
    }
}

/// 代理链分析结果
#[derive(Debug, Clone)]
pub struct ProxyChainAnalysis {
    /// 可信客户端 IP(经过验证)
    pub trusted_client_ip: IpAddr,
    /// 完整的代理链(从客户端到当前代理)
    pub full_proxy_chain: Vec<IpAddr>,
    /// 可信代理链部分
    pub trusted_proxy_chain: Vec<IpAddr>,
    /// 验证通过的跳数
    pub validated_hops: usize,
    /// 是否通过完整性检查
    pub chain_integrity: bool,
    /// 使用的验证模式
    pub validation_mode_used: ValidationMode,
    /// 可能的安全警告
    pub security_warnings: Vec<String>,
}

/// 多层代理处理器核心实现
pub struct MultiProxyProcessor {
    /// 基础可信代理配置
    base_config: TrustedProxiesConfig,
    /// 高级配置
    advanced_config: MultiProxyConfig,
    /// 已知的可信代理集合(缓存,加速查找)
    trusted_ips_cache: HashSet<IpAddr>,
    /// 已知的私有网络集合
    private_nets_cache: Vec<IpNetwork>,
}

impl MultiProxyProcessor {
    /// 创建新的处理器
    pub fn new(
        base_config: TrustedProxiesConfig,
        advanced_config: Option<MultiProxyConfig>,
    ) -> Self {
        let config = advanced_config.unwrap_or_default();

        // 构建 IP 缓存以提高性能
        let mut trusted_ips_cache = HashSet::new();
        for proxy in &base_config.proxies {
            match proxy {
                TrustedProxy::Single(ip) => {
                    trusted_ips_cache.insert(*ip);
                }
                TrustedProxy::Cidr(network) => {
                    // 对于大型网络,我们只缓存网络本身,运行时检查
                    // 这里我们缓存网络的前缀
                    if network.prefix() >= 24 {
                        // 对于小网络(/24 及以上),我们可以缓存所有 IP
                        // 但为了简单,这里只缓存网络表示
                    }
                }
            }
        }

        Self {
            base_config,
            advanced_config: config,
            trusted_ips_cache,
            private_nets_cache: config.allowed_private_nets.clone(),
        }
    }

    /// 主处理函数:从请求中提取并验证客户端 IP
    pub fn process_request(
        &self,
        peer_addr: &SocketAddr,
        headers: &axum::http::HeaderMap,
    ) -> Result<ProxyChainAnalysis, ProxyError> {
        debug!("开始处理代理请求,对端地址:{}", peer_addr);

        // 1. 检查是否来自可信代理
        if !self.base_config.is_trusted(peer_addr) {
            warn!("请求来自不可信代理:{}", peer_addr);
            return self.handle_untrusted_source(peer_addr);
        }

        // 2. 尝试多种头部解析策略
        let analysis = self.analyze_proxy_chain(peer_addr.ip(), headers)?;

        // 3. 执行安全性检查
        self.perform_security_checks(&analysis)?;

        Ok(analysis)
    }

    /// 分析代理链的核心方法
    fn analyze_proxy_chain(
        &self,
        current_proxy_ip: IpAddr,
        headers: &axum::http::HeaderMap,
    ) -> Result<ProxyChainAnalysis, ProxyError> {
        // 收集所有可用的代理链信息
        let mut proxy_chains = Vec::new();
        let mut security_warnings = Vec::new();

        // 策略 1: 优先使用 RFC 7239 Forwarded 头部
        if self.advanced_config.enable_rfc7239 {
            if let Some(forwarded) = Self::parse_forwarded_header(headers) {
                if let Some(rfc_chain) = Self::extract_chain_from_forwarded(&forwarded) {
                    proxy_chains.push(("rfc7239", rfc_chain));
                }
            }
        }

        // 策略 2: 使用传统的 X-Forwarded-For 头部
        if let Some(xff_chain) = Self::parse_x_forwarded_for(headers) {
            proxy_chains.push(("xff", xff_chain));
        }

        // 策略 3: 使用 X-Real-IP 作为备用
        if let Some(real_ip) = Self::parse_x_real_ip(headers) {
            proxy_chains.push(("x-real-ip", vec![real_ip]));
        }

        // 如果没有找到任何代理信息
        if proxy_chains.is_empty() {
            debug!("未找到代理头部,使用直接连接 IP");
            return Ok(ProxyChainAnalysis {
                trusted_client_ip: current_proxy_ip,
                full_proxy_chain: vec![current_proxy_ip],
                trusted_proxy_chain: vec![current_proxy_ip],
                validated_hops: 0,
                chain_integrity: true,
                validation_mode_used: ValidationMode::Lenient,
                security_warnings: vec!["未使用代理头部,可能是直接连接".to_string()],
            });
        }

        // 选择最可靠的代理链(优先使用 RFC 7239)
        let (source, mut full_chain) = proxy_chains
            .into_iter()
            .max_by_key(|(source, _)| match *source {
                "rfc7239" => 3,
                "xff" => 2,
                "x-real-ip" => 1,
                _ => 0,
            })
            .unwrap();

        debug!("使用代理链来源:{},链:{:?}", source, full_chain);

        // 将当前代理 IP 添加到链的末尾
        full_chain.push(current_proxy_ip);

        // 根据配置的验证模式处理代理链
        let (trusted_client_ip, trusted_proxy_chain, validated_hops) = match self.advanced_config.validation_mode {
            ValidationMode::Lenient => self.validate_lenient(&full_chain),
            ValidationMode::Strict => self.validate_strict(&full_chain)?,
            ValidationMode::HopByHop => self.validate_hop_by_hop(&full_chain)?,
        };

        // 检查链连续性
        let chain_integrity = if self.advanced_config.enable_chain_continuity_check {
            self.check_chain_continuity(&full_chain, &trusted_proxy_chain)
        } else {
            true
        };

        // 收集安全警告
        if !chain_integrity {
            security_warnings.push("代理链连续性检查失败".to_string());
        }

        if full_chain.len() > self.advanced_config.max_proxy_hops {
            security_warnings.push(format!(
                "代理链长度超过限制:{}/{}",
                full_chain.len(),
                self.advanced_config.max_proxy_hops
            ));
        }

        Ok(ProxyChainAnalysis {
            trusted_client_ip,
            full_proxy_chain: full_chain,
            trusted_proxy_chain,
            validated_hops,
            chain_integrity,
            validation_mode_used: self.advanced_config.validation_mode,
            security_warnings,
        })
    }

    /// 宽松验证模式:只要最后一个代理可信,就接受整个链
    fn validate_lenient(&self, chain: &[IpAddr]) -> (IpAddr, Vec<IpAddr>, usize) {
        if chain.is_empty() {
            return (IpAddr::from([0, 0, 0, 0]), vec![], 0);
        }

        // 取链中的第一个 IP 作为客户端 IP
        let client_ip = chain[0];

        // 验证最后一个代理是否可信
        let last_proxy = chain.last().unwrap();
        let is_last_trusted = self.is_ip_trusted(last_proxy);

        if is_last_trusted {
            (
                client_ip,
                chain.iter().copied().collect(),
                chain.len(),
            )
        } else {
            // 如果最后一个代理不可信,使用最后一个代理的 IP
            (
                *last_proxy,
                vec![*last_proxy],
                0,
            )
        }
    }

    /// 严格验证模式:要求链中所有代理都可信
    fn validate_strict(&self, chain: &[IpAddr]) -> Result<(IpAddr, Vec<IpAddr>, usize), ProxyError> {
        if chain.is_empty() {
            return Ok((IpAddr::from([0, 0, 0, 0]), vec![], 0));
        }

        // 检查每个代理是否都可信
        for (i, ip) in chain.iter().enumerate() {
            if !self.is_ip_trusted(ip) {
                return Err(ProxyError::ChainValidationFailed(format!(
                    "链中第{}个代理 ({}) 不可信",
                    i + 1, ip
                )));
            }
        }

        Ok((
            chain[0],  // 第一个 IP 是客户端
            chain.to_vec(),
            chain.len(),
        ))
    }

    /// 跳数验证模式:从右向左找到第一个不可信代理
    fn validate_hop_by_hop(&self, chain: &[IpAddr]) -> Result<(IpAddr, Vec<IpAddr>, usize), ProxyError> {
        if chain.is_empty() {
            return Ok((IpAddr::from([0, 0, 0, 0]), vec![], 0));
        }

        let mut trusted_chain = Vec::new();
        let mut validated_hops = 0;

        // 从右向左遍历(从离我们最近的代理开始)
        for ip in chain.iter().rev() {
            if self.is_ip_trusted(ip) {
                trusted_chain.insert(0, *ip);
                validated_hops += 1;
            } else {
                // 找到第一个不可信代理,停止遍历
                break;
            }
        }

        if trusted_chain.is_empty() {
            // 没有可信代理,使用链的最后一个 IP
            let last_ip = *chain.last().unwrap();
            Ok((last_ip, vec![last_ip], 0))
        } else {
            // 客户端 IP 是可信链的第一个 IP 之前的那个 IP
            // 或者如果整个链都可信,就是原始链的第一个 IP
            let client_ip_index = chain.len().saturating_sub(trusted_chain.len());
            let client_ip = if client_ip_index > 0 {
                chain[client_ip_index - 1]
            } else {
                chain[0]
            };

            Ok((client_ip, trusted_chain, validated_hops))
        }
    }

    /// 检查代理链的连续性
    /// 验证规则:从客户端到服务器的路径应该是连续的
    fn check_chain_continuity(&self, full_chain: &[IpAddr], trusted_chain: &[IpAddr]) -> bool {
        if full_chain.len() <= 1 || trusted_chain.is_empty() {
            return true;
        }

        // 验证可信链是否确实是完整链的尾部连续部分
        let expected_tail = &full_chain[full_chain.len() - trusted_chain.len()..];
        expected_tail == trusted_chain
    }

    /// 处理不可信源的请求
    fn handle_untrusted_source(&self, peer_addr: &SocketAddr) -> Result<ProxyChainAnalysis, ProxyError> {
        let ip = peer_addr.ip();

        // 检查是否是私有地址(可能是内部服务调用)
        let is_private = self.private_nets_cache
            .iter()
            .any(|network| network.contains(ip));

        let warnings = if is_private {
            vec!["来自内部网络但未配置为可信代理".to_string()]
        } else {
            vec!["来自不可信的公网地址".to_string()]
        };

        Ok(ProxyChainAnalysis {
            trusted_client_ip: ip,
            full_proxy_chain: vec![ip],
            trusted_proxy_chain: vec![ip],
            validated_hops: 0,
            chain_integrity: true,
            validation_mode_used: ValidationMode::Lenient,
            security_warnings: warnings,
        })
    }

    /// 执行安全性检查
    fn perform_security_checks(&self, analysis: &ProxyChainAnalysis) -> Result<(), ProxyError> {
        // 检查代理链长度
        if analysis.full_proxy_chain.len() > self.advanced_config.max_proxy_hops {
            return Err(ProxyError::ChainValidationFailed(format!(
                "代理链过长:{} > {}",
                analysis.full_proxy_chain.len(),
                self.advanced_config.max_proxy_hops
            )));
        }

        // 检查客户端 IP 是否有效
        if analysis.trusted_client_ip.is_unspecified() {
            return Err(ProxyError::ChainValidationFailed(
                "客户端 IP 是未指定地址 (0.0.0.0 或::)".to_string()
            ));
        }

        // 检查是否回环地址(可能是配置错误或攻击)
        if analysis.trusted_client_ip.is_loopback() && analysis.validated_hops > 1 {
            warn!("客户端 IP 是回环地址,但经过了多层代理:{}", analysis.trusted_client_ip);
        }

        // 检查多播地址
        if analysis.trusted_client_ip.is_multicast() {
            return Err(ProxyError::ChainValidationFailed(
                "客户端 IP 是多播地址".to_string()
            ));
        }

        Ok(())
    }

    /// 检查单个 IP 是否可信
    fn is_ip_trusted(&self, ip: &IpAddr) -> bool {
        // 首先检查缓存
        if self.trusted_ips_cache.contains(ip) {
            return true;
        }

        // 然后检查 CIDR 范围
        let dummy_port = 0;
        let addr = SocketAddr::new(*ip, dummy_port);
        self.base_config.is_trusted(&addr)
    }

    /// 解析 RFC 7239 Forwarded 头部
    fn parse_forwarded_header(headers: &axum::http::HeaderMap) -> Option<ForwardedHeader> {
        let forwarded_value = headers.get("forwarded")?.to_str().ok()?;

        // Forwarded 头部可以有多个值,用逗号分隔
        // 我们取第一个
        let first_part = forwarded_value.split(',').next()?.trim();

        let mut result = ForwardedHeader {
            for_client: None,
            by_proxy: None,
            proto: None,
            host: None,
        };

        // 解析键值对,如:for=192.0.2.60;proto=http;by=203.0.113.43
        for part in first_part.split(';') {
            let part = part.trim();
            if let Some((key, value)) = part.split_once('=') {
                match key.trim().to_lowercase().as_str() {
                    "for" => {
                        // 去掉可能的引号和端口号
                        let clean_value = value.trim_matches('"');
                        if let Ok(ip) = clean_value.split(':').next().unwrap_or(clean_value).parse() {
                            result.for_client = Some(ip);
                        }
                    }
                    "by" => {
                        let clean_value = value.trim_matches('"');
                        if let Ok(ip) = clean_value.split(':').next().unwrap_or(clean_value).parse() {
                            result.by_proxy = Some(ip);
                        }
                    }
                    "proto" => {
                        result.proto = Some(value.trim_matches('"').to_string());
                    }
                    "host" => {
                        result.host = Some(value.trim_matches('"').to_string());
                    }
                    _ => {}
                }
            }
        }

        Some(result)
    }

    /// 从 Forwarded 头部提取代理链
    fn extract_chain_from_forwarded(forwarded: &ForwardedHeader) -> Option<Vec<IpAddr>> {
        let mut chain = Vec::new();

        // 如果有 for 字段,它是客户端 IP
        if let Some(client_ip) = &forwarded.for_client {
            chain.push(*client_ip);
        }

        // 如果有 by 字段,它可能是代理 IP
        // 注意:在 RFC 7239 中,多个代理会有多个 Forwarded 头部值
        // 这里简化处理,只取一个

        if chain.is_empty() {
            None
        } else {
            Some(chain)
        }
    }

    /// 解析 X-Forwarded-For 头部
    fn parse_x_forwarded_for(headers: &axum::http::HeaderMap) -> Option<Vec<IpAddr>> {
        let xff_value = headers.get("x-forwarded-for")?.to_str().ok()?;

        let ips: Vec<IpAddr> = xff_value
            .split(',')
            .map(|s| s.trim())
            .filter(|s| !s.is_empty())
            .filter_map(|s| {
                // 可能包含端口号,只取 IP 部分
                let ip_part = s.split(':').next().unwrap_or(s);
                ip_part.parse().ok()
            })
            .collect();

        if ips.is_empty() {
            None
        } else {
            Some(ips)
        }
    }

    /// 解析 X-Real-IP 头部
    fn parse_x_real_ip(headers: &axum::http::HeaderMap) -> Option<IpAddr> {
        let value = headers.get("x-real-ip")?.to_str().ok()?;
        value.parse().ok()
    }
}

// 单元测试
#[cfg(test)]
mod tests {
    use super::*;
    use axum::http::HeaderMap;

    fn create_test_processor() -> MultiProxyProcessor {
        let base_config = TrustedProxiesConfig::from_strs(&[
            "10.0.0.0/8",
            "172.16.0.0/12",
            "192.168.0.0/16",
            "127.0.0.1",
        ]).unwrap();

        MultiProxyProcessor::new(base_config, None)
    }

    #[test]
    fn test_parse_x_forwarded_for() {
        let mut headers = HeaderMap::new();
        headers.insert("X-Forwarded-For", "203.0.113.195, 198.51.100.1, 10.0.1.100".parse().unwrap());

        let result = MultiProxyProcessor::parse_x_forwarded_for(&headers).unwrap();
        assert_eq!(result.len(), 3);
        assert_eq!(result[0], IpAddr::from_str("203.0.113.195").unwrap());
    }

    #[test]
    fn test_parse_x_forwarded_for_with_ports() {
        let mut headers = HeaderMap::new();
        headers.insert("X-Forwarded-For", "203.0.113.195:1234, 198.51.100.1:80".parse().unwrap());

        let result = MultiProxyProcessor::parse_x_forwarded_for(&headers).unwrap();
        assert_eq!(result.len(), 2);
        assert_eq!(result[0], IpAddr::from_str("203.0.113.195").unwrap());
    }

    #[test]
    fn test_parse_forwarded_header() {
        let mut headers = HeaderMap::new();
        headers.insert("Forwarded", r#"for=192.0.2.60;proto=http;by=203.0.113.43"#.parse().unwrap());

        let result = MultiProxyProcessor::parse_forwarded_header(&headers).unwrap();
        assert_eq!(result.for_client, Some(IpAddr::from_str("192.0.2.60").unwrap()));
        assert_eq!(result.by_proxy, Some(IpAddr::from_str("203.0.113.43").unwrap()));
        assert_eq!(result.proto, Some("http".to_string()));
    }

    #[test]
    fn test_validate_hop_by_hop() {
        let processor = create_test_processor();

        // 测试链:客户端 (公网) -> 代理 1(私有) -> 代理 2(私有)
        let chain = vec![
            IpAddr::from_str("8.8.8.8").unwrap(),      // 客户端(不可信)
            IpAddr::from_str("10.0.1.100").unwrap(),   // 代理 1(可信)
            IpAddr::from_str("192.168.1.50").unwrap(), // 代理 2(可信)
        ];

        let result = processor.validate_hop_by_hop(&chain).unwrap();
        assert_eq!(result.0, IpAddr::from_str("8.8.8.8").unwrap()); // 客户端 IP
        assert_eq!(result.2, 2); // 验证了 2 跳
    }

    #[test]
    fn test_chain_continuity_check() {
        let processor = create_test_processor();

        // 完整链
        let full_chain = vec![
            IpAddr::from_str("8.8.8.8").unwrap(),
            IpAddr::from_str("10.0.1.100").unwrap(),
            IpAddr::from_str("192.168.1.50").unwrap(),
        ];

        // 可信链应该是完整链的尾部连续部分
        let trusted_chain = vec![
            IpAddr::from_str("10.0.1.100").unwrap(),
            IpAddr::from_str("192.168.1.50").unwrap(),
        ];

        assert!(processor.check_chain_continuity(&full_chain, &trusted_chain));

        // 不连续的情况应该失败
        let bad_trusted_chain = vec![
            IpAddr::from_str("192.168.1.50").unwrap(),
        ];
        assert!(!processor.check_chain_continuity(&full_chain, &bad_trusted_chain));
    }

    #[test]
    fn test_security_checks() {
        let processor = create_test_processor();

        // 创建测试分析结果
        let analysis = ProxyChainAnalysis {
            trusted_client_ip: IpAddr::from([0, 0, 0, 0]), // 无效地址
            full_proxy_chain: Vec::new(),
            trusted_proxy_chain: Vec::new(),
            validated_hops: 0,
            chain_integrity: true,
            validation_mode_used: ValidationMode::Lenient,
            security_warnings: Vec::new(),
        };

        // 应该失败,因为客户端 IP 是未指定地址
        assert!(processor.perform_security_checks(&analysis).is_err());
    }
}

2.2 与中间件集成

更新 src/middleware.rs 以使用新的高级处理器:

// 在 middleware.rs 中添加
use crate::advanced::{MultiProxyProcessor, MultiProxyConfig};

// 更新 TrustedProxiesMiddleware 的 call 方法
fn call(&mut self, mut req: Request) -> Self::Future {
    let span = info_span!(
        "trusted_proxy_check",
        peer_addr = ?req.extensions().get::<SocketAddr>().map(|a| a.to_string())
    );

    let _guard = span.enter();

    // 使用高级处理器
    let processor = MultiProxyProcessor::new(
        self.config.clone(),
        Some(MultiProxyConfig::default()),
    );

    let peer_addr = req.extensions().get::<SocketAddr>().copied();
    let headers = req.headers();

    match peer_addr {
        Some(addr) => {
            match processor.process_request(&addr, headers) {
                Ok(analysis) => {
                    tracing::debug!(
                        "代理链分析完成:client_ip={}, hops={}, integrity={}",
                        analysis.trusted_client_ip,
                        analysis.validated_hops,
                        analysis.chain_integrity
                    );

                    // 创建增强的客户端信息
                    let client_info = EnhancedClientInfo {
                        real_ip: analysis.trusted_client_ip,
                        forwarded_host: headers
                            .get("x-forwarded-host")
                            .and_then(|h| h.to_str().ok())
                            .map(String::from),
                        forwarded_proto: headers
                            .get("x-forwarded-proto")
                            .and_then(|h| h.to_str().ok())
                            .map(String::from),
                        is_from_trusted_proxy: analysis.validated_hops > 0,
                        proxy_ip: Some(addr.ip()),
                        proxy_chain_analysis: Some(analysis),
                    };

                    req.extensions_mut().insert(client_info);

                    // 记录安全警告
                    if !client_info.proxy_chain_analysis.as_ref().unwrap().security_warnings.is_empty() {
                        tracing::warn!(
                            "代理链安全警告:{:?}",
                            client_info.proxy_chain_analysis.as_ref().unwrap().security_warnings
                        );
                    }
                }
                Err(err) => {
                    tracing::warn!("代理链验证失败:{}", err);

                    // 验证失败时回退到基本处理
                    let client_info = if self.config.is_trusted(&addr) {
                        extract_client_info_from_headers(&req)
                    } else {
                        ClientInfo::direct(addr)
                    };

                    req.extensions_mut().insert(client_info);
                }
            }
        }
        None => {
            tracing::warn!("无法获取对端地址");
            let client_info = ClientInfo {
                real_ip: std::net::Ipv4Addr::UNSPECIFIED.into(),
                forwarded_host: None,
                forwarded_proto: None,
                is_from_trusted_proxy: false,
                proxy_ip: None,
            };
            req.extensions_mut().insert(client_info);
        }
    }

    self.inner.call(req)
}

// 增强的客户端信息结构
#[derive(Debug, Clone)]
pub struct EnhancedClientInfo {
    pub real_ip: std::net::IpAddr,
    pub forwarded_host: Option<String>,
    pub forwarded_proto: Option<String>,
    pub is_from_trusted_proxy: bool,
    pub proxy_ip: Option<std::net::IpAddr>,
    pub proxy_chain_analysis: Option<crate::advanced::ProxyChainAnalysis>,
}

三、性能优化与缓存策略

3.1 IP 地址缓存优化

// 在 advanced.rs 中添加
use std::sync::Arc;
use lru_cache::LruCache;
use std::time::{Duration, Instant};

/// IP 验证结果缓存项
#[derive(Debug, Clone)]
struct IpValidationCacheEntry {
    /// 验证结果
    is_trusted: bool,
    /// 缓存时间戳
    cached_at: Instant,
    /// 过期时间
    expires_in: Duration,
}

/// 高性能的 IP 验证缓存
pub struct IpValidationCache {
    /// LRU 缓存
    cache: LruCache<IpAddr, IpValidationCacheEntry>,
    /// 缓存命中统计
    stats: CacheStats,
    /// 默认缓存过期时间
    default_ttl: Duration,
}

#[derive(Debug, Default)]
struct CacheStats {
    hits: u64,
    misses: u64,
    evictions: u64,
}

impl IpValidationCache {
    pub fn new(capacity: usize, default_ttl: Duration) -> Self {
        Self {
            cache: LruCache::new(capacity),
            stats: CacheStats::default(),
            default_ttl,
        }
    }

    /// 检查 IP 是否可信(带缓存)
    pub fn is_trusted(
        &mut self,
        ip: &IpAddr,
        validator: impl FnOnce(&IpAddr) -> bool,
    ) -> bool {
        let now = Instant::now();

        // 检查缓存
        if let Some(entry) = self.cache.get_mut(ip) {
            if now.duration_since(entry.cached_at) < entry.expires_in {
                self.stats.hits += 1;
                return entry.is_trusted;
            } else {
                // 缓存过期
                self.cache.remove(ip);
            }
        }

        self.stats.misses += 1;

        // 调用验证函数
        let is_trusted = validator(ip);

        // 更新缓存
        let entry = IpValidationCacheEntry {
            is_trusted,
            cached_at: now,
            expires_in: self.default_ttl,
        };

        if self.cache.len() >= self.cache.capacity() {
            self.stats.evictions += 1;
        }

        self.cache.insert(*ip, entry);
        is_trusted
    }

    /// 获取缓存统计
    pub fn stats(&self) -> &CacheStats {
        &self.stats
    }

    /// 清除过期缓存项
    pub fn cleanup_expired(&mut self) -> usize {
        let now = Instant::now();
        let mut expired = 0;

        let keys: Vec<IpAddr> = self.cache.iter()
            .filter(|(_, entry)| now.duration_since(entry.cached_at) >= entry.expires_in)
            .map(|(ip, _)| *ip)
            .collect();

        for ip in keys {
            self.cache.remove(&ip);
            expired += 1;
        }

        expired
    }
}

// 更新 MultiProxyProcessor 以使用缓存
pub struct OptimizedMultiProxyProcessor {
    base_config: Arc<TrustedProxiesConfig>,
    advanced_config: MultiProxyConfig,
    ip_cache: std::sync::Mutex<IpValidationCache>,
    cidr_matcher: CidrMatcher,
}

impl OptimizedMultiProxyProcessor {
    pub fn new(
        base_config: TrustedProxiesConfig,
        advanced_config: Option<MultiProxyConfig>,
    ) -> Self {
        let config = advanced_config.unwrap_or_default();

        // 预编译 CIDR 匹配器
        let cidr_matcher = CidrMatcher::new(&base_config);

        Self {
            base_config: Arc::new(base_config),
            advanced_config: config,
            ip_cache: std::sync::Mutex::new(
                IpValidationCache::new(10000, Duration::from_secs(300))
            ),
            cidr_matcher,
        }
    }

    /// 优化的 IP 可信检查
    fn is_ip_trusted_optimized(&self, ip: &IpAddr) -> bool {
        let mut cache = self.ip_cache.lock().unwrap();

        cache.is_trusted(ip, |ip| {
            // 快速路径:检查单个 IP 缓存
            // 慢速路径:检查 CIDR 范围
            self.cidr_matcher.contains(ip)
        })
    }
}

/// 优化的 CIDR 匹配器
struct CidrMatcher {
    ipv4_networks: Vec<IpNetwork>,
    ipv6_networks: Vec<IpNetwork>,
    single_ips: HashSet<IpAddr>,
}

impl CidrMatcher {
    fn new(config: &TrustedProxiesConfig) -> Self {
        let mut ipv4_networks = Vec::new();
        let mut ipv6_networks = Vec::new();
        let mut single_ips = HashSet::new();

        for proxy in &config.proxies {
            match proxy {
                TrustedProxy::Single(ip) => {
                    single_ips.insert(*ip);
                }
                TrustedProxy::Cidr(network) => {
                    match network {
                        IpNetwork::V4(_) => ipv4_networks.push(*network),
                        IpNetwork::V6(_) => ipv6_networks.push(*network),
                    }
                }
            }
        }

        // 按前缀长度排序,更具体的网络优先
        ipv4_networks.sort_by(|a, b| b.prefix().cmp(&a.prefix()));
        ipv6_networks.sort_by(|a, b| b.prefix().cmp(&a.prefix()));

        Self {
            ipv4_networks,
            ipv6_networks,
            single_ips,
        }
    }

    fn contains(&self, ip: &IpAddr) -> bool {
        // 首先检查单个 IP
        if self.single_ips.contains(ip) {
            return true;
        }

        // 然后检查 CIDR 范围
        match ip {
            IpAddr::V4(ipv4) => {
                for network in &self.ipv4_networks {
                    if let IpNetwork::V4(v4_net) = network {
                        if v4_net.contains(*ipv4) {
                            return true;
                        }
                    }
                }
            }
            IpAddr::V6(ipv6) => {
                for network in &self.ipv6_networks {
                    if let IpNetwork::V6(v6_net) = network {
                        if v6_net.contains(*ipv6) {
                            return true;
                        }
                    }
                }
            }
        }

        false
    }
}

四、监控与可观测性

4.1 代理链监控指标

// 在 advanced.rs 中添加
use metrics::{counter, histogram, gauge};

/// 代理处理监控指标
pub struct ProxyMetrics {
    /// 处理的请求总数
    pub total_requests: counter::Counter,
    /// 来自可信代理的请求数
    pub trusted_proxy_requests: counter::Counter,
    /// 来自不可信源的请求数
    pub untrusted_requests: counter::Counter,
    /// 验证失败的请求数
    pub validation_failed: counter::Counter,
    /// 代理链长度分布
    pub chain_length: histogram::Histogram,
    /// 验证耗时
    pub validation_duration: histogram::Histogram,
    /// 缓存命中率
    pub cache_hit_ratio: gauge::Gauge,
}

impl ProxyMetrics {
    pub fn new() -> Self {
        Self {
            total_requests: counter!("proxy.total_requests", "处理的请求总数"),
            trusted_proxy_requests: counter!("proxy.trusted_requests", "来自可信代理的请求"),
            untrusted_requests: counter!("proxy.untrusted_requests", "来自不可信源的请求"),
            validation_failed: counter!("proxy.validation_failed", "验证失败的请求"),
            chain_length: histogram!("proxy.chain_length", "代理链长度分布"),
            validation_duration: histogram!("proxy.validation_duration_ms", "验证耗时 (ms)"),
            cache_hit_ratio: gauge!("proxy.cache_hit_ratio", "缓存命中率"),
        }
    }

    pub fn record_request(&self, analysis: &ProxyChainAnalysis, duration_ms: f64) {
        self.total_requests.increment(1);
        self.chain_length.record(analysis.full_proxy_chain.len() as f64);
        self.validation_duration.record(duration_ms);

        if analysis.validated_hops > 0 {
            self.trusted_proxy_requests.increment(1);
        } else {
            self.untrusted_requests.increment(1);
        }

        if !analysis.security_warnings.is_empty() {
            self.validation_failed.increment(1);
        }
    }
}

// 在 MultiProxyProcessor 中添加监控
pub struct MonitoredMultiProxyProcessor {
    processor: MultiProxyProcessor,
    metrics: ProxyMetrics,
}

impl MonitoredMultiProxyProcessor {
    pub fn process_request_with_metrics(
        &self,
        peer_addr: &SocketAddr,
        headers: &axum::http::HeaderMap,
    ) -> Result<ProxyChainAnalysis, ProxyError> {
        let start = std::time::Instant::now();

        let result = self.processor.process_request(peer_addr, headers);

        let duration_ms = start.elapsed().as_secs_f64() * 1000.0;

        match &result {
            Ok(analysis) => {
                self.metrics.record_request(analysis, duration_ms);
            }
            Err(_) => {
                self.metrics.validation_failed.increment(1);
                self.metrics.total_requests.increment(1);
            }
        }

        result
    }
}

五、配置文件示例

创建 config/proxy.toml

[proxy]
# 验证模式:lenient, strict, hop_by_hop
validation_mode = "hop_by_hop"

# 是否启用 RFC 7239 Forwarded 头部
enable_rfc7239 = true

# 最大代理跳数
max_proxy_hops = 10

# 是否启用链连续性检查
enable_chain_continuity_check = true

# 缓存配置
[cache]
capacity = 10000
ttl_seconds = 300
cleanup_interval_seconds = 60

# 监控配置
[monitoring]
enable_metrics = true
log_level = "info"
log_failed_validations = true

# 可信代理列表
trusted_proxies = [
    "127.0.0.1",
    "::1",
    "10.0.0.0/8",
    "172.16.0.0/12",
    "192.168.0.0/16",
    "fd00::/8",
]

# 生产环境特定配置(可通过环境变量覆盖)
[production]
# Cloudflare IP 范围
additional_trusted_proxies = [
    "103.21.244.0/22",
    "103.22.200.0/22",
    "103.31.4.0/22",
    "104.16.0.0/13",
    "104.24.0.0/14",
    "108.162.192.0/18",
    "131.0.72.0/22",
    "141.101.64.0/18",
    "162.158.0.0/15",
    "172.64.0.0/13",
    "173.245.48.0/20",
    "188.114.96.0/20",
    "190.93.240.0/20",
    "197.234.240.0/22",
    "198.41.128.0/17",
]

六、集成测试示例

创建 tests/advanced_integration.rs

#[tokio::test]
async fn test_complex_proxy_chain() {
    use axum::Router;
    use axum_trusted_proxies_tutorial::advanced::{
        MultiProxyProcessor, MultiProxyConfig, ValidationMode
    };
    use axum_trusted_proxies_tutorial::config::TrustedProxiesConfig;
    use std::net::SocketAddr;
    use axum::http::HeaderMap;

    // 创建测试配置
    let base_config = TrustedProxiesConfig::from_strs(&[
        "10.0.0.0/8",
        "172.16.0.0/12",
        "192.168.0.0/16",
    ]).unwrap();

    let advanced_config = MultiProxyConfig {
        validation_mode: ValidationMode::HopByHop,
        enable_rfc7239: true,
        max_proxy_hops: 5,
        ..Default::default()
    };

    let processor = MultiProxyProcessor::new(base_config, Some(advanced_config));

    // 模拟复杂代理链
    let peer_addr = SocketAddr::from(([192, 168, 1, 100], 8080));

    let mut headers = HeaderMap::new();
    headers.insert(
        "X-Forwarded-For",
        "203.0.113.195, 10.0.1.50, 172.16.0.10, 192.168.1.1".parse().unwrap()
    );
    headers.insert("X-Forwarded-Proto", "https".parse().unwrap());
    headers.insert("X-Forwarded-Host", "api.example.com".parse().unwrap());

    // 测试处理
    let result = processor.process_request(&peer_addr, &headers).unwrap();

    assert_eq!(result.trusted_client_ip.to_string(), "203.0.113.195");
    assert_eq!(result.validated_hops, 3); // 192.168.1.1, 172.16.0.10, 10.0.1.50
    assert!(result.chain_integrity);

    // 测试 RFC 7239 头部
    let mut rfc_headers = HeaderMap::new();
    rfc_headers.insert(
        "Forwarded",
        r#"for=192.0.2.60;proto=https;by=203.0.113.43,for=198.51.100.17"#.parse().unwrap()
    );

    let rfc_result = processor.process_request(&peer_addr, &rfc_headers).unwrap();
    assert_eq!(rfc_result.trusted_client_ip.to_string(), "192.0.2.60");
}

#[tokio::test]
async fn test_proxy_chain_attack_scenarios() {
    use axum_trusted_proxies_tutorial::advanced::{MultiProxyProcessor, ProxyError};

    let base_config = TrustedProxiesConfig::from_strs(&["10.0.0.0/8"]).unwrap();
    let processor = MultiProxyProcessor::new(base_config, None);

    // 场景 1: 过长代理链攻击
    let peer_addr = SocketAddr::from(([10, 0, 0, 1], 80));
    let mut headers = HeaderMap::new();

    // 创建超长链(超过默认限制 10 跳)
    let long_chain = (0..15)
        .map(|i| format!("10.0.{}.1", i))
        .collect::<Vec<_>>()
        .join(", ");

    headers.insert("X-Forwarded-For", long_chain.parse().unwrap());

    let result = processor.process_request(&peer_addr, &headers);
    assert!(result.is_err());

    // 场景 2: IP 伪造攻击
    let mut attack_headers = HeaderMap::new();
    attack_headers.insert("X-Forwarded-For", "8.8.8.8".parse().unwrap());

    let attack_result = processor.process_request(&peer_addr, &attack_headers).unwrap();
    // 应该正确识别 8.8.8.8 为客户端 IP(因为代理可信)
    assert_eq!(attack_result.trusted_client_ip.to_string(), "8.8.8.8");

    // 场景 3: 不可信代理尝试欺骗
    let untrusted_peer = SocketAddr::from(([8, 8, 8, 8], 80));
    let fake_headers = HeaderMap::new();
    fake_headers.insert("X-Forwarded-For", "10.0.0.100".parse().unwrap());

    let fake_result = processor.process_request(&untrusted_peer, &fake_headers).unwrap();
    // 应该忽略 X-Forwarded-For,使用 8.8.8.8 作为客户端 IP
    assert_eq!(fake_result.trusted_client_ip.to_string(), "8.8.8.8");
}

七、生产环境部署建议

7.1 配置管理

// 环境特定的配置加载
pub fn load_production_config() -> MultiProxyConfig {
    let mut config = MultiProxyConfig::default();

    // 从环境变量读取
    if let Ok(mode_str) = std::env::var("PROXY_VALIDATION_MODE") {
        config.validation_mode = match mode_str.as_str() {
            "strict" => ValidationMode::Strict,
            "lenient" => ValidationMode::Lenient,
            _ => ValidationMode::HopByHop,
        };
    }

    // 从云服务商元数据添加可信代理
    if let Ok(cloud_ips) = fetch_cloud_provider_ips() {
        for ip_range in cloud_ips {
            if let Ok(network) = ip_range.parse() {
                config.allowed_private_nets.push(network);
            }
        }
    }

    config
}

// 动态配置更新
pub struct DynamicConfigManager {
    current_config: Arc<std::sync::RwLock<MultiProxyConfig>>,
    config_watcher: tokio::sync::watch::Sender<MultiProxyConfig>,
}

impl DynamicConfigManager {
    pub fn new(initial_config: MultiProxyConfig) -> Self {
        let (sender, _) = tokio::sync::watch::channel(initial_config.clone());

        Self {
            current_config: Arc::new(std::sync::RwLock::new(initial_config)),
            config_watcher: sender,
        }
    }

    pub fn update_config(&self, new_config: MultiProxyConfig) {
        let mut config = self.current_config.write().unwrap();
        *config = new_config.clone();

        // 通知所有监听者
        let _ = self.config_watcher.send(new_config);
    }

    pub fn get_config(&self) -> MultiProxyConfig {
        self.current_config.read().unwrap().clone()
    }
}

7.2 健康检查端点

// 在路由中添加健康检查
async fn proxy_health_handler(
    State(state): State<Arc<AppState>>,
) -> Json<serde_json::Value> {
    let metrics = state.proxy_processor.get_metrics();
    let cache_stats = state.proxy_processor.get_cache_stats();

    Json(json!({
        "status": "healthy",
        "metrics": {
            "total_requests": metrics.total_requests,
            "trusted_requests": metrics.trusted_proxy_requests,
            "cache_hit_rate": cache_stats.hit_rate(),
            "avg_validation_time_ms": metrics.avg_validation_time,
        },
        "config": {
            "validation_mode": format!("{:?}", state.proxy_config.validation_mode),
            "max_proxy_hops": state.proxy_config.max_proxy_hops,
            "trusted_networks_count": state.proxy_config.allowed_private_nets.len(),
        }
    }))
}

总结

通过本章的优化,我们的可信代理处理器现在具备:

  1. 多层代理链安全验证:支持三种验证模式,可应对不同安全需求
  2. RFC 7239 标准支持:兼容最新的 Forwarded 头部标准
  3. 高性能缓存:LRU 缓存和 CIDR 预编译优化,提升验证速度
  4. 全面的监控:详细的指标收集和日志记录
  5. 防御加固:针对各种代理链攻击的防护措施
  6. 动态配置:支持运行时配置更新

在实际部署时,建议根据具体环境调整配置:

  • 内部服务间调用:使用 Strict 模式
  • 面向公网的服务:使用 HopByHop 模式
  • 需要最大兼容性:使用 Lenient 模式

通过这套完整的解决方案,你的 Axum 应用将能够安全、高效地处理任何复杂的代理部署场景。

版权声明:自由转载-非商用-非衍生-保持署名(创意共享3.0许可证)