Skip to content

第六步:错误处理

src/error.rs

rust
use axum::{
    http::StatusCode,
    response::{Response, IntoResponse},
    Json,
};
use serde_json::json;
use thiserror::Error;

#[derive(Debug, Error)]
pub enum AppError {
    #[error("数据库错误: {0}")]
    Database(#[from] sqlx::Error),
    
    #[error("认证失败")]
    Unauthorized,
    
    #[error("资源未找到")]
    NotFound,
    
    #[error("请求参数错误: {0}")]
    BadRequest(String),
    
    #[error("内部错误: {0}")]
    Internal(#[from] anyhow::Error),
}

impl IntoResponse for AppError {
    fn into_response(self) -> Response {
        let (status, message) = match self {
            AppError::Database(e) => {
                if let sqlx::Error::RowNotFound = e {
                    (StatusCode::NOT_FOUND, "资源未找到")
                } else {
                    (StatusCode::INTERNAL_SERVER_ERROR, "数据库错误")
                }
            }
            AppError::Unauthorized => (StatusCode::UNAUTHORIZED, "认证失败"),
            AppError::NotFound => (StatusCode::NOT_FOUND, "资源未找到"),
            AppError::BadRequest(msg) => (StatusCode::BAD_REQUEST, msg.as_str()),
            AppError::Internal(_) => (StatusCode::INTERNAL_SERVER_ERROR, "内部错误"),
        };
        
        let body = Json(json!({
            "error": message,
        }));
        
        (status, body).into_response()
    }
}

pub type Result<T> = std::result::Result<T, AppError>;
▶ Run

第七步:用户处理

src/handlers/user.rs

rust
use axum::{
    extract::{Path, State},
    Json,
};
use bcrypt::{hash, DEFAULT_COST};
use sqlx::PgPool;
use uuid::Uuid;

use crate::error::Result;
use crate::models::user::{CreateUser, UpdateUser, UserResponse};

/// 创建用户
pub async fn create_user(
    State(pool): State<PgPool>,
    Json(payload): Json<CreateUser>,
) -> Result<Json<UserResponse>> {
    // 验证输入
    if payload.email.is_empty() || payload.password.is_empty() {
        return Err(crate::error::AppError::BadRequest("邮箱和密码不能为空".to_string()));
    }
    
    // 密码加密
    let password_hash = hash(payload.password, DEFAULT_COST)
        .map_err(|e| crate::error::AppError::Internal(e.into()))?;
    
    // 插入数据库
    let user = sqlx::query_as::<_, crate::models::user::User>(
        r#"
        INSERT INTO users (email, password_hash, username)
        VALUES ($1, $2, $3)
        RETURNING *
        "#
    )
    .bind(&payload.email)
    .bind(&password_hash)
    .bind(&payload.username)
    .fetch_one(&pool)
    .await?;
    
    Ok(Json(UserResponse::from(user)))
}

/// 获取所有用户
pub async fn list_users(
    State(pool): State<PgPool>,
) -> Result<Json<Vec<UserResponse>>> {
    let users = sqlx::query_as::<_, crate::models::user::User>(
        "SELECT * FROM users ORDER BY created_at DESC"
    )
    .fetch_all(&pool)
    .await?;
    
    Ok(Json(users.into_iter().map(UserResponse::from).collect()))
}

/// 获取单个用户
pub async fn get_user(
    State(pool): State<PgPool>,
    Path(id): Path<Uuid>,
) -> Result<Json<UserResponse>> {
    let user = sqlx::query_as::<_, crate::models::user::User>(
        "SELECT * FROM users WHERE id = $1"
    )
    .bind(id)
    .fetch_one(&pool)
    .await?;
    
    Ok(Json(UserResponse::from(user)))
}

/// 更新用户
pub async fn update_user(
    State(pool): State<PgPool>,
    Path(id): Path<Uuid>,
    Json(payload): Json<UpdateUser>,
) -> Result<Json<UserResponse>> {
    let user = sqlx::query_as::<_, crate::models::user::User>(
        r#"
        UPDATE users
        SET username = COALESCE($1, username),
            email = COALESCE($2, email),
            updated_at = NOW()
        WHERE id = $3
        RETURNING *
        "#
    )
    .bind(&payload.username)
    .bind(&payload.email)
    .bind(id)
    .fetch_one(&pool)
    .await?;
    
    Ok(Json(UserResponse::from(user)))
}

/// 删除用户
pub async fn delete_user(
    State(pool): State<PgPool>,
    Path(id): Path<Uuid>,
) -> Result<()> {
    sqlx::query("DELETE FROM users WHERE id = $1")
        .bind(id)
        .execute(&pool)
        .await?;
    
    Ok(())
}
▶ Run

第八步:认证处理

src/handlers/auth.rs

rust
use axum::{extract::State, Json};
use bcrypt::verify;
use jsonwebtoken::{encode, EncodingKey, Header};
use sqlx::PgPool;
use std::time::{SystemTime, UNIX_EPOCH};

use crate::config::Config;
use crate::error::{AppError, Result};
use crate::models::auth::{Claims, LoginRequest, LoginResponse};
use crate::models::user::UserResponse;

/// 用户登录
pub async fn login(
    State(pool): State<PgPool>,
    State(config): State<Config>,
    Json(payload): Json<LoginRequest>,
) -> Result<Json<LoginResponse>> {
    // 查找用户
    let user = sqlx::query_as::<_, crate::models::user::User>(
        "SELECT * FROM users WHERE email = $1"
    )
    .bind(&payload.email)
    .fetch_optional(&pool)
    .await?;
    
    let user = user.ok_or(AppError::Unauthorized)?;
    
    // 验证密码
    let valid = verify(&payload.password, &user.password_hash)
        .map_err(|e| AppError::Internal(e.into()))?;
    
    if !valid {
        return Err(AppError::Unauthorized);
    }
    
    // 生成 JWT
    let now = SystemTime::now()
        .duration_since(UNIX_EPOCH)
        .unwrap()
        .as_secs() as usize;
    
    let claims = Claims {
        sub: user.id.to_string(),
        email: user.email.clone(),
        exp: now + config.jwt_expiration as usize,
        iat: now,
    };
    
    let token = encode(
        &Header::default(),
        &claims,
        &EncodingKey::from_secret(config.jwt_secret.as_bytes()),
    )
    .map_err(|e| AppError::Internal(e.into()))?;
    
    Ok(Json(LoginResponse {
        token,
        user: UserResponse::from(user),
    }))
}
▶ Run

第九步:认证中间件

src/middleware/auth.rs

rust
use axum::{
    extract::Request,
    http::StatusCode,
    middleware::Next,
    response::Response,
};
use jsonwebtoken::{decode, DecodingKey, Validation};
use crate::models::auth::Claims;
use crate::error::AppError;

/// 认证中间件
pub async fn auth_middleware(
    request: Request,
    next: Next,
) -> Result<Response, AppError> {
    let auth_header = request
        .headers()
        .get("Authorization")
        .and_then(|h| h.to_str().ok())
        .ok_or(AppError::Unauthorized)?;
    
    if !auth_header.starts_with("Bearer ") {
        return Err(AppError::Unauthorized);
    }
    
    let token = &auth_header[7..];
    
    // 从环境变量获取密钥
    let secret = std::env::var("JWT_SECRET")
        .unwrap_or("default-secret".to_string());
    
    let claims = decode::<Claims>(
        token,
        &DecodingKey::from_secret(secret.as_bytes()),
        &Validation::default(),
    )
    .map_err(|_| AppError::Unauthorized)?
    .claims;
    
    // 将用户信息添加到请求扩展中
    let request = request.extensions().insert(claims);
    
    Ok(next.run(request).await)
}
▶ Run

第十步:主入口

src/main.rs

rust
use axum::{
    routing::{get, post, put, delete},
    Router,
    Extension,
};
use tower_http::cors::{CorsLayer, Any};
use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt};

mod config;
mod db;
mod models;
mod handlers;
mod middleware;
mod error;

#[tokio::main]
async fn main() -> anyhow::Result<()> {
    // 初始化日志
    tracing_subscriber::registry()
        .with(tracing_subscriber::EnvFilter::new("info"))
        .with(tracing_subscriber::fmt::layer())
        .init();
    
    // 加载配置
    let config = config::Config::from_env()?;
    
    // 连接数据库
    let pool = db::create_pool(&config.database_url).await?;
    
    // 构建路由
    let app = Router::new()
        // 公开路由
        .route("/api/auth/login", post(handlers::auth::login))
        .route("/api/users", post(handlers::user::create_user))
        
        // 需要认证的路由
        .route("/api/users", get(handlers::user::list_users))
        .route("/api/users/:id", get(handlers::user::get_user))
        .route("/api/users/:id", put(handlers::user::update_user))
        .route("/api/users/:id", delete(handlers::user::delete_user))
        .layer(middleware::auth_middleware)
        
        // CORS
        .layer(CorsLayer::new().allow_origin(Any).allow_methods(Any))
        
        // 状态
        .with_state(pool.clone())
        .with_state(config.clone());
    
    // 启动服务器
    let addr = format!("{}:{}", config.server_host, config.server_port);
    tracing::info!("服务器启动: http://{}", addr);
    
    let listener = tokio::net::TcpListener::bind(&addr).await?;
    axum::serve(listener, app).await?;
    
    Ok(())
}
▶ Run

API 测试

bash
# 创建用户
curl -X POST http://localhost:3000/api/users \
  -H "Content-Type: application/json" \
  -d '{"email":"test@example.com","password":"password123","username":"testuser"}'

# 登录
curl -X POST http://localhost:3000/api/auth/login \
  -H "Content-Type: application/json" \
  -d '{"email":"test@example.com","password":"password123"}'

# 获取用户列表(需要认证)
curl http://localhost:3000/api/users \
  -H "Authorization: Bearer <token>"

# 获取单个用户
curl http://localhost:3000/api/users/<id> \
  -H "Authorization: Bearer <token>"

# 更新用户
curl -X PUT http://localhost:3000/api/users/<id> \
  -H "Authorization: Bearer <token>" \
  -H "Content-Type: application/json" \
  -d '{"username":"newname"}'

# 删除用户
curl -X DELETE http://localhost:3000/api/users/<id> \
  -H "Authorization: Bearer <token>"