LangChain工具输入:Pydantic BaseModel指南
使用场景: 当参数只能是几个固定值之一时使用参数将 Pydantic BaseModel 传递给@tool装饰器总是添加描述使用为每个字段添加描述使用 Literal限制参数为特定的枚举值使用 Field 约束如gelemin_lengthmax_length等验证输入提供默认值让可选参数更容易使用使用类型提示如OptionalListDict等表达复杂类型。
·
使用 Pydantic BaseModel 定义 LangChain 工具输入
目录
为什么使用 Pydantic BaseModel
传统方式的局限
from langchain.tools import tool
@tool
def get_weather(location: str, units: str = "celsius") -> str:
"""Get weather information."""
# 问题:
# 1. 无法限制 units 只能是 "celsius" 或 "fahrenheit"
# 2. 无法为每个参数添加详细描述
# 3. 无法进行复杂的数据验证
pass
Pydantic BaseModel 的优势
✅ 类型验证 - 自动验证输入数据类型
✅ 字段描述 - 为每个字段添加详细说明,帮助模型理解
✅ 约束定义 - 限制值的范围(如枚举、最小/最大值)
✅ 复杂类型 - 支持嵌套对象、列表、可选字段等
✅ 默认值 - 为字段设置默认值
✅ 文档生成 - 自动生成清晰的 JSON Schema
基础用法
步骤 1: 定义输入 Schema
from pydantic import BaseModel, Field
from typing import Literal
class WeatherInput(BaseModel):
"""Input for weather queries."""
location: str = Field(description="City name or coordinates")
units: Literal["celsius", "fahrenheit"] = Field(
default="celsius",
description="Temperature unit preference"
)
步骤 2: 使用 args_schema 参数
from langchain.tools import tool
@tool(args_schema=WeatherInput)
def get_weather(location: str, units: str = "celsius") -> str:
"""Get current weather information."""
temp = 22 if units == "celsius" else 72
return f"Weather in {location}: {temp}°{units[0].upper()}"
完整代码
from langchain.tools import tool
from langchain.agents import create_agent
from langchain_openai import ChatOpenAI
from pydantic import BaseModel, Field
from typing import Literal
# 1. 定义输入 Schema
class WeatherInput(BaseModel):
"""Input for weather queries."""
location: str = Field(description="City name or coordinates")
units: Literal["celsius", "fahrenheit"] = Field(
default="celsius",
description="Temperature unit preference"
)
# 2. 创建工具
@tool(args_schema=WeatherInput)
def get_weather(location: str, units: str = "celsius") -> str:
"""Get current weather information."""
temp = 22 if units == "celsius" else 72
return f"Weather in {location}: {temp}°{units[0].upper()}"
# 3. 使用工具
model = ChatOpenAI(model="gpt-4")
agent = create_agent(model, tools=[get_weather])
result = agent.invoke({
"messages": [{"role": "user", "content": "What's the weather in Paris in fahrenheit?"}]
})
print(result["messages"][-1].content)
Field 字段详解
Field 参数说明
from pydantic import BaseModel, Field
class ExampleInput(BaseModel):
field_name: type = Field(
default=..., # 默认值(... 表示必需)
description="...", # 字段描述(重要!)
examples=["..."], # 示例值
min_length=1, # 最小长度(字符串)
max_length=100, # 最大长度(字符串)
ge=0, # 大于等于(数字)
le=100, # 小于等于(数字)
gt=0, # 大于(数字)
lt=100, # 小于(数字)
pattern="^[A-Z]", # 正则表达式(字符串)
alias="fieldName", # 别名
)
常用 Field 示例
1. 必需字段
class UserInput(BaseModel):
user_id: str = Field(description="User unique identifier")
# 或显式标记为必需
user_id: str = Field(..., description="User unique identifier")
2. 可选字段(带默认值)
class SearchInput(BaseModel):
query: str = Field(description="Search query")
limit: int = Field(default=10, description="Maximum results")
offset: int = Field(default=0, description="Skip first N results")
3. 字符串约束
class UserProfileInput(BaseModel):
username: str = Field(
description="Username (3-20 characters)",
min_length=3,
max_length=20,
pattern="^[a-zA-Z0-9_]+$" # 只允许字母、数字、下划线
)
email: str = Field(
description="User email address",
pattern=r"^[\w\.-]+@[\w\.-]+\.\w+$"
)
4. 数字约束
class ProductFilterInput(BaseModel):
min_price: float = Field(
default=0.0,
ge=0.0, # 大于等于 0
description="Minimum price in USD"
)
max_price: float = Field(
default=1000.0,
le=10000.0, # 小于等于 10000
description="Maximum price in USD"
)
quantity: int = Field(
default=1,
gt=0, # 大于 0
lt=100, # 小于 100
description="Quantity to order"
)
5. 示例值
class QueryInput(BaseModel):
search_term: str = Field(
description="Search keywords",
examples=["machine learning", "python tutorial", "AI news"]
)
高级类型定义
1. 枚举类型(Literal)
from typing import Literal
class WeatherInput(BaseModel):
location: str = Field(description="City name")
units: Literal["celsius", "fahrenheit", "kelvin"] = Field(
default="celsius",
description="Temperature unit"
)
forecast_type: Literal["hourly", "daily", "weekly"] = Field(
default="daily",
description="Forecast time range"
)
使用场景: 当参数只能是几个固定值之一时
2. 可选字段(Optional)
from typing import Optional
class SearchInput(BaseModel):
query: str = Field(description="Search query (required)")
category: Optional[str] = Field(
default=None,
description="Filter by category (optional)"
)
min_rating: Optional[float] = Field(
default=None,
description="Minimum rating filter"
)
使用场景: 字段可以为 None 或不提供
3. 列表类型(List)
from typing import List
class BatchQueryInput(BaseModel):
queries: List[str] = Field(
description="List of search queries to process"
)
tags: List[str] = Field(
default=[],
description="Filter by tags"
)
user_ids: List[int] = Field(
description="List of user IDs to query",
min_length=1, # 至少一个元素
max_length=10 # 最多 10 个元素
)
4. 嵌套对象
class Address(BaseModel):
street: str = Field(description="Street address")
city: str = Field(description="City name")
country: str = Field(description="Country name")
postal_code: str = Field(description="Postal code")
class CreateUserInput(BaseModel):
name: str = Field(description="User's full name")
email: str = Field(description="User's email address")
address: Address = Field(description="User's address")
5. 联合类型(Union)
from typing import Union
class SearchInput(BaseModel):
identifier: Union[int, str] = Field(
description="User ID (int) or username (str)"
)
date_filter: Union[str, int] = Field(
description="Date as string (YYYY-MM-DD) or unix timestamp"
)
6. 字典类型(Dict)
from typing import Dict
class MetadataInput(BaseModel):
filters: Dict[str, str] = Field(
default={},
description="Key-value pairs for filtering"
)
settings: Dict[str, int] = Field(
default={"timeout": 30, "retries": 3},
description="Configuration settings"
)
7. 复杂组合
from typing import List, Optional, Literal, Dict
class AdvancedSearchInput(BaseModel):
"""Advanced search with multiple filters."""
# 必需字段
query: str = Field(
description="Main search query",
min_length=1,
max_length=200
)
# 枚举类型
sort_by: Literal["relevance", "date", "popularity"] = Field(
default="relevance",
description="Sort results by"
)
# 可选字段
category: Optional[str] = Field(
default=None,
description="Filter by category"
)
# 列表
tags: List[str] = Field(
default=[],
description="Filter by tags"
)
# 数字约束
limit: int = Field(
default=10,
ge=1,
le=100,
description="Maximum results"
)
# 布尔值
include_archived: bool = Field(
default=False,
description="Include archived items"
)
# 字典
metadata: Dict[str, str] = Field(
default={},
description="Additional metadata filters"
)
完整示例
示例 1: 天气查询工具
from langchain.tools import tool
from langchain.agents import create_agent
from langchain_openai import ChatOpenAI
from pydantic import BaseModel, Field
from typing import Literal, Optional
# 定义输入 Schema
class WeatherInput(BaseModel):
"""Input schema for weather queries."""
location: str = Field(
description="City name or coordinates (e.g., 'Paris' or '48.8566,2.3522')",
examples=["Paris", "New York", "Tokyo"]
)
units: Literal["celsius", "fahrenheit"] = Field(
default="celsius",
description="Temperature unit preference"
)
include_forecast: bool = Field(
default=False,
description="Include 5-day forecast"
)
days: Optional[int] = Field(
default=None,
ge=1,
le=7,
description="Number of forecast days (1-7)"
)
# 创建工具
@tool(args_schema=WeatherInput)
def get_weather(
location: str,
units: str = "celsius",
include_forecast: bool = False,
days: Optional[int] = None
) -> str:
"""Get current weather and optional forecast for a location.
This tool provides current weather information and can optionally
include a multi-day forecast.
"""
# 模拟天气数据
temp = 22 if units == "celsius" else 72
result = f"Current weather in {location}: {temp}°{units[0].upper()}, Sunny"
if include_forecast and days:
result += f"\n\n{days}-day forecast:"
for i in range(1, days + 1):
temp_day = temp + i
result += f"\n Day {i}: {temp_day}°{units[0].upper()}, Partly cloudy"
return result
# 使用工具
if __name__ == "__main__":
model = ChatOpenAI(model="gpt-4")
agent = create_agent(model, tools=[get_weather])
# 测试 1: 简单查询
result = agent.invoke({
"messages": [{"role": "user", "content": "What's the weather in Paris?"}]
})
print(result["messages"][-1].content)
# 测试 2: 带预报
result = agent.invoke({
"messages": [{"role": "user", "content": "Give me the weather in Tokyo with a 3-day forecast in fahrenheit"}]
})
print(result["messages"][-1].content)
示例 2: 数据库查询工具
from langchain.tools import tool
from pydantic import BaseModel, Field
from typing import List, Literal, Optional
class DatabaseQueryInput(BaseModel):
"""Input for database query operations."""
table_name: str = Field(
description="Name of the table to query",
examples=["users", "products", "orders"]
)
columns: List[str] = Field(
default=["*"],
description="Columns to select (default: all columns)",
examples=[["id", "name", "email"], ["*"]]
)
filters: Optional[dict] = Field(
default=None,
description="Filter conditions as key-value pairs",
examples=[{"status": "active"}, {"age": 25}]
)
order_by: Optional[str] = Field(
default=None,
description="Column to sort by",
examples=["created_at", "name"]
)
order_direction: Literal["ASC", "DESC"] = Field(
default="ASC",
description="Sort direction"
)
limit: int = Field(
default=10,
ge=1,
le=100,
description="Maximum number of results"
)
@tool(args_schema=DatabaseQueryInput)
def query_database(
table_name: str,
columns: List[str] = ["*"],
filters: Optional[dict] = None,
order_by: Optional[str] = None,
order_direction: str = "ASC",
limit: int = 10
) -> str:
"""Query the database with specified filters and options.
This tool allows you to retrieve data from database tables
with flexible filtering and sorting options.
"""
# 构建 SQL 查询(示例)
cols = ", ".join(columns)
query = f"SELECT {cols} FROM {table_name}"
if filters:
conditions = [f"{k} = '{v}'" for k, v in filters.items()]
query += " WHERE " + " AND ".join(conditions)
if order_by:
query += f" ORDER BY {order_by} {order_direction}"
query += f" LIMIT {limit}"
# 模拟查询结果
return f"Executed query: {query}\n\nResults: [Sample data would appear here]"
示例 3: 文件操作工具
from langchain.tools import tool
from pydantic import BaseModel, Field
from typing import Literal, Optional
class FileOperationInput(BaseModel):
"""Input for file operations."""
operation: Literal["read", "write", "delete", "list"] = Field(
description="Type of file operation to perform"
)
file_path: str = Field(
description="Path to the file",
examples=["/path/to/file.txt", "documents/report.pdf"]
)
content: Optional[str] = Field(
default=None,
description="Content to write (only for 'write' operation)"
)
encoding: str = Field(
default="utf-8",
description="File encoding"
)
mode: Literal["text", "binary"] = Field(
default="text",
description="Read/write mode"
)
@tool(args_schema=FileOperationInput)
def file_operation(
operation: str,
file_path: str,
content: Optional[str] = None,
encoding: str = "utf-8",
mode: str = "text"
) -> str:
"""Perform file operations (read, write, delete, list).
This tool allows safe file system operations with
proper encoding and mode handling.
"""
if operation == "read":
return f"Reading file: {file_path}\n[File content would appear here]"
elif operation == "write":
return f"Writing to file: {file_path}\nContent: {content[:50]}..."
elif operation == "delete":
return f"Deleting file: {file_path}"
elif operation == "list":
return f"Listing directory: {file_path}\n- file1.txt\n- file2.pdf"
else:
return f"Unknown operation: {operation}"
示例 4: API 调用工具
from langchain.tools import tool
from pydantic import BaseModel, Field
from typing import Dict, Optional, Literal, List
class APICallInput(BaseModel):
"""Input for making API requests."""
url: str = Field(
description="API endpoint URL",
pattern="^https?://", # 必须是 http 或 https
examples=["https://api.example.com/users"]
)
method: Literal["GET", "POST", "PUT", "DELETE", "PATCH"] = Field(
default="GET",
description="HTTP method"
)
headers: Optional[Dict[str, str]] = Field(
default=None,
description="HTTP headers",
examples=[{"Authorization": "Bearer token123"}]
)
params: Optional[Dict[str, str]] = Field(
default=None,
description="Query parameters",
examples=[{"page": "1", "limit": "10"}]
)
body: Optional[Dict] = Field(
default=None,
description="Request body (for POST, PUT, PATCH)"
)
timeout: int = Field(
default=30,
ge=1,
le=300,
description="Request timeout in seconds"
)
@tool(args_schema=APICallInput)
def api_call(
url: str,
method: str = "GET",
headers: Optional[Dict[str, str]] = None,
params: Optional[Dict[str, str]] = None,
body: Optional[Dict] = None,
timeout: int = 30
) -> str:
"""Make HTTP API requests with specified parameters.
This tool allows you to interact with external APIs
using various HTTP methods and parameters.
"""
result = f"API Call:\n"
result += f" Method: {method}\n"
result += f" URL: {url}\n"
if params:
result += f" Params: {params}\n"
if headers:
result += f" Headers: {headers}\n"
if body:
result += f" Body: {body}\n"
# 模拟 API 响应
result += f"\nResponse: [API response would appear here]"
return result
最佳实践
1. 总是添加描述
# ❌ 不好 - 没有描述
class BadInput(BaseModel):
location: str
units: str
# ✅ 好 - 有清晰的描述
class GoodInput(BaseModel):
"""Input for weather queries."""
location: str = Field(description="City name or coordinates")
units: str = Field(description="Temperature unit (celsius or fahrenheit)")
为什么重要: 描述帮助 LLM 理解何时以及如何使用每个参数。
2. 使用 Literal 限制选项
# ❌ 不好 - 字符串可以是任何值
units: str = Field(description="Temperature unit")
# ✅ 好 - 限制为特定值
units: Literal["celsius", "fahrenheit"] = Field(
description="Temperature unit"
)
3. 提供默认值
class SearchInput(BaseModel):
query: str = Field(description="Search query")
limit: int = Field(default=10, description="Results limit") # ✅ 有默认值
offset: int = Field(default=0, description="Skip N results") # ✅ 有默认值
4. 添加示例
class UserInput(BaseModel):
email: str = Field(
description="User email address",
examples=["user@example.com", "john.doe@company.org"]
)
5. 使用约束验证
class ProductInput(BaseModel):
name: str = Field(
description="Product name",
min_length=3,
max_length=100
)
price: float = Field(
description="Product price",
gt=0, # 必须大于 0
le=10000 # 最多 10000
)
quantity: int = Field(
description="Quantity",
ge=1, # 至少 1
le=100 # 最多 100
)
6. 类和字段都要有文档字符串
class WeatherInput(BaseModel):
"""Input schema for weather-related queries.
This schema defines the parameters needed to query
weather information for a specific location.
"""
location: str = Field(
description="City name or geographic coordinates"
)
7. 合理使用 Optional
from typing import Optional
class SearchInput(BaseModel):
query: str = Field(description="Required search query")
# Optional - 可以不提供
category: Optional[str] = Field(
default=None,
description="Optional category filter"
)
# 有默认值 - 也是可选的
limit: int = Field(
default=10,
description="Maximum results"
)
与装饰器对比
方式 1: 仅使用装饰器(简单场景)
@tool
def get_weather(location: str, units: str = "celsius") -> str:
"""Get weather information.
Args:
location: City name or coordinates
units: Temperature unit (celsius or fahrenheit)
"""
return f"Weather in {location}: 22°{units[0].upper()}"
优点:
- 简单快速
- 适合参数少的工具
缺点:
- 无法限制参数值(如枚举)
- 无法进行复杂验证
- 描述在 docstring 中,不够结构化
方式 2: 使用 Pydantic BaseModel(推荐)
class WeatherInput(BaseModel):
"""Input for weather queries."""
location: str = Field(description="City name or coordinates")
units: Literal["celsius", "fahrenheit"] = Field(
default="celsius",
description="Temperature unit"
)
@tool(args_schema=WeatherInput)
def get_weather(location: str, units: str = "celsius") -> str:
"""Get weather information."""
return f"Weather in {location}: 22°{units[0].upper()}"
优点:
- 类型验证和约束
- 清晰的字段描述
- 支持复杂类型
- 自动生成规范的 JSON Schema
缺点:
- 需要额外定义 BaseModel 类
- 代码稍多一些
何时使用哪种方式
| 场景 | 推荐方式 |
|---|---|
| 参数少(1-2个),类型简单 | 仅装饰器 |
| 参数有枚举值限制 | BaseModel |
| 需要数字范围验证 | BaseModel |
| 需要复杂类型(列表、嵌套对象) | BaseModel |
| 参数超过 3 个 | BaseModel |
| 需要详细的字段描述 | BaseModel |
调试技巧
查看生成的 Schema
from langchain.tools import tool
from pydantic import BaseModel, Field
from typing import Literal
class WeatherInput(BaseModel):
location: str = Field(description="City name")
units: Literal["celsius", "fahrenheit"] = Field(default="celsius")
@tool(args_schema=WeatherInput)
def get_weather(location: str, units: str = "celsius") -> str:
"""Get weather information."""
return f"Weather in {location}"
# 查看工具的 schema
print(get_weather.args_schema.schema())
输出:
{
"title": "WeatherInput",
"type": "object",
"properties": {
"location": {
"title": "Location",
"description": "City name",
"type": "string"
},
"units": {
"title": "Units",
"default": "celsius",
"enum": ["celsius", "fahrenheit"],
"type": "string"
}
},
"required": ["location"]
}
测试验证
# 测试 Schema 验证
try:
# 有效输入
valid_input = WeatherInput(location="Paris", units="celsius")
print(f"✅ Valid: {valid_input}")
# 无效输入 - units 不是枚举值
invalid_input = WeatherInput(location="Paris", units="kelvin")
except Exception as e:
print(f"❌ Validation Error: {e}")
常见问题
Q1: 参数名称必须匹配吗?
是的,BaseModel 中的字段名称必须与函数参数名称匹配。
# ✅ 正确 - 名称匹配
class WeatherInput(BaseModel):
location: str
units: str
@tool(args_schema=WeatherInput)
def get_weather(location: str, units: str = "celsius") -> str:
pass
# ❌ 错误 - 名称不匹配
class WeatherInput(BaseModel):
city: str # 应该是 location
temp_unit: str # 应该是 units
@tool(args_schema=WeatherInput)
def get_weather(location: str, units: str = "celsius") -> str:
pass
Q2: 如何处理可变数量的参数?
使用 List 或 Dict:
from typing import List, Dict
class BatchInput(BaseModel):
items: List[str] = Field(description="List of items to process")
metadata: Dict[str, str] = Field(default={}, description="Additional metadata")
Q3: 可以嵌套 BaseModel 吗?
可以:
class Address(BaseModel):
street: str
city: str
country: str
class UserInput(BaseModel):
name: str
address: Address # 嵌套的 BaseModel
Q4: 如何设置必需字段?
不提供默认值的字段就是必需的:
class MyInput(BaseModel):
required_field: str = Field(description="This is required")
optional_field: str = Field(default="default", description="This is optional")
总结
关键要点
- 使用
args_schema参数 将 Pydantic BaseModel 传递给@tool装饰器 - 总是添加描述 使用
Field(description="...")为每个字段添加描述 - 使用 Literal 限制参数为特定的枚举值
- 使用 Field 约束 如
ge,le,min_length,max_length等验证输入 - 提供默认值 让可选参数更容易使用
- 使用类型提示 如
Optional,List,Dict等表达复杂类型
推荐阅读
- Pydantic 官方文档: https://docs.pydantic.dev/
- LangChain Tools 文档: https://docs.langchain.com/oss/python/langchain/tools
- Python typing 模块: https://docs.python.org/3/library/typing.html
快速参考
from langchain.tools import tool
from pydantic import BaseModel, Field
from typing import Literal, Optional, List
class ToolInput(BaseModel):
"""工具输入描述"""
# 必需字段
required_field: str = Field(description="字段描述")
# 可选字段(默认值)
optional_field: str = Field(default="default", description="字段描述")
# 枚举
enum_field: Literal["option1", "option2"] = Field(description="字段描述")
# 数字约束
number_field: int = Field(default=10, ge=1, le=100, description="字段描述")
# 列表
list_field: List[str] = Field(default=[], description="字段描述")
# 可选类型
nullable_field: Optional[str] = Field(default=None, description="字段描述")
@tool(args_schema=ToolInput)
def my_tool(
required_field: str,
optional_field: str = "default",
enum_field: str = "option1",
number_field: int = 10,
list_field: List[str] = [],
nullable_field: Optional[str] = None
) -> str:
"""工具功能描述"""
return "result"
练习
尝试创建以下工具,使用 Pydantic BaseModel 定义输入:
- 邮件发送工具 - 包含收件人、主题、正文、优先级(枚举)
- 产品搜索工具 - 包含关键词、价格范围、分类、排序方式
- 用户管理工具 - 包含操作类型(CRUD)、用户信息、权限列表
祝你学习愉快! 🚀
更多推荐


所有评论(0)