使用 Pydantic BaseModel 定义 LangChain 工具输入

目录

  1. 为什么使用 Pydantic BaseModel
  2. 基础用法
  3. Field 字段详解
  4. 高级类型定义
  5. 完整示例
  6. 最佳实践
  7. 与装饰器对比

为什么使用 Pydantic BaseModel

传统方式的局限

from langchain.tools import tool

@tool
def get_weather(location: str, units: str = "celsius") -> str:
    """Get weather information."""
    # 问题:
    # 1. 无法限制 units 只能是 "celsius" 或 "fahrenheit"
    # 2. 无法为每个参数添加详细描述
    # 3. 无法进行复杂的数据验证
    pass

Pydantic BaseModel 的优势

类型验证 - 自动验证输入数据类型
字段描述 - 为每个字段添加详细说明,帮助模型理解
约束定义 - 限制值的范围(如枚举、最小/最大值)
复杂类型 - 支持嵌套对象、列表、可选字段等
默认值 - 为字段设置默认值
文档生成 - 自动生成清晰的 JSON Schema


基础用法

步骤 1: 定义输入 Schema

from pydantic import BaseModel, Field
from typing import Literal

class WeatherInput(BaseModel):
    """Input for weather queries."""
    location: str = Field(description="City name or coordinates")
    units: Literal["celsius", "fahrenheit"] = Field(
        default="celsius",
        description="Temperature unit preference"
    )

步骤 2: 使用 args_schema 参数

from langchain.tools import tool

@tool(args_schema=WeatherInput)
def get_weather(location: str, units: str = "celsius") -> str:
    """Get current weather information."""
    temp = 22 if units == "celsius" else 72
    return f"Weather in {location}: {temp}°{units[0].upper()}"

完整代码

from langchain.tools import tool
from langchain.agents import create_agent
from langchain_openai import ChatOpenAI
from pydantic import BaseModel, Field
from typing import Literal

# 1. 定义输入 Schema
class WeatherInput(BaseModel):
    """Input for weather queries."""
    location: str = Field(description="City name or coordinates")
    units: Literal["celsius", "fahrenheit"] = Field(
        default="celsius",
        description="Temperature unit preference"
    )

# 2. 创建工具
@tool(args_schema=WeatherInput)
def get_weather(location: str, units: str = "celsius") -> str:
    """Get current weather information."""
    temp = 22 if units == "celsius" else 72
    return f"Weather in {location}: {temp}°{units[0].upper()}"

# 3. 使用工具
model = ChatOpenAI(model="gpt-4")
agent = create_agent(model, tools=[get_weather])

result = agent.invoke({
    "messages": [{"role": "user", "content": "What's the weather in Paris in fahrenheit?"}]
})

print(result["messages"][-1].content)

Field 字段详解

Field 参数说明

from pydantic import BaseModel, Field

class ExampleInput(BaseModel):
    field_name: type = Field(
        default=...,              # 默认值(... 表示必需)
        description="...",        # 字段描述(重要!)
        examples=["..."],         # 示例值
        min_length=1,            # 最小长度(字符串)
        max_length=100,          # 最大长度(字符串)
        ge=0,                    # 大于等于(数字)
        le=100,                  # 小于等于(数字)
        gt=0,                    # 大于(数字)
        lt=100,                  # 小于(数字)
        pattern="^[A-Z]",        # 正则表达式(字符串)
        alias="fieldName",       # 别名
    )

常用 Field 示例

1. 必需字段
class UserInput(BaseModel):
    user_id: str = Field(description="User unique identifier")
    # 或显式标记为必需
    user_id: str = Field(..., description="User unique identifier")
2. 可选字段(带默认值)
class SearchInput(BaseModel):
    query: str = Field(description="Search query")
    limit: int = Field(default=10, description="Maximum results")
    offset: int = Field(default=0, description="Skip first N results")
3. 字符串约束
class UserProfileInput(BaseModel):
    username: str = Field(
        description="Username (3-20 characters)",
        min_length=3,
        max_length=20,
        pattern="^[a-zA-Z0-9_]+$"  # 只允许字母、数字、下划线
    )
    email: str = Field(
        description="User email address",
        pattern=r"^[\w\.-]+@[\w\.-]+\.\w+$"
    )
4. 数字约束
class ProductFilterInput(BaseModel):
    min_price: float = Field(
        default=0.0,
        ge=0.0,  # 大于等于 0
        description="Minimum price in USD"
    )
    max_price: float = Field(
        default=1000.0,
        le=10000.0,  # 小于等于 10000
        description="Maximum price in USD"
    )
    quantity: int = Field(
        default=1,
        gt=0,  # 大于 0
        lt=100,  # 小于 100
        description="Quantity to order"
    )
5. 示例值
class QueryInput(BaseModel):
    search_term: str = Field(
        description="Search keywords",
        examples=["machine learning", "python tutorial", "AI news"]
    )

高级类型定义

1. 枚举类型(Literal)

from typing import Literal

class WeatherInput(BaseModel):
    location: str = Field(description="City name")
    units: Literal["celsius", "fahrenheit", "kelvin"] = Field(
        default="celsius",
        description="Temperature unit"
    )
    forecast_type: Literal["hourly", "daily", "weekly"] = Field(
        default="daily",
        description="Forecast time range"
    )

使用场景: 当参数只能是几个固定值之一时

2. 可选字段(Optional)

from typing import Optional

class SearchInput(BaseModel):
    query: str = Field(description="Search query (required)")
    category: Optional[str] = Field(
        default=None,
        description="Filter by category (optional)"
    )
    min_rating: Optional[float] = Field(
        default=None,
        description="Minimum rating filter"
    )

使用场景: 字段可以为 None 或不提供

3. 列表类型(List)

from typing import List

class BatchQueryInput(BaseModel):
    queries: List[str] = Field(
        description="List of search queries to process"
    )
    tags: List[str] = Field(
        default=[],
        description="Filter by tags"
    )
    user_ids: List[int] = Field(
        description="List of user IDs to query",
        min_length=1,  # 至少一个元素
        max_length=10  # 最多 10 个元素
    )

4. 嵌套对象

class Address(BaseModel):
    street: str = Field(description="Street address")
    city: str = Field(description="City name")
    country: str = Field(description="Country name")
    postal_code: str = Field(description="Postal code")

class CreateUserInput(BaseModel):
    name: str = Field(description="User's full name")
    email: str = Field(description="User's email address")
    address: Address = Field(description="User's address")

5. 联合类型(Union)

from typing import Union

class SearchInput(BaseModel):
    identifier: Union[int, str] = Field(
        description="User ID (int) or username (str)"
    )
    date_filter: Union[str, int] = Field(
        description="Date as string (YYYY-MM-DD) or unix timestamp"
    )

6. 字典类型(Dict)

from typing import Dict

class MetadataInput(BaseModel):
    filters: Dict[str, str] = Field(
        default={},
        description="Key-value pairs for filtering"
    )
    settings: Dict[str, int] = Field(
        default={"timeout": 30, "retries": 3},
        description="Configuration settings"
    )

7. 复杂组合

from typing import List, Optional, Literal, Dict

class AdvancedSearchInput(BaseModel):
    """Advanced search with multiple filters."""

    # 必需字段
    query: str = Field(
        description="Main search query",
        min_length=1,
        max_length=200
    )

    # 枚举类型
    sort_by: Literal["relevance", "date", "popularity"] = Field(
        default="relevance",
        description="Sort results by"
    )

    # 可选字段
    category: Optional[str] = Field(
        default=None,
        description="Filter by category"
    )

    # 列表
    tags: List[str] = Field(
        default=[],
        description="Filter by tags"
    )

    # 数字约束
    limit: int = Field(
        default=10,
        ge=1,
        le=100,
        description="Maximum results"
    )

    # 布尔值
    include_archived: bool = Field(
        default=False,
        description="Include archived items"
    )

    # 字典
    metadata: Dict[str, str] = Field(
        default={},
        description="Additional metadata filters"
    )

完整示例

示例 1: 天气查询工具

from langchain.tools import tool
from langchain.agents import create_agent
from langchain_openai import ChatOpenAI
from pydantic import BaseModel, Field
from typing import Literal, Optional

# 定义输入 Schema
class WeatherInput(BaseModel):
    """Input schema for weather queries."""

    location: str = Field(
        description="City name or coordinates (e.g., 'Paris' or '48.8566,2.3522')",
        examples=["Paris", "New York", "Tokyo"]
    )

    units: Literal["celsius", "fahrenheit"] = Field(
        default="celsius",
        description="Temperature unit preference"
    )

    include_forecast: bool = Field(
        default=False,
        description="Include 5-day forecast"
    )

    days: Optional[int] = Field(
        default=None,
        ge=1,
        le=7,
        description="Number of forecast days (1-7)"
    )

# 创建工具
@tool(args_schema=WeatherInput)
def get_weather(
    location: str,
    units: str = "celsius",
    include_forecast: bool = False,
    days: Optional[int] = None
) -> str:
    """Get current weather and optional forecast for a location.

    This tool provides current weather information and can optionally
    include a multi-day forecast.
    """
    # 模拟天气数据
    temp = 22 if units == "celsius" else 72
    result = f"Current weather in {location}: {temp}°{units[0].upper()}, Sunny"

    if include_forecast and days:
        result += f"\n\n{days}-day forecast:"
        for i in range(1, days + 1):
            temp_day = temp + i
            result += f"\n  Day {i}: {temp_day}°{units[0].upper()}, Partly cloudy"

    return result

# 使用工具
if __name__ == "__main__":
    model = ChatOpenAI(model="gpt-4")
    agent = create_agent(model, tools=[get_weather])

    # 测试 1: 简单查询
    result = agent.invoke({
        "messages": [{"role": "user", "content": "What's the weather in Paris?"}]
    })
    print(result["messages"][-1].content)

    # 测试 2: 带预报
    result = agent.invoke({
        "messages": [{"role": "user", "content": "Give me the weather in Tokyo with a 3-day forecast in fahrenheit"}]
    })
    print(result["messages"][-1].content)

示例 2: 数据库查询工具

from langchain.tools import tool
from pydantic import BaseModel, Field
from typing import List, Literal, Optional

class DatabaseQueryInput(BaseModel):
    """Input for database query operations."""

    table_name: str = Field(
        description="Name of the table to query",
        examples=["users", "products", "orders"]
    )

    columns: List[str] = Field(
        default=["*"],
        description="Columns to select (default: all columns)",
        examples=[["id", "name", "email"], ["*"]]
    )

    filters: Optional[dict] = Field(
        default=None,
        description="Filter conditions as key-value pairs",
        examples=[{"status": "active"}, {"age": 25}]
    )

    order_by: Optional[str] = Field(
        default=None,
        description="Column to sort by",
        examples=["created_at", "name"]
    )

    order_direction: Literal["ASC", "DESC"] = Field(
        default="ASC",
        description="Sort direction"
    )

    limit: int = Field(
        default=10,
        ge=1,
        le=100,
        description="Maximum number of results"
    )

@tool(args_schema=DatabaseQueryInput)
def query_database(
    table_name: str,
    columns: List[str] = ["*"],
    filters: Optional[dict] = None,
    order_by: Optional[str] = None,
    order_direction: str = "ASC",
    limit: int = 10
) -> str:
    """Query the database with specified filters and options.

    This tool allows you to retrieve data from database tables
    with flexible filtering and sorting options.
    """
    # 构建 SQL 查询(示例)
    cols = ", ".join(columns)
    query = f"SELECT {cols} FROM {table_name}"

    if filters:
        conditions = [f"{k} = '{v}'" for k, v in filters.items()]
        query += " WHERE " + " AND ".join(conditions)

    if order_by:
        query += f" ORDER BY {order_by} {order_direction}"

    query += f" LIMIT {limit}"

    # 模拟查询结果
    return f"Executed query: {query}\n\nResults: [Sample data would appear here]"

示例 3: 文件操作工具

from langchain.tools import tool
from pydantic import BaseModel, Field
from typing import Literal, Optional

class FileOperationInput(BaseModel):
    """Input for file operations."""

    operation: Literal["read", "write", "delete", "list"] = Field(
        description="Type of file operation to perform"
    )

    file_path: str = Field(
        description="Path to the file",
        examples=["/path/to/file.txt", "documents/report.pdf"]
    )

    content: Optional[str] = Field(
        default=None,
        description="Content to write (only for 'write' operation)"
    )

    encoding: str = Field(
        default="utf-8",
        description="File encoding"
    )

    mode: Literal["text", "binary"] = Field(
        default="text",
        description="Read/write mode"
    )

@tool(args_schema=FileOperationInput)
def file_operation(
    operation: str,
    file_path: str,
    content: Optional[str] = None,
    encoding: str = "utf-8",
    mode: str = "text"
) -> str:
    """Perform file operations (read, write, delete, list).

    This tool allows safe file system operations with
    proper encoding and mode handling.
    """
    if operation == "read":
        return f"Reading file: {file_path}\n[File content would appear here]"
    elif operation == "write":
        return f"Writing to file: {file_path}\nContent: {content[:50]}..."
    elif operation == "delete":
        return f"Deleting file: {file_path}"
    elif operation == "list":
        return f"Listing directory: {file_path}\n- file1.txt\n- file2.pdf"
    else:
        return f"Unknown operation: {operation}"

示例 4: API 调用工具

from langchain.tools import tool
from pydantic import BaseModel, Field
from typing import Dict, Optional, Literal, List

class APICallInput(BaseModel):
    """Input for making API requests."""

    url: str = Field(
        description="API endpoint URL",
        pattern="^https?://",  # 必须是 http 或 https
        examples=["https://api.example.com/users"]
    )

    method: Literal["GET", "POST", "PUT", "DELETE", "PATCH"] = Field(
        default="GET",
        description="HTTP method"
    )

    headers: Optional[Dict[str, str]] = Field(
        default=None,
        description="HTTP headers",
        examples=[{"Authorization": "Bearer token123"}]
    )

    params: Optional[Dict[str, str]] = Field(
        default=None,
        description="Query parameters",
        examples=[{"page": "1", "limit": "10"}]
    )

    body: Optional[Dict] = Field(
        default=None,
        description="Request body (for POST, PUT, PATCH)"
    )

    timeout: int = Field(
        default=30,
        ge=1,
        le=300,
        description="Request timeout in seconds"
    )

@tool(args_schema=APICallInput)
def api_call(
    url: str,
    method: str = "GET",
    headers: Optional[Dict[str, str]] = None,
    params: Optional[Dict[str, str]] = None,
    body: Optional[Dict] = None,
    timeout: int = 30
) -> str:
    """Make HTTP API requests with specified parameters.

    This tool allows you to interact with external APIs
    using various HTTP methods and parameters.
    """
    result = f"API Call:\n"
    result += f"  Method: {method}\n"
    result += f"  URL: {url}\n"
    if params:
        result += f"  Params: {params}\n"
    if headers:
        result += f"  Headers: {headers}\n"
    if body:
        result += f"  Body: {body}\n"

    # 模拟 API 响应
    result += f"\nResponse: [API response would appear here]"
    return result

最佳实践

1. 总是添加描述

# ❌ 不好 - 没有描述
class BadInput(BaseModel):
    location: str
    units: str

# ✅ 好 - 有清晰的描述
class GoodInput(BaseModel):
    """Input for weather queries."""
    location: str = Field(description="City name or coordinates")
    units: str = Field(description="Temperature unit (celsius or fahrenheit)")

为什么重要: 描述帮助 LLM 理解何时以及如何使用每个参数。

2. 使用 Literal 限制选项

# ❌ 不好 - 字符串可以是任何值
units: str = Field(description="Temperature unit")

# ✅ 好 - 限制为特定值
units: Literal["celsius", "fahrenheit"] = Field(
    description="Temperature unit"
)

3. 提供默认值

class SearchInput(BaseModel):
    query: str = Field(description="Search query")
    limit: int = Field(default=10, description="Results limit")  # ✅ 有默认值
    offset: int = Field(default=0, description="Skip N results")  # ✅ 有默认值

4. 添加示例

class UserInput(BaseModel):
    email: str = Field(
        description="User email address",
        examples=["user@example.com", "john.doe@company.org"]
    )

5. 使用约束验证

class ProductInput(BaseModel):
    name: str = Field(
        description="Product name",
        min_length=3,
        max_length=100
    )
    price: float = Field(
        description="Product price",
        gt=0,  # 必须大于 0
        le=10000  # 最多 10000
    )
    quantity: int = Field(
        description="Quantity",
        ge=1,  # 至少 1
        le=100  # 最多 100
    )

6. 类和字段都要有文档字符串

class WeatherInput(BaseModel):
    """Input schema for weather-related queries.

    This schema defines the parameters needed to query
    weather information for a specific location.
    """

    location: str = Field(
        description="City name or geographic coordinates"
    )

7. 合理使用 Optional

from typing import Optional

class SearchInput(BaseModel):
    query: str = Field(description="Required search query")

    # Optional - 可以不提供
    category: Optional[str] = Field(
        default=None,
        description="Optional category filter"
    )

    # 有默认值 - 也是可选的
    limit: int = Field(
        default=10,
        description="Maximum results"
    )

与装饰器对比

方式 1: 仅使用装饰器(简单场景)

@tool
def get_weather(location: str, units: str = "celsius") -> str:
    """Get weather information.

    Args:
        location: City name or coordinates
        units: Temperature unit (celsius or fahrenheit)
    """
    return f"Weather in {location}: 22°{units[0].upper()}"

优点:

  • 简单快速
  • 适合参数少的工具

缺点:

  • 无法限制参数值(如枚举)
  • 无法进行复杂验证
  • 描述在 docstring 中,不够结构化

方式 2: 使用 Pydantic BaseModel(推荐)

class WeatherInput(BaseModel):
    """Input for weather queries."""
    location: str = Field(description="City name or coordinates")
    units: Literal["celsius", "fahrenheit"] = Field(
        default="celsius",
        description="Temperature unit"
    )

@tool(args_schema=WeatherInput)
def get_weather(location: str, units: str = "celsius") -> str:
    """Get weather information."""
    return f"Weather in {location}: 22°{units[0].upper()}"

优点:

  • 类型验证和约束
  • 清晰的字段描述
  • 支持复杂类型
  • 自动生成规范的 JSON Schema

缺点:

  • 需要额外定义 BaseModel 类
  • 代码稍多一些

何时使用哪种方式

场景 推荐方式
参数少(1-2个),类型简单 仅装饰器
参数有枚举值限制 BaseModel
需要数字范围验证 BaseModel
需要复杂类型(列表、嵌套对象) BaseModel
参数超过 3 个 BaseModel
需要详细的字段描述 BaseModel

调试技巧

查看生成的 Schema

from langchain.tools import tool
from pydantic import BaseModel, Field
from typing import Literal

class WeatherInput(BaseModel):
    location: str = Field(description="City name")
    units: Literal["celsius", "fahrenheit"] = Field(default="celsius")

@tool(args_schema=WeatherInput)
def get_weather(location: str, units: str = "celsius") -> str:
    """Get weather information."""
    return f"Weather in {location}"

# 查看工具的 schema
print(get_weather.args_schema.schema())

输出:

{
  "title": "WeatherInput",
  "type": "object",
  "properties": {
    "location": {
      "title": "Location",
      "description": "City name",
      "type": "string"
    },
    "units": {
      "title": "Units",
      "default": "celsius",
      "enum": ["celsius", "fahrenheit"],
      "type": "string"
    }
  },
  "required": ["location"]
}

测试验证

# 测试 Schema 验证
try:
    # 有效输入
    valid_input = WeatherInput(location="Paris", units="celsius")
    print(f"✅ Valid: {valid_input}")

    # 无效输入 - units 不是枚举值
    invalid_input = WeatherInput(location="Paris", units="kelvin")
except Exception as e:
    print(f"❌ Validation Error: {e}")

常见问题

Q1: 参数名称必须匹配吗?

是的,BaseModel 中的字段名称必须与函数参数名称匹配。

# ✅ 正确 - 名称匹配
class WeatherInput(BaseModel):
    location: str
    units: str

@tool(args_schema=WeatherInput)
def get_weather(location: str, units: str = "celsius") -> str:
    pass

# ❌ 错误 - 名称不匹配
class WeatherInput(BaseModel):
    city: str  # 应该是 location
    temp_unit: str  # 应该是 units

@tool(args_schema=WeatherInput)
def get_weather(location: str, units: str = "celsius") -> str:
    pass

Q2: 如何处理可变数量的参数?

使用 ListDict:

from typing import List, Dict

class BatchInput(BaseModel):
    items: List[str] = Field(description="List of items to process")
    metadata: Dict[str, str] = Field(default={}, description="Additional metadata")

Q3: 可以嵌套 BaseModel 吗?

可以:

class Address(BaseModel):
    street: str
    city: str
    country: str

class UserInput(BaseModel):
    name: str
    address: Address  # 嵌套的 BaseModel

Q4: 如何设置必需字段?

不提供默认值的字段就是必需的:

class MyInput(BaseModel):
    required_field: str = Field(description="This is required")
    optional_field: str = Field(default="default", description="This is optional")

总结

关键要点

  1. 使用 args_schema 参数 将 Pydantic BaseModel 传递给 @tool 装饰器
  2. 总是添加描述 使用 Field(description="...") 为每个字段添加描述
  3. 使用 Literal 限制参数为特定的枚举值
  4. 使用 Field 约束ge, le, min_length, max_length 等验证输入
  5. 提供默认值 让可选参数更容易使用
  6. 使用类型提示Optional, List, Dict 等表达复杂类型

推荐阅读

  • Pydantic 官方文档: https://docs.pydantic.dev/
  • LangChain Tools 文档: https://docs.langchain.com/oss/python/langchain/tools
  • Python typing 模块: https://docs.python.org/3/library/typing.html

快速参考

from langchain.tools import tool
from pydantic import BaseModel, Field
from typing import Literal, Optional, List

class ToolInput(BaseModel):
    """工具输入描述"""

    # 必需字段
    required_field: str = Field(description="字段描述")

    # 可选字段(默认值)
    optional_field: str = Field(default="default", description="字段描述")

    # 枚举
    enum_field: Literal["option1", "option2"] = Field(description="字段描述")

    # 数字约束
    number_field: int = Field(default=10, ge=1, le=100, description="字段描述")

    # 列表
    list_field: List[str] = Field(default=[], description="字段描述")

    # 可选类型
    nullable_field: Optional[str] = Field(default=None, description="字段描述")

@tool(args_schema=ToolInput)
def my_tool(
    required_field: str,
    optional_field: str = "default",
    enum_field: str = "option1",
    number_field: int = 10,
    list_field: List[str] = [],
    nullable_field: Optional[str] = None
) -> str:
    """工具功能描述"""
    return "result"

练习

尝试创建以下工具,使用 Pydantic BaseModel 定义输入:

  1. 邮件发送工具 - 包含收件人、主题、正文、优先级(枚举)
  2. 产品搜索工具 - 包含关键词、价格范围、分类、排序方式
  3. 用户管理工具 - 包含操作类型(CRUD)、用户信息、权限列表

祝你学习愉快! 🚀

Logo

有“AI”的1024 = 2048,欢迎大家加入2048 AI社区

更多推荐