正常情况下一般一个工具注册器就包含了所有的工具包括mcp工具,那为什么还要引入这个工具路由呢,其实有一些工具在特定的场景下是没有必要使用的,如果你使用了就会多消耗token但是还没有什么用,下面来看代码。

package com.code.codeplus.config;

import com.code.codeplus.agent.tools.CommonTools;
import com.code.codeplus.agent.tools.RagTools;
import lombok.RequiredArgsConstructor;
import org.springframework.ai.model.function.FunctionCallback;
import org.springframework.ai.tool.ToolCallback;
import org.springframework.ai.tool.ToolCallbackProvider;
import org.springframework.ai.tool.ToolCallbacks;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Configuration;
import org.springframework.context.annotation.Primary;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;

/**
 * 统一工具回调提供者配置
 * 整合 CommonTools、RagTools、MCP Tools(自动注入)
 */
@Configuration
@RequiredArgsConstructor
public class ToolCallbackProviderConfig {

    private final CommonTools commonTools;
    private final RagTools ragTools;
    
    // MCP 工具(由 spring-ai-starter-mcp-client-webflux 自动注入,可能为空)
    @Autowired(required = false)
    private List<ToolCallbackProvider> mcpToolCallbackProviders;

    /**
     * MCP 工具列表
     */
    @Bean("mcpTools")
    public List<ToolCallback> mcpTools() {
        if (mcpToolCallbackProviders == null || mcpToolCallbackProviders.isEmpty()) {
            return List.of();
        }
        List<ToolCallback> mcpTools = new ArrayList<>();
        for (ToolCallbackProvider provider : mcpToolCallbackProviders) {
            FunctionCallback[] callbacks = provider.getToolCallbacks();
            for (FunctionCallback cb : callbacks) {
                if (cb instanceof ToolCallback tc) {
                    mcpTools.add(tc);
                }
            }
        }
        return mcpTools;
    }

    /**
     * 统一工具回调提供者
     * 包含所有工具:CommonTools + RagTools + MCP Tools
     */
    @Bean
    @Primary
    public ToolCallbackProvider toolCallbackProvider() {
        return new ToolCallbackProvider() {
            @Override
            public FunctionCallback[] getToolCallbacks() {
                List<FunctionCallback> allTools = new ArrayList<>();
                
                // 添加 CommonTools
                allTools.addAll(Arrays.asList(ToolCallbacks.from(commonTools)));
                
                // 添加 RagTools
                allTools.addAll(Arrays.asList(ToolCallbacks.from(ragTools)));
                
                // 添加 MCP Tools
                allTools.addAll(mcpTools());
                
                return allTools.toArray(new FunctionCallback[0]);
            }
        };
    }
}

这种是通过

// MCP 工具(由 spring-ai-starter-mcp-client-webflux 自动注入,可能为空)这种方法注册的mcp但是这种不推荐但是比较方便,这就是把mcp工具给注册进来,下面来看业务代码。
package com.code.codeplus.agent;

import com.code.codeplus.agent.tools.CommonTools;
import com.code.codeplus.agent.tools.RagTools;
import com.code.codeplus.service.HotKeyCaffeineService;
import lombok.RequiredArgsConstructor;
import org.springframework.ai.model.function.FunctionCallback;
import org.springframework.ai.tool.ToolCallback;
import org.springframework.ai.tool.ToolCallbackProvider;
import org.springframework.ai.tool.ToolCallbacks;
import org.springframework.stereotype.Component;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;

/**
 * 工具路由器
 * 缓存命中返回全部工具,未命中返回除RAG外的所有工具
 */
@Component
@RequiredArgsConstructor
public class ToolRouter {
    
    private final HotKeyCaffeineService hotKeyCaffeineService;
    private final RagTools ragTools;
    private final CommonTools commonTools;
    private final List<ToolCallback> mcpTools; // MCP 工具注入
    private final ToolCallbackProvider toolCallbackProvider; // 全部工具提供者
    
    /**
     * 根据题目缓存状态返回可用工具
     * 缓存命中 → 返回全部工具(Common + RAG + MCP)
     * 缓存未命中 → 返回除RAG外的所有工具(Common + MCP)
     */
    public FunctionCallback[] getAvailableTools(Integer problemId) {
        String cached = problemId != null ? hotKeyCaffeineService.getQuestionText(problemId) : null;
        boolean cacheHit = cached != null && !cached.startsWith("题号格式错误") && !cached.startsWith("未找到");
        
        List<FunctionCallback> tools = new ArrayList<>();
        tools.addAll(Arrays.asList(ToolCallbacks.from(commonTools)));
        tools.addAll(mcpTools); // MCP 工具始终包含
        
        if (cacheHit) {
            tools.addAll(Arrays.asList(ToolCallbacks.from(ragTools)));
        }
        return tools.toArray(new FunctionCallback[0]);
    }
    
    /**
     * 获取全部工具(从 ToolCallbackProvider)
     */
    public FunctionCallback[] getAllTools() {
        return toolCallbackProvider.getToolCallbacks();
    }
}

就是把刚才那个

public List<ToolCallback> mcpTools给注入进来。

这里就是根据缓存是否命中然后选择是否要返回RAG工具。

下面是业务代码:

import cn.hutool.core.util.StrUtil;
import org.springframework.ai.chat.ChatResponse;
import org.springframework.ai.chat.messages.SystemMessage;
import org.springframework.ai.chat.messages.UserMessage;
import org.springframework.ai.chat.prompt.Prompt;
import org.springframework.ai.tool.FunctionCallback;
import org.springframework.ai.tool.ToolCallback;
import org.springframework.ai.support.ToolCallbacks;
import lombok.extern.slf4j.Slf4j;
import javax.annotation.Resource;
import java.util.List;

/**
 * 核心思考方法:修正ToolCallback→FunctionCallback类型转换,保证工具全量可用
 */
@Slf4j
public class YourThinkClass { // 替换为你的实际类名
    // 原有成员变量保留
    private List<org.springframework.ai.chat.messages.Message> messageList;
    private Object chatOptions;
    private ChatResponse toolCallChatResponse;
    private String systemPrompt;
    // 注入核心依赖
    @Resource
    private ToolRouter toolRouter;
    @Resource
    private HotKeyCaffeineService hotKeyCaffeineService;
    @Resource
    private YourChatClient chatClient; // 替换为你的实际ChatClient类型

    public boolean think() {
        // 1. 获取当前题目ID(需根据你的业务实现,示例为占位)
        Integer problemId = getCurrentProblemId();
        FunctionCallback[] finalAvailableTools = null; // 修正:改为FunctionCallback[](大模型标准入参)
        boolean isCacheHit = false;
        String cacheData = null;

        // 2. 缓存命中判断 + 动态工具筛选(保留原有业务逻辑)
        if (problemId != null) {
            cacheData = hotKeyCaffeineService.getQuestionText(problemId);
            // 缓存判定规则与ToolRouter完全一致,保证业务统一
            isCacheHit = StrUtil.isNotBlank(cacheData)
                    && !cacheData.startsWith("题号格式错误")
                    && !cacheData.startsWith("未找到");

            // 3. 核心:获取Router工具并完成【ToolCallback→FunctionCallback】标准化转换
            ToolCallback[] routerTools = isCacheHit
                    ? toolRouter.getAvailableTools(problemId) // 命中→无RAG工具
                    : toolRouter.getAllTools(); // 未命中→全量工具
            // 关键修正:用Spring AI原生工具转换类,将ToolCallback[]转为可执行的FunctionCallback[]
            finalAvailableTools = ToolCallbacks.from(routerTools);
        } else {
            // 无题目ID:兜底用全量工具并完成类型转换
            ToolCallback[] allRouterTools = toolRouter.getAllTools();
            finalAvailableTools = ToolCallbacks.from(allRouterTools);
        }

        // 4. 拼接提示词 + 缓存数据注入上下文(保留原有逻辑)
        List<org.springframework.ai.chat.messages.Message> userMessageList = getMessageList();
        // 系统提示词(原有逻辑,修正为SystemMessage更规范)
        if (StrUtil.isNotBlank(getSystemPrompt())) {
            userMessageList.add(new SystemMessage(getSystemPrompt()));
        }
        // 缓存命中:将数据注入上下文,明确告知大模型无需调用RAG
        if (isCacheHit && StrUtil.isNotBlank(cacheData)) {
            String cacheHintPrompt = "【缓存命中专属数据】:" + cacheData + "\n"
                    + "请严格基于上述缓存数据处理用户请求,禁止调用RAG相关工具,直接返回处理结果";
            userMessageList.add(new UserMessage(cacheHintPrompt));
        }

        // 5. 调用大模型:传入【可执行、类型匹配】的finalAvailableTools
        Prompt prompt = new Prompt(userMessageList, this.chatOptions);
        try {
            ChatResponse chatResponse = chatClient
                    .prompt(prompt)
                    .system(getSystemPrompt())
                    .tools(finalAvailableTools) // 现在是大模型可识别的FunctionCallback[]
                    .call()
                    .chatResponse();

            // 记录响应,供后续Act阶段使用(原有逻辑)
            this.toolCallChatResponse = chatResponse;
            return true; // 思考成功
        } catch (Exception e) {
            log.error("大模型思考阶段调用失败,题目ID:{},缓存命中状态:{}", problemId, isCacheHit, e);
            return false; // 思考失败
        }
    }

    // ########## 以下为你的原有辅助方法,按需保留/实现 ##########
    private Integer getCurrentProblemId() {
        // 示例:从业务上下文/请求/成员变量中获取题目ID,根据实际场景实现
        // return this.problemId; // 推荐:将problemId作为类成员变量,业务初始化时赋值
        return null;
    }

    public List<org.springframework.ai.chat.messages.Message> getMessageList() {
        return this.messageList;
    }

    public String getSystemPrompt() {
        return this.systemPrompt;
    }

    // 其他get/set方法...
}

 if (problemId != null) {
            cacheData = hotKeyCaffeineService.getQuestionText(problemId);
            // 缓存判定规则与ToolRouter完全一致,保证业务统一
            isCacheHit = StrUtil.isNotBlank(cacheData)
                    && !cacheData.startsWith("题号格式错误")
                    && !cacheData.startsWith("未找到");

            // 3. 核心:获取Router工具并完成【ToolCallback→FunctionCallback】标准化转换
            ToolCallback[] routerTools = isCacheHit
                    ? toolRouter.getAvailableTools(problemId) // 命中→无RAG工具
                    : toolRouter.getAllTools(); // 未命中→全量工具
            // 关键修正:用Spring AI原生工具转换类,将ToolCallback[]转为可执行的FunctionCallback[]
            finalAvailableTools = ToolCallbacks.from(routerTools);

根据问题id去判断是不是热key,如果是的话,  cacheData = hotKeyCaffeineService.getQuestionText(problemId);的值就不为空,自然而然 ToolCallback[] routerTools = isCacheHit
                    ? toolRouter.getAvailableTools(problemId) // 命中→无RAG工具
                    : toolRouter.getAllTools(); // 未命中→全量工具

当它不为空的时候返回无rag工具的所有工具,然后在

 ChatResponse chatResponse = chatClient
                    .prompt(prompt)
                    .system(getSystemPrompt())
                    .tools(finalAvailableTools) // 现在是大模型可识别的FunctionCallback[]
                    .call()
                    .chatResponse();
加到这个里面。

当缓存未命中的时候就是

  // 无题目ID:兜底用全量工具并完成类型转换
            ToolCallback[] allRouterTools = toolRouter.getAllTools();
            finalAvailableTools = ToolCallbacks.from(allRouterTools);

把所有工具都给赋值。

// 4. 拼接提示词 + 缓存数据注入上下文(保留原有逻辑)
        List<org.springframework.ai.chat.messages.Message> userMessageList = getMessageList();
        // 系统提示词(原有逻辑,修正为SystemMessage更规范)
        if (StrUtil.isNotBlank(getSystemPrompt())) {
            userMessageList.add(new SystemMessage(getSystemPrompt()));
        }
        // 缓存命中:将数据注入上下文,明确告知大模型无需调用RAG
        if (isCacheHit && StrUtil.isNotBlank(cacheData)) {
            String cacheHintPrompt = "【缓存命中专属数据】:" + cacheData + "\n"
                    + "请严格基于上述缓存数据处理用户请求,禁止调用RAG相关工具,直接返回处理结果";
            userMessageList.add(new UserMessage(cacheHintPrompt));
        }

这是当缓存命中的时候给添加提示词(缓存)以及系统提示词。

Logo

有“AI”的1024 = 2048,欢迎大家加入2048 AI社区

更多推荐