langchain文档给了创建自定义聊天模型的例子ChatParrotLink,但ChatParrotLink并不具备集成MCP工具的能力。这里参考langchain_ollama.ChatOllama,尝试完善ChatParrotLink。为降低难度,这里暂不考虑stream模式。

如何创建自定义聊天模型类 | 🦜️🔗 LangChain 框架

1  MCP + LLM

1)MCP Prompt

参考@mcp.tool如何从函数定义映射到llm系统输入-CSDN博客,mcp tool作为system消息输入LLM。

先写一个system_prompt模版,提示LLM针对输入mcp_tools,参考tool_call_examples例子,选择合适的tool call输出,示例如下。

"""你是一位函数组合专家。已知一个问题及一组可能的函数,你需要根据问题目的,进行一个或多个函数/工具调用来实现目标。
如果用户问题不能用函数组合方式解答,请按正常方式回答。
否则:
    1)如果没有任何函数可用,请明确指出,然后正常返回;
    2)如果问题缺少函数所需的参数,也请指出,然而正常返回;
    3)如果你决定调用任何函数,只需在工具调用部分返回函数调用内容,在回应中不应包含其他任何文本,必须严格按照以下格式书写:```json
{tool_call_examples}
```。


以下是你可以调用的函数列表,以 JSON 格式呈现:
```json
{mcp_tools}
```
"""

2)消息角色转化

参考OllamaChat,ChatParrotLink通过_generate->_chat_stream_with_aggregation调用LLM。

在调用LLM前需要将langchain格式的消息转化为目标LLM消息,示例如下

为简化分析,统一将tool call、tool message消息转化为user消息,两者都在用户端完成,转化为user消息不影响正确性。

        role_msgs = []
        for msg in messages:
            role = None
            if type(msg) == HumanMessage:
                role = "user"
            elif type(msg) == SystemMessage:
                role = "system"
            elif type(msg) == AIMessage:
                role = "assistant"
            elif type(msg) == ToolMessage:
                role = "user"
            elif type(msg) == FunctionMessage:
                role = "fuction"
            elif type(msg) == ToolCall:
                role = "user"
            if role:
                role_msg = {
                    "role": role,
                    "content": f"{msg.content}"
                }
                role_msgs.append(role_msg.copy())

3)MCP工具集成

在实际调用llm之前,还需要将mcp tool prompt添加到system消息中,这样LLM才能看到MCP定义的工具,示例如下。

system_content = sys_template.format(json.dumps(out_ex, ensure_ascii=False), json.dumps(mcp_tools, ensure_ascii=False))
msgs = [{'role': 'system', 'content': system_content}] + messages
content = llm_impl.main_run(msgs)

整体ChatParrotLink定义代码如下

import json 
from langchain_core.messages import (
    AIMessage,
    AIMessageChunk,
    BaseMessage,
    ChatMessage,
    HumanMessage,
    SystemMessage,
    ToolCall,
    ToolMessage,
    is_data_content_block,
)
import ast
import json
import logging
from uuid import uuid4
from collections.abc import AsyncIterator, Iterator, Mapping, Sequence
from operator import itemgetter
from typing import (
    Any,
    List,
    Dict, 
    Callable,
    Literal,
    Optional,
    Union,
    cast,
)
from langchain_core.callbacks import (
    CallbackManagerForLLMRun,
)
from langchain_core.language_models import LanguageModelInput
from langchain_core.tools import BaseTool
from langchain_core.language_models import BaseChatModel
from langchain_core.messages import (
    AIMessage,
    AIMessageChunk,
    BaseMessage,
)
from langchain_core.utils.function_calling import (
    convert_to_json_schema,
    convert_to_openai_tool,
)
from langchain_core.messages.tool import tool_call
from langchain_core.runnables import Runnable, RunnableMap, RunnablePassthrough
from langchain_core.messages.ai import UsageMetadata
from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult
from pydantic import Field


log = logging.getLogger(__name__)



sys_template = """你是一位函数组合专家。已知一个问题及一组可能的函数,你需要根据问题目的,进行一个或多个函数/工具调用来实现目标。
如果用户问题不能用函数组合方式解答,请按正常方式回答。
否则:
    1)如果没有任何函数可用,请明确指出,然后正常返回;
    2)如果问题缺少函数所需的参数,也请指出,然而正常返回;
    3)如果你决定调用任何函数,只需在工具调用部分返回函数调用内容,在回应中不应包含其他任何文本,必须严格按照以下格式书写:```json
{}
```。


以下是你可以调用的函数列表,以 JSON 格式呈现:
```json
{}
```
"""

out_ex = [
    {"function": {"name": "get_weather", "arguments": {"city": "San Francisco", "metric": "celsius"}}},
    {"function":  {"name": "get_weather", "arguments": {"city": "Seattle", "metric": "celsius"}}}
]

def llm_call(messages, mcp_tools):
    system_content = sys_template.format(json.dumps(out_ex, ensure_ascii=False), json.dumps(mcp_tools, ensure_ascii=False))
    msgs = [{'role': 'system', 'content': system_content}] + messages

    print(f"msgs: {msgs}")
    
    # llm call
    import llm_impl
    content = llm_impl.main_run(msgs)
    
    # parse tool calls
    tool_calls = None
    if "```json" in content:
        call_body = content.strip().lstrip("```json").rstrip("```").strip()
        tool_calls = json.loads(call_body)
    return {
        "message": {
            "content": content,
            "tool_calls": tool_calls
        }
    }


def _get_usage_metadata_from_generation_info(
    generation_info: Optional[Mapping[str, Any]],
) -> Optional[UsageMetadata]:
    """Get usage metadata from ollama generation info mapping."""
    if generation_info is None:
        return None
    input_tokens: Optional[int] = generation_info.get("prompt_eval_count")
    output_tokens: Optional[int] = generation_info.get("eval_count")
    if input_tokens is not None and output_tokens is not None:
        return UsageMetadata(
            input_tokens=input_tokens,
            output_tokens=output_tokens,
            total_tokens=input_tokens + output_tokens,
        )
    return None


def _parse_json_string(
    json_string: str,
    *,
    raw_tool_call: dict[str, Any],
    skip: bool,
) -> Any:
    """Attempt to parse a JSON string for tool calling.

    It first tries to use the standard json.loads. If that fails, it falls
    back to ast.literal_eval to safely parse Python literals, which is more
    robust against models using single quotes or containing apostrophes.

    Args:
        json_string: JSON string to parse.
        raw_tool_call: Raw tool call to include in error message.
        skip: Whether to ignore parsing errors and return the value anyways.

    Returns:
        The parsed JSON string or Python literal.

    Raises:
        OutputParserException: If the string is invalid and skip=False.
    """
    try:
        return json.loads(json_string)
    except json.JSONDecodeError:
        try:
            # Use ast.literal_eval to safely parse Python-style dicts
            # (e.g. with single quotes)
            return ast.literal_eval(json_string)
        except (SyntaxError, ValueError) as e:
            # If both fail, and we're not skipping, raise an informative error.
            if skip:
                return json_string
            msg = (
                f"Function {raw_tool_call['function']['name']} arguments:\n\n"
                f"{raw_tool_call['function']['arguments']}"
                "\n\nare not valid JSON or a Python literal. "
                f"Received error: {e}"
            )
            raise OutputParserException(msg) from e
    except TypeError as e:
        if skip:
            return json_string
        msg = (
            f"Function {raw_tool_call['function']['name']} arguments:\n\n"
            f"{raw_tool_call['function']['arguments']}\n\nare not a string or a "
            f"dictionary. Received TypeError {e}"
        )
        raise OutputParserException(msg) from e



def _parse_arguments_from_tool_call(
    raw_tool_call: dict[str, Any],
) -> Optional[dict[str, Any]]:
    """Parse arguments by trying to parse any shallowly nested string-encoded JSON.

    Band-aid fix for issue in Ollama with inconsistent tool call argument structure.
    Should be removed/changed if fixed upstream.
    See https://github.com/ollama/ollama/issues/6155
    """
    if "function" not in raw_tool_call:
        return None
    arguments = raw_tool_call["function"]["arguments"]
    parsed_arguments: dict = {}
    if isinstance(arguments, dict):
        for key, value in arguments.items():
            if isinstance(value, str):
                parsed_value = _parse_json_string(
                    value, skip=True, raw_tool_call=raw_tool_call
                )
                if isinstance(parsed_value, (dict, list)):
                    parsed_arguments[key] = parsed_value
                else:
                    parsed_arguments[key] = value
            else:
                parsed_arguments[key] = value
    else:
        parsed_arguments = _parse_json_string(
            arguments, skip=False, raw_tool_call=raw_tool_call
        )
    return parsed_arguments


def _get_tool_calls_from_response(
    response: Mapping[str, Any],
) -> list[ToolCall]:
    """Get tool calls from ollama response."""
    tool_calls = []
    if "message" in response and (
        raw_tool_calls := response["message"].get("tool_calls")
    ):
        tool_calls.extend(
            [
                tool_call(
                    id=str(uuid4()),
                    name=tc["function"]["name"],
                    args=_parse_arguments_from_tool_call(tc) or {},
                )
                for tc in raw_tool_calls
            ]
        )
    return tool_calls


def _lc_tool_call_to_openai_tool_call(tool_call_: ToolCall) -> dict:
    """Convert a LangChain tool call to an OpenAI tool call format."""
    return {
        "type": "function",
        "id": tool_call_["id"],
        "function": {
            "name": tool_call_["name"],
            "arguments": tool_call_["args"],
        },
    }


def _get_image_from_data_content_block(block: dict) -> str:
    """Format standard data content block to format expected by Ollama."""
    if block["type"] == "image":
        if block["source_type"] == "base64":
            return block["data"]
        error_message = "Image data only supported through in-line base64 format."
        raise ValueError(error_message)

    error_message = f"Blocks of type {block['type']} not supported."
    raise ValueError(error_message)


def _is_pydantic_class(obj: Any) -> bool:
    return isinstance(obj, type) and is_basemodel_subclass(obj)


class ChatParrotLink(BaseChatModel):
    """A custom chat model that echoes the first `parrot_buffer_length` characters
    of the input.

    When contributing an implementation to LangChain, carefully document
    the model including the initialization parameters, include
    an example of how to initialize the model and include any relevant
    links to the underlying models documentation or API.

    Example:

        .. code-block:: python

            model = ChatParrotLink(parrot_buffer_length=2, model="bird-brain-001")
            result = model.invoke([HumanMessage(content="hello")])
            result = model.batch([[HumanMessage(content="hello")],
                                 [HumanMessage(content="world")]])
    """

    model_name: str = Field(alias="model")
    """The name of the model"""
    parrot_buffer_length: int
    """The number of characters from the last message of the prompt to be echoed."""
    temperature: Optional[float] = None
    max_tokens: Optional[int] = None
    timeout: Optional[int] = None
    stop: Optional[List[str]] = None
    max_retries: int = 2
    reasoning: Optional[bool] = None

    def _generate(
        self,
        messages: List[BaseMessage],
        stop: Optional[List[str]] = None,
        run_manager: Optional[CallbackManagerForLLMRun] = None,
        **kwargs: Any,
    ) -> ChatResult:
        """Override the _generate method to implement the chat model logic.

        This can be a call to an API, a call to a local model, or any other
        implementation that generates a response to the input prompt.

        Args:
            messages: the prompt composed of a list of messages.
            stop: a list of strings on which the model should stop generating.
                  If generation stops due to a stop token, the stop token itself
                  SHOULD BE INCLUDED as part of the output. This is not enforced
                  across models right now, but it's a good practice to follow since
                  it makes it much easier to parse the output of the model
                  downstream and understand why generation stopped.
            run_manager: A run manager with callbacks for the LLM.
        """
        # Replace this with actual logic to generate a response from a list
        # of messages.

        final_chunk = self._chat_stream_with_aggregation(
            messages, stop, run_manager, verbose=self.verbose, **kwargs
        )
        generation_info = final_chunk.generation_info
        chat_generation = ChatGeneration(
            message=AIMessage(
                content=final_chunk.text,
                usage_metadata=cast(AIMessageChunk, final_chunk.message).usage_metadata,
                tool_calls=cast(AIMessageChunk, final_chunk.message).tool_calls,
                additional_kwargs=final_chunk.message.additional_kwargs,
            ),
            generation_info=generation_info,
        )
        return ChatResult(generations=[chat_generation])


    def bind_tools(
        self,
        tools: Sequence[Union[dict[str, Any], type, Callable, BaseTool]],
        *,
        tool_choice: Optional[Union[dict, str, Literal["auto", "any"], bool]] = None,  # noqa: PYI051
        **kwargs: Any,
    ) -> Runnable[LanguageModelInput, BaseMessage]:
        """Bind tool-like objects to this chat model.

        Assumes model is compatible with OpenAI tool-calling API.

        Args:
            tools: A list of tool definitions to bind to this chat model.
                Supports any tool definition handled by
                :meth:`langchain_core.utils.function_calling.convert_to_openai_tool`.
            tool_choice: If provided, which tool for model to call. **This parameter
                is currently ignored as it is not supported by Ollama.**
            kwargs: Any additional parameters are passed directly to
                ``self.bind(**kwargs)``.
        """
        formatted_tools = [convert_to_openai_tool(tool) for tool in tools]
        return super().bind(tools=formatted_tools, **kwargs)


    def _chat_stream_with_aggregation(
        self,
        messages: list[BaseMessage],
        stop: Optional[list[str]] = None,
        run_manager: Optional[CallbackManagerForLLMRun] = None,
        verbose: bool = False,  # noqa: FBT001, FBT002
        **kwargs: Any,
    ) -> ChatGenerationChunk:
        role_msgs = []
        for msg in messages:
            role = None
            if type(msg) == HumanMessage:
                role = "user"
            elif type(msg) == SystemMessage:
                role = "system"
            elif type(msg) == AIMessage:
                role = "assistant"
            elif type(msg) == ToolMessage:
                role = "user"
            elif type(msg) == FunctionMessage:
                role = "fuction"
            elif type(msg) == ToolCall:
                role = "toolcall"
            if role:
                role_msg = {
                    "role": role,
                    "content": f"{msg.content}"
                }
                role_msgs.append(role_msg.copy())
        # llm call
        stream_resp = llm_call(role_msgs, kwargs["tools"])
        content = stream_resp["message"]["content"]
        final_chunk = ChatGenerationChunk(
                    message=AIMessageChunk(
                        content=content,
                        tool_calls=_get_tool_calls_from_response(stream_resp),
                    ),
                    generation_info=None,
                )
        return final_chunk


    def _stream(
        self,
        messages: List[BaseMessage],
        stop: Optional[List[str]] = None,
        run_manager: Optional[CallbackManagerForLLMRun] = None,
        **kwargs: Any,
    ) -> Iterator[ChatGenerationChunk]:
        """Stream the output of the model.

        This method should be implemented if the model can generate output
        in a streaming fashion. If the model does not support streaming,
        do not implement it. In that case streaming requests will be automatically
        handled by the _generate method.

        Args:
            messages: the prompt composed of a list of messages.
            stop: a list of strings on which the model should stop generating.
                  If generation stops due to a stop token, the stop token itself
                  SHOULD BE INCLUDED as part of the output. This is not enforced
                  across models right now, but it's a good practice to follow since
                  it makes it much easier to parse the output of the model
                  downstream and understand why generation stopped.
            run_manager: A run manager with callbacks for the LLM.
        """
        last_message = messages[-1]
        tokens = str(last_message.content[: self.parrot_buffer_length])
        ct_input_tokens = sum(len(message.content) for message in messages)

        for token in tokens:
            usage_metadata = UsageMetadata(
                {
                    "input_tokens": ct_input_tokens,
                    "output_tokens": 1,
                    "total_tokens": ct_input_tokens + 1,
                }
            )
            ct_input_tokens = 0
            chunk = ChatGenerationChunk(
                message=AIMessageChunk(content=token, usage_metadata=usage_metadata)
            )

            if run_manager:
                # This is optional in newer versions of LangChain
                # The on_llm_new_token will be called automatically
                run_manager.on_llm_new_token(token, chunk=chunk)

            yield chunk

        # Let's add some other information (e.g., response metadata)
        chunk = ChatGenerationChunk(
            message=AIMessageChunk(
                content="",
                response_metadata={"time_in_sec": 3, "model_name": self.model_name},
            )
        )
        if run_manager:
            # This is optional in newer versions of LangChain
            # The on_llm_new_token will be called automatically
            run_manager.on_llm_new_token(token, chunk=chunk)
        yield chunk

    @property
    def _llm_type(self) -> str:
        """Get the type of language model used by this chat model."""
        return "echoing-chat-model-advanced"

    @property
    def _identifying_params(self) -> Dict[str, Any]:
        """Return a dictionary of identifying parameters.

        This information is used by the LangChain callback system, which
        is used for tracing purposes make it possible to monitor LLMs.
        """
        return {
            # The model name allows users to specify custom token counting
            # rules in LLM monitoring applications (e.g., in LangSmith users
            # can provide per token pricing for their model and monitor
            # costs for the given LLM.)
            "model_name": self.model_name,
        }


2 MCP工具定义

这里参考@mcp.tool如何从函数定义映射到llm系统输入-CSDN博客,定义check_child_study_situation、query_student_score等MCP工具,示例如下。


from mcp.server.fastmcp import FastMCP
import os

mcp = FastMCP("Tom's tools")

@mcp.tool()
def check_child_study_situation(name: str) -> str:
    """检查小孩最近的学习状况"""
    db = {
        "小明": "学习很努力 from Michael阿明老师点评",
        "小红": "学习一般",
        "小刚": "学习不太好",
    }
    print(f"Checking study status for {name}")
    return db.get(name, "没有找到这个小孩的学习记录")

@mcp.tool()
def query_student_scores(name: str) -> str:
    """查询学生成绩
    
    Args:
        name: 学生姓名
        
    Returns:
        学生成绩信息,如果未找到则返回相应提示
    """
    file_path = os.path.join(os.path.dirname(__file__), "score_points.txt")
    
    try:
        with open(file_path, "r", encoding="utf-8") as f:
            content = f.read()
        
        # 解析文件内容
        students = {}
        current_student = None
        
        for line in content.split("\n"):
            line = line.strip()
            if not line:
                continue
                
            if line.endswith(":"):  # 学生名
                current_student = line[:-1]  # 去掉结尾的冒号
                students[current_student] = {}
            elif current_student and ":" in line:  # 科目分数
                subject, score = line.split(":")
                students[current_student][subject.strip()] = score.strip()
        
        # 返回学生成绩
        if name in students:
            result = f"{name}的成绩:\n"
            for subject, score in students[name].items():
                result += f"- {subject}: {score}\n"
            return result
        else:
            return f"没有找到{name}的成绩记录"
    except Exception as e:
        return f"查询成绩出错: {str(e)}"

@mcp.tool()
def add(a: int, b: int) -> int:
    """Add two numbers"""
    return a + b

@mcp.tool()
def multiply(a: int, b: int) -> int:
    """Multiply two numbers"""
    return a * b

if __name__ == "__main__":
    mcp.run(transport="stdio")

3 Agent集成

这里使用自定义ChatParrotLink作为Agent的ChatBaseModel,上述定义MCP函数作为Agent可使用工具,代码示例如下。

from mcp import ClientSession, StdioServerParameters
from mcp.client.stdio import stdio_client
from langchain_mcp_adapters.tools import load_mcp_tools
from langgraph.prebuilt import create_react_agent
import asyncio


import llm_chat_model

model = llm_chat_model.ChatParrotLink(parrot_buffer_length=81920, model="my_custom_model")

server_params = StdioServerParameters(
    command="python",
    args=["server.py"],
)

async def run_agent():
    async with stdio_client(server_params) as (read, write):
        async with ClientSession(read, write) as session:
            print("xxxxx")
            await session.initialize()
            tools = await load_mcp_tools(session)
            agent = create_react_agent(model, tools)
            print("dfasdfdfa")
            agent_response = await agent.ainvoke({
                "messages": [
                    {"role": "user", "content": "小明最近学习成绩?"},
                ]
            })
            return agent_response

if __name__ == "__main__": 
    result = asyncio.run(run_agent())
    messages = result["messages"]
    print(len(messages))
    print(messages)

此时agent代码就可以正常运行了。针对Agent提问"小明最近学习成绩?",Agent回复如下

[
HumanMessage(content='小明最近学习成绩?', additional_kwargs={}, response_metadata={}, id='1603671f-4d40-4152-9dd1-1ec3a6e8b6be'),

AIMessage(content='\n\n```json\n[{"function": {"name": "query_student_scores", "arguments": {"name": "小明"}}}]\n```', additional_kwargs={}, response_metadata={}, id='run--a37af7af-385e-4563-9baf-5532e27e9220-0', tool_calls=[{'name': 'query_student_scores', 'args': {'name': '小明'}, 'id': '0ecaa97e-8f87-4277-bcad-800a36ad0d37', 'type': 'tool_call'}]),

ToolMessage(content='小明的成绩:\n- 语文: 80\n- 数学: 95\n', name='query_student_scores', id='6b9ee9f2-f3d4-47cc-8e24-a0a29c960827', tool_call_id='0ecaa97e-8f87-4277-bcad-800a36ad0d37'),

AIMessage(content='\n\n小明的成绩如下:\n- 语文:80分\n- 数学:95分\n\n数学成绩尤其出色,表现很不错!需要分析具体科目的学习情况吗?', additional_kwargs={}, response_metadata={}, id='run--ed4dff97-5f36-4439-bb6f-0ca5ed8afb41-0')
]

Agent可以依据上下文和用户输入问题,自主选择合适的MCP工具调用。

这里ChatParrotLink最终调用的是deepseek-r1。

reference

---

如何创建自定义聊天模型类

https://python.langchain.ac.cn/docs/how_to/custom_chat_model/

langchain_ollama/chat_models.py

https://github.com/langchain-ai/langchain/blob/master/libs/partners/ollama/langchain_ollama/chat_models.py

LLM中工具调用的完整指南:释放大模型潜能的关键

https://zhuanlan.zhihu.com/p/1909186977066128479

工具调用原理:从推理到训练

https://csrayz.github.io/post/tool-calling-principle-from-reasoning-to-training.html

揭秘Function calling:详解大模型调用工具底层原理,四大优化方案提升Agent性能!

https://blog.csdn.net/fufan_LLM/article/details/147234519

如何基于BaseChatModel自定义langchain聊天模型

https://blog.csdn.net/liliang199/article/details/150399931

Logo

有“AI”的1024 = 2048,欢迎大家加入2048 AI社区

更多推荐