一、为什么需要自定义状态?

默认情况下,LangGraph 的状态只包含 messages 字段,用于存储对话历史。但在实际应用中,我们经常需要存储其他业务数据:

  • 用户信息:姓名、生日、联系方式等
  • 业务数据:订单号、商品列表、价格等
  • 流程状态:审批状态、处理进度等
  • 配置信息:用户偏好、系统设置等

自定义状态让你可以在 State 中添加任意字段,实现更复杂的业务逻辑。


二、核心概念

2.1 State 的结构

from typing import Annotated
from typing_extensions import TypedDict
from langgraph.graph.message import add_messages

class State(TypedDict):
    # 标准消息字段(使用 add_messages reducer)
    messages: Annotated[list, add_messages]
    # 自定义字段
    name: str
    birthday: str
    order_id: str
    # ... 任意其他字段

2.2 add_messages 是什么?

add_messages 是一个 reducer 函数,它定义了如何合并新旧消息:

# 原有消息: [msg1, msg2]
# 新增消息: [msg3]
# 合并后: [msg1, msg2, msg3]

messages: Annotated[list, add_messages]

简单说:每次返回新消息时,它会自动追加到现有消息列表后面,而不是覆盖。

2.3 自定义字段的特点

自定义字段(如 name, birthday没有 reducer,所以它们的行为是:

# 原有值: name = "张三"
# 新值: name = "李四"
# 结果: name = "李四"  ← 直接覆盖

三、如何更新自定义状态

有三种方式更新自定义状态:

方式1:在节点中返回

def my_node(state: State) -> dict:
    return {
        "name": "张三",  # 更新自定义字段
        "messages": [{"role": "assistant", "content": "好的"}]
    }

方式2:在工具中使用 Command

from langgraph.types import Command
from langchain_core.tools import InjectedToolCallId, tool

@tool
def my_tool(
    name: str,
    tool_call_id: Annotated[str, InjectedToolCallId]  # 自动注入
) -> Command:
    """工具内部更新状态"""
    return Command(update={
        "name": name,
        "messages": [ToolMessage("已更新", tool_call_id=tool_call_id)]
    })

方式3:手动更新

graph.update_state(config, {"name": "王五"})

四、完整示例:信息验证流程

让我们通过一个完整的例子来理解自定义状态的使用。

4.1 场景描述

用户提交信息 → LLM 调用工具 → 人工验证 → 保存到状态

4.2 状态定义

class State(TypedDict):
    messages: Annotated[list, add_messages]  # 对话历史
    name: str                                 # 姓名
    birthday: str                             # 生日

4.3 工具实现

from langgraph.types import Command, interrupt
from langchain_core.tools import InjectedToolCallId, tool
from langchain_core.messages import ToolMessage

@tool
def human_assistance(
    name: str, 
    birthday: str, 
    tool_call_id: Annotated[str, InjectedToolCallId]
) -> Command:
    """请求人工验证信息"""

    # 1. 暂停执行,等待人工输入
    human_response = interrupt({
        "question": "请确认信息是否正确?",
        "name": name,
        "birthday": birthday,
    })

    # 2. 处理人工响应
    if human_response.get("correct", "").lower().startswith("y"):
        verified_name = name
        verified_birthday = birthday
        response = "信息已确认"
    else:
        verified_name = human_response.get("name", name)
        verified_birthday = human_response.get("birthday", birthday)
        response = f"已修正信息"

    # 3. 返回 Command 更新状态
    return Command(update={
        "name": verified_name,
        "birthday": verified_birthday,
        "messages": [ToolMessage(response, tool_call_id=tool_call_id)]
    })

4.4 关键点解释

InjectedToolCallId 是什么?
tool_call_id: Annotated[str, InjectedToolCallId]

这是一个自动注入的参数

  • LLM 调用工具时,系统自动传入工具调用 ID
  • LLM 看不到这个参数(不会在提示词中显示)
  • 用于创建 ToolMessage 时关联原始调用
Command 是什么?
return Command(update={
    "name": "张三",
    "birthday": "1990-01-01"
})

Command 是从工具内部更新状态的方式:

  • update: 更新状态的字段
  • resume: 恢复执行时传入的数据

五、执行流程

5.1 流程图

┌─────────────────────────────────────────────────────────────────┐
│                    自定义状态更新流程                            │
├─────────────────────────────────────────────────────────────────┤
│                                                                 │
│  1. 用户输入: "张三的生日是1990年5月15日"                        │
│     │                                                          │
│     ▼                                                          │
│  ┌─────────────┐                                               │
│  │   chatbot   │  LLM 决定调用 human_assistance 工具           │
│  └──────┬──────┘                                               │
│         │                                                      │
│         ▼                                                      │
│  ┌─────────────────────┐                                       │
│  │ human_assistance    │                                       │
│  │   工具被调用        │                                       │
│  │   name="张三"       │                                       │
│  │   birthday="1990.." │                                       │
│  └──────────┬──────────┘                                       │
│             │                                                  │
│             ▼                                                  │
│  ┌─────────────────────┐                                       │
│  │   interrupt()       │  暂停执行,等待人工验证               │
│  └──────────┬──────────┘                                       │
│             │                                                  │
│             │  人工输入: {"correct": "n", "name": "李四", ...}  │
│             │                                                  │
│             ▼                                                  │
│  ┌─────────────────────┐                                       │
│  │ Command(update=...) │  更新状态                             │
│  │   name="李四"       │                                       │
│  │   birthday="1992.." │                                       │
│  └──────────┬──────────┘                                       │
│             │                                                  │
│             ▼                                                  │
│  ┌─────────────┐                                               │
│  │   chatbot   │  LLM 继续处理                                 │
│  └─────────────┘                                               │
│                                                                 │
└─────────────────────────────────────────────────────────────────┘

5.2 代码执行

# 1. 启动流程
config = {"configurable": {"thread_id": "demo"}}
graph.invoke({"messages": [{"role": "user", "content": "..."}]}, config)

# 2. 检查状态(暂停中)
snapshot = graph.get_state(config)
print(snapshot.next)  # ('tools',) 表示在工具节点暂停

# 3. 提供人工输入
human_command = Command(resume={
    "correct": "n",
    "name": "李四",
    "birthday": "1992年8月20日"
})
graph.invoke(human_command, config)

# 4. 查看更新后的状态
snapshot = graph.get_state(config)
print(snapshot.values["name"])      # "李四"
print(snapshot.values["birthday"])  # "1992年8月20日"

六、手动更新状态

除了在工具中更新,你还可以随时手动更新状态:

# 更新单个字段
graph.update_state(config, {"name": "王五"})

# 更新多个字段
graph.update_state(config, {
    "name": "赵六",
    "birthday": "1985-03-10"
})

# 查看更新后的状态
snapshot = graph.get_state(config)
print(snapshot.values["name"])  # "赵六"

6.1 应用场景

  • 修正错误:发现数据有误,手动修正
  • 补充信息:后续添加额外字段
  • 重置状态:清空某些字段重新开始

七、下游节点访问状态

自定义字段的一个重要用途是让下游节点访问:

def downstream_node(state: State) -> dict:
    # 访问之前保存的状态
    name = state.get("name", "未知")
    birthday = state.get("birthday", "未知")

    # 基于这些信息进行处理
    message = f"根据记录,{name} 的生日是 {birthday}"

    return {"messages": [{"role": "assistant", "content": message}]}

7.1 数据流示意图

┌─────────────────────────────────────────────────────────────────┐
│                    状态在节点间传递                              │
├─────────────────────────────────────────────────────────────────┤
│                                                                 │
│  节点A: 保存信息到状态                                          │
│  ┌─────────────────────────────────────────┐                   │
│  │  state["name"] = "张三"                 │                   │
│  │  state["birthday"] = "1990-01-01"       │                   │
│  └─────────────────────────────────────────┘                   │
│                      │                                          │
│                      ▼                                          │
│  ┌─────────────────────────────────────────┐                   │
│  │           MemorySaver                   │                   │
│  │    保存: name="张三", birthday="..."    │                   │
│  └─────────────────────────────────────────┘                   │
│                      │                                          │
│                      ▼                                          │
│  节点B: 读取状态中的信息                                        │
│  ┌─────────────────────────────────────────┐                   │
│  │  name = state["name"]  # "张三"         │                   │
│  │  birthday = state["birthday"]           │                   │
│  │  # 使用这些信息进行处理...               │                   │
│  └─────────────────────────────────────────┘                   │
│                                                                 │
└─────────────────────────────────────────────────────────────────┘

八、完整代码示例

8.1 示例文件

"""
LangGraph 自定义状态示例 - 演示如何在状态中添加自定义字段

本示例演示:
- 向 State 添加自定义字段(name, birthday)
- 在工具内部使用 Command 更新状态
- 使用 InjectedToolCallId 注入工具调用 ID
- 使用 update_state() 手动更新状态
- 结合人工干预验证信息

图结构: START -> chatbot -> tools_condition -> tools -> chatbot -> END
"""

from typing import Annotated
from typing_extensions import TypedDict

from langchain_core.messages import ToolMessage
from langchain_core.tools import InjectedToolCallId, tool
from langchain_openai import ChatOpenAI

from langgraph.checkpoint.memory import MemorySaver
from langgraph.graph import StateGraph, START, END
from langgraph.graph.message import add_messages
from langgraph.prebuilt import ToolNode, tools_condition
from langgraph.types import Command, interrupt


# ==================== 状态定义 ====================
class State(TypedDict):
    """
    自定义状态结构
    
    关键点:
    - messages: 使用 add_messages reducer 自动合并消息
    - name: 自定义字段,存储姓名
    - birthday: 自定义字段,存储生日
    """
    messages: Annotated[list, add_messages]
    name: str
    birthday: str


# ==================== 工具定义 ====================
@tool
def human_assistance(
    name: str, 
    birthday: str, 
    tool_call_id: Annotated[str, InjectedToolCallId]
) -> Command:
    """
    请求人工协助验证信息
    
    关键点:
    - InjectedToolCallId: 自动注入工具调用 ID,不暴露给 LLM
    - interrupt(): 暂停执行等待人工输入
    - Command(update=...): 从工具内部更新状态
    """
    print(f"\n  🔧 human_assistance 工具被调用")
    print(f"     name: {name}")
    print(f"     birthday: {birthday}")
    
    human_response = interrupt(
        {
            "question": "请确认以下信息是否正确?",
            "name": name,
            "birthday": birthday,
        },
    )
    
    print(f"\n  📩 收到人工响应: {human_response}")
    
    if human_response.get("correct", "").lower().startswith("y"):
        verified_name = name
        verified_birthday = birthday
        response = "✅ 信息已确认正确"
    else:
        verified_name = human_response.get("name", name)
        verified_birthday = human_response.get("birthday", birthday)
        response = f"📝 已修正: name={verified_name}, birthday={verified_birthday}"
    
    state_update = {
        "name": verified_name,
        "birthday": verified_birthday,
        "messages": [ToolMessage(response, tool_call_id=tool_call_id)],
    }
    
    print(f"\n  💾 更新状态:")
    print(f"     name: {verified_name}")
    print(f"     birthday: {verified_birthday}")
    
    return Command(update=state_update)


@tool
def multiply(a: int, b: int) -> int:
    """Multiply two numbers."""
    return a * b


# ==================== 初始化大模型 ====================
llm = ChatOpenAI(
    model="qwen-plus",
    base_url="https://dashscope.aliyuncs.com/compatible-mode/v1",
    api_key="sk-ba8d8ecea7f048e5a3b9f3d2b9d3cc14",
    temperature=0.7,
    timeout=60,
    request_timeout=60,
)

tools = [human_assistance, multiply]
llm_with_tools = llm.bind_tools(tools)


# ==================== 节点函数 ====================
def chatbot(state: State) -> dict:
    """聊天机器人节点"""
    message = llm_with_tools.invoke(state["messages"])
    return {"messages": [message]}


# ==================== 构建图 ====================
graph_builder = StateGraph(State)

graph_builder.add_node("chatbot", chatbot)

tool_node = ToolNode(tools=tools)
graph_builder.add_node("tools", tool_node)

graph_builder.add_conditional_edges(
    "chatbot",
    tools_condition,
)
graph_builder.add_edge("tools", "chatbot")
graph_builder.add_edge(START, "chatbot")

memory = MemorySaver()
graph = graph_builder.compile(checkpointer=memory)


# ==================== 主程序 ====================
def run_demo():
    """运行演示"""
    print("=" * 70)
    print("📌 LangGraph 自定义状态示例")
    print("=" * 70)
    print("\n图结构: START -> chatbot -> tools_condition -> tools -> chatbot")
    print("\n功能演示:")
    print("  1. 向 State 添加自定义字段 (name, birthday)")
    print("  2. 在工具内部使用 Command 更新状态")
    print("  3. 使用 update_state() 手动更新状态")
    print("=" * 70)
    
    config = {"configurable": {"thread_id": "custom_state_demo"}}
    
    # ==================== 场景1: 使用 human_assistance 工具 ====================
    print("\n" + "=" * 70)
    print("🔄 场景1: 使用 human_assistance 工具验证信息")
    print("=" * 70)
    
    user_input = (
        "请帮我记录一下:张三的生日是1990年5月15日。"
        "请使用 human_assistance 工具来验证这个信息。"
    )
    
    print(f"\n用户输入: {user_input}")
    
    events = graph.stream(
        {"messages": [{"role": "user", "content": user_input}]},
        config,
        stream_mode="values",
    )
    
    for event in events:
        if "messages" in event:
            last_msg = event["messages"][-1]
            if hasattr(last_msg, 'content') and last_msg.content:
                print(f"\n  🤖 LLM: {last_msg.content[:200]}...")
    
    # 检查是否需要人工干预
    snapshot = graph.get_state(config)
    print(f"\n  📊 当前状态:")
    print(f"     next: {snapshot.next}")
    
    if snapshot.next:
        print("\n  ⏸️ 执行已暂停,等待人工验证...")
        
        # 模拟人工确认(修正信息)
        human_command = Command(
            resume={
                "correct": "n",
                "name": "李四",
                "birthday": "1992年8月20日",
            },
        )
        
        print("\n  👤 人工修正:")
        print("     原信息: 张三, 1990年5月15日")
        print("     修正为: 李四, 1992年8月20日")
        
        events = graph.stream(human_command, config, stream_mode="values")
        for event in events:
            if "messages" in event:
                last_msg = event["messages"][-1]
                if hasattr(last_msg, 'content') and last_msg.content:
                    print(f"\n  🤖 LLM: {last_msg.content[:300]}...")
    
    # 查看更新后的状态
    snapshot = graph.get_state(config)
    print("\n" + "-" * 70)
    print("📊 更新后的状态:")
    print("-" * 70)
    print(f"   name: {snapshot.values.get('name', '未设置')}")
    print(f"   birthday: {snapshot.values.get('birthday', '未设置')}")
    
    # ==================== 场景2: 手动更新状态 ====================
    print("\n" + "=" * 70)
    print("🔄 场景2: 使用 update_state() 手动更新状态")
    print("=" * 70)
    
    print("\n  手动更新 name 字段...")
    graph.update_state(config, {"name": "王五"})
    
    snapshot = graph.get_state(config)
    print(f"\n  📊 更新后的状态:")
    print(f"     name: {snapshot.values.get('name', '未设置')}")
    print(f"     birthday: {snapshot.values.get('birthday', '未设置')}")
    
    print("\n  手动更新 birthday 字段...")
    graph.update_state(config, {"birthday": "1988年3月10日"})
    
    snapshot = graph.get_state(config)
    print(f"\n  📊 更新后的状态:")
    print(f"     name: {snapshot.values.get('name', '未设置')}")
    print(f"     birthday: {snapshot.values.get('birthday', '未设置')}")
    
    # ==================== 场景3: 下游节点访问状态 ====================
    print("\n" + "=" * 70)
    print("🔄 场景3: 在新对话中访问之前保存的状态")
    print("=" * 70)
    
    user_input2 = "我之前记录的人是谁?生日是什么时候?"
    print(f"\n用户输入: {user_input2}")
    
    events = graph.stream(
        {"messages": [{"role": "user", "content": user_input2}]},
        config,
        stream_mode="values",
    )
    
    for event in events:
        if "messages" in event:
            last_msg = event["messages"][-1]
            if hasattr(last_msg, 'content') and last_msg.content:
                print(f"\n  🤖 LLM: {last_msg.content}")
    
    # ==================== 总结 ====================
    print("\n" + "=" * 70)
    print("📊 总结:自定义状态关键点")
    print("=" * 70)
    print("""
✅ 自定义状态的优势:
   - 在 State 中添加任意字段存储业务数据
   - 工具内部可以使用 Command 更新状态
   - 下游节点可以访问这些字段
   - 支持手动更新状态 (update_state)

📝 核心 API:
   1. State 定义: class State(TypedDict):
         messages: Annotated[list, add_messages]
         name: str
         birthday: str

   2. 工具更新状态:
         @tool
         def my_tool(..., tool_call_id: Annotated[str, InjectedToolCallId]):
             return Command(update={"name": "xxx"})

   3. 手动更新状态:
         graph.update_state(config, {"name": "xxx"})

   4. 查看状态:
         snapshot = graph.get_state(config)
         snapshot.values["name"]

🎯 适用场景:
   - 存储用户信息、配置等业务数据
   - 多步骤工作流中传递数据
   - 人工验证后保存确认信息
""")


if __name__ == "__main__":
    run_demo()

8.2 运行结果

======================================================================
📌 LangGraph 自定义状态示例
======================================================================

图结构: START -> chatbot -> tools_condition -> tools -> chatbot

功能演示:
  1. 向 State 添加自定义字段 (name, birthday)
  2. 在工具内部使用 Command 更新状态
  3. 使用 update_state() 手动更新状态
======================================================================

======================================================================
🔄 场景1: 使用 human_assistance 工具验证信息
======================================================================

用户输入: 请帮我记录一下:张三的生日是1990年5月15日。请使用 human_assistance 工具来验证这个信息。

  🤖 LLM: 请帮我记录一下:张三的生日是1990年5月15日。请使用 human_assistance 工具来验证这个信息。...

  🔧 human_assistance 工具被调用
     name: 张三
     birthday: 1990年5月15日

  📊 当前状态:
     next: ('tools',)

  ⏸️ 执行已暂停,等待人工验证...

  👤 人工修正:
     原信息: 张三, 1990年5月15日
     修正为: 李四, 1992年8月20日

  🔧 human_assistance 工具被调用
     name: 张三
     birthday: 1990年5月15日

  📩 收到人工响应: {'correct': 'n', 'name': '李四', 'birthday': '1992年8月20日'}

  💾 更新状态:
     name: 李四
     birthday: 1992年8月20日

  🤖 LLM: 📝 已修正: name=李四, birthday=1992年8月20日...

  🤖 LLM: 已记录修正信息:姓名为“李四”,生日为“1992年8月20日”。如需进一步操作,请随时告知!...

----------------------------------------------------------------------
📊 更新后的状态:
----------------------------------------------------------------------
   name: 李四
   birthday: 1992年8月20日

======================================================================
🔄 场景2: 使用 update_state() 手动更新状态
======================================================================

  手动更新 name 字段...

  📊 更新后的状态:
     name: 王五
     birthday: 1992年8月20日

  手动更新 birthday 字段...

  📊 更新后的状态:
     name: 王五
     birthday: 1988年3月10日

======================================================================
🔄 场景3: 在新对话中访问之前保存的状态
======================================================================

用户输入: 我之前记录的人是谁?生日是什么时候?

  🤖 LLM: 我之前记录的人是谁?生日是什么时候?

  🤖 LLM: 您之前记录的是“李四”,生日是“1992年8月20日”。

======================================================================
📊 总结:自定义状态关键点
======================================================================

✅ 自定义状态的优势:
   - 在 State 中添加任意字段存储业务数据
   - 工具内部可以使用 Command 更新状态
   - 下游节点可以访问这些字段
   - 支持手动更新状态 (update_state)

📝 核心 API:
   1. State 定义: class State(TypedDict):
         messages: Annotated[list, add_messages]
         name: str
         birthday: str

   2. 工具更新状态:
         @tool
         def my_tool(..., tool_call_id: Annotated[str, InjectedToolCallId]):
             return Command(update={"name": "xxx"})

   3. 手动更新状态:
         graph.update_state(config, {"name": "xxx"})

   4. 查看状态:
         snapshot = graph.get_state(config)
         snapshot.values["name"]

🎯 适用场景:
   - 存储用户信息、配置等业务数据
   - 多步骤工作流中传递数据
   - 人工验证后保存确认信息

8.3 核心代码片段

from typing import Annotated
from typing_extensions import TypedDict
from langgraph.graph.message import add_messages
from langgraph.types import Command, interrupt
from langchain_core.tools import InjectedToolCallId, tool

# 1. 定义状态
class State(TypedDict):
    messages: Annotated[list, add_messages]
    name: str
    birthday: str

# 2. 定义工具(更新状态)
@tool
def human_assistance(
    name: str, 
    birthday: str, 
    tool_call_id: Annotated[str, InjectedToolCallId]
) -> Command:
    human_response = interrupt({"name": name, "birthday": birthday})

    return Command(update={
        "name": human_response.get("name", name),
        "birthday": human_response.get("birthday", birthday),
        "messages": [ToolMessage("已更新", tool_call_id=tool_call_id)]
    })

# 3. 构建图
graph = StateGraph(State)
graph.add_node("chatbot", chatbot_node)
graph.add_node("tools", ToolNode(tools=[human_assistance]))
# ... 添加边 ...

# 4. 编译(需要 MemorySaver)
memory = MemorySaver()
graph = graph.compile(checkpointer=memory)

# 5. 执行
config = {"configurable": {"thread_id": "xxx"}}
graph.invoke({"messages": [...]}, config)

# 6. 恢复执行
graph.invoke(Command(resume={...}), config)

# 7. 手动更新
graph.update_state(config, {"name": "新名字"})

九、API 速查表

API 用途 示例
State 定义状态结构 class State(TypedDict): name: str
add_messages 消息合并 reducer messages: Annotated[list, add_messages]
InjectedToolCallId 注入工具调用 ID tool_call_id: Annotated[str, InjectedToolCallId]
Command(update=...) 工具内更新状态 return Command(update={"name": "xxx"})
graph.update_state() 手动更新状态 graph.update_state(config, {"name": "xxx"})
graph.get_state() 查看状态 snapshot.values["name"]

十、最佳实践

10.1 状态字段设计

# ✅ 好的设计:字段职责清晰
class State(TypedDict):
    messages: Annotated[list, add_messages]
    user_name: str           # 用户姓名
    user_birthday: str       # 用户生日
    order_status: str        # 订单状态

# ❌ 不好的设计:字段过于笼统
class State(TypedDict):
    messages: Annotated[list, add_messages]
    data: dict  # 什么都往里塞,难以维护

10.2 状态更新时机

场景 推荐方式
节点处理结果 节点返回 {"field": value}
工具处理结果 Command(update={...})
人工修正 update_state()
外部系统同步 update_state()

10.3 注意事项

  1. 必须有 MemorySaver:自定义状态需要持久化
  2. 使用 thread_id:区分不同会话的状态
  3. 避免过多字段:状态太复杂会影响性能
  4. 字段类型明确:使用 TypedDict 提供类型提示

十一、延伸阅读

Logo

有“AI”的1024 = 2048,欢迎大家加入2048 AI社区

更多推荐