🦀 Axum 秒级限流:Governor 令牌桶一键接入,Redis 分布式抗 DDoS

Photos provided by Unsplash OR Pexels

Rust Axum 框架中生产级 Rate Limiting(限流)中间件实现最佳实践

介绍

在生产环境中,Rate Limiting(速率限制)是防止滥用、DDoS 攻击、爬虫和保障服务稳定性的核心安全机制。常见策略包括:

  • 固定窗口(Fixed Window)
  • 滑动窗口(Sliding Window)
  • 令牌桶(Token Bucket)
  • 漏桶(Leaky Bucket)

Axum 生态中最推荐的生产级方案是使用 tower::limit::RateLimitLayer 结合 governor crate,它实现了高效的令牌桶算法,支持分布式(通过 Redis)扩展,且内存占用低、性能优秀。

本文基于前文完整实现(可信代理 IP + Forwarded + TLS + HSTS + CORS),添加工业级 Rate Limiting 中间件,关键点:

  • 基于真实客户端 IP(使用前文提取的 ClientIp)进行限流
  • 使用 governor + 内存存储(单实例)
  • 支持配置限流阈值(例如每分钟 60 次请求)
  • 支持 Burst(突发)容量
  • 返回标准 HTTP 429 Too Many Requests + Retry-After 头部

最佳实践要点

  • 限流键必须使用真实客户端 IP(而非代理 IP),否则所有流量都会被视为同一个来源
  • 生产环境建议结合 Redis(governor 支持)实现分布式限流
  • 返回标准 429 响应,便于客户端重试
  • 可结合路径或方法细粒度控制
  • 与其他中间件顺序:Rate Limit 应在 CORS 之后、业务逻辑之前

完整实例代码(含生产级 Rate Limiting)

Cargo.toml

[package]
name = "axum-trusted-proxies-tls-hsts-cors-ratelimit"
version = "0.1.0"
edition = "2021"

[dependencies]
axum = "0.8"
axum_server = "0.7"
tokio = { version = "1.48.0", features = ["full"] }
tower = "0.5"
tower-http = { 
    version = "0.6.8", 
    features = ["trace", "cors"]
}
governor = "0.6"  # 高效令牌桶限流
rustls = "0.23.35"
rustls-pemfile = "2.2.0"
ipnet = "2.10.1"
tracing = "0.1.40"
tracing-subscriber = "0.3.18"

src/main.rs

use axum::{
    extract::{ConnectInfo, Extension, Request},
    http::{HeaderMap, HeaderValue, StatusCode},
    middleware::{self, Next},
    response::{IntoResponse, Response},
    routing::get,
    Router,
};
use axum_server::tls_rustls::RustlsConfig;
use governor::{
    clock::DefaultClock,
    middleware::RateLimitMiddleware,
    state::InMemoryState,
    Quota, RateLimiter,
};
use ipnet::IpNet;
use std::collections::HashSet;
use std::net::{IpAddr, SocketAddr};
use std::num::NonZeroU32;
use std::path::Path;
use std::str::FromStr;
use std::sync::Arc;
use tower_http::{
    cors::CorsLayer,
    trace::TraceLayer,
};
use tracing::{error, info, warn};

// 真实客户端 IP 扩展
#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
pub struct ClientIp(pub IpAddr);

// 可信代理配置
#[derive(Clone, Debug)]
pub struct TrustedProxies {
    proxies: HashSet<IpNet>,
    max_hops: Option<usize>,
}

impl TrustedProxies {
    pub fn new(proxies: Vec<String>, max_hops: Option<usize>) -> Self {
        let mut set = HashSet::new();
        for proxy in proxies {
            if let Ok(net) = IpNet::from_str(&proxy) {
                set.insert(net);
            } else {
                error!("Invalid proxy CIDR: {}", proxy);
            }
        }
        Self { proxies: set, max_hops }
    }

    pub fn is_trusted(&self, ip: IpAddr) -> bool {
        self.proxies.iter().any(|net| net.contains(&ip))
    }
}

// HSTS 配置
#[derive(Clone)]
pub struct HstsConfig {
    pub max_age: u64,
    pub include_subdomains: bool,
    pub preload: bool,
}

impl Default for HstsConfig {
    fn default() -> Self {
        Self {
            max_age: 63072000,
            include_subdomains: true,
            preload: false,
        }
    }
}

// HSTS 中间件
async fn hsts_middleware(
    Extension(hsts_config): Extension<Arc<HstsConfig>>,
    req: Request,
    next: Next,
) -> impl IntoResponse {
    let mut res = next.run(req).await;

    let mut value = format!("max-age={}", hsts_config.max_age);
    if hsts_config.include_subdomains {
        value.push_str("; includeSubDomains");
    }
    if hsts_config.preload {
        value.push_str("; preload");
    }

    res.headers_mut().insert(
        "strict-transport-security",
        HeaderValue::from_str(&value).unwrap(),
    );

    res
}

// Forwarded 解析(简化版,核心逻辑同前)
fn extract_client_ip_from_headers(
    headers: &HeaderMap,
    trusted_proxies: &TrustedProxies,
) -> Option<IpAddr> {
    // 优先 Forwarded
    // ...(同前文实现,此处省略以节省篇幅)
    // 回退 X-Forwarded-For 和 X-Real-IP
    None // 实际应返回解析结果
}

// 可信代理中间件(提取真实 IP)
async fn trusted_proxy_middleware(
    Extension(trusted_proxies): Extension<Arc<TrustedProxies>>,
    ConnectInfo(socket_addr): ConnectInfo<SocketAddr>,
    headers: HeaderMap,
    mut req: Request,
    next: Next,
) -> impl IntoResponse {
    let socket_ip = socket_addr.ip();
    info!("Connection from: {}", socket_ip);

    let client_ip = if trusted_proxies.is_trusted(socket_ip) {
        extract_client_ip_from_headers(&headers, trusted_proxies)
            .unwrap_or(socket_ip)
    } else {
        socket_ip
    };

    info!("Real client IP: {}", client_ip);
    req.extensions_mut().insert(ClientIp(client_ip));

    next.run(req).await
}

// 自定义 Rate Limiting 中间件(基于真实 ClientIp)
async fn rate_limit_middleware(
    Extension(rate_limiter): Extension<Arc<RateLimiter<ClientIp, InMemoryState, DefaultClock>>>,
    Extension(client_ip): Extension<ClientIp>,
    req: Request,
    next: Next,
) -> impl IntoResponse {
    // 检查是否超过限流
    if let Err(not_until) = rate_limiter.check_key(&client_ip) {
        let retry_after = not_until.earliest_possible() - DefaultClock::default().now();
        let seconds = retry_after.as_secs();

        return (
            StatusCode::TOO_MANY_REQUESTS,
            [
                ("content-type", "text/plain"),
                ("retry-after", &seconds.to_string()),
            ],
            format!("Too many requests. Retry after {} seconds.", seconds),
        )
            .into_response();
    }

    next.run(req).await
}

// Handler 示例
async fn root(Extension(client_ip): Extension<ClientIp>) -> impl IntoResponse {
    format!(
        "Welcome to secure API\n\
         Your IP: {}\n\
         Rate limiting active (60 req/min)",
        client_ip.0
    )
}

async fn health() -> &'static str {
    "OK"
}

#[tokio::main]
async fn main() -> Result<(), Box<dyn std::error::Error>> {
    tracing_subscriber::fmt::init();

    let trusted_proxies = Arc::new(TrustedProxies::new(
        vec![
            "127.0.0.1/32".to_string(),
            "10.0.0.0/8".to_string(),
            "172.16.0.0/12".to_string(),
            "192.168.0.0/16".to_string(),
        ],
        Some(10),
    ));

    let hsts_config = Arc::new(HstsConfig::default());

    // === Rate Limiter 配置:每分钟 60 次请求,允许突发 10 次 ===
    let quota = Quota::per_minute(NonZeroU32::new(60).unwrap())
        .allow_burst(NonZeroU32::new(10).unwrap());
    let rate_limiter = Arc::new(RateLimiter::keyed(quota));

    // === CORS 配置 ===
    let cors = CorsLayer::new()
        .allow_origin("https://your-frontend.com".parse::<HeaderValue>()?)
        .allow_methods([axum::http::Method::GET, axum::http::Method::POST])
        .allow_headers([axum::http::header::CONTENT_TYPE, axum::http::header::AUTHORIZATION])
        .allow_credentials(true)
        .max_age(86400);

    // === TLS 配置 ===
    let tls_config = RustlsConfig::from_pem_file(
        "certs/fullchain.pem",
        "certs/privkey.pem",
    )
    .await?;

    // === Router 构建 ===
    let app = Router::new()
        .route("/", get(root))
        .route("/health", get(health))
        .layer(cors)
        .layer(middleware::from_fn_with_state(
            trusted_proxies.clone(),
            trusted_proxy_middleware,
        ))
        .layer(middleware::from_fn_with_state(
            rate_limiter.clone(),
            rate_limit_middleware,
        ))
        .layer(middleware::from_fn_with_state(
            hsts_config.clone(),
            hsts_middleware,
        ))
        .layer(TraceLayer::new_for_http());

    let addr = SocketAddr::from(([0, 0, 0, 0], 443));
    info!("Starting full-featured secure server on https://{}", addr);

    axum_server::bind_rustls(addr, tls_config)
        .serve(app.into_make_service_with_connect_info::<SocketAddr>())
        .await?;

    Ok(())
}

分布式限流升级(生产推荐)

InMemoryState 替换为 Redis:

governor = { version = "0.6", features = ["redis"] }
redis = "0.26"
use governor::state::redis::RedisState;

// 在 main 中:
let redis_conn = redis::Client::open("redis://127.0.0.1/")?.get_async_connection().await?;
let rate_limiter = Arc::new(RateLimiter::redis(quota, redis_conn)?);

测试限流

# 快速发送多个请求观察 429 响应
for i in {1..70}; do
    curl -s -o /dev/null -w "%{http_code}\n" https://your-api.com/
done

参考资料

至此,你的 Axum 服务已具备完整生产级防护能力

  • 真实客户端 IP 提取(支持多级代理和 Forwarded)
  • rustls TLS 加密
  • HSTS 强制 HTTPS
  • 严格 CORS
  • 基于真实 IP 的高效令牌桶限流

代码结构清晰、可配置、可扩展,完全符合工业级标准。

版权声明:自由转载-非商用-非衍生-保持署名(创意共享3.0许可证)