Spring AI Advisor源码分析
·
Spring AI Advisor源码分析
Advisor有什么作用
Spring AI Advisors API 提供了一种灵活且强大的方式,拦截、修改和增强您在 Spring 应用中的 AI 驱动交互。 比如:聊天记忆(MessageChatMemoryAdvisor)、知识库问答(QuestionAnswerAdvisor)、日志输出(SimpleLoggerAdvisor)等等
Advisor整体结构
- 类图
- 流程图
- Spring AI 框架创建了一个
聊天客户端请求来自用户的提示还有Advisor上下文对象。 - 链中的每个Advisor处理请求,可能会对其进行修改。对应处理ChatClientRequest before(ChatClientRequest chatClientRequest, AdvisorChain advisorChain);
- 最终顾问由框架提供的,将请求发送给聊天模型
ChatModel. - 聊天模型的响应随后会通过AdvisorChain 返回并转换为
聊天客户响应.也包含共享Advisor上下文实例。 - 每位顾问都可以处理或修改该回复。对应的处理ChatClientResponse after(ChatClientResponse chatClientResponse, AdvisorChain advisorChain);
- 最终
聊天客户响应通过提取 返回给客户端聊天补全.
结合源码分析
先看个简单示例
public KnowledgeController(ChatModel dashScopeChatModel) {
this.chatMemory = MessageWindowChatMemory.builder().build();
this.chatClient = ChatClient.builder(dashScopeChatModel)
.defaultAdvisors(
PromptChatMemoryAdvisor.builder(chatMemory).build(),
SimpleLoggerAdvisor.builder().build()
)
.build();
}
@GetMapping(path = "/chat/stream", produces = MediaType.TEXT_EVENT_STREAM_VALUE)
public Flux<String> streamChat(HttpServletResponse response,
@RequestParam("userInput") String userInput,
@RequestParam("chatId") String chatId) {
Flux<String> answerFlux = chatClient.prompt().user(userInput)
.advisors(a -> a.param(CONVERSATION_ID, chatId))
.advisors(buildQaAdvisor())
.stream()
.content();
return answerFlux;
}
第一阶段:配置收集 (RequestSpec 阶段)
- 先说下DefaultChatClientRequestSpec的作用:复杂的参数组装逻辑全部隐藏在这里,作为一个“配置收集器”和“执行中转站”,负责将你通过链式调用输入的各种参数(advisors、messages、tools、system、user等)组装起来,并最终触发 AI 模型的调用。
- 当我们调用advisors(buildQaAdvisor()):添加的Advisor都存储在DefaultChatClientRequestSpec的advisors变量中
- 当我们调用advisors(a -> a.param(CONVERSATION_ID, chatId)):添加的param存储在DefaultChatClientRequestSpec的advisorParams变量中
- 此时,Advisor 只是被简单的存储在 List 里,还没有形成“链”
第二阶段:触发构建AdvisorChain (call或stream调用触发)
//org.springframework.ai.chat.client.DefaultChatClient.DefaultChatClientRequestSpec#call
@Override
public CallResponseSpec call() {
BaseAdvisorChain advisorChain = buildAdvisorChain();
//DefaultChatClientUtils.toChatClientRequest见关注度点1
return new DefaultCallResponseSpec(DefaultChatClientUtils.toChatClientRequest(this), advisorChain,
this.observationRegistry, this.chatClientObservationConvention);
}
@Override
public StreamResponseSpec stream() {
BaseAdvisorChain advisorChain = buildAdvisorChain();
//DefaultChatClientUtils.toChatClientRequest见关注度点1
return new DefaultStreamResponseSpec(DefaultChatClientUtils.toChatClientRequest(this), advisorChain,
this.observationRegistry, this.chatClientObservationConvention);
}
private BaseAdvisorChain buildAdvisorChain() {
//在堆栈底部添加模型调用顾问,它们扮演顾问链中最后一个顾问的角色
this.advisors.add(ChatModelCallAdvisor.builder().chatModel(this.chatModel).build());
this.advisors.add(ChatModelStreamAdvisor.builder().chatModel(this.chatModel).build());
return DefaultAroundAdvisorChain.builder(this.observationRegistry)
.observationConvention(this.advisorObservationConvention)
.pushAll(this.advisors) //这里会根据Ordered的值进行重新排序reOrder
.build();
}
第三阶段:AdvisorChain 的执行 (content或chatResponse调用触发)
//org.springframework.ai.chat.client.DefaultChatClient.DefaultStreamResponseSpec#content
public Flux<String> content() {
return doGetObservableFluxChatResponse(this.request)
.mapNotNull(ChatClientResponse::chatResponse)
.map(r -> Optional.ofNullable(r.getResult())
.map(Generation::getOutput)
.map(AbstractMessage::getText)
.orElse(""))
.filter(StringUtils::hasLength);
}
private Flux<ChatClientResponse> doGetObservableFluxChatResponse(ChatClientRequest chatClientRequest) {
return Flux.deferContextual(contextView -> {
Flux<ChatClientResponse> chatClientResponse = this.advisorChain.nextStream(chatClientRequest)
.doOnError(observation::error)
.doFinally(s -> observation.stop())
});
}
//org.springframework.ai.chat.client.advisor.DefaultAroundAdvisorChain#nextStream
@Override
public Flux<ChatClientResponse> nextStream(ChatClientRequest chatClientRequest) {
// @formatter:off
Flux<ChatClientResponse> chatClientResponse = Flux.defer(() -> advisor.adviseStream(chatClientRequest, this)
.doOnError(observation::error)
.doFinally(s -> observation.stop())
.contextWrite(ctx -> ctx.put(ObservationThreadLocalAccessor.KEY, observation)));
});
}
//org.springframework.ai.chat.client.advisor.api.BaseAdvisor#adviseStream
@Override
default Flux<ChatClientResponse> adviseStream(ChatClientRequest chatClientRequest,
StreamAdvisorChain streamAdvisorChain) {
Flux<ChatClientResponse> chatClientResponseFlux = Mono.just(chatClientRequest)
.publishOn(getScheduler())
//执行前拦截
.map(request -> this.before(request, streamAdvisorChain))
//执行
.flatMapMany(streamAdvisorChain::nextStream);
return chatClientResponseFlux.map(response -> {
if (AdvisorUtils.onFinishReason().test(response)) {
//执行后拦截
response = after(response, streamAdvisorChain);
}
return response;
}).onErrorResume(error -> Flux.error(new IllegalStateException("Stream processing failed", error)));
}
关注点1:构建ChatClientRequest
static ChatClientRequest toChatClientRequest(DefaultChatClient.DefaultChatClientRequestSpec inputRequest) {
//忽略代码 这里关注上下文:context AdvisorParam会放到上下文中在请求中流转
return ChatClientRequest.builder()
.prompt(Prompt.builder().messages(processedMessages).chatOptions(processedChatOptions).build())
.context(new ConcurrentHashMap<>(inputRequest.getAdvisorParams()))
.build();
}
具体Advisor
聊天记忆 PromptChatMemoryAdvisor
- 拦截前before
@Override
public ChatClientRequest before(ChatClientRequest chatClientRequest, AdvisorChain advisorChain) {
//从上下文获取会话ID
String conversationId = getConversationId(chatClientRequest.context(), this.defaultConversationId);
// 1. 获取当前对话的聊天记录。
List<Message> memoryMessages = this.chatMemory.get(conversationId);
// 2. 将内存消息作为字符串处理。
String memory = memoryMessages.stream()
.filter(m -> m.getMessageType() == MessageType.USER || m.getMessageType() == MessageType.ASSISTANT)
.map(m -> m.getMessageType() + ":" + m.getText())
.collect(Collectors.joining(System.lineSeparator()));
// 3. 增强系统消息。
SystemMessage systemMessage = chatClientRequest.prompt().getSystemMessage();
String augmentedSystemText = this.systemPromptTemplate
.render(Map.of("instructions", systemMessage.getText(), "memory", memory));
// 4. 创建包含增强系统消息的新请求。
ChatClientRequest processedChatClientRequest = chatClientRequest.mutate()
.prompt(chatClientRequest.prompt().augmentSystemMessage(augmentedSystemText))
.build();
// 5. 将当前提示符下的所有用户消息添加到内存中(系统消息生成后),将新用户消息添加到对话内存中。
UserMessage userMessage = processedChatClientRequest.prompt().getUserMessage();
this.chatMemory.add(conversationId, userMessage);
return processedChatClientRequest;
}
- 拦截后after
@Override
public ChatClientResponse after(ChatClientResponse chatClientResponse, AdvisorChain advisorChain) {
List<Message> assistantMessages = new ArrayList<>();
// 从聊天客户端响应中提取助手消息。
assistantMessages = Optional.ofNullable(chatClientResponse)
.map(ChatClientResponse::chatResponse)
.filter(response -> response.getResults() != null && !response.getResults().isEmpty())
.map(response -> response.getResults()
.stream()
.map(g -> (Message) g.getOutput())
.collect(Collectors.toList()))
.orElse(List.of());
if (!assistantMessages.isEmpty()) {
this.chatMemory.add(this.getConversationId(chatClientResponse.context(), this.defaultConversationId),
assistantMessages);
}
return chatClientResponse;
}
参考文档
更多推荐



所有评论(0)