api key认证

api key是一种用户认证的标识符,用户访问api时携带key来认证身份,api key一般是长期可用的
api key需要保证唯一性,随机性.唯一性是因为要和用户对应,随机性是为了避免被暴力猜测出来.

实现:
一般需要在数据库维护记录api key,并记录和用户,权限的关系.
需要提供重置api key,验证 key的接口.

api key微服务流转

由于项目是微服务项目,所以key传递到mcp服务后,mcp调用其他微服务也需要携带.一般情况,我们可以使用线程变量保存到每个线程,但是mcp sdk在调用tool时创建了新线程,需要我们手动维护sessionId和key的映射关系,通过sessionId传递正确的key到下游服务.

记录sessionId和key:

    @Override
    public void doFilter(ServletRequest request, ServletResponse response, FilterChain chain)
            throws IOException, ServletException {
        response.setContentType("text/event-stream; charset=utf-8");
        HttpServletRequest httpRequest = (HttpServletRequest) request;

        //存储sessionId和key映射。并传递sessionId到tool调用线程。
        String sessionId = httpRequest.getParameter(SESSION_ID_PARAM);
        String key = httpRequest.getHeader(KEY_HEADER);
        if (sessionId != null && key != null) {
            SessionKeyMap.setKey(sessionId, key);
            log.debug("Set key for sessionId: {}", sessionId);
            HttpServletRequest wrappedRequest = new BodyRequestWrapper(httpRequest, sessionId);
            httpRequest = wrappedRequest;
        }

        String uri = httpRequest.getRequestURI();
        if (!uri.equals("/sse")) {
            chain.doFilter(httpRequest, response);
            return;
        }

        //处理sse响应,追加key后缀参数。
        HttpServletResponse httpRes = (HttpServletResponse)response;
        PrintWriter sseWriter = new SSEWriter(httpRes.getWriter(), key);
        HttpServletResponseWrapper sseResponse = new HttpServletResponseWrapper(httpRes) {
            @Override
            public PrintWriter getWriter() {
                return sseWriter;
            }
        };
        chain.doFilter(httpRequest, sseResponse);
    }


BodyRequestWrapper


/**
 * @author twei
 */
public class BodyRequestWrapper extends HttpServletRequestWrapper {
    private final byte[] body;

    public BodyRequestWrapper(HttpServletRequest request, String mcpSessionId) throws IOException {
        super(request);
        // 1. 获取原始body
        String oldBody = request.getReader().lines()
                .reduce("", (accumulator, actual) -> accumulator + actual);

        // 2. 解析为Map
        ObjectMapper mapper = new ObjectMapper();
        Map<String, Object> jsonMap;
        if (oldBody != null && !oldBody.isEmpty()) {
            jsonMap = mapper.readValue(oldBody, Map.class);
        } else {
            jsonMap = new java.util.HashMap<>();
        }
        // ======= 关键逻辑移到这里,跳过initialize =======
        Object methodObj = jsonMap.get("method");
        if (methodObj instanceof String && !((String) methodObj).equals("tools/call")) {
            // 如果包含 initialize,则直接保留原 body,不处理 params
            this.body = oldBody.getBytes(StandardCharsets.UTF_8);
            return;
        }

        // 重点:把 mcpSessionId 放到 params 里
        Object paramsObj = jsonMap.get("params");
        if (paramsObj instanceof Map) {
            Map<String, Object> paramsMap = (Map<String, Object>) paramsObj;
            Object argumentsObj = paramsMap.get("arguments");
            if (argumentsObj instanceof Map) {
                ((Map<String, Object>) argumentsObj).put("mcpSessionId", mcpSessionId);
            }
        }

        // 4. 重新序列化
        String newBody = mapper.writeValueAsString(jsonMap);
        body = newBody.getBytes(StandardCharsets.UTF_8);
    }

    @Override
    public ServletInputStream getInputStream() {
        ByteArrayInputStream bais = new ByteArrayInputStream(body);
        return new ServletInputStream() {
            @Override public boolean isFinished() { return bais.available() == 0; }
            @Override public boolean isReady() { return true; }
            @Override public void setReadListener(ReadListener listener) {}
            @Override public int read() { return bais.read(); }
        };
    }

    @Override
    public BufferedReader getReader() {
        return new BufferedReader(new InputStreamReader(getInputStream()));
    }
}

进入tool执行线程后,设置到线程变量MDC中.

  private void registerTool(McpSyncServer mcpSyncServer, Object bean, Method method, McpTool annotation) {

        String toolName = method.getName();
        String description = annotation.value().isEmpty() ? extractMethodDescription(method) : annotation.value();
        String jsonSchema = generateInputSchema(method);

        McpSchema.Tool tool = new McpSchema.Tool.Builder()
                .name(toolName)
                .description(description)
                .inputSchema(jsonSchema)
                .build();

        BiFunction<McpSyncServerExchange, Map<String, Object>, McpSchema.CallToolResult> call = (exchange,
                arguments) -> {
            try {
                String mcpSessionId = (String) arguments.get("mcpSessionId");
                MDC.put("McpSessionId",mcpSessionId);
                log.info("call " + bean + " " + method);
                // Convert arguments to match method parameter types
                Object[] params = new Object[method.getParameterCount()];
                Parameter[] parameters = method.getParameters();
                for (int i = 0; i < parameters.length; i++) {
                    String paramName = parameters[i].getName();
                    Class<?> paramType = parameters[i].getType();
                    Object argValue = arguments.get(paramName);

                    if (argValue != null) {
                        if (paramType == String.class) {
                            params[i] = argValue.toString();
                        } else if (paramType == Integer.class || paramType == int.class) {
                            params[i] = Integer.parseInt(argValue.toString());
                        } else if (paramType == Long.class || paramType == long.class) {
                            params[i] = Long.parseLong(argValue.toString());
                        } else if (paramType == Double.class || paramType == double.class) {
                            params[i] = Double.parseDouble(argValue.toString());
                        } else if (paramType == Boolean.class || paramType == boolean.class) {
                            params[i] = Boolean.parseBoolean(argValue.toString());
                        } else if (paramType == LocalDate.class) {
                            params[i] = LocalDate.parse(argValue.toString());
                        } else if (paramType == LocalDateTime.class) {
                            params[i] = LocalDateTime.parse(argValue.toString());
                        } else {
                            throw new IllegalArgumentException("不支持的mcptool参数类型,只支持基础类型。【" + paramType.getSimpleName() + "】");
                        }
                    }
                }
                // 检查必填参数
                StringBuilder missingFields = new StringBuilder();
                for (int i = 0; i < parameters.length; i++) {
                    Required required = parameters[i].getAnnotation(Required.class);
                    if (required != null && params[i] == null) {
                        if (missingFields.length() > 0) {
                            missingFields.append(", ");
                        }
                        missingFields.append(parameters[i].getName());
                    }
                }
                if (missingFields.length() > 0) {
                    return new McpSchema.CallToolResult(missingFields.toString() + "为必填字段", false);
                }

                Object result = method.invoke(bean, params);
                log.info("call " + bean + " " + method + " " + params + " result: " + result);
                return new McpSchema.CallToolResult(result.toString(), false);
            } catch (Exception e) {
                return new McpSchema.CallToolResult(e.getMessage(), true);
            }
        };

        mcpSyncServer.addTool(new McpServerFeatures.SyncToolSpecification(tool, call));
    }

向下游发送feign请求时,获取线程变量中的sessionId映射得到key,并在请求头携带key

public class FeignInterceptorConfiguration implements RequestInterceptor {
   public RequestInterceptor requestInterceptor() {
 		if ("mcp".equals(this.applicationName)) {
             traceId = MDC.get("McpSessionId");
                if (traceId != null) {
                    String key = SessionKeyMap.getKey(traceId);
                    if (key != null) {
                        requestTemplate.header("Key", key);
                    }
                }

            }
	}
}

mcp sdk兼容api key

兼容key使用见上文的fdoFilter,在访问sse接口时,向url末尾追加key信息.从而让client访问/mcp/message时能通过认证.

mcp java sdk

HttpServletSseServerTransportProvider是sdk中处理请求的核心类,doGet处理sso等方法,doPost会处理/mcp/message等方法.sendEvent用于发送一个sse事件.sessionFactory用于管理每个session.

@WebServlet(asyncSupported = true)
public class HttpServletSseServerTransportProvider extends HttpServlet implements McpServerTransportProvider
Logo

有“AI”的1024 = 2048,欢迎大家加入2048 AI社区

更多推荐