Spring AI 入门
Spring AI 是 Spring 生态系统中的最新成员,旨在简化 AI 功能集成到 Spring 应用程序中的过程。它提供了一个统一的抽象层,让你可以轻松地与各种 AI 模型和服务进行交互。基础配置:环境搭建和项目配置核心概念:ChatClient 和 EmbeddingClient实际应用:文本生成和问答系统技术要点:消息构建、向量计算、错误处理最佳实践:性能优化、缓存、限流等掌握了这些内容
目录
Spring AI 概述
Spring AI 是 Spring 生态系统中的最新成员,旨在简化 AI 功能集成到 Spring 应用程序中的过程。它提供了一个统一的抽象层,让你可以轻松地与各种 AI 模型和服务进行交互。
主要特性
-
模型抽象:统一的 API 接口,支持多种 AI 模型
-
Spring 集成:无缝集成 Spring Boot 生态系统
-
配置灵活:支持多种配置方式和环境
-
扩展性强:易于扩展和定制
环境准备与项目搭建
1. JDK 版本要求
-
Java 17 或更高版本
-
Maven 3.6+ 或 Gradle 7.0+
2. 项目初始化
使用 Maven
<parent> <groupId>org.springframework.boot</groupId> <artifactId>spring-boot-starter-parent</artifactId> <version>3.2.0</version> <relativePath/> </parent> <dependencies> <!-- Spring AI Core --> <dependency> <groupId>org.springframework.experimental</groupId> <artifactId>spring-ai-core</artifactId> <version>1.0.0-M3</version> </dependency> <!-- Spring AI OpenAI --> <dependency> <groupId>org.springframework.experimental</groupId> <artifactId>spring-ai-openai</artifactId> <version>1.0.0-M3</version> </dependency> <!-- Spring Boot Web --> <dependency> <groupId>org.springframework.boot</groupId> <artifactId>spring-boot-starter-web</artifactId> </dependency> </dependencies> <repositories> <repository> <id>spring-milestones</id> <name>Spring Milestones</name> <url>https://repo.spring.io/milestone</url> <snapshots> <enabled>false</enabled> </snapshots> </repository> </repositories>
使用 Gradle
repositories { mavenCentral() maven { url 'https://repo.spring.io/milestone' } } dependencies { implementation 'org.springframework.experimental:spring-ai-core:1.0.0-M3' implementation 'org.springframework.experimental:spring-ai-openai:1.0.0-M3' implementation 'org.springframework.boot:spring-boot-starter-web:3.2.0' }
基础配置与核心概念
1. 配置文件设置
application.yml
spring: ai: openai: api-key: ${OPENAI_API_KEY:your-api-key-here} client: provider: openai chat: options: model: gpt-3.5-turbo temperature: 0.7 max-tokens: 1000 embedding: options: model: text-embedding-ada-002
application.properties
# OpenAI 配置 spring.ai.openai.api-key=${OPENAI_API_KEY:your-api-key-here} spring.ai.openai.client.provider=openai spring.ai.openai.client.chat.options.model=gpt-3.5-turbo spring.ai.openai.client.chat.options.temperature=0.7 spring.ai.openai.client.chat.options.max-tokens=1000
2. 核心概念理解
ChatClient
用于与聊天模型交互的主要客户端:
@Autowired private ChatClient chatClient; public String generateResponse(String prompt) { ChatRequest request = ChatRequest.of(prompt); ChatResponse response = chatClient.generate(request); return response.getResult().getOutput().getContent(); }
EmbeddingClient
用于生成文本嵌入向量的客户端:
@Autowired private EmbeddingClient embeddingClient; public List<Double> generateEmbedding(String text) { EmbeddingRequest request = EmbeddingRequest.of(text); EmbeddingResponse response = embeddingClient.embed(request); return response.getResult().getOutput().getEmbeddings().get(0).getEmbedding(); }
初级应用:文本生成
1. 简单的文本生成服务
TextGenerationService.java
package com.example.springai.service; import org.springframework.ai.chat.client.ChatClient; import org.springframework.ai.chat.model.ChatRequest; import org.springframework.ai.chat.model.ChatResponse; import org.springframework.beans.factory.annotation.Autowired; import org.springframework.stereotype.Service; @Service public class TextGenerationService { private final ChatClient chatClient; public TextGenerationService(ChatClient chatClient) { this.chatClient = chatClient; } /** * 生成简单文本 */ public String generateText(String prompt) { ChatRequest request = ChatRequest.of(prompt); ChatResponse response = chatClient.generate(request); return response.getResult().getOutput().getContent(); } /** * 带参数的文本生成 */ public String generateTextWithOptions(String prompt, Double temperature, Integer maxTokens) { ChatRequest request = ChatRequest.builder() .withPrompt(prompt) .withTemperature(temperature) .withMaxTokens(maxTokens) .build(); ChatResponse response = chatClient.generate(request); return response.getResult().getOutput().getContent(); } }
TextGenerationController.java
package com.example.springai.controller; import com.example.springai.service.TextGenerationService; import org.springframework.beans.factory.annotation.Autowired; import org.springframework.http.ResponseEntity; import org.springframework.web.bind.annotation.*; @RestController @RequestMapping("/api/text") public class TextGenerationController { private final TextGenerationService textGenerationService; public TextGenerationController(TextGenerationService textGenerationService) { this.textGenerationService = textGenerationService; } @PostMapping("/generate") public ResponseEntity<String> generateText(@RequestBody String prompt) { try { String generatedText = textGenerationService.generateText(prompt); return ResponseEntity.ok(generatedText); } catch (Exception e) { return ResponseEntity.internalServerError() .body("文本生成失败: " + e.getMessage()); } } @PostMapping("/generate-advanced") public ResponseEntity<String> generateAdvancedText( @RequestParam String prompt, @RequestParam(defaultValue = "0.7") Double temperature, @RequestParam(defaultValue = "1000") Integer maxTokens) { try { String generatedText = textGenerationService .generateTextWithOptions(prompt, temperature, maxTokens); return ResponseEntity.ok(generatedText); } catch (Exception e) { return ResponseEntity.internalServerError() .body("高级文本生成失败: " + e.getMessage()); } } }
2. 使用示例
测试用例
package com.example.springai.service; import org.junit.jupiter.api.Test; import org.springframework.beans.factory.annotation.Autowired; import org.springframework.boot.test.context.SpringBootTest; @SpringBootTest class TextGenerationServiceTest { @Autowired private TextGenerationService textGenerationService; @Test void testSimpleTextGeneration() { String prompt = "请写一首关于春天的诗"; String result = textGenerationService.generateText(prompt); System.out.println("生成的文本: " + result); assert result != null && !result.isEmpty(); } @Test void testAdvancedTextGeneration() { String prompt = "解释什么是Spring Framework"; String result = textGenerationService.generateTextWithOptions(prompt, 0.5, 500); System.out.println("生成的文本: " + result); assert result != null && !result.isEmpty(); } }
中级应用:问答系统
1. 构建简单的问答系统
QAService.java
package com.example.springai.service; import org.springframework.ai.chat.client.ChatClient; import org.springframework.ai.chat.model.ChatRequest; import org.springframework.ai.chat.model.ChatResponse; import org.springframework.ai.chat.model.ChatMessage; import org.springframework.ai.chat.model.ChatOptions; import org.springframework.beans.factory.annotation.Autowired; import org.springframework.stereotype.Service; import java.util.List; import java.util.ArrayList; @Service public class QAService { private final ChatClient chatClient; private List<ChatMessage> conversationHistory; public QAService(ChatClient chatClient) { this.chatClient = chatClient; this.conversationHistory = new ArrayList<>(); } /** * 单轮问答 */ public String askQuestion(String question) { ChatRequest request = ChatRequest.of(question); ChatResponse response = chatClient.generate(request); String answer = response.getResult().getOutput().getContent(); // 保存对话历史 conversationHistory.add(new ChatMessage("user", question)); conversationHistory.add(new ChatMessage("assistant", answer)); return answer; } /** * 多轮对话问答 */ public String askQuestionWithContext(String question) { // 构建包含上下文的请求 List<ChatMessage> messages = new ArrayList<>(conversationHistory); messages.add(new ChatMessage("user", question)); ChatRequest request = ChatRequest.builder() .withMessages(messages) .withOptions(ChatOptions.builder() .withTemperature(0.7) .withMaxTokens(1000) .build()) .build(); ChatResponse response = chatClient.generate(request); String answer = response.getResult().getOutput().getContent(); // 更新对话历史 conversationHistory.add(new ChatMessage("user", question)); conversationHistory.add(new ChatMessage("assistant", answer)); return answer; } /** * 清理对话历史 */ public void clearConversation() { conversationHistory.clear(); } /** * 获取对话历史 */ public List<ChatMessage> getConversationHistory() { return new ArrayList<>(conversationHistory); } }
QAController.java
package com.example.springai.controller; import com.example.springai.service.QAService; import org.springframework.ai.chat.model.ChatMessage; import org.springframework.beans.factory.annotation.Autowired; import org.springframework.http.ResponseEntity; import org.springframework.web.bind.annotation.*; import java.util.HashMap; import java.util.List; import java.util.Map; @RestController @RequestMapping("/api/qa") public class QAController { private final QAService qaService; public QAController(QAService qaService) { this.qaService = qaService; } @PostMapping("/ask") public ResponseEntity<Map<String, String>> askQuestion(@RequestBody Map<String, String> request) { try { String question = request.get("question"); String answer = qaService.askQuestion(question); Map<String, String> response = new HashMap<>(); response.put("question", question); response.put("answer", answer); return ResponseEntity.ok(response); } catch (Exception e) { return ResponseEntity.internalServerError() .body(Map.of("error", "问答失败: " + e.getMessage())); } } @PostMapping("/ask-contextual") public ResponseEntity<Map<String, String>> askQuestionWithContext(@RequestBody Map<String, String> request) { try { String question = request.get("question"); String answer = qaService.askQuestionWithContext(question); Map<String, String> response = new HashMap<>(); response.put("question", question); response.put("answer", answer); return ResponseEntity.ok(response); } catch (Exception e) { return ResponseEntity.internalServerError() .body(Map.of("error", "上下文问答失败: " + e.getMessage())); } } @GetMapping("/history") public ResponseEntity<List<ChatMessage>> getConversationHistory() { try { List<ChatMessage> history = qaService.getConversationHistory(); return ResponseEntity.ok(history); } catch (Exception e) { return ResponseEntity.internalServerError().build(); } } @DeleteMapping("/clear") public ResponseEntity<Map<String, String>> clearConversation() { try { qaService.clearConversation(); return ResponseEntity.ok(Map.of("message", "对话历史已清空")); } catch (Exception e) { return ResponseEntity.internalServerError() .body(Map.of("error", "清空失败: " + e.getMessage())); } } }
2. 文档问答系统
DocumentQAService.java
package com.example.springai.service; import org.springframework.ai.chat.client.ChatClient; import org.springframework.ai.chat.model.ChatRequest; import org.springframework.ai.chat.model.ChatResponse; import org.springframework.ai.embedding.EmbeddingClient; import org.springframework.ai.embedding.EmbeddingRequest; import org.springframework.ai.embedding.EmbeddingResponse; import org.springframework.beans.factory.annotation.Autowired; import org.springframework.stereotype.Service; import java.util.List; import java.util.ArrayList; import java.util.Map; import java.util.HashMap; @Service public class DocumentQAService { private final ChatClient chatClient; private final EmbeddingClient embeddingClient; private Map<String, List<Double>> documentEmbeddings; private Map<String, String> documentContents; public DocumentQAService(ChatClient chatClient, EmbeddingClient embeddingClient) { this.chatClient = chatClient; this.embeddingClient = embeddingClient; this.documentEmbeddings = new HashMap<>(); this.documentContents = new HashMap<>(); } /** * 添加文档 */ public void addDocument(String docId, String content) { documentContents.put(docId, content); // 生成文档嵌入 EmbeddingRequest request = EmbeddingRequest.of(content); EmbeddingResponse response = embeddingClient.embed(request); List<Double> embedding = response.getResult().getOutput().getEmbeddings().get(0).getEmbedding(); documentEmbeddings.put(docId, embedding); } /** * 基于文档回答问题 */ public String answerQuestionFromDocuments(String question, String docId) { if (!documentContents.containsKey(docId)) { return "文档不存在"; } String documentContent = documentContents.get(docId); String prompt = String.format( "基于以下文档内容回答问题:\n\n文档内容:\n%s\n\n问题:%s\n\n请提供详细准确的答案:", documentContent, question ); ChatRequest request = ChatRequest.of(prompt); ChatResponse response = chatClient.generate(request); return response.getResult().getOutput().getContent(); } /** * 查找相关文档 */ public String findMostRelevantDocument(String question) { if (documentEmbeddings.isEmpty()) { return "没有可用的文档"; } // 生成问题的嵌入 EmbeddingRequest request = EmbeddingRequest.of(question); EmbeddingResponse response = embeddingClient.embed(request); List<Double> questionEmbedding = response.getResult().getOutput().getEmbeddings().get(0).getEmbedding(); // 计算相似度 String mostRelevantDoc = null; double maxSimilarity = -1.0; for (Map.Entry<String, List<Double>> entry : documentEmbeddings.entrySet()) { double similarity = cosineSimilarity(questionEmbedding, entry.getValue()); if (similarity > maxSimilarity) { maxSimilarity = similarity; mostRelevantDoc = entry.getKey(); } } return mostRelevantDoc; } /** * 计算余弦相似度 */ private double cosineSimilarity(List<Double> vectorA, List<Double> vectorB) { if (vectorA.size() != vectorB.size()) { return 0.0; } double dotProduct = 0.0; double normA = 0.0; double normB = 0.0; for (int i = 0; i < vectorA.size(); i++) { dotProduct += vectorA.get(i) * vectorB.get(i); normA += Math.pow(vectorA.get(i), 2); normB += Math.pow(vectorB.get(i), 2); } return dotProduct / (Math.sqrt(normA) * Math.sqrt(normB)); } }
关键技术详解
1. ChatClient 深入理解
配置选项详解
@Component public class AdvancedChatClientConfig { @Autowired private ChatClient.Builder chatClientBuilder; @Bean public ChatClient customChatClient() { return chatClientBuilder .withDefaultOptions(ChatOptions.builder() .withModel("gpt-3.5-turbo") .withTemperature(0.7) .withMaxTokens(2000) .withTopP(1.0) .withPresencePenalty(0.0) .withFrequencyPenalty(0.0) .build()) .build(); } }
消息构建最佳实践
public class MessageBuilder { /** * 构建系统消息 */ public static ChatMessage buildSystemMessage(String content) { return ChatMessage.builder() .withRole(ChatMessage.Role.SYSTEM) .withContent(content) .build(); } /** * 构建用户消息 */ public static ChatMessage buildUserMessage(String content) { return ChatMessage.builder() .withRole(ChatMessage.Role.USER) .withContent(content) .build(); } /** * 构建助手消息 */ public static ChatMessage buildAssistantMessage(String content) { return ChatMessage.builder() .withRole(ChatMessage.Role.ASSISTANT) .withContent(content) .build(); } }
2. EmbeddingClient 应用
向量数据库集成示例
@Service public class VectorDatabaseService { private final EmbeddingClient embeddingClient; private Map<String, VectorDocument> vectorDatabase; public VectorDatabaseService(EmbeddingClient embeddingClient) { this.embeddingClient = embeddingClient; this.vectorDatabase = new HashMap<>(); } /** * 向量存储 */ public void storeVector(String id, String content) { List<Double> embedding = generateEmbedding(content); VectorDocument doc = new VectorDocument(id, content, embedding); vectorDatabase.put(id, doc); } /** * 向量搜索 */ public List<VectorDocument> searchSimilar(String query, int topK) { List<Double> queryEmbedding = generateEmbedding(query); List<VectorDocument> candidates = new ArrayList<>(); for (VectorDocument doc : vectorDatabase.values()) { double similarity = cosineSimilarity(queryEmbedding, doc.getEmbedding()); doc.setSimilarity(similarity); candidates.add(doc); } return candidates.stream() .sorted((a, b) -> Double.compare(b.getSimilarity(), a.getSimilarity())) .limit(topK) .collect(Collectors.toList()); } private List<Double> generateEmbedding(String text) { EmbeddingRequest request = EmbeddingRequest.of(text); EmbeddingResponse response = embeddingClient.embed(request); return response.getResult().getOutput().getEmbeddings().get(0).getEmbedding(); } private double cosineSimilarity(List<Double> vectorA, List<Double> vectorB) { // 实现余弦相似度计算 double dotProduct = 0.0; double normA = 0.0; double normB = 0.0; for (int i = 0; i < vectorA.size(); i++) { dotProduct += vectorA.get(i) * vectorB.get(i); normA += Math.pow(vectorA.get(i), 2); normB += Math.pow(vectorB.get(i), 2); } return dotProduct / (Math.sqrt(normA) * Math.sqrt(normB)); } } class VectorDocument { private String id; private String content; private List<Double> embedding; private double similarity; // 构造函数、getter 和 setter public VectorDocument(String id, String content, List<Double> embedding) { this.id = id; this.content = content; this.embedding = embedding; } // getter and setter methods... }
3. 错误处理与重试机制
ChatServiceWithRetry.java
package com.example.springai.service; import org.springframework.ai.chat.client.ChatClient; import org.springframework.ai.chat.model.ChatRequest; import org.springframework.ai.chat.model.ChatResponse; import org.springframework.retry.annotation.Backoff"; import org.springframework.retry.annotation.Retryable; import org.springframework.retry.support.RetryTemplate; import org.springframework.stereotype.Service; import org.springframework.retry.policy.SimpleRetryPolicy; import org.springframework.retry.backoff.ExponentialBackOffPolicy; @Service public class ChatServiceWithRetry { private final ChatClient chatClient; private final RetryTemplate retryTemplate; public ChatServiceWithRetry(ChatClient chatClient) { this.chatClient = chatClient; this.retryTemplate = createRetryTemplate(); } /** * 带重试机制的聊天服务 */ @Retryable(value = {Exception.class}, maxAttempts = 3, backoff = @Backoff(delay = 1000)) public String generateWithRetry(String prompt) { try { ChatRequest request = ChatRequest.of(prompt); ChatResponse response = chatClient.generate(request); return response.getResult().getOutput().getContent(); } catch (Exception e) { throw new RuntimeException("聊天生成失败", e); } } /** * 手动重试模板 */ private byte[] generateWithManualRetry(String prompt) { try { return retryTemplate.execute(context -> { ChatRequest request = ChatRequest.of(prompt); ChatResponse response = chatClient.generate(request); return response.getResult().getOutput().getContent().getBytes(); }); } catch (Exception e) { throw new RuntimeException("重试失败", e); } } private RetryTemplate createRetryTemplate() { RetryTemplate template = new RetryTemplate(); // 重试策略 SimpleRetryPolicy retryPolicy = new SimpleRetryPolicy(); retryPolicy.setMaxAttempts(3); template.setRetryPolicy(retryPolicy); // 后退策略 ExponentialBackOffPolicy backOffPolicy = new ExponentialBackOffPolicy(); backOffPolicy.setInitialInterval(1000); backOffPolicy.setMaxInterval(10000); template.setBackOffPolicy(backOffPolicy); return template; } }
常见问题与最佳实践
1. 常见问题
问题1:API密钥配置问题
# 安全的配置方式 spring: ai: openai: api-key: ${OPENAI_API_KEY} # 使用环境变量
问题2:模型调用超时
@Configuration public class ChatClientConfig { @Bean public ChatClient chatClient() { return ChatClient.builder() .withDefaultOptions(ChatOptions.builder() .withTimeout(Duration.ofMinutes(5)) // 设置超时时间 .build()) .build(); } }
问题3:并发调用限制
package com.example.springai.service; import com.google.common.util.concurrent.RateLimiter; import org.springframework.ai.chat.client.ChatClient; import org.springframework.ai.chat.model.ChatRequest; import org.springframework.ai.chat.model.ChatResponse; import org.springframework.stereotype.Service; @Service public class RateLimitedChatService { private final ChatClient chatClient; private final RateLimiter rateLimiter; public RateLimitedChatService(ChatClient chatClient) { this.chatClient = chatClient; this.rateLimiter = RateLimiter.create(10.0); // 每秒10个请求 } public String generateText(String prompt) { rateLimiter.acquire(); // 限流 ChatRequest request = ChatRequest.of(prompt); ChatResponse response = chatClient.generate(request); return response.getResult().getOutput().getContent(); } }
2. 最佳实践
1. 配置管理最佳实践
@ConfigurationProperties(prefix = "spring.ai.openai") @Configuration public class OpenAIConfigProperties { private String apiKey; private String model = "gpt-3.5-turbo"; private Double temperature = 0.7; private Integer maxTokens = 1000; private Duration timeout = Duration.ofMinutes(2); // getter and setter methods... }
2. 异常处理最佳实践
@Service public class RobustChatService { private final ChatClient chatClient; private final Logger logger = LoggerFactory.getLogger(getClass()); public RobustChatService(ChatClient chatClient) { this.chatClient = chatClient; } public CompletableFuture<String> generateTextAsync(String prompt) { return CompletableFuture.supplyAsync(() -> { try { ChatRequest request = ChatRequest.of(prompt); ChatResponse response = chatClient.generate(request); return response.getResult().getOutput().getContent(); } catch (Exception e) { logger.error("文本生成失败: {}", e.getMessage(), e); throw new RuntimeException("文本生成失败", e); } }); } @Retryable(value = ResourceAccessException.class, maxAttempts = 3) public String generateWithCircuitBreaker(String prompt) { try { ChatRequest request = ChatRequest.of(prompt); ChatResponse response = chatClient.generate(request); return response.getResult().getOutput().getContent(); } catch (ResourceAccessException e) { logger.warn("网络访问异常,准备重试: {}", e.getMessage()); throw e; } } }
3. 性能优化最佳实践
@Service public class OptimizedChatService { private final ChatClient chatClient; private final CacheManager cacheManager; private final CounterMetrics metrics; public OptimizedChatService(ChatClient chatClient, CacheManager cacheManager) { this.chatClient = chatClient; this.cacheManager = cacheManager; this.metrics = new CounterMetrics(); } @Cacheable(value = "chat-responses", key = "#prompt") public String generateTextCached(String prompt) { metrics.increment("chat.requests"); long startTime = System.currentTimeMillis(); try { ChatRequest request = ChatRequest.of(prompt); ChatResponse response = chatClient.generate(request); long duration = System.currentTimeMillis() - startTime; metrics.record("chat.duration", duration); return response.getResult().getOutput().getContent(); } catch (Exception e) { metrics.increment("chat.errors"); throw new RuntimeException("文本生成失败", e); } } }
-
集成测试示例
package com.example.springai.integration; import org.junit.jupiter.api.Test; import org.springframework.boot.test.context.SpringBootTest; import org.springframework.test.context.TestPropertySource; import org.springframework.test.context.junit.jupiter.SpringJUnitJupiterConfig; @SpringBootTest @TestPropertySource(properties = { "spring.ai.openai.api-key=${OPENAI_API_KEY}" }) class SpringAIIntegrationTest { @Autowired private TextGenerationService textGenerationService; @Autowired private QAService qaService; @Test void testTextGenerationIntegration() { // 集成测试 String prompt = "请简单介绍一下Spring AI"; String result = textGenerationService.generateText(prompt); assert result != null && !result.isEmpty(); assert result.length() > 50; // 确保生成足够的文本 } @Test void testQASystemIntegration() { // 问答系统集成测试 String question = "什么是Spring Boot?"; String answer = qaService.askQuestion(question); assert answer != null && !answer.isEmpty(); assert answer.contains("Spring") || answer.contains("spring"); } }
总结
本文介绍了 Spring AI 的初级到中级应用,涵盖了:
-
基础配置:环境搭建和项目配置
-
核心概念:ChatClient 和 EmbeddingClient
-
实际应用:文本生成和问答系统
-
技术要点:消息构建、向量计算、错误处理
-
最佳实践:性能优化、缓存、限流等
掌握了这些内容后,您已经可以在 Spring 应用中成功集成和使用 Spring AI 的基本功能。接下来可以学习更高级的功能,如模型微调、多模态处理等。
进阶学习建议
⚠️ 常见陷阱:
-
API限制:注意OpenAI的RPM和TPM限制,合理设置重试机制
-
初始化配置:建议使用 @PostConstruct 确保Bean完全初始化后再使用
-
资源清理:WebSocket连接和长连接要及时清理,避免内存泄漏
-
错误处理:AI服务具有不确定性,必需完善的异常处理机制
扩展阅读
本文基于 Spring AI 1.0.0-M3 版本编写,如有更新请参考最新官方文档。
更多推荐
所有评论(0)