LangChain中的语言模型是一个Runnable对象,目前我们已经了解了作为输入的LanguageModelInput和作为输出LanguageModelOutputVar,现在我们关注它作为Runnable对象是如何执行的。所有的语言模型类型(包括基于文本补齐的Completion模型和基于多角色参与的Chat模型)都继承自抽象类BaseLanguageModel,派生于它的BaseChatModel是所有Chat模型的基类。基于OpenAI的ChatOpenAI的基类BaseChatOpenAI就继承自BaseChatModel。这篇文章着重介绍BaseLanguageModel和BaseChatModel这两个基类的设计,下篇文章将会介绍ChatOpenAI。

1. BaseLanguageModel

BaseLanguageModel继承自RunnableSerializable,这意味着它能成为LCEL链上的一环,我们先来看看它的字段和属性成员的定义。作为BaseLanguageModel的输入,它可以是一个PromptValue对象,字符串文本或者MessageLikeRepresentation序列。MessageLikeRepresentation则针对BaseMessage、字符串(视为模板)、字符串列表(视为多个模板)、二元组(视为角色+模板)和字典类型的(“role”和“content”作为Key,对应角色和模板)联合。作为输出的LanguageModelOutputVar则限制为AIMessage(针对Chat模型)或者字符串(针对Completion模型)。

class BaseLanguageModel(
    RunnableSerializable[LanguageModelInput, LanguageModelOutputVar], ABC)

LanguageModelInput = PromptValue | str | Sequence[MessageLikeRepresentation]
LanguageModelOutputVar = TypeVar("LanguageModelOutputVar", AIMessage, str)
MessageLikeRepresentation = (
    BaseMessage | list[str] | tuple[str, str] | str | dict[str, Any]
)

它的cache字段用于控制针对响应结果的缓存,我们可以指定一个实现了缓存读写功能的BaseCache对象。如果是True,意味着使用全局默认缓存。设置为None相当于关闭了缓存的功能。verbose字段是开启冗余跟踪调试模式的开关。callbacks字段可以用于注册在整个处理流程响应步骤自动调用的回调。tagsmetadata字段提供的标签和元数据会被添加到跟踪记录中,用于进一步描述当前的应用场景和执行环境。custom_get_token_ids字段提供了一个可执行对象作为分词器,它将指定的作为输入的文本转换成一系列的Token,并返回这些Token的标识。

class BaseLanguageModel(
    RunnableSerializable[LanguageModelInput, LanguageModelOutputVar], ABC):
    cache: BaseCache | bool | None = Field(default=None, exclude=True)
    verbose: bool = Field(default_factory=_get_verbosity, exclude=True, repr=False)
    callbacks: Callbacks = Field(default=None, exclude=True)
    tags: list[str] | None = Field(default=None, exclude=True)
    metadata: dict[str, Any] | None = Field(default=None, exclude=True)
    custom_get_token_ids: Callable[[str], list[int]] | None = Field(
        default=None, exclude=True)	

generate_prompt/agenerate_prompt是两个最为核心的抽象方法,它以PromptValue作为输入,并将其转换成兼容参数对模型实施调用,最后将返回内容转换成一个LLMResult对象。我们可以从LLMResult对象中提取模型生成的文本、元数据及Token消耗统计数据。

class BaseLanguageModel(
    RunnableSerializable[LanguageModelInput, LanguageModelOutputVar], ABC):
    @abstractmethod
    def generate_prompt(
        self,
        prompts: list[PromptValue],
        stop: list[str] | None = None,
        callbacks: Callbacks = None,
        **kwargs: Any,
    ) -> LLMResult:

    @abstractmethod
    async def agenerate_prompt(
        self,
        prompts: list[PromptValue],
        stop: list[str] | None = None,
        callbacks: Callbacks = None,
        **kwargs: Any,
    ) -> LLMResult

BaseLanguageModel还提供了如下四个辅助方法。with_structured_output方法通过指定JSON Schema或者Pydantic Model的方式来控制输出的结构。该方法默认会抛出NotImplementedError异常,支持“结构化输出”特性的子类需要重写此方法。get_token_ids方法对指定文本实施分词,如果对custom_get_token_ids字段进行了设置,此方法会直接调用设置的可执行对象。否则会使用_get_token_ids_default_method字段设置的可执行对象作为兜底分词器。它的get_num_tokensget_num_tokens_from_messages方法用于计算指定文本和消息对应的Token数量,我们主要使用它们来确定是否超出模型上下文窗口限制,这两个方法返回的统计结果来源于针对于针对get_token_ids方法的调用。

class BaseLanguageModel(
    RunnableSerializable[LanguageModelInput, LanguageModelOutputVar], ABC):
    def with_structured_output(
        self, schema: dict | type, **kwargs: Any
    ) -> Runnable[LanguageModelInput, dict | BaseModel]:
        raise NotImplementedError    

    def get_token_ids(self, text: str) -> list[int]:
        if self.custom_get_token_ids is not None:
            return self.custom_get_token_ids(text)
        return _get_token_ids_default_method(text)

    def get_num_tokens(self, text: str) -> int:
        return len(self.get_token_ids(text))

    def get_num_tokens_from_messages(
        self,
        messages: list[BaseMessage],
        tools: Sequence | None = None,
    ) -> int:
        return sum(self.get_num_tokens(get_buffer_string([m])) for m in messages)

BaseLanguageModel还定义了如下这个名为_identifying_params的属性,它返回一个标识该模型实例特征的字典。该字典主要用于缓存键的生成,确保不同配置的模型调用不会互相覆盖缓存。此属性默认返回lc_attributes属性。

class BaseLanguageModel(
    RunnableSerializable[LanguageModelInput, LanguageModelOutputVar], ABC):
    @property
    def _identifying_params(self) -> Mapping[str, Any]:	
        return self.lc_attributes	

2. LLMResult

传统的Completion模型以单纯的字符串作为输入和输出,Chat模型则采用绑定为具体角色的消息作为输入和输出,具体的输出体现为一个AIMessage对象。接收到的模型响应结果后,最终会被转换成一个LLMResult对象,作为输出的字符串或AIMessage是根据这个LLMResult对象生成的。

LLMResult其实并非对“单一模型调用结果”的封装,都是对“批量模型调用结果”的封装。它也具有专属的类型“LLMResult”,run字段返回的RunInfo列表用于描述模型调用,每个RunInfo对象只包含表示调用标识的run_id字段。generations字段是LLMResult中最重要的部分,存储的是模型生成的实际文本或消息。它返回一个两层列表,外层列表对应输入批次,如果一次性传了 3 个 提示词,外层列表长度就是 3。内层列表对应单次输入的候选结果(n)。如果你设置n > 1(要求模型返回多个备选答案),内层列表就会包含多个Generation对象。llm_output字段返回一个字典,存储了特定于供应商的非内容信息。它不属于生成的文本,而是关于生成过程的统计。

class LLMResult(BaseModel):
    generations: list[
        list[Generation | ChatGeneration | GenerationChunk | ChatGenerationChunk]
    ]
    llm_output: dict | None = None
    run: list[RunInfo] | None = None
    type: Literal["LLMResult"] = "LLMResult"
    def flatten(self) -> list[LLMResult]

class RunInfo(BaseModel):
    run_id: UUID

我们可以调用flatten方法将一个LLMResult对象“扁平化”为一个LLMResult列表,扁平化后的每个LLMResult对应一个单一的调用,所以其generations字段的第一层列表长度总是为1。

针对Completion模型的生成文本封装在Generation/GenerationChunk对象中,Generation通过继承Serializable被赋予了可被序列化和反序列化的能力,其专属的类型为“Generation”。生成的字符串文本被存储在text字段上,generation_info以字典的形式存储额外的元数据,比如“终止原因”和“对数概率”等。服务于“流式输出”的GenerationChunk类型继承自Generation,它通过重写__add__方法实现“拼接”功能。

class Generation(Serializable):
    text: str
    generation_info: dict[str, Any] | None = None
    type: Literal["Generation"] = "Generation"

class GenerationChunk(Generation):
    def __add__(self, other: GenerationChunk) -> GenerationChunk

Chat模型的生成的是绑定为某个特定角色的消息,对应于ChatGeneration/ChatGenerationChunck类型的message字段,类型分别为BaseMessage/BaseMessageChunktext字段返回的是消息的文本表示。这两个类型具有各自专属类型“ChatGeneration”和“ChatGenerationChunk”。

class ChatGeneration(Generation):
    text: str = ""
    message: BaseMessage
    type: Literal["ChatGeneration"] = "ChatGeneration"

class ChatGenerationChunk(ChatGeneration):
    message: BaseMessageChunk
    type: Literal["ChatGenerationChunk"] = "ChatGenerationChunk"
    def __add__(
        self, other: ChatGenerationChunk | list[ChatGenerationChunk]
    ) -> ChatGenerationChunk 

3. BaseChatModel

作为Chat模型的基类,BaseChatModel继承自BaseLanguageModel。从如下的定义可以看出,它将输出类型固定为AIMessage

class BaseChatModel(BaseLanguageModel[AIMessage], ABC)

3.1 配置字段/属性

BaseChatModel在基类上添加了四个额外字段,disable_streaming字段作为“流式输出”的开关,除了设置布尔值之外,如果设置为“tool_calling” 意味着仅在工具调用时关闭流式输出。output_version字段可以控制AIMessage的内容格式版本(如v1提供更标准化的多模态内容块),它默认会采用环境变量“LC_OUTPUT_VERSION”的值。表示LLM类型的属性_llm_type被定义成抽象方法,具体的子类需要实现它来提供对应的LLM类型。

class BaseChatModel(BaseLanguageModel[AIMessage], ABC):
    rate_limiter: BaseRateLimiter | None = Field(default=None, exclude=True)
    disable_streaming: bool | Literal["tool_calling"] = False
    output_version: str | None = Field(
        default_factory=from_env("LC_OUTPUT_VERSION", default=None)
    )
    profile: ModelProfile | None = Field(default=None, exclude=True)

    @property
    @abstractmethod
    def _llm_type(self) -> str

利用rate_limiter字段设置的“限流器”可以在调用高频接口时防止触发服务商的限流阈值。该字段返回的类型BaseRateLimiter是限流器的基类。我们可以使用InMemoryRateLimiter这个默认实现,它采用经典且工业级的令牌桶限流算法。在调用构造函数创建InMemoryRateLimiter的时候,我们可以利用参数requests_per_secondmax_bucket_sizecheck_every_n_seconds指定生成令牌的速率、令牌桶的容量和轮询间隔。

class BaseRateLimiter(abc.ABC):
    @abc.abstractmethod
    def acquire(self, *, blocking: bool = True) -> bool
    @abc.abstractmethod
    async def aacquire(self, *, blocking: bool = True) -> bool

class InMemoryRateLimiter(BaseRateLimiter):
    def __init__(
        self,
        *,
        requests_per_second: float = 1,
        check_every_n_seconds: float = 0.1,
        max_bucket_size: float = 1,
    ) -> None

BaseChatModeprofile字段返回的ModelProfile对象可以作为描述模型的“画像”,利用它可以帮助框架在运行时自动判断模型能力。如下面的代码片段所示,ModelProfile提供了针对输入/输出的限制,针对工具调用特性的描述和结构化输出能力的指示。作为注册表的ModelProfileRegistry本质上就是一个将ModelProfile作为Value的字典。

class ModelProfile(TypedDict, total=False):
    # --- Input constraints ---
    max_input_tokens: int
    image_inputs: bool
    image_url_inputs: bool
    pdf_inputs: bool
    audio_inputs: bool
    video_inputs: bool
    image_tool_message: bool
    pdf_tool_message: bool

    # --- Output constraints ---
    max_output_tokens: int
    reasoning_output: bool
    image_outputs: bool
    audio_outputs: bool
    video_outputs: bool

    # --- Tool calling ---
    tool_calling: bool
    tool_choice: bool

    # --- Structured output ---
    structured_output: bool

ModelProfileRegistry = dict[str, ModelProfile]

3.2 两种调用方式

BaseChatModel重写了invoke/ainvokestream/astream方法。重写的invoke方法会调用generate_prompt方法,并从返回的LLMResult对象中提取第一个ChatGeneration对象,最后将其转换成返回的AIMessage。调用generate_prompt方法所需的参数均来源于提供的RunnableConfig配置。重写的ainvoke方法采用了类似的实现。

class BaseChatModel(BaseLanguageModel[AIMessage], ABC):
    @override
    def invoke(
        self,
        input: LanguageModelInput,
        config: RunnableConfig | None = None,
        *,
        stop: list[str] | None = None,
        **kwargs: Any,
    ) -> AIMessage:
        config = ensure_config(config)
        return cast(
            "AIMessage",
            cast(
                "ChatGeneration",
                self.generate_prompt(
                    [self._convert_input(input)],
                    stop=stop,
                    callbacks=config.get("callbacks"),
                    tags=config.get("tags"),
                    metadata=config.get("metadata"),
                    run_name=config.get("run_name"),
                    run_id=config.pop("run_id", None),
                    **kwargs,
                ).generations[0][0],
            ).message,
        )

    @override
    async def ainvoke(
        self,
        input: LanguageModelInput,
        config: RunnableConfig | None = None,
        *,
        stop: list[str] | None = None,
        **kwargs: Any,
    ) -> AIMessage:
        config = ensure_config(config)
        llm_result = await self.agenerate_prompt(
            [self._convert_input(input)],
            stop=stop,
            callbacks=config.get("callbacks"),
            tags=config.get("tags"),
            metadata=config.get("metadata"),
            run_name=config.get("run_name"),
            run_id=config.pop("run_id", None),
            **kwargs,
        )
        return cast(
            "AIMessage", cast("ChatGeneration", llm_result.generations[0][0]).message
        )

BaseChatModel将流式处理的核心实现下放给它的子类,具体体现在它的_stream方法上。该方法返回一个针对ChatGenerationChunk迭代器,支持流式输出的子类需要重写这个方法。stream/atream方法最终会调用此方法,并将得到的ChatGenerationChunk转换成AIMessageChunk。除此之外,注册到RunnableConfig中的很多回调也会在此方法中调用。值得一提的是,如果不支持流式输出,这两个方法也不会抛出异常,它们会使用invoke/avinoke方法作为兜底。利用rate_limiter字段设置的限流器也在这里被应用。

class BaseChatModel(BaseLanguageModel[AIMessage], ABC):
    @override
    def stream(
        self,
        input: LanguageModelInput,
        config: RunnableConfig | None = None,
        *,
        stop: list[str] | None = None,
        **kwargs: Any,
    ) -> Iterator[AIMessageChunk]:
    @override
    async def astream(
        self,
        input: LanguageModelInput,
        config: RunnableConfig | None = None,
        *,
        stop: list[str] | None = None,
        **kwargs: Any,
    ) -> AsyncIterator[AIMessageChunk]

    def _stream(
        self,
        messages: list[BaseMessage],
        stop: list[str] | None = None,
        run_manager: CallbackManagerForLLMRun | None = None,
        **kwargs: Any,
    ) -> Iterator[ChatGenerationChunk]:
        raise NotImplementedError

3.3 工具绑定与结构化输出

BaseChatModelbind_tools方法是让对话模型具备“工具调用” 能力的核心方法。它的本质是将可用的工具“注册”给模型,让模型在需要时决定调用哪个工具。 绑定的工具可以采用多种表示,可以是字典、Pydantic模型类、可执行对象(比如函数)和一个BaseTool对象。tool_choice参数体现了针对目标工具的选择策略,具体的子类一般会有针对性的定义。

class BaseChatModel(BaseLanguageModel[AIMessage], ABC):
    def bind_tools(
        self,
        tools: Sequence[builtins.dict[str, Any] | type | Callable | BaseTool],
        *,
        tool_choice: str | None = None,
        **kwargs: Any,
    ) -> Runnable[LanguageModelInput, AIMessage]:
        raise NotImplementedError

    def with_structured_output(
        self,
        schema: builtins.dict[str, Any] | type,
        *,
        include_raw: bool = False,
        **kwargs: Any,
    ) -> Runnable[LanguageModelInput, builtins.dict[str, Any] | BaseModel]

BaseChatModel重写了用于实现“结构化输出”的with_structured_output,它会根据schema参数提供的类型先择相应的JsonOutputToolsParser将原始输出转换成Schema对应的对象。最终返回的Runnable是当前对象和这个JsonOutputToolsParser组成的LCEL链。

3.4 执行流程

作为一个Runnable对象,其核心功能体现在它实现的invoke/ainvokestream/astream两组核心方法中。我们接下来大致介绍一下BaseChatModelinvoke方法的执行流程。如下面的代码片段所示,重写的invoke方法会调用私有方法_convert_input将输入统一转换成PromptValue对象。具体的转换规则很明确:如果传入的是字符串,则转换成一个StringPromptValue;如果是一个序列,则转换成ChatPromptValue对象。

class BaseChatModel(BaseLanguageModel[AIMessage], ABC):
    @override
    def invoke(
        self,
        input: LanguageModelInput,
        config: RunnableConfig | None = None,
        *,
        stop: list[str] | None = None,
        **kwargs: Any,
    ) -> AIMessage:
        config = ensure_config(config)
        return cast(
            "AIMessage",
            cast(
                "ChatGeneration",
                self.generate_prompt(
                    [self._convert_input(input)],
                    stop=stop,
                    callbacks=config.get("callbacks"),
                    tags=config.get("tags"),
                    metadata=config.get("metadata"),
                    run_name=config.get("run_name"),
                    run_id=config.pop("run_id", None),
                    **kwargs,
                ).generations[0][0],
            ).message,
        )

    def _convert_input(self, model_input: LanguageModelInput) -> PromptValue:
        if isinstance(model_input, PromptValue):
            return model_input
        if isinstance(model_input, str):
            return StringPromptValue(text=model_input)
        if isinstance(model_input, Sequence):
            return ChatPromptValue(messages=convert_to_messages(model_input))
        msg = (
            f"Invalid input type {type(model_input)}. "
            "Must be a PromptValue, str, or list of BaseMessages."
        )
        raise ValueError(msg)

    @override
    def generate_prompt(
        self,
        prompts: list[PromptValue],
        stop: list[str] | None = None,
        callbacks: Callbacks = None,
        **kwargs: Any,
    ) -> LLMResult:
        prompt_messages = [p.to_messages() for p in prompts]
        return self.generate(prompt_messages, stop=stop, callbacks=callbacks, **kwargs)

得到的PromptValue被作为参数调用重写的generate_prompt方法,返回的LLMResult对象的第一个ChatGeneration被用于生成返回的AIMessage对象。至于重写的generate_prompt方法,它会调用每个PromptValueto_messages方法将其转换成消息列表,并将此列表和其他参数作为输入调用generate方法。

这个generate方法才是BaseChatModel最核心的部分,它“批量”处理传入的消息。除了作为核心输入的消息列表外,我们还可以传入一系列额外的参数,包括作为“终止词”的stop参数,当LLM在进行文本生成过程中,遇到它就会立即终止。callbacks参数提供回调以钩子的形式参与到整个处理流程,tagsmetadata参数提供的标签和元数据会被写入输出的跟踪记录中。除此之外,我们还可以利用run_namerun_id参数为当前调用设置一个具有可读性的名称和唯一标识。

class BaseChatModel(BaseLanguageModel[AIMessage], ABC):
    def generate(
        self,
        messages: list[list[BaseMessage]],
        stop: list[str] | None = None,
        callbacks: Callbacks = None,
        *,
        tags: list[str] | None = None,
        metadata: dict[str, Any] | None = None,
        run_name: str | None = None,
        run_id: uuid.UUID | None = None,
        **kwargs: Any,
    ) -> LLMResult

BaseChatModel会对生成的内容基于输入的提示词(字符串)实施缓存。输入的消息被转换成规范化的字符串形式的提示词后,会作为Key的一部分用于提取缓存的结果。只有在缓存项目不存在的情况下,才会实施后续的处理流程。这一操作实现在如下这个私有的_generate_with_cache中。generate方法仅仅遍历输入的每个消息,然后将单个消息作为输入调用此方法,该方法返回的ChatResult最终用来生成LLMResult对象。

class BaseChatModel(BaseLanguageModel[AIMessage], ABC):
    def _generate_with_cache(
        self,
        messages: list[BaseMessage],
        stop: list[str] | None = None,
        run_manager: CallbackManagerForLLMRun | None = None,
        **kwargs: Any,
    ) -> ChatResult:

class BaseChatModel(BaseLanguageModel[AIMessage], ABC):
    cache: BaseCache | bool | None = Field(default=None, exclude=True)

class BaseCache(ABC):
    @abstractmethod
    def lookup(self, prompt: str, llm_string: str) -> RETURN_VAL_TYPE | None:
    @abstractmethod
    def update(self, prompt: str, llm_string: str, return_val: RETURN_VAL_TYPE) -> None:

    @abstractmethod
    def clear(self, **kwargs: Any) -> None
    async def alookup(self, prompt: str, llm_string: str) -> RETURN_VAL_TYPE | None
    async def aupdate(
        self, prompt: str, llm_string: str, return_val: RETURN_VAL_TYPE
    ) -> None:
    async def aclear(self, **kwargs: Any) -> None

class InMemoryCache(BaseCache)

通过抽象类BaseCache表示的缓存组件保存在BaseChatModelcache字段上。该抽象类定义了相应的抽象方法用于缓存条目的更新、读取和清除,InMemoryCache是我们常用的实现类型。值得一提的是,这两个类型是由langchain_core.caches这个包提供的,langgraph.cache.baselanggraph.cache.memory这两个包中也定义了同名的类型,不要混肴。作为“缓存键”的组成,除了作为提示词的字符串外,还包括区分LLM的标识,这很好理解,不同的LLM针对相同的提示词会生成不同的内容。

再回到_generate_with_cache方法,当它根据生成的规范化的提示词和目标LLM的标识确定对应的缓存条目存的情况下,它会直接返回缓存的内容。如果缓存内容不存在,在执行“流式输出”的情况下,它会调用_stream方法并得到一组ChatGenerationChunk对象,并将其转换成返回的ChatResult对象。否则会调用另一个_generate方法生成返回的ChatResult对象,ChatResultLLMResult具有类似的定义。

class ChatResult(BaseModel):
    generations: list[ChatGeneration]
    llm_output: dict | None = None

“非流式”内容生成的_generate被定义成抽象方法。由于“流式输出”并非“必需”的处理方式,所以_stream并非抽象方法,但是它会直接抛出NotImplementedError异常。最终针对LLM的调用和对结果的解析以这种方式下放给继承BaseChatModel的子类。

class BaseChatModel(BaseLanguageModel[AIMessage], ABC):
    @abstractmethod
    def _generate(
        self,
        messages: list[BaseMessage],
        stop: list[str] | None = None,
        run_manager: CallbackManagerForLLMRun | None = None,
        **kwargs: Any,
    ) -> ChatResult:
    def _stream(
        self,
        messages: list[BaseMessage],
        stop: list[str] | None = None,
        run_manager: CallbackManagerForLLMRun | None = None,
        **kwargs: Any,
    ) -> Iterator[ChatGenerationChunk]:
        raise NotImplementedError

上面我们介绍了invoke方法大致的处理流程,异步版本的ainvoke的实现与之类似,只是它们调用异步版本的agenerate_prompt_agenerate_astream方法。由于这些异步方法最终还是转移到针对对应的同步方法上,子类可以通过重写它们实现真正意义上的异步。

class BaseChatModel(BaseLanguageModel[AIMessage], ABC):
    @override
    async def agenerate_prompt(
        self,
        prompts: list[PromptValue],
        stop: list[str] | None = None,
        callbacks: Callbacks = None,
        **kwargs: Any,
    ) -> LLMResult

    async def _agenerate(
        self,
        messages: list[BaseMessage],
        stop: list[str] | None = None,
        run_manager: AsyncCallbackManagerForLLMRun | None = None,
        **kwargs: Any,
    ) -> ChatResult
    async def _astream(
        self,
        messages: list[BaseMessage],
        stop: list[str] | None = None,
        run_manager: AsyncCallbackManagerForLLMRun | None = None,
        **kwargs: Any,
    ) -> AsyncIterator[ChatGenerationChunk]
Logo

有“AI”的1024 = 2048,欢迎大家加入2048 AI社区

更多推荐