Spring AI Advisor源码分析

Advisor有什么作用

Spring AI Advisors API 提供了一种灵活且强大的方式,拦截、修改和增强您在 Spring 应用中的 AI 驱动交互。 比如:聊天记忆(MessageChatMemoryAdvisor)、知识库问答(QuestionAnswerAdvisor)、日志输出(SimpleLoggerAdvisor)等等

Advisor整体结构

  • 类图
Advisor API 类
  • 流程图
Advisors API 流程
  1. Spring AI 框架创建了一个聊天客户端请求来自用户的提示还有Advisor上下文对象。
  2. 链中的每个Advisor处理请求,可能会对其进行修改。对应处理ChatClientRequest before(ChatClientRequest chatClientRequest, AdvisorChain advisorChain);
  3. 最终顾问由框架提供的,将请求发送给聊天模型ChatModel.
  4. 聊天模型的响应随后会通过AdvisorChain 返回并转换为聊天客户响应.也包含共享Advisor上下文实例。
  5. 每位顾问都可以处理或修改该回复。对应的处理ChatClientResponse after(ChatClientResponse chatClientResponse, AdvisorChain advisorChain);
  6. 最终聊天客户响应通过提取 返回给客户端聊天补全.

结合源码分析

先看个简单示例

public KnowledgeController(ChatModel dashScopeChatModel) {
  this.chatMemory = MessageWindowChatMemory.builder().build();
  this.chatClient = ChatClient.builder(dashScopeChatModel)
    .defaultAdvisors(
    PromptChatMemoryAdvisor.builder(chatMemory).build(),
    SimpleLoggerAdvisor.builder().build()
  )
    .build();
}

@GetMapping(path = "/chat/stream", produces = MediaType.TEXT_EVENT_STREAM_VALUE)
public Flux<String> streamChat(HttpServletResponse response,
                               @RequestParam("userInput") String userInput,
                               @RequestParam("chatId") String chatId) {
  Flux<String> answerFlux = chatClient.prompt().user(userInput)
    .advisors(a -> a.param(CONVERSATION_ID, chatId))
    .advisors(buildQaAdvisor())
    .stream()
    .content();

  return answerFlux;
}    

第一阶段:配置收集 (RequestSpec 阶段)

  • 先说下DefaultChatClientRequestSpec的作用:复杂的参数组装逻辑全部隐藏在这里,作为一个“配置收集器”和“执行中转站”,负责将你通过链式调用输入的各种参数(advisors、messages、tools、system、user等)组装起来,并最终触发 AI 模型的调用。
  • 当我们调用advisors(buildQaAdvisor()):添加的Advisor都存储在DefaultChatClientRequestSpec的advisors变量中
  • 当我们调用advisors(a -> a.param(CONVERSATION_ID, chatId)):添加的param存储在DefaultChatClientRequestSpec的advisorParams变量中
  • 此时,Advisor 只是被简单的存储在 List 里,还没有形成“链”

第二阶段:触发构建AdvisorChain (call或stream调用触发)

//org.springframework.ai.chat.client.DefaultChatClient.DefaultChatClientRequestSpec#call
@Override
public CallResponseSpec call() {
  BaseAdvisorChain advisorChain = buildAdvisorChain();
  //DefaultChatClientUtils.toChatClientRequest见关注度点1
  return new DefaultCallResponseSpec(DefaultChatClientUtils.toChatClientRequest(this), advisorChain,
                                     this.observationRegistry, this.chatClientObservationConvention);
}

@Override
public StreamResponseSpec stream() {
  BaseAdvisorChain advisorChain = buildAdvisorChain();
  //DefaultChatClientUtils.toChatClientRequest见关注度点1
  return new DefaultStreamResponseSpec(DefaultChatClientUtils.toChatClientRequest(this), advisorChain,
                                       this.observationRegistry, this.chatClientObservationConvention);
}

private BaseAdvisorChain buildAdvisorChain() {
  //在堆栈底部添加模型调用顾问,它们扮演顾问链中最后一个顾问的角色
  this.advisors.add(ChatModelCallAdvisor.builder().chatModel(this.chatModel).build());
  this.advisors.add(ChatModelStreamAdvisor.builder().chatModel(this.chatModel).build());
  return DefaultAroundAdvisorChain.builder(this.observationRegistry)
    .observationConvention(this.advisorObservationConvention)
    .pushAll(this.advisors)  //这里会根据Ordered的值进行重新排序reOrder
    .build();
}

第三阶段:AdvisorChain 的执行 (content或chatResponse调用触发)

//org.springframework.ai.chat.client.DefaultChatClient.DefaultStreamResponseSpec#content
public Flux<String> content() {
  return doGetObservableFluxChatResponse(this.request)
    .mapNotNull(ChatClientResponse::chatResponse)
    .map(r -> Optional.ofNullable(r.getResult())
         .map(Generation::getOutput)
         .map(AbstractMessage::getText)
         .orElse(""))
    .filter(StringUtils::hasLength);
}

private Flux<ChatClientResponse> doGetObservableFluxChatResponse(ChatClientRequest chatClientRequest) {
			return Flux.deferContextual(contextView -> {
        
				Flux<ChatClientResponse> chatClientResponse = this.advisorChain.nextStream(chatClientRequest)
						.doOnError(observation::error)
						.doFinally(s -> observation.stop())
			});
}

//org.springframework.ai.chat.client.advisor.DefaultAroundAdvisorChain#nextStream
@Override
	public Flux<ChatClientResponse> nextStream(ChatClientRequest chatClientRequest) {
			// @formatter:off
			Flux<ChatClientResponse> chatClientResponse = Flux.defer(() -> advisor.adviseStream(chatClientRequest, this)
						.doOnError(observation::error)
						.doFinally(s -> observation.stop())
						.contextWrite(ctx -> ctx.put(ObservationThreadLocalAccessor.KEY, observation)));
		});
	}

//org.springframework.ai.chat.client.advisor.api.BaseAdvisor#adviseStream
	@Override
	default Flux<ChatClientResponse> adviseStream(ChatClientRequest chatClientRequest,
			StreamAdvisorChain streamAdvisorChain) {
		Flux<ChatClientResponse> chatClientResponseFlux = Mono.just(chatClientRequest)
			.publishOn(getScheduler())
      //执行前拦截
			.map(request -> this.before(request, streamAdvisorChain))
      //执行
			.flatMapMany(streamAdvisorChain::nextStream);

		return chatClientResponseFlux.map(response -> {
			if (AdvisorUtils.onFinishReason().test(response)) {
        //执行后拦截
				response = after(response, streamAdvisorChain);
			}
			return response;
		}).onErrorResume(error -> Flux.error(new IllegalStateException("Stream processing failed", error)));
	}

关注点1:构建ChatClientRequest

static ChatClientRequest toChatClientRequest(DefaultChatClient.DefaultChatClientRequestSpec inputRequest) {
		//忽略代码 这里关注上下文:context  AdvisorParam会放到上下文中在请求中流转
		return ChatClientRequest.builder()
			.prompt(Prompt.builder().messages(processedMessages).chatOptions(processedChatOptions).build())
			.context(new ConcurrentHashMap<>(inputRequest.getAdvisorParams()))
			.build();
	}

具体Advisor

聊天记忆 PromptChatMemoryAdvisor

  • 拦截前before
@Override
	public ChatClientRequest before(ChatClientRequest chatClientRequest, AdvisorChain advisorChain) {
    //从上下文获取会话ID
		String conversationId = getConversationId(chatClientRequest.context(), this.defaultConversationId);
		// 1. 获取当前对话的聊天记录。
		List<Message> memoryMessages = this.chatMemory.get(conversationId);
		// 2. 将内存消息作为字符串处理。
		String memory = memoryMessages.stream()
			.filter(m -> m.getMessageType() == MessageType.USER || m.getMessageType() == MessageType.ASSISTANT)
			.map(m -> m.getMessageType() + ":" + m.getText())
			.collect(Collectors.joining(System.lineSeparator()));
		// 3. 增强系统消息。
		SystemMessage systemMessage = chatClientRequest.prompt().getSystemMessage();
		String augmentedSystemText = this.systemPromptTemplate
			.render(Map.of("instructions", systemMessage.getText(), "memory", memory));

		// 4. 创建包含增强系统消息的新请求。
		ChatClientRequest processedChatClientRequest = chatClientRequest.mutate()
			.prompt(chatClientRequest.prompt().augmentSystemMessage(augmentedSystemText))
			.build();

		// 5. 将当前提示符下的所有用户消息添加到内存中(系统消息生成后),将新用户消息添加到对话内存中。
		UserMessage userMessage = processedChatClientRequest.prompt().getUserMessage();
		this.chatMemory.add(conversationId, userMessage);

		return processedChatClientRequest;
	}
  • 拦截后after
@Override
	public ChatClientResponse after(ChatClientResponse chatClientResponse, AdvisorChain advisorChain) {
		List<Message> assistantMessages = new ArrayList<>();
		// 从聊天客户端响应中提取助手消息。
		assistantMessages = Optional.ofNullable(chatClientResponse)
			.map(ChatClientResponse::chatResponse)
			.filter(response -> response.getResults() != null && !response.getResults().isEmpty())
			.map(response -> response.getResults()
				.stream()
				.map(g -> (Message) g.getOutput())
				.collect(Collectors.toList()))
			.orElse(List.of());

		if (!assistantMessages.isEmpty()) {
			this.chatMemory.add(this.getConversationId(chatClientResponse.context(), this.defaultConversationId),
					assistantMessages);
		}
		return chatClientResponse;
	}

参考文档

Spring AI Advisors API

Logo

有“AI”的1024 = 2048,欢迎大家加入2048 AI社区

更多推荐