Skip to content

中间件

日志中间件

rust
use axum::{
    middleware::{self, Next},
    http::{Request, StatusCode},
    response::Response,
    routing::get,
    Router,
};
use std::time::Instant;

async fn log_request<B>(
    req: Request<B>,
    next: Next<B>,
) -> Result<Response, StatusCode> {
    let method = req.method().clone();
    let uri = req.uri().clone();
    let start = Instant::now();

    println!("--> {} {} ", method, uri);

    let response = next.run(req).await;

    let duration = start.elapsed();
    let status = response.status();
    println!("<-- {} {} {:?}", uri, status, duration);

    Ok(response)
}

async fn handler() -> &'static str {
    "Hello"
}

#[tokio::main]
async fn main() {
    let app = Router::new()
        .route("/", get(handler))
        .layer(middleware::from_fn(log_request));

    let listener = tokio::net::TcpListener::bind("127.0.0.1:3000")
        .await
        .unwrap();

    axum::serve(listener, app).await.unwrap();
}
▶ Run

CORS 中间件

rust
use axum::{
    middleware,
    http::{Request, Response, StatusCode},
    Router,
};
use tower_http::cors::{CorsLayer, Any};

async fn cors_middleware<B>(
    req: Request<B>,
    next: middleware::Next<B>,
) -> Result<Response, StatusCode> {
    let mut response = next.run(req).await;

    response.headers_mut().insert(
        "Access-Control-Allow-Origin",
        "*".parse().unwrap(),
    );

    Ok(response)
}

#[tokio::main]
async fn main() {
    // 使用 tower-http 的 CORS
    let cors = CorsLayer::new()
        .allow_origin(Any)
        .allow_methods(Any)
        .allow_headers(Any);

    let app = Router::new()
        .route("/", get(|| async { "Hello" }))
        .layer(cors);

    // 或者使用自定义中间件
    let app_with_custom = Router::new()
        .route("/", get(|| async { "Hello" }))
        .layer(middleware::from_fn(cors_middleware));

    let listener = tokio::net::TcpListener::bind("127.0.0.1:3000")
        .await
        .unwrap();

    axum::serve(listener, app).await.unwrap();
}
▶ Run

静态文件服务

使用 tower-http

toml
[dependencies]
axum = "0.7"
tokio = { version = "1", features = ["full"] }
tower-http = { version = "0.5", features = ["fs", "trace"] }
rust
use axum::{Router, routing::get_service};
use tower_http::services::{ServeDir, ServeFile};

#[tokio::main]
async fn main() {
    let app = Router::new()
        // 服务静态目录
        .nest_service("/static", ServeDir::new("public"))
        // 服务单个文件
        .nest_service(
            "/favicon.ico",
            ServeFile::new("assets/favicon.ico"),
        );

    let listener = tokio::net::TcpListener::bind("127.0.0.1:3000")
        .await
        .unwrap();

    axum::serve(listener, app).await.unwrap();
}
▶ Run

Actix-web 示例

项目设置

toml
[package]
name = "actix-demo"
version = "0.1.0"
edition = "2021"

[dependencies]
actix-web = "4"
serde = { version = "1", features = ["derive"] }
serde_json = "1.0"

Hello World

rust
use actix_web::{web, App, HttpServer, HttpResponse, Responder};
use serde::{Deserialize, Serialize};

#[derive(Serialize)]
struct User {
    id: u32,
    name: String,
}

#[derive(Deserialize)]
struct Info {
    name: String,
    age: u32,
}

async fn index() -> impl Responder {
    HttpResponse::Ok().body("Hello, World!")
}

async fn greet(info: web::Json<Info>) -> impl Responder {
    HttpResponse::Ok().json(User {
        id: 1,
        name: info.name.clone(),
    })
}

async fn get_user(path: web::Path<u32>) -> impl Responder {
    HttpResponse::Ok().json(User {
        id: *path,
        name: format!("User {}", path),
    })
}

#[actix_web::main]
async fn main() -> std::io::Result<()> {
    HttpServer::new(|| {
        App::new()
            .route("/", web::get().to(index))
            .route("/greet", web::post().to(greet))
            .route("/users/{id}", web::get().to(get_user))
    })
    .bind("127.0.0.1:8080")?
    .run()
    .await
}
▶ Run

数据库集成

使用 SQLx

toml
[dependencies]
sqlx = { version = "0.7", features = ["postgres", "runtime-tokio-native-tls"] }
axum = "0.7"
tokio = { version = "1", features = ["full"] }
serde = { version = "1", features = ["derive"] }
rust
use axum::{
    extract::{Path, State},
    Json, Router,
    routing::{get, post},
};
use sqlx::PgPool;
use serde::{Deserialize, Serialize};

#[derive(Debug, Serialize, Deserialize)]
struct User {
    id: i32,
    name: String,
    email: String,
}

#[derive(Deserialize)]
struct NewUser {
    name: String,
    email: String,
}

#[derive(Clone)]
struct AppState {
    db: PgPool,
}

async fn list_users(
    State(state): State<AppState>,
) -> Result<Json<Vec<User>>, sqlx::Error> {
    let users = sqlx::query_as::<_, User>("SELECT * FROM users")
        .fetch_all(&state.db)
        .await?;
    Ok(Json(users))
}

async fn create_user(
    State(state): State<AppState>,
    Json(payload): Json<NewUser>,
) -> Result<Json<User>, sqlx::Error> {
    let user = sqlx::query_as::<_, User>(
        "INSERT INTO users (name, email) VALUES ($1, $2) RETURNING *"
    )
    .bind(&payload.name)
    .bind(&payload.email)
    .fetch_one(&state.db)
    .await?;
    Ok(Json(user))
}

#[tokio::main]
async fn main() {
    let database_url = std::env::var("DATABASE_URL")
        .expect("DATABASE_URL must be set");

    let db = PgPool::connect(&database_url)
        .await
        .expect("Failed to connect to database");

    let state = AppState { db };

    let app = Router::new()
        .route("/users", get(list_users))
        .route("/users", post(create_user))
        .with_state(state);

    let listener = tokio::net::TcpListener::bind("127.0.0.1:3000")
        .await
        .unwrap();

    axum::serve(listener, app).await.unwrap();
}
▶ Run