本教程将指导你如何使用 Rust 和 Axum 框架实现基于 JWT(JSON Web Token)的认证系统。我们将逐步构建一个完整的认证流程,包括令牌生成、验证、HttpOnly-Cookie 存储和用户状态管理。


项目设置与依赖配置

首先,创建一个新的 Rust 项目并添加必要的依赖:

cargo new jwt_demo
cd jwt_demo

编辑 Cargo.toml 文件,添加以下依赖:

[package]
name = "jwt_demo"
version = "0.1.0"
edition = "2024"

[dependencies]
axum = "0.8.4"
axum-extra = { version = "0.10.1", features = ["typed-header","cookie"] }

jsonwebtoken = "9.3.1"
serde = { version = "1.0.219", features = ["derive"] }
serde_json = "1.0.143"
tokio = { version = "1.47.1", features = ["full"] }
tracing = "0.1.41"
tracing-subscriber = { version = "0.3.20", features = ["env-filter"] }
time = "0.3.43"

这些依赖提供了:

  • axum: Web 框架

  • axum-extra: 额外的 Axum 功能,包括 Cookie 支持

  • jsonwebtoken: JWT 处理

  • serde: 序列化和反序列化

  • tokio: 异步运行时

  • tracing: 日志记录

  • time: 时间处理


基础数据结构定义

在 src/main.rs 中,我们首先定义所需的数据结构:

use serde::{Deserialize, Serialize};

/// JWT 声明结构
/// 包含需要在令牌中传递的信息,会被序列化到 JWT 中
#[derive(Debug, Serialize, Deserialize)]
struct Claims {
    sub: String,        // 标准字段:主题(subject)
    company: String,    // 自定义字段:公司名称
    exp: usize,         // 标准字段:过期时间(expiration time),Unix 时间戳
}

/// 认证请求负载结构体
/// 用于接收客户端发送的认证凭据
#[derive(Debug, Deserialize)]
struct AuthPayload {
    client_id: String,      // 客户端 ID
    client_secret: String,  // 客户端密钥
}

/// 认证响应结构体
/// 包含生成的 JWT 令牌和令牌类型
#[derive(Debug, Serialize)]
struct AuthBody {
    access_token: String,   // JWT 访问令牌
    token_type: String,     // 令牌类型(通常为 "Bearer")
}

/// 登出响应结构体
#[derive(Debug, Serialize)]
struct LogoutResponse {
    message: String,        // 登出成功消息
}

/// 认证相关错误枚举
/// 定义了各种可能的认证错误类型
#[derive(Debug)]
enum AuthError {
    WrongCredentials,      // 凭据错误
    MissingCredentials,    // 缺少凭据
    TokenCreation,         // 令牌创建失败
    InvalidToken,          // 无效的令牌
}

这些数据结构定义了:

  • Claims: JWT 令牌中包含的用户信息

  • AuthPayload: 客户端发送的认证请求

  • AuthBody: 认证成功后的响应

  • LogoutResponse: 登出操作的响应

  • AuthError: 认证过程中可能出现的错误


JWT 密钥管理

JWT 需要使用密钥进行签名和验证。我们创建一个专门的结构来管理这些密钥:

use jsonwebtoken::{DecodingKey, EncodingKey};
use std::sync::LazyLock;

/// 存储 JWT 编码和解码所需的密钥
struct Keys {
    encoding: EncodingKey,  // 用于生成 JWT 的编码密钥
    decoding: DecodingKey,  // 用于验证 JWT 的解码密钥
}

impl Keys {
    /// 创建新的 Keys 实例
    /// 从字节数组初始化编码和解码密钥
    fn new(secret: &[u8]) -> Self {
        Self {
            encoding: EncodingKey::from_secret(secret),
            decoding: DecodingKey::from_secret(secret),
        }
    }
}

/// 全局 JWT 密钥存储
/// 使用 LazyLock 实现线程安全的延迟初始化,程序启动时从环境变量加载密钥
static KEYS: LazyLock<Keys> = LazyLock::new(|| {
    let secret = std::env::var("JWT_SECRET").expect("JWT_SECRET must be set");
    Keys::new(secret.as_bytes())
});

这里使用了 LazyLock 来确保密钥只被初始化一次,并且是线程安全的。密钥从环境变量 JWT_SECRET 中获取,这增加了安全性,因为密钥不会硬编码在代码中。


认证路由实现

接下来,我们实现认证相关的路由处理函数:

use axum::{
    extract::State,
    response::{IntoResponse, Response},
    Json,
};
use axum_extra::extract::{CookieJar, cookie::Cookie};
use time::Duration;

/// 处理授权请求,验证客户端凭据并生成 JWT 令牌
async fn authorize(
    jar: CookieJar,
    Json(payload): Json<AuthPayload>,
) -> Result<(CookieJar, Json<AuthBody>), AuthError> {
    // 检查客户端是否提供了凭据
    if payload.client_id.is_empty() || payload.client_secret.is_empty() {
        return Err(AuthError::MissingCredentials);
    }
    
    // 验证客户端凭据(实际应用中应从数据库查询验证)
    // 这里使用硬编码的 "foo" 和 "bar" 作为有效凭据
    if payload.client_id != "foo" || payload.client_secret != "bar" {
        return Err(AuthError::WrongCredentials);
    }
    
    // 创建 JWT 声明(Claims)
    let claims = Claims {
        sub: "b@b.com".to_owned(),  // 主题(通常是用户标识)
        company: "ACME".to_owned(),  // 自定义字段:公司名称
        exp: 2000000000,  // 过期时间(Unix 时间戳),这里设置到 2033 年
    };
    
    // 生成 JWT 令牌
    let token = encode(&Header::default(), &claims, &KEYS.encoding)
        .map_err(|_| AuthError::TokenCreation)?;  // 令牌生成失败时返回错误

    // 创建 HttpOnly cookie
    let cookie = Cookie::build(("jwt_token", token.clone()))
        .path("/") // 设置 cookie 路径
        .http_only(true) // 设置为 HttpOnly,防止 XSS 攻击
        .secure(false) // 开发环境使用 false,生产环境应设置为 true
        .same_site(axum_extra::extract::cookie::SameSite::Strict) // 设置 SameSite 策略
        .max_age(Duration::days(1)) // 设置 cookie 有效期
        .build();

    // 将 cookie 添加到 jar 中并返回
    Ok((jar.add(cookie), Json(AuthBody::new(token))))
}

/// 处理登出请求
async fn logout(jar: CookieJar) -> Result<(CookieJar, Json<LogoutResponse>), AuthError> {
    // 创建一个同名的空 cookie 并设置立即过期来清除 cookie
    let cookie = Cookie::build(("jwt_token", ""))
        .path("/")
        .http_only(true)
        .secure(false)
        .max_age(Duration::seconds(0)) // 设置为 0 使 cookie 立即过期
        .build();

    Ok((
        jar.add(cookie),
        Json(LogoutResponse {
            message: "Successfully logged out".to_string(),
        }),
    ))
}

/// 处理受保护路由的请求
/// 需要验证 JWT 令牌,成功后返回受保护内容
async fn protected(claims: Claims) -> Result<String, AuthError> {
    // 向用户发送受保护的数据
    Ok(format!(
        "Welcome to the protected area :)\nYour data:\n{}",
        claims
    ))
}

/// 检查认证状态
async fn check_auth(claims: Claims) -> Result<Json<serde_json::Value>, AuthError> {
    Ok(Json(serde_json::json!({
        "authenticated": true,
        "user": {
            "email": claims.sub,
            "company": claims.company
        }
    })))
}

/// AuthBody 结构体的辅助方法
impl AuthBody {
    /// 创建新的 AuthBody 实例
    fn new(access_token: String) -> Self {
        Self {
            access_token,  // JWT 令牌
            token_type: "Bearer".to_string(),  // 令牌类型,固定为 "Bearer"
        }
    }
}

/// 为 Claims 实现 Display trait,用于格式化输出
impl std::fmt::Display for Claims {
    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
        write!(f, "Email: {}\nCompany: {}", self.sub, self.company)
    }
}

这些函数实现了:

  • authorize: 验证用户凭据并生成 JWT 令牌,存储在 HttpOnly Cookie 中

  • logout: 清除 JWT Cookie,实现用户登出

  • protected: 受保护的路由,需要有效的 JWT 令牌才能访问

  • check_auth: 检查当前认证状态


JWT 令牌提取与验证

为了实现从请求中自动提取和验证 JWT 令牌,我们需要为 Claims 实现 FromRequestParts trait:

use axum::{
    extract::FromRequestParts,
    http::request::Parts,
    RequestPartsExt,
};
use axum_extra::{
    headers::{authorization::Bearer, Authorization},
    TypedHeader,
};
use jsonwebtoken::{decode, Validation};

/// 为 Claims 实现 FromRequestParts trait,用于从请求中提取并验证 JWT 令牌
/// 实现后,Claims 可以作为处理函数的参数,自动从请求中提取
impl<S> FromRequestParts<S> for Claims
where
    S: Send + Sync,  // 状态类型需要满足 Send + Sync 约束
{
    type Rejection = AuthError;  // 提取失败时的错误类型

    /// 从请求部分提取并验证 JWT 令牌,返回 Claims
    async fn from_request_parts(parts: &mut Parts, _state: &S) -> Result<Self, Self::Rejection> {
        // 首先尝试从 Cookie 中获取令牌
        let jar = CookieJar::from_request_parts(parts, _state)
            .await
            .map_err(|_| AuthError::InvalidToken)?;
            
        if let Some(cookie) = jar.get("jwt_token") {
            let token = cookie.value();
            let token_data = decode::<Claims>(token, &KEYS.decoding, &Validation::default())
                .map_err(|_| AuthError::InvalidToken)?;  // 解码失败时返回无效令牌错误
            return Ok(token_data.claims);
        }
        
        // 如果 Cookie 中没有令牌,尝试从 Authorization 头中获取(向后兼容)
        if let Ok(TypedHeader(Authorization(bearer))) = parts
            .extract::<TypedHeader<Authorization<Bearer>>>().await
        {
            let token_data = decode::<Claims>(bearer.token(), &KEYS.decoding, &Validation::default())
                .map_err(|_| AuthError::InvalidToken)?;  // 解码失败时返回无效令牌错误
            return Ok(token_data.claims);
        }
        
        Err(AuthError::InvalidToken)
    }
}

这个实现允许我们:

  1. 首先尝试从 Cookie 中提取 JWT 令牌

  2. 如果 Cookie 中没有令牌,则尝试从 Authorization 头中提取(向后兼容)

  3. 验证令牌的有效性并提取其中的声明信息


错误处理

为了提供友好的错误响应,我们需要为 AuthError 实现 IntoResponse trait:

use axum::http::StatusCode;
use serde_json::json;

/// 为 AuthError 实现 IntoResponse trait,将错误转换为 HTTP 响应
impl IntoResponse for AuthError {
    fn into_response(self) -> Response {
        // 根据错误类型确定 HTTP 状态码和错误消息
        let (status, error_message) = match self {
            AuthError::WrongCredentials => (StatusCode::UNAUTHORIZED, "Wrong credentials"),
            AuthError::MissingCredentials => (StatusCode::BAD_REQUEST, "Missing credentials"),
            AuthError::TokenCreation => (StatusCode::INTERNAL_SERVER_ERROR, "Token creation error"),
            AuthError::InvalidToken => (StatusCode::UNAUTHORIZED, "Invalid token"),
        };
        
        // 构建 JSON 响应体
        let body = Json(json!({
            "error": error_message,
        }));
        
        // 组合状态码和响应体,转换为 Response
        (status, body).into_response()
    }
}

这个实现确保了每种错误类型都有对应的 HTTP 状态码和友好的错误消息。


完整代码与测试

现在,让我们将所有部分组合起来,创建完整的应用程序:

use axum::{
    extract::FromRequestParts,
    http::{request::Parts, StatusCode},
    response::{IntoResponse, Response},
    routing::{get, post},
    Json, RequestPartsExt, Router,
};
use axum_extra::{
    extract::{CookieJar, cookie::Cookie},
    headers::{authorization::Bearer, Authorization},
    TypedHeader,
};
use jsonwebtoken::{decode, encode, DecodingKey, EncodingKey, Header, Validation};
use serde::{Deserialize, Serialize};
use serde_json::json;
use std::fmt::Display;
use std::sync::LazyLock;
use time::Duration;
use tokio::net::TcpListener;
use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt};

// 此处插入之前定义的所有结构体和实现...

/// 程序入口点
#[tokio::main]
async fn main() {
    // 初始化日志追踪系统
    tracing_subscriber::registry()
        .with(
            tracing_subscriber::EnvFilter::try_from_default_env()
                .unwrap_or_else(|_| "debug".into()),
        )
        .with(tracing_subscriber::fmt::layer())
        .init();

    // 创建路由
    let app = Router::new()
        .route("/protected", get(protected))    // 受保护路由
        .route("/authorize", post(authorize))   // 认证路由
        .route("/logout", post(logout))         // 登出路由
        .route("/check-auth", get(check_auth)); // 检查登录状态

    // 绑定到本地地址 127.0.0.1:3000
    let listener = TcpListener::bind("127.0.0.1:3000").await.unwrap();
    tracing::debug!("listening on {}", listener.local_addr().unwrap());
    axum::serve(listener, app).await.unwrap();
}

测试应用程序

启动应用程序前,设置 JWT 密钥环境变量:

set JWT_SECRET=your_super_secret_key
cargo run

然后使用 curl 命令测试各个端点:

1、获取 JWT 令牌

curl -X POST http://localhost:3000/authorize -H "Content-Type: application/json" -d "{\"client_id\":\"foo\",\"client_secret\":\"bar\"}" -c cookies.txt

运行结果:

{"access_token":"eyJ0eXAiOiJKV1QiLCJhbGciOiJIUzI1NiJ9.eyJzdWIiOiJiQGIuY29tIiwiY29tcGFueSI6IkFDTUUiLCJleHAiOjIwMDAwMDAwMDB9.10Lx0Dy0UEW5hhh75kz04Cb0XuG1iUOClKHM1u0dlP8","token_type":"Bearer"}

2、访问受保护的路由(使用 Cookie):

curl -b "jwt_token=<YOUR_TOKEN>" http://localhost:3000/protected

运行结果:

Welcome to the protected area :)
Your data:
Email: b@b.com
Company: ACME

3、检查认证状态

curl -b cookies.txt http://localhost:3000/check-auth

运行结果:

{"authenticated":true,"user":{"company":"ACME","email":"b@b.com"}}

4、登出

curl -X POST -b cookies.txt http://localhost:3000/logout

运行结果:

{"message":"Successfully logged out"}

安全最佳实践

在实现 JWT 认证时,请遵循以下安全最佳实践:

  1. 使用强密钥:JWT 密钥应该足够长且随机,建议至少 32 字节

  2. 设置合理的过期时间:JWT 令牌应该有适当的过期时间,减少被盗用的风险

  3. 使用 HTTPS:在生产环境中,始终使用 HTTPS 来传输令牌

  4. HttpOnly Cookie:将 JWT 存储在 HttpOnly Cookie 中可以防止 XSS 攻击

  5. SameSite 属性:设置适当的 SameSite 属性可以防止 CSRF 攻击

  6. 验证声明:始终验证 JWT 中的声明,特别是过期时间(exp)

  7. 不要存储敏感信息:JWT 的内容是 Base64 编码的,可以被解码,不要存储敏感信息

通过本教程,你已经学会了如何使用 Rust 和 Axum 框架实现一个完整的 JWT 认证系统。这个系统包含了令牌生成、验证、Cookie 存储和用户状态管理等关键功能,可以作为一个安全可靠的认证基础。


完整代码

//! JWT 授权(authorization) / 认证示例(authentication)。
//! 该示例使用 axum 框架实现了基于 JWT 的身份验证机制,使用 HttpOnly Cookie 存储令牌
//!
//! 运行方式:
//!
//! ```not_rust
//! JWT_SECRET=secret cargo run
//! ```

use axum::{
    extract::{FromRequestParts, State}, // 移除了未使用的 Query
    http::{request::Parts, StatusCode},
    response::{IntoResponse, Response},
    routing::{get, post},
    Json, RequestPartsExt, Router,
};
use axum_extra::{
    extract::{CookieJar, cookie::Cookie}, // 正确导入 Cookie
    headers::{authorization::Bearer, Authorization},
    TypedHeader,
};
use jsonwebtoken::{decode, encode, DecodingKey, EncodingKey, Header, Validation};
use serde::{Deserialize, Serialize};
use serde_json::json;
use std::fmt::Display;
use std::sync::LazyLock;
use time::Duration; // 移除了未使用的 OffsetDateTime
use tokio::net::TcpListener;
use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt};

// 添加 AppState 用于共享状态
#[derive(Clone)]
struct AppState {
    keys: Keys,
}

// 全局 JWT 密钥存储
static KEYS: LazyLock<Keys> = LazyLock::new(|| {
    let secret = std::env::var("JWT_SECRET").expect("JWT_SECRET must be set");
    Keys::new(secret.as_bytes())
});

/// 程序入口点
#[tokio::main]
async fn main() {
    // 初始化日志追踪系统
    tracing_subscriber::registry()
        .with(
            tracing_subscriber::EnvFilter::try_from_default_env()
                .unwrap_or_else(|_| "debug".into()),
        )
        .with(tracing_subscriber::fmt::layer())
        .init();

    // 初始化应用状态
    let state = AppState {
        keys: Keys::new(
            std::env::var("JWT_SECRET")
                .expect("JWT_SECRET must be set")
                .as_bytes(),
        ),
    };

    // 创建路由
    let app = Router::new()
        .route("/protected", get(protected))
        .route("/authorize", post(authorize))
        .route("/logout", post(logout))
        .route("/check-auth", get(check_auth))
        .with_state(state); // 添加状态管理

    // 绑定到本地地址 127.0.0.1:3000
    let listener = TcpListener::bind("127.0.0.1:3000").await.unwrap();
    tracing::debug!("listening on {}", listener.local_addr().unwrap());
    axum::serve(listener, app).await.unwrap();
}

/// 处理受保护路由的请求
async fn protected(claims: Claims) -> Result<String, AuthError> {
    Ok(format!(
        "Welcome to the protected area :)\nYour data:\n{claims}",
    ))
}

/// 检查认证状态
async fn check_auth(claims: Claims) -> Result<Json<serde_json::Value>, AuthError> {
    Ok(Json(json!({
        "authenticated": true,
        "user": {
            "email": claims.sub,
            "company": claims.company
        }
    })))
}

/// 处理授权请求
async fn authorize(
    State(state): State<AppState>,
    jar: CookieJar,
    Json(payload): Json<AuthPayload>,
) -> Result<(CookieJar, Json<AuthBody>), AuthError> {
    if payload.client_id.is_empty() || payload.client_secret.is_empty() {
        return Err(AuthError::MissingCredentials);
    }
    
    if payload.client_id != "foo" || payload.client_secret != "bar" {
        return Err(AuthError::WrongCredentials);
    }
    
    let claims = Claims {
        sub: "b@b.com".to_owned(),
        company: "ACME".to_owned(),
        exp: 2000000000,
    };
    
    let token = encode(&Header::default(), &claims, &state.keys.encoding)
        .map_err(|_| AuthError::TokenCreation)?;

    // 创建 HttpOnly cookie
    let cookie = Cookie::build(("jwt_token", token.clone()))
        .path("/") // 设置 cookie 路径
        .http_only(true) // 设置为 HttpOnly
        .secure(false) // 开发环境使用 false,生产环境应设置为 true
        .same_site(axum_extra::extract::cookie::SameSite::Strict) // 设置 SameSite 策略
        .max_age(Duration::days(1)) // 设置 cookie 有效期
        .build();

    // 将 cookie 添加到 jar 中并返回
    Ok((jar.add(cookie), Json(AuthBody::new(token))))
}

/// 处理登出请求
async fn logout(jar: CookieJar) -> Result<(CookieJar, Json<LogoutResponse>), AuthError> {
    // 创建一个同名的空 cookie 并设置立即过期来清除 cookie
    let cookie = Cookie::build(("jwt_token", ""))
        .path("/")
        .http_only(true)
        .secure(false)
        .max_age(Duration::seconds(0)) // 设置为 0 使 cookie 立即过期
        .build();

    Ok((
        jar.add(cookie),
        Json(LogoutResponse {
            message: "Successfully logged out".to_string(),
        }),
    ))
}

impl Display for Claims {
    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
        write!(f, "Email: {}\nCompany: {}", self.sub, self.company)
    }
}

impl AuthBody {
    fn new(access_token: String) -> Self {
        Self {
            access_token,
            token_type: "Bearer".to_string(),
        }
    }
}

/// 修改 FromRequestParts 实现,支持从 Cookie 或 Authorization 头提取令牌
impl<S> FromRequestParts<S> for Claims
where
    S: Send + Sync,
{
    type Rejection = AuthError;

    async fn from_request_parts(parts: &mut Parts, state: &S) -> Result<Self, Self::Rejection> {
        // 首先尝试从 Cookie 中获取令牌
        let jar = CookieJar::from_request_parts(parts, state)
            .await
            .map_err(|_| AuthError::InvalidToken)?;
            
        if let Some(cookie) = jar.get("jwt_token") {
            let token = cookie.value();
            let token_data = decode::<Claims>(token, &KEYS.decoding, &Validation::default())
                .map_err(|_| AuthError::InvalidToken)?;
            return Ok(token_data.claims);
        }
        
        // 如果 Cookie 中没有令牌,尝试从 Authorization 头中获取(向后兼容)
        if let Ok(TypedHeader(Authorization(bearer))) = parts
            .extract::<TypedHeader<Authorization<Bearer>>>().await
        {
            let token_data = decode::<Claims>(bearer.token(), &KEYS.decoding, &Validation::default())
                .map_err(|_| AuthError::InvalidToken)?;
            return Ok(token_data.claims);
        }
        
        Err(AuthError::InvalidToken)
    }
}

impl IntoResponse for AuthError {
    fn into_response(self) -> Response {
        let (status, error_message) = match self {
            AuthError::WrongCredentials => (StatusCode::UNAUTHORIZED, "Wrong credentials"),
            AuthError::MissingCredentials => (StatusCode::BAD_REQUEST, "Missing credentials"),
            AuthError::TokenCreation => (StatusCode::INTERNAL_SERVER_ERROR, "Token creation error"),
            AuthError::InvalidToken => (StatusCode::UNAUTHORIZED, "Invalid token"), // 修改为 UNAUTHORIZED
        };
        
        let body = Json(json!({
            "error": error_message,
        }));
        
        (status, body).into_response()
    }
}

/// JWT 密钥结构
#[derive(Clone)] // 为 Keys 添加 Clone trait
struct Keys {
    encoding: EncodingKey,
    decoding: DecodingKey,
}

impl Keys {
    fn new(secret: &[u8]) -> Self {
        Self {
            encoding: EncodingKey::from_secret(secret),
            decoding: DecodingKey::from_secret(secret),
        }
    }
}

/// JWT 声明结构
#[derive(Debug, Serialize, Deserialize)]
struct Claims {
    sub: String,
    company: String,
    exp: usize,
}

/// 认证响应结构
#[derive(Debug, Serialize)]
struct AuthBody {
    access_token: String,
    token_type: String,
}

/// 登出响应结构
#[derive(Debug, Serialize)]
struct LogoutResponse {
    message: String,
}

/// 认证请求负载结构
#[derive(Debug, Deserialize)]
struct AuthPayload {
    client_id: String,
    client_secret: String,
}

/// 认证错误枚举
#[derive(Debug)]
enum AuthError {
    WrongCredentials,
    MissingCredentials,
    TokenCreation,
    InvalidToken,
}

Logo

有“AI”的1024 = 2048,欢迎大家加入2048 AI社区

更多推荐