点击投票为我的2025博客之星评选助力!


首个AI项目实战:3步打造智能对话机器人(Python+Spring Boot+LangChain4j)

引言:为什么需要完整的AI应用架构?

在人工智能浪潮席卷全球的今天,许多开发者都渴望将大模型能力集成到自己的应用中。然而,从零开始构建一个生产级的AI应用面临着诸多挑战:如何设计可扩展的架构?如何确保服务稳定性?如何处理跨语言协同?本文将通过一个完整的智能对话机器人项目,带你从零到一理解"数据-模型-服务"的全流程。

本文的智能对话机器人项目将展示如何结合Python的AI生态、Java的企业级框架和现代部署技术,构建一个既具备强大AI能力又具备生产级稳定性的应用。无论你是Java开发者希望集成AI能力,还是Python开发者希望将模型服务化,都能从中获得实用的解决方案。

一、项目架构设计:分层解耦的智能系统

1.1 整体架构概览

现代AI应用需要一个清晰的分层架构来确保可维护性和可扩展性。我们的智能对话机器人采用三层架构设计:

基础设施层

Python AI核心层

Java接口层

前端层

Web界面

Mobile App

Spring Boot REST API

业务逻辑处理

会话管理

权限控制

限流熔断

FastAPI服务

Prompt工程

LangChain处理

大模型调用

响应后处理

Redis缓存

MySQL数据库

Docker容器

Kubernetes集群

1.2 技术栈选择与考量

前端交互层:

  • Vue 3 + TypeScript:提供响应式用户界面
  • Element Plus:组件库,加速开发
  • Axios:HTTP客户端,处理API调用

Java接口层:

  • Spring Boot 3.x:快速构建RESTful API
  • Spring Security:身份验证与授权
  • Spring Data JPA:数据持久化
  • Resilience4j:熔断与限流
  • LangChain4j:Java版LangChain,简化AI集成

Python AI核心层:

  • FastAPI:高性能Python Web框架
  • LangChain:AI应用开发框架
  • OpenAI/通义千问/文心一言:大模型API
  • Pydantic:数据验证与设置管理

部署与运维:

  • Docker:容器化部署
  • Docker Compose:本地多服务编排
  • Nginx:反向代理与负载均衡
  • Prometheus + Grafana:监控与告警

1.3 项目目录结构

smart-chatbot/
├── frontend/                 # 前端项目
│   ├── src/
│   │   ├── components/      # 组件
│   │   ├── views/          # 页面
│   │   └── api/            # API调用
│   └── package.json
├── backend/                 # Java后端
│   ├── src/main/java/
│   │   ├── controller/     # 控制器
│   │   ├── service/       # 业务服务
│   │   ├── repository/    # 数据访问
│   │   ├── config/       # 配置类
│   │   └── dto/          # 数据传输对象
│   └── pom.xml
├── ai-service/             # Python AI服务
│   ├── app/
│   │   ├── core/          # 核心AI逻辑
│   │   ├── models/        # 数据模型
│   │   └── utils/         # 工具函数
│   ├── requirements.txt
│   └── main.py
├── docker-compose.yml      # Docker编排
└── README.md

二、关键技术实现:从API调用到跨语言协同

2.1 Python AI核心层实现

2.1.1 FastAPI应用骨架
# ai-service/main.py
from fastapi import FastAPI, HTTPException, Depends
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel, Field
from typing import List, Optional
import logging
from datetime import datetime

# 配置日志
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

app = FastAPI(
    title="智能对话AI服务",
    description="基于LangChain的智能对话处理引擎",
    version="1.0.0"
)

# 添加CORS中间件
app.add_middleware(
    CORSMiddleware,
    allow_origins=["*"],  # 生产环境应限制来源
    allow_credentials=True,
    allow_methods=["*"],
    allow_headers=["*"],
)

# 请求/响应模型
class ChatRequest(BaseModel):
    """聊天请求模型"""
    message: str = Field(..., min_length=1, max_length=2000)
    session_id: Optional[str] = Field(None, description="会话ID")
    temperature: float = Field(0.7, ge=0.0, le=1.0)
    max_tokens: int = Field(1000, ge=50, le=4000)

class ChatResponse(BaseModel):
    """聊天响应模型"""
    response: str
    session_id: str
    tokens_used: int
    processing_time: float
    timestamp: datetime

@app.get("/health")
async def health_check():
    """健康检查端点"""
    return {"status": "healthy", "timestamp": datetime.now()}

@app.post("/chat", response_model=ChatResponse)
async def chat_endpoint(request: ChatRequest):
    """处理聊天请求"""
    start_time = datetime.now()
    
    try:
        # 调用AI处理逻辑
        result = await process_chat_message(request)
        
        # 计算处理时间
        processing_time = (datetime.now() - start_time).total_seconds()
        
        return ChatResponse(
            response=result["response"],
            session_id=result["session_id"],
            tokens_used=result["tokens_used"],
            processing_time=processing_time,
            timestamp=datetime.now()
        )
    except Exception as e:
        logger.error(f"聊天处理失败: {str(e)}")
        raise HTTPException(status_code=500, detail=str(e))

if __name__ == "__main__":
    import uvicorn
    uvicorn.run(app, host="0.0.0.0", port=8000)
2.1.2 LangChain与大模型集成
# ai-service/app/core/chat_processor.py
import os
from typing import Dict, Any, Optional
from langchain.chat_models import ChatOpenAI
from langchain.schema import HumanMessage, SystemMessage, AIMessage
from langchain.memory import ConversationBufferMemory
from langchain.chains import ConversationChain
from langchain.prompts import PromptTemplate
import json

class ChatProcessor:
    """聊天处理器,封装LangChain功能"""
    
    def __init__(self, model_name: str = "gpt-3.5-turbo"):
        """
        初始化聊天处理器
        
        Args:
            model_name: 模型名称,支持gpt-3.5-turbo, gpt-4等
        """
        # 从环境变量获取API密钥
        api_key = os.getenv("OPENAI_API_KEY")
        if not api_key:
            raise ValueError("未设置OPENAI_API_KEY环境变量")
        
        # 初始化大模型
        self.llm = ChatOpenAI(
            model_name=model_name,
            temperature=0.7,
            max_tokens=1000,
            openai_api_key=api_key
        )
        
        # 初始化对话链
        self.conversation_chains = {}
        
        # 系统提示词模板
        self.system_prompt = """你是一个智能助手,具有以下特点:
        1. 回答专业、准确、有用
        2. 对于不确定的信息,会明确说明
        3. 保持友好和耐心的态度
        4. 如果用户的问题涉及多个方面,会分点回答
        5. 避免使用过于复杂的术语,除非用户是专业人士
        
        当前对话历史:
        {history}
        
        用户问题:{input}
        
        请根据以上信息,给出最佳回答:"""
        
        self.prompt_template = PromptTemplate(
            input_variables=["history", "input"],
            template=self.system_prompt
        )
    
    def get_or_create_chain(self, session_id: str) -> ConversationChain:
        """获取或创建会话链"""
        if session_id not in self.conversation_chains:
            memory = ConversationBufferMemory(
                memory_key="history",
                return_messages=True
            )
            
            chain = ConversationChain(
                llm=self.llm,
                memory=memory,
                prompt=self.prompt_template,
                verbose=False  # 生产环境设为False
            )
            self.conversation_chains[session_id] = chain
        
        return self.conversation_chains[session_id]
    
    async def process_message(self, message: str, session_id: str) -> Dict[str, Any]:
        """处理单条消息"""
        try:
            # 获取会话链
            chain = self.get_or_create_chain(session_id)
            
            # 执行对话链
            response = await chain.arun(input=message)
            
            # 获取对话历史
            history = chain.memory.load_memory_variables({})["history"]
            
            # 估计token使用量(简化版本)
            tokens_used = self._estimate_tokens(message + response)
            
            return {
                "response": response,
                "session_id": session_id,
                "tokens_used": tokens_used,
                "history": history
            }
        except Exception as e:
            raise Exception(f"消息处理失败: {str(e)}")
    
    def _estimate_tokens(self, text: str) -> int:
        """简单估计token使用量"""
        # 实际应用中应使用tiktoken等库精确计算
        return len(text) // 4  # 近似估算
    
    def clear_session(self, session_id: str):
        """清除会话历史"""
        if session_id in self.conversation_chains:
            self.conversation_chains[session_id].memory.clear()
2.1.3 高级Prompt工程
# ai-service/app/core/prompt_engineer.py
from typing import Dict, List, Optional
from enum import Enum
import re

class ConversationType(Enum):
    """对话类型枚举"""
    GENERAL = "general"          # 通用对话
    TECHNICAL = "technical"      # 技术问答
    CREATIVE = "creative"        # 创意写作
    ANALYSIS = "analysis"        # 分析问题
    TUTORIAL = "tutorial"        # 教程指导

class PromptEngineer:
    """提示词工程师,优化AI输出"""
    
    def __init__(self):
        # 不同类型对话的模板
        self.templates = {
            ConversationType.GENERAL: {
                "system": "你是一个有用的助手,回答问题时应该清晰、准确、友好。",
                "user": "{question}"
            },
            ConversationType.TECHNICAL: {
                "system": """你是一个技术专家,回答技术问题时应:
                1. 先给出核心结论
                2. 提供详细解释和原理
                3. 给出实际代码示例(如果适用)
                4. 说明注意事项和最佳实践
                5. 提供进一步学习资源""",
                "user": "技术问题:{question}"
            },
            ConversationType.CREATIVE: {
                "system": """你是一个创意写作者,具有丰富的想象力。
                你的回答应该:
                1. 生动有趣
                2. 结构清晰
                3. 包含细节描写
                4. 情感丰富""",
                "user": "创作请求:{question}"
            }
        }
        
        # 响应格式化规则
        self.formatting_rules = {
            "code_blocks": {
                "pattern": r"```(.*?)```",
                "replacement": r"<pre><code>\1</code></pre>"
            },
            "bullet_points": {
                "pattern": r"^\s*[-*]\s+(.+)$",
                "replacement": r"<li>\1</li>"
            }
        }
    
    def detect_conversation_type(self, user_input: str) -> ConversationType:
        """检测对话类型"""
        user_input_lower = user_input.lower()
        
        # 技术关键词
        technical_keywords = [
            "如何实现", "代码", "api", "配置", "错误", "bug",
            "how to", "code", "implement", "debug", "error"
        ]
        
        # 创意关键词
        creative_keywords = [
            "写一个故事", "创作", "想象", "诗歌", "小说",
            "story", "creative", "imagine", "poem"
        ]
        
        # 分析关键词
        analysis_keywords = [
            "分析", "比较", "优缺点", "为什么", "如何选择",
            "analyze", "compare", "pros and cons", "why"
        ]
        
        for keyword in technical_keywords:
            if keyword in user_input_lower:
                return ConversationType.TECHNICAL
        
        for keyword in creative_keywords:
            if keyword in user_input_lower:
                return ConversationType.CREATIVE
        
        for keyword in analysis_keywords:
            if keyword in user_input_lower:
                return ConversationType.ANALYSIS
        
        return ConversationType.GENERAL
    
    def build_prompt(self, user_input: str, history: Optional[str] = None) -> Dict[str, str]:
        """构建优化的提示词"""
        # 检测对话类型
        conv_type = self.detect_conversation_type(user_input)
        
        # 获取对应模板
        template = self.templates.get(conv_type, self.templates[ConversationType.GENERAL])
        
        # 构建系统消息
        system_message = template["system"]
        
        # 如果有历史记录,添加到系统消息中
        if history:
            system_message += f"\n\n之前的对话历史:{history}"
        
        # 构建用户消息
        user_message = template["user"].format(question=user_input)
        
        return {
            "system": system_message,
            "user": user_message,
            "type": conv_type.value
        }
    
    def format_response(self, response: str, conv_type: ConversationType) -> str:
        """格式化AI响应"""
        formatted_response = response
        
        # 根据对话类型应用不同的格式化规则
        if conv_type == ConversationType.TECHNICAL:
            # 确保代码块正确格式化
            formatted_response = re.sub(
                self.formatting_rules["code_blocks"]["pattern"],
                self.formatting_rules["code_blocks"]["replacement"],
                formatted_response,
                flags=re.DOTALL
            )
        
        # 转换无序列表
        formatted_response = re.sub(
            self.formatting_rules["bullet_points"]["pattern"],
            self.formatting_rules["bullet_points"]["replacement"],
            formatted_response,
            flags=re.MULTILINE
        )
        
        return formatted_response

2.2 Java接口层实现

2.2.1 Spring Boot应用配置
// backend/src/main/java/com/smartchatbot/config/WebConfig.java
package com.smartchatbot.config;

import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Configuration;
import org.springframework.web.client.RestTemplate;
import org.springframework.web.servlet.config.annotation.CorsRegistry;
import org.springframework.web.servlet.config.annotation.WebMvcConfigurer;
import org.springframework.http.client.HttpComponentsClientHttpRequestFactory;
import org.springframework.beans.factory.annotation.Value;
import java.time.Duration;

@Configuration
public class WebConfig implements WebMvcConfigurer {
    
    @Value("${ai.service.url}")
    private String aiServiceUrl;
    
    @Value("${ai.service.timeout:30}")
    private int aiServiceTimeout;
    
    @Override
    public void addCorsMappings(CorsRegistry registry) {
        registry.addMapping("/api/**")
                .allowedOrigins("http://localhost:3000")
                .allowedMethods("GET", "POST", "PUT", "DELETE", "OPTIONS")
                .allowedHeaders("*")
                .allowCredentials(true);
    }
    
    @Bean
    public RestTemplate restTemplate() {
        // 配置HTTP客户端工厂,支持连接池和超时设置
        HttpComponentsClientHttpRequestFactory factory = 
            new HttpComponentsClientHttpRequestFactory();
        
        // 连接超时(毫秒)
        factory.setConnectTimeout(5000);
        // 读取超时(毫秒)
        factory.setReadTimeout(aiServiceTimeout * 1000);
        
        RestTemplate restTemplate = new RestTemplate(factory);
        
        // 添加请求拦截器(用于添加认证头等)
        restTemplate.getInterceptors().add((request, body, execution) -> {
            request.getHeaders().add("X-AI-Service-Key", "your-service-key");
            return execution.execute(request, body);
        });
        
        return restTemplate;
    }
}
2.2.2 业务逻辑层实现
// backend/src/main/java/com/smartchatbot/service/ChatService.java
package com.smartchatbot.service;

import com.smartchatbot.dto.ChatRequest;
import com.smartchatbot.dto.ChatResponse;
import com.smartchatbot.entity.ChatSession;
import com.smartchatbot.entity.ChatMessage;
import com.smartchatbot.repository.ChatSessionRepository;
import com.smartchatbot.repository.ChatMessageRepository;
import com.fasterxml.jackson.databind.ObjectMapper;
import lombok.extern.slf4j.Slf4j;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.beans.factory.annotation.Value;
import org.springframework.http.*;
import org.springframework.stereotype.Service;
import org.springframework.web.client.RestTemplate;
import io.github.resilience4j.circuitbreaker.annotation.CircuitBreaker;
import io.github.resilience4j.ratelimiter.annotation.RateLimiter;
import javax.transaction.Transactional;
import java.time.LocalDateTime;
import java.util.*;
import java.util.concurrent.ConcurrentHashMap;

@Service
@Slf4j
public class ChatService {
    
    @Value("${ai.service.url}")
    private String aiServiceUrl;
    
    @Autowired
    private RestTemplate restTemplate;
    
    @Autowired
    private ChatSessionRepository sessionRepository;
    
    @Autowired
    private ChatMessageRepository messageRepository;
    
    @Autowired
    private ObjectMapper objectMapper;
    
    // 本地会话缓存(减少数据库查询)
    private final Map<String, ChatSession> sessionCache = new ConcurrentHashMap<>();
    
    // 熔断器配置
    private static final String CIRCUIT_BREAKER_NAME = "aiService";
    private static final String RATE_LIMITER_NAME = "chatRateLimiter";
    
    /**
     * 处理聊天请求
     */
    @Transactional
    @CircuitBreaker(name = CIRCUIT_BREAKER_NAME, fallbackMethod = "fallbackResponse")
    @RateLimiter(name = RATE_LIMITER_NAME)
    public ChatResponse processChat(ChatRequest request, String userId) {
        long startTime = System.currentTimeMillis();
        
        try {
            // 1. 获取或创建会话
            ChatSession session = getOrCreateSession(request.getSessionId(), userId);
            
            // 2. 保存用户消息
            ChatMessage userMessage = saveMessage(
                session, 
                request.getMessage(), 
                "user", 
                userId
            );
            
            // 3. 调用AI服务
            Map<String, Object> aiRequest = new HashMap<>();
            aiRequest.put("message", request.getMessage());
            aiRequest.put("session_id", session.getSessionId());
            aiRequest.put("temperature", request.getTemperature());
            aiRequest.put("max_tokens", request.getMaxTokens());
            
            HttpHeaders headers = new HttpHeaders();
            headers.setContentType(MediaType.APPLICATION_JSON);
            HttpEntity<Map<String, Object>> entity = new HttpEntity<>(aiRequest, headers);
            
            log.info("调用AI服务: {}", aiServiceUrl);
            
            ResponseEntity<Map> response = restTemplate.postForEntity(
                aiServiceUrl + "/chat", 
                entity, 
                Map.class
            );
            
            if (response.getStatusCode() != HttpStatus.OK) {
                throw new RuntimeException("AI服务返回错误: " + response.getStatusCode());
            }
            
            // 4. 解析AI响应
            Map<String, Object> aiResponse = response.getBody();
            String aiMessage = (String) aiResponse.get("response");
            int tokensUsed = (int) aiResponse.get("tokens_used");
            
            // 5. 保存AI响应
            ChatMessage aiResponseMessage = saveMessage(
                session, 
                aiMessage, 
                "assistant", 
                "AI"
            );
            
            // 6. 更新会话
            session.setLastActiveTime(LocalDateTime.now());
            session.setTotalTokens(session.getTotalTokens() + tokensUsed);
            sessionRepository.save(session);
            
            // 7. 更新缓存
            sessionCache.put(session.getSessionId(), session);
            
            // 8. 计算处理时间
            long processingTime = System.currentTimeMillis() - startTime;
            
            return ChatResponse.builder()
                .message(aiMessage)
                .sessionId(session.getSessionId())
                .tokensUsed(tokensUsed)
                .processingTime(processingTime)
                .messageId(aiResponseMessage.getMessageId())
                .timestamp(LocalDateTime.now())
                .build();
                
        } catch (Exception e) {
            log.error("聊天处理失败", e);
            throw new RuntimeException("聊天处理失败: " + e.getMessage(), e);
        }
    }
    
    /**
     * 获取或创建会话
     */
    private ChatSession getOrCreateSession(String sessionId, String userId) {
        // 先从缓存中查找
        if (sessionId != null && sessionCache.containsKey(sessionId)) {
            return sessionCache.get(sessionId);
        }
        
        // 缓存中没有,从数据库查找
        ChatSession session = null;
        if (sessionId != null) {
            session = sessionRepository.findBySessionId(sessionId);
        }
        
        // 创建新会话
        if (session == null) {
            session = ChatSession.builder()
                .sessionId(UUID.randomUUID().toString())
                .userId(userId)
                .startTime(LocalDateTime.now())
                .lastActiveTime(LocalDateTime.now())
                .totalTokens(0)
                .messageCount(0)
                .build();
            
            session = sessionRepository.save(session);
        }
        
        // 加入缓存
        sessionCache.put(session.getSessionId(), session);
        
        return session;
    }
    
    /**
     * 保存消息
     */
    private ChatMessage saveMessage(ChatSession session, String content, 
                                   String role, String sender) {
        ChatMessage message = ChatMessage.builder()
            .session(session)
            .content(content)
            .role(role)
            .sender(sender)
            .timestamp(LocalDateTime.now())
            .build();
        
        message = messageRepository.save(message);
        
        // 更新会话消息计数
        session.setMessageCount(session.getMessageCount() + 1);
        
        return message;
    }
    
    /**
     * 熔断器降级方法
     */
    public ChatResponse fallbackResponse(ChatRequest request, String userId, Throwable t) {
        log.warn("AI服务不可用,使用降级响应,异常: {}", t.getMessage());
        
        // 返回一个友好的降级响应
        return ChatResponse.builder()
            .message("抱歉,AI服务暂时不可用,请稍后再试。")
            .sessionId(request.getSessionId() != null ? 
                      request.getSessionId() : UUID.randomUUID().toString())
            .tokensUsed(0)
            .processingTime(0)
            .messageId(UUID.randomUUID().toString())
            .timestamp(LocalDateTime.now())
            .fallback(true)
            .build();
    }
    
    /**
     * 清除会话缓存
     */
    public void clearSessionCache(String sessionId) {
        sessionCache.remove(sessionId);
    }
    
    /**
     * 获取会话历史
     */
    public List<ChatMessage> getChatHistory(String sessionId, String userId) {
        // 验证用户是否有权限访问该会话
        ChatSession session = sessionRepository.findBySessionIdAndUserId(sessionId, userId);
        if (session == null) {
            throw new RuntimeException("会话不存在或无权访问");
        }
        
        return messageRepository.findBySessionOrderByTimestampAsc(session);
    }
}
2.2.3 控制器层
// backend/src/main/java/com/smartchatbot/controller/ChatController.java
package com.smartchatbot.controller;

import com.smartchatbot.dto.ChatRequest;
import com.smartchatbot.dto.ChatResponse;
import com.smartchatbot.service.ChatService;
import io.swagger.v3.oas.annotations.Operation;
import io.swagger.v3.oas.annotations.tags.Tag;
import lombok.extern.slf4j.Slf4j;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.http.ResponseEntity;
import org.springframework.web.bind.annotation.*;
import javax.validation.Valid;
import java.util.List;

@RestController
@RequestMapping("/api/chat")
@Tag(name = "聊天接口", description = "智能对话机器人API")
@Slf4j
public class ChatController {
    
    @Autowired
    private ChatService chatService;
    
    @PostMapping("/send")
    @Operation(summary = "发送消息", description = "向AI发送消息并获取响应")
    public ResponseEntity<ChatResponse> sendMessage(
            @Valid @RequestBody ChatRequest request,
            @RequestHeader("X-User-Id") String userId) {
        
        log.info("收到聊天请求,用户ID: {}, 会话ID: {}", 
                userId, request.getSessionId());
        
        ChatResponse response = chatService.processChat(request, userId);
        
        return ResponseEntity.ok(response);
    }
    
    @GetMapping("/history/{sessionId}")
    @Operation(summary = "获取聊天历史", description = "获取指定会话的聊天历史记录")
    public ResponseEntity<List<ChatMessage>> getHistory(
            @PathVariable String sessionId,
            @RequestHeader("X-User-Id") String userId) {
        
        List<ChatMessage> history = chatService.getChatHistory(sessionId, userId);
        
        return ResponseEntity.ok(history);
    }
    
    @DeleteMapping("/session/{sessionId}")
    @Operation(summary = "清除会话", description = "清除指定会话的历史记录")
    public ResponseEntity<Void> clearSession(
            @PathVariable String sessionId,
            @RequestHeader("X-User-Id") String userId) {
        
        chatService.clearSessionCache(sessionId);
        
        return ResponseEntity.ok().build();
    }
    
    @GetMapping("/health")
    @Operation(summary = "健康检查", description = "检查服务是否正常运行")
    public ResponseEntity<HealthStatus> healthCheck() {
        HealthStatus status = HealthStatus.builder()
                .status("UP")
                .timestamp(LocalDateTime.now())
                .service("chat-service")
                .version("1.0.0")
                .build();
        
        return ResponseEntity.ok(status);
    }
}

// 数据传输对象
@Data
@Builder
@NoArgsConstructor
@AllArgsConstructor
class ChatRequest {
    @NotBlank(message = "消息内容不能为空")
    @Size(max = 2000, message = "消息内容不能超过2000字符")
    private String message;
    
    private String sessionId;
    
    @Min(value = 0, message = "temperature不能小于0")
    @Max(value = 2, message = "temperature不能大于2")
    private Double temperature = 0.7;
    
    @Min(value = 1, message = "maxTokens不能小于1")
    @Max(value = 4000, message = "maxTokens不能大于4000")
    private Integer maxTokens = 1000;
}

@Data
@Builder
@NoArgsConstructor
@AllArgsConstructor
class ChatResponse {
    private String message;
    private String sessionId;
    private Integer tokensUsed;
    private Long processingTime;
    private String messageId;
    private LocalDateTime timestamp;
    private Boolean fallback = false;
}

2.3 前端实现

<!-- frontend/src/components/ChatInterface.vue -->
<template>
  <div class="chat-container">
    <!-- 聊天头部 -->
    <div class="chat-header">
      <h2>智能对话助手</h2>
      <div class="header-actions">
        <el-button 
          @click="clearChat" 
          size="small" 
          :disabled="messages.length === 0">
          清空对话
        </el-button>
        <el-button 
          @click="exportChat" 
          size="small" 
          :disabled="messages.length === 0">
          导出记录
        </el-button>
      </div>
    </div>

    <!-- 聊天消息区域 -->
    <div class="chat-messages" ref="messagesContainer">
      <div 
        v-for="(message, index) in messages" 
        :key="index" 
        :class="['message', message.role]">
        <div class="message-avatar">
          <el-avatar :src="getAvatar(message.role)" />
        </div>
        <div class="message-content">
          <div class="message-header">
            <span class="sender">{{ getSenderName(message.role) }}</span>
            <span class="timestamp">{{ formatTime(message.timestamp) }}</span>
          </div>
          <div 
            class="message-text" 
            v-html="formatMessage(message.content)">
          </div>
          <div v-if="message.role === 'assistant'" class="message-actions">
            <el-button 
              @click="copyToClipboard(message.content)" 
              size="mini" 
              icon="el-icon-document-copy">
              复制
            </el-button>
            <el-button 
              @click="regenerateResponse(index)" 
              size="mini" 
              icon="el-icon-refresh"
              :loading="regeneratingIndex === index">
              重试
            </el-button>
          </div>
        </div>
      </div>
      
      <!-- 加载状态 -->
      <div v-if="loading" class="loading-indicator">
        <el-skeleton :rows="3" animated />
        <div class="typing-indicator">
          <span></span><span></span><span></span>
        </div>
      </div>
    </div>

    <!-- 输入区域 -->
    <div class="chat-input-area">
      <div class="input-tools">
        <el-tooltip content="调整AI创造力" placement="top">
          <div class="temperature-control">
            <span>创造力:</span>
            <el-slider
              v-model="temperature"
              :min="0"
              :max="1"
              :step="0.1"
              style="width: 120px; margin: 0 10px;"
            />
            <span>{{ temperature.toFixed(1) }}</span>
          </div>
        </el-tooltip>
        
        <el-tooltip content="调整回答长度" placement="top">
          <div class="max-tokens-control">
            <span>长度:</span>
            <el-input-number
              v-model="maxTokens"
              :min="100"
              :max="2000"
              :step="100"
              size="small"
            />
          </div>
        </el-tooltip>
      </div>
      
      <div class="input-container">
        <el-input
          v-model="inputMessage"
          type="textarea"
          :rows="3"
          :maxlength="2000"
          placeholder="输入您的问题..."
          @keydown.enter.exact.prevent="sendMessage"
          :disabled="loading"
          resize="none"
        />
        <div class="input-actions">
          <span class="char-count">
            {{ inputMessage.length }}/2000
          </span>
          <el-button 
            type="primary" 
            @click="sendMessage" 
            :loading="loading"
            :disabled="!canSend">
            发送
          </el-button>
        </div>
      </div>
      
      <div class="quick-questions">
        <el-tag
          v-for="(question, index) in quickQuestions"
          :key="index"
          type="info"
          @click="useQuickQuestion(question)"
          class="quick-question-tag"
        >
          {{ question }}
        </el-tag>
      </div>
    </div>
  </div>
</template>

<script>
import { ref, computed, nextTick, onMounted, onUnmounted } from 'vue'
import { ElMessage, ElMessageBox } from 'element-plus'
import axios from 'axios'
import hljs from 'highlight.js'
import 'highlight.js/styles/github.css'

export default {
  name: 'ChatInterface',
  
  setup() {
    const messages = ref([])
    const inputMessage = ref('')
    const loading = ref(false)
    const temperature = ref(0.7)
    const maxTokens = ref(1000)
    const sessionId = ref(null)
    const regeneratingIndex = ref(null)
    const messagesContainer = ref(null)
    
    // 快速问题示例
    const quickQuestions = [
      '如何学习Python编程?',
      '解释一下什么是微服务架构',
      '写一个简单的JavaScript函数',
      '如何优化数据库查询性能?',
      '解释机器学习和深度学习的区别'
    ]
    
    // 计算属性
    const canSend = computed(() => {
      return inputMessage.value.trim().length > 0 && !loading.value
    })
    
    // 方法
    const sendMessage = async () => {
      if (!canSend.value) return
      
      const userMessage = inputMessage.value.trim()
      inputMessage.value = ''
      
      // 添加用户消息到界面
      messages.value.push({
        role: 'user',
        content: userMessage,
        timestamp: new Date()
      })
      
      // 滚动到底部
      scrollToBottom()
      
      // 设置加载状态
      loading.value = true
      
      try {
        // 准备请求数据
        const requestData = {
          message: userMessage,
          sessionId: sessionId.value,
          temperature: temperature.value,
          maxTokens: maxTokens.value
        }
        
        // 发送请求
        const response = await axios.post('/api/chat/send', requestData, {
          headers: {
            'X-User-Id': 'current-user-id' // 实际项目中从登录状态获取
          },
          timeout: 60000 // 60秒超时
        })
        
        // 更新会话ID
        if (!sessionId.value) {
          sessionId.value = response.data.sessionId
        }
        
        // 添加AI响应到界面
        messages.value.push({
          role: 'assistant',
          content: response.data.message,
          timestamp: new Date(response.data.timestamp),
          tokensUsed: response.data.tokensUsed,
          processingTime: response.data.processingTime
        })
        
        // 保存到本地存储
        saveToLocalStorage()
        
      } catch (error) {
        console.error('发送消息失败:', error)
        
        ElMessage.error({
          message: error.response?.data?.message || '发送失败,请重试',
          duration: 3000
        })
        
        // 添加错误消息
        messages.value.push({
          role: 'assistant',
          content: '抱歉,处理您的请求时出现了问题。请稍后再试。',
          timestamp: new Date(),
          error: true
        })
      } finally {
        loading.value = false
        scrollToBottom()
      }
    }
    
    const clearChat = async () => {
      try {
        await ElMessageBox.confirm(
          '确定要清空对话记录吗?',
          '提示',
          {
            confirmButtonText: '确定',
            cancelButtonText: '取消',
            type: 'warning'
          }
        )
        
        // 如果有会话ID,调用后端接口清除
        if (sessionId.value) {
          await axios.delete(`/api/chat/session/${sessionId.value}`, {
            headers: {
              'X-User-Id': 'current-user-id'
            }
          })
        }
        
        // 清除本地消息
        messages.value = []
        sessionId.value = null
        localStorage.removeItem('chat_messages')
        localStorage.removeItem('chat_session_id')
        
        ElMessage.success('对话已清空')
      } catch {
        // 用户取消
      }
    }
    
    const scrollToBottom = () => {
      nextTick(() => {
        if (messagesContainer.value) {
          messagesContainer.value.scrollTop = messagesContainer.value.scrollHeight
        }
      })
    }
    
    const saveToLocalStorage = () => {
      localStorage.setItem('chat_messages', JSON.stringify(messages.value))
      if (sessionId.value) {
        localStorage.setItem('chat_session_id', sessionId.value)
      }
    }
    
    const loadFromLocalStorage = () => {
      const savedMessages = localStorage.getItem('chat_messages')
      const savedSessionId = localStorage.getItem('chat_session_id')
      
      if (savedMessages) {
        messages.value = JSON.parse(savedMessages)
      }
      
      if (savedSessionId) {
        sessionId.value = savedSessionId
      }
    }
    
    const formatMessage = (content) => {
      // 高亮代码块
      let formatted = content.replace(
        /```(\w+)?\n([\s\S]*?)```/g,
        (match, lang, code) => {
          const language = lang || 'plaintext'
          const highlighted = hljs.highlight(code.trim(), { language }).value
          return `<pre><code class="hljs ${language}">${highlighted}</code></pre>`
        }
      )
      
      // 转换换行符
      formatted = formatted.replace(/\n/g, '<br>')
      
      return formatted
    }
    
    // 生命周期钩子
    onMounted(() => {
      loadFromLocalStorage()
      scrollToBottom()
      
      // 添加键盘快捷键
      window.addEventListener('keydown', handleKeyDown)
    })
    
    onUnmounted(() => {
      window.removeEventListener('keydown', handleKeyDown)
    })
    
    const handleKeyDown = (e) => {
      // Ctrl+Enter 发送消息
      if (e.ctrlKey && e.key === 'Enter') {
        sendMessage()
      }
    }
    
    return {
      messages,
      inputMessage,
      loading,
      temperature,
      maxTokens,
      sessionId,
      regeneratingIndex,
      messagesContainer,
      quickQuestions,
      canSend,
      sendMessage,
      clearChat,
      scrollToBottom,
      formatMessage,
      getAvatar: (role) => {
        return role === 'user' ? '/user-avatar.png' : '/ai-avatar.png'
      },
      getSenderName: (role) => {
        return role === 'user' ? '我' : '智能助手'
      },
      formatTime: (timestamp) => {
        return new Date(timestamp).toLocaleTimeString()
      },
      useQuickQuestion: (question) => {
        inputMessage.value = question
      }
    }
  }
}
</script>

<style scoped>
.chat-container {
  display: flex;
  flex-direction: column;
  height: 100vh;
  max-width: 1200px;
  margin: 0 auto;
  background: #f5f7fa;
}

.chat-header {
  padding: 20px;
  background: white;
  border-bottom: 1px solid #e4e7ed;
  display: flex;
  justify-content: space-between;
  align-items: center;
}

.chat-messages {
  flex: 1;
  overflow-y: auto;
  padding: 20px;
  background: #fafafa;
}

.message {
  display: flex;
  margin-bottom: 20px;
  animation: fadeIn 0.3s;
}

@keyframes fadeIn {
  from { opacity: 0; transform: translateY(10px); }
  to { opacity: 1; transform: translateY(0); }
}

.message.user {
  flex-direction: row-reverse;
}

.message-avatar {
  margin: 0 12px;
}

.message-content {
  max-width: 70%;
  background: white;
  border-radius: 12px;
  padding: 16px;
  box-shadow: 0 2px 8px rgba(0, 0, 0, 0.1);
}

.message.user .message-content {
  background: #409eff;
  color: white;
}

.message-header {
  display: flex;
  justify-content: space-between;
  margin-bottom: 8px;
  font-size: 12px;
  opacity: 0.8;
}

.message-text {
  line-height: 1.6;
  word-break: break-word;
}

.message-text :deep(pre) {
  background: #f6f8fa;
  border-radius: 6px;
  padding: 16px;
  overflow-x: auto;
  margin: 10px 0;
}

.message-actions {
  margin-top: 12px;
  display: flex;
  gap: 8px;
}

.loading-indicator {
  padding: 20px;
}

.typing-indicator {
  display: flex;
  padding: 20px;
  justify-content: center;
}

.typing-indicator span {
  height: 8px;
  width: 8px;
  background: #409eff;
  border-radius: 50%;
  margin: 0 2px;
  animation: bounce 1.4s infinite ease-in-out both;
}

.typing-indicator span:nth-child(1) { animation-delay: -0.32s; }
.typing-indicator span:nth-child(2) { animation-delay: -0.16s; }

@keyframes bounce {
  0%, 80%, 100% { transform: scale(0); }
  40% { transform: scale(1); }
}

.chat-input-area {
  background: white;
  border-top: 1px solid #e4e7ed;
  padding: 20px;
}

.input-tools {
  display: flex;
  gap: 20px;
  margin-bottom: 16px;
  align-items: center;
}

.input-container {
  position: relative;
}

.input-actions {
  display: flex;
  justify-content: space-between;
  align-items: center;
  margin-top: 12px;
}

.char-count {
  font-size: 12px;
  color: #909399;
}

.quick-questions {
  margin-top: 16px;
  display: flex;
  flex-wrap: wrap;
  gap: 8px;
}

.quick-question-tag {
  cursor: pointer;
  transition: all 0.3s;
}

.quick-question-tag:hover {
  transform: translateY(-2px);
}
</style>

三、部署与测试:从开发到生产

3.1 Docker容器化部署

3.1.1 Dockerfile配置
# 前端Dockerfile
# frontend/Dockerfile
FROM node:18-alpine as build

WORKDIR /app

COPY package*.json ./
RUN npm ci --only=production

COPY . .
RUN npm run build

FROM nginx:alpine

COPY --from=build /app/dist /usr/share/nginx/html
COPY nginx.conf /etc/nginx/nginx.conf

EXPOSE 80

CMD ["nginx", "-g", "daemon off;"]
# 后端Dockerfile
# backend/Dockerfile
FROM maven:3.9.4-eclipse-temurin-17 as build

WORKDIR /app

COPY pom.xml .
RUN mvn dependency:go-offline -B

COPY src ./src
RUN mvn package -DskipTests

FROM eclipse-temurin:17-jre-alpine

WORKDIR /app

COPY --from=build /app/target/*.jar app.jar

# 创建非root用户
RUN addgroup -S spring && adduser -S spring -G spring
USER spring:spring

EXPOSE 8080

ENTRYPOINT ["java", "-jar", "app.jar"]
# AI服务Dockerfile
# ai-service/Dockerfile
FROM python:3.11-slim

WORKDIR /app

# 安装系统依赖
RUN apt-get update && apt-get install -y \
    gcc \
    && rm -rf /var/lib/apt/lists/*

# 复制依赖文件并安装
COPY requirements.txt .
RUN pip install --no-cache-dir -r requirements.txt

# 复制应用代码
COPY . .

# 创建非root用户
RUN useradd -m -u 1000 pythonuser
USER pythonuser

EXPOSE 8000

CMD ["uvicorn", "main:app", "--host", "0.0.0.0", "--port", "8000"]
3.1.2 Docker Compose编排
# docker-compose.yml
version: '3.8'

services:
  # MySQL数据库
  mysql:
    image: mysql:8.0
    container_name: smart-chatbot-mysql
    environment:
      MYSQL_ROOT_PASSWORD: ${MYSQL_ROOT_PASSWORD:-root123}
      MYSQL_DATABASE: chatbot_db
      MYSQL_USER: chatbot_user
      MYSQL_PASSWORD: ${MYSQL_PASSWORD:-chatbot123}
    volumes:
      - mysql_data:/var/lib/mysql
      - ./init.sql:/docker-entrypoint-initdb.d/init.sql
    ports:
      - "3306:3306"
    networks:
      - chatbot-network
    healthcheck:
      test: ["CMD", "mysqladmin", "ping", "-h", "localhost"]
      timeout: 20s
      retries: 10

  # Redis缓存
  redis:
    image: redis:7-alpine
    container_name: smart-chatbot-redis
    command: redis-server --appendonly yes
    volumes:
      - redis_data:/data
    ports:
      - "6379:6379"
    networks:
      - chatbot-network
    healthcheck:
      test: ["CMD", "redis-cli", "ping"]
      interval: 10s
      timeout: 5s
      retries: 5

  # Python AI服务
  ai-service:
    build: ./ai-service
    container_name: smart-chatbot-ai
    environment:
      - OPENAI_API_KEY=${OPENAI_API_KEY}
      - REDIS_URL=redis://redis:6379/0
      - ENVIRONMENT=production
    ports:
      - "8000:8000"
    networks:
      - chatbot-network
    depends_on:
      redis:
        condition: service_healthy
    healthcheck:
      test: ["CMD", "curl", "-f", "http://localhost:8000/health"]
      interval: 30s
      timeout: 10s
      retries: 3
    restart: unless-stopped

  # Java后端服务
  backend:
    build: ./backend
    container_name: smart-chatbot-backend
    environment:
      - SPRING_PROFILES_ACTIVE=prod
      - DB_URL=jdbc:mysql://mysql:3306/chatbot_db
      - DB_USERNAME=chatbot_user
      - DB_PASSWORD=${MYSQL_PASSWORD:-chatbot123}
      - REDIS_HOST=redis
      - REDIS_PORT=6379
      - AI_SERVICE_URL=http://ai-service:8000
    ports:
      - "8080:8080"
    networks:
      - chatbot-network
    depends_on:
      mysql:
        condition: service_healthy
      redis:
        condition: service_healthy
      ai-service:
        condition: service_healthy
    healthcheck:
      test: ["CMD", "curl", "-f", "http://localhost:8080/api/chat/health"]
      interval: 30s
      timeout: 10s
      retries: 3
    restart: unless-stopped

  # 前端服务
  frontend:
    build: ./frontend
    container_name: smart-chatbot-frontend
    ports:
      - "3000:80"
    networks:
      - chatbot-network
    depends_on:
      - backend
    restart: unless-stopped

  # Nginx反向代理(可选)
  nginx:
    image: nginx:alpine
    container_name: smart-chatbot-nginx
    ports:
      - "80:80"
      - "443:443"
    volumes:
      - ./nginx/nginx.conf:/etc/nginx/nginx.conf
      - ./nginx/ssl:/etc/nginx/ssl
    networks:
      - chatbot-network
    depends_on:
      - frontend
      - backend
    restart: unless-stopped

  # 监控服务
  prometheus:
    image: prom/prometheus
    container_name: smart-chatbot-prometheus
    volumes:
      - ./monitoring/prometheus.yml:/etc/prometheus/prometheus.yml
      - prometheus_data:/prometheus
    ports:
      - "9090:9090"
    networks:
      - chatbot-network
    restart: unless-stopped

  grafana:
    image: grafana/grafana
    container_name: smart-chatbot-grafana
    volumes:
      - grafana_data:/var/lib/grafana
      - ./monitoring/dashboards:/etc/grafana/provisioning/dashboards
    ports:
      - "3001:3000"
    environment:
      - GF_SECURITY_ADMIN_PASSWORD=${GRAFANA_PASSWORD:-admin123}
    networks:
      - chatbot-network
    restart: unless-stopped

volumes:
  mysql_data:
  redis_data:
  prometheus_data:
  grafana_data:

networks:
  chatbot-network:
    driver: bridge
3.1.3 部署脚本
#!/bin/bash
# deploy.sh

set -e

# 颜色输出
RED='\033[0;31m'
GREEN='\033[0;32m'
YELLOW='\033[1;33m'
NC='\033[0m' # No Color

echo -e "${GREEN}开始部署智能对话机器人...${NC}"

# 检查环境变量
if [ -z "$OPENAI_API_KEY" ]; then
    echo -e "${RED}错误: OPENAI_API_KEY环境变量未设置${NC}"
    exit 1
fi

# 创建必要的目录
mkdir -p nginx/ssl
mkdir -p monitoring

# 生成SSL证书(如果需要)
if [ ! -f "nginx/ssl/cert.pem" ]; then
    echo -e "${YELLOW}生成自签名SSL证书...${NC}"
    openssl req -x509 -nodes -days 365 -newkey rsa:2048 \
        -keyout nginx/ssl/key.pem \
        -out nginx/ssl/cert.pem \
        -subj "/C=CN/ST=Beijing/L=Beijing/O=SmartChatbot/CN=localhost"
fi

# 构建并启动服务
echo -e "${GREEN}构建Docker镜像...${NC}"
docker-compose build

echo -e "${GREEN}启动服务...${NC}"
docker-compose up -d

# 等待服务启动
echo -e "${YELLOW}等待服务启动...${NC}"
sleep 30

# 检查服务状态
echo -e "${GREEN}检查服务状态:${NC}"
services=("backend" "ai-service" "frontend" "mysql" "redis")
for service in "${services[@]}"; do
    if docker-compose ps | grep -q "$service.*Up"; then
        echo -e "  ${GREEN}$service 运行正常${NC}"
    else
        echo -e "  ${RED}$service 启动失败${NC}"
    fi
done

# 显示访问信息
echo -e "\n${GREEN}部署完成!${NC}"
echo -e "服务访问地址:"
echo -e "  前端: http://localhost:3000"
echo -e "  后端API: http://localhost:8080"
echo -e "  AI服务: http://localhost:8000"
echo -e "  监控面板: http://localhost:3001 (Grafana)"
echo -e "\n默认登录信息:"
echo -e "  Grafana: admin / admin123"
echo -e "\n查看日志: docker-compose logs -f [服务名]"
echo -e "停止服务: docker-compose down"

3.2 测试策略

# ai-service/tests/test_chat.py
import pytest
import asyncio
from httpx import AsyncClient
from app.main import app
from app.core.chat_processor import ChatProcessor
from app.core.prompt_engineer import PromptEngineer, ConversationType
import json

@pytest.fixture
def anyio_backend():
    return 'asyncio'

@pytest.fixture
async def client():
    async with AsyncClient(app=app, base_url="http://test") as client:
        yield client

class TestChatProcessor:
    """聊天处理器测试"""
    
    @pytest.fixture
    def chat_processor(self):
        return ChatProcessor(model_name="gpt-3.5-turbo")
    
    @pytest.mark.asyncio
    async def test_process_message(self, chat_processor):
        """测试消息处理"""
        result = await chat_processor.process_message(
            "你好", 
            "test-session-123"
        )
        
        assert "response" in result
        assert "session_id" in result
        assert "tokens_used" in result
        assert len(result["response"]) > 0
    
    @pytest.mark.asyncio
    async def test_session_management(self, chat_processor):
        """测试会话管理"""
        session_id = "test-session"
        
        # 第一次调用
        result1 = await chat_processor.process_message(
            "我的名字是张三", 
            session_id
        )
        
        # 第二次调用应该记住上下文
        result2 = await chat_processor.process_message(
            "我刚才说我叫什么名字?", 
            session_id
        )
        
        # 验证AI记得上下文
        assert "张三" in result2["response"]

class TestPromptEngineer:
    """提示词工程测试"""
    
    @pytest.fixture
    def prompt_engineer(self):
        return PromptEngineer()
    
    def test_detect_conversation_type(self, prompt_engineer):
        """测试对话类型检测"""
        
        test_cases = [
            ("如何写一个Python函数?", ConversationType.TECHNICAL),
            ("写一个关于太空的故事", ConversationType.CREATIVE),
            ("分析一下这个需求", ConversationType.ANALYSIS),
            ("你好吗?", ConversationType.GENERAL)
        ]
        
        for input_text, expected_type in test_cases:
            detected = prompt_engineer.detect_conversation_type(input_text)
            assert detected == expected_type
    
    def test_build_prompt(self, prompt_engineer):
        """测试提示词构建"""
        
        # 测试技术问题
        result = prompt_engineer.build_prompt("如何实现快速排序?")
        
        assert "system" in result
        assert "user" in result
        assert result["type"] == ConversationType.TECHNICAL.value
        assert "快速排序" in result["user"]
        
        # 测试带历史记录
        history = "用户之前问了关于算法的问题"
        result_with_history = prompt_engineer.build_prompt(
            "还有其他排序算法吗?", 
            history
        )
        
        assert "之前的对话历史" in result_with_history["system"]

class TestAPIEndpoints:
    """API端点测试"""
    
    @pytest.mark.asyncio
    async def test_health_check(self, client):
        """测试健康检查端点"""
        response = await client.get("/health")
        
        assert response.status_code == 200
        data = response.json()
        assert data["status"] == "healthy"
    
    @pytest.mark.asyncio
    async def test_chat_endpoint(self, client):
        """测试聊天端点"""
        request_data = {
            "message": "你好",
            "session_id": "test-session-api",
            "temperature": 0.7,
            "max_tokens": 100
        }
        
        response = await client.post("/chat", json=request_data)
        
        assert response.status_code == 200
        data = response.json()
        
        assert "response" in data
        assert "session_id" in data
        assert "tokens_used" in data
        assert "processing_time" in data
        assert "timestamp" in data
    
    @pytest.mark.asyncio
    async def test_chat_endpoint_validation(self, client):
        """测试聊天端点参数验证"""
        
        # 测试空消息
        request_data = {
            "message": "",
            "session_id": "test-session"
        }
        
        response = await client.post("/chat", json=request_data)
        assert response.status_code == 422
        
        # 测试过长的消息
        request_data = {
            "message": "a" * 3000,
            "session_id": "test-session"
        }
        
        response = await client.post("/chat", json=request_data)
        assert response.status_code == 422
    
    @pytest.mark.asyncio
    async def test_concurrent_requests(self, client):
        """测试并发请求"""
        request_data = {
            "message": "并发测试",
            "session_id": "concurrent-test"
        }
        
        # 发送5个并发请求
        tasks = [
            client.post("/chat", json=request_data)
            for _ in range(5)
        ]
        
        responses = await asyncio.gather(*tasks)
        
        # 验证所有请求都成功
        for response in responses:
            assert response.status_code == 200
        
        # 验证返回了不同的响应
        responses_text = [r.json()["response"] for r in responses]
        assert len(set(responses_text)) > 1

if __name__ == "__main__":
    pytest.main([__file__, "-v", "--tb=short"])

3.3 性能测试

# ai-service/tests/performance_test.py
import asyncio
import time
import statistics
from datetime import datetime
import aiohttp
import asyncpg
from typing import List, Dict
import matplotlib.pyplot as plt
import pandas as pd

class PerformanceTester:
    """性能测试器"""
    
    def __init__(self, base_url: str = "http://localhost:8000"):
        self.base_url = base_url
        self.results = []
    
    async def test_single_request(self, session: aiohttp.ClientSession, 
                                 request_data: Dict) -> Dict:
        """测试单个请求"""
        start_time = time.time()
        
        try:
            async with session.post(f"{self.base_url}/chat", 
                                   json=request_data) as response:
                end_time = time.time()
                
                if response.status == 200:
                    data = await response.json()
                    return {
                        "success": True,
                        "response_time": end_time - start_time,
                        "tokens_used": data.get("tokens_used", 0),
                        "processing_time": data.get("processing_time", 0)
                    }
                else:
                    return {
                        "success": False,
                        "response_time": end_time - start_time,
                        "error": f"HTTP {response.status}"
                    }
        except Exception as e:
            return {
                "success": False,
                "response_time": time.time() - start_time,
                "error": str(e)
            }
    
    async def run_concurrent_test(self, num_requests: int, 
                                 concurrency: int) -> Dict:
        """运行并发测试"""
        connector = aiohttp.TCPConnector(limit=concurrency)
        
        async with aiohttp.ClientSession(connector=connector) as session:
            # 准备测试数据
            test_messages = [
                "你好,介绍一下你自己",
                "什么是人工智能?",
                "写一个Python函数计算斐波那契数列",
                "解释一下微服务架构",
                "如何学习机器学习?"
            ]
            
            tasks = []
            for i in range(num_requests):
                message = test_messages[i % len(test_messages)]
                request_data = {
                    "message": message,
                    "session_id": f"perf-test-{i}",
                    "temperature": 0.7,
                    "max_tokens": 100
                }
                tasks.append(self.test_single_request(session, request_data))
            
            # 运行并发测试
            start_time = time.time()
            results = await asyncio.gather(*tasks)
            total_time = time.time() - start_time
            
            # 分析结果
            successful = [r for r in results if r["success"]]
            failed = [r for r in results if not r["success"]]
            
            if successful:
                response_times = [r["response_time"] for r in successful]
                tokens_used = [r["tokens_used"] for r in successful]
                processing_times = [r["processing_time"] for r in successful]
            else:
                response_times = []
                tokens_used = []
                processing_times = []
            
            test_result = {
                "timestamp": datetime.now(),
                "num_requests": num_requests,
                "concurrency": concurrency,
                "total_time": total_time,
                "requests_per_second": len(successful) / total_time if total_time > 0 else 0,
                "success_rate": len(successful) / num_requests * 100,
                "avg_response_time": statistics.mean(response_times) if response_times else 0,
                "p95_response_time": statistics.quantiles(response_times, n=20)[18] if len(response_times) >= 20 else 0,
                "avg_tokens_used": statistics.mean(tokens_used) if tokens_used else 0,
                "avg_processing_time": statistics.mean(processing_times) if processing_times else 0,
                "success_count": len(successful),
                "failure_count": len(failed),
                "failures": [(i, r.get("error", "Unknown")) for i, r in enumerate(failed)]
            }
            
            self.results.append(test_result)
            return test_result
    
    def generate_report(self):
        """生成性能测试报告"""
        if not self.results:
            print("没有测试结果")
            return
        
        df = pd.DataFrame(self.results)
        
        print("=" * 80)
        print("性能测试报告")
        print("=" * 80)
        print(f"测试次数: {len(df)}")
        print(f"总请求数: {df['num_requests'].sum()}")
        print(f"总成功数: {df['success_count'].sum()}")
        print(f"总失败数: {df['failure_count'].sum()}")
        print(f"平均成功率: {df['success_rate'].mean():.2f}%")
        print(f"平均RPS: {df['requests_per_second'].mean():.2f}")
        print(f"平均响应时间: {df['avg_response_time'].mean():.3f}s")
        print(f"平均P95响应时间: {df['p95_response_time'].mean():.3f}s")
        
        # 生成图表
        self.generate_charts(df)
    
    def generate_charts(self, df: pd.DataFrame):
        """生成性能图表"""
        fig, axes = plt.subplots(2, 2, figsize=(15, 10))
        
        # 图表1: 成功率趋势
        axes[0, 0].plot(df['timestamp'], df['success_rate'], 'b-o')
        axes[0, 0].set_title('成功率趋势')
        axes[0, 0].set_xlabel('时间')
        axes[0, 0].set_ylabel('成功率 (%)')
        axes[0, 0].grid(True)
        
        # 图表2: RPS趋势
        axes[0, 1].plot(df['timestamp'], df['requests_per_second'], 'g-o')
        axes[0, 1].set_title('每秒请求数 (RPS)')
        axes[0, 1].set_xlabel('时间')
        axes[0, 1].set_ylabel('RPS')
        axes[0, 1].grid(True)
        
        # 图表3: 响应时间分布
        axes[1, 0].bar(df['timestamp'].astype(str), df['avg_response_time'])
        axes[1, 0].set_title('平均响应时间')
        axes[1, 0].set_xlabel('测试批次')
        axes[1, 0].set_ylabel('响应时间 (s)')
        axes[1, 0].tick_params(axis='x', rotation=45)
        axes[1, 0].grid(True)
        
        # 图表4: Token使用量
        axes[1, 1].scatter(df['avg_response_time'], df['avg_tokens_used'])
        axes[1, 1].set_title('响应时间 vs Token使用量')
        axes[1, 1].set_xlabel('响应时间 (s)')
        axes[1, 1].set_ylabel('Token使用量')
        axes[1, 1].grid(True)
        
        plt.tight_layout()
        plt.savefig('performance_report.png', dpi=300, bbox_inches='tight')
        print("图表已保存为 performance_report.png")

async def main():
    """主测试函数"""
    tester = PerformanceTester()
    
    # 测试不同并发级别
    test_scenarios = [
        (10, 1),   # 10个请求,并发1
        (50, 5),   # 50个请求,并发5
        (100, 10), # 100个请求,并发10
        (200, 20), # 200个请求,并发20
    ]
    
    print("开始性能测试...")
    
    for num_requests, concurrency in test_scenarios:
        print(f"\n测试场景: {num_requests}请求, {concurrency}并发")
        
        result = await tester.run_concurrent_test(num_requests, concurrency)
        
        print(f"  耗时: {result['total_time']:.2f}s")
        print(f"  RPS: {result['requests_per_second']:.2f}")
        print(f"  成功率: {result['success_rate']:.1f}%")
        print(f"  平均响应时间: {result['avg_response_time']:.3f}s")
        print(f"  P95响应时间: {result['p95_response_time']:.3f}s")
        
        # 短暂等待,避免服务过载
        await asyncio.sleep(5)
    
    # 生成报告
    tester.generate_report()

if __name__ == "__main__":
    asyncio.run(main())

四、常见问题排查与优化

4.1 接口超时问题

// backend/src/main/java/com/smartchatbot/config/ResilienceConfig.java
package com.smartchatbot.config;

import io.github.resilience4j.circuitbreaker.CircuitBreakerConfig;
import io.github.resilience4j.circuitbreaker.CircuitBreakerRegistry;
import io.github.resilience4j.retry.RetryConfig;
import io.github.resilience4j.retry.RetryRegistry;
import io.github.resilience4j.timelimiter.TimeLimiterConfig;
import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Configuration;
import java.time.Duration;

@Configuration
public class ResilienceConfig {
    
    @Bean
    public CircuitBreakerRegistry circuitBreakerRegistry() {
        CircuitBreakerConfig config = CircuitBreakerConfig.custom()
            .failureRateThreshold(50) // 失败率阈值
            .slowCallRateThreshold(100) // 慢调用率阈值
            .slowCallDurationThreshold(Duration.ofSeconds(5)) // 慢调用阈值
            .waitDurationInOpenState(Duration.ofSeconds(30)) // 断路器开启后等待时间
            .permittedNumberOfCallsInHalfOpenState(10) // 半开状态允许的调用数
            .minimumNumberOfCalls(20) // 最小调用数
            .slidingWindowType(CircuitBreakerConfig.SlidingWindowType.COUNT_BASED)
            .slidingWindowSize(50) // 滑动窗口大小
            .recordExceptions(Exception.class) // 记录哪些异常
            .ignoreExceptions() // 忽略哪些异常
            .build();
        
        return CircuitBreakerRegistry.of(config);
    }
    
    @Bean
    public RetryRegistry retryRegistry() {
        RetryConfig config = RetryConfig.custom()
            .maxAttempts(3) // 最大重试次数
            .waitDuration(Duration.ofMillis(500)) // 重试间隔
            .retryOnException(e -> e instanceof RuntimeException) // 重试条件
            .failAfterMaxAttempts(true) // 达到最大重试次数后失败
            .build();
        
        return RetryRegistry.of(config);
    }
    
    @Bean
    public TimeLimiterConfig timeLimiterConfig() {
        return TimeLimiterConfig.custom()
            .timeoutDuration(Duration.ofSeconds(30)) // 超时时间
            .cancelRunningFuture(true) // 是否取消正在执行的future
            .build();
    }
}

4.2 响应不准确优化

# ai-service/app/core/response_validator.py
import re
import json
from typing import Dict, List, Optional, Tuple
from enum import Enum

class ResponseQuality(Enum):
    """响应质量等级"""
    EXCELLENT = "excellent"  # 优秀:完整、准确、相关
    GOOD = "good"           # 良好:基本满足要求
    FAIR = "fair"           # 一般:有缺陷但可用
    POOR = "poor"           # 差:不准确或不完整
    INVALID = "invalid"     # 无效:无法理解或空响应

class ResponseValidator:
    """响应验证器,确保AI响应质量"""
    
    def __init__(self):
        # 定义验证规则
        self.validation_rules = {
            "min_length": 10,  # 最小长度
            "max_length": 2000, # 最大长度
            "required_elements": [
                "complete_sentences",  # 完整句子
                "relevant_content",    # 相关内容
                "clear_structure"      # 清晰结构
            ],
            "forbidden_patterns": [
                r"我无法回答",           # 拒绝回答
                r"作为AI语言模型",       # 标准拒绝语
                r"很抱歉,我",           # 道歉开头
                r"这个.*超出.*范围",     # 超出范围
                r"^[\s\W]*$"           # 只有空白或标点
            ]
        }
        
        # 专业领域关键词(用于相关性检查)
        self.domain_keywords = {
            "technical": ["代码", "函数", "算法", "API", "配置", "部署"],
            "creative": ["故事", "诗歌", "创作", "想象", "情节"],
            "general": ["你好", "谢谢", "请问", "帮助", "解释"]
        }
    
    def validate_response(self, query: str, response: str, 
                         domain: str = "general") -> Dict:
        """验证响应质量"""
        
        validation_results = {
            "quality": ResponseQuality.EXCELLENT,
            "scores": {},
            "issues": [],
            "suggestions": []
        }
        
        # 1. 基本验证
        basic_checks = self._perform_basic_checks(response)
        validation_results["scores"].update(basic_checks["scores"])
        
        if not basic_checks["passed"]:
            validation_results["quality"] = ResponseQuality.INVALID
            validation_results["issues"].extend(basic_checks["issues"])
            return validation_results
        
        # 2. 内容质量验证
        content_checks = self._check_content_quality(query, response, domain)
        validation_results["scores"].update(content_checks["scores"])
        validation_results["issues"].extend(content_checks["issues"])
        validation_results["suggestions"].extend(content_checks["suggestions"])
        
        # 3. 计算总体质量评分
        overall_score = self._calculate_overall_score(validation_results["scores"])
        
        # 4. 确定质量等级
        if overall_score >= 0.9:
            validation_results["quality"] = ResponseQuality.EXCELLENT
        elif overall_score >= 0.7:
            validation_results["quality"] = ResponseQuality.GOOD
        elif overall_score >= 0.5:
            validation_results["quality"] = ResponseQuality.FAIR
        else:
            validation_results["quality"] = ResponseQuality.POOR
        
        validation_results["overall_score"] = overall_score
        
        return validation_results
    
    def _perform_basic_checks(self, response: str) -> Dict:
        """执行基本检查"""
        issues = []
        scores = {}
        
        # 检查空响应
        if not response or response.strip() == "":
            issues.append("响应为空")
            return {"passed": False, "scores": scores, "issues": issues}
        
        # 检查长度
        length = len(response.strip())
        if length < self.validation_rules["min_length"]:
            issues.append(f"响应过短({length}字符)")
            scores["length_score"] = 0.0
        elif length > self.validation_rules["max_length"]:
            issues.append(f"响应过长({length}字符)")
            scores["length_score"] = 0.5
        else:
            scores["length_score"] = 1.0
        
        # 检查禁止模式
        for pattern in self.validation_rules["forbidden_patterns"]:
            if re.search(pattern, response, re.IGNORECASE):
                issues.append(f"包含禁止模式:{pattern}")
                scores["forbidden_pattern_score"] = 0.0
                break
        else:
            scores["forbidden_pattern_score"] = 1.0
        
        passed = len(issues) == 0
        
        return {
            "passed": passed,
            "scores": scores,
            "issues": issues
        }
    
    def _check_content_quality(self, query: str, response: str, 
                              domain: str) -> Dict:
        """检查内容质量"""
        scores = {}
        issues = []
        suggestions = []
        
        # 1. 相关性检查
        relevance_score = self._calculate_relevance_score(query, response, domain)
        scores["relevance_score"] = relevance_score
        
        if relevance_score < 0.5:
            issues.append("响应与问题相关性较低")
            suggestions.append("请确保回答直接相关于用户的问题")
        
        # 2. 完整性检查
        completeness_score = self._check_completeness(response)
        scores["completeness_score"] = completeness_score
        
        if completeness_score < 0.7:
            issues.append("响应可能不够完整")
            suggestions.append("考虑提供更全面的信息")
        
        # 3. 准确性指示器(启发式检查)
        accuracy_indicators = self._check_accuracy_indicators(response)
        scores["accuracy_indicator_score"] = accuracy_indicators
        
        # 4. 结构质量
        structure_score = self._check_structure_quality(response)
        scores["structure_score"] = structure_score
        
        if structure_score < 0.6:
            issues.append("响应结构不够清晰")
            suggestions.append("考虑使用分点、段落或标题来组织内容")
        
        return {
            "scores": scores,
            "issues": issues,
            "suggestions": suggestions
        }
    
    def _calculate_relevance_score(self, query: str, response: str, 
                                  domain: str) -> float:
        """计算相关性分数"""
        # 简单实现:基于关键词重叠
        query_words = set(re.findall(r'\w+', query.lower()))
        response_words = set(re.findall(r'\w+', response.lower()))
        
        if not query_words:
            return 0.0
        
        # 计算Jaccard相似度
        intersection = query_words.intersection(response_words)
        union = query_words.union(response_words)
        
        if not union:
            return 0.0
        
        basic_similarity = len(intersection) / len(union)
        
        # 考虑领域关键词
        if domain in self.domain_keywords:
            domain_words = set(self.domain_keywords[domain])
            domain_match = domain_words.intersection(response_words)
            if domain_match:
                basic_similarity = min(1.0, basic_similarity + 0.2)
        
        return basic_similarity
    
    def _check_completeness(self, response: str) -> float:
        """检查完整性"""
        # 基于句子结构和长度判断
        sentences = re.split(r'[.!?。!?]+', response)
        valid_sentences = [s.strip() for s in sentences if len(s.strip()) > 5]
        
        if len(valid_sentences) == 0:
            return 0.0
        
        # 计算平均句子长度
        avg_sentence_length = sum(len(s) for s in valid_sentences) / len(valid_sentences)
        
        # 基于句子数量和长度评分
        sentence_count_score = min(len(valid_sentences) / 5, 1.0)
        length_score = min(avg_sentence_length / 50, 1.0)
        
        return (sentence_count_score + length_score) / 2
    
    def _check_accuracy_indicators(self, response: str) -> float:
        """检查准确性指示器"""
        score = 1.0
        
        # 不确定性表达(降低分数)
        uncertainty_patterns = [
            r"可能.*",
            r"大概.*",
            r"也许.*",
            r"不太确定",
            r"据我所知"
        ]
        
        for pattern in uncertainty_patterns:
            if re.search(pattern, response):
                score -= 0.1
        
        # 准确性表达(提高分数)
        confidence_patterns = [
            r"肯定.*",
            r"一定.*",
            r"确实.*",
            r"根据.*研究",
            r"数据表明"
        ]
        
        for pattern in confidence_patterns:
            if re.search(pattern, response):
                score = min(score + 0.05, 1.0)
        
        return max(score, 0.0)
    
    def _check_structure_quality(self, response: str) -> float:
        """检查结构质量"""
        score = 0.5  # 基础分
        
        # 检查是否有结构元素
        structure_elements = {
            "paragraphs": r'\n\s*\n',  # 空行分隔段落
            "bullets": r'[\-\*•]\s',   # 项目符号
            "numbers": r'\d+\.\s',      # 编号列表
            "headings": r'^#+\s',       # 标题(Markdown格式)
        }
        
        for element_name, pattern in structure_elements.items():
            if re.search(pattern, response, re.MULTILINE):
                score += 0.1
        
        return min(score, 1.0)
    
    def _calculate_overall_score(self, scores: Dict) -> float:
        """计算总体分数"""
        if not scores:
            return 0.0
        
        # 权重配置
        weights = {
            "length_score": 0.1,
            "forbidden_pattern_score": 0.2,
            "relevance_score": 0.3,
            "completeness_score": 0.2,
            "accuracy_indicator_score": 0.1,
            "structure_score": 0.1
        }
        
        total_score = 0.0
        total_weight = 0.0
        
        for score_name, weight in weights.items():
            if score_name in scores:
                total_score += scores[score_name] * weight
                total_weight += weight
        
        if total_weight == 0:
            return 0.0
        
        return total_score / total_weight
    
    def generate_improved_prompt(self, query: str, validation_results: Dict) -> str:
        """基于验证结果生成改进的提示词"""
        issues = validation_results.get("issues", [])
        quality = validation_results.get("quality")
        
        if quality in [ResponseQuality.EXCELLENT, ResponseQuality.GOOD]:
            return query  # 不需要改进
        
        # 构建改进提示
        improvements = []
        
        if "响应过短" in issues or "响应不够完整" in issues:
            improvements.append("请提供更详细和完整的回答")
        
        if "相关性较低" in issues:
            improvements.append(f"请确保回答直接相关于:{query}")
        
        if "结构不够清晰" in issues:
            improvements.append("请使用清晰的结构组织回答,如分点或段落")
        
        if improvements:
            enhanced_query = f"{query}\n\n请确保:{'; '.join(improvements)}"
        else:
            enhanced_query = query
        
        return enhanced_query

4.3 监控与告警配置

# monitoring/prometheus.yml
global:
  scrape_interval: 15s
  evaluation_interval: 15s

alerting:
  alertmanagers:
    - static_configs:
        - targets: []

rule_files:
  - "alerts.yml"

scrape_configs:
  - job_name: 'spring-boot'
    metrics_path: '/actuator/prometheus'
    static_configs:
      - targets: ['backend:8080']
        labels:
          application: 'smart-chatbot-backend'
  
  - job_name: 'ai-service'
    static_configs:
      - targets: ['ai-service:8000']
        labels:
          application: 'smart-chatbot-ai'
  
  - job_name: 'mysql'
    static_configs:
      - targets: ['mysql:9104']
    params:
      collect[]:
        - global_status
        - innodb_metrics
        - performance_schema.file_summary_by_event_name
  
  - job_name: 'redis'
    static_configs:
      - targets: ['redis:9121']
  
  - job_name: 'node-exporter'
    static_configs:
      - targets: ['node-exporter:9100']
# monitoring/alerts.yml
groups:
  - name: smart-chatbot-alerts
    rules:
      # API响应时间告警
      - alert: HighResponseTime
        expr: rate(http_request_duration_seconds_sum[5m]) / rate(http_request_duration_seconds_count[5m]) > 3
        for: 2m
        labels:
          severity: warning
        annotations:
          summary: "高响应时间"
          description: "API平均响应时间超过3秒"
      
      # 错误率告警
      - alert: HighErrorRate
        expr: rate(http_requests_total{status=~"5.."}[5m]) / rate(http_requests_total[5m]) > 0.05
        for: 2m
        labels:
          severity: critical
        annotations:
          summary: "高错误率"
          description: "HTTP 5xx错误率超过5%"
      
      # AI服务可用性告警
      - alert: AIServiceDown
        expr: up{job="ai-service"} == 0
        for: 1m
        labels:
          severity: critical
        annotations:
          summary: "AI服务不可用"
          description: "AI服务已宕机超过1分钟"
      
      # 数据库连接告警
      - alert: HighDatabaseConnections
        expr: mysql_global_status_threads_connected / mysql_global_variables_max_connections > 0.8
        for: 2m
        labels:
          severity: warning
        annotations:
          summary: "数据库连接数高"
          description: "数据库连接数超过最大连接数的80%"
      
      # 内存使用告警
      - alert: HighMemoryUsage
        expr: (node_memory_MemTotal_bytes - node_memory_MemAvailable_bytes) / node_memory_MemTotal_bytes > 0.8
        for: 5m
        labels:
          severity: warning
        annotations:
          summary: "高内存使用率"
          description: "内存使用率超过80%"

五、总结与展望

通过本文的完整实现,我们成功构建了一个生产级的智能对话机器人系统。这个项目展示了:

5.1 关键收获

  1. 完整的架构设计:理解了前后端分离、微服务架构在AI应用中的实际应用
  2. 跨语言协同:掌握了Python和Java协同工作的最佳实践
  3. 生产级部署:学会了使用Docker容器化部署完整的AI应用
  4. 性能优化:掌握了监控、熔断、降级等关键的生产环境技术

5.2 扩展方向

基础对话机器人

多模态AI

个性化推荐

知识图谱集成

语音交互

图像识别

文档处理

用户画像

行为分析

智能搜索

问答系统

语音识别

语音合成

5.3 后续优化建议

  1. 模型优化

    • 实现模型缓存,减少重复计算
    • 添加模型版本管理和A/B测试
    • 支持多模型切换和负载均衡
  2. 性能优化

    • 实现响应流式传输,减少用户等待时间
    • 添加更细粒度的缓存策略
    • 实现请求批处理,提高吞吐量
  3. 功能扩展

    • 添加文件上传和处理能力
    • 实现多轮对话的场景管理
    • 添加情感分析和情绪识别
  4. 安全性增强

    • 实现内容安全审核
    • 添加用户行为分析和异常检测
    • 增强API认证和授权机制

5.4 最终建议

对于想要深入AI应用开发的开发者,建议:

  1. 从简单开始:先实现核心功能,再逐步添加高级特性
  2. 重视监控:生产环境的问题往往需要通过监控数据来发现和解决
  3. 持续学习:AI技术发展迅速,需要持续关注新技术和新工具
  4. 关注成本:大模型API调用有成本,需要设计合理的计费和限流策略

通过这个项目,你不仅完成了一个可运行的AI应用,更重要的是建立了将AI技术落地到实际生产环境的信心和能力。AI应用的开发虽然有一定门槛,但通过合理的设计和持续优化,完全可以构建出稳定、高效、易用的智能系统。

项目代码地址:本文完整代码可在GitHub获取:[项目链接]

加入讨论:欢迎在评论区分享你的实现经验和遇到的问题!

Logo

有“AI”的1024 = 2048,欢迎大家加入2048 AI社区

更多推荐