01. RunnableWithMessageHistory 使用示例

在前面的示例中,我们将历史消息显示地传递给链,在链外单独处理历史消息的记忆存储。这是一种完全可接受的方法,除此之外,LangChain 还提供了一个名为 RunnableWithMessageHistory 的类/包裹器,能让链自动处理这个过程(填充+存储)。

类构造函数接收的参数如下:

  1. runnable:需要包装的链或者可运行的组件。
  2. get_session_history:一个工厂函数,它返回给定会话ID的消息历史记录。这样,您的链就可以通过加载不同对话的不同消息来同时处理多个用户。
  3. input_messages_key:人类的输入是哪一个键,用于指定输入的哪个部分应该在聊天历史中被跟踪和存储。
  4. output_messages_key:AI 的输出是哪一个键,指定要将哪个输出存储为历史记录。
  5. history_messages_key:历史消息使用哪一个键,用于指定以前的消息使用特定的变量在模板中格式化。

使用 RunnableWithMessageHistory 包装链后,就可以像正常链一样调用了,除此之外,还可以增加一个运行时配置来指定传递给工厂函数的 session_id 是什么,从哪里获取存储的历史消息。

示例代码如下

import dotenv

from langchain_community.chat_message_histories import FileChatMessageHistory

from langchain_core.chat_history import BaseChatMessageHistory

from langchain_core.output_parsers import StrOutputParser

from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder

from langchain_core.runnables.history import RunnableWithMessageHistory

from langchain_openai import ChatOpenAI

dotenv.load_dotenv()

# 1.定义历史记忆存储

store = {}

# 2.工厂函数,用于获取指定会话的聊天历史

def get_session_history(session_id: str) -> BaseChatMessageHistory:

    if session_id not in store:

        store[session_id] = FileChatMessageHistory(f"chat_history_{session_id}.txt")

    return store[session_id]

# 3.构建提示模板与大语言模型

prompt = ChatPromptTemplate.from_messages([

    ("system", "你是一个强大的聊天机器人,请根据用户的需求回复问题。"),

    MessagesPlaceholder("history"),

    ("human", "{query}"),

])

llm = ChatOpenAI(model="gpt-3.5-turbo-16k")

# 4.构建链

chain = prompt | llm | StrOutputParser()

# 5.包装链

with_message_chain = RunnableWithMessageHistory(

    chain,

    get_session_history,

    input_messages_key="query",

    history_messages_key="history",

)

while True:

    query = input("Human: ")

    if query == "q":

        exit(0)

    # 6.运行链并传递配置信息

    response = with_message_chain.stream(

        {"query": query},

        config={"configurable": {"session_id": "123456"}}

    )

    print("AI: ", flush=True, end="")

    for chunk in response:

        print(chunk, flush=True, end="")

print("")

02. RunnableWithMessageHistory 运行流程

RunnableWithMessageHisotry 内部通过传递的运行时配置 session_id 获取到对应的消息历史实例,然后将消息历史实例组装用户输入字典,拼接到原始的 Runnable 可运行链应用中,并为新的 Runnable可运行链应用添加 callback 回调处理器,用于处理存储 LLM 生成的内容,并存储到消息历史记忆中。

整体的运行流程如下

def __init__(

    self,

    runnable: Union[

        Runnable[

            Union[MessagesOrDictWithMessages],

            Union[str, BaseMessage, MessagesOrDictWithMessages],

        ],

        LanguageModelLike,

    ],

    get_session_history: GetSessionHistoryCallable,

    *,

    input_messages_key: Optional[str] = None,

    output_messages_key: Optional[str] = None,

    history_messages_key: Optional[str] = None,

    history_factory_config: Optional[Sequence[ConfigurableFieldSpec]] = None,

    **kwargs: Any,

) -> None:

    history_chain: Runnable = RunnableLambda(

        self._enter_history, self._aenter_history

    ).with_config(run_name="load_history")

    messages_key = history_messages_key or input_messages_key

    if messages_key:

        history_chain = RunnablePassthrough.assign(

            **{messages_key: history_chain}

        ).with_config(run_name="insert_history")

    bound = (

        history_chain | runnable.with_listeners(on_end=self._exit_history)

    ).with_config(run_name="RunnableWithMessageHistory")

    if history_factory_config:

        _config_specs = history_factory_config

    else:

        # If not provided, then we'll use the default session_id field

        _config_specs = [

            ConfigurableFieldSpec(

                id="session_id",

                annotation=str,

                name="Session ID",

                description="Unique identifier for a session.",

                default="",

                is_shared=True,

            ),

        ]

    super().__init__(

        get_session_history=get_session_history,

        input_messages_key=input_messages_key,

        output_messages_key=output_messages_key,

        bound=bound,

        history_messages_key=history_messages_key,

        history_factory_config=_config_specs,

        **kwargs,

    )

def _exit_history(self, run: Run, config: RunnableConfig) -> None:

    hist: BaseChatMessageHistory = config["configurable"]["message_history"]

    # Get the input messages

    inputs = load(run.inputs)

    input_messages = self._get_input_messages(inputs)

    # If historic messages were prepended to the input messages, remove them to

    # avoid adding duplicate messages to history.

    if not self.history_messages_key:

        historic_messages = config["configurable"]["message_history"].messages

        input_messages = input_messages[len(historic_messages) :]

    # Get the output messages

    output_val = load(run.outputs)

    output_messages = self._get_output_messages(output_val)

    hist.add_messages(input_messages + output_messages)

def _merge_configs(self, *configs: Optional[RunnableConfig]) -> RunnableConfig:

    config = super()._merge_configs(*configs)

    expected_keys = [field_spec.id for field_spec in self.history_factory_config]

    configurable = config.get("configurable", {})

    missing_keys = set(expected_keys) - set(configurable.keys())

    if missing_keys:

        example_input = {self.input_messages_key: "foo"}

        example_configurable = {

            missing_key: "[your-value-here]" for missing_key in missing_keys

        }

        example_config = {"configurable": example_configurable}

        raise ValueError(

            f"Missing keys {sorted(missing_keys)} in config['configurable'] "

            f"Expected keys are {sorted(expected_keys)}."

            f"When using via .invoke() or .stream(), pass in a config; "

            f"e.g., chain.invoke({example_input}, {example_config})"

        )

    parameter_names = _get_parameter_names(self.get_session_history)

    if len(expected_keys) == 1:

        # If arity = 1, then invoke function by positional arguments

        message_history = self.get_session_history(configurable[expected_keys[0]])

    else:

        # otherwise verify that names of keys patch and invoke by named arguments

        if set(expected_keys) != set(parameter_names):

            raise ValueError(

                f"Expected keys {sorted(expected_keys)} do not match parameter "

                f"names {sorted(parameter_names)} of get_session_history."

            )

        message_history = self.get_session_history(

            **{key: configurable[key] for key in expected_keys}

        )

    config["configurable"]["message_history"] = message_history

    return config

Logo

有“AI”的1024 = 2048,欢迎大家加入2048 AI社区

更多推荐