Spring AI 1.0.4 版本持久化对话完全指南:从基础到自定义
可以恢复之前的对话上下文支持多轮对话的连续性便于对话历史管理和分析支持多设备同步@Service// 最大消息历史数量@Override@Override.toList();// 可以从上下文获取conversation.setTitle("新对话");// 处理工具调用信息。
·
Spring AI 1.0.4 版本持久化对话完全指南:从基础到自定义
前言
随着人工智能应用的普及,对话式 AI 系统已经成为许多应用的核心功能。Spring AI 作为 Spring 生态中的人工智能框架,在 1.0.4 版本中对对话持久化功能进行了重要增强。本文将深入探讨 Spring AI 1.0.4 的对话持久化机制,包括基础配置、工具调用的对话持久化,以及如何实现自定义持久化方案。
一、Spring AI 对话持久化概述
1.1 什么是对话持久化
对话持久化是指将用户与 AI 的交互历史存储到持久化存储中,使得:
- 可以恢复之前的对话上下文
- 支持多轮对话的连续性
- 便于对话历史管理和分析
- 支持多设备同步
1.2 Spring AI 1.0.4 的持久化架构
┌─────────────┐
│ User │
└──────┬──────┘
│
▼
┌─────────────┐ ┌─────────────┐
│ Chat │────▶│ ChatMemory │
│ Interface │ │ Service │
└──────┬──────┘ └──────┬──────┘
│ │
│ ┌───▼───┐
│ │ Store │
│ └───┬───┘
▼ │
┌─────────────┐ ┌─────▼─────┐
│ AI │ │ Database │
│ Service │ │ / Redis │
└─────────────┘ └───────────┘
二、基础对话持久化配置
2.1 依赖配置
在 pom.xml 中添加必要的依赖:
<dependencies>
<!-- Spring AI 核心依赖 -->
<dependency>
<groupId>org.springframework.ai</groupId>
<artifactId>spring-ai-openai-spring-boot-starter</artifactId>
<version>1.0.4</version>
</dependency>
<!-- JPA 持久化 -->
<dependency>
<groupId>org.springframework.boot</groupId>
<artifactId>spring-boot-starter-data-jpa</artifactId>
</dependency>
<!-- H2 数据库(开发环境) -->
<dependency>
<groupId>com.h2database</groupId>
<artifactId>h2</artifactId>
<scope>runtime</scope>
</dependency>
<!-- MySQL 驱动(生产环境) -->
<dependency>
<groupId>mysql</groupId>
<artifactId>mysql-connector-java</artifactId>
</dependency>
</dependencies>
2.2 application.yml 配置
spring:
ai:
openai:
api-key: ${OPENAI_API_KEY}
chat:
options:
model: gpt-4
temperature: 0.7
datasource:
url: jdbc:h2:mem:testdb
driver-class-name: org.h2.Driver
username: sa
password:
jpa:
hibernate:
ddl-auto: update
show-sql: true
properties:
hibernate:
format_sql: true
h2:
console:
enabled: true
2.3 基础实体类设计
import jakarta.persistence.*;
import java.time.LocalDateTime;
import java.util.ArrayList;
import java.util.List;
@Entity
@Table(name = "chat_conversation")
public class ChatConversation {
@Id
@GeneratedValue(strategy = GenerationType.IDENTITY)
private Long id;
@Column(nullable = false, unique = true)
private String conversationId;
@Column(nullable = false)
private String userId;
@Column(length = 500)
private String title;
@Column(nullable = false)
private LocalDateTime createdAt;
@Column(nullable = false)
private LocalDateTime updatedAt;
@OneToMany(mappedBy = "conversation", cascade = CascadeType.ALL, orphanRemoval = true)
private List<ChatMessage> messages = new ArrayList<>();
// 构造函数、getter、setter
public ChatConversation() {
this.createdAt = LocalDateTime.now();
this.updatedAt = LocalDateTime.now();
}
public void addMessage(ChatMessage message) {
messages.add(message);
message.setConversation(this);
this.updatedAt = LocalDateTime.now();
}
// getter 和 setter 省略...
}
@Entity
@Table(name = "chat_message")
public class ChatMessage {
@Id
@GeneratedValue(strategy = GenerationType.IDENTITY)
private Long id;
@ManyToOne(fetch = FetchType.LAZY)
@JoinColumn(name = "conversation_id", nullable = false)
private ChatConversation conversation;
@Enumerated(EnumType.STRING)
@Column(nullable = false)
private MessageType type;
@Column(nullable = false, columnDefinition = "TEXT")
private String content;
@Column(columnDefinition = "TEXT")
private String toolCalls;
@Column(columnDefinition = "TEXT")
private String toolResponses;
@Column(nullable = false)
private LocalDateTime timestamp;
private Integer tokens;
// 构造函数、getter、setter
public ChatMessage() {
this.timestamp = LocalDateTime.now();
}
public enum MessageType {
USER,
ASSISTANT,
SYSTEM,
TOOL
}
// getter 和 setter 省略...
}
2.4 Repository 层
import org.springframework.data.jpa.repository.JpaRepository;
import org.springframework.data.jpa.repository.Query;
import org.springframework.data.repository.query.Param;
import org.springframework.stereotype.Repository;
import java.util.List;
import java.util.Optional;
@Repository
public interface ChatConversationRepository extends JpaRepository<ChatConversation, Long> {
Optional<ChatConversation> findByConversationId(String conversationId);
List<ChatConversation> findByUserIdOrderByUpdatedAtDesc(String userId);
@Query("SELECT c FROM ChatConversation c WHERE c.userId = :userId " +
"ORDER BY c.updatedAt DESC")
List<ChatConversation> findRecentConversations(@Param("userId") String userId,
Pageable pageable);
}
三、实现 ChatMemory 接口
3.1 自定义 ChatMemory 实现
import org.springframework.ai.chat.memory.ChatMemory;
import org.springframework.ai.chat.messages.Message;
import org.springframework.ai.chat.messages.MessageType;
import org.springframework.stereotype.Service;
import org.springframework.transaction.annotation.Transactional;
import java.util.ArrayList;
import java.util.List;
import java.util.UUID;
@Service
public class PersistentChatMemory implements ChatMemory {
private final ChatConversationRepository conversationRepository;
private final ChatMessageRepository messageRepository;
// 最大消息历史数量
private static final int MAX_MESSAGES = 50;
public PersistentChatMemory(ChatConversationRepository conversationRepository,
ChatMessageRepository messageRepository) {
this.conversationRepository = conversationRepository;
this.messageRepository = messageRepository;
}
@Override
@Transactional
public void add(String conversationId, List<Message> messages) {
ChatConversation conversation = conversationRepository
.findByConversationId(conversationId)
.orElseGet(() -> createNewConversation(conversationId));
for (Message message : messages) {
ChatMessage chatMessage = convertToChatMessage(message, conversation);
conversation.addMessage(chatMessage);
}
conversationRepository.save(conversation);
}
@Override
@Transactional(readOnly = true)
public List<Message> get(String conversationId, int lastN) {
ChatConversation conversation = conversationRepository
.findByConversationId(conversationId)
.orElse(null);
if (conversation == null) {
return List.of();
}
return conversation.getMessages().stream()
.skip(Math.max(0, conversation.getMessages().size() - lastN))
.map(this::convertToSpringAIMessage)
.toList();
}
@Override
@Transactional(readOnly = true)
public List<Message> get(String conversationId) {
return get(conversationId, MAX_MESSAGES);
}
@Override
@Transactional
public void clear(String conversationId) {
ChatConversation conversation = conversationRepository
.findByConversationId(conversationId)
.orElse(null);
if (conversation != null) {
conversation.getMessages().clear();
conversationRepository.save(conversation);
}
}
private ChatConversation createNewConversation(String conversationId) {
ChatConversation conversation = new ChatConversation();
conversation.setConversationId(conversationId != null ? conversationId : UUID.randomUUID().toString());
conversation.setUserId("default_user"); // 可以从上下文获取
conversation.setTitle("新对话");
return conversation;
}
private ChatMessage convertToChatMessage(Message message, ChatConversation conversation) {
ChatMessage chatMessage = new ChatMessage();
chatMessage.setConversation(conversation);
chatMessage.setContent(message.getContent());
chatMessage.setType(mapMessageType(message.getMessageType()));
// 处理工具调用信息
if (message.getMessageType() == org.springframework.ai.chat.messages.MessageType.ASSISTANT) {
if (message.getMetadata() != null && message.getMetadata().containsKey("toolCalls")) {
chatMessage.setToolCalls(
message.getMetadata().get("toolCalls").toString()
);
}
}
return chatMessage;
}
private ChatMessage.MessageType mapMessageType(org.springframework.ai.chat.messages.MessageType type) {
return switch (type) {
case USER -> ChatMessage.MessageType.USER;
case ASSISTANT -> ChatMessage.MessageType.ASSISTANT;
case SYSTEM -> ChatMessage.MessageType.SYSTEM;
default -> ChatMessage.MessageType.ASSISTANT;
};
}
private Message convertToSpringAIMessage(ChatMessage chatMessage) {
return switch (chatMessage.getType()) {
case USER -> new org.springframework.ai.chat.messages.UserMessage(chatMessage.getContent());
case ASSISTANT -> {
org.springframework.ai.chat.messages.AssistantMessage assistantMessage =
new org.springframework.ai.chat.messages.AssistantMessage(chatMessage.getContent());
// 恢复工具调用信息
if (chatMessage.getToolCalls() != null) {
// 这里需要根据实际存储格式反序列化
// assistantMessage.setToolCalls(...);
}
yield assistantMessage;
}
case SYSTEM -> new org.springframework.ai.chat.messages.SystemMessage(chatMessage.getContent());
case TOOL -> new org.springframework.ai.chat.messages.ToolResponseMessage(
chatMessage.getContent(),
List.of() // 工具ID等
);
};
}
}
3.2 配置 ChatService 使用自定义 ChatMemory
import org.springframework.ai.chat.client.ChatClient;
import org.springframework.ai.chat.memory.ChatMemory;
import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Configuration;
@Configuration
public class ChatConfig {
@Bean
public ChatClient chatClient(ChatClient.Builder builder,
ChatMemory chatMemory) {
return builder
.defaultSystem("你是一个有帮助的AI助手")
.defaultAdvisors(new MessageChatMemoryAdvisor(chatMemory))
.build();
}
}
四、工具调用的对话持久化
4.1 工具调用扩展实体
@Entity
@Table(name = "tool_execution")
public class ToolExecution {
@Id
@GeneratedValue(strategy = GenerationType.IDENTITY)
private Long id;
@ManyToOne(fetch = FetchType.LAZY)
@JoinColumn(name = "chat_message_id", nullable = false)
private ChatMessage chatMessage;
@Column(nullable = false)
private String toolName;
@Column(columnDefinition = "TEXT")
private String arguments;
@Column
private String result;
@Column
private Long executionTime; // 执行耗时(毫秒)
@Enumerated(EnumType.STRING)
@Column(nullable = false)
private ExecutionStatus status;
@Column(nullable = false)
private LocalDateTime executedAt;
@Column(length = 1000)
private String errorMessage;
public enum ExecutionStatus {
SUCCESS,
FAILED,
TIMEOUT
}
public ToolExecution() {
this.executedAt = LocalDateTime.now();
}
// getter 和 setter 省略...
}
4.2 工具调用监听器
import org.springframework.ai.chat.client.advisor.MessageChatMemoryAdvisor;
import org.springframework.ai.chat.client.advisor.QuestionAnswerAdvisor;
import org.springframework.ai.chat.messages.AssistantMessage;
import org.springframework.ai.chat.model.ChatResponse;
import org.springframework.ai.chat.model.Generation;
import org.springframework.ai.chat.prompt.ChatOptions;
import org.springframework.ai.model.function.FunctionCallback;
import org.springframework.ai.model.function.FunctionCallbackContext;
import org.springframework.ai.tool.ToolCallback;
import org.springframework.ai.tool.ToolCallbackResponse;
import org.springframework.stereotype.Component;
import org.springframework.transaction.annotation.Transactional;
import com.fasterxml.jackson.databind.ObjectMapper;
import java.util.List;
import java.util.Map;
@Component
public class ToolExecutionListener {
private final ToolExecutionRepository toolExecutionRepository;
private final ChatMessageRepository chatMessageRepository;
private final ObjectMapper objectMapper;
public ToolExecutionListener(ToolExecutionRepository toolExecutionRepository,
ChatMessageRepository chatMessageRepository,
ObjectMapper objectMapper) {
this.toolExecutionRepository = toolExecutionRepository;
this.chatMessageRepository = chatMessageRepository;
this.objectMapper = objectMapper;
}
@Transactional
public void recordToolExecution(Long chatMessageId,
String toolName,
Map<String, Object> arguments,
Object result,
long executionTime,
ToolExecution.ExecutionStatus status,
String errorMessage) {
ChatMessage chatMessage = chatMessageRepository.findById(chatMessageId)
.orElseThrow(() -> new IllegalArgumentException("Chat message not found"));
ToolExecution execution = new ToolExecution();
execution.setChatMessage(chatMessage);
execution.setToolName(toolName);
try {
execution.setArguments(objectMapper.writeValueAsString(arguments));
if (result != null) {
execution.setResult(objectMapper.writeValueAsString(result));
}
} catch (Exception e) {
execution.setArguments(arguments.toString());
if (result != null) {
execution.setResult(result.toString());
}
}
execution.setExecutionTime(executionTime);
execution.setStatus(status);
execution.setErrorMessage(errorMessage);
toolExecutionRepository.save(execution);
}
@Transactional
public void recordAssistantMessageWithToolCalls(ChatMessage chatMessage,
List<ToolCallback> toolCalls) {
try {
String toolCallsJson = objectMapper.writeValueAsString(toolCalls);
chatMessage.setToolCalls(toolCallsJson);
chatMessageRepository.save(chatMessage);
} catch (Exception e) {
chatMessage.setToolCalls(toolCalls.toString());
chatMessageRepository.save(chatMessage);
}
}
}
4.3 完整的聊天服务实现
import org.springframework.ai.chat.client.ChatClient;
import org.springframework.ai.chat.client.advisor.MessageChatMemoryAdvisor;
import org.springframework.ai.chat.messages.Message;
import org.springframework.ai.chat.messages.UserMessage;
import org.springframework.ai.chat.prompt.ChatOptions;
import org.springframework.ai.tool.ToolCallback;
import org.springframework.stereotype.Service;
import org.springframework.transaction.annotation.Transactional;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.UUID;
@Service
public class ChatService {
private final ChatClient chatClient;
private final ChatConversationRepository conversationRepository;
private final ToolExecutionListener toolExecutionListener;
private final PersistentChatMemory chatMemory;
public ChatService(ChatClient chatClient,
ChatConversationRepository conversationRepository,
ToolExecutionListener toolExecutionListener,
PersistentChatMemory chatMemory) {
this.chatClient = chatClient;
this.conversationRepository = conversationRepository;
this.toolExecutionListener = toolExecutionListener;
this.chatMemory = chatMemory;
}
@Transactional
public ChatResponse chat(String conversationId, String userMessage, List<ToolCallback> tools) {
// 确保 conversation 存在
if (conversationId == null) {
conversationId = UUID.randomUUID().toString();
ChatConversation conversation = new ChatConversation();
conversation.setConversationId(conversationId);
conversation.setUserId(getCurrentUserId());
conversation.setTitle(generateTitle(userMessage));
conversationRepository.save(conversation);
}
// 构建请求
UserMessage message = new UserMessage(userMessage);
// 保存用户消息
chatMemory.add(conversationId, List.of(message));
long startTime = System.currentTimeMillis();
// 调用 AI
ChatResponse response;
if (tools != null && !tools.isEmpty()) {
response = chatClient.prompt()
.user(userMessage)
.advisors(new MessageChatMemoryAdvisor(chatMemory, conversationId, 50))
.functions(tools)
.call()
.chatResponse();
} else {
response = chatClient.prompt()
.user(userMessage)
.advisors(new MessageChatMemoryAdvisor(chatMemory, conversationId, 50))
.call()
.chatResponse();
}
long executionTime = System.currentTimeMillis() - startTime;
// 处理响应并保存
processAndSaveResponse(conversationId, response, executionTime);
return response;
}
private void processAndSaveResponse(String conversationId,
ChatResponse response,
long executionTime) {
Generation generation = response.getResult();
AssistantMessage assistantMessage = generation.getOutput();
// 保存助手消息
chatMemory.add(conversationId, List.of(assistantMessage));
// 如果有工具调用,记录工具执行
if (assistantMessage.getToolCalls() != null && !assistantMessage.getToolCalls().isEmpty()) {
// 查找刚保存的消息 ID
ChatConversation conversation = conversationRepository
.findByConversationId(conversationId)
.orElseThrow();
ChatMessage lastMessage = conversation.getMessages()
.get(conversation.getMessages().size() - 1);
// 记录工具调用
toolExecutionListener.recordAssistantMessageWithToolCalls(
lastMessage,
assistantMessage.getToolCalls()
);
// 记录每个工具的执行详情
for (var toolCall : assistantMessage.getToolCalls()) {
toolExecutionListener.recordToolExecution(
lastMessage.getId(),
toolCall.name(),
toolCall.arguments(),
null, // 结果在后续处理
executionTime,
ToolExecution.ExecutionStatus.SUCCESS,
null
);
}
}
}
@Transactional(readOnly = true)
public List<ChatMessage> getConversationHistory(String conversationId) {
ChatConversation conversation = conversationRepository
.findByConversationId(conversationId)
.orElseThrow(() -> new IllegalArgumentException("Conversation not found"));
return conversation.getMessages();
}
@Transactional
public void deleteConversation(String conversationId) {
conversationRepository.deleteByConversationId(conversationId);
chatMemory.clear(conversationId);
}
private String getCurrentUserId() {
// 从 Spring Security 上下文或其他方式获取当前用户
return "user_" + UUID.randomUUID().toString();
}
private String generateTitle(String firstMessage) {
// 简化版,实际可以用 AI 生成标题
return firstMessage.length() > 20 ?
firstMessage.substring(0, 20) + "..." : firstMessage;
}
}
五、自定义持久化实现
5.1 Redis 持久化方案
对于高并发场景,可以使用 Redis 作为缓存层,结合数据库实现混合持久化策略。
Redis 配置
spring:
data:
redis:
host: localhost
port: 6379
password:
database: 0
timeout: 5000ms
lettuce:
pool:
max-active: 8
max-wait: -1ms
max-idle: 8
min-idle: 0
Redis ChatMemory 实现
import com.fasterxml.jackson.databind.ObjectMapper;
import org.springframework.ai.chat.memory.ChatMemory;
import org.springframework.ai.chat.messages.Message;
import org.springframework.data.redis.core.RedisTemplate;
import org.springframework.stereotype.Service;
import java.time.Duration;
import java.util.ArrayList;
import java.util.List;
@Service
public class RedisChatMemory implements ChatMemory {
private final RedisTemplate<String, Object> redisTemplate;
private final ObjectMapper objectMapper;
private static final String CHAT_KEY_PREFIX = "chat:";
private static final Duration DEFAULT_TTL = Duration.ofDays(7);
private static final int MAX_MESSAGES = 100;
public RedisChatMemory(RedisTemplate<String, Object> redisTemplate,
ObjectMapper objectMapper) {
this.redisTemplate = redisTemplate;
this.objectMapper = objectMapper;
}
@Override
public void add(String conversationId, List<Message> messages) {
String key = CHAT_KEY_PREFIX + conversationId;
for (Message message : messages) {
redisTemplate.opsForList().rightPush(key, serializeMessage(message));
}
// 设置过期时间
redisTemplate.expire(key, DEFAULT_TTL);
// 限制消息数量
Long size = redisTemplate.opsForList().size(key);
if (size != null && size > MAX_MESSAGES) {
redisTemplate.opsForList().trim(key, size - MAX_MESSAGES, -1);
}
}
@Override
public List<Message> get(String conversationId, int lastN) {
String key = CHAT_KEY_PREFIX + conversationId;
Long size = redisTemplate.opsForList().size(key);
if (size == null || size == 0) {
return List.of();
}
int start = Math.max(0, (int)(size - lastN));
List<Object> objects = redisTemplate.opsForList().range(key, start, -1);
if (objects == null) {
return List.of();
}
return objects.stream()
.map(obj -> deserializeMessage((String) obj))
.toList();
}
@Override
public List<Message> get(String conversationId) {
return get(conversationId, MAX_MESSAGES);
}
@Override
public void clear(String conversationId) {
String key = CHAT_KEY_PREFIX + conversationId;
redisTemplate.delete(key);
}
private String serializeMessage(Message message) {
try {
Map<String, Object> messageMap = new HashMap<>();
messageMap.put("type", message.getMessageType().name());
messageMap.put("content", message.getContent());
messageMap.put("metadata", message.getMetadata());
return objectMapper.writeValueAsString(messageMap);
} catch (Exception e) {
throw new RuntimeException("Failed to serialize message", e);
}
}
private Message deserializeMessage(String json) {
try {
Map<String, Object> messageMap = objectMapper.readValue(json, Map.class);
String type = (String) messageMap.get("type");
String content = (String) messageMap.get("content");
return switch (type) {
case "USER" -> new org.springframework.ai.chat.messages.UserMessage(content);
case "ASSISTANT" -> new org.springframework.ai.chat.messages.AssistantMessage(content);
case "SYSTEM" -> new org.springframework.ai.chat.messages.SystemMessage(content);
default -> new org.springframework.ai.chat.messages.AssistantMessage(content);
};
} catch (Exception e) {
throw new RuntimeException("Failed to deserialize message", e);
}
}
}
5.2 混合持久化策略(Redis + Database)
import org.springframework.ai.chat.memory.ChatMemory;
import org.springframework.ai.chat.messages.Message;
import org.springframework.scheduling.annotation.Async;
import org.springframework.stereotype.Service;
import org.springframework.transaction.annotation.Transactional;
import java.util.List;
import java.util.concurrent.CompletableFuture;
@Service
public class HybridChatMemory implements ChatMemory {
private final RedisChatMemory redisMemory;
private final PersistentChatMemory dbMemory;
public HybridChatMemory(RedisChatMemory redisMemory,
PersistentChatMemory dbMemory) {
this.redisMemory = redisMemory;
this.dbMemory = dbMemory;
}
@Override
public void add(String conversationId, List<Message> messages) {
// 先写入 Redis(快速)
redisMemory.add(conversationId, messages);
// 异步写入数据库(持久化)
asyncAddToDatabase(conversationId, messages);
}
@Async
@Transactional
protected void asyncAddToDatabase(String conversationId, List<Message> messages) {
dbMemory.add(conversationId, messages);
}
@Override
public List<Message> get(String conversationId, int lastN) {
// 先从 Redis 获取
List<Message> messages = redisMemory.get(conversationId, lastN);
// 如果 Redis 为空,从数据库加载
if (messages.isEmpty()) {
messages = dbMemory.get(conversationId, lastN);
// 回填 Redis
if (!messages.isEmpty()) {
redisMemory.add(conversationId, messages);
}
}
return messages;
}
@Override
public List<Message> get(String conversationId) {
return get(conversationId, 50);
}
@Override
public void clear(String conversationId) {
redisMemory.clear(conversationId);
dbMemory.clear(conversationId);
}
}
5.3 自定义存储适配器
import org.springframework.ai.chat.memory.ChatMemory;
import org.springframework.ai.chat.messages.Message;
import org.springframework.stereotype.Component;
import java.util.ArrayList;
import java.util.List;
/**
* 自定义存储适配器接口
*/
public interface StorageAdapter {
void store(String conversationId, List<Message> messages);
List<Message> retrieve(String conversationId, int limit);
void delete(String conversationId);
boolean exists(String conversationId);
}
/**
* MongoDB 存储适配器示例
*/
@Component
public class MongoStorageAdapter implements StorageAdapter {
private final MongoTemplate mongoTemplate;
public MongoStorageAdapter(MongoTemplate mongoTemplate) {
this.mongoTemplate = mongoTemplate;
}
@Override
public void store(String conversationId, List<Message> messages) {
Query query = Query.query(Criteria.where("conversationId").is(conversationId));
ConversationDocument doc = mongoTemplate.findOne(query, ConversationDocument.class);
if (doc == null) {
doc = new ConversationDocument();
doc.setConversationId(conversationId);
}
List<MessageDocument> messageDocs = messages.stream()
.map(this::convertToDocument)
.toList();
doc.getMessages().addAll(messageDocs);
mongoTemplate.save(doc);
}
@Override
public List<Message> retrieve(String conversationId, int limit) {
Query query = Query.query(Criteria.where("conversationId").is(conversationId))
.with(Sort.by(Sort.Direction.ASC, "timestamp"))
.limit(limit);
ConversationDocument doc = mongoTemplate.findOne(query, ConversationDocument.class);
if (doc == null) {
return List.of();
}
return doc.getMessages().stream()
.map(this::convertToMessage)
.toList();
}
@Override
public void delete(String conversationId) {
Query query = Query.query(Criteria.where("conversationId").is(conversationId));
mongoTemplate.remove(query, ConversationDocument.class);
}
@Override
public boolean exists(String conversationId) {
Query query = Query.query(Criteria.where("conversationId").is(conversationId));
return mongoTemplate.exists(query, ConversationDocument.class);
}
// 转换方法省略...
}
/**
* 使用适配器的 ChatMemory 实现
*/
@Component
public class CustomStorageChatMemory implements ChatMemory {
private final StorageAdapter storageAdapter;
public CustomStorageChatMemory(StorageAdapter storageAdapter) {
this.storageAdapter = storageAdapter;
}
@Override
public void add(String conversationId, List<Message> messages) {
storageAdapter.store(conversationId, messages);
}
@Override
public List<Message> get(String conversationId, int lastN) {
return storageAdapter.retrieve(conversationId, lastN);
}
@Override
public List<Message> get(String conversationId) {
return get(conversationId, 50);
}
@Override
public void clear(String conversationId) {
storageAdapter.delete(conversationId);
}
}
六、REST API 接口
import org.springframework.ai.chat.client.ChatClient;
import org.springframework.http.ResponseEntity;
import org.springframework.web.bind.annotation.*;
import java.util.List;
import java.util.Map;
@RestController
@RequestMapping("/api/chat")
public class ChatController {
private final ChatService chatService;
public ChatController(ChatService chatService) {
this.chatService = chatService;
}
@PostMapping("/send")
public ResponseEntity<Map<String, Object>> sendMessage(
@RequestBody ChatRequest request) {
ChatResponse response = chatService.chat(
request.getConversationId(),
request.getMessage(),
request.getTools()
);
return ResponseEntity.ok(Map.of(
"conversationId", request.getConversationId(),
"response", response.getResult().getOutput().getContent(),
"metadata", response.getMetadata()
));
}
@GetMapping("/history/{conversationId}")
public ResponseEntity<List<ChatMessage>> getHistory(
@PathVariable String conversationId) {
return ResponseEntity.ok(chatService.getConversationHistory(conversationId));
}
@DeleteMapping("/{conversationId}")
public ResponseEntity<Void> deleteConversation(
@PathVariable String conversationId) {
chatService.deleteConversation(conversationId);
return ResponseEntity.ok().build();
}
@GetMapping("/conversations")
public ResponseEntity<List<ChatConversation>> getUserConversations(
@RequestParam String userId) {
return ResponseEntity.ok(
chatService.getConversationsByUserId(userId)
);
}
}
record ChatRequest(
String conversationId,
String message,
List<ToolCallback> tools
) {}
七、最佳实践与注意事项
7.1 性能优化
- 批量操作: 尽量使用批量插入而不是逐条插入
- 索引优化: 为常用查询字段添加索引
- 缓存策略: 合理使用 Redis 缓存热点数据
- 分页加载: 历史消息分页加载,避免一次性加载过多数据
// 分页查询示例
@Query("SELECT m FROM ChatMessage m WHERE m.conversation.conversationId = :conversationId " +
"ORDER BY m.timestamp ASC")
Page<ChatMessage> findByConversationId(@Param("conversationId") String conversationId,
Pageable pageable);
7.2 数据安全
- 敏感信息过滤: 在持久化前过滤敏感信息
- 数据加密: 对存储的消息内容进行加密
- 访问控制: 实现基于用户权限的访问控制
@Component
public class SensitiveDataFilter {
private static final Pattern SENSITIVE_PATTERN =
Pattern.compile("(password|token|secret|key)\\s*[:=]\\s*[^\\s]+", Pattern.CASE_INSENSITIVE);
public String filterSensitiveData(String content) {
return SENSITIVE_PATTERN.matcher(content).replaceAll("[REDACTED]");
}
}
7.3 数据清理策略
@Component
@Slf4j
public class ChatDataCleanupScheduler {
private final ChatConversationRepository conversationRepository;
private final ChatMessageRepository messageRepository;
@Scheduled(cron = "0 0 2 * * ?") // 每天凌晨2点执行
@Transactional
public void cleanupOldConversations() {
LocalDateTime cutoffDate = LocalDateTime.now().minusDays(90);
List<ChatConversation> oldConversations =
conversationRepository.findByUpdatedAtBefore(cutoffDate);
log.info("Found {} old conversations to clean up", oldConversations.size());
conversationRepository.deleteAll(oldConversations);
log.info("Cleaned up {} old conversations", oldConversations.size());
}
}
7.4 监控与日志
@Aspect
@Component
@Slf4j
public class ChatPerformanceMonitor {
@Around("execution(* com.example.chat.service.ChatService.chat(..))")
public Object monitorChatPerformance(ProceedingJoinPoint joinPoint) throws Throwable {
long startTime = System.currentTimeMillis();
try {
Object result = joinPoint.proceed();
long duration = System.currentTimeMillis() - startTime;
log.info("Chat request completed in {} ms", duration);
// 记录到监控系统
recordMetric("chat.duration", duration);
return result;
} catch (Exception e) {
long duration = System.currentTimeMillis() - startTime;
log.error("Chat request failed after {} ms", duration, e);
recordMetric("chat.errors", 1);
throw e;
}
}
private void recordMetric(String name, long value) {
// 集成 Prometheus、Micrometer 等
}
}
八、总结
本文详细介绍了 Spring AI 1.0.4 版本中对话持久化的完整实现方案,包括:
- 基础配置: 数据库实体、Repository、ChatMemory 接口实现
- 工具调用持久化: 扩展实体类、监听器、完整聊天服务
- 自定义持久化: Redis、MongoDB 等多种存储方案
- 混合策略: Redis 缓存 + 数据库持久化的混合架构
- 最佳实践: 性能优化、数据安全、清理策略、监控日志
通过合理的设计和实现,可以构建一个高性能、可靠的对话持久化系统,为用户提供连续的对话体验,同时支持数据分析和历史回溯。
参考资源
作者: AI 技术分享者
发布时间: 2024年
版本: Spring AI 1.0.4
更多推荐



所有评论(0)