前言

在前面两篇文章中介绍了 OpenManus Agent 的基本运行流程,以及4个与 Agent 运行流程相关的类的实现。除此之外,还有一些工具类也会参与到流程中,负责提供与大模型、记忆、对话、工具等不可或缺的功能。本文就稍微详细的分析这些工具类的实现。

LLM

LLM 类封装与大模型交互相关的逻辑,提供 askask_with_imagesask_tool 三个方法,以不同的方式与大模型交互。

关键方法

ask

ask 方法通过 OpenAI SDK 与大模型交互,简化后代码如下:

# retry 装饰器用于实现自动重试机制
@retry(
    wait=wait_random_exponential(min=1, max=60),
    stop=stop_after_attempt(6),
    retry=retry_if_exception_type(
        (OpenAIError, Exception, ValueError)
    ),  # Don't retry TokenLimitExceeded
)
async def ask(
    self,
    messages: List[Union[dict, Message]],
    system_msgs: Optional[List[Union[dict, Message]]] = None,
    stream: bool = True,
    temperature: Optional[float] = None,
) -> str:
    """
    Send a prompt to the LLM and get the response.

    Args:
        messages: List of conversation messages
        system_msgs: Optional system messages to prepend
        stream (bool): Whether to stream the response
        temperature (float): Sampling temperature for the response

    Returns:
        str: The generated response

    Raises:
        TokenLimitExceeded: If token limits are exceeded
        ValueError: If messages are invalid or response is empty
        OpenAIError: If API call fails after retries
        Exception: For unexpected errors
    """
    try:
        # 检查模型是否支持输入图片
        supports_images = self.model in MULTIMODAL_MODELS

        # 将消息格式化为支持多模态模型的格式
        if system_msgs:
            system_msgs = self.format_messages(system_msgs, supports_images)
            messages = system_msgs + self.format_messages(messages, supports_images)
        else:
            messages = self.format_messages(messages, supports_images)

        # 计算输入token数量
        input_tokens = self.count_message_tokens(messages)

        # 检查输入token是否超过限制
        if not self.check_token_limit(input_tokens):
            error_message = self.get_limit_error_message(input_tokens)
            raise TokenLimitExceeded(error_message)

        params = {
            "model": self.model,
            "messages": messages,
        }

        if self.model in REASONING_MODELS:
            params["max_completion_tokens"] = self.max_tokens
        else:
            params["max_tokens"] = self.max_tokens
            params["temperature"] = (
                temperature if temperature is not None else self.temperature
            )

        # 普通调用
        if not stream:
            # Non-streaming request
            response = await self.client.chat.completions.create(
                **params, stream=False
            )

            if not response.choices or not response.choices[0].message.content:
                raise ValueError("Empty or invalid response from LLM")

            # 更新输入、输出token数
            self.update_token_count(
                response.usage.prompt_tokens, response.usage.completion_tokens
            )

            return response.choices[0].message.content

        # 流式调用
        self.update_token_count(input_tokens)

        response = await self.client.chat.completions.create(**params, stream=True)

        collected_messages = []
        completion_text = ""
        async for chunk in response:
            chunk_message = chunk.choices[0].delta.content or ""
            collected_messages.append(chunk_message)
            completion_text += chunk_message
            print(chunk_message, end="", flush=True)

        print()  # Newline after streaming
        # 将收集的分片消息拼接成完整的响应
        full_response = "".join(collected_messages).strip()
        if not full_response:
            raise ValueError("Empty response from streaming LLM")

        # 计算输出token数量
        completion_tokens = self.count_tokens(completion_text)
        logger.info(
            f"Estimated completion tokens for streaming response: {completion_tokens}"
        )
        self.total_completion_tokens += completion_tokens

        return full_response

    # 一些错误处理

ask_tool

ask 方法的基础上,支持 Function Call,简化后代码如下:

# retry 装饰器用于实现自动重试机制
@retry(
    wait=wait_random_exponential(min=1, max=60),
    stop=stop_after_attempt(6),
    retry=retry_if_exception_type(
        (OpenAIError, Exception, ValueError)
    ),  # Don't retry TokenLimitExceeded
)
async def ask_tool(
    self,
    messages: List[Union[dict, Message]],
    system_msgs: Optional[List[Union[dict, Message]]] = None,
    timeout: int = 300,
    tools: Optional[List[dict]] = None,
    tool_choice: TOOL_CHOICE_TYPE = ToolChoice.AUTO,  # type: ignore
    temperature: Optional[float] = None,
    **kwargs,
) -> ChatCompletionMessage | None:
    """
    Ask LLM using functions/tools and return the response.

    Args:
        messages: List of conversation messages
        system_msgs: Optional system messages to prepend
        timeout: Request timeout in seconds
        tools: List of tools to use
        tool_choice: Tool choice strategy
        temperature: Sampling temperature for the response
        **kwargs: Additional completion arguments

    Returns:
        ChatCompletionMessage: The model's response

    Raises:
        TokenLimitExceeded: If token limits are exceeded
        ValueError: If tools, tool_choice, or messages are invalid
        OpenAIError: If API call fails after retries
        Exception: For unexpected errors
    """
    try:
        # 检查 tool_choice 入参,必须是 none、auto、required 之一
        if tool_choice not in TOOL_CHOICE_VALUES:
            raise ValueError(f"Invalid tool_choice: {tool_choice}")

        # 检查模型是否支持输入图片
        supports_images = self.model in MULTIMODAL_MODELS

        # 将消息格式化为支持多模态模型的格式
        if system_msgs:
            system_msgs = self.format_messages(system_msgs, supports_images)
            messages = system_msgs + self.format_messages(messages, supports_images)
        else:
            messages = self.format_messages(messages, supports_images)

        # 计算输入token数量
        input_tokens = self.count_message_tokens(messages)

        # 计算工具描述的token数量
        tools_tokens = 0
        if tools:
            for tool in tools:
                tools_tokens += self.count_tokens(str(tool))

        input_tokens += tools_tokens

        # 检查输入token是否超过限制
        if not self.check_token_limit(input_tokens):
            error_message = self.get_limit_error_message(input_tokens)
            raise TokenLimitExceeded(error_message)

        # 校验工具
        if tools:
            for tool in tools:
                if not isinstance(tool, dict) or "type" not in tool:
                    raise ValueError("Each tool must be a dict with 'type' field")

        params = {
            "model": self.model,
            "messages": messages,
            "tools": tools,
            "tool_choice": tool_choice,
            "timeout": timeout,
            **kwargs,
        }

        if self.model in REASONING_MODELS:
            params["max_completion_tokens"] = self.max_tokens
        else:
            params["max_tokens"] = self.max_tokens
            params["temperature"] = (
                temperature if temperature is not None else self.temperature
            )

        params["stream"] = False  # Always use non-streaming for tool requests
        response: ChatCompletion = await self.client.chat.completions.create(
            **params
        )

        if not response.choices or not response.choices[0].message:
            print(response)
            return None

        # 更新输入输出token
        self.update_token_count(
            response.usage.prompt_tokens, response.usage.completion_tokens
        )

        return response.choices[0].message

    # 一些错误处理

Memory

Memory 类继承自 BaseModel,封装与记忆相关的逻辑。BaseModel 是 Pydantic 库提供的基类,提供数据验证、字段管理、序列化等功能。

代码

整体代码非常简单:

class Memory(BaseModel):
    messages: List[Message] = Field(default_factory=list) # 保存对话信息
    max_messages: int = Field(default=100) # 最大保存消息数

    def add_message(self, message: Message) -> None:
        """Add a message to memory"""
        self.messages.append(message)
        # Optional: Implement message limit
        # 如果存储的消息数量超过限制,则保留最近的消息
        if len(self.messages) > self.max_messages:
            self.messages = self.messages[-self.max_messages :]

    def add_messages(self, messages: List[Message]) -> None:
        """Add multiple messages to memory"""
        self.messages.extend(messages)
        # Optional: Implement message limit
        if len(self.messages) > self.max_messages:
            self.messages = self.messages[-self.max_messages :]

    def clear(self) -> None:
        """Clear all messages"""
        self.messages.clear()

    def get_recent_messages(self, n: int) -> List[Message]:
        """Get n most recent messages"""
        return self.messages[-n:]

    def to_dict_list(self) -> List[dict]:
        """Convert messages to list of dicts"""
        return [msg.to_dict() for msg in self.messages]

其中关于记忆裁剪的逻辑非常简单,就是单纯判断存储的消息数量是否超过限制,若超过限制,则默认保留最近 100 条消息。

Message

在 OpenManus 中,与大模型对话的消息通常都是 Message 类型。Message 类继承自 BaseModel,封装与消息相关的逻辑。

关键属性

  1. role:角色,systemuserassistanttool 之一。
  2. content:消息内容,字符串。

关键方法

to_dict

将当前消息对象转为字典,字典格式:

{
    "role": str,
    "content": Optional(str),
    "tool_calls": Optional(List[str]),
    "name": Optional(str),
    "tool_call_id": Optional(str),
    "base64_image": Optional(str)
}

创建消息

Message 提供 5 个静态方法用于创建不同类型的消息对象。这些方法的具体区别就是实例化 Message 对象时设置不同的字段或值,最终都会返回新的 Message 对象:

  1. user_message:创建用户角色的消息。
  2. system_message:创建系统角色的消息。
  3. assistant_message:创建助手(AI)角色的消息。
  4. tool_message:创建工具角色的消息。
  5. from_tool_calls:原始工具调用创建助手消息。

示例:

# 创建不同类型的消息
user_msg = Message.user_message("你好,帮我分析这张图片", base64_image="...")
system_msg = Message.system_message("你是一个 helpful assistant")
assistant_msg = Message.assistant_message("好的,我来帮你分析")
tool_msg = Message.tool_message("执行结果", "calculator", "call_123")

源码

class ToolCall(BaseModel):
    """Represents a tool/function call in a message"""

    id: str
    type: str = "function"
    function: Function

class Message(BaseModel):
    """Represents a chat message in the conversation"""

    role: ROLE_TYPE = Field(...)  # type: ignore
    content: Optional[str] = Field(default=None)
    tool_calls: Optional[List[ToolCall]] = Field(default=None)
    name: Optional[str] = Field(default=None)
    tool_call_id: Optional[str] = Field(default=None)
    base64_image: Optional[str] = Field(default=None)

    def __add__(self, other) -> List["Message"]:
        """支持 Message + list 或 Message + Message 的操作"""
        if isinstance(other, list):
            return [self] + other
        elif isinstance(other, Message):
            return [self, other]
        else:
            raise TypeError(
                f"unsupported operand type(s) for +: '{type(self).__name__}' and '{type(other).__name__}'"
            )

    def __radd__(self, other) -> List["Message"]:
        """支持 list + Message 的操作"""
        if isinstance(other, list):
            return other + [self]
        else:
            raise TypeError(
                f"unsupported operand type(s) for +: '{type(other).__name__}' and '{type(self).__name__}'"
            )

    def to_dict(self) -> dict:
        """Convert message to dictionary format"""
        message = {"role": self.role}
        if self.content is not None:
            message["content"] = self.content
        if self.tool_calls is not None:
            message["tool_calls"] = [tool_call.dict() for tool_call in self.tool_calls]
        if self.name is not None:
            message["name"] = self.name
        if self.tool_call_id is not None:
            message["tool_call_id"] = self.tool_call_id
        if self.base64_image is not None:
            message["base64_image"] = self.base64_image
        return message

    @classmethod
    def user_message(
        cls, content: str, base64_image: Optional[str] = None
    ) -> "Message":
        """Create a user message"""
        return cls(role=Role.USER, content=content, base64_image=base64_image)

    @classmethod
    def system_message(cls, content: str) -> "Message":
        """Create a system message"""
        return cls(role=Role.SYSTEM, content=content)

    @classmethod
    def assistant_message(
        cls, content: Optional[str] = None, base64_image: Optional[str] = None
    ) -> "Message":
        """Create an assistant message"""
        return cls(role=Role.ASSISTANT, content=content, base64_image=base64_image)

    @classmethod
    def tool_message(
        cls, content: str, name, tool_call_id: str, base64_image: Optional[str] = None
    ) -> "Message":
        """Create a tool message"""
        return cls(
            role=Role.TOOL,
            content=content,
            name=name,
            tool_call_id=tool_call_id,
            base64_image=base64_image,
        )

    @classmethod
    def from_tool_calls(
        cls,
        tool_calls: List[Any],
        content: Union[str, List[str]] = "",
        base64_image: Optional[str] = None,
        **kwargs,
    ):
        """Create ToolCallsMessage from raw tool calls.

        Args:
            tool_calls: Raw tool calls from LLM
            content: Optional message content
            base64_image: Optional base64 encoded image
        """
        formatted_calls = [
            {"id": call.id, "function": call.function.model_dump(), "type": "function"}
            for call in tool_calls
        ]
        return cls(
            role=Role.ASSISTANT,
            content=content,
            tool_calls=formatted_calls,
            base64_image=base64_image,
            **kwargs,
        )

ToolCollection

ToolCollection 类中维护了一组工具,并提供 execute 方法调用指定的工具。

关键属性

  1. tools:工具元组,元组内的元素都是 BaseTool 类型,详见后文。
  2. tool_map:维护工具名和工具的映射。

关键方法

  1. execute:执行指定工具。
  2. execute_all:按顺序执行集合中的所有工具。
  3. add_tool:向集合中添加单个工具。如果已存在同名工具,则将跳过该工具。
  4. add_tools:向集合中添加多个工具。如果已存在同名工具,则将跳过该工具。
  5. to_param:依次调用工具的 to_param 方法,生成符合 OpenAI Function Call 要求的工具描述。

源码

ToolCollection 的源码非常简单:

class ToolCollection:
    """A collection of defined tools."""

    class Config:
        arbitrary_types_allowed = True

    # *tools 表示该方法可以接受任意数量的位置参数,所有传入的参数会被收集到一个元组中
    def __init__(self, *tools: BaseTool):
        self.tools = tools
        self.tool_map = {tool.name: tool for tool in tools}

    # __iter__ 方法使得 ToolCollection 类的实例可以被迭代
    def __iter__(self):
        # iter 是 Python 的内置函数,用于获取一个对象的迭代器
        return iter(self.tools)

    def to_params(self) -> List[Dict[str, Any]]:
        return [tool.to_param() for tool in self.tools]

    async def execute(
        # * 后面的参数必须以关键字形式传递
        self, *, name: str, tool_input: Dict[str, Any] = None
    ) -> ToolResult:
        """执行指定工具"""
        tool = self.tool_map.get(name)
        if not tool:
            return ToolFailure(error=f"Tool {name} is invalid")
        try:
            # 使用 ** 解包字典,将键值对作为关键字参数传递给工具
            result = await tool(**tool_input)
            return result
        except ToolError as e:
            return ToolFailure(error=e.message)

    async def execute_all(self) -> List[ToolResult]:
        """按顺序执行集合中的所有工具"""
        results = []
        for tool in self.tools:
            try:
                # 注意没有传入参数
                result = await tool()
                results.append(result)
            except ToolError as e:
                results.append(ToolFailure(error=e.message))
        return results

    def get_tool(self, name: str) -> BaseTool:
        return self.tool_map.get(name)

    def add_tool(self, tool: BaseTool):
        """向集合中添加单个工具。如果已存在同名工具,则将跳过该工具并记录警告。"""
        if tool.name in self.tool_map:
            logger.warning(f"Tool {tool.name} already exists in collection, skipping")
            return self

        # 元组创建之后不能修改,这里是重新绑定
        # 等同于 self.tools = self.tools + (tool,)
        self.tools += (tool,)
        self.tool_map[tool.name] = tool
        return self

    def add_tools(self, *tools: BaseTool):
        """Add multiple tools to the collection.

        If any tool has a name conflict with an existing tool, it will be skipped and a warning will be logged.
        """
        for tool in tools:
            self.add_tool(tool)
        return self

工具

BaseTool

在 OpenManus 中,所有的工具都是 BaseTool 的子类。BaseTool 继承自 BaseModel 和 ABC。

BaseModel 前面已经介绍过了。ABC (Abstract Base Class) 是 Python 标准库中的抽象基类,提供接口定义、防止实例化等功能。

关键属性

  1. name: 工具名称,由子类覆盖。
  2. description: 工具名描述,由子类覆盖。
  3. parameters: 字典,符合 OpenAI Function Call 要求的工具描述,由子类覆盖。

关键方法

  1. execute:执行的抽象方法,子类必须重写。
  2. to_param:为当前工具生成符合 openAI Function Call 要求的工具描述。

源码

class BaseTool(ABC, BaseModel):
    name: str
    description: str
    parameters: Optional[dict] = None

    class Config:
        arbitrary_types_allowed = True

    async def __call__(self, **kwargs) -> Any:
        """Execute the tool with given parameters."""
        return await self.execute(**kwargs)

    @abstractmethod
    async def execute(self, **kwargs) -> Any:
        """Execute the tool with given parameters."""

    def to_param(self) -> Dict:
        """Convert tool to function call format."""
        return {
            "type": "function",
            "function": {
                "name": self.name,
                "description": self.description,
                "parameters": self.parameters,
            },
        }

工具实现

举两个在 Manus 类中使用、具有代表性的工具,来展示 OpenManus 中,工具是如何定义的。

Terminate

这是一个很特殊的工具,用于终止 Agent:

_TERMINATE_DESCRIPTION = """Terminate the interaction when the request is met OR if the assistant cannot proceed further with the task.
When you have finished all the tasks, call this tool to end the work."""
"""当请求得到满足或助手无法继续执行任务时,终止交互。当你完成所有任务后,调用此工具结束工作。"""

class Terminate(BaseTool):
    name: str = "terminate"
    description: str = _TERMINATE_DESCRIPTION
    parameters: dict = {
        "type": "object",
        "properties": {
            "status": {
                "type": "string",
                "description": "The finish status of the interaction.",
                "enum": ["success", "failure"],
            }
        },
        "required": ["status"],
    }

    async def execute(self, status: str) -> str:
        """Finish the current execution"""
        return f"The interaction has been completed with status: {status}"

实际并不会执行什么逻辑,而是直接返回一个字符串,当这个工具调用时,就意味着 Agent 的工作完成了。

PythonExecute

用于执行Python代码的工具,提供超时和安全限制。

_run_code

PythonExecute 中定义的一个内部方法,用于安全执行Python代码并捕获输出。

其原理就是劫持标准输出流到一个字符串缓冲区,通过 exec 方法在沙箱中执行代码:

def _run_code(self, code: str, result_dict: dict, safe_globals: dict) -> None:
    """安全执行Python代码并捕获输出"""

    # 保存当前的标准输出流,以便后续恢复
    original_stdout = sys.stdout
    try:
        # 创建一个内存中的字符串缓冲区,用于捕获代码执行期间的输出
        output_buffer = StringIO()
        # 将标准输出重定向到缓冲区,这样代码中所有的 print 输出都会被捕捉到 output_buffer 中
        sys.stdout = output_buffer
        # 在指定的安全全局命名空间中执行传入的 Python 代码字符串
        exec(code, safe_globals, safe_globals)
        # 获取缓冲区中捕获的所有输出内容,并存入结果字典
        result_dict["observation"] = output_buffer.getvalue()
        # 标记代码执行成功
        result_dict["success"] = True
    except Exception as e:
        # 如果执行过程中发生异常,则将异常信息存入结果字典
        result_dict["observation"] = str(e)
        # 标记代码执行失败
        result_dict["success"] = False
    finally:
        # 无论执行成功还是失败,都恢复原始的标准输出流
        sys.stdout = original_stdout
execute

工具具体的执行方法,基本原理就是通过一个子进程执行 _run_code 方法:

async def execute(
    self,
    code: str,
    timeout: int = 5,
) -> Dict:
    """
    Executes the provided Python code with a timeout.

    Args:
        code (str): The Python code to execute.
        timeout (int): Execution timeout in seconds.

    Returns:
        Dict: Contains 'output' with execution output or error message and 'success' status.
    """

    # Manager 用于进程间通信
    with multiprocessing.Manager() as manager:
        result = manager.dict({"observation": "", "success": False})
        # __builtins__ 是 Python 中一个特殊的内置模块/命名空间,包含了Python解释器内置的所有函数、异常和常量
        if isinstance(__builtins__, dict):
            safe_globals = {"__builtins__": __builtins__}
        else:
            safe_globals = {"__builtins__": __builtins__.__dict__.copy()}
        # 启动子进程执行代码
        proc = multiprocessing.Process(
            target=self._run_code, args=(code, result, safe_globals)
        )
        proc.start()
        proc.join(timeout)

        # timeout process
        if proc.is_alive():
            proc.terminate()
            proc.join(1)
            return {
                "observation": f"Execution timeout after {timeout} seconds",
                "success": False,
            }
        return dict(result)

关于 Python 中的多进程,可以参考

Logo

有“AI”的1024 = 2048,欢迎大家加入2048 AI社区

更多推荐