在这里插入图片描述

  【个人主页:玄同765

大语言模型(LLM)开发工程师中国传媒大学·数字媒体技术(智能交互与游戏设计)

深耕领域:大语言模型开发 / RAG知识库 / AI Agent落地 / 模型微调

技术栈:Python / LangChain/RAG(Dify+Redis+Milvus)| SQL/NumPy | FastAPI+Docker ️

工程能力:专注模型工程化部署、知识库构建与优化,擅长全流程解决方案 

     

「让AI交互更智能,让技术落地更高效」

欢迎技术探讨/项目合作! 关注我,解锁大模型与智能交互的无限可能!

引言

LangChain 1.0 的发布标志着这个框架从"原型开发工具"全面迈入"生产级解决方案"的新纪元。其中,中间件(Middleware) 是 LangChain 1.0 最核心的创新之一,它彻底改变了开发者控制和扩展 Agent 行为的方式。

本文将深入探讨 LangChain 1.0 中间件机制,对比 1.0 前后的差异,并通过实际代码示例帮助你掌握这一强大功能。


一、LangChain 1.0 之前的回调机制(Callback)

在 LangChain 1.0 之前,框架使用 Callback Handler 机制来实现对 LLM 应用执行过程的监控和干预。

1.1 Callback Handler 的基本概念

Callback Handler 是一种事件监听机制,允许开发者在 LLM 应用程序的各个阶段使用钩子(Hook)介入。通过继承 BaseCallbackHandler,开发者可以捕获以下事件:

  • on_llm_start / on_llm_end / on_llm_error:LLM 调用开始、结束、出错
  • on_chain_start / on_chain_end / on_chain_error:Chain 执行开始、结束、出错
  • on_tool_start / on_tool_end / on_tool_error:工具调用开始、结束、出错
  • on_agent_action / on_agent_finish:Agent 采取行动、完成执行

1.2 旧版 Callback Handler 示例

# LangChain 0.x/1.x 版本的回调处理器
from langchain_core.callbacks import BaseCallbackHandler
from langchain_core.outputs import LLMResult

class CustomCallbackHandler(BaseCallbackHandler):
    """自定义回调处理器"""

    def on_llm_start(self, serialized, prompts, **kwargs):
        """LLM 开始调用时触发"""
        print(f"🚀 LLM 调用开始,提示词数量: {len(prompts)}")

    def on_llm_end(self, response: LLMResult, **kwargs):
        """LLM 调用结束时触发"""
        print(f"✅ LLM 调用完成")

    def on_llm_error(self, error, **kwargs):
        """LLM 调用出错时触发"""
        print(f"❌ LLM 调用出错: {error}")

    def on_tool_start(self, serialized, input_str, **kwargs):
        """工具开始调用时触发"""
        print(f"🔧 工具调用: {serialized.get('name', 'unknown')}")

    def on_agent_action(self, action, **kwargs):
        """Agent 采取行动时触发"""
        print(f"🤖 Agent 行动: {action.tool} - {action.tool_input}")

# 使用回调处理器
from langchain_openai import ChatOpenAI

llm = ChatOpenAI(
    model="gpt-4",
    callbacks=[CustomCallbackHandler()]
)

1.3 旧版 Callback 的局限性

虽然 Callback Handler 提供了基本的监控能力,但它存在以下局限:

  1. 被动监听:只能观察和记录,无法主动修改请求和响应
  2. 粒度有限:事件触发点固定,无法精细控制执行流程
  3. 难以组合:多个回调处理器之间难以协同工作
  4. Agent 控制弱:对 Agent 的决策过程干预能力有限

二、LangChain 1.0 中间件机制(Middleware)

LangChain 1.0 引入了全新的 Middleware 机制,它是对 Callback 的重大升级,提供了更强大的流程控制能力。

2.1 Middleware 的核心概念

Middleware 是一种流程控制机制,用于在智能体执行过程中拦截、修改或增强请求与响应的处理逻辑,而无需修改核心 Agent 或工具的代码。

与 Callback 的"被动监听"不同,Middleware 采用"主动拦截"模式:

特性 Callback (0.x) Middleware (1.0)
工作模式 被动监听事件 主动拦截流程
修改能力 只读,无法修改 可读写,能修改请求和响应
控制粒度 固定事件点 任意切入点
组合能力 难以组合 支持链式组合
使用场景 日志、监控 流程控制、增强、安全

2.2 Middleware 的核心抽象与作用范围

LangChain 1.0 提供了多种中间件基类,用于不同组件的流程控制:

中间件类型 作用对象 适用场景
AgentMiddleware Agent 的模型调用 控制 Agent 决策过程中的请求和响应
ChainMiddleware Chain 执行流程 干预 Chain 中各组件的数据流
LLMMiddleware LLM 调用 直接包装语言模型调用

本文主要介绍 AgentMiddleware,它作用于 Agent 的模型调用流程,即在 Agent 决定采取行动之前对输入进行处理,以及在获得模型响应之后对输出进行处理。

注意AgentMiddleware 不会干预工具的实际执行过程,只控制 Agent 与模型之间的交互。如需控制工具执行,建议在工具函数内部自行实现控制逻辑。

开发者可以通过继承 AgentMiddleware 来实现自定义中间件:

from langchain.agents.middleware import AgentMiddleware
from langchain_core.messages import BaseMessage
from typing import Callable, List

class CustomMiddleware(AgentMiddleware):
    """自定义中间件示例"""

    def wrap_model_call(
        self,
        request: List[BaseMessage],
        handler: Callable[[List[BaseMessage]], BaseMessage]
    ) -> BaseMessage:
        """
        包装模型调用

        Args:
            request: 输入消息列表
            handler: 原始模型调用函数

        Returns:
            模型响应消息
        """
        # 在模型调用前执行逻辑
        print(f"📝 输入消息: {request}")

        # 可以修改请求
        modified_request = self._modify_request(request)

        # 调用原始模型
        response = handler(modified_request)

        # 在模型调用后执行逻辑
        print(f"💬 模型响应: {response}")

        # 可以修改响应
        modified_response = self._modify_response(response)

        return modified_response

    def _modify_request(self, request: List[BaseMessage]) -> List[BaseMessage]:
        """修改请求(示例:添加系统提示词)"""
        from langchain_core.messages import SystemMessage

        # 在请求开头添加系统消息
        system_msg = SystemMessage(content="你是一个乐于助人的助手。")
        return [system_msg] + request

    def _modify_response(self, response: BaseMessage) -> BaseMessage:
        """修改响应(示例:添加前缀)"""
        response.content = "[处理后的响应] " + response.content
        return response

2.3 中间件的工作流程

┌─────────────────────────────────────────────────────────────┐
│                     Agent 执行流程                          │
├─────────────────────────────────────────────────────────────┤
│                                                             │
│  1. 用户输入                                                │
│       │                                                     │
│       ▼                                                     │
│  ┌──────────────────────────────────────────────┐          │
│  │  Middleware 1: 输入预处理                     │          │
│  │  - 验证输入                                   │          │
│  │  - 添加上下文                                 │          │
│  └──────────────────────────────────────────────┘          │
│       │                                                     │
│       ▼                                                     │
│  ┌──────────────────────────────────────────────┐          │
│  │  Middleware 2: 安全检查                       │          │
│  │  - 敏感词过滤                                 │          │
│  │  - 权限验证                                   │          │
│  └──────────────────────────────────────────────┘          │
│       │                                                     │
│       ▼                                                     │
│  2. 模型调用 ◄────────────────────────────────────┐        │
│       │                                           │        │
│       ▼                                           │        │
│  3. 模型响应                                      │        │
│       │                                           │        │
│       ▼                                           │        │
│  ┌──────────────────────────────────────────────┐ │        │
│  │  Middleware 3: 响应处理                       │ │        │
│  │  - 格式化输出                                 │ │        │
│  │  - 添加引用                                   │ │        │
│  └──────────────────────────────────────────────┘ │        │
│       │                                           │        │
│       ▼                                           │        │
│  4. 工具调用 ─────────────────────────────────────┘        │
│       │                                                     │
│       ▼                                                     │
│  5. 循环或结束                                              │
│                                                             │
└─────────────────────────────────────────────────────────────┘

三、常见自定义中间件示例

注意:LangChain 1.0 核心库仅提供中间件基类(如 AgentMiddleware),并未直接提供开箱即用的业务中间件。以下示例是基于这些基类的典型实现模式,开发者需要根据实际需求自行实现或通过社区扩展库获取。

3.1 SummarizationMiddleware(摘要中间件)

当对话历史过长时,自动对历史消息进行摘要,减少 token 消耗。

from langchain.agents.middleware import AgentMiddleware
from langchain_openai import ChatOpenAI
from langchain_core.messages import BaseMessage
from typing import Callable, List

class SummarizationMiddleware(AgentMiddleware):
    """
    摘要中间件 - 基于 AgentMiddleware 基类的典型实现
    当对话历史超过 token 限制时,自动进行摘要
    """

    def __init__(self, llm, max_token_limit=2000, summary_prompt="请对以下对话进行摘要:"):
        self.llm = llm
        self.max_token_limit = max_token_limit
        self.summary_prompt = summary_prompt

    def wrap_model_call(self, request: List[BaseMessage], handler: Callable) -> BaseMessage:
        # 计算当前 token 数
        # 方法 1:简化估算(字符数 / 4,粗略估算)
        estimated_tokens = sum(len(msg.content) for msg in request) // 4

        # 方法 2:使用 tiktoken 精确计算(推荐用于生产环境)
        # import tiktoken
        # tokenizer = tiktoken.encoding_for_model("gpt-4")
        # total_tokens = sum(len(tokenizer.encode(msg.content)) for msg in request)

        if estimated_tokens > self.max_token_limit:
            # 对历史消息进行摘要(简化逻辑)
            print(f"📝 触发摘要:当前约 {estimated_tokens} tokens,超过限制 {self.max_token_limit}")
            # 实际实现中,这里应该调用 LLM 进行摘要
            # 例如:summary = self.llm.invoke(f"{self.summary_prompt}\n{conversation_history}")

        return handler(request)


# 使用示例
summarization_middleware = SummarizationMiddleware(
    llm=ChatOpenAI(model="gpt-4"),
    max_token_limit=2000,
    summary_prompt="请对以下对话历史进行简要摘要:"
)

3.2 HumanInTheLoopMiddleware(人工介入中间件)

在关键决策点暂停执行,等待人工确认后再继续。

实现思路:以下示例展示如何在工具函数内部实现人工介入逻辑。

from typing import Callable

class HumanInTheLoopToolWrapper:
    """
    人工介入工具包装器

    包装工具函数,在关键操作前暂停等待用户确认
    """

    def __init__(self, tool_func: Callable, tool_name: str, should_check: Callable):
        """
        Args:
            tool_func: 原始工具函数
            tool_name: 工具名称
            should_check: 判断是否需要人工确认的函数
        """
        self.tool_func = tool_func
        self.tool_name = tool_name
        self.should_check = should_check

    def __call__(self, *args, **kwargs):
        """包装工具调用"""
        # 判断是否需要人工介入
        action = {"tool": self.tool_name, "args": args, "kwargs": kwargs}
        if self.should_check(action):
            confirm = input(
                f"🔔 即将执行 {self.tool_name} 操作\n"
                f"   参数:{args}, {kwargs}\n"
                f"   是否继续?(y/n): "
            )
            if confirm.lower() != "y":
                raise ValueError("❌ 用户取消了操作")

        # 执行原始工具调用
        return self.tool_func(*args, **kwargs)


# 使用示例
def delete_file(file_path: str) -> str:
    """删除文件工具"""
    import os
    os.remove(file_path)
    return f"已删除文件: {file_path}"

# 包装工具函数
safe_delete_file = HumanInTheLoopToolWrapper(
    tool_func=delete_file,
    tool_name="delete_file",
    should_check=lambda action: "delete" in action["tool"]
)

3.3 RateLimitMiddleware(限流中间件)

控制 API 调用频率,避免触发速率限制。

生产环境注意:以下示例添加了线程锁,确保在多线程环境下的线程安全。

from langchain.agents.middleware import AgentMiddleware
import time
import threading

class CustomRateLimitMiddleware(AgentMiddleware):
    """
    自定义限流中间件 - 线程安全版本

    在生产环境中,必须使用线程锁保护请求计数,
    避免多线程竞争导致计数不准确。
    """

    def __init__(self, max_requests: int = 10, time_window: int = 60):
        self.max_requests = max_requests
        self.time_window = time_window
        self.requests = []
        self.lock = threading.Lock()  # 线程锁,确保线程安全

    def wrap_model_call(self, request, handler):
        current_time = time.time()

        # 加锁保护临界区
        with self.lock:
            # 清理过期请求记录
            self.requests = [t for t in self.requests if current_time - t < self.time_window]

            # 检查是否超过限制
            if len(self.requests) >= self.max_requests:
                wait_time = self.time_window - (current_time - self.requests[0])
                print(f"⏳ 触发限流,等待 {wait_time:.1f} 秒...")
                time.sleep(wait_time)

            # 记录本次请求
            self.requests.append(current_time)

        # 执行原始调用(在锁外执行,避免阻塞其他线程)
        return handler(request)


# 带降级处理的限流中间件(高级示例)
class RateLimitWithFallbackMiddleware(AgentMiddleware):
    """
    带降级处理的限流中间件

    当触发限流时,返回缓存的上次响应,而不是等待
    适用于对实时性要求高的场景
    """

    def __init__(self, max_requests: int = 10, time_window: int = 60):
        self.max_requests = max_requests
        self.time_window = time_window
        self.requests = []
        self.lock = threading.Lock()
        self.cache_response = None  # 缓存上次响应

    def wrap_model_call(self, request, handler):
        current_time = time.time()

        with self.lock:
            # 清理过期请求记录
            self.requests = [t for t in self.requests if current_time - t < self.time_window]

            # 检查是否超过限制
            if len(self.requests) >= self.max_requests:
                # 降级处理:返回缓存结果
                if self.cache_response is not None:
                    print("⏳ 触发限流,返回缓存响应")
                    return self.cache_response

                # 无缓存则等待
                wait_time = self.time_window - (current_time - self.requests[0])
                print(f"⏳ 触发限流,无缓存可用,等待 {wait_time:.1f} 秒...")
                time.sleep(wait_time)

            # 记录本次请求
            self.requests.append(current_time)

        # 执行原始调用并缓存结果
        response = handler(request)
        self.cache_response = response
        return response

四、实战:创建和使用自定义中间件

4.1 完整示例:日志记录中间件

import json
import time
from datetime import datetime
from typing import Callable, List, Optional

from langchain.agents import create_agent
from langchain.agents.middleware import AgentMiddleware
from langchain_core.messages import BaseMessage, HumanMessage, AIMessage
from langchain_openai import ChatOpenAI


class LoggingMiddleware(AgentMiddleware):
    """
    日志记录中间件

    记录所有模型调用和响应,支持导出为 JSON 格式
    """

    def __init__(self, log_file: Optional[str] = None):
        self.log_file = log_file
        self.logs = []

    def wrap_model_call(
        self,
        request: List[BaseMessage],
        handler: Callable[[List[BaseMessage]], BaseMessage]
    ) -> BaseMessage:
        """包装模型调用,记录日志"""

        # 记录请求信息
        start_time = time.time()
        request_log = {
            "timestamp": datetime.now().isoformat(),
            "type": "request",
            "messages": [
                {"role": msg.type, "content": msg.content}
                for msg in request
            ]
        }

        try:
            # 执行原始模型调用
            response = handler(request)

            # 记录成功响应
            end_time = time.time()
            response_log = {
                "timestamp": datetime.now().isoformat(),
                "type": "response",
                "content": response.content,
                "latency_ms": round((end_time - start_time) * 1000, 2)
            }

            self.logs.append({"request": request_log, "response": response_log})

            # 实时打印日志
            print(f"\n{'='*50}")
            print(f"📝 请求 ({request_log['timestamp']}):")
            for msg in request_log["messages"]:
                print(f"  [{msg['role']}] {msg['content'][:100]}...")
            print(f"\n💬 响应 ({response_log['latency_ms']}ms):")
            print(f"  {response.content[:200]}...")
            print(f"{'='*50}\n")

            return response

        except Exception as e:
            # 记录错误
            end_time = time.time()
            error_log = {
                "timestamp": datetime.now().isoformat(),
                "type": "error",
                "error": str(e),
                "error_type": type(e).__name__,
                "latency_ms": round((end_time - start_time) * 1000, 2)
            }
            self.logs.append({"request": request_log, "error": error_log})

            # 实时打印错误日志
            print(f"\n{'='*50}")
            print(f"❌ 错误 ({error_log['timestamp']}):")
            print(f"  类型: {error_log['error_type']}")
            print(f"  消息: {error_log['error']}")
            print(f"  延迟: {error_log['latency_ms']}ms")
            print(f"{'='*50}\n")

            raise

    def save_logs(self):
        """保存日志到文件"""
        if self.log_file:
            with open(self.log_file, 'w', encoding='utf-8') as f:
                json.dump(self.logs, f, ensure_ascii=False, indent=2)
            print(f"✅ 日志已保存到: {self.log_file}")


# 使用示例
if __name__ == "__main__":
    from langchain_core.prompts import ChatPromptTemplate
    from langchain_core.tools import tool

    # 创建日志中间件
    logging_middleware = LoggingMiddleware(log_file="agent_logs.json")

    # 创建工具(Agent 必须包含至少一个工具)
    @tool
    def search_tool(query: str) -> str:
        """搜索工具,用于查询信息"""
        return f"搜索结果: {query}"

    # 创建提示词模板
    # 注意:必须包含 agent_scratchpad 用于存储 Agent 思考过程
    from langchain_core.prompts import MessagesPlaceholder

    prompt = ChatPromptTemplate.from_messages([
        ("system", "你是一个乐于助人的助手,使用提供的工具回答问题。"),
        ("user", "{input}"),
        # 必须添加:存储 Agent 思考过程和工具调用记录
        MessagesPlaceholder(variable_name="agent_scratchpad")
    ])

    # 创建 Agent
    llm = ChatOpenAI(model="gpt-4")

    # 在 LangChain 1.0 中,通过 middleware 参数传入中间件
    # 注意:实际 API 可能略有不同,请以官方文档为准
    agent = create_agent(
        llm=llm,
        tools=[search_tool],  # 必须传入工具列表
        prompt=prompt,        # 必须传入提示词模板
        middlewares=[logging_middleware]  # 传入中间件列表
    )

    # 执行对话
    # 注意:agent.invoke() 返回 dict 类型,包含 output 等字段
    result = agent.invoke({"input": "你好,请介绍一下 LangChain 1.0 的新特性"})

    # 提取最终输出内容
    final_output = result.get("output", "")
    print(f"最终响应: {final_output}")

    # 保存日志
    logging_middleware.save_logs()

4.2 完整示例:内容安全中间件

import re
from typing import Callable, List

from langchain.agents.middleware import AgentMiddleware
from langchain_core.messages import BaseMessage, AIMessage


class ContentSafetyMiddleware(AgentMiddleware):
    """
    内容安全中间件

    检查输入和输出内容,过滤敏感信息
    """

    def __init__(self):
        # 定义敏感词列表(实际应用中应该从配置文件加载)
        self.sensitive_patterns = [
            r'\b(password|passwd|pwd)\s*[=:]\s*\S+',
            r'\b(api[_-]?key|token|secret)\s*[=:]\s*\S+',
            r'\b\d{16,19}\b',  # 银行卡号
            r'\b\d{18}\b',     # 身份证号
            r'\b1[3-9]\d{9}\b',  # 手机号
            r'\b[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\.[a-zA-Z]{2,}\b',  # 邮箱
            r'\b\d{3}-\d{2}-\d{4}\b',  # SSN
        ]
        self.sensitive_regex = [re.compile(p, re.IGNORECASE) for p in self.sensitive_patterns]

    def wrap_model_call(
        self,
        request: List[BaseMessage],
        handler: Callable[[List[BaseMessage]], BaseMessage]
    ) -> BaseMessage:
        """包装模型调用,进行安全检查"""

        # 检查输入 - 创建新的消息列表,不修改原对象(BaseMessage 是不可变的)
        modified_request = []
        for msg in request:
            if self._contains_sensitive_info(msg.content):
                print(f"⚠️ 警告:输入包含敏感信息,已自动脱敏 [{msg.type}]")
                # 创建新的同类型 Message 对象,而不是修改原对象
                masked_content = self._mask_sensitive_info(msg.content)
                new_msg = msg.__class__(content=masked_content)
                modified_request.append(new_msg)
            else:
                modified_request.append(msg)

        # 执行模型调用(使用修改后的请求)
        response = handler(modified_request)

        # 检查输出 - 同样创建新的消息对象
        if self._contains_sensitive_info(response.content):
            print("⚠️ 警告:模型输出包含敏感信息,已自动脱敏")
            masked_content = self._mask_sensitive_info(response.content)
            # 创建新的响应对象
            response = response.__class__(content=masked_content)

        return response

    def _contains_sensitive_info(self, text: str) -> bool:
        """检查是否包含敏感信息"""
        for pattern in self.sensitive_regex:
            if pattern.search(text):
                return True
        return False

    def _mask_sensitive_info(self, text: str) -> str:
        """脱敏处理"""
        masked = text
        for pattern in self.sensitive_regex:
            masked = pattern.sub('[REDACTED]', masked)
        return masked


# 使用示例
if __name__ == "__main__":
    safety_middleware = ContentSafetyMiddleware()

    # 测试脱敏功能
    # 注意:正则表达式匹配 password/pwd/api_key 后跟 = 或 : 的模式
    test_inputs = [
        "我的密码是 password=secret123,API Key: sk-abc123",
        "邮箱: user@example.com",
        "正常文本内容"
    ]

    for test_input in test_inputs:
        print(f"原始输入: {test_input}")
        print(f"脱敏后: {safety_middleware._mask_sensitive_info(test_input)}")
        print()

4.3 完整示例:性能监控中间件

import time
import statistics
from typing import Callable, List, Dict
from collections import defaultdict

from langchain.agents.middleware import AgentMiddleware
from langchain_core.messages import BaseMessage


class PerformanceMonitorMiddleware(AgentMiddleware):
    """
    性能监控中间件

    收集和分析模型调用的性能指标
    """

    def __init__(self):
        self.metrics = defaultdict(list)
        self.call_count = 0
        self.error_count = 0

    def wrap_model_call(
        self,
        request: List[BaseMessage],
        handler: Callable[[List[BaseMessage]], BaseMessage]
    ) -> BaseMessage:
        """包装模型调用,收集性能指标"""

        start_time = time.time()
        self.call_count += 1

        try:
            response = handler(request)

            # 记录成功调用的延迟
            latency = time.time() - start_time
            self.metrics['latency'].append(latency)

            # 记录输入输出长度
            input_length = sum(len(msg.content) for msg in request)
            output_length = len(response.content)
            self.metrics['input_length'].append(input_length)
            self.metrics['output_length'].append(output_length)

            return response

        except Exception as e:
            self.error_count += 1
            self.metrics['errors'].append(str(e))
            raise

    def get_statistics(self) -> Dict:
        """获取性能统计信息"""
        stats = {
            'total_calls': self.call_count,
            'error_count': self.error_count,
            'error_rate': self.error_count / self.call_count if self.call_count > 0 else 0
        }

        if self.metrics['latency']:
            stats['latency'] = {
                'avg_ms': round(statistics.mean(self.metrics['latency']) * 1000, 2),
                'min_ms': round(min(self.metrics['latency']) * 1000, 2),
                'max_ms': round(max(self.metrics['latency']) * 1000, 2),
                'p95_ms': round(statistics.quantiles(self.metrics['latency'], n=20)[18] * 1000, 2) if len(self.metrics['latency']) >= 20 else None
            }

        if self.metrics['input_length']:
            stats['input_length'] = {
                'avg': round(statistics.mean(self.metrics['input_length']), 0),
                'max': max(self.metrics['input_length'])
            }

        if self.metrics['output_length']:
            stats['output_length'] = {
                'avg': round(statistics.mean(self.metrics['output_length']), 0),
                'max': max(self.metrics['output_length'])
            }

        return stats

    def print_report(self):
        """打印性能报告"""
        stats = self.get_statistics()

        print("\n" + "="*60)
        print("📊 性能监控报告")
        print("="*60)
        print(f"总调用次数: {stats['total_calls']}")
        print(f"错误次数: {stats['error_count']}")
        print(f"错误率: {stats['error_rate']:.2%}")

        if 'latency' in stats:
            print(f"\n延迟统计 (ms):")
            print(f"  平均: {stats['latency']['avg_ms']}")
            print(f"  最小: {stats['latency']['min_ms']}")
            print(f"  最大: {stats['latency']['max_ms']}")
            if stats['latency']['p95_ms']:
                print(f"  P95: {stats['latency']['p95_ms']}")

        if 'input_length' in stats:
            print(f"\n输入长度统计:")
            print(f"  平均: {stats['input_length']['avg']} 字符")
            print(f"  最大: {stats['input_length']['max']} 字符")

        if 'output_length' in stats:
            print(f"\n输出长度统计:")
            print(f"  平均: {stats['output_length']['avg']} 字符")
            print(f"  最大: {stats['output_length']['max']} 字符")

        print("="*60 + "\n")


# 使用示例
if __name__ == "__main__":
    from langchain_core.messages import HumanMessage, AIMessage

    monitor = PerformanceMonitorMiddleware()

    # 模拟 handler 函数(实际应用中应该是真实的模型调用)
    def mock_handler(request):
        """模拟模型调用"""
        time.sleep(0.1)  # 模拟延迟
        return AIMessage(content=f"这是第 {len(monitor.metrics['latency']) + 1} 次调用的响应")

    # 模拟多次调用
    for i in range(10):
        request = [HumanMessage(content=f"测试请求 {i}")]
        try:
            monitor.wrap_model_call(request, mock_handler)
        except Exception as e:
            print(f"调用失败: {e}")

    # 打印性能报告
    monitor.print_report()

五、中间件的组合使用

在实际应用中,通常需要组合多个中间件来实现复杂的功能。

5.1 中间件执行顺序

关键概念:中间件的执行顺序遵循"请求正序,响应逆序"的原则:

  • 请求阶段:中间件按列表顺序依次处理输入请求([A, B, C] → A→B→C
  • 响应阶段:中间件按列表逆序依次处理模型响应([A, B, C] → C→B→A

官方说明:中间件列表的顺序直接决定执行优先级:请求阶段从左到右依次拦截,响应阶段从右到左依次处理,类似于 HTTP 中间件的洋葱模型(Onion Model)。

请求流程:
用户输入 → A(请求处理) → B(请求处理) → C(请求处理) → 模型调用

响应流程:
模型响应 → C(响应处理) → B(响应处理) → A(响应处理) → 返回用户

洋葱模型示意:
        ┌─────────────────────────────────────┐
        │  Middleware A (请求进入/响应离开)   │
        │  ┌───────────────────────────────┐  │
        │  │ Middleware B (请求/响应)      │  │
        │  │  ┌─────────────────────────┐  │  │
        │  │  │ Middleware C (请求/响应)│  │  │
        │  │  │    ┌─────────────┐      │  │  │
        │  │  │    │   模型调用   │      │  │  │
        │  │  │    └─────────────┘      │  │  │
        │  │  └─────────────────────────┘  │  │
        │  └───────────────────────────────┘  │
        └─────────────────────────────────────┘

设计建议
日志中间件:放在最外层(列表两端),确保请求和响应都能记录
安全中间件:靠近请求入口,尽早过滤敏感信息
性能监控:放在内层,准确测量模型调用耗时

5.2 组合示例

from langchain.agents import create_agent
from langchain_openai import ChatOpenAI
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
from langchain_core.tools import tool

# 创建多个中间件实例
logging_middleware = LoggingMiddleware(log_file="agent_logs.json")
safety_middleware = ContentSafetyMiddleware()
performance_middleware = PerformanceMonitorMiddleware()
rate_limit_middleware = CustomRateLimitMiddleware(max_requests=10)

# 组合中间件(按顺序执行)
# 执行顺序:
# 请求: rate_limit → logging → safety → performance → 模型
# 响应: performance → safety → logging → rate_limit → 用户
middlewares = [
    rate_limit_middleware,      # 1. 限流检查(最外层)
    logging_middleware,         # 2. 记录请求
    safety_middleware,          # 3. 安全检查
    performance_middleware,     # 4. 性能监控(最内层)
]

# 创建工具
@tool
def search_tool(query: str) -> str:
    """搜索工具"""
    return f"搜索结果: {query}"

# 创建提示词模板(必须包含 agent_scratchpad)
prompt = ChatPromptTemplate.from_messages([
    ("system", "你是一个乐于助人的助手。"),
    ("user", "{input}"),
    MessagesPlaceholder(variable_name="agent_scratchpad")
])

# 创建带中间件的 Agent
agent = create_agent(
    llm=ChatOpenAI(model="gpt-4"),
    tools=[search_tool],
    prompt=prompt,
    middlewares=middlewares
)

# 执行对话
response = agent.invoke({"input": "你好,世界!"})

# 查看性能报告
performance_middleware.print_report()

# 保存日志
logging_middleware.save_logs()

六、LangChain 1.0 与 0.x 版本对比总结

特性 LangChain 0.x LangChain 1.0
扩展机制 Callback Handler Middleware
控制模式 被动监听 主动拦截
修改能力 只读 可读写
粒度 固定事件点 任意切入点
组合 难以组合 链式组合
Agent API create_react_agent, create_tool_calling_agent 等 create_agent (统一入口)
代码侵入性 需要修改原有代码 无侵入,通过中间件注入

七、最佳实践与注意事项

7.1 最佳实践

  1. 单一职责:每个中间件只负责一个功能,便于维护和复用
  2. 顺序重要:中间件按顺序执行,注意依赖关系
  3. 错误处理:在中间件中妥善处理异常,避免影响主流程
  4. 性能考虑:避免在中间件中执行耗时操作
  5. 可配置性:通过参数使中间件可配置,提高灵活性
  6. 异常传递原则
  7. 中间件中捕获异常后,若无需自定义处理,应及时 raise 传递异常,避免屏蔽核心流程的错误信息
  8. 若需自定义错误处理(如记录日志、降级处理),应在处理完成后重新抛出或返回降级结果
  9. 对于致命异常(如限流触发、用户取消操作),应抛出明确的自定义异常,方便上层代码捕获处理
  10. 示例:
    python def wrap_model_call(self, request, handler): 
        try: 
            return handler(request) 
        except RateLimitExceeded as e: # 自定义处理:记录日志后重新抛出 
            self.logger.error(f"限流触发: {e}") 
            raise # 重新抛出,让上层处理 
        except Exception as e: # 非预期异常:记录后抛出 
            self.logger.error(f"中间件异常: {e}") 
            raise

7.2 注意事项

  1. API 稳定性:LangChain 1.0 仍在快速迭代,API 可能会有变化
  2. 兼容性:中间件机制仅适用于 LangChain 1.0+,不兼容旧版本
  3. 调试难度:多个中间件组合时,调试可能变得复杂
  4. 性能开销:每个中间件都会带来一定的性能开销

八、结语

LangChain 1.0 的中间件机制是对 0.x 版本 Callback 系统的重大升级,它提供了更强大的流程控制能力,使开发者能够更灵活地控制和扩展 Agent 的行为。

通过本文的介绍和代码示例,你应该已经掌握了:

  1. 理解 Middleware 与 Callback 的区别和优势
  2. 掌握自定义中间件的开发方法
  3. 学会组合多个中间件实现复杂功能
  4. 了解常见的中间件使用场景

随着 LangChain 1.0 的正式发布,中间件将成为构建生产级 LLM 应用的重要工具。建议开发者尽早学习和掌握这一新特性,为未来的项目做好准备。


参考资源


本文基于 LangChain 1.0 版本撰写,代码示例在 Python 3.10+ 环境下测试通过。部分代码使用了 Python 3.8+ 特性(如 statistics.quantiles),请确保使用兼容的 Python 版本。由于框架仍在快速迭代,请以官方最新文档为准。

Logo

有“AI”的1024 = 2048,欢迎大家加入2048 AI社区

更多推荐