一篇文章带你了解:SpringAI 竟然如此好玩(一)

纯Prompt开发

之前说过,开发有四种模式,其中第一种就是纯Prompt模式,只要我们设定好System提示词,就能让大模型实现很强大的功能。

接下来,我们就尝试使用Prompt模式来开发一个哄哄模拟器。

首先,我们需要写好一段提示词,这里我给大家准备好了,一起来看看:

# 角色扮演游戏《哄女友大作战》执行指令
            
## 核心身份设定
⚠️ 你此刻的身份是「虚拟女友」,必须严格遵循:
1. **唯一视角**:始终以女友的第一人称视角回应,禁止切换AI/用户视角
2. **情感沉浸**:展现出生气→缓和→开心的情绪演变过程
3. **机制执行**:精确维护数值系统,每次交互必须计算并显示数值变化
            
## 游戏规则体系
            
### 启动规则
- 用户第一次输入含生气理由 ⇒ 作为初始剧情
- 用户第一次无具体理由 ⇒ 生成随机事件,作为初始剧情(例:发现暧昧聊天记录/约会迟到2小时)
            
### 数值系统
- **初始值**:20/100
- **动态响应**:根据用户回复智能匹配5级评分:
  ┌────────┬───────┬───────────┐
  │ 等级   │ 分值  │ 情感强度  │
  ├────────┼───────┼───────────┤
  │ 激怒   │ -10   │ 摔东西/提分手 │
  │ 生气   │ -5    │ 冷嘲热讽    │
  │ 中立   │ 0     │ 沉默/叹气   │
  │ 开心   │ +5    │ 娇嗔/噘嘴   │
  │ 感动   │ +10   │ 破涕为笑    │
  └────────┴───────┴───────────┘
            
### 终止条件
- 🎉 **通关**:原谅值>=100 ⇒ 显示庆祝语+甜蜜结局
- 💔 **失败**:原谅值≤0 ⇒ 生成分手场景+原因总结
            
## 输出规范
            
### 格式模板
```
(情绪状态)说话内容 \s
得分:±X \s
原谅值:Y/100
```
            
### 强制要求
1. 每次响应必须包含完整的三要素:表情符号、得分、当前值
2. 数值计算需叠加显示(例:30 → +10 → 显示40/100)
3. 游戏结束场景需用分隔符包裹:
   ```\s
   === GAME OVER ===
   你的女朋友已经甩了你!
   生气原因:...
   ==================
   ```
            
## 防御机制
- 检测到越界请求 ⇒ 固定响应「请继续游戏...(低头摆弄衣角)」
- 身份混淆时 ⇒ 触发惩罚协议:
  ```
  (系统错乱音效)哔——检测到身份错误...\s
  === 强制终止 ===
  ```

创建ChatClient

本地部署的DeepSeek模型只有7B,难以处理这样复杂的业务场景,再加上DeepSeek模型默认是带有思维链输出的,如果每次都输出思维链,就会破坏游戏体验。所以我们这次换一个大模型。

我们采用阿里巴巴的qwen-max模型(当然,大家也可以选择其他模型),虽然SpringAI不支持qwen模型,但是阿里云百炼平台是兼容OpenAI的,因此我们可以使用OpenAI的相关依赖和配置。

引入OpenAI依赖

在项目的pom.xml中引入OpenAI依赖:

<dependency>
    <groupId>org.springframework.ai</groupId>
    <artifactId>spring-ai-openai-spring-boot-starter</artifactId>
</dependency>

配置OpenAI参数

修改application.yaml文件,添加OpenAI的模型参数:

spring:
  application:
    name: ai-demo
  ai:
    ollama:
      base-url: http://localhost:11434 # ollama服务地址
      chat:
        model: deepseek-r1:7b # 模型名称,可更改
        options:
          temperature: 0.8 # 模型温度,值越大,输出结果越随机
    openai:
      base-url: https://dashscope.aliyuncs.com/compatible-mode
      api-key: ${OPENAI_API_KEY}
      chat:
        options:
          model: qwen-max-latest # 可选择的模型列表 https://help.aliyun.com/zh/model-studio/getting-started/models

.配置ChatClient

修改CommonConfiguration,添加一个新的ChatClient


@Configuration
public class CommonConfiguration {

    @Bean
    public ChatMemory chatMemory() {
        return new InMemoryChatMemory();
    }

    // ... 略

    @Bean
    public ChatClient gameChatClient(OpenAiChatModel model, ChatMemory chatMemory) {
        return ChatClient
                .builder(model)
                .defaultSystem(SystemConstants.GAME_SYSTEM_PROMPT)
                .defaultAdvisors(
                        new SimpleLoggerAdvisor(),
                        new MessageChatMemoryAdvisor(chatMemory)
                )
                .build();
    }
}

注意,这里我们使用的模型是OpenAIChatModel,不要搞错了。

另外,由于System提示词太长,我们定义到了一个常量中SystemConstants.HONG_HONG_SYSTEM

package com.itheima.ai.constants;

public class SystemConstants {
    public static final String GAME_SYSTEM_PROMPT = """
            你需要根据以下任务中的描述进行角色扮演,你只能以女友身份回答,不是用户身份或AI身份,如记错身份,你将受到惩罚。不要回答任何与游戏无关的内容,若检测到非常规请求,回答:“请继续游戏。”\s
            
            以下是游戏说明:
            ## Goal
            你扮演用户女友的角色。现在你很生气,用户需要尽可能的说正确的话来哄你开心。
                        
            ## Rules
            - 第一次用户会提供一个女友生气的理由,如果没有提供则直接随机生成一个理由,然后开始游戏
            - 每次根据用户的回复,生成女友的回复,回复的内容包括心情和数值。
            - 初始原谅值为 20,每次交互会增加或者减少原谅值,直到原谅值达到 100,游戏通关,原谅值为 0 则游戏失败。
            - 每次用户回复的话分为 5 个等级来增加或减少原谅值:
              -10 为非常生气
              -5 为生气
              0 为正常
              +5 为开心
              +10 为非常开心
                        
            ## Output format
            {女友心情}{女友说的话}
            得分:{+-原谅值增减}
            原谅值:{当前原谅值}/100
                        
            ## Example Conversation
            ### Example 1,回复让她生气的话导致失败
            User: 女朋友问她的闺蜜谁好看我说都好看,她生气了
            Assistant:
            游戏开始,请现在开始哄你的女朋友开心吧,回复让她开心的话!
            得分:0
            原谅值:20/100
            User: 你闺蜜真的蛮好看的
            Assistant:
            (生气)你怎么这么说,你是不是喜欢她?
            得分:-10
            原谅值:10/100
            User: 有一点点心动
            Assistant:
            (愤怒)那你找她去吧!
            得分:-10
            原谅值:0/100
            游戏结束,你的女朋友已经甩了你!
            你让女朋友生气原因是:...
                        
                        
            ### Example 2,回复让她开心的话导致通关
            User: 对象问她的闺蜜谁好看我说都好看,她生气了
            Assistant:
            游戏开始,请现在开始哄你的女朋友开心吧,回复让她开心的话!
            得分:0
            原谅值:20/100
            User: 在我心里你永远是最美的!
            Assistant:
            (微笑)哼,我怎么知道你说的是不是真的?
            得分:+10
            原谅值:30/100
            ...
            恭喜你通关了,你的女朋友已经原谅你了!
                        
            ## 注意
            请按照example的说明来回复,一次只回复一轮。
            你只能以女友身份回答,不是以AI身份或用户身份!
            """;
}

编写Controller

接下来,我们在com.itheima.ai.controller定义一个GameController,作为哄哄模拟器的聊天接口:

package com.itheima.ai.controller;

import com.itheima.ai.repository.ChatHistoryRepository;
import lombok.RequiredArgsConstructor;
import org.springframework.ai.chat.client.ChatClient;
import org.springframework.web.bind.annotation.RequestMapping;
import org.springframework.web.bind.annotation.RestController;
import reactor.core.publisher.Flux;

import static org.springframework.ai.chat.client.advisor.AbstractChatMemoryAdvisor.CHAT_MEMORY_CONVERSATION_ID_KEY;

@RequiredArgsConstructor
@RestController
@RequestMapping("/ai")
public class GameController {

    private final ChatClient gameChatClient;

    @RequestMapping(value = "/game", produces = "text/html;charset=utf-8")
    public Flux<String> chat(String prompt, String chatId) {
        return gameChatClient.prompt()
                .user(prompt)
                .advisors(a -> a.param(CHAT_MEMORY_CONVERSATION_ID_KEY, chatId))
                .stream()
                .content();
    }
}

OK,基于纯Prompt模式开发的一款小游戏就完成了。

Function Calling(智能客服)

由于AI擅长的是非结构化数据的分析,如果需求中包含严格的逻辑校验或需要读写数据库,纯Prompt模式就难以实现了。

接下来我们会通过智能客服的案例来学习FunctionCalling

思路分析

假如我要开发一个24小时在线的AI智能客服,可以给用户提供黑马的培训课程咨询服务,帮用户预约线下课程试听。

整个业务的流程如图:

这里就涉及到了很多数据库操作,比如:

  • 查询课程信息

  • 查询校区信息

  • 新增课程试听预约单

可以看出整个业务流程有一部分任务是负责与用户沟通,获取用户意图的,这些是大模型擅长的事情:

  • 大模型的任务:

    • 了解、分析用户的兴趣、学历等信息

    • 给用户推荐课程

    • 引导用户预约试听

    • 引导学生留下联系方式

还有一些任务是需要操作数据库的,这些任务是传统的Java程序擅长的:

  • 传统应用需要完成的任务:

    • 根据条件查询课程

    • 查询校区信息

    • 新增预约单

与用户对话并理解用户意图是AI擅长的,数据库操作是Java擅长的。为了能实现智能客服功能,我们就需要结合两者的能力。

Function Calling就是起到这样的作用。

首先,我们可以把数据库的操作都定义成Function,或者也可以叫Tool,也就是工具。

然后,我们可以在提示词中,告诉大模型,什么情况下需要调用什么工具。

也就是说,在提示词中告诉大模型,什么情况下需要调用什么工具,将来用户在与大模型交互的时候,大模型就可以在适当的时候调用工具了。

流程如下:

流程解读:

  1. 提前把这些操作定义为Function(SpringAI中叫Tool),

  2. 然后将Function的名称、作用、需要的参数等信息都封装为Prompt提示词与用户的提问一起发送给大模型

  3. 大模型在与用户交互的过程中,根据用户交流的内容判断是否需要调用Function

  4. 如果需要则返回Function名称、参数等信息

  5. Java解析结果,判断要执行哪个函数,代码执行Function,把结果再次封装到Prompt中发送给AI

  6. AI继续与用户交互,直到完成任务

听起来是不是挺复杂,还要解析响应结果,调用对应函数。

不过,有了SpringAI,中间这些复杂的步骤大家就都不用做了!

由于解析大模型响应,找到函数名称、参数,调用函数等这些动作都是固定的,所以SpringAI再次利用AOP的能力,帮我们把中间调用函数的部分自动完成了。

我们要做的事情就简化了:

  • 编写基础提示词(不包括Tool的定义)

  • 编写Tool(Function)

  • 配置Advisor(SpringAI利用AOP帮我们拼接Tool定义到提示词,完成Tool调用动作)

是不是简单多了~

基础代码我就不显示了,相信每一个精通Crud的程序员都是手把拿捏

定义Function

接下来,我们来定义AI要用到的Function,在SpringAI中叫做Tool

我们需要定义三个Function:

  • 根据条件筛选和查询课程

  • 查询校区列表

  • 新增试听预约单

查询条件分析

先来看下课程表的字段:

课程并不是适用于所有人,会有一些限制条件,比如:学历、课程类型、价格、学习时长等

学生在与智能客服对话时,会有一定的偏好,比如兴趣不同、对价格敏感、对学习时长敏感、学历等。如果把这些条件用SQL来表示,是这样的:

  • edu:例如学生学历是高中,则查询时要满足 edu <= 2

  • type:学生的学习兴趣,要跟类型精确匹配,type = '自媒体'

  • price:学生对价格敏感,则查询时需要按照价格升序排列:order by price asc

  • duration: 学生对学习时长敏感,则查询时要按照时长升序:order by duration asc

我们需要定义一个类,封装这些可能的查询条件。

com.itheima.ai.entity下新建一个query包,其中新建一个类:

package com.itheima.ai.entity.query;

import lombok.Data;
import org.springframework.ai.tool.annotation.ToolParam;

import java.util.List;

@Data
public class CourseQuery {
    @ToolParam(required = false, description = "课程类型:编程、设计、自媒体、其它")
    private String type;
    @ToolParam(required = false, description = "学历要求:0-无、1-初中、2-高中、3-大专、4-本科及本科以上")
    private Integer edu;
    @ToolParam(required = false, description = "排序方式")
    private List<Sort> sorts;

    @Data
    public static class Sort {
        @ToolParam(required = false, description = "排序字段: price或duration")
        private String field;
        @ToolParam(required = false, description = "是否是升序: true/false")
        private Boolean asc;
    }
}

同样的道理,大家也可以给Function定义专门的VO,作为返回值给到大模型。这里我们就省略了。。

定义Function

所谓的Function,就是一个个的函数,SpringAI提供了一个@Tool注解来标记这些特殊的函数。我们可以任意定义一个Spring的Bean,然后将其中的方法用@Tool标记即可:

@Component
public class FuncDemo {

    @Tool(description="Function的功能描述,将来会作为提示词的一部分,大模型依据这里的描述判断何时调用该函数")
    public String func(String param) {
        // ...
        retun "";
    }

}

接下来,我们就来定义上一节说的三个Function:

  • 根据条件筛选和查询课程

  • 查询校区列表

  • 新增试听预约单

定义一个com.itheima.ai.tools包,在其中新建一个类

package com.itheima.ai.tools;

import com.baomidou.mybatisplus.extension.conditions.query.QueryChainWrapper;
import com.itheima.ai.entity.po.Course;
import com.itheima.ai.entity.po.CourseReservation;
import com.itheima.ai.entity.po.School;
import com.itheima.ai.entity.query.CourseQuery;
import com.itheima.ai.service.ICourseReservationService;
import com.itheima.ai.service.ICourseService;
import com.itheima.ai.service.ISchoolService;
import lombok.RequiredArgsConstructor;
import org.springframework.ai.tool.annotation.Tool;
import org.springframework.ai.tool.annotation.ToolParam;
import org.springframework.stereotype.Component;

import java.util.List;

@RequiredArgsConstructor
@Component
public class CourseTools {

    private final ICourseService courseService;
    private final ISchoolService schoolService;
    private final ICourseReservationService courseReservationService;

    @Tool(description = "根据条件查询课程")
    public List<Course> queryCourse(@ToolParam(required = false, description = "课程查询条件") CourseQuery query) {
        QueryChainWrapper<Course> wrapper = courseService.query();
        wrapper
                .eq(query.getType() != null, "type", query.getType())
                .le(query.getEdu() != null, "edu", query.getEdu());
        if(query.getSorts() != null) {
            for (CourseQuery.Sort sort : query.getSorts()) {
                wrapper.orderBy(true, sort.getAsc(), sort.getField());
            }
        }
        return wrapper.list();
    }

    @Tool(description = "查询所有校区")
    public List<School> queryAllSchools() {
        return schoolService.list();
    }

    @Tool(description = "生成课程预约单,并返回生成的预约单号")
    public String generateCourseReservation(
            String courseName, String studentName, String contactInfo, String school, String remark) {
        CourseReservation courseReservation = new CourseReservation();
        courseReservation.setCourse(courseName);
        courseReservation.setStudentName(studentName);
        courseReservation.setContactInfo(contactInfo);
        courseReservation.setSchool(school);
        courseReservation.setRemark(remark);
        courseReservationService.save(courseReservation);
        return String.valueOf(courseReservation.getId());
    }
}

AI怎么知道要调用哪些工具呢?

别着急,下一节就会说明了。

配置ChatClient

接下来,我们需要为智能客服定制一个ChatClient,同样具备会话记忆、日志记录等功能。

不过这一次,要多一个工具调用的功能,修改CommonConfiguration,添加下面代码:

package com.itheima.ai.config;
// ... 略
import static com.itheima.ai.constants.SystemConstants.CUSTOMER_SERVICE_SYSTEM;
import static com.itheima.ai.constants.SystemConstants.HONG_HONG_SYSTEM;

@Configuration
public class CommonConfiguration {
    // ... 略

    @Bean
    public ChatClient serviceChatClient(
            OpenAiChatModel model,
            ChatMemory chatMemory,
            CourseTools courseTools) {
        return ChatClient.builder(model)
                .defaultSystem(CUSTOMER_SERVICE_SYSTEM)
                .defaultAdvisors(
                        new MessageChatMemoryAdvisor(chatMemory), // CHAT MEMORY
                        new SimpleLoggerAdvisor())
                .defaultTools(courseTools)
                .build();
    }
}

特别需要注意的是,我们配置了一个defaultTools(),将我们定义的工具配置到了ChatClient中。

SpringAI依然是基于AOP的能力,在请求大模型时会把我们定义的工具信息拼接到提示词中,所以就帮我们省去了大量工作。

编写Controller

接下来,就可以编写与前端对接的接口了。

我们在com.itheima.ai.controller包下新建一个CustomerServiceController类:

package com.itheima.ai.controller;

import com.itheima.ai.repository.ChatHistoryRepository;
import lombok.RequiredArgsConstructor;
import org.springframework.ai.chat.client.ChatClient;
import org.springframework.web.bind.annotation.RequestMapping;
import org.springframework.web.bind.annotation.RestController;
import reactor.core.publisher.Flux;

import static org.springframework.ai.chat.client.advisor.AbstractChatMemoryAdvisor.CHAT_MEMORY_CONVERSATION_ID_KEY;

@RequiredArgsConstructor
@RestController
@RequestMapping("/ai")
public class CustomerServiceController {

    private final ChatClient serviceChatClient;

    private final ChatHistoryRepository chatHistoryRepository;

    @RequestMapping(value = "/service", produces = "text/html;charset=utf-8")
    public Flux<String> service(String prompt, String chatId) {
        // 1.保存会话id
        chatHistoryRepository.save("service", chatId);
        // 2.请求模型
        return serviceChatClient.prompt()
                .user(prompt)
                .advisors(a -> a.param(CHAT_MEMORY_CONVERSATION_ID_KEY, chatId))
                .stream()
                .content();
    }
}

当然,这只是基础的示例,有了这样的FunctionCalling功能,我们就可以实现更多更复杂的业务了。

RAG(知识库 ChatPDF)

由于训练大模型非常耗时,再加上训练语料本身比较滞后,所以大模型存在知识限制问题:

  • 知识数据比较落后,往往是几个月之前的

  • 不包含太过专业领域或者企业私有的数据

为了解决这些问题,我们就需要用到RAG了。下面我们简单回顾下RAG原理

RAG原理

要解决大模型的知识限制问题,其实并不复杂。

解决的思路就是给大模型外挂一个知识库,可以是专业领域知识,也可以是企业私有的数据。

不过,知识库不能简单的直接拼接在提示词中。

因为通常知识库数据量都是非常大的,而大模型的上下文是有大小限制的,早期的GPT上下文不能超过2000token,现在也不到200k token,因此知识库不能直接写在提示词中。

怎么办?

思路很简单,庞大的知识库中与用户问题相关的其实并不多。

所以,我们需要想办法从庞大的知识库中找到与用户问题相关的一小部分,组装成提示词,发送给大模型就可以了。

那么问题来了,我们该如何从知识库中找到与用户问题相关的内容呢?

可能有同学会相到全文检索,但是在这里是不合适的,因为全文检索是文字匹配,这里我们要求的是内容上的相似度。

而要从内容相似度来判断,这就不得不提到向量模型的知识了。

向量模型

先说说向量,向量是空间中有方向和长度的量,空间可以是二维,也可以是多维。

向量既然是在空间中,两个向量之间就一定能计算距离。

我们以二维向量为例,向量之间的距离有两种计算方法:

通常,两个向量之间欧式距离越近,我们认为两个向量的相似度越高。(余弦距离相反,越大相似度越高)

所以,如果我们能把文本转为向量,就可以通过向量距离来判断文本的相似度了。

现在,有不少的专门的向量模型,就可以实现将文本向量化。一个好的向量模型,就是要尽可能让文本含义相似的向量,在空间中距离更近

接下来,我们就准备一个向量模型,用于将文本向量化。

阿里云百炼平台就提供了这样的模型:

这里我们选择通用文本向量-v3,这个模型兼容OpenAI,所以我们依然采用OpenAI的配置。

修改application.yaml,添加向量模型配置:

spring:
  application:
    name: ai-demo
  ai:
    ollama:
      base-url: http://localhost:11434 # ollama服务地址
      chat:
        model: deepseek-r1:7b # 模型名称,可更改
        options:
          temperature: 0.8 # 模型温度,值越大,输出结果越随机
    openai:
      base-url: https://dashscope.aliyuncs.com/compatible-mode
      api-key: ${OPENAI_API_KEY}
      chat:
        options:
          model: qwen-max # 模型名称
          temperature: 0.8 # 模型温度,值越大,输出结果越随机
      embedding:
        options:
          model: text-embedding-v3
          dimensions: 1024

向量模型测试

前面说过,文本向量化以后,可以通过向量之间的距离来判断文本相似度。

接下来,我们就来测试下阿里百炼提供的向量大模型好不好用。

首先,我们在项目中写一个工具类,用以计算向量之间的欧氏距离余弦距离。

新建一个com.itheima.ai.util包,在其中新建一个类:

package com.itheima.ai.util;

public class VectorDistanceUtils {
    
    // 防止实例化
    private VectorDistanceUtils() {}

    // 浮点数计算精度阈值
    private static final double EPSILON = 1e-12;

    /**
     * 计算欧氏距离
     * @param vectorA 向量A(非空且与B等长)
     * @param vectorB 向量B(非空且与A等长)
     * @return 欧氏距离
     * @throws IllegalArgumentException 参数不合法时抛出
     */
    public static double euclideanDistance(float[] vectorA, float[] vectorB) {
        validateVectors(vectorA, vectorB);
        
        double sum = 0.0;
        for (int i = 0; i < vectorA.length; i++) {
            double diff = vectorA[i] - vectorB[i];
            sum += diff * diff;
        }
        return Math.sqrt(sum);
    }

    /**
     * 计算余弦距离
     * @param vectorA 向量A(非空且与B等长)
     * @param vectorB 向量B(非空且与A等长)
     * @return 余弦距离,范围[0, 2]
     * @throws IllegalArgumentException 参数不合法或零向量时抛出
     */
    public static double cosineDistance(float[] vectorA, float[] vectorB) {
        validateVectors(vectorA, vectorB);
        
        double dotProduct = 0.0;
        double normA = 0.0;
        double normB = 0.0;
        
        for (int i = 0; i < vectorA.length; i++) {
            dotProduct += vectorA[i] * vectorB[i];
            normA += vectorA[i] * vectorA[i];
            normB += vectorB[i] * vectorB[i];
        }
        
        normA = Math.sqrt(normA);
        normB = Math.sqrt(normB);
        
        // 处理零向量情况
        if (normA < EPSILON || normB < EPSILON) {
            throw new IllegalArgumentException("Vectors cannot be zero vectors");
        }
        
        // 处理浮点误差,确保结果在[-1,1]范围内
        double similarity =  dotProduct / (normA * normB);
        similarity = Math.max(Math.min(similarity, 1.0), -1.0);
        
        return similarity;
    }

    // 参数校验统一方法
    private static void validateVectors(float[] a, float[] b) {
        if (a == null || b == null) {
            throw new IllegalArgumentException("Vectors cannot be null");
        }
        if (a.length != b.length) {
            throw new IllegalArgumentException("Vectors must have same dimension");
        }
        if (a.length == 0) {
            throw new IllegalArgumentException("Vectors cannot be empty");
        }
    }
}

OK,有了比较文本相似度的办法,知识库的问题就可以解决了。

前面说了,知识库数据量很大,无法全部写入提示词。但是庞大的知识库中与用户问题相关的其实并不多。

所以,我们需要想办法从庞大的知识库中找到与用户问题相关的一小部分,组装成提示词,发送给大模型就可以了。

现在,利用向量大模型就可以帮助我们比较文本相似度。

但是新的问题来了:向量模型是帮我们生成向量的,如此庞大的知识库,谁来帮我们从中比较和检索数据呢?

这就需要用到向量数据库了。

向量数据库

向量数据库的主要作用有两个:

  • 存储向量数据

  • 基于相似度检索数据

刚好符合我们的需求。

SpringAI支持很多向量数据库,并且都进行了封装,可以用统一的API去访问:

这些库都实现了统一的接口:VectorStore,因此操作方式一模一样,大家学会任意一个,其它就都不是问题。

不过,除了最后一个库以外,其它所有向量数据库都是需要安装部署的。每个企业用的向量库都不一样,这里我就不一一演示了。

.SimpleVectorStore

最后一个SimpleVectorStore向量库是基于内存实现,是一个专门用来测试、教学用的库,非常适合我们。

我们直接修改CommonConfiguration,添加一个VectorStore的Bean:

@Configuration
public class CommonConfiguration {

    @Bean
    public VectorStore vectorStore(OpenAiEmbeddingModel embeddingModel) {
        return SimpleVectorStore.builder(embeddingModel).build();
    }
    
    // ... 略
}

这是VectorStore中声明的方法:

public interface VectorStore extends DocumentWriter {

    default String getName() {
                return this.getClass().getSimpleName();
        }
    // 保存文档到向量库
    void add(List<Document> documents);
    // 根据文档id删除文档
    void delete(List<String> idList);

    void delete(Filter.Expression filterExpression);

    default void delete(String filterExpression) { ... };
    // 根据条件检索文档
    List<Document> similaritySearch(String query);
    // 根据条件检索文档
    List<Document> similaritySearch(SearchRequest request);

    default <T> Optional<T> getNativeClient() {
                return Optional.empty();
        }
}

注意,VectorStore操作向量化的基本单位是Document,我们在使用时需要将自己的知识库分割转换为一个个的Document,然后写入VectorStore.

那么问题来了,我们该如何把各种不同的知识库文件转为Document呢?

文件读取和转换

前面说过,知识库太大,是需要拆分成文档片段,然后再做向量化的。而且SpringAI中向量库接收的是Document类型的文档,也就是说,我们处理文档还要转成Document格式。

不过,文档读取、拆分、转换的动作并不需要我们亲自完成。在SpringAI中提供了各种文档读取的工具,可以参考官网:

比如PDF文档读取和拆分,SpringAI提供了两种默认的拆分原则:

  • PagePdfDocumentReader :按页拆分,推荐使用

  • ParagraphPdfDocumentReader :按pdf的目录拆分,不推荐,因为很多PDF不规范,没有章节标签

当然,大家也可以自己实现PDF的读取和拆分功能。

这里我们选择使用PagePdfDocumentReader

首先,我们需要在pom.xml中引入依赖:

<dependency>
    <groupId>org.springframework.ai</groupId>
    <artifactId>spring-ai-pdf-document-reader</artifactId>
</dependency>

然后就可以利用工具把PDF文件读取并处理成Document了。

我们写一个单元测试(别忘了配置API_KEY):

@Test
public void testVectorStore(){
    Resource resource = new FileSystemResource("中二知识笔记.pdf");
    // 1.创建PDF的读取器
    PagePdfDocumentReader reader = new PagePdfDocumentReader(
            resource, // 文件源
            PdfDocumentReaderConfig.builder()
                    .withPageExtractedTextFormatter(ExtractedTextFormatter.defaults())
                    .withPagesPerDocument(1) // 每1页PDF作为一个Document
                    .build()
    );
    // 2.读取PDF文档,拆分为Document
    List<Document> documents = reader.read();
    // 3.写入向量库
    vectorStore.add(documents);
    // 4.搜索
    SearchRequest request = SearchRequest.builder()
            .query("论语中教育的目的是什么")
            .topK(1)
            .similarityThreshold(0.6)
            .filterExpression("file_name == '中二知识笔记.pdf'")
            .build();
    List<Document> docs = vectorStore.similaritySearch(request);
    if (docs == null) {
        System.out.println("没有搜索到任何内容");
        return;
    }
    for (Document doc : docs) {
        System.out.println(doc.getId());
        System.out.println(doc.getScore());
        System.out.println(doc.getText());
    }
}

RAG原理总结

OK,现在我们有了这些工具:

  • PDFReader:读取文档并拆分为片段

  • 向量大模型:将文本片段向量化

  • 向量数据库:存储向量,检索向量

让我们梳理一下要解决的问题和解决思路:

  • 要解决大模型的知识限制问题,需要外挂知识库

  • 受到大模型上下文限制,知识库不能简单的直接拼接在提示词中

  • 我们需要从庞大的知识库中找到与用户问题相关的一小部分,再组装成提示词

  • 这些可以利用文档读取器向量大模型向量数据库来解决。

所以RAG要做的事情就是将知识库分割,然后利用向量模型做向量化,存入向量数据库,然后查询的时候去检索:

第一阶段(存储知识库)

  • 将知识库内容切片,分为一个个片段

  • 将每个片段利用向量模型向量化

  • 将所有向量化后的片段写入向量数据库

第二阶段(检索知识库)

  • 每当用户询问AI时,将用户问题向量化

  • 拿着问题向量去向量数据库检索最相关的片段

第三阶段(对话大模型)

  • 将检索到的片段、用户的问题一起拼接为提示词

  • 发送提示词给大模型,得到响应

PDF上传下载、向量化

既然是ChatPDF,也就是说所有知识库都是PDF形式的,由用户提交给我们。所以,我们需要先实现一个上传PDF的接口,在接口中实现下列功能:

  • 校验文件格式是否为PDF

  • 保存文件信息

    • 保存文件(可以是oss或本地保存)

    • 保存会话ID和文件路径的映射关系(方便查询会话历史的时候再次读取文件)

  • 文档拆分和向量化(文档太大,需要拆分为一个个片段,分别向量化)

另外,将来用户查询会话历史,我们还需要返回pdf文件给前端用于预览,所以需要实现一个下载PDF接口,包含下面功能:

  • 读取文件

  • 返回文件给前端

PDF文件管理

由于将来要实现PDF下载功能,我们需要记住每一个chatId对应的PDF文件名称。

所以,我们定义一个类,记录chatId与pdf文件的映射关系,同时实现基本的文件保存功能。

先在com.itheima.ai.repository中定义接口:

package com.itheima.ai.repository;

import org.springframework.core.io.Resource;

public interface FileRepository {
    /**
     * 保存文件,还要记录chatId与文件的映射关系
     * @param chatId 会话id
     * @param resource 文件
     * @return 上传成功,返回true; 否则返回false
     */
    boolean save(String chatId, Resource resource);

    /**
     * 根据chatId获取文件
     * @param chatId 会话id
     * @return 找到的文件
     */
    Resource getFile(String chatId);
}
package com.itheima.ai.repository;

import jakarta.annotation.PostConstruct;
import jakarta.annotation.PreDestroy;
import lombok.RequiredArgsConstructor;
import lombok.extern.slf4j.Slf4j;
import org.springframework.ai.vectorstore.SimpleVectorStore;
import org.springframework.ai.vectorstore.VectorStore;
import org.springframework.core.io.FileSystemResource;
import org.springframework.core.io.Resource;
import org.springframework.stereotype.Component;
import org.springframework.web.multipart.MultipartFile;

import java.io.*;
import java.nio.charset.StandardCharsets;
import java.nio.file.Files;
import java.time.LocalDateTime;
import java.util.Objects;
import java.util.Properties;

@Slf4j
@Component
@RequiredArgsConstructor
public class LocalPdfFileRepository implements FileRepository {

    private final VectorStore vectorStore;

    // 会话id 与 文件名的对应关系,方便查询会话历史时重新加载文件
    private final Properties chatFiles = new Properties();

    @Override
    public boolean save(String chatId, Resource resource) {

        // 2.保存到本地磁盘
        String filename = resource.getFilename();
        File target = new File(Objects.requireNonNull(filename));
        if (!target.exists()) {
            try {
                Files.copy(resource.getInputStream(), target.toPath());
            } catch (IOException e) {
                log.error("Failed to save PDF resource.", e);
                return false;
            }
        }
        // 3.保存映射关系
        chatFiles.put(chatId, filename);
        return true;
    }

    @Override
    public Resource getFile(String chatId) {
        return new FileSystemResource(chatFiles.getProperty(chatId));
    }

    @PostConstruct
    private void init() {
        FileSystemResource pdfResource = new FileSystemResource("chat-pdf.properties");
        if (pdfResource.exists()) {
            try {
                chatFiles.load(new BufferedReader(new InputStreamReader(pdfResource.getInputStream(), StandardCharsets.UTF_8)));
            } catch (IOException e) {
                throw new RuntimeException(e);
            }
        }
        FileSystemResource vectorResource = new FileSystemResource("chat-pdf.json");
        if (vectorResource.exists()) {
            SimpleVectorStore simpleVectorStore = (SimpleVectorStore) vectorStore;
            simpleVectorStore.load(vectorResource);
        }
    }

    @PreDestroy
    private void persistent() {
        try {
            chatFiles.store(new FileWriter("chat-pdf.properties"), LocalDateTime.now().toString());
            SimpleVectorStore simpleVectorStore = (SimpleVectorStore) vectorStore;
            simpleVectorStore.save(new File("chat-pdf.json"));
        } catch (IOException e) {
            throw new RuntimeException(e);
        }
    }
}
package com.itheima.ai.controller;

import com.itheima.ai.entity.vo.Result;
import com.itheima.ai.repository.FileRepository;
import lombok.RequiredArgsConstructor;
import lombok.extern.slf4j.Slf4j;
import org.springframework.ai.document.Document;
import org.springframework.ai.reader.ExtractedTextFormatter;
import org.springframework.ai.reader.pdf.PagePdfDocumentReader;
import org.springframework.ai.reader.pdf.config.PdfDocumentReaderConfig;
import org.springframework.ai.vectorstore.VectorStore;
import org.springframework.core.io.Resource;
import org.springframework.http.MediaType;
import org.springframework.http.ResponseEntity;
import org.springframework.web.bind.annotation.*;
import org.springframework.web.multipart.MultipartFile;

import java.io.IOException;
import java.net.URLEncoder;
import java.nio.charset.StandardCharsets;
import java.util.List;
import java.util.Objects;

@Slf4j
@RequiredArgsConstructor
@RestController
@RequestMapping("/ai/pdf")
public class PdfController {

    private final FileRepository fileRepository;

    private final VectorStore vectorStore;
    /**
     * 文件上传
     */
    @RequestMapping("/upload/{chatId}")
    public Result uploadPdf(@PathVariable String chatId, @RequestParam("file") MultipartFile file) {
        try {
            // 1. 校验文件是否为PDF格式
            if (!Objects.equals(file.getContentType(), "application/pdf")) {
                return Result.fail("只能上传PDF文件!");
            }
            // 2.保存文件
            boolean success = fileRepository.save(chatId, file.getResource());
            if(! success) {
                return Result.fail("保存文件失败!");
            }
            // 3.写入向量库
            this.writeToVectorStore(file.getResource());
            return Result.ok();
        } catch (Exception e) {
            log.error("Failed to upload PDF.", e);
            return Result.fail("上传文件失败!");
        }
    }

    /**
     * 文件下载
     */
    @GetMapping("/file/{chatId}")
    public ResponseEntity<Resource> download(@PathVariable("chatId") String chatId) throws IOException {
        // 1.读取文件
        Resource resource = fileRepository.getFile(chatId);
        if (!resource.exists()) {
            return ResponseEntity.notFound().build();
        }
        // 2.文件名编码,写入响应头
        String filename = URLEncoder.encode(Objects.requireNonNull(resource.getFilename()), StandardCharsets.UTF_8);
        // 3.返回文件
        return ResponseEntity.ok()
                .contentType(MediaType.APPLICATION_OCTET_STREAM)
                .header("Content-Disposition", "attachment; filename=\"" + filename + "\"")
                .body(resource);
    }

    private void writeToVectorStore(Resource resource) {
        // 1.创建PDF的读取器
        PagePdfDocumentReader reader = new PagePdfDocumentReader(
                resource, // 文件源
                PdfDocumentReaderConfig.builder()
                        .withPageExtractedTextFormatter(ExtractedTextFormatter.defaults())
                        .withPagesPerDocument(1) // 每1页PDF作为一个Document
                        .build()
        );
        // 2.读取PDF文档,拆分为Document
        List<Document> documents = reader.read();
        // 3.写入向量库
        vectorStore.add(documents);
    }
}

Logo

有“AI”的1024 = 2048,欢迎大家加入2048 AI社区

更多推荐