【光子AI】FastAPI 极简教程
FastAPI 是一个现代、高性能的 Python Web 框架,专为构建 API 而设计。本教程涵盖从基础到进阶的 FastAPI 开发内容,包括: 环境搭建与安装 创建第一个 FastAPI 应用 路径参数和查询参数处理 请求体验证与 Pydantic 模型 响应模型和状态码控制 数据库集成与用户认证 高级功能如依赖注入、中间件、WebSocket 支持 测试与部署实践 教程提供代码示例和项目
FastAPI 极简教程
文章目录
目录
- FastAPI 简介
- 环境搭建与安装
- 第一个 FastAPI 应用
- 路径参数与查询参数
- 请求体与 Pydantic 模型
- 响应模型与状态码
- 表单数据与文件上传
- 依赖注入系统
- 中间件与 CORS
- 数据库集成
- 用户认证与授权
- 错误处理
- 后台任务
- WebSocket 支持
- 测试与部署
- 项目实战:构建完整 API
1. FastAPI 简介
1.1 什么是 FastAPI
FastAPI 是一个现代、快速(高性能)的 Python Web 框架,用于构建 API。它基于 Python 3.7+ 的类型提示功能,结合了 Starlette(用于 Web 部分)和 Pydantic(用于数据验证)的优点。
1.2 核心特性
| 特性 | 描述 |
|---|---|
| 高性能 | 与 NodeJS 和 Go 相当的性能,得益于 Starlette 和 Pydantic |
| 快速开发 | 开发速度提升 200%-300% |
| 更少的 Bug | 减少约 40% 的人为错误 |
| 智能提示 | 完整的编辑器支持和自动补全 |
| 简单易学 | 设计简洁,文档完善 |
| 自动文档 | 自动生成交互式 API 文档(Swagger UI 和 ReDoc) |
| 标准化 | 完全兼容 OpenAPI 和 JSON Schema |
1.3 为什么选择 FastAPI
传统 Flask/Django REST API 开发流程:
定义路由 → 手动验证数据 → 手动序列化 → 手动编写文档
FastAPI 开发流程:
定义路由 + 类型提示 → 自动验证 + 自动序列化 + 自动文档
2. 环境搭建与安装
2.1 Python 版本要求
FastAPI 需要 Python 3.7 或更高版本。建议使用 Python 3.9+。
# 检查 Python 版本
python --version
# 或
python3 --version
2.2 创建虚拟环境
# 创建项目目录
mkdir fastapi-tutorial
cd fastapi-tutorial
# 创建虚拟环境
python -m venv venv
# 激活虚拟环境
# Windows
venv\Scripts\activate
# macOS/Linux
source venv/bin/activate
2.3 安装 FastAPI 和依赖
# 安装 FastAPI
pip install fastapi
# 安装 ASGI 服务器(用于运行应用)
pip install "uvicorn[standard]"
# 或者一次性安装所有推荐依赖
pip install "fastapi[all]"
2.4 项目结构推荐
fastapi-tutorial/
├── app/
│ ├── __init__.py
│ ├── main.py # 应用入口
│ ├── config.py # 配置文件
│ ├── models/ # 数据库模型
│ │ ├── __init__.py
│ │ └── user.py
│ ├── schemas/ # Pydantic 模型
│ │ ├── __init__.py
│ │ └── user.py
│ ├── routers/ # 路由模块
│ │ ├── __init__.py
│ │ └── users.py
│ ├── dependencies/ # 依赖项
│ │ └── __init__.py
│ └── utils/ # 工具函数
│ └── __init__.py
├── tests/ # 测试文件
├── requirements.txt
└── README.md
3. 第一个 FastAPI 应用
3.1 Hello World
创建 main.py 文件:
from fastapi import FastAPI
# 创建 FastAPI 实例
app = FastAPI(
title="我的第一个 FastAPI 应用",
description="这是一个学习 FastAPI 的示例项目",
version="1.0.0"
)
# 定义根路由
@app.get("/")
def read_root():
return {"message": "Hello, FastAPI!"}
# 定义另一个路由
@app.get("/hello/{name}")
def say_hello(name: str):
return {"message": f"Hello, {name}!"}
3.2 运行应用
# 基本运行
uvicorn main:app
# 开发模式(自动重载)
uvicorn main:app --reload
# 指定主机和端口
uvicorn main:app --reload --host 0.0.0.0 --port 8000
3.3 访问自动生成的文档
启动应用后,访问以下地址:
- Swagger UI: http://127.0.0.1:8000/docs
- ReDoc: http://127.0.0.1:8000/redoc
- OpenAPI JSON: http://127.0.0.1:8000/openapi.json
3.4 异步支持
FastAPI 原生支持异步操作:
import asyncio
from fastapi import FastAPI
app = FastAPI()
# 同步函数
@app.get("/sync")
def sync_endpoint():
return {"type": "sync"}
# 异步函数
@app.get("/async")
async def async_endpoint():
await asyncio.sleep(1) # 模拟异步操作
return {"type": "async"}
4. 路径参数与查询参数
4.1 路径参数
路径参数是 URL 路径的一部分:
from fastapi import FastAPI
app = FastAPI()
# 基本路径参数
@app.get("/users/{user_id}")
def get_user(user_id: int):
return {"user_id": user_id}
# 多个路径参数
@app.get("/posts/{post_id}/comments/{comment_id}")
def get_comment(post_id: int, comment_id: int):
return {
"post_id": post_id,
"comment_id": comment_id
}
# 路径参数验证
from enum import Enum
class ModelName(str, Enum):
alexnet = "alexnet"
resnet = "resnet"
lenet = "lenet"
@app.get("/models/{model_name}")
def get_model(model_name: ModelName):
if model_name == ModelName.alexnet:
return {"model_name": model_name, "message": "Deep Learning FTW!"}
return {"model_name": model_name, "message": "Have some residuals"}
4.2 查询参数
查询参数是 URL 中 ? 后面的键值对:
from fastapi import FastAPI
from typing import Optional
app = FastAPI()
# 基本查询参数
@app.get("/items/")
def read_items(skip: int = 0, limit: int = 10):
return {"skip": skip, "limit": limit}
# 可选查询参数
@app.get("/items/{item_id}")
def read_item(item_id: int, q: Optional[str] = None):
if q:
return {"item_id": item_id, "q": q}
return {"item_id": item_id}
# 必需的查询参数(不提供默认值)
@app.get("/search/")
def search(keyword: str): # 必需参数
return {"keyword": keyword}
# 布尔类型查询参数
@app.get("/items/{item_id}/details")
def read_item_details(
item_id: int,
short: bool = False,
full: bool = True
):
item = {"item_id": item_id, "name": "Example Item"}
if short:
return {"item_id": item_id}
if full:
item["description"] = "This is a detailed description"
return item
4.3 使用 Query 进行高级验证
from fastapi import FastAPI, Query
from typing import Optional, List
app = FastAPI()
@app.get("/items/")
def read_items(
q: Optional[str] = Query(
None,
min_length=3,
max_length=50,
regex="^[a-zA-Z]+$",
title="查询字符串",
description="用于搜索的查询字符串"
)
):
return {"q": q}
# 查询参数列表
@app.get("/items/multi/")
def read_items_multi(
q: List[str] = Query(
default=["foo", "bar"],
title="多值查询参数"
)
):
return {"q": q}
# 别名参数
@app.get("/items/alias/")
def read_items_alias(
item_query: Optional[str] = Query(None, alias="item-query")
):
return {"item_query": item_query}
# 弃用参数
@app.get("/items/deprecated/")
def read_items_deprecated(
old_param: Optional[str] = Query(None, deprecated=True)
):
return {"old_param": old_param}
4.4 使用 Path 进行路径参数验证
from fastapi import FastAPI, Path
app = FastAPI()
@app.get("/items/{item_id}")
def read_item(
item_id: int = Path(
..., # ... 表示必需
title="物品 ID",
description="要获取的物品的唯一标识符",
ge=1, # 大于等于 1
le=1000 # 小于等于 1000
)
):
return {"item_id": item_id}
# 路径参数和查询参数组合
@app.get("/items/{item_id}/details")
def read_item_details(
item_id: int = Path(..., ge=1),
q: Optional[str] = Query(None, max_length=50),
short: bool = False
):
item = {"item_id": item_id}
if q:
item["q"] = q
if not short:
item["description"] = "This is a long description"
return item
5. 请求体与 Pydantic 模型
5.1 Pydantic 基础
Pydantic 用于数据验证和序列化:
from pydantic import BaseModel, Field, validator
from typing import Optional, List
from datetime import datetime
# 基本模型
class Item(BaseModel):
name: str
price: float
is_offer: Optional[bool] = None
# 带有 Field 验证的模型
class Product(BaseModel):
name: str = Field(..., min_length=1, max_length=100)
description: Optional[str] = Field(None, max_length=500)
price: float = Field(..., gt=0, description="价格必须大于0")
tax: Optional[float] = Field(None, ge=0)
tags: List[str] = Field(default_factory=list)
# 自定义验证器
@validator('name')
def name_must_not_be_empty(cls, v):
if not v.strip():
raise ValueError('名称不能为空')
return v.strip()
# 嵌套模型
class Address(BaseModel):
street: str
city: str
country: str
zip_code: str
class User(BaseModel):
id: int
username: str
email: str
full_name: Optional[str] = None
address: Optional[Address] = None
created_at: datetime = Field(default_factory=datetime.now)
class Config:
# 允许使用 ORM 模型
orm_mode = True
# 示例数据(用于文档)
schema_extra = {
"example": {
"id": 1,
"username": "johndoe",
"email": "john@example.com",
"full_name": "John Doe"
}
}
5.2 请求体处理
from fastapi import FastAPI
from pydantic import BaseModel
from typing import Optional
app = FastAPI()
class Item(BaseModel):
name: str
description: Optional[str] = None
price: float
tax: Optional[float] = None
# 基本请求体
@app.post("/items/")
def create_item(item: Item):
item_dict = item.dict()
if item.tax:
price_with_tax = item.price + item.tax
item_dict["price_with_tax"] = price_with_tax
return item_dict
# 请求体 + 路径参数 + 查询参数
@app.put("/items/{item_id}")
def update_item(
item_id: int,
item: Item,
q: Optional[str] = None
):
result = {"item_id": item_id, **item.dict()}
if q:
result["q"] = q
return result
5.3 多个请求体参数
from fastapi import FastAPI, Body
from pydantic import BaseModel
app = FastAPI()
class Item(BaseModel):
name: str
price: float
class User(BaseModel):
username: str
full_name: str
# 多个请求体模型
@app.put("/items/{item_id}")
def update_item(
item_id: int,
item: Item,
user: User
):
return {
"item_id": item_id,
"item": item,
"user": user
}
# 使用 Body 嵌入单个值
@app.put("/items/{item_id}/importance")
def update_item_importance(
item_id: int,
item: Item,
importance: int = Body(..., ge=1, le=10)
):
return {
"item_id": item_id,
"item": item,
"importance": importance
}
# 嵌入整个请求体
@app.put("/items/{item_id}/embedded")
def update_item_embedded(
item_id: int,
item: Item = Body(..., embed=True) # 期望 {"item": {...}}
):
return {"item_id": item_id, "item": item}
5.4 字段类型和验证
from pydantic import BaseModel, Field, EmailStr, HttpUrl
from typing import Optional, List, Set
from datetime import datetime, date
from uuid import UUID
class CompleteModel(BaseModel):
# 基本类型
id: int
name: str
price: float
is_active: bool
# 带约束的字段
age: int = Field(..., ge=0, le=150)
score: float = Field(..., ge=0.0, le=100.0)
# 字符串约束
username: str = Field(..., min_length=3, max_length=20)
password: str = Field(..., min_length=8, regex="^(?=.*[a-z])(?=.*[A-Z])(?=.*\d)")
# 特殊类型
email: EmailStr
website: Optional[HttpUrl] = None
uuid: UUID
# 日期时间
created_at: datetime
birth_date: date
# 集合类型
tags: List[str] = []
unique_tags: Set[str] = set()
# 嵌套
metadata: Optional[dict] = None
6. 响应模型与状态码
6.1 响应模型
from fastapi import FastAPI
from pydantic import BaseModel, EmailStr
from typing import Optional, List
app = FastAPI()
class UserIn(BaseModel):
username: str
password: str
email: EmailStr
full_name: Optional[str] = None
class UserOut(BaseModel):
username: str
email: EmailStr
full_name: Optional[str] = None
class UserInDB(BaseModel):
username: str
hashed_password: str
email: EmailStr
full_name: Optional[str] = None
# 使用 response_model 过滤响应数据
@app.post("/users/", response_model=UserOut)
def create_user(user: UserIn):
# 密码不会包含在响应中
return user
# 响应模型排除未设置的字段
@app.get("/users/{user_id}", response_model=UserOut, response_model_exclude_unset=True)
def read_user(user_id: int):
return {"username": "john", "email": "john@example.com"}
# 排除特定字段
@app.get(
"/users/{user_id}/full",
response_model=UserInDB,
response_model_exclude={"hashed_password"}
)
def read_user_full(user_id: int):
return {
"username": "john",
"hashed_password": "hashed_secret",
"email": "john@example.com"
}
6.2 多种响应类型
from fastapi import FastAPI
from fastapi.responses import JSONResponse, HTMLResponse, PlainTextResponse
from pydantic import BaseModel
from typing import Union
app = FastAPI()
class Item(BaseModel):
name: str
price: float
class Message(BaseModel):
message: str
# 多种响应模型
@app.get(
"/items/{item_id}",
response_model=Union[Item, Message],
responses={
200: {"model": Item, "description": "成功返回物品"},
404: {"model": Message, "description": "物品不存在"}
}
)
def read_item(item_id: int):
if item_id == 0:
return Message(message="Item not found")
return Item(name="Foo", price=35.0)
# 不同的响应类型
@app.get("/html/", response_class=HTMLResponse)
def get_html():
return """
<html>
<head><title>Hello</title></head>
<body><h1>Hello, FastAPI!</h1></body>
</html>
"""
@app.get("/text/", response_class=PlainTextResponse)
def get_text():
return "Hello, World!"
6.3 状态码
from fastapi import FastAPI, status
app = FastAPI()
# 使用 status 模块
@app.post("/items/", status_code=status.HTTP_201_CREATED)
def create_item(name: str):
return {"name": name}
@app.delete("/items/{item_id}", status_code=status.HTTP_204_NO_CONTENT)
def delete_item(item_id: int):
return None
# 常用状态码
# HTTP_200_OK - 成功
# HTTP_201_CREATED - 创建成功
# HTTP_204_NO_CONTENT - 无内容
# HTTP_400_BAD_REQUEST - 错误请求
# HTTP_401_UNAUTHORIZED - 未授权
# HTTP_403_FORBIDDEN - 禁止访问
# HTTP_404_NOT_FOUND - 未找到
# HTTP_422_UNPROCESSABLE_ENTITY - 无法处理的实体
# HTTP_500_INTERNAL_SERVER_ERROR - 服务器错误
7. 表单数据与文件上传
7.1 表单数据
# 安装 python-multipart
pip install python-multipart
from fastapi import FastAPI, Form
app = FastAPI()
# 基本表单
@app.post("/login/")
def login(
username: str = Form(...),
password: str = Form(...)
):
return {"username": username}
# 可选表单字段
@app.post("/profile/")
def update_profile(
username: str = Form(...),
bio: str = Form(None),
avatar_url: str = Form(None)
):
return {
"username": username,
"bio": bio,
"avatar_url": avatar_url
}
7.2 文件上传
from fastapi import FastAPI, File, UploadFile
from typing import List
app = FastAPI()
# 使用 bytes 接收小文件
@app.post("/files/")
def upload_file(file: bytes = File(...)):
return {"file_size": len(file)}
# 使用 UploadFile 接收大文件(推荐)
@app.post("/uploadfile/")
async def upload_file(file: UploadFile):
contents = await file.read()
return {
"filename": file.filename,
"content_type": file.content_type,
"size": len(contents)
}
# 多文件上传
@app.post("/uploadfiles/")
async def upload_files(files: List[UploadFile]):
return {
"filenames": [file.filename for file in files],
"count": len(files)
}
# 文件和表单数据组合
@app.post("/files/with-form/")
async def upload_with_form(
file: UploadFile,
description: str = Form(...),
tags: List[str] = Form(default=[])
):
return {
"filename": file.filename,
"description": description,
"tags": tags
}
# 保存文件到磁盘
import shutil
from pathlib import Path
UPLOAD_DIR = Path("uploads")
UPLOAD_DIR.mkdir(exist_ok=True)
@app.post("/upload/save/")
async def save_upload_file(file: UploadFile):
file_path = UPLOAD_DIR / file.filename
with open(file_path, "wb") as buffer:
shutil.copyfileobj(file.file, buffer)
return {
"filename": file.filename,
"saved_path": str(file_path)
}
8. 依赖注入系统
8.1 依赖注入基础
from fastapi import FastAPI, Depends
from typing import Optional
app = FastAPI()
# 简单依赖
def common_parameters(
q: Optional[str] = None,
skip: int = 0,
limit: int = 100
):
return {"q": q, "skip": skip, "limit": limit}
@app.get("/items/")
def read_items(commons: dict = Depends(common_parameters)):
return commons
@app.get("/users/")
def read_users(commons: dict = Depends(common_parameters)):
return commons
8.2 类作为依赖
from fastapi import FastAPI, Depends
app = FastAPI()
class CommonQueryParams:
def __init__(
self,
q: Optional[str] = None,
skip: int = 0,
limit: int = 100
):
self.q = q
self.skip = skip
self.limit = limit
@app.get("/items/")
def read_items(commons: CommonQueryParams = Depends(CommonQueryParams)):
return {
"q": commons.q,
"skip": commons.skip,
"limit": commons.limit
}
# 简写形式
@app.get("/users/")
def read_users(commons: CommonQueryParams = Depends()):
return commons.__dict__
8.3 嵌套依赖
from fastapi import FastAPI, Depends, HTTPException
app = FastAPI()
def query_extractor(q: Optional[str] = None):
return q
def query_or_cookie_extractor(
q: str = Depends(query_extractor),
last_query: Optional[str] = Cookie(None)
):
if not q:
return last_query
return q
@app.get("/items/")
def read_items(query_or_default: str = Depends(query_or_cookie_extractor)):
return {"q_or_cookie": query_or_default}
8.4 依赖与数据库
from fastapi import FastAPI, Depends
from sqlalchemy.orm import Session
app = FastAPI()
# 数据库会话依赖
def get_db():
db = SessionLocal()
try:
yield db
finally:
db.close()
@app.get("/users/")
def read_users(db: Session = Depends(get_db)):
users = db.query(User).all()
return users
@app.get("/users/{user_id}")
def read_user(user_id: int, db: Session = Depends(get_db)):
user = db.query(User).filter(User.id == user_id).first()
if not user:
raise HTTPException(status_code=404, detail="User not found")
return user
8.5 全局依赖
from fastapi import FastAPI, Depends, Header, HTTPException
async def verify_token(x_token: str = Header(...)):
if x_token != "fake-super-secret-token":
raise HTTPException(status_code=400, detail="X-Token header invalid")
async def verify_key(x_key: str = Header(...)):
if x_key != "fake-super-secret-key":
raise HTTPException(status_code=400, detail="X-Key header invalid")
return x_key
# 应用级别的依赖
app = FastAPI(dependencies=[Depends(verify_token), Depends(verify_key)])
@app.get("/items/")
def read_items():
return [{"item": "Foo"}, {"item": "Bar"}]
# 路由级别的依赖
from fastapi import APIRouter
router = APIRouter(
prefix="/admin",
dependencies=[Depends(verify_token)]
)
9. 中间件与 CORS
9.1 中间件基础
from fastapi import FastAPI, Request
import time
app = FastAPI()
# 自定义中间件
@app.middleware("http")
async def add_process_time_header(request: Request, call_next):
start_time = time.time()
response = await call_next(request)
process_time = time.time() - start_time
response.headers["X-Process-Time"] = str(process_time)
return response
# 请求日志中间件
@app.middleware("http")
async def log_requests(request: Request, call_next):
print(f"Request: {request.method} {request.url}")
response = await call_next(request)
print(f"Response: {response.status_code}")
return response
9.2 CORS 配置
from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware
app = FastAPI()
# 允许的来源
origins = [
"http://localhost",
"http://localhost:3000",
"http://localhost:8080",
"https://example.com",
]
app.add_middleware(
CORSMiddleware,
allow_origins=origins, # 允许的来源列表
allow_credentials=True, # 允许携带凭证
allow_methods=["*"], # 允许所有 HTTP 方法
allow_headers=["*"], # 允许所有请求头
)
# 如果需要允许所有来源
# allow_origins=["*"]
9.3 其他常用中间件
from fastapi import FastAPI
from fastapi.middleware.gzip import GZipMiddleware
from fastapi.middleware.trustedhost import TrustedHostMiddleware
app = FastAPI()
# GZip 压缩
app.add_middleware(GZipMiddleware, minimum_size=1000)
# 信任主机
app.add_middleware(
TrustedHostMiddleware,
allowed_hosts=["example.com", "*.example.com"]
)
# 自定义中间件类
from starlette.middleware.base import BaseHTTPMiddleware
class CustomMiddleware(BaseHTTPMiddleware):
async def dispatch(self, request: Request, call_next):
# 请求处理前
print(f"Before request: {request.url}")
response = await call_next(request)
# 请求处理后
print(f"After request: {response.status_code}")
return response
app.add_middleware(CustomMiddleware)
10. 数据库集成
10.1 SQLAlchemy 集成
pip install sqlalchemy
database.py - 数据库配置:
from sqlalchemy import create_engine
from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy.orm import sessionmaker
# SQLite
SQLALCHEMY_DATABASE_URL = "sqlite:///./app.db"
# PostgreSQL
# SQLALCHEMY_DATABASE_URL = "postgresql://user:password@localhost/dbname"
# MySQL
# SQLALCHEMY_DATABASE_URL = "mysql+pymysql://user:password@localhost/dbname"
engine = create_engine(
SQLALCHEMY_DATABASE_URL,
connect_args={"check_same_thread": False} # 仅 SQLite 需要
)
SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
Base = declarative_base()
# 依赖项
def get_db():
db = SessionLocal()
try:
yield db
finally:
db.close()
models.py - 数据库模型:
from sqlalchemy import Column, Integer, String, Boolean, ForeignKey, DateTime
from sqlalchemy.orm import relationship
from datetime import datetime
from .database import Base
class User(Base):
__tablename__ = "users"
id = Column(Integer, primary_key=True, index=True)
email = Column(String(100), unique=True, index=True)
username = Column(String(50), unique=True, index=True)
hashed_password = Column(String(200))
is_active = Column(Boolean, default=True)
created_at = Column(DateTime, default=datetime.utcnow)
items = relationship("Item", back_populates="owner")
class Item(Base):
__tablename__ = "items"
id = Column(Integer, primary_key=True, index=True)
title = Column(String(100), index=True)
description = Column(String(500))
owner_id = Column(Integer, ForeignKey("users.id"))
owner = relationship("User", back_populates="items")
schemas.py - Pydantic 模型:
from pydantic import BaseModel, EmailStr
from typing import Optional, List
from datetime import datetime
class ItemBase(BaseModel):
title: str
description: Optional[str] = None
class ItemCreate(ItemBase):
pass
class Item(ItemBase):
id: int
owner_id: int
class Config:
orm_mode = True
class UserBase(BaseModel):
email: EmailStr
username: str
class UserCreate(UserBase):
password: str
class User(UserBase):
id: int
is_active: bool
created_at: datetime
items: List[Item] = []
class Config:
orm_mode = True
crud.py - CRUD 操作:
from sqlalchemy.orm import Session
from . import models, schemas
from passlib.context import CryptContext
pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")
def get_user(db: Session, user_id: int):
return db.query(models.User).filter(models.User.id == user_id).first()
def get_user_by_email(db: Session, email: str):
return db.query(models.User).filter(models.User.email == email).first()
def get_users(db: Session, skip: int = 0, limit: int = 100):
return db.query(models.User).offset(skip).limit(limit).all()
def create_user(db: Session, user: schemas.UserCreate):
hashed_password = pwd_context.hash(user.password)
db_user = models.User(
email=user.email,
username=user.username,
hashed_password=hashed_password
)
db.add(db_user)
db.commit()
db.refresh(db_user)
return db_user
def create_user_item(db: Session, item: schemas.ItemCreate, user_id: int):
db_item = models.Item(**item.dict(), owner_id=user_id)
db.add(db_item)
db.commit()
db.refresh(db_item)
return db_item
main.py - 主应用:
from fastapi import FastAPI, Depends, HTTPException
from sqlalchemy.orm import Session
from . import crud, models, schemas
from .database import SessionLocal, engine, get_db
models.Base.metadata.create_all(bind=engine)
app = FastAPI()
@app.post("/users/", response_model=schemas.User)
def create_user(user: schemas.UserCreate, db: Session = Depends(get_db)):
db_user = crud.get_user_by_email(db, email=user.email)
if db_user:
raise HTTPException(status_code=400, detail="Email already registered")
return crud.create_user(db=db, user=user)
@app.get("/users/", response_model=List[schemas.User])
def read_users(skip: int = 0, limit: int = 100, db: Session = Depends(get_db)):
users = crud.get_users(db, skip=skip, limit=limit)
return users
@app.get("/users/{user_id}", response_model=schemas.User)
def read_user(user_id: int, db: Session = Depends(get_db)):
db_user = crud.get_user(db, user_id=user_id)
if db_user is None:
raise HTTPException(status_code=404, detail="User not found")
return db_user
@app.post("/users/{user_id}/items/", response_model=schemas.Item)
def create_item_for_user(
user_id: int,
item: schemas.ItemCreate,
db: Session = Depends(get_db)
):
return crud.create_user_item(db=db, item=item, user_id=user_id)
10.2 异步数据库 (databases + SQLAlchemy)
pip install databases[sqlite]
# 或
pip install databases[postgresql]
pip install databases[mysql]
from fastapi import FastAPI
from databases import Database
from sqlalchemy import create_engine, MetaData, Table, Column, Integer, String
DATABASE_URL = "sqlite:///./test.db"
database = Database(DATABASE_URL)
metadata = MetaData()
users = Table(
"users",
metadata,
Column("id", Integer, primary_key=True),
Column("name", String(50)),
Column("email", String(100)),
)
engine = create_engine(DATABASE_URL)
metadata.create_all(engine)
app = FastAPI()
@app.on_event("startup")
async def startup():
await database.connect()
@app.on_event("shutdown")
async def shutdown():
await database.disconnect()
@app.get("/users/")
async def read_users():
query = users.select()
return await database.fetch_all(query)
@app.post("/users/")
async def create_user(name: str, email: str):
query = users.insert().values(name=name, email=email)
last_record_id = await database.execute(query)
return {"id": last_record_id, "name": name, "email": email}
11. 用户认证与授权
11.1 OAuth2 密码模式
pip install python-jose[cryptography] passlib[bcrypt]
from datetime import datetime, timedelta
from typing import Optional
from fastapi import FastAPI, Depends, HTTPException, status
from fastapi.security import OAuth2PasswordBearer, OAuth2PasswordRequestForm
from jose import JWTError, jwt
from passlib.context import CryptContext
from pydantic import BaseModel
# 配置
SECRET_KEY = "your-secret-key-here"
ALGORITHM = "HS256"
ACCESS_TOKEN_EXPIRE_MINUTES = 30
# 密码加密
pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")
# OAuth2 配置
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token")
app = FastAPI()
# 模型
class Token(BaseModel):
access_token: str
token_type: str
class TokenData(BaseModel):
username: Optional[str] = None
class User(BaseModel):
username: str
email: Optional[str] = None
full_name: Optional[str] = None
disabled: Optional[bool] = None
class UserInDB(User):
hashed_password: str
# 模拟数据库
fake_users_db = {
"johndoe": {
"username": "johndoe",
"full_name": "John Doe",
"email": "johndoe@example.com",
"hashed_password": pwd_context.hash("secret"),
"disabled": False,
}
}
# 工具函数
def verify_password(plain_password, hashed_password):
return pwd_context.verify(plain_password, hashed_password)
def get_password_hash(password):
return pwd_context.hash(password)
def get_user(db, username: str):
if username in db:
user_dict = db[username]
return UserInDB(**user_dict)
def authenticate_user(db, username: str, password: str):
user = get_user(db, username)
if not user:
return False
if not verify_password(password, user.hashed_password):
return False
return user
def create_access_token(data: dict, expires_delta: Optional[timedelta] = None):
to_encode = data.copy()
if expires_delta:
expire = datetime.utcnow() + expires_delta
else:
expire = datetime.utcnow() + timedelta(minutes=15)
to_encode.update({"exp": expire})
encoded_jwt = jwt.encode(to_encode, SECRET_KEY, algorithm=ALGORITHM)
return encoded_jwt
async def get_current_user(token: str = Depends(oauth2_scheme)):
credentials_exception = HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Could not validate credentials",
headers={"WWW-Authenticate": "Bearer"},
)
try:
payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM])
username: str = payload.get("sub")
if username is None:
raise credentials_exception
token_data = TokenData(username=username)
except JWTError:
raise credentials_exception
user = get_user(fake_users_db, username=token_data.username)
if user is None:
raise credentials_exception
return user
async def get_current_active_user(current_user: User = Depends(get_current_user)):
if current_user.disabled:
raise HTTPException(status_code=400, detail="Inactive user")
return current_user
# 路由
@app.post("/token", response_model=Token)
async def login_for_access_token(form_data: OAuth2PasswordRequestForm = Depends()):
user = authenticate_user(fake_users_db, form_data.username, form_data.password)
if not user:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Incorrect username or password",
headers={"WWW-Authenticate": "Bearer"},
)
access_token_expires = timedelta(minutes=ACCESS_TOKEN_EXPIRE_MINUTES)
access_token = create_access_token(
data={"sub": user.username}, expires_delta=access_token_expires
)
return {"access_token": access_token, "token_type": "bearer"}
@app.get("/users/me/", response_model=User)
async def read_users_me(current_user: User = Depends(get_current_active_user)):
return current_user
@app.get("/users/me/items/")
async def read_own_items(current_user: User = Depends(get_current_active_user)):
return [{"item_id": "Foo", "owner": current_user.username}]
11.2 基于角色的权限控制
from enum import Enum
from typing import List
class Role(str, Enum):
admin = "admin"
user = "user"
guest = "guest"
class UserWithRole(User):
roles: List[Role] = [Role.user]
def check_roles(allowed_roles: List[Role]):
async def role_checker(current_user: UserWithRole = Depends(get_current_active_user)):
for role in current_user.roles:
if role in allowed_roles:
return current_user
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="Not enough permissions"
)
return role_checker
# 使用
@app.get("/admin/")
async def admin_route(user: UserWithRole = Depends(check_roles([Role.admin]))):
return {"message": "Welcome, admin!"}
@app.get("/user/")
async def user_route(user: UserWithRole = Depends(check_roles([Role.admin, Role.user]))):
return {"message": f"Welcome, {user.username}!"}
12. 错误处理
12.1 HTTPException
from fastapi import FastAPI, HTTPException, status
app = FastAPI()
items = {"foo": "The Foo Item"}
@app.get("/items/{item_id}")
def read_item(item_id: str):
if item_id not in items:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="Item not found",
headers={"X-Error": "There goes my error"}
)
return {"item": items[item_id]}
12.2 自定义异常处理器
from fastapi import FastAPI, Request
from fastapi.responses import JSONResponse
app = FastAPI()
# 自定义异常类
class UnicornException(Exception):
def __init__(self, name: str):
self.name = name
# 注册异常处理器
@app.exception_handler(UnicornException)
async def unicorn_exception_handler(request: Request, exc: UnicornException):
return JSONResponse(
status_code=418,
content={
"message": f"Oops! {exc.name} did something wrong.",
"path": str(request.url)
}
)
@app.get("/unicorns/{name}")
def read_unicorn(name: str):
if name == "yolo":
raise UnicornException(name=name)
return {"unicorn_name": name}
# 覆盖默认异常处理器
from fastapi.exceptions import RequestValidationError
from starlette.exceptions import HTTPException as StarletteHTTPException
@app.exception_handler(StarletteHTTPException)
async def http_exception_handler(request: Request, exc: StarletteHTTPException):
return JSONResponse(
status_code=exc.status_code,
content={
"error": True,
"message": exc.detail,
"status_code": exc.status_code
}
)
@app.exception_handler(RequestValidationError)
async def validation_exception_handler(request: Request, exc: RequestValidationError):
return JSONResponse(
status_code=422,
content={
"error": True,
"message": "Validation Error",
"details": exc.errors()
}
)
12.3 统一错误响应格式
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
from typing import Optional, Any
class ErrorResponse(BaseModel):
success: bool = False
error_code: str
message: str
details: Optional[Any] = None
class AppException(Exception):
def __init__(
self,
status_code: int,
error_code: str,
message: str,
details: Any = None
):
self.status_code = status_code
self.error_code = error_code
self.message = message
self.details = details
app = FastAPI()
@app.exception_handler(AppException)
async def app_exception_handler(request: Request, exc: AppException):
return JSONResponse(
status_code=exc.status_code,
content=ErrorResponse(
error_code=exc.error_code,
message=exc.message,
details=exc.details
).dict()
)
# 使用
@app.get("/users/{user_id}")
def get_user(user_id: int):
if user_id <= 0:
raise AppException(
status_code=400,
error_code="INVALID_USER_ID",
message="User ID must be positive",
details={"user_id": user_id}
)
# ...
13. 后台任务
13.1 BackgroundTasks
from fastapi import FastAPI, BackgroundTasks
from pydantic import BaseModel, EmailStr
app = FastAPI()
# 后台任务函数
def write_log(message: str):
with open("log.txt", mode="a") as log_file:
log_file.write(f"{message}\n")
def send_email(email: str, message: str):
# 模拟发送邮件
import time
time.sleep(5) # 模拟耗时操作
print(f"Email sent to {email}: {message}")
# 基本使用
@app.post("/items/")
def create_item(background_tasks: BackgroundTasks, name: str):
background_tasks.add_task(write_log, f"Item created: {name}")
return {"message": "Item created", "name": name}
# 发送邮件
class EmailSchema(BaseModel):
email: EmailStr
body: str
@app.post("/send-notification/")
async def send_notification(
email: EmailSchema,
background_tasks: BackgroundTasks
):
background_tasks.add_task(send_email, email.email, email.body)
return {"message": "Notification sent in the background"}
# 在依赖中使用
def get_query(background_tasks: BackgroundTasks, q: Optional[str] = None):
if q:
background_tasks.add_task(write_log, f"Query: {q}")
return q
@app.get("/search/")
def search(q: str = Depends(get_query)):
return {"q": q}
13.2 Celery 集成(重量级任务)
pip install celery redis
celery_worker.py:
from celery import Celery
celery_app = Celery(
"worker",
broker="redis://localhost:6379/0",
backend="redis://localhost:6379/0"
)
celery_app.conf.task_routes = {
"celery_worker.send_email": "main-queue"
}
@celery_app.task
def send_email(email: str, message: str):
import time
time.sleep(10) # 模拟发送邮件
return f"Email sent to {email}"
@celery_app.task
def process_data(data: dict):
# 处理数据
return {"status": "processed", "data": data}
main.py:
from fastapi import FastAPI
from celery_worker import send_email, process_data
app = FastAPI()
@app.post("/send-email/")
def trigger_send_email(email: str, message: str):
task = send_email.delay(email, message)
return {"task_id": task.id}
@app.get("/task/{task_id}")
def get_task_status(task_id: str):
task = celery_app.AsyncResult(task_id)
return {
"task_id": task_id,
"status": task.status,
"result": task.result if task.ready() else None
}
14. WebSocket 支持
14.1 基本 WebSocket
from fastapi import FastAPI, WebSocket, WebSocketDisconnect
from typing import List
app = FastAPI()
# 基本 WebSocket
@app.websocket("/ws")
async def websocket_endpoint(websocket: WebSocket):
await websocket.accept()
try:
while True:
data = await websocket.receive_text()
await websocket.send_text(f"Message received: {data}")
except WebSocketDisconnect:
print("Client disconnected")
# 连接管理器
class ConnectionManager:
def __init__(self):
self.active_connections: List[WebSocket] = []
async def connect(self, websocket: WebSocket):
await websocket.accept()
self.active_connections.append(websocket)
def disconnect(self, websocket: WebSocket):
self.active_connections.remove(websocket)
async def send_personal_message(self, message: str, websocket: WebSocket):
await websocket.send_text(message)
async def broadcast(self, message: str):
for connection in self.active_connections:
await connection.send_text(message)
manager = ConnectionManager()
@app.websocket("/ws/{client_id}")
async def websocket_endpoint(websocket: WebSocket, client_id: int):
await manager.connect(websocket)
try:
while True:
data = await websocket.receive_text()
await manager.send_personal_message(f"You wrote: {data}", websocket)
await manager.broadcast(f"Client #{client_id} says: {data}")
except WebSocketDisconnect:
manager.disconnect(websocket)
await manager.broadcast(f"Client #{client_id} left the chat")
14.2 带认证的 WebSocket
from fastapi import WebSocket, Query, status
async def get_token_from_query(
websocket: WebSocket,
token: str = Query(...)
):
# 验证 token
if token != "valid_token":
await websocket.close(code=status.WS_1008_POLICY_VIOLATION)
return None
return token
@app.websocket("/ws/secure")
async def secure_websocket(
websocket: WebSocket,
token: str = Depends(get_token_from_query)
):
if token is None:
return
await websocket.accept()
try:
while True:
data = await websocket.receive_text()
await websocket.send_text(f"Secure message: {data}")
except WebSocketDisconnect:
pass
15. 测试与部署
15.1 测试
pip install pytest httpx
test_main.py:
from fastapi.testclient import TestClient
from main import app
client = TestClient(app)
def test_read_root():
response = client.get("/")
assert response.status_code == 200
assert response.json() == {"message": "Hello, FastAPI!"}
def test_read_item():
response = client.get("/items/1")
assert response.status_code == 200
assert "item_id" in response.json()
def test_read_item_not_found():
response = client.get("/items/999")
assert response.status_code == 404
def test_create_item():
response = client.post(
"/items/",
json={"name": "Foo", "price": 45.0}
)
assert response.status_code == 201
assert response.json()["name"] == "Foo"
# 异步测试
import pytest
from httpx import AsyncClient
@pytest.mark.anyio
async def test_async_root():
async with AsyncClient(app=app, base_url="http://test") as ac:
response = await ac.get("/")
assert response.status_code == 200
# 测试数据库
from sqlalchemy import create_engine
from sqlalchemy.orm import sessionmaker
SQLALCHEMY_DATABASE_URL = "sqlite:///./test.db"
engine = create_engine(SQLALCHEMY_DATABASE_URL)
TestingSessionLocal = sessionmaker(bind=engine)
def override_get_db():
try:
db = TestingSessionLocal()
yield db
finally:
db.close()
app.dependency_overrides[get_db] = override_get_db
15.2 部署
使用 Uvicorn 生产模式:
uvicorn main:app --host 0.0.0.0 --port 8000 --workers 4
使用 Gunicorn + Uvicorn:
pip install gunicorn
gunicorn main:app -w 4 -k uvicorn.workers.UvicornWorker -b 0.0.0.0:8000
Docker 部署:
# Dockerfile
FROM python:3.11-slim
WORKDIR /app
COPY requirements.txt .
RUN pip install --no-cache-dir -r requirements.txt
COPY . .
EXPOSE 8000
CMD ["uvicorn", "main:app", "--host", "0.0.0.0", "--port", "8000"]
# docker-compose.yml
version: '3.8'
services:
web:
build: .
ports:
- "8000:8000"
environment:
- DATABASE_URL=postgresql://user:pass@db:5432/dbname
depends_on:
- db
db:
image: postgres:15
environment:
- POSTGRES_USER=user
- POSTGRES_PASSWORD=pass
- POSTGRES_DB=dbname
volumes:
- postgres_data:/var/lib/postgresql/data
volumes:
postgres_data:
16. 项目实战:构建完整 API
16.1 项目结构
blog_api/
├── app/
│ ├── __init__.py
│ ├── main.py
│ ├── config.py
│ ├── database.py
│ ├── models/
│ │ ├── __init__.py
│ │ ├── user.py
│ │ └── post.py
│ ├── schemas/
│ │ ├── __init__.py
│ │ ├── user.py
│ │ └── post.py
│ ├── routers/
│ │ ├── __init__.py
│ │ ├── auth.py
│ │ ├── users.py
│ │ └── posts.py
│ ├── services/
│ │ ├── __init__.py
│ │ ├── user_service.py
│ │ └── post_service.py
│ └── utils/
│ ├── __init__.py
│ ├── security.py
│ └── deps.py
├── tests/
├── requirements.txt
└── README.md
16.2 核心代码
config.py:
from pydantic import BaseSettings
class Settings(BaseSettings):
app_name: str = "Blog API"
database_url: str = "sqlite:///./blog.db"
secret_key: str = "your-secret-key"
access_token_expire_minutes: int = 30
class Config:
env_file = ".env"
settings = Settings()
main.py:
from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware
from app.routers import auth, users, posts
from app.database import engine, Base
from app.config import settings
# 创建数据库表
Base.metadata.create_all(bind=engine)
app = FastAPI(
title=settings.app_name,
description="一个完整的博客 API",
version="1.0.0"
)
# CORS
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# 路由
app.include_router(auth.router, prefix="/auth", tags=["认证"])
app.include_router(users.router, prefix="/users", tags=["用户"])
app.include_router(posts.router, prefix="/posts", tags=["文章"])
@app.get("/")
def root():
return {"message": "Welcome to Blog API"}
@app.get("/health")
def health_check():
return {"status": "healthy"}
routers/posts.py:
from fastapi import APIRouter, Depends, HTTPException, status
from sqlalchemy.orm import Session
from typing import List
from app.schemas.post import Post, PostCreate, PostUpdate
from app.services import post_service
from app.utils.deps import get_db, get_current_user
router = APIRouter()
@router.get("/", response_model=List[Post])
def get_posts(
skip: int = 0,
limit: int = 10,
db: Session = Depends(get_db)
):
return post_service.get_posts(db, skip=skip, limit=limit)
@router.get("/{post_id}", response_model=Post)
def get_post(post_id: int, db: Session = Depends(get_db)):
post = post_service.get_post(db, post_id)
if not post:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="Post not found"
)
return post
@router.post("/", response_model=Post, status_code=status.HTTP_201_CREATED)
def create_post(
post: PostCreate,
db: Session = Depends(get_db),
current_user = Depends(get_current_user)
):
return post_service.create_post(db, post, current_user.id)
@router.put("/{post_id}", response_model=Post)
def update_post(
post_id: int,
post: PostUpdate,
db: Session = Depends(get_db),
current_user = Depends(get_current_user)
):
db_post = post_service.get_post(db, post_id)
if not db_post:
raise HTTPException(status_code=404, detail="Post not found")
if db_post.author_id != current_user.id:
raise HTTPException(status_code=403, detail="Not authorized")
return post_service.update_post(db, post_id, post)
@router.delete("/{post_id}", status_code=status.HTTP_204_NO_CONTENT)
def delete_post(
post_id: int,
db: Session = Depends(get_db),
current_user = Depends(get_current_user)
):
db_post = post_service.get_post(db, post_id)
if not db_post:
raise HTTPException(status_code=404, detail="Post not found")
if db_post.author_id != current_user.id:
raise HTTPException(status_code=403, detail="Not authorized")
post_service.delete_post(db, post_id)
return None
总结
FastAPI 是一个强大且易用的现代 Python Web 框架。本教程涵盖了:
- 基础概念: 路由、参数、请求体
- 数据验证: Pydantic 模型、Field 验证
- 响应处理: 响应模型、状态码
- 依赖注入: 代码复用、数据库连接
- 安全认证: OAuth2、JWT
- 中间件: CORS、自定义中间件
- 数据库: SQLAlchemy 集成
- 高级特性: WebSocket、后台任务
- 测试与部署: pytest、Docker
学习资源
- 官方文档: https://fastapi.tiangolo.com
- GitHub: https://github.com/tiangolo/fastapi
- Pydantic 文档: https://pydantic-docs.helpmanual.io
最佳实践
- 始终使用类型提示
- 合理组织项目结构
- 使用依赖注入管理数据库连接
- 编写单元测试
- 使用环境变量管理配置
- 在生产环境使用 Gunicorn + Uvicorn
Happy Coding! 🚀
更多推荐

所有评论(0)