智能体工作流的透明化:LangGraph 动态执行路径的实时渲染

在构建复杂的智能体(AI Agent)系统时,我们常常面临一个核心挑战:理解其内部运作机制。当智能体开始执行任务、做出决策、使用工具,并根据环境反馈进行迭代时,其内部路径和状态变化往往是“黑箱”的。这种不透明性不仅增加了调试难度,也使得终端用户难以信任或理解智能体的行为。LangGraph作为一个强大的框架,用于构建有状态、循环和多参与者的智能体工作流,其固有的复杂性使得对执行路径的实时可视化变得尤为重要。

本讲座将深入探讨“Collaboration State Visualization”这一主题,即如何实时渲染LangGraph的动态执行路径,并将其呈现给终端用户。我们将从LangGraph的执行模型出发,逐步拆解实现实时可视化的核心组件,并通过具体的代码示例展示如何在后端捕获事件,以及如何在前端构建一个交互式的可视化界面。

1. LangGraph 执行模型概览

LangGraph的核心是一个有向图,其中节点(nodes)代表计算单元(如调用LLM、执行工具、条件判断),边(edges)定义了节点之间的流转路径。LangGraph的独特之处在于其支持状态(state)管理和循环(cycles),这使得构建复杂的、具备记忆和迭代能力的智能体成为可能。

LangGraph 的关键概念:

  • State (状态): LangGraph的核心。它是一个在整个图执行过程中共享和更新的字典或对象。每个节点接收当前状态作为输入,并返回一个更新后的状态。
  • Nodes (节点): 执行具体操作的单元。可以是LLM调用、工具调用、自定义Python函数等。
  • Edges (边): 定义了状态在节点之间的传递路径。可以是固定的(add_edge),也可以是条件性的(add_conditional_edges),根据状态决定下一个要执行的节点。
  • Checkpoints (检查点): LangGraph可以存储和恢复执行状态,这对于长时间运行的智能体和调试非常有用。
  • Cycles (循环): 允许智能体在特定条件下重复访问某些节点,形成迭代或决策循环。

执行流程:

当LangGraph被调用时,它从一个初始状态开始,并从图的起点节点开始执行。每个节点完成其操作后,会根据定义的边和条件将控制权传递给下一个节点,同时更新全局状态。这个过程持续进行,直到达到一个终止节点(如END)或满足某个退出条件。

可视化LangGraph的动态执行路径,本质上就是要追踪这个状态流转和节点激活的过程,并将其以图形化的方式呈现出来。这意味着我们需要:

  1. 捕获事件: 在每个节点被激活、完成、状态更新时,捕获这些事件。
  2. 传输事件: 将捕获到的事件实时发送到客户端(前端)。
  3. 渲染路径: 在前端根据接收到的事件动态更新图的显示,高亮当前活跃节点,展示状态变化。

2. 核心组件:构建实时可视化系统

要实现LangGraph的实时可视化,我们需要以下核心组件协同工作:

  1. LangGraph 仪表化 (Backend/Server-side): 负责捕获LangGraph执行过程中的关键事件和状态变化。
  2. 事件发射机制 (Backend/Server-side): 负责将捕获到的事件实时推送给前端客户端。
  3. 前端可视化界面 (Client-side): 负责接收事件、动态渲染图结构、并展示执行路径和状态。
2.1 LangGraph 仪表化:事件捕获

LangChain和LangGraph提供了强大的回调机制,允许我们在智能体执行的各个阶段插入自定义逻辑。BaseCallbackHandler是我们的主要工具。

LangChain CallbackHandler 的作用:

BaseCallbackHandler是一个抽象基类,定义了一系列可以在LLM调用、链(chain)执行、工具使用等事件发生时被触发的方法。通过继承并实现这些方法,我们可以捕获到LangGraph执行的细粒度事件。

对于LangGraph的执行可视化,我们主要关注以下几类事件:

  • 节点进入/退出: 当执行流进入或离开一个节点时。
  • 工具调用/返回: 当智能体决定使用工具或工具返回结果时。
  • LLM 调用/返回: 当智能体与语言模型交互时。
  • 状态更新: 最关键的是LangGraph的全局状态如何变化。

自定义 LangGraphVisualizationCallback

我们将创建一个自定义的回调处理器,用于收集和结构化这些事件。

from langgraph.graph import StateGraph
from langgraph.checkpoint import MemorySaver
from langchain_core.runnables import RunnableConfig
from langchain_core.callbacks import BaseCallbackHandler
from typing import Any, Dict, List, Optional
import asyncio
import json
import time

# 假设我们有一个事件队列或直接的WebSocket连接发送器
# 这里用一个简单的函数模拟事件发送
async def send_event_to_client(event_data: Dict):
    """
    模拟将事件数据发送到客户端的函数。
    在实际应用中,这会通过WebSocket或SSE发送。
    """
    # print(f"Sending event: {json.dumps(event_data, indent=2)}")
    # 实际场景中,这里会将event_data推送到一个WebSocket连接或SSE流
    pass

class LangGraphVisualizationCallback(BaseCallbackHandler):
    """
    自定义LangGraph回调处理器,用于捕获执行事件并将其结构化发送。
    """
    def __init__(self, session_id: str, event_sender_func: callable):
        self.session_id = session_id
        self.event_sender_func = event_sender_func
        self._current_chain_id: Optional[str] = None
        self._node_call_stack: List[Dict[str, Any]] = []

    async def _emit_event(self, event_type: str, payload: Dict):
        """通用事件发射器"""
        event = {
            "session_id": self.session_id,
            "timestamp": time.time(),
            "event_type": event_type,
            "payload": payload
        }
        await self.event_sender_func(event)

    async def on_chain_start(
        self, serialized: Dict[str, Any], **kwargs: Any
    ) -> None:
        """Called when a chain starts running."""
        # LangGraph的节点(node)在LangChain中被视为RunnableSequence或RunnableParallel,
        # 它们会触发on_chain_start/end。我们需要区分是LangGraph本身的运行还是内部节点的运行。

        # 识别当前运行的是哪个LangGraph节点
        # LangGraph内部的节点通常会有"name"或"id"字段
        node_name = serialized.get("name") or serialized.get("id")

        # 检查是否是LangGraph本身的顶级运行,或者是一个内部节点
        if node_name == "Graph": # Top-level LangGraph run
            self._current_chain_id = kwargs.get("run_id")
            await self._emit_event(
                "graph_start",
                {
                    "run_id": str(kwargs.get("run_id")),
                    "inputs": kwargs.get("inputs"),
                    "graph_name": serialized.get("name", "LangGraph")
                }
            )
        elif self._current_chain_id: # Inside a LangGraph run, this is likely a node
            # Push node to stack to track nested calls
            self._node_call_stack.append({
                "run_id": str(kwargs.get("run_id")),
                "node_name": node_name,
                "type": serialized.get("lc_serializable", False) and serialized["lc_serializable"][0],
                "start_time": time.time()
            })
            await self._emit_event(
                "node_start",
                {
                    "run_id": str(kwargs.get("run_id")),
                    "parent_run_id": str(kwargs.get("parent_run_id")),
                    "node_name": node_name,
                    "inputs": kwargs.get("inputs"),
                    "node_type": serialized.get("lc_serializable", False) and serialized["lc_serializable"][0],
                }
            )

    async def on_chain_end(
        self, outputs: Dict[str, Any], **kwargs: Any
    ) -> None:
        """Called when a chain ends running."""
        if self._current_chain_id == str(kwargs.get("run_id")): # Top-level LangGraph run ends
            await self._emit_event(
                "graph_end",
                {
                    "run_id": str(kwargs.get("run_id")),
                    "outputs": outputs
                }
            )
            self._current_chain_id = None
        elif self._node_call_stack and self._node_call_stack[-1]["run_id"] == str(kwargs.get("run_id")):
            node_info = self._node_call_stack.pop()
            await self._emit_event(
                "node_end",
                {
                    "run_id": str(kwargs.get("run_id")),
                    "node_name": node_info["node_name"],
                    "outputs": outputs,
                    "duration": time.time() - node_info["start_time"]
                }
            )

            # Additional logic to capture state changes after node end (requires accessing graph state)
            # This is tricky with standard callbacks as `state` is not directly passed here.
            # We would need to either:
            # 1. Access the `graph.get_state()` if within the graph run context (difficult with async callbacks)
            # 2. Modify the LangGraph source slightly to pass state to callbacks (more invasive)
            # 3. Use LangGraph's checkpointer to read the latest state (adds latency)
            # For simplicity, we'll assume state diffs are sent separately or inferred on frontend.
            # A more robust solution might involve wrapping LangGraph's `_execute_node` method.
            # For this example, we'll simulate state updates separately.

    async def on_tool_start(self, serialized: Dict[str, Any], input_str: str, **kwargs: Any) -> None:
        """Called when a tool starts running."""
        tool_name = serialized.get("name", "unknown_tool")
        await self._emit_event(
            "tool_start",
            {
                "run_id": str(kwargs.get("run_id")),
                "parent_run_id": str(kwargs.get("parent_run_id")),
                "tool_name": tool_name,
                "input": input_str
            }
        )

    async def on_tool_end(
        self, output: Any, **kwargs: Any
    ) -> None:
        """Called when a tool ends running."""
        tool_name = kwargs.get("name", "unknown_tool") # Tool name might be in kwargs
        await self._emit_event(
            "tool_end",
            {
                "run_id": str(kwargs.get("run_id")),
                "parent_run_id": str(kwargs.get("parent_run_id")),
                "tool_name": tool_name,
                "output": output
            }
        )

    async def on_llm_start(
        self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any
    ) -> None:
        """Called when an LLM starts running."""
        await self._emit_event(
            "llm_start",
            {
                "run_id": str(kwargs.get("run_id")),
                "parent_run_id": str(kwargs.get("parent_run_id")),
                "model_name": serialized.get("name", "unknown_llm"),
                "prompts": prompts
            }
        )

    async def on_llm_end(
        self, response: Any, **kwargs: Any
    ) -> None:
        """Called when an LLM ends running."""
        await self._emit_event(
            "llm_end",
            {
                "run_id": str(kwargs.get("run_id")),
                "parent_run_id": str(kwargs.get("parent_run_id")),
                "response": response.dict() if hasattr(response, 'dict') else str(response)
            }
        )

    # 还需要一个机制来捕获LangGraph的状态变化。
    # 这不是BaseCallbackHandler直接提供的,通常需要我们自行在LangGraph的节点函数中
    # 或者通过定制LangGraph的执行逻辑来实现。
    # 一个可行的方法是,在每个节点函数执行完毕后,手动发送当前Graph状态的差异。
    async def emit_state_update(self, run_id: str, current_state: Dict, previous_state: Dict):
        """
        手动发送状态更新事件。
        这个方法需要在LangGraph的每个节点完成其逻辑后被调用。
        """
        state_diff = self._calculate_state_diff(previous_state, current_state)
        await self._emit_event(
            "state_update",
            {
                "run_id": run_id,
                "current_state": current_state,
                "state_diff": state_diff
            }
        )

    def _calculate_state_diff(self, old_state: Dict, new_state: Dict) -> Dict:
        """计算状态差异"""
        diff = {}
        for key, new_value in new_state.items():
            old_value = old_state.get(key)
            if new_value != old_value:
                diff[key] = {"old": old_value, "new": new_value}
        for key in old_state:
            if key not in new_state:
                diff[key] = {"old": old_value, "new": None} # Key removed
        return diff

关于状态更新的补充说明:

捕获LangGraph的state变化是可视化的关键。BaseCallbackHandler本身没有直接的on_state_update方法。要获取每次节点执行后的状态,一种常见且相对简洁的方法是:

  1. 在LangGraph的StateGraph中定义每个节点时,将一个包装器函数作为节点逻辑。
  2. 这个包装器函数在调用实际的节点逻辑之前和之后,记录并比较状态,然后触发emit_state_update

例如,如果一个节点 my_node 的原始函数是 my_node_logic(state)

async def wrapped_node_logic(state: Dict, config: RunnableConfig):
    # 获取当前LangGraph的run_id,这需要从config中提取
    run_id = str(config.get("run_id", "unknown"))

    # 模拟获取前一个状态 (需要一个机制来存储或从checkpointer获取)
    # 实际应用中,如果使用MemorySaver,可以通过run_id访问
    # 或者如果LangGraph的节点函数能访问到checkpointer
    # For simplicity, let's assume `previous_state` is passed or retrieved
    previous_state = state.copy() # Simplistic: assumes input state is previous state

    # 执行原始节点逻辑
    result = await my_node_logic(state) # my_node_logic returns updated state or dict for state update

    # 获取节点执行后的新状态
    # 如果my_node_logic返回的是更新字典,LangGraph会合并它
    # 所以 `result` 已经是更新后的部分状态,需要与 `previous_state` 合并得到 `current_state`
    # LangGraph的节点函数签名是 `Callable[[State], PartialState]`
    # 所以 `result` 是 `PartialState`

    # 这里我们简化,假设 `result` 是完整的新状态。
    # 真实情况中,需要从checkpointer或LangGraph内部获取整个状态。
    # 鉴于LangGraph的执行机制,`state` 参数就是当前的完整状态。
    # 节点函数返回的是对这个状态的 *更新*。

    # 捕获状态更新的最佳时机是在LangGraph的内部执行逻辑中。
    # 如果无法修改LangGraph核心,那么在节点函数前后记录是最直接的方式。

    # 模拟在节点函数内部捕获状态变化
    # 注意:这里的 `callback_handler` 实例需要被传递到节点函数中。
    # 这是一个挑战,因为LangGraph的节点通常不直接接收回调实例。
    # 通常的解决方案是将callback_handler作为全局变量或通过闭包/偏函数传递。

    # 假设callback_handler是可访问的
    # await callback_handler.emit_state_update(run_id, current_full_state, previous_full_state)

    return result

上述的LangGraphVisualizationCallback已经足够捕获节点和工具的启动/结束事件。对于状态的实时差异,我们可能需要更深入地集成或在业务逻辑层面手动触发。在许多情况下,仅仅显示节点激活和完成,并在侧边栏展示当前完整状态就已足够。

2.2 事件发射机制:实时数据传输

捕获到事件后,我们需要一个高效且低延迟的方式将其从后端发送到前端。

1. Server-Sent Events (SSE):

  • 特点: 单向通信(服务器到客户端),基于HTTP,简单易用,浏览器原生支持。
  • 优点: 易于实现,无需复杂协议,自动重连,适合推送通知、实时日志等场景。
  • 缺点: 只能从服务器推送,客户端无法主动发送数据给服务器(除非结合传统HTTP请求)。

后端实现 (FastAPI + SSE):

# app_backend.py (FastAPI)
from fastapi import FastAPI, Request
from fastapi.responses import StreamingResponse
from fastapi.middleware.cors import CORSMiddleware
import asyncio
import uuid
import json
from collections import defaultdict

# ... (Previous LangGraphVisualizationCallback and LangGraph definition) ...

app = FastAPI()

# 允许跨域请求,前端通常运行在不同端口
app.add_middleware(
    CORSMiddleware,
    allow_origins=["*"], # 生产环境应限制为特定域名
    allow_credentials=True,
    allow_methods=["*"],
    allow_headers=["*"],
)

# 存储所有活跃的SSE连接
sse_connections = defaultdict(list) # session_id -> list of Queue

async def event_generator(session_id: str):
    """为每个连接生成SSE事件"""
    queue = asyncio.Queue()
    sse_connections[session_id].append(queue)
    try:
        while True:
            event_data = await queue.get()
            yield f"data: {json.dumps(event_data)}nn"
    except asyncio.CancelledError:
        print(f"SSE connection for session {session_id} cancelled.")
    finally:
        sse_connections[session_id].remove(queue)
        print(f"SSE connection for session {session_id} closed.")

# 用于将事件发送到所有订阅了特定session_id的SSE连接
async def broadcast_event(session_id: str, event_data: Dict):
    for queue in sse_connections[session_id]:
        await queue.put(event_data)

@app.get("/events/{session_id}")
async def sse_endpoint(session_id: str, request: Request):
    """SSE端点,客户端通过此连接接收事件"""
    return StreamingResponse(event_generator(session_id), media_type="text/event-stream")

# 模拟LangGraph执行的函数
async def run_langgraph_simulation(session_id: str, input_message: str):
    print(f"Starting LangGraph run for session {session_id} with input: {input_message}")

    # 初始化回调处理器,传入广播函数
    callback_handler = LangGraphVisualizationCallback(session_id, lambda event: broadcast_event(session_id, event))

    # 假设LangGraph定义在这里
    # For demonstration, let's define a simple graph:
    # State: {"messages": List[str], "next_node": str}
    class GraphState(Dict):
        messages: List[str] = []
        next_node: str = "start"

    builder = StateGraph(GraphState)

    def node_start(state: GraphState):
        messages = state["messages"] + ["Node: Start"]
        return {"messages": messages, "next_node": "process"}

    def node_process(state: GraphState):
        messages = state["messages"] + ["Node: Process"]
        return {"messages": messages, "next_node": "finish"}

    def node_finish(state: GraphState):
        messages = state["messages"] + ["Node: Finish"]
        return {"messages": messages, "next_node": "END"}

    builder.add_node("start_node", node_start)
    builder.add_node("process_node", node_process)
    builder.add_node("finish_node", node_finish)

    builder.set_entry_point("start_node")
    builder.add_edge("start_node", "process_node")
    builder.add_edge("process_node", "finish_node")
    builder.add_edge("finish_node", "END")

    graph = builder.compile()

    # 运行LangGraph,并将回调处理器传入
    initial_state = {"messages": [input_message], "next_node": "start"}

    # 模拟LangGraph执行中的状态更新
    # 这是一个手动模拟,真实场景中需要更精细的集成
    current_state = initial_state
    previous_state = {}
    await callback_handler.emit_state_update(session_id, current_state, previous_state)

    for i, step_state in enumerate(graph.stream(initial_state, config={"callbacks": [callback_handler]})):
        # print(f"Step {i} state: {step_state}")
        previous_state = current_state.copy()
        current_state.update(step_state) # LangGraph streams partial state updates
        await callback_handler.emit_state_update(session_id, current_state, previous_state)
        await asyncio.sleep(0.5) # Simulate work

    print(f"LangGraph run for session {session_id} finished.")

@app.post("/start_graph_run/{session_id}")
async def start_graph_run_endpoint(session_id: str, input_data: Dict):
    """触发LangGraph运行的API"""
    input_message = input_data.get("message", "Default input")
    # 在后台启动LangGraph运行,不阻塞HTTP响应
    asyncio.create_task(run_langgraph_simulation(session_id, input_message))
    return {"message": "LangGraph run started", "session_id": session_id}

# 示例:获取初始图结构(静态定义)
@app.get("/graph_definition")
async def get_graph_definition():
    # 假设我们有一个预定义的图结构,用于前端初始化
    # 真实的LangGraph对象可以被序列化为字典或使用graphviz等工具生成
    simple_graph_definition = {
        "nodes": [
            {"id": "start_node", "label": "Start"},
            {"id": "process_node", "label": "Process"},
            {"id": "finish_node", "label": "Finish"},
            {"id": "END", "label": "END"}
        ],
        "edges": [
            {"id": "e1", "source": "start_node", "target": "process_node"},
            {"id": "e2", "source": "process_node", "target": "finish_node"},
            {"id": "e3", "source": "finish_node", "target": "END"}
        ]
    }
    return simple_graph_definition

2. WebSockets:

  • 特点: 双向全双工通信,基于TCP,提供更低的延迟和更高的效率。
  • 优点: 客户端和服务器可以随时互相发送数据,适合需要双向交互的场景(如暂停/恢复、调试指令)。
  • 缺点: 比SSE更复杂,需要处理连接管理、心跳、断线重连等。

后端实现 (FastAPI + WebSockets):

# app_backend.py (FastAPI - WebSocket version)
# ... (Imports and CORS middleware as above) ...

from fastapi import WebSocket, WebSocketDisconnect

# 存储所有活跃的WebSocket连接
active_websocket_connections: Dict[str, List[WebSocket]] = defaultdict(list)

# 用于将事件发送到所有订阅了特定session_id的WebSocket连接
async def broadcast_websocket_event(session_id: str, event_data: Dict):
    if session_id in active_websocket_connections:
        for connection in active_websocket_connections[session_id]:
            try:
                await connection.send_json(event_data)
            except RuntimeError as e:
                print(f"Error sending to WebSocket {connection}: {e}")
            except WebSocketDisconnect:
                print(f"WebSocket {connection} disconnected during broadcast.")
                active_websocket_connections[session_id].remove(connection)

@app.websocket("/ws/{session_id}")
async def websocket_endpoint(websocket: WebSocket, session_id: str):
    await websocket.accept()
    active_websocket_connections[session_id].append(websocket)
    try:
        while True:
            # WebSocket也可以接收客户端消息,例如暂停/恢复指令
            data = await websocket.receive_text()
            print(f"Received message from client {session_id}: {data}")
            # 处理客户端指令...
    except WebSocketDisconnect:
        active_websocket_connections[session_id].remove(websocket)
        print(f"WebSocket {session_id} disconnected.")
    except Exception as e:
        print(f"WebSocket error for session {session_id}: {e}")
    finally:
        if not active_websocket_connections[session_id]:
            del active_websocket_connections[session_id]

# run_langgraph_simulation 函数需要修改,将event_sender_func改为broadcast_websocket_event
# ... (Rest of the FastAPI app including /start_graph_run and /graph_definition) ...

对于LangGraph的执行,由于它本质上是服务器驱动的,SSE通常是一个更简单的起点。如果需要客户端的实时交互(如用户点击暂停、查看更详细的历史状态等),那么WebSocket是更合适的选择。在本次讲座中,我们将主要以WebSocket为例进行前端的讲解,因为它提供了最大的灵活性。

2.3 前端可视化界面:React + React Flow

前端负责接收后端推送的事件,并动态更新图的可视化。我们将使用React作为UI框架,并结合React Flow作为图渲染库。React Flow是一个强大的库,专为构建交互式节点编辑器和图表而设计,非常适合我们的需求。

技术栈选择:

  • UI 框架: React (也可选Vue, Svelte)
  • 图渲染库: React Flow (也可选D3.js, vis.js Network)
  • 数据流管理: React Hooks (useState, useEffect, useRef)
  • 布局算法: dagre (或 elkjs) 用于自动布局

前端核心功能:

  1. 初始化图结构: 从后端获取初始的静态图定义(节点和边)。
  2. 建立 WebSocket 连接: 连接到后端 /ws/{session_id} 端点。
  3. 事件监听与处理: 监听 WebSocket 消息,根据事件类型更新图状态。
  4. 动态渲染: 根据事件实时更新节点的颜色、边的高亮、显示输入/输出/状态变化。
  5. 自动布局: 确保图在更新时保持清晰的布局。

前端代码示例 (React + React Flow):

首先,确保你已经安装了必要的依赖:
npm install react react-dom @react-flow/core @react-flow/minimap @react-flow/controls dagre

// src/App.js
import React, { useState, useEffect, useRef, useCallback } from 'react';
import ReactFlow, {
  Controls,
  MiniMap,
  Background,
  applyNodeChanges,
  applyEdgeChanges,
  addEdge,
  useReactFlow,
} from '@react-flow/core';
import { MarkerType } from '@react-flow/core'; // For custom arrow markers
import dagre from 'dagre'; // For automatic layout
import './App.css'; // Basic CSS for styling
import { v4 as uuidv4 } from 'uuid'; // For generating unique session IDs

// --- Dagre 布局函数 ---
const dagreGraph = new dagre.graphlib.Graph();
dagreGraph.setDefaultEdgeLabel(() => ({}));

const nodeWidth = 170;
const nodeHeight = 36;

const getLayoutedElements = (nodes, edges, direction = 'TB') => {
  const isHorizontal = direction === 'LR';
  dagreGraph.setGraph({ rankdir: direction });

  nodes.forEach((node) => {
    dagreGraph.setNode(node.id, { width: nodeWidth, height: nodeHeight });
  });

  edges.forEach((edge) => {
    dagreGraph.setEdge(edge.source, edge.target);
  });

  dagre.layout(dagreGraph);

  nodes.forEach((node) => {
    const nodeWithPosition = dagreGraph.node(node.id);
    node.targetPosition = isHorizontal ? 'left' : 'top';
    node.sourcePosition = isHorizontal ? 'right' : 'bottom';

    // We are shifting the dagre node position (anchor=center) to the top-left
    // so it matches React Flow's node anchor point.
    node.position = {
      x: nodeWithPosition.x - nodeWidth / 2,
      y: nodeWithPosition.y - nodeHeight / 2,
    };

    return node;
  });

  return { nodes, edges };
};
// --- End Dagre Layout ---

function FlowVisualizer() {
  const [nodes, setNodes] = useState([]);
  const [edges, setEdges] = useState([]);
  const [session_id, setSessionId] = useState('');
  const [currentStatus, setCurrentStatus] = useState({}); // { node_id: "active" | "completed" | "error" }
  const [nodeDetails, setNodeDetails] = useState({}); // { node_id: { input, output, ... } }
  const [latestGraphState, setLatestGraphState] = useState({});
  const [activeNodeId, setActiveNodeId] = useState(null);

  const reactFlowWrapper = useRef(null);
  const { fitView } = useReactFlow();

  const websocket = useRef(null);

  // 1. 初始化会话ID并建立WebSocket连接
  useEffect(() => {
    const newSessionId = uuidv4();
    setSessionId(newSessionId);

    // 获取初始图定义
    fetch('http://localhost:8000/graph_definition')
      .then(res => res.json())
      .then(initialGraph => {
        const initialNodes = initialGraph.nodes.map(node => ({
          id: node.id,
          data: { label: node.label },
          position: { x: 0, y: 0 }, // Initial dummy position
          style: { background: '#eee', color: '#333', border: '1px solid #ccc' },
        }));
        const initialEdges = initialGraph.edges.map(edge => ({
          id: edge.id,
          source: edge.source,
          target: edge.target,
          type: 'default',
          markerEnd: { type: MarkerType.ArrowClosed, color: '#666' },
          style: { strokeWidth: 1.5, stroke: '#666' },
        }));

        const { nodes: layoutedNodes, edges: layoutedEdges } = getLayoutedElements(initialNodes, initialEdges);
        setNodes(layoutedNodes);
        setEdges(layoutedEdges);
        fitView(); // Fit graph to view after initial layout
      });

    // 建立WebSocket连接
    websocket.current = new WebSocket(`ws://localhost:8000/ws/${newSessionId}`);

    websocket.current.onopen = () => {
      console.log('WebSocket connected');
    };

    websocket.current.onmessage = (event) => {
      const eventData = JSON.parse(event.data);
      console.log('Received event:', eventData);
      handleGraphEvent(eventData);
    };

    websocket.current.onclose = () => {
      console.log('WebSocket disconnected');
    };

    websocket.current.onerror = (error) => {
      console.error('WebSocket error:', error);
    };

    return () => {
      websocket.current?.close();
    };
  }, [fitView]);

  // 2. 事件处理函数
  const handleGraphEvent = useCallback((event) => {
    const { event_type, payload } = event;

    setNodes((nds) => {
      let newNodes = nds.map((node) => {
        let newNode = { ...node };

        // Reset all nodes to default before applying current active/completed states
        newNode.style = { background: '#eee', color: '#333', border: '1px solid #ccc' };
        newNode.data.label = nds.find(n => n.id === node.id)?.data.label || node.id; // Keep original label

        // Apply dynamic styles based on currentStatus
        if (currentStatus[node.id] === 'active') {
          newNode.style = { ...newNode.style, background: '#a7e9af', border: '1px solid #4CAF50' }; // Active green
        } else if (currentStatus[node.id] === 'completed') {
          newNode.style = { ...newNode.style, background: '#cceeff', border: '1px solid #3498db' }; // Completed blue
        } else if (currentStatus[node.id] === 'error') {
          newNode.style = { ...newNode.style, background: '#ffcccc', border: '1px solid #e74c3c' }; // Error red
        }
        return newNode;
      });

      if (event_type === 'node_start') {
        const nodeId = payload.node_name;
        setActiveNodeId(nodeId);
        setCurrentStatus((prev) => ({ ...prev, [nodeId]: 'active' }));
        setNodeDetails((prev) => ({
          ...prev,
          [nodeId]: { ...(prev[nodeId] || {}), inputs: payload.inputs },
        }));
        // Update node style for active node
        newNodes = newNodes.map((node) =>
          node.id === nodeId
            ? { ...node, style: { background: '#a7e9af', border: '2px solid #4CAF50' } } // Highlight active
            : node
        );
      } else if (event_type === 'node_end') {
        const nodeId = payload.node_name;
        setActiveNodeId(null);
        setCurrentStatus((prev) => ({ ...prev, [nodeId]: 'completed' }));
        setNodeDetails((prev) => ({
          ...prev,
          [nodeId]: { ...(prev[nodeId] || {}), outputs: payload.outputs },
        }));
        // Update node style for completed node
        newNodes = newNodes.map((node) =>
          node.id === nodeId
            ? { ...node, style: { background: '#cceeff', border: '2px solid #3498db' } } // Mark completed
            : node
        );
      } else if (event_type === 'tool_start') {
        // You might want to show tool calls within the parent node's details
        const parentNodeId = payload.parent_run_id; // This needs mapping to node_name
        // For simplicity, let's assume parent_run_id is the current active node's run_id
        // A more robust mapping is needed for nested runs
        console.log(`Tool ${payload.tool_name} started with input: ${payload.input}`);
      } else if (event_type === 'state_update') {
        setLatestGraphState(payload.current_state);
        // Optionally update node labels or data based on state changes if relevant
        // e.g., if a node's label should reflect a state variable.
      } else if (event_type === 'graph_start') {
        setCurrentStatus({}); // Reset status for a new graph run
        setNodeDetails({});
        setLatestGraphState({});
      }

      return newNodes;
    });

    setEdges((eds) => {
      // Logic to highlight active edges based on node transitions
      // This is more complex as edge activation is inferred from node_start/end sequence.
      // For simplicity, we'll keep edges static in this example.
      // A more advanced approach would track the last active node and highlight the edge leading to the new active node.
      return eds.map(edge => {
        // Reset edge style
        edge.style = { strokeWidth: 1.5, stroke: '#666' };
        edge.markerEnd = { type: MarkerType.ArrowClosed, color: '#666' };

        // Example: Highlight edge if source is completed and target is active
        if (currentStatus[edge.source] === 'completed' && activeNodeId === edge.target) {
            edge.style = { strokeWidth: 2.5, stroke: '#FFD700' }; // Gold for active path
            edge.markerEnd = { type: MarkerType.ArrowClosed, color: '#FFD700' };
        }
        return edge;
      });
    });

  }, [currentStatus, activeNodeId]); // Dependencies for useCallback

  const onNodesChange = useCallback(
    (changes) => setNodes((nds) => applyNodeChanges(changes, nds)),
    [setNodes]
  );
  const onEdgesChange = useCallback(
    (changes) => setEdges((eds) => applyEdgeChanges(changes, eds)),
    [setEdges]
  );
  const onConnect = useCallback(
    (connection) => setEdges((eds) => addEdge(connection, eds)),
    [setEdges]
  );

  // 3. 触发后端LangGraph执行
  const startGraphRun = async () => {
    if (!session_id) return;
    try {
      const response = await fetch(`http://localhost:8000/start_graph_run/${session_id}`, {
        method: 'POST',
        headers: { 'Content-Type': 'application/json' },
        body: JSON.stringify({ message: "Hello, LangGraph!" }),
      });
      const data = await response.json();
      console.log(data);
    } catch (error) {
      console.error('Failed to start graph run:', error);
    }
  };

  return (
    <div style={{ display: 'flex', height: '100vh', width: '100vw' }}>
      <div style={{ flexGrow: 1, height: '100%' }} ref={reactFlowWrapper}>
        <ReactFlow
          nodes={nodes}
          edges={edges}
          onNodesChange={onNodesChange}
          onEdgesChange={onEdgesChange}
          onConnect={onConnect}
          fitView
        >
          <Controls />
          <MiniMap />
          <Background variant="dots" gap={12} size={1} />
        </ReactFlow>
      </div>
      <div style={{ width: '300px', padding: '10px', borderLeft: '1px solid #eee', overflowY: 'auto' }}>
        <h2>LangGraph Visualizer</h2>
        <p>Session ID: {session_id}</p>
        <button onClick={startGraphRun} style={{ padding: '10px 20px', fontSize: '16px', cursor: 'pointer' }}>
          Start Graph Run
        </button>
        <hr />
        <h3>Current Graph State</h3>
        <pre style={{ whiteSpace: 'pre-wrap', wordBreak: 'break-all', fontSize: '12px', background: '#f5f5f5', padding: '5px' }}>
          {JSON.stringify(latestGraphState, null, 2)}
        </pre>
        <hr />
        <h3>Node Details</h3>
        {Object.entries(nodeDetails).map(([nodeId, details]) => (
          <div key={nodeId} style={{ marginBottom: '10px', padding: '8px', border: '1px solid #ddd', borderRadius: '4px' }}>
            <h4>Node: {nodeId} ({currentStatus[nodeId] || 'idle'})</h4>
            {details.inputs && (
              <>
                <strong>Inputs:</strong>
                <pre style={{ fontSize: '10px', background: '#e9e9e9', padding: '3px' }}>
                  {JSON.stringify(details.inputs, null, 2)}
                </pre>
              </>
            )}
            {details.outputs && (
              <>
                <strong>Outputs:</strong>
                <pre style={{ fontSize: '10px', background: '#e9e9e9', padding: '3px' }}>
                  {JSON.stringify(details.outputs, null, 2)}
                </pre>
              </>
            )}
          </div>
        ))}
      </div>
    </div>
  );
}

export default function App() {
  return (
    <ReactFlowProvider>
      <FlowVisualizer />
    </ReactFlowProvider>
  );
}

运行指导:

  1. 后端:
    • 将上述FastAPI代码保存为 app_backend.py
    • 安装依赖:pip install fastapi uvicorn websockets langchain_core langgraph
    • 运行:uvicorn app_backend:app --reload
  2. 前端:
    • 使用 create-react-app 创建一个新的React项目:npx create-react-app langgraph-viz-frontend
    • 进入项目目录:cd langgraph-viz-frontend
    • 安装依赖:npm install @react-flow/core @react-flow/minimap @react-flow/controls dagre uuid
    • 将上述React代码替换 src/App.js 的内容。
    • 添加基础CSS到 src/App.css (如果需要):
      body { margin: 0; font-family: -apple-system, BlinkMacSystemFont, 'Segoe UI', 'Roboto', 'Oxygen', 'Ubuntu', 'Cantarell', 'Fira Sans', 'Droid Sans', 'Helvetica Neue', sans-serif; }
      .react-flow__node {
          padding: 8px 12px;
          border-radius: 5px;
          font-size: 14px;
          text-align: center;
          min-width: 120px;
      }
    • 运行:npm start

现在,当你打开前端页面,点击“Start Graph Run”按钮时,后端会启动LangGraph执行,并通过WebSocket实时将事件推送给前端。前端的React Flow图将动态高亮当前活跃的节点,并在侧边栏显示最新的全局状态和每个节点的输入/输出。

3. 高级考量与增强

上述实现提供了一个基本框架,但实时可视化系统可以有许多高级特性:

  1. 状态差异可视化:
    • 不仅仅显示完整的最新状态,而是突出显示每次节点执行后状态的具体变化(增、删、改)。
    • 可以使用一个专门的UI组件来展示JSON diff。
  2. 历史记录与回放:
    • 后端可以持久化所有事件到数据库(如PostgreSQL, MongoDB)。
    • 前端提供时间轴或步骤控制,允许用户“倒带”或“快进”执行过程,以便详细分析。
    • 这需要前端维护一个事件列表,并在回放时根据事件重建图的状态。
  3. 用户交互与调试:
    • 暂停/恢复: 通过WebSocket从前端发送指令到后端,控制LangGraph的执行。
    • 检查节点: 点击节点显示更详细的运行时信息(如LLM的完整提示和响应、工具的详细日志)。
    • 注入修改: 在某些节点暂停时,允许用户修改状态变量,然后继续执行,实现“实时调试”。
  4. 图的复杂性管理:
    • 对于大型LangGraph,图可能会非常庞大。可以实现:
      • 子图折叠: 将一组相关节点折叠成一个高层节点。
      • 过滤: 根据节点类型、状态或执行历史过滤显示。
      • 搜索: 快速定位特定节点。
  5. 性能优化:
    • 事件节流/去抖: 防止在短时间内收到大量事件导致前端渲染卡顿。
    • 虚拟化: 对于大量节点和边,使用React Flow的虚拟化功能只渲染视口内的元素。
    • 后端事件队列: 使用Kafka、RabbitMQ等消息队列解耦LangGraph执行和事件发送,提高系统吞吐量和可靠性。
  6. 错误处理与日志:
    • 清晰地在UI中显示节点执行失败的原因和错误信息。
    • 集成后端日志到前端,方便调试。
  7. 部署与扩展:
    • 将后端和前端分别打包为Docker容器,使用Kubernetes进行部署和扩展。
    • 利用云服务(如AWS Lambda, Google Cloud Run)无服务器地运行LangGraph执行。

表格:事件类型及其可视化效果

事件类型 触发时机 可视化效果 侧边栏/详情区内容
graph_start LangGraph 工作流开始 重置所有节点状态,清除旧的详细信息 显示整个工作流的初始输入
node_start 执行流进入某个节点 节点高亮为“活跃”颜色(如绿色),可能显示加载动画 节点的输入参数(来自 LangGraph 状态)
tool_start 智能体调用工具 (如果工具是独立节点)高亮工具节点;(如果工具是LLM内部)在LLM节点或侧边栏显示工具调用详情 工具名称、调用参数
llm_start 智能体调用 LLM 关联的 LLM 节点高亮为“活跃”颜色 LLM 的提示(prompt)
llm_end LLM 调用返回结果 关联的 LLM 节点变为“完成”颜色 LLM 的响应(response),可能包含工具调用指令
tool_end 工具执行完成并返回结果 (如果工具是独立节点)工具节点变为“完成”颜色 工具的输出结果
node_end 执行流离开某个节点 节点高亮为“完成”颜色(如蓝色),移除活跃状态 节点的输出结果,可能显示状态变更摘要
state_update LangGraph 全局状态发生变化 (无直接节点变化) 显示 current_statestate_diff
graph_end LangGraph 工作流结束 所有节点保持其最终状态 显示整个工作流的最终输出
node_error 某个节点执行失败 节点高亮为“错误”颜色(如红色) 错误信息、堆栈追踪

4. 总结与展望

通过将LangGraph的强大状态管理与实时Web技术相结合,我们成功构建了一个能够透明化智能体工作流的系统。这种“Collaboration State Visualization”不仅极大地提升了开发者的调试效率,也让终端用户能够直观地理解智能体的决策过程和内部状态演变。从后端事件捕获、到实时事件传输,再到前端的动态图渲染,每一个环节都至关重要。

随着AI智能体变得日益复杂,对这些系统的可解释性和可控性的需求也愈发迫切。实时可视化提供了一个强大的窗口,让我们能够深入洞察智能体的“思维”,从而更好地构建、优化和信任这些下一代AI应用。未来的工作将集中于提高可视化系统的交互性、可扩展性,并集成更高级的分析功能,以满足不断进化的智能体生态系统的需求。

Logo

有“AI”的1024 = 2048,欢迎大家加入2048 AI社区

更多推荐