如何基于BaseChatModel自定义langchain聊天模型
在langchain+mcp场景使用现有LLM,总会遇到各种ak兼容性问题,代码不太好修改。
这里使用langchain的基类BaseChatModel封装现有LLM,就能有效解决这些问题,可高度定制。
1 自定义模型要点
1)内置消息
消息类型 | 描述 |
---|---|
SystemMessage |
预设 AI 行为,通常消息序列中的第一条消息传入。 |
HumanMessage |
与模型聊天的用户发出的消息。 |
AIMessage |
模型发出的消息,可以是文本,可以是调用工具请求。 |
FunctionMessage / ToolMessage |
用于将工具调用结果传回模型的消息。 |
AIMessageChunk / HumanMessageChunk / ... |
每种消息类型的分块变体。 |
from langchain_core.messages import (
AIMessage,
BaseMessage,
FunctionMessage,
HumanMessage,
SystemMessage,
ToolMessage,
)
2)流式变体
所有聊天消息都有一个流式变体,其名称中包含 Chunk
。
from langchain_core.messages import (
AIMessageChunk,
FunctionMessageChunk,
HumanMessageChunk,
SystemMessageChunk,
ToolMessageChunk,
)
这些分块在从聊天模型流式传输输出时使用,并且它们都定义了可加性!
AIMessageChunk(content="Hello") + AIMessageChunk(content=" World!")
AIMessageChunk(content='Hello World!')
2 创建聊天模型
实现一个聊天模型,它将提示中最后一条消息的前 n
个字符回显回来。
这里继承自 BaseChatModel
,并且实现以下内容:
方法/属性 | 描述 | 必选/可选 |
---|---|---|
_generate |
用于从提示生成聊天结果 | 必选 |
_llm_type (属性) |
用于唯一标识模型类型。用于日志记录。 | 必选 |
_identifying_params (属性) |
表示用于追踪目的的模型参数化。 | 可选 |
_stream |
用于实现流式传输。 | 可选 |
_agenerate |
用于实现原生的异步方法。 | 可选 |
_astream |
用于实现 _stream 的异步版本。 |
可选 |
聊天模型代码示例
from langchain_core.messages import (
AIMessage,
BaseMessage,
FunctionMessage,
HumanMessage,
SystemMessage,
ToolMessage,
)
from collections.abc import AsyncIterator, Iterator, Mapping, Sequence
from operator import itemgetter
from typing import (
Any,
List,
Dict,
Callable,
Literal,
Optional,
Union,
cast,
)
from langchain_core.callbacks import (
CallbackManagerForLLMRun,
)
from langchain_core.language_models import LanguageModelInput
from langchain_core.tools import BaseTool
from langchain_core.language_models import BaseChatModel
from langchain_core.messages import (
AIMessage,
AIMessageChunk,
BaseMessage,
)
from langchain_core.utils.function_calling import (
convert_to_json_schema,
convert_to_openai_tool,
)
from langchain_core.runnables import Runnable, RunnableMap, RunnablePassthrough
from langchain_core.messages.ai import UsageMetadata
from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult
from pydantic import Field
class ChatParrotLink(BaseChatModel):
"""A custom chat model that echoes the first `parrot_buffer_length` characters
of the input.
When contributing an implementation to LangChain, carefully document
the model including the initialization parameters, include
an example of how to initialize the model and include any relevant
links to the underlying models documentation or API.
Example:
.. code-block:: python
model = ChatParrotLink(parrot_buffer_length=2, model="bird-brain-001")
result = model.invoke([HumanMessage(content="hello")])
result = model.batch([[HumanMessage(content="hello")],
[HumanMessage(content="world")]])
"""
model_name: str = Field(alias="model")
"""The name of the model"""
parrot_buffer_length: int
"""The number of characters from the last message of the prompt to be echoed."""
temperature: Optional[float] = None
max_tokens: Optional[int] = None
timeout: Optional[int] = None
stop: Optional[List[str]] = None
max_retries: int = 2
def _generate(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> ChatResult:
"""Override the _generate method to implement the chat model logic.
This can be a call to an API, a call to a local model, or any other
implementation that generates a response to the input prompt.
Args:
messages: the prompt composed of a list of messages.
stop: a list of strings on which the model should stop generating.
If generation stops due to a stop token, the stop token itself
SHOULD BE INCLUDED as part of the output. This is not enforced
across models right now, but it's a good practice to follow since
it makes it much easier to parse the output of the model
downstream and understand why generation stopped.
run_manager: A run manager with callbacks for the LLM.
"""
# Replace this with actual logic to generate a response from a list
# of messages.
last_message = messages[-1]
tokens = last_message.content[: self.parrot_buffer_length]
ct_input_tokens = sum(len(message.content) for message in messages)
ct_output_tokens = len(tokens)
message = AIMessage(
content=tokens,
additional_kwargs={}, # Used to add additional payload to the message
response_metadata={ # Use for response metadata
"time_in_seconds": 3,
"model_name": self.model_name,
},
usage_metadata={
"input_tokens": ct_input_tokens,
"output_tokens": ct_output_tokens,
"total_tokens": ct_input_tokens + ct_output_tokens,
},
)
##
generation = ChatGeneration(message=message)
return ChatResult(generations=[generation])
def bind_tools(
self,
tools: Sequence[Union[dict[str, Any], type, Callable, BaseTool]],
*,
tool_choice: Optional[Union[dict, str, Literal["auto", "any"], bool]] = None, # noqa: PYI051
**kwargs: Any,
) -> Runnable[LanguageModelInput, BaseMessage]:
"""Bind tool-like objects to this chat model.
Assumes model is compatible with OpenAI tool-calling API.
Args:
tools: A list of tool definitions to bind to this chat model.
Supports any tool definition handled by
:meth:`langchain_core.utils.function_calling.convert_to_openai_tool`.
tool_choice: If provided, which tool for model to call. **This parameter
is currently ignored as it is not supported by Ollama.**
kwargs: Any additional parameters are passed directly to
``self.bind(**kwargs)``.
"""
formatted_tools = [convert_to_openai_tool(tool) for tool in tools]
return super().bind(tools=formatted_tools, **kwargs)
def _stream(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> Iterator[ChatGenerationChunk]:
"""Stream the output of the model.
This method should be implemented if the model can generate output
in a streaming fashion. If the model does not support streaming,
do not implement it. In that case streaming requests will be automatically
handled by the _generate method.
Args:
messages: the prompt composed of a list of messages.
stop: a list of strings on which the model should stop generating.
If generation stops due to a stop token, the stop token itself
SHOULD BE INCLUDED as part of the output. This is not enforced
across models right now, but it's a good practice to follow since
it makes it much easier to parse the output of the model
downstream and understand why generation stopped.
run_manager: A run manager with callbacks for the LLM.
"""
last_message = messages[-1]
tokens = str(last_message.content[: self.parrot_buffer_length])
ct_input_tokens = sum(len(message.content) for message in messages)
for token in tokens:
usage_metadata = UsageMetadata(
{
"input_tokens": ct_input_tokens,
"output_tokens": 1,
"total_tokens": ct_input_tokens + 1,
}
)
ct_input_tokens = 0
chunk = ChatGenerationChunk(
message=AIMessageChunk(content=token, usage_metadata=usage_metadata)
)
if run_manager:
# This is optional in newer versions of LangChain
# The on_llm_new_token will be called automatically
run_manager.on_llm_new_token(token, chunk=chunk)
yield chunk
# Let's add some other information (e.g., response metadata)
chunk = ChatGenerationChunk(
message=AIMessageChunk(
content="",
response_metadata={"time_in_sec": 3, "model_name": self.model_name},
)
)
if run_manager:
# This is optional in newer versions of LangChain
# The on_llm_new_token will be called automatically
run_manager.on_llm_new_token(token, chunk=chunk)
yield chunk
@property
def _llm_type(self) -> str:
"""Get the type of language model used by this chat model."""
return "echoing-chat-model-advanced"
@property
def _identifying_params(self) -> Dict[str, Any]:
"""Return a dictionary of identifying parameters.
This information is used by the LangChain callback system, which
is used for tracing purposes make it possible to monitor LLMs.
"""
return {
# The model name allows users to specify custom token counting
# rules in LLM monitoring applications (e.g., in LangSmith users
# can provide per token pricing for their model and monitor
# costs for the given LLM.)
"model_name": self.model_name,
}
3 测试聊天模型
1)模型示例
model = ChatParrotLink(parrot_buffer_length=3, model="my_custom_model")
model.invoke(
[
HumanMessage(content="hello!"),
AIMessage(content="Hi there human!"),
HumanMessage(content="Meow!"),
]
)
返回
AIMessage(content='Meo', additional_kwargs={}, response_metadata={'time_in_seconds': 3, 'model_name': 'my_custom_model'}, id='run--9b2fe63a-a109-4286-9c7c-d6d49e645e3a-0', usage_metadata={'input_tokens': 26, 'output_tokens': 3, 'total_tokens': 29})
2)invoke
model.invoke("hello")
返回
AIMessage(content='hel', additional_kwargs={}, response_metadata={'time_in_seconds': 3, 'model_name': 'my_custom_model'}, id='run--bf910f1d-37dc-4133-a6fd-7d6fe3e0a5e2-0', usage_metadata={'input_tokens': 5, 'output_tokens': 3, 'total_tokens': 8})
3)batch
model.batch(["hello", "goodbye"])
返回
[AIMessage(content='hel', additional_kwargs={}, response_metadata={'time_in_seconds': 3, 'model_name': 'my_custom_model'}, id='run--84be948e-6c7d-4018-8485-e9edf71c791b-0', usage_metadata={'input_tokens': 5, 'output_tokens': 3, 'total_tokens': 8}),
AIMessage(content='goo', additional_kwargs={}, response_metadata={'time_in_seconds': 3, 'model_name': 'my_custom_model'}, id='run--1a1f5df3-a3cb-4dd5-85fb-2592a9dedd5c-0', usage_metadata={'input_tokens': 7, 'output_tokens': 3, 'total_tokens': 10})]
4)stream
for chunk in model.stream("cat"):
print(chunk.content, end="|")
返回
c|a|t||
5)astream
async for chunk in model.astream("cat"):
print(chunk.content, end="|")
返回
c|a|t||
6)astream_events
async for event in model.astream_events("cat", version="v1"):
print(event)
返回
{'event': 'on_chat_model_start', 'run_id': '502e4eeb-9755-4366-85ef-ee660cb0993b', 'name': 'ChatParrotLink', 'tags': [], 'metadata': {}, 'data': {'input': 'cat'}, 'parent_ids': []} {'event': 'on_chat_model_stream', 'run_id': '502e4eeb-9755-4366-85ef-ee660cb0993b', 'tags': [], 'metadata': {}, 'name': 'ChatParrotLink', 'data': {'chunk': AIMessageChunk(content='c', additional_kwargs={}, response_metadata={}, id='run--502e4eeb-9755-4366-85ef-ee660cb0993b', usage_metadata={'input_tokens': 3, 'output_tokens': 1, 'total_tokens': 4})}, 'parent_ids': []} {'event': 'on_chat_model_stream', 'run_id': '502e4eeb-9755-4366-85ef-ee660cb0993b', 'tags': [], 'metadata': {}, 'name': 'ChatParrotLink', 'data': {'chunk': AIMessageChunk(content='a', additional_kwargs={}, response_metadata={}, id='run--502e4eeb-9755-4366-85ef-ee660cb0993b', usage_metadata={'input_tokens': 0, 'output_tokens': 1, 'total_tokens': 1})}, 'parent_ids': []} {'event': 'on_chat_model_stream', 'run_id': '502e4eeb-9755-4366-85ef-ee660cb0993b', 'tags': [], 'metadata': {}, 'name': 'ChatParrotLink', 'data': {'chunk': AIMessageChunk(content='t', additional_kwargs={}, response_metadata={}, id='run--502e4eeb-9755-4366-85ef-ee660cb0993b', usage_metadata={'input_tokens': 0, 'output_tokens': 1, 'total_tokens': 1})}, 'parent_ids': []} {'event': 'on_chat_model_stream', 'run_id': '502e4eeb-9755-4366-85ef-ee660cb0993b', 'tags': [], 'metadata': {}, 'name': 'ChatParrotLink', 'data': {'chunk': AIMessageChunk(content='', additional_kwargs={}, response_metadata={'time_in_sec': 3, 'model_name': 'my_custom_model'}, id='run--502e4eeb-9755-4366-85ef-ee660cb0993b')}, 'parent_ids': []} {'event': 'on_chat_model_end', 'name': 'ChatParrotLink', 'run_id': '502e4eeb-9755-4366-85ef-ee660cb0993b', 'tags': [], 'metadata': {}, 'data': {'output': AIMessageChunk(content='cat', additional_kwargs={}, response_metadata={'time_in_sec': 3, 'model_name': 'my_custom_model'}, id='run--502e4eeb-9755-4366-85ef-ee660cb0993b', usage_metadata={'input_tokens': 3, 'output_tokens': 3, 'total_tokens': 6})}, 'parent_ids': []}
reference
---
如何创建自定义聊天模型类
https://python.langchain.ac.cn/docs/how_to/custom_chat_model/
ChatOllama的定义
langchain-ollama
https://github.com/langchain-ai/langchain/tree/master/libs/partners/ollama
langchain ollama
https://github.com/langchain-ai/langchain/releases/tag/langchain-ollama%3D%3D0.3.6
langchain model配置
更多推荐
所有评论(0)