前言

RetrievalAugmentor是检索增强器,其目的是为了更好的检索向量数据库,有助于我们更好的封装用户信息(UserMessage)有一些必要的知识需要学习,如检索器等

过程

现在假设我们的AI聊天助手接收到了用户的提问

我们要做的第一件事就是构建RetrievalAugmentor

DefaultRetrievalAugmentor retrievalAugmentor = DefaultRetrievalAugmentor.builder().queryRouter(queryRouter).build();

注意这个queryRouter是我们提前构建好的,里面包含了我们的自己定义的检索器,这里不再赘述

现在我们有这么一段代码:

Metadata metadata = Metadata.from(textUserMessage, "memoryId", chatMemory.messages());
AugmentationRequest augmentationRequest = new AugmentationRequest(textUserMessage, metadata);
augmentationResult = retrievalAugmentor.augment(augmentationRequest);

我们构造了一个request,这个request包含了用户提问的内容以及元数据

接下来我们进入augment里看看是什么样的流程

增强流程

源码里,会根据用户的提问信息和我们刚刚注入的metadata封装成Query

Query originalQuery = Query.from(queryText, augmentationRequest.metadata());

后续的步骤是这样的:

Collection<Query> queries = queryTransformer.transform(originalQuery);

Map<Query, Collection<List<Content>>> queryToContents = process(queries);

List<Content> contents = contentAggregator.aggregate(queryToContents);

ChatMessage augmentedChatMessage = contentInjector.inject(contents, chatMessage);

return AugmentationResult.builder()
    .chatMessage(augmentedChatMessage)
    .contents(contents)
    .build();

将query对象转换成集合类-->执行查询-->对查询结果聚合-->将结果和用户消息一起注入到模板

一共就这么四步

拆解开来

转换query为集合类

第一步就是将我们的query对象转换为集合的形式,默认是一个singleList,这一步没什么特别的,下面开始是重要部分

process方法执行查询

process里,就是进行增强检索的处理,通过我们前面注入的queryRouter拿到里面的检索器ContentRetriever,再调用相对应的检索方法,如果有多个检索器,会逐个的异步执行以下操作,最终聚合结果。检索器里主要是有几个参数,如maxResults、minScore、filter

(个人理解:queryRouter里包含了多个条件,他就叫做contentRetriever,我们在执行查询的时候会调用这里面的contentRetriever构造条件,这里的条件包括查询时的和查询后的,类似于MySQL中where条件和聚合结果,如filter就是where条件,maxResults和minScore就是聚合结果,类似于limit)

这里以单个EmbeddingStoreContentRetriever为例,其内部的检索步骤:

  1. 构造Embedding对象Embedding embeddedQuery = embeddingModel.embed(query.text()).content();

  2. 构造EmbeddingSearchRequest。将嵌入查询的text传入,剩下的是参数,这里内部是动态建立的过程,如果没有传入,那就会走默认的如(query)->3

    EmbeddingSearchRequest searchRequest = EmbeddingSearchRequest.builder()
            .queryEmbedding(embeddedQuery)
            .maxResults(maxResultsProvider.apply(query))
            .minScore(minScoreProvider.apply(query))
            .filter(filterProvider.apply(query))
            .build();
  3. 调用EmbeddingSearchResult<TextSegment> searchResult = embeddingStore.search(searchRequest);拿到执行的结果,以PgVector为例子,实际上执行的就是建立postgresql连接并执行查询语句

  4. 结果的内部是一个List集合,其类型是List<EmbeddingMatch<TextSegment>>,这个TextSegment里封装的是查询结果里的text以及metadata。除此之外还有:
    EmbeddingMatch(score, embeddingId, embedding, textSegment)embedding是float[]向量

  5. 检索最终返回的是List<Content> 因此需要对这个List<EmbeddingMatch<TextSegment>>做转换,Content里有两个属性:textSegment和metadata,这里对应上了
return searchResult.matches().stream()
        .map(embeddingMatch -> Content.from(
                embeddingMatch.embedded(),//其实就是刚刚的TextSegment
                Map.of(
                        ContentMetadata.SCORE, embeddingMatch.score(),
                        ContentMetadata.EMBEDDING_ID, embeddingMatch.embeddingId()
                )
        ))
        .collect(Collectors.toList());

聚合查询结果

        process走完了,该走聚合了,用聚合器将结果聚合,其实就是整理起来,过程不赘述

inject注入模板中

        最后一步,inject注射,源码将content这个list用/n/n分隔开(即分段)放入一个叫variables的Map里,这个Map比较关键的是userMessage和contents这两个key,除此之外还有日期、时间等,最终这个key是填入了default模板里面,至此chatMessage构造完毕,返回结果

return AugmentationResult.builder()
    .chatMessage(augmentedChatMessage)
    .contents(contents)
    .build();

注意这个模板是预定义的提示词模板,通过参数替换,而形成最终的提示词。

到这里整个流程就结束了,还是比较简单的

结语

有许多非常关键的概念,需要我们提前去了解后才能深入探索,是大模型学习的过程中基础中的基础,langchain4j官方有比较完整的概念讲解,必须要提前了解打好基础,才能探索整个AI工程

不足之处欢迎指出、有错误的地方欢迎纠正。

Logo

有“AI”的1024 = 2048,欢迎大家加入2048 AI社区

更多推荐