大模型RAG(检索增强生成)知识详解

目录

  1. 一、RAG概述
  2. 二、RAG工作原理
  3. 三、RAG核心组件
  4. 四、RAG实现方式
  5. 五、RAG优化策略
  6. 六、RAG应用场景
  7. 七、大模型开发实战笔记
  8. 八、完整项目案例

一、RAG概述

1.1 什么是RAG?

RAG(Retrieval-Augmented Generation,检索增强生成)是一种结合信息检索和文本生成的AI技术架构。它通过先从知识库中检索相关信息,再将检索结果作为上下文提供给大语言模型进行回答生成,从而提高回答的准确性和可靠性。

1.2 RAG的价值

  1. 减少幻觉:通过真实数据支撑,降低模型产生错误信息的概率
  2. 增强时效性:可以快速更新知识库,保持信息最新
  3. 提高透明度:回答基于可验证的数据源
  4. 降低成本:相比微调,RAG实现成本更低
  5. 保护隐私:私有数据无需参与模型训练

二、RAG工作原理

2.1 基本工作流程

用户提问 → 向量化检索 → 相关文档获取 → 上下文构建 → LLM生成 → 答案输出

2.2 详细步骤

RAG系统的核心工作流程包含五个关键步骤,每个步骤都有其特定的功能和技术实现:

  1. 用户查询接收

    • 接收用户自然语言查询
    • 对查询进行预处理和清理
  2. 查询向量化

    • 使用embedding模型将查询转换为向量
    • 保持与文档向量的相同语义空间
  3. 相似度检索

    • 在向量数据库中进行相似度搜索
    • 返回最相关的文档片段
  4. 上下文构建

    • 将检索到的文档与原始查询组合
    • 构建完整的prompt模板
  5. 答案生成

    • 将构建的上下文输入LLM
    • 生成基于检索结果的回答

以下是一个完整的RAG工作流程实现示例,展示了如何将这些步骤整合到一个可运行的系统中:

import numpy as np
from typing import List, Dict, Any
import time

class RAGWorkflow:
    """RAG工作流程主类,实现从用户查询到答案生成的完整流程"""
    
    def __init__(self, embedding_model, vector_store, llm_service):
        """
        初始化RAG工作流程
        
        Args:
            embedding_model: 文本向量化模型
            vector_store: 向量数据库实例
            llm_service: 大语言模型服务
        """
        self.embedding_model = embedding_model
        self.vector_store = vector_store
        self.llm_service = llm_service
    
    def process_query(self, user_query: str, top_k: int = 5) -> Dict[str, Any]:
        """
        处理用户查询的主流程方法
        
        Args:
            user_query: 用户输入的查询问题
            top_k: 检索返回的最相关文档数量
            
        Returns:
            包含答案、相关文档和元数据的完整响应
        """
        start_time = time.time()
        
        # 步骤1: 查询预处理
        processed_query = self._preprocess_query(user_query)
        
        # 步骤2: 查询向量化
        query_vector = self._vectorize_query(processed_query)
        
        # 步骤3: 相似度检索
        relevant_docs = self._retrieve_documents(query_vector, top_k)
        
        # 步骤4: 上下文构建
        context = self._build_context(relevant_docs, processed_query)
        
        # 步骤5: 答案生成
        answer = self._generate_answer(user_query, context)
        
        processing_time = time.time() - start_time
        
        return {
            'answer': answer,
            'original_query': user_query,
            'processed_query': processed_query,
            'retrieved_documents': relevant_docs,
            'context': context,
            'processing_time': processing_time,
            'top_k': top_k
        }
    
    def _preprocess_query(self, query: str) -> str:
        """
        步骤1: 查询预处理
        清理和标准化用户输入的查询文本
        """
        # 去除多余空格和特殊字符
        cleaned_query = ' '.join(query.split())
        
        # 转换为小写(对于英文)
        cleaned_query = cleaned_query.lower()
        
        # 可以添加更多预处理逻辑,如拼写检查、查询扩展等
        return cleaned_query
    
    def _vectorize_query(self, query: str) -> np.ndarray:
        """
        步骤2: 查询向量化
        将文本查询转换为数值向量表示
        """
        # 使用预训练的embedding模型将查询转换为向量
        query_vector = self.embedding_model.embed_query(query)
        
        # 确保向量维度正确
        if query_vector.ndim == 1:
            query_vector = query_vector.reshape(1, -1)
        
        return query_vector
    
    def _retrieve_documents(self, query_vector: np.ndarray, top_k: int) -> List[Dict]:
        """
        步骤3: 相似度检索
        在向量数据库中查找与查询最相关的文档
        """
        # 在向量数据库中进行相似度搜索
        search_results = self.vector_store.search(
            query_embeddings=query_vector.tolist(),
            n_results=top_k
        )
        
        # 格式化搜索结果
        retrieved_docs = []
        for i in range(len(search_results['documents'][0])):
            doc = {
                'content': search_results['documents'][0][i],
                'metadata': search_results['metadatas'][0][i] if search_results['metadatas'] else {},
                'similarity_score': 1 - search_results['distances'][0][i],  # 转换为相似度
                'doc_id': search_results['ids'][0][i] if search_results['ids'] else f"doc_{i}"
            }
            retrieved_docs.append(doc)
        
        return retrieved_docs
    
    def _build_context(self, documents: List[Dict], query: str, max_length: int = 2000) -> str:
        """
        步骤4: 上下文构建
        将检索到的文档组织成适合LLM输入的上下文格式
        """
        if not documents:
            return "未找到相关文档来回答此问题。"
        
        context_parts = []
        current_length = 0
        
        # 按相似度排序文档
        sorted_docs = sorted(documents, key=lambda x: x['similarity_score'], reverse=True)
        
        for i, doc in enumerate(sorted_docs):
            doc_content = f"文档 {i+1} (相似度: {doc['similarity_score']:.3f}):
{doc['content']}"
            doc_length = len(doc_content)
            
            # 检查是否超过最大长度限制
            if current_length + doc_length > max_length and context_parts:
                # 如果添加这个文档会超过限制,尝试截断
                remaining_space = max_length - current_length - 50
                if remaining_space > 100:
                    truncated_content = doc['content'][:remaining_space] + "..."
                    context_parts.append(f"文档 {i+1} (已截断):
{truncated_content}")
                break
            
            context_parts.append(doc_content)
            current_length += doc_length
        
        return '

'.join(context_parts)
    
    def _generate_answer(self, original_query: str, context: str) -> str:
        """
        步骤5: 答案生成
        使用LLM基于上下文生成回答
        """
        # 构建提示词模板
        prompt = f"""基于以下提供的文档内容,请准确回答用户的问题。

上下文信息:
{context}

用户问题:{original_query}

请基于上述上下文信息回答问题。如果上下文中没有足够的信息来回答问题,请明确说明。回答要准确、简洁、有帮助。

回答:"""
        
        # 调用LLM生成回答
        answer = self.llm_service.generate_response(prompt)
        
        return answer
    
    def batch_process_queries(self, queries: List[str], top_k: int = 5) -> List[Dict]:
        """
        批量处理多个查询
        """
        results = []
        
        for query in queries:
            try:
                result = self.process_query(query, top_k)
                results.append(result)
            except Exception as e:
                # 记录错误但继续处理其他查询
                error_result = {
                    'answer': f"处理查询时出错: {str(e)}",
                    'original_query': query,
                    'error': str(e),
                    'processing_time': 0
                }
                results.append(error_result)
        
        return results
    
    def get_workflow_stats(self) -> Dict[str, Any]:
        """
        获取工作流程统计信息
        """
        return {
            'embedding_model': type(self.embedding_model).__name__,
            'vector_store': type(self.vector_store).__name__,
            'llm_service': type(self.llm_service).__name__,
            'components_initialized': all([
                self.embedding_model is not None,
                self.vector_store is not None,
                self.llm_service is not None
            ])
        }


class RAGWorkflowManager:
    """RAG工作流程管理器,提供高级管理功能"""
    
    def __init__(self, workflow: RAGWorkflow):
        self.workflow = workflow
        self.query_history = []
        self.performance_metrics = {
            'total_queries': 0,
            'successful_queries': 0,
            'failed_queries': 0,
            'average_processing_time': 0,
            'total_processing_time': 0
        }
    
    def process_and_log_query(self, query: str, top_k: int = 5) -> Dict[str, Any]:
        """
        处理查询并记录历史和性能指标
        """
        try:
            result = self.workflow.process_query(query, top_k)
            
            # 记录查询历史
            self.query_history.append({
                'timestamp': time.time(),
                'query': query,
                'success': True,
                'processing_time': result['processing_time'],
                'retrieved_docs_count': len(result.get('retrieved_documents', []))
            })
            
            # 更新性能指标
            self._update_performance_metrics(result['processing_time'], success=True)
            
            return result
            
        except Exception as e:
            # 记录失败的查询
            self.query_history.append({
                'timestamp': time.time(),
                'query': query,
                'success': False,
                'error': str(e)
            })
            
            # 更新性能指标
            self._update_performance_metrics(0, success=False)
            
            raise e
    
    def _update_performance_metrics(self, processing_time: float, success: bool):
        """更新性能指标"""
        self.performance_metrics['total_queries'] += 1
        self.performance_metrics['total_processing_time'] += processing_time
        
        if success:
            self.performance_metrics['successful_queries'] += 1
        else:
            self.performance_metrics['failed_queries'] += 1
        
        self.performance_metrics['average_processing_time'] = (
            self.performance_metrics['total_processing_time'] / 
            self.performance_metrics['total_queries']
        )
    
    def get_performance_report(self) -> Dict[str, Any]:
        """获取性能报告"""
        return {
            'performance_metrics': self.performance_metrics.copy(),
            'success_rate': (
                self.performance_metrics['successful_queries'] / 
                max(self.performance_metrics['total_queries'], 1)
            ),
            'recent_queries': self.query_history[-10:],  # 最近10个查询
            'query_history_size': len(self.query_history)
        }

三、RAG核心组件

3.1 数据源(Data Source)

类型:

  • 结构化数据:数据库、表格、API
  • 半结构化数据:JSON、XML、Markdown
  • 非结构化数据:PDF、Word、网页、图片

数据预处理实现:

import re
import nltk
import spacy
from typing import List, Dict, Optional
from pathlib import Path
import PyPDF2
from docx import Document
import json
import pandas as pd

class DocumentProcessor:
    def __init__(self):
        # 加载自然语言处理模型
        try:
            self.nlp = spacy.load("zh_core_web_sm")
        except:
            print("请安装spacy中文模型: python -m spacy download zh_core_web_sm")
            self.nlp = None
        
        # 下载NLTK数据
        nltk.download('punkt', quiet=True)
        nltk.download('stopwords', quiet=True)
    
    def load_documents(self, data_path: str) -> List[Dict]:
        """加载多种格式的文档"""
        documents = []
        data_dir = Path(data_path)
        
        if not data_dir.exists():
            raise FileNotFoundError(f"数据路径不存在: {data_path}")
        
        # 支持的文件类型
        file_patterns = {
            '.pdf': self.load_pdf,
            '.docx': self.load_docx,
            '.md': self.load_markdown,
            '.txt': self.load_text,
            '.json': self.load_json,
            '.csv': self.load_csv,
            '.html': self.load_html
        }
        
        for file_path in data_dir.rglob('*'):
            if file_path.is_file() and file_path.suffix.lower() in file_patterns:
                try:
                    loader_func = file_patterns[file_path.suffix.lower()]
                    docs = loader_func(file_path)
                    documents.extend(docs)
                    print(f"成功加载文件: {file_path}")
                except Exception as e:
                    print(f"加载文件失败 {file_path}: {e}")
        
        return documents
    
    def load_pdf(self, file_path: Path) -> List[Dict]:
        """加载PDF文件"""
        documents = []
        
        with open(file_path, 'rb') as file:
            pdf_reader = PyPDF2.PdfReader(file)
            
            for page_num, page in enumerate(pdf_reader.pages):
                text = page.extract_text()
                if text.strip():
                    documents.append({
                        'content': text.strip(),
                        'metadata': {
                            'file_name': file_path.name,
                            'file_type': 'pdf',
                            'page_number': page_num + 1,
                            'total_pages': len(pdf_reader.pages),
                            'char_count': len(text.strip())
                        }
                    })
        
        return documents
    
    def load_docx(self, file_path: Path) -> List[Dict]:
        """加载Word文档"""
        documents = []
        doc = Document(file_path)
        
        doc_metadata = {
            'file_name': file_path.name,
            'file_type': 'docx',
            'paragraphs_count': len(doc.paragraphs)
        }
        
        for i, paragraph in enumerate(doc.paragraphs):
            if paragraph.text.strip():
                documents.append({
                    'content': paragraph.text.strip(),
                    'metadata': {
                        **doc_metadata,
                        'paragraph_index': i,
                        'char_count': len(paragraph.text)
                    }
                })
        
        return documents
    
    def load_markdown(self, file_path: Path) -> List[Dict]:
        """加载Markdown文件"""
        documents = []
        
        with open(file_path, 'r', encoding='utf-8') as f:
            content = f.read()
        
        # 按标题分割文档
        sections = self.split_markdown_by_headers(content)
        
        for section in sections:
            documents.append({
                'content': section['content'],
                'metadata': {
                    'file_name': file_path.name,
                    'file_type': 'markdown',
                    'section_title': section.get('title', '无标题'),
                    'section_level': section.get('level', 0),
                    'char_count': len(section['content'])
                }
            })
        
        return documents
    
    def split_markdown_by_headers(self, content: str) -> List[Dict]:
        """按标题分割Markdown内容"""
        sections = []
        lines = content.split('\n')
        current_section = []
        current_title = "无标题"
        current_level = 0
        
        for line in lines:
            if line.startswith('#'):
                # 保存当前section
                if current_section:
                    sections.append({
                        'title': current_title,
                        'level': current_level,
                        'content': '\n'.join(current_section)
                    })
                
                # 开始新section
                current_title = line.strip()
                current_level = len(line) - len(line.lstrip('#'))
                current_section = []
            else:
                current_section.append(line)
        
        # 添加最后一个section
        if current_section:
            sections.append({
                'title': current_title,
                'level': current_level,
                'content': '\n'.join(current_section)
            })
        
        return sections
    
    def clean_text(self, text: str) -> str:
        """清理文本"""
        # 移除多余的空白字符
        text = re.sub(r'\s+', ' ', text)
        # 移除特殊字符(保留中文、英文、数字、基本标点)
        text = re.sub(r'[^\u4e00-\u9fa5\w\s.,!?;:()[\]{}"\'-]', '', text)
        # 移除行首行尾空白
        text = text.strip()
        
        return text
    
    def extract_keywords(self, text: str, max_keywords: int = 10) -> List[str]:
        """提取关键词"""
        if self.nlp is None:
            # 简单的词频统计作为后备方案
            words = re.findall(r'\b\w+\b', text.lower())
            word_freq = {}
            for word in words:
                if len(word) > 2:  # 过滤短词
                    word_freq[word] = word_freq.get(word, 0) + 1
            
            return sorted(word_freq.keys(), key=lambda x: word_freq[x], 
                         reverse=True)[:max_keywords]
        
        # 使用spacy进行关键词提取
        doc = self.nlp(text)
        keywords = []
        
        # 提取名词和专有名词
        for token in doc:
            if (token.pos_ in ['NOUN', 'PROPN'] and 
                not token.is_stop and 
                len(token.text) > 2):
                keywords.append(token.lemma_.lower())
        
        # 去重并限制数量
        return list(set(keywords))[:max_keywords]

3.2 向量化模型(Embedding Model)

向量化服务实现:

import numpy as np
from sentence_transformers import SentenceTransformer
from typing import List, Dict, Any, Optional
import pickle
import hashlib
import time
from dataclasses import dataclass

@dataclass
class EmbeddingConfig:
    model_name: str = "shibing624/text2vec-base-chinese"
    cache_enabled: bool = True
    cache_file: str = "embedding_cache.pkl"
    batch_size: int = 32
    max_length: int = 512

class EmbeddingService:
    def __init__(self, config: EmbeddingConfig):
        self.config = config
        self.model = SentenceTransformer(config.model_name)
        self.cache = {}
        
        # 加载缓存
        if config.cache_enabled:
            self.load_cache()
        
        print(f"已加载embedding模型: {config.model_name}")
    
    def load_cache(self):
        """加载向量缓存"""
        try:
            with open(self.config.cache_file, 'rb') as f:
                self.cache = pickle.load(f)
            print(f"已加载缓存,包含 {len(self.cache)} 个向量")
        except FileNotFoundError:
            self.cache = {}
            print("缓存文件不存在,将创建新缓存")
    
    def save_cache(self):
        """保存向量缓存"""
        try:
            with open(self.config.cache_file, 'wb') as f:
                pickle.dump(self.cache, f)
            print(f"已保存 {len(self.cache)} 个向量到缓存")
        except Exception as e:
            print(f"保存缓存失败: {e}")
    
    def embed_texts(self, texts: List[str], use_cache: bool = True) -> np.ndarray:
        """批量文本向量化"""
        embeddings = []
        uncached_texts = []
        uncached_indices = []
        
        # 检查缓存
        if use_cache:
            for i, text in enumerate(texts):
                text_hash = self.get_text_hash(text)
                if text_hash in self.cache:
                    embeddings.append(self.cache[text_hash])
                else:
                    uncached_texts.append(text)
                    uncached_indices.append(i)
        else:
            uncached_texts = texts
            uncached_indices = list(range(len(texts)))
        
        # 计算未缓存的向量
        if uncached_texts:
            print(f"计算 {len(uncached_texts)} 个新向量...")
            start_time = time.time()
            
            new_embeddings = self.model.encode(
                uncached_texts,
                batch_size=self.config.batch_size,
                show_progress_bar=True,
                normalize_embeddings=True
            )
            
            # 更新缓存
            for i, embedding in enumerate(new_embeddings):
                original_index = uncached_indices[i]
                text = uncached_texts[i]
                text_hash = self.get_text_hash(text)
                
                if use_cache:
                    self.cache[text_hash] = embedding
                
                # 确保embedding在正确位置
                while len(embeddings) <= original_index:
                    embeddings.append(None)
                embeddings[original_index] = embedding
            
            # 保存缓存
            if use_cache:
                self.save_cache()
            
            elapsed_time = time.time() - start_time
            print(f"向量化完成,耗时: {elapsed_time:.2f}秒")
        
        return np.array(embeddings)
    
    def embed_query(self, query: str) -> np.ndarray:
        """单个查询向量化"""
        query_hash = self.get_text_hash(query)
        
        if self.config.cache_enabled and query_hash in self.cache:
            return self.cache[query_hash]
        
        embedding = self.model.encode([query], normalize_embeddings=True)[0]
        
        if self.config.cache_enabled:
            self.cache[query_hash] = embedding
        
        return embedding
    
    def get_text_hash(self, text: str) -> str:
        """获取文本哈希值用于缓存"""
        return hashlib.md5(text.encode('utf-8')).hexdigest()
    
    def compute_similarity(self, query_embedding: np.ndarray, 
                          doc_embeddings: np.ndarray) -> np.ndarray:
        """计算相似度"""
        # 使用余弦相似度
        similarities = np.dot(doc_embeddings, query_embedding)
        return similarities
    
    def find_most_similar(self, query: str, documents: List[str], 
                         top_k: int = 5) -> List[Dict]:
        """查找最相似的文档"""
        query_embedding = self.embed_query(query)
        doc_embeddings = self.embed_texts(documents)
        similarities = self.compute_similarity(query_embedding, doc_embeddings)
        
        # 获取top-k结果
        top_indices = np.argsort(similarities)[-top_k:][::-1]
        
        results = []
        for idx in top_indices:
            results.append({
                'index': int(idx),
                'content': documents[idx],
                'similarity_score': float(similarities[idx])
            })
        
        return results
    
    def get_embedding_dimension(self) -> int:
        """获取向量维度"""
        return self.model.get_sentence_embedding_dimension()
    
    def benchmark_performance(self, test_texts: List[str]) -> Dict:
        """性能基准测试"""
        print("开始性能基准测试...")
        
        # 单文本测试
        start_time = time.time()
        single_embedding = self.embed_query(test_texts[0])
        single_time = time.time() - start_time
        
        # 批量测试
        batch_sizes = [1, 8, 16, 32]
        batch_results = {}
        
        for batch_size in batch_sizes:
            batch_texts = test_texts[:batch_size]
            
            start_time = time.time()
            batch_embeddings = self.embed_texts(batch_texts, use_cache=False)
            batch_time = time.time() - start_time
            
            batch_results[batch_size] = {
                'time': batch_time,
                'throughput': batch_size / batch_time
            }
        
        return {
            'embedding_dimension': self.get_embedding_dimension(),
            'single_text_time': single_time,
            'batch_results': batch_results,
            'cache_size': len(self.cache)
        }

3.3 文档分块策略

智能分块实现:

import re
import jieba
from typing import List, Dict, Tuple, Optional
import nltk
from dataclasses import dataclass

@dataclass
class ChunkingConfig:
    chunk_size: int = 1000
    chunk_overlap: int = 200
    min_chunk_size: int = 100
    max_chunk_size: int = 2000
    use_semantic_chunking: bool = True
    separators: List[str] = None

class DocumentChunker:
    def __init__(self, config: ChunkingConfig):
        self.config = config
        if config.separators is None:
            self.config.separators = ['\n\n', '\n', '。', '!', '?', ';', ';']
        
        # 下载NLTK数据
        try:
            nltk.data.find('tokenizers/punkt')
        except LookupError:
            nltk.download('punkt', quiet=True)
    
    def chunk_documents(self, documents: List[Dict]) -> List[Dict]:
        """分块文档列表"""
        all_chunks = []
        
        for doc_idx, document in enumerate(documents):
            content = document['content']
            metadata = document.get('metadata', {})
            
            # 尝试不同的分块策略
            chunks = self.chunk_single_document(content, metadata, doc_idx)
            all_chunks.extend(chunks)
        
        print(f"共生成 {len(all_chunks)} 个文档块")
        return all_chunks
    
    def chunk_single_document(self, content: str, metadata: Dict, doc_idx: int) -> List[Dict]:
        """分块单个文档"""
        if self.config.use_semantic_chunking:
            # 尝试语义分块
            chunks = self.semantic_chunking(content, metadata, doc_idx)
            if len(chunks) > 1:
                return chunks
        
        # 回退到递归分块
        chunks = self.recursive_character_chunking(content, metadata, doc_idx)
        if len(chunks) > 1:
            return chunks
        
        # 最后使用固定大小分块
        return self.fixed_size_chunking(content, metadata, doc_idx)
    
    def semantic_chunking(self, content: str, metadata: Dict, doc_idx: int) -> List[Dict]:
        """语义分块"""
        sentences = self.split_sentences(content)
        chunks = []
        current_chunk = []
        current_length = 0
        
        for i, sentence in enumerate(sentences):
            sentence_length = len(sentence)
            
            # 如果当前句子会使chunk超过限制,创建新chunk
            if (current_length + sentence_length > self.config.chunk_size and 
                current_chunk and len(current_chunk) > 1):
                
                chunk_text = ' '.join(current_chunk)
                chunks.append(self.create_chunk(chunk_text, metadata, doc_idx, len(chunks)))
                
                # 创建重叠
                overlap_sentences = self.get_overlap_sentences(current_chunk, i, sentences)
                current_chunk = overlap_sentences
                current_length = sum(len(s) for s in current_chunk)
            
            current_chunk.append(sentence)
            current_length += sentence_length
        
        # 添加最后一个chunk
        if current_chunk:
            chunk_text = ' '.join(current_chunk)
            chunks.append(self.create_chunk(chunk_text, metadata, doc_idx, len(chunks)))
        
        return chunks
    
    def split_sentences(self, text: str) -> List[str]:
        """分割句子"""
        # 对于英文
        if re.search(r'[a-zA-Z]', text):
            sentences = nltk.sent_tokenize(text)
        else:
            # 对于中文,使用jieba和正则表达式
            sentences = re.split(r'[。!?;]+', text)
        
        # 清理和过滤
        sentences = [s.strip() for s in sentences if s.strip() and len(s.strip()) > 5]
        
        return sentences
    
    def get_overlap_sentences(self, current_chunk: List[str], current_index: int, 
                             all_sentences: List[str]) -> List[str]:
        """获取重叠句子"""
        overlap_sentences = []
        overlap_length = 0
        
        # 从后向前添加句子直到达到重叠大小
        for sentence in reversed(current_chunk):
            if overlap_length + len(sentence) <= self.config.chunk_overlap:
                overlap_sentences.insert(0, sentence)
                overlap_length += len(sentence)
            else:
                break
        
        return overlap_sentences
    
    def recursive_character_chunking(self, content: str, metadata: Dict, doc_idx: int) -> List[Dict]:
        """递归字符分块"""
        return self._recursive_split(content, metadata, doc_idx, self.config.separators, 0)
    
    def _recursive_split(self, content: str, metadata: Dict, doc_idx: int, 
                         separators: List[str], separator_index: int) -> List[Dict]:
        """递归分割实现"""
        if separator_index >= len(separators):
            # 最后使用固定大小分块
            return self.fixed_size_chunking(content, metadata, doc_idx)
        
        separator = separators[separator_index]
        parts = content.split(separator)
        chunks = []
        current_chunk = ""
        
        for part in parts:
            test_chunk = current_chunk + (separator if current_chunk else "")
            test_chunk += part
            
            if len(test_chunk) <= self.config.chunk_size:
                current_chunk = test_chunk
            else:
                if current_chunk:
                    # 递归分割当前块
                    sub_chunks = self._recursive_split(
                        current_chunk, metadata, doc_idx, separators, separator_index + 1
                    )
                    chunks.extend(sub_chunks)
                    current_chunk = part
                else:
                    # 单个part太长,继续递归
                    sub_chunks = self._recursive_split(
                        part, metadata, doc_idx, separators, separator_index + 1
                    )
                    chunks.extend(sub_chunks)
        
        # 处理最后一个chunk
        if current_chunk:
            sub_chunks = self._recursive_split(
                current_chunk, metadata, doc_idx, separators, separator_index + 1
            )
            chunks.extend(sub_chunks)
        
        return chunks
    
    def fixed_size_chunking(self, content: str, metadata: Dict, doc_idx: int) -> List[Dict]:
        """固定大小分块"""
        chunks = []
        content_length = len(content)
        
        start = 0
        chunk_idx = 0
        
        while start < content_length:
            end = min(start + self.config.chunk_size, content_length)
            
            # 如果不是最后一个chunk,尝试在句子边界分割
            if end < content_length:
                # 向后查找句子边界
                sentence_end = end
                for i in range(end, min(end + 100, content_length)):
                    if content[i] in ['。', '!', '?', '\n\n']:
                        sentence_end = i + 1
                        break
                
                # 如果找到了句子边界,使用它
                if sentence_end > end:
                    end = sentence_end
            
            chunk_text = content[start:end].strip()
            if chunk_text:
                chunks.append(self.create_chunk(chunk_text, metadata, doc_idx, chunk_idx))
                chunk_idx += 1
            
            # 计算下一个开始位置(考虑重叠)
            start = max(start + self.config.chunk_size - self.config.chunk_overlap, end)
        
        return chunks
    
    def create_chunk(self, content: str, metadata: Dict, doc_idx: int, chunk_idx: int) -> Dict:
        """创建文档块"""
        chunk_metadata = metadata.copy()
        chunk_metadata.update({
            'chunk_index': chunk_idx,
            'char_count': len(content),
            'word_count': len(content.split()),
            'created_at': time.time(),
            'parent_doc_index': doc_idx
        })
        
        return {
            'content': content,
            'metadata': chunk_metadata
        }
    
    def analyze_chunks(self, chunks: List[Dict]) -> Dict:
        """分析分块结果"""
        if not chunks:
            return {}
        
        sizes = [len(chunk['content']) for chunk in chunks]
        
        return {
            'total_chunks': len(chunks),
            'avg_chunk_size': np.mean(sizes),
            'min_chunk_size': min(sizes),
            'max_chunk_size': max(sizes),
            'total_characters': sum(sizes),
            'size_distribution': {
                'small': sum(1 for s in sizes if s < self.config.min_chunk_size),
                'medium': sum(1 for s in sizes if self.config.min_chunk_size <= s <= self.config.chunk_size),
                'large': sum(1 for s in sizes if s > self.config.chunk_size)
            }
        }

3.4 向量数据库

Chroma向量数据库实现:

import chromadb
from chromadb.config import Settings
from typing import List, Dict, Any, Optional
import uuid
import time

class ChromaVectorStore:
    def __init__(self, persist_directory: str = "./vector_db", 
                 collection_name: str = "documents"):
        self.persist_directory = persist_directory
        self.collection_name = collection_name
        
        # 初始化Chroma客户端
        self.client = chromadb.PersistentClient(path=persist_directory)
        
        # 获取或创建集合
        self.collection = self.client.get_or_create_collection(
            name=collection_name,
            metadata={"hnsw:space": "cosine"}
        )
        
        print(f"已连接到向量数据库: {persist_directory}")
        print(f"使用集合: {collection_name}")
    
    def add_documents(self, texts: List[str], embeddings: List[List[float]], 
                     metadatas: List[Dict] = None, ids: List[str] = None) -> List[str]:
        """添加文档到向量数据库"""
        if len(texts) != len(embeddings):
            raise ValueError("文本和向量数量不匹配")
        
        # 生成ID
        if ids is None:
            ids = [f"doc_{uuid.uuid4().hex[:8]}" for _ in range(len(texts))]
        
        # 生成元数据
        if metadatas is None:
            metadatas = [{} for _ in range(len(texts))]
        
        # 批量添加
        batch_size = 100
        added_ids = []
        
        for i in range(0, len(texts), batch_size):
            batch_end = min(i + batch_size, len(texts))
            
            self.collection.add(
                ids=ids[i:batch_end],
                embeddings=embeddings[i:batch_end],
                documents=texts[i:batch_end],
                metadatas=metadatas[i:batch_end]
            )
            
            added_ids.extend(ids[i:batch_end])
        
        print(f"已添加 {len(added_ids)} 个文档到向量数据库")
        return added_ids
    
    def search(self, query_embeddings: List[List[float]], 
              n_results: int = 5, where: Optional[Dict] = None,
              where_document: Optional[Dict] = None) -> Dict[str, Any]:
        """搜索向量数据库"""
        results = self.collection.query(
            query_embeddings=query_embeddings,
            n_results=n_results,
            where=where,
            where_document=where_document
        )
        
        # 转换结果格式
        return {
            'documents': results['documents'][0] if results['documents'] else [],
            'metadatas': results['metadatas'][0] if results['metadatas'] else [],
            'distances': results['distances'][0] if results['distances'] else [],
            'ids': results['ids'][0] if results['ids'] else []
        }
    
    def similarity_search(self, query_embedding: List[float], 
                        threshold: float = 0.7, max_results: int = 10) -> List[Dict]:
        """基于阈值的相似度搜索"""
        results = self.search([query_embedding], n_results=max_results)
        
        filtered_results = []
        for i in range(len(results['documents'])):
            similarity_score = 1 - results['distances'][i]  # 转换为相似度
            
            if similarity_score >= threshold:
                filtered_results.append({
                    'id': results['ids'][i],
                    'content': results['documents'][i],
                    'metadata': results['metadatas'][i],
                    'similarity_score': similarity_score
                })
        
        return filtered_results
    
    def hybrid_search(self, query_embedding: List[float], keyword_query: str,
                     alpha: float = 0.7, top_k: int = 10) -> List[Dict]:
        """混合搜索(向量 + 关键词)"""
        # 向量搜索
        vector_results = self.search([query_embedding], n_results=top_k * 2)
        
        # 关键词搜索(Chroma支持基本的文本搜索)
        keyword_results = self.collection.query(
            query_texts=[keyword_query],
            n_results=top_k * 2
        )
        
        # 合并和重排序结果
        merged_scores = {}
        
        # 向量搜索结果
        for i, doc_id in enumerate(vector_results['ids'][0]):
            vector_score = 1 - vector_results['distances'][0][i]
            merged_scores[doc_id] = alpha * vector_score
        
        # 关键词搜索结果
        for i, doc_id in enumerate(keyword_results['ids'][0]):
            keyword_score = 1 - keyword_results['distances'][0][i]
            if doc_id in merged_scores:
                merged_scores[doc_id] += (1 - alpha) * keyword_score
            else:
                merged_scores[doc_id] = (1 - alpha) * keyword_score
        
        # 获取最终文档
        final_results = []
        for doc_id, score in sorted(merged_scores.items(), key=lambda x: x[1], reverse=True)[:top_k]:
            # 获取文档内容
            doc_content = self.get_document_by_id(doc_id)
            if doc_content:
                final_results.append({
                    'id': doc_id,
                    'content': doc_content['content'],
                    'metadata': doc_content['metadata'],
                    'hybrid_score': score
                })
        
        return final_results
    
    def get_document_by_id(self, doc_id: str) -> Optional[Dict]:
        """根据ID获取文档"""
        try:
            results = self.collection.get(ids=[doc_id])
            if results['documents']:
                return {
                    'id': doc_id,
                    'content': results['documents'][0],
                    'metadata': results['metadatas'][0] if results['metadatas'] else {}
                }
        except Exception as e:
            print(f"获取文档失败 {doc_id}: {e}")
        
        return None
    
    def update_document(self, doc_id: str, new_content: str, 
                      new_embedding: List[float], new_metadata: Dict = None):
        """更新文档"""
        if new_metadata is None:
            new_metadata = {}
        
        self.collection.update(
            ids=[doc_id],
            documents=[new_content],
            embeddings=[new_embedding],
            metadatas=[new_metadata]
        )
        print(f"已更新文档: {doc_id}")
    
    def delete_documents(self, doc_ids: List[str]):
        """删除文档"""
        self.collection.delete(ids=doc_ids)
        print(f"已删除 {len(doc_ids)} 个文档")
    
    def get_collection_stats(self) -> Dict:
        """获取集合统计信息"""
        try:
            count = self.collection.count()
            return {
                'document_count': count,
                'collection_name': self.collection_name,
                'persist_directory': self.persist_directory
            }
        except Exception as e:
            return {'error': str(e)}
    
    def clear_collection(self):
        """清空集合"""
        try:
            # 获取所有文档ID
            all_docs = self.collection.get()
            if all_docs['ids']:
                self.delete_documents(all_docs['ids'])
            print(f"已清空集合 {self.collection_name}")
        except Exception as e:
            print(f"清空集合失败: {e}")

四、RAG实现方式

4.1 基础RAG系统

基础RAG系统是整个检索增强生成架构的核心实现。它整合了文档处理、向量化、检索和生成等关键组件,提供了一个统一的接口来处理用户查询并生成基于知识的回答。以下实现包含了完整的项目结构、配置管理、知识库构建和查询处理功能:

核心RAG类实现:

import logging
from typing import List, Dict, Any, Optional
from dataclasses import dataclass
import time

@dataclass
class RAGConfig:
    embedding_model: str = "shibing624/text2vec-base-chinese"
    llm_model: str = "deepseek-chat"
    chunk_size: int = 1000
    chunk_overlap: int = 200
    top_k: int = 5
    similarity_threshold: float = 0.7
    max_context_length: int = 4000
    use_hybrid_search: bool = False
    enable_reranking: bool = False

class BasicRAGSystem:
    def __init__(self, config: RAGConfig):
        self.config = config
        
        # 初始化日志
        logging.basicConfig(level=logging.INFO)
        self.logger = logging.getLogger(__name__)
        
        # 初始化组件(稍后实现)
        self.embedding_service = None
        self.vector_store = None
        self.llm_service = None
        self.document_chunker = None
        
        self.logger.info("RAG系统初始化完成")
    
    def initialize(self, embedding_service, vector_store, llm_service, document_chunker):
        """初始化RAG组件"""
        self.embedding_service = embedding_service
        self.vector_store = vector_store
        self.llm_service = llm_service
        self.document_chunker = document_chunker
        
        self.logger.info("RAG组件初始化完成")
    
    def build_knowledge_base(self, data_path: str) -> Dict[str, Any]:
        """构建知识库"""
        self.logger.info(f"开始构建知识库,数据路径: {data_path}")
        start_time = time.time()
        
        try:
            # 1. 加载文档
            doc_processor = DocumentProcessor()
            documents = doc_processor.load_documents(data_path)
            self.logger.info(f"加载了 {len(documents)} 个文档")
            
            # 2. 文档分块
            chunks = self.document_chunker.chunk_documents(documents)
            self.logger.info(f"生成了 {len(chunks)} 个文档块")
            
            # 3. 向量化
            chunk_texts = [chunk['content'] for chunk in chunks]
            chunk_metadatas = [chunk['metadata'] for chunk in chunks]
            
            embeddings = self.embedding_service.embed_texts(chunk_texts)
            
            # 4. 存储到向量数据库
            doc_ids = self.vector_store.add_documents(
                texts=chunk_texts,
                embeddings=embeddings.tolist(),
                metadatas=chunk_metadatas
            )
            
            processing_time = time.time() - start_time
            
            return {
                'success': True,
                'document_count': len(documents),
                'chunk_count': len(chunks),
                'processing_time': processing_time,
                'embedding_dimension': self.embedding_service.get_embedding_dimension()
            }
            
        except Exception as e:
            self.logger.error(f"构建知识库失败: {e}")
            return {
                'success': False,
                'error': str(e),
                'processing_time': time.time() - start_time
            }
    
    def query(self, user_question: str, 
              use_hybrid: bool = None,
              return_sources: bool = True) -> Dict[str, Any]:
        """查询RAG系统"""
        if use_hybrid is None:
            use_hybrid = self.config.use_hybrid_search
        
        start_time = time.time()
        
        try:
            # 1. 查询预处理
            processed_query = self.preprocess_query(user_question)
            
            # 2. 检索相关文档
            if use_hybrid:
                retrieved_docs = self.hybrid_retrieval(processed_query)
            else:
                retrieved_docs = self.vector_retrieval(processed_query)
            
            # 3. 构建上下文
            context = self.build_context(retrieved_docs)
            
            # 4. 生成回答
            answer = self.generate_answer(user_question, context)
            
            # 5. 构建响应
            processing_time = time.time() - start_time
            
            response = {
                'answer': answer,
                'question': user_question,
                'retrieved_docs_count': len(retrieved_docs),
                'used_docs_count': len(retrieved_docs),
                'processing_time': processing_time
            }
            
            if return_sources:
                response['sources'] = self.format_sources(retrieved_docs)
            
            return response
            
        except Exception as e:
            self.logger.error(f"查询处理失败: {e}")
            return {
                'answer': '抱歉,处理您的问题时遇到了错误。',
                'question': user_question,
                'error': str(e),
                'processing_time': time.time() - start_time,
                'sources': []
            }
    
    def preprocess_query(self, query: str) -> Dict[str, Any]:
        """查询预处理"""
        # 清理查询
        cleaned_query = query.strip()
        
        # 可以在这里添加查询扩展、意图识别等
        return {
            'original': query,
            'cleaned': cleaned_query,
            'expanded': cleaned_query  # 简单实现,可以扩展
        }
    
    def vector_retrieval(self, processed_query: Dict[str, Any]) -> List[Dict]:
        """向量检索"""
        query_embedding = self.embedding_service.embed_query(processed_query['expanded'])
        
        results = self.vector_store.search(
            query_embeddings=[query_embedding],
            n_results=self.config.top_k
        )
        
        # 转换结果格式
        retrieved_docs = []
        for i in range(len(results['documents'])):
            similarity_score = 1 - results['distances'][i]  # 转换为相似度
            
            if similarity_score >= self.config.similarity_threshold:
                retrieved_docs.append({
                    'id': results['ids'][i],
                    'content': results['documents'][i],
                    'metadata': results['metadatas'][i],
                    'similarity_score': similarity_score
                })
        
        return retrieved_docs
    
    def hybrid_retrieval(self, processed_query: Dict[str, Any]) -> List[Dict]:
        """混合检索"""
        query_embedding = self.embedding_service.embed_query(processed_query['expanded'])
        
        results = self.vector_store.hybrid_search(
            query_embedding=query_embedding,
            keyword_query=processed_query['expanded'],
            alpha=0.7,
            top_k=self.config.top_k
        )
        
        # 过滤低相关性结果
        filtered_results = [
            doc for doc in results 
            if doc.get('hybrid_score', 0) >= self.config.similarity_threshold
        ]
        
        return filtered_results
    
    def build_context(self, retrieved_docs: List[Dict]) -> str:
        """构建上下文"""
        if not retrieved_docs:
            return "未找到相关信息。"
        
        context_parts = []
        current_length = 0
        
        for i, doc in enumerate(retrieved_docs):
            doc_content = doc['content']
            
            # 检查长度限制
            if current_length + len(doc_content) > self.config.max_context_length:
                # 尝试截取部分内容
                remaining_space = self.config.max_context_length - current_length - 100
                if remaining_space > 100:
                    truncated_content = doc_content[:remaining_space] + "..."
                    context_parts.append(f"文档 {i+1}{truncated_content}")
                break
            
            context_parts.append(f"文档 {i+1}{doc_content}")
            current_length += len(doc_content)
        
        return '\n\n'.join(context_parts)
    
    def generate_answer(self, question: str, context: str) -> str:
        """生成回答"""
        prompt = f"""
基于以下上下文信息回答用户问题。请确保回答准确、相关且简洁。

上下文信息:
{context}

用户问题:{question}

请基于上下文信息回答问题。如果上下文中没有足够信息,请说明。
"""
        
        answer = self.llm_service.generate_response(prompt)
        return answer
    
    def format_sources(self, retrieved_docs: List[Dict]) -> List[Dict]:
        """格式化来源信息"""
        sources = []
        for doc in retrieved_docs:
            source = {
                'id': doc['id'],
                'similarity_score': doc.get('similarity_score', 0),
                'metadata': doc['metadata']
            }
            
            # 添加有用的元数据
            if 'file_name' in doc['metadata']:
                source['file_name'] = doc['metadata']['file_name']
            if 'title' in doc['metadata']:
                source['title'] = doc['metadata']['title']
            
            sources.append(source)
        
        return sources

4.2 高级RAG功能

高级RAG功能在基础系统之上添加了智能化的增强能力,包括查询扩展、重排序、多路检索等技术,显著提升了系统的检索精度和回答质量。这些功能通过优化查询理解、文档相关性判断和结果排序来克服基础系统的局限性:

查询扩展器:

from typing import List, Dict, Set
import jieba
import requests

class QueryExpander:
    def __init__(self):
        self.synonym_dict = self.load_synonym_dict()
        self.related_words_dict = self.load_related_words_dict()
    
    def load_synonym_dict(self) -> Dict[str, List[str]]:
        """加载同义词字典"""
        return {
            '工作': ['职业', '岗位', '职能', '任职'],
            '方法': ['方式', '途径', '手段', '策略'],
            '问题': ['难题', '挑战', '疑问', '困惑'],
            '发展': ['进步', '提升', '成长', '演变'],
            '重要': ['关键', '核心', '主要', '首要'],
            'Python': ['编程', '代码', '脚本', '开发'],
            '数据库': ['MySQL', 'PostgreSQL', 'MongoDB', 'Oracle'],
            '前端': ['HTML', 'CSS', 'JavaScript', 'UI', '用户体验']
        }
    
    def load_related_words_dict(self) -> Dict[str, List[str]]:
        """加载相关词字典"""
        return {
            '机器学习': ['深度学习', '神经网络', 'AI', '人工智能'],
            '数据分析': ['数据挖掘', '统计', '可视化', '报表'],
            '云计算': ['AWS', 'Azure', '阿里云', '服务器'],
            '网络安全': ['防火墙', '加密', '认证', '漏洞']
        }
    
    def expand_query(self, original_query: str, 
                     expansion_methods: List[str] = None) -> List[str]:
        """查询扩展"""
        if expansion_methods is None:
            expansion_methods = ['synonyms', 'related_words']
        
        expanded_queries = [original_query]  # 包含原始查询
        
        # 分词
        words = jieba.lcut(original_query)
        
        for method in expansion_methods:
            if method == 'synonyms':
                expanded_queries.extend(self.expand_with_synonyms(original_query, words))
            elif method == 'related_words':
                expanded_queries.extend(self.expand_with_related_words(words))
        
        # 去重
        expanded_queries = list(set(expanded_queries))
        
        return expanded_queries
    
    def expand_with_synonyms(self, original_query: str, words: List[str]) -> List[str]:
        """同义词扩展"""
        expanded_queries = []
        
        for word in words:
            if word in self.synonym_dict:
                for synonym in self.synonym_dict[word][:2]:  # 最多2个同义词
                    expanded_query = original_query.replace(word, synonym)
                    if expanded_query != original_query:
                        expanded_queries.append(expanded_query)
        
        return expanded_queries
    
    def expand_with_related_words(self, words: List[str]) -> List[str]:
        """相关词扩展"""
        related_queries = []
        
        for word in words:
            if word in self.related_words_dict:
                for related_word in self.related_words_dict[word]:
                    related_queries.append(f"{word} {related_word}")
        
        return related_queries

重排序器:

import torch
from transformers import AutoTokenizer, AutoModelForSequenceClassification
import numpy as np

class NeuralReranker:
    def __init__(self, model_name: str = "cross-encoder/ms-marco-MiniLM-L-6-v2"):
        self.model_name = model_name
        self.tokenizer = AutoTokenizer.from_pretrained(model_name)
        self.model = AutoModelForSequenceClassification.from_pretrained(model_name)
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.model.to(self.device)
        self.model.eval()
    
    def rerank(self, query: str, documents: List[str], top_k: int = 5) -> List[Dict]:
        """重排序文档"""
        if not documents:
            return []
        
        # 构建查询-文档对
        query_doc_pairs = [(query, doc) for doc in documents]
        
        # 批量处理
        scores = self.compute_scores(query_doc_pairs)
        
        # 排序并返回top-k
        scored_docs = list(zip(documents, scores))
        scored_docs.sort(key=lambda x: x[1], reverse=True)
        
        results = []
        for i, (doc, score) in enumerate(scored_docs[:top_k]):
            results.append({
                'content': doc,
                'rerank_score': float(score),
                'rank': i + 1
            })
        
        return results
    
    def compute_scores(self, query_doc_pairs: List[tuple]) -> np.ndarray:
        """计算相关性分数"""
        batch_size = 16
        all_scores = []
        
        for i in range(0, len(query_doc_pairs), batch_size):
            batch_pairs = query_doc_pairs[i:i + batch_size]
            
            # 编码输入
            inputs = self.tokenizer(
                [f"{query} [SEP] {doc}" for query, doc in batch_pairs],
                padding=True,
                truncation=True,
                max_length=512,
                return_tensors="pt"
            )
            
            inputs = {k: v.to(self.device) for k, v in inputs.items()}
            
            # 模型预测
            with torch.no_grad():
                outputs = self.model(**inputs)
                batch_scores = torch.sigmoid(outputs.logits).cpu().numpy()
            
            all_scores.extend(batch_scores.flatten())
        
        return np.array(all_scores)

五、RAG优化策略

5.1 检索优化

检索优化是提升RAG系统性能的关键环节,通过多策略融合、智能排序和上下文自适应等技术来改善检索的准确性和效率。这些优化策略解决了单一检索方法的局限性,提高了系统的整体表现:

多路检索融合:

class MultiModalRetrieval:
    def __init__(self, vector_store, keyword_searcher):
        self.vector_store = vector_store
        self.keyword_searcher = keyword_searcher
        self.fusion_methods = ['rrf', 'weighted_sum', 'reciprocal_rank']
    
    def fused_search(self, query, method='rrf', weights=None):
        """融合多种检索结果"""
        # 向量检索
        vector_results = self.vector_store.search(
            self.embed_query(query), n_results=20
        )
        
        # 关键词检索
        keyword_results = self.keyword_searcher.search(query, n_results=20)
        
        # 融合结果
        if method == 'rrf':
            return self.reciprocal_rank_fusion(vector_results, keyword_results)
        elif method == 'weighted_sum':
            return self.weighted_sum_fusion(vector_results, keyword_results, weights)
        else:
            return self.reciprocal_rank_fusion(vector_results, keyword_results)
    
    def reciprocal_rank_fusion(self, vector_results, keyword_results, k=60):
        """倒数排名融合"""
        fused_scores = {}
        
        # 向量检索结果
        for rank, doc in enumerate(vector_results):
            doc_id = doc['id']
            score = 1.0 / (k + rank + 1)
            fused_scores[doc_id] = fused_scores.get(doc_id, 0) + score
        
        # 关键词检索结果
        for rank, doc in enumerate(keyword_results):
            doc_id = doc['id']
            score = 1.0 / (k + rank + 1)
            fused_scores[doc_id] = fused_scores.get(doc_id, 0) + score
        
        # 按分数排序
        sorted_docs = sorted(fused_scores.items(), key=lambda x: x[1], reverse=True)
        
        return [doc for doc, score in sorted_docs[:10]]

5.2 上下文优化

上下文优化专注于如何有效地组织和利用检索到的文档信息,确保生成的回答既准确又相关。通过智能的文档选择、内容截取和重要性评估,系统可以在有限的上下文窗口内最大化信息密度和回答质量:

自适应上下文选择:

class AdaptiveContextManager:
    def __init__(self, max_context_length=4000):
        self.max_length = max_context_length
        self.importance_scorer = ImportanceScorer()
    
    def select_context(self, query, documents, max_docs=5):
        """自适应选择上下文"""
        # 计算文档重要性
        scored_docs = []
        for doc in documents:
            importance_score = self.importance_scorer.score(query, doc)
            scored_docs.append((doc, importance_score))
        
        # 按重要性排序
        scored_docs.sort(key=lambda x: x[1], reverse=True)
        
        # 选择文档直到达到长度限制
        selected_docs = []
        current_length = 0
        
        for doc, score in scored_docs:
            doc_length = len(doc['content'])
            
            if current_length + doc_length <= self.max_length and len(selected_docs) < max_docs:
                selected_docs.append(doc)
                current_length += doc_length
            elif current_length + doc_length > self.max_length and selected_docs:
                # 尝试截取文档
                remaining_space = self.max_length - current_length
                if remaining_space > 100:
                    truncated_doc = self.truncate_document(doc, remaining_space)
                    selected_docs.append(truncated_doc)
                break
        
        return selected_docs
    
    def truncate_document(self, doc, max_length):
        """智能截取文档"""
        content = doc['content']
        
        # 尝试在句子边界截取
        sentences = re.split(r'[。!?]', content)
        truncated_content = ""
        
        for sentence in sentences:
            if len(truncated_content + sentence) <= max_length:
                truncated_content += sentence + "。"
            else:
                break
        
        if not truncated_content:
            # 如果没有完整的句子,直接截取
            truncated_content = content[:max_length-3] + "..."
        
        return {
            **doc,
            'content': truncated_content,
            'truncated': True
        }

class ImportanceScorer:
    def __init__(self):
        self.weight_similarity = 0.4
        self.weight_recency = 0.2
        self.weight_length = 0.2
        self.weight_structure = 0.2
    
    def score(self, query, document):
        """计算文档重要性分数"""
        # 相似度分数
        similarity_score = document.get('similarity_score', 0)
        
        # 时效性分数(基于元数据)
        recency_score = self.calculate_recency_score(document)
        
        # 长度分数(偏好适中长度)
        length_score = self.calculate_length_score(document)
        
        # 结构分数(标题、章节等)
        structure_score = self.calculate_structure_score(document)
        
        # 加权求和
        total_score = (
            self.weight_similarity * similarity_score +
            self.weight_recency * recency_score +
            self.weight_length * length_score +
            self.weight_structure * structure_score
        )
        
        return total_score
    
    def calculate_recency_score(self, document):
        """计算时效性分数"""
        metadata = document.get('metadata', {})
        
        # 检查是否有日期信息
        if 'updated_date' in metadata:
            # 这里可以实现更复杂的日期计算
            return 0.8
        elif 'created_at' in metadata:
            return 0.6
        else:
            return 0.3  # 没有时间信息的文档分数较低
    
    def calculate_length_score(self, document):
        """计算长度分数"""
        content_length = len(document['content'])
        
        if content_length < 100:
            return 0.3  # 太短
        elif content_length < 500:
            return 0.7  # 适中
        elif content_length < 1500:
            return 1.0  # 理想
        else:
            return 0.6  # 太长
    
    def calculate_structure_score(self, document):
        """计算结构分数"""
        metadata = document.get('metadata', {})
        score = 0.5  # 基础分数
        
        # 有标题加分
        if 'title' in metadata:
            score += 0.2
        
        # 是主要章节加分
        if metadata.get('section_level', 999) <= 2:
            score += 0.2
        
        # 文件类型加分
        file_type = metadata.get('file_type', '')
        if file_type in ['markdown', 'pdf']:
            score += 0.1
        
        return min(score, 1.0)

六、RAG应用场景

6.1 企业知识库问答

企业知识库问答是RAG技术的重要应用场景,它能够帮助企业快速构建智能化的内部知识服务系统。这类系统需要处理权限控制、数据安全、多部门协作等复杂需求,为员工提供准确、及时的知识支持:

企业级RAG系统:

from enum import Enum
from typing import Dict, List, Optional
from datetime import datetime

class AccessLevel(Enum):
    PUBLIC = "public"
    INTERNAL = "internal"
    CONFIDENTIAL = "confidential"
    SECRET = "secret"

class EnterpriseRAGSystem:
    def __init__(self, config):
        self.config = config
        self.user_manager = UserManager()
        self.doc_indexer = EnterpriseDocumentIndexer()
        self.access_controller = AccessController()
        self.audit_logger = AuditLogger()
        self.base_rag = BaseRAGSystem(config)
    
    def query_with_permission(self, user_id: str, query: str, 
                             access_level: AccessLevel = AccessLevel.INTERNAL) -> Dict:
        """带权限控制的查询"""
        # 验证用户权限
        user_info = self.user_manager.get_user(user_id)
        if not user_info:
            return {
                'answer': '用户不存在',
                'error': 'user_not_found',
                'question': query
            }
        
        if not self.access_controller.can_access(user_info, access_level):
            return {
                'answer': '您没有权限访问该级别的信息',
                'error': 'permission_denied',
                'question': query
            }
        
        # 记录查询
        self.audit_logger.log_query(user_id, query, access_level)
        
        try:
            # 执行检索(带权限过滤)
            filtered_docs = self.doc_indexer.search_with_permission(
                query, user_info['department'], access_level
            )
            
            if not filtered_docs:
                return {
                    'answer': '未找到相关信息或您没有访问权限',
                    'sources': [],
                    'success': True
                }
            
            # 生成回答
            context = self.build_secure_context(filtered_docs)
            answer = self.base_rag.generate_answer(query, context)
            
            # 记录回答
            self.audit_logger.log_answer(user_id, query, answer, filtered_docs)
            
            return {
                'answer': answer,
                'sources': self.format_enterprise_sources(filtered_docs),
                'success': True
            }
            
        except Exception as e:
            self.audit_logger.log_error(user_id, query, str(e))
            return {
                'answer': '处理您的查询时遇到了问题,请联系管理员',
                'error': str(e),
                'success': False
            }
    
    def build_secure_context(self, documents):
        """构建安全的上下文"""
        context_parts = []
        for i, doc in enumerate(documents):
            # 移除敏感信息
            safe_content = self.sanitize_content(doc['content'])
            context_parts.append(f"文档{i+1}{safe_content[:500]}")
        
        return '\n\n'.join(context_parts)
    
    def sanitize_content(self, content):
        """清理敏感内容"""
        # 这里可以实现更复杂的内容过滤逻辑
        sensitive_patterns = [
            r'\b\d{4}-\d{4}-\d{4}-\d{4}\b',  # 信用卡号
            r'\b\d{3}-\d{2}-\d{4}\b',       # SSN
            r'password\s*[:=]\s*\w+',        # 密码
        ]
        
        for pattern in sensitive_patterns:
            content = re.sub(pattern, '[已隐藏]', content, flags=re.IGNORECASE)
        
        return content

class UserManager:
    def __init__(self):
        self.users = {
            'user001': {
                'name': '张三',
                'department': '技术部',
                'role': 'developer',
                'access_level': AccessLevel.CONFIDENTIAL,
                'roles': ['developer']
            },
            'user002': {
                'name': '李四',
                'department': '人事部',
                'role': 'hr_manager',
                'access_level': AccessLevel.CONFIDENTIAL,
                'roles': ['hr_manager']
            }
        }
    
    def get_user(self, user_id: str) -> Optional[Dict]:
        return self.users.get(user_id)
    
    def has_permission(self, user_id: str, required_level: AccessLevel) -> bool:
        user = self.get_user(user_id)
        if not user:
            return False
        
        level_hierarchy = {
            AccessLevel.PUBLIC: 0,
            AccessLevel.INTERNAL: 1,
            AccessLevel.CONFIDENTIAL: 2,
            AccessLevel.SECRET: 3
        }
        
        user_level = level_hierarchy.get(user['access_level'], 0)
        required_level_value = level_hierarchy.get(required_level, 0)
        
        return user_level >= required_level_value

class AccessController:
    def can_access(self, user_info: Dict, required_level: AccessLevel) -> bool:
        user_level = user_info.get('access_level', AccessLevel.PUBLIC)
        
        level_hierarchy = {
            AccessLevel.PUBLIC: 0,
            AccessLevel.INTERNAL: 1,
            AccessLevel.CONFIDENTIAL: 2,
            AccessLevel.SECRET: 3
        }
        
        return level_hierarchy.get(user_level, 0) >= level_hierarchy.get(required_level, 0)

class AuditLogger:
    def __init__(self):
        self.logs = []
    
    def log_query(self, user_id: str, query: str, access_level: AccessLevel):
        log_entry = {
            'timestamp': datetime.now(),
            'type': 'query',
            'user_id': user_id,
            'query': query,
            'access_level': access_level.value
        }
        self.logs.append(log_entry)
    
    def log_answer(self, user_id: str, query: str, answer: str, sources: List[Dict]):
        log_entry = {
            'timestamp': datetime.now(),
            'type': 'answer',
            'user_id': user_id,
            'query': query,
            'answer': answer[:100],  # 只记录前100个字符
            'source_count': len(sources)
        }
        self.logs.append(log_entry)
    
    def log_error(self, user_id: str, query: str, error: str):
        log_entry = {
            'timestamp': datetime.now(),
            'type': 'error',
            'user_id': user_id,
            'query': query,
            'error': error
        }
        self.logs.append(log_entry)

6.2 智能客服系统

智能客服系统是RAG技术的另一个重要应用领域,它通过结合企业产品知识库和对话管理,为用户提供24/7的智能咨询服务。这类系统需要处理意图识别、情感分析、多轮对话等复杂场景,提供个性化、精准的客服体验:

客服RAG实现:

import uuid
from datetime import datetime, timedelta
from typing import Dict, List, Optional

class CustomerServiceRAG:
    def __init__(self, config):
        self.config = config
        self.conversation_manager = ConversationManager()
        self.intent_classifier = IntentClassifier()
        self.emotion_analyzer = EmotionAnalyzer()
        self.response_generator = ResponseGenerator()
        self.knowledge_base = ProductKnowledgeBase()
        
    def handle_customer_query(self, session_id: str, user_message: str, 
                             user_context: Dict = None) -> Dict:
        """处理客户查询"""
        # 获取对话历史
        conversation_history = self.conversation_manager.get_history(session_id)
        
        # 意图识别
        intent = self.intent_classifier.classify(user_message, conversation_history)
        
        # 情感分析
        emotion = self.emotion_analyzer.analyze(user_message)
        
        # 获取用户上下文
        if user_context is None:
            user_context = {}
        
        # 根据意图处理查询
        if intent == 'complaint':
            return self.handle_complaint(session_id, user_message, emotion, user_context)
        elif intent == 'technical_support':
            return self.handle_technical_support(session_id, user_message, conversation_history)
        elif intent == 'product_inquiry':
            return self.handle_product_inquiry(session_id, user_message, user_context)
        elif intent == 'general_inquiry':
            return self.handle_general_inquiry(session_id, user_message, conversation_history)
        else:
            return self.handle_unknown_intent(session_id, user_message)
    
    def handle_technical_support(self, session_id: str, query: str, 
                               history: List[Dict]) -> Dict:
        """处理技术支持查询"""
        # 检索相关知识
        relevant_docs = self.knowledge_base.search_technical_docs(query)
        
        # 构建上下文
        context = self.build_technical_context(query, relevant_docs, history)
        
        # 生成技术支持回答
        response = self.response_generator.generate_technical_response(
            query, context, history
        )
        
        # 保存对话
        self.conversation_manager.add_message(session_id, query, response, 'technical_support')
        
        # 生成后续建议
        suggestions = self.generate_technical_suggestions(query, response, relevant_docs)
        
        return {
            'response': response,
            'intent': 'technical_support',
            'suggestions': suggestions,
            'escalation_needed': self.should_escalate(response, relevant_docs),
            'sources': self.format_sources(relevant_docs)
        }
    
    def handle_complaint(self, session_id: str, complaint: str, 
                        emotion: str, user_context: Dict) -> Dict:
        """处理投诉"""
        # 立即道歉和共情
        apology_response = self.generate_empathetic_response(complaint, emotion)
        
        # 识别投诉类型
        complaint_type = self.classify_complaint_type(complaint)
        
        # 检索相关政策和解决方案
        solutions = self.knowledge_base.search_complaint_solutions(complaint_type)
        
        # 生成处理方案
        resolution = self.generate_complaint_resolution(complaint, complaint_type, solutions)
        
        # 创建投诉记录
        complaint_id = self.create_complaint_record(
            session_id, complaint, complaint_type, emotion, user_context
        )
        
        # 生成完整回复
        full_response = f"{apology_response}\n\n{resolution}"
        
        # 保存对话
        self.conversation_manager.add_message(session_id, complaint, full_response, 'complaint')
        
        return {
            'response': full_response,
            'intent': 'complaint',
            'complaint_id': complaint_id,
            'complaint_type': complaint_type,
            'escalation_needed': emotion in ['angry', 'frustrated'],
            'follow_up_required': True,
            'estimated_resolution_time': self.get_resolution_time(complaint_type)
        }
    
    def handle_product_inquiry(self, session_id: str, query: str, 
                              user_context: Dict) -> Dict:
        """处理产品咨询"""
        # 识别产品
        products = self.extract_products_from_query(query)
        
        # 检索产品信息
        product_info = []
        for product in products:
            info = self.knowledge_base.get_product_info(product)
            if info:
                product_info.append(info)
        
        # 构建产品上下文
        context = self.build_product_context(query, product_info, user_context)
        
        # 生成产品咨询回答
        response = self.response_generator.generate_product_response(
            query, context, products
        )
        
        # 生成相关产品推荐
        recommendations = self.generate_product_recommendations(products, user_context)
        
        # 保存对话
        self.conversation_manager.add_message(session_id, query, response, 'product_inquiry')
        
        return {
            'response': response,
            'intent': 'product_inquiry',
            'products_found': products,
            'recommendations': recommendations,
            'sources': self.format_product_sources(product_info)
        }
    
    def generate_empathetic_response(self, complaint: str, emotion: str) -> str:
        """生成共情回应"""
        empathetic_templates = {
            'angry': "我理解您的愤怒,给您带来这样的体验我深感抱歉。我会认真对待您的问题并尽快帮助您解决。",
            'frustrated': "我理解您的沮丧,这种情况确实很令人困扰。让我来帮助您解决这个问题。",
            'disappointed': "很抱歉让您感到失望,我们重视每一位客户的体验。请告诉我具体情况,我会尽力协助您。",
            'neutral': "感谢您的反馈,我理解您的 concerns。我会认真处理您提到的问题。"
        }
        
        base_response = empathetic_templates.get(emotion, empathetic_templates['neutral'])
        return base_response

class ConversationManager:
    def __init__(self):
        self.conversations = {}
        self.max_history_length = 20
    
    def get_history(self, session_id: str) -> List[Dict]:
        """获取对话历史"""
        if session_id not in self.conversations:
            self.conversations[session_id] = {
                'session_id': session_id,
                'created_at': datetime.now(),
                'messages': []
            }
        
        return self.conversations[session_id]['messages']
    
    def add_message(self, session_id: str, user_message: str, 
                   bot_response: str, intent: str):
        """添加对话消息"""
        if session_id not in self.conversations:
            self.get_history(session_id)  # 创建新会话
        
        message = {
            'timestamp': datetime.now(),
            'user_message': user_message,
            'bot_response': bot_response,
            'intent': intent
        }
        
        self.conversations[session_id]['messages'].append(message)
        
        # 限制历史长度
        if len(self.conversations[session_id]['messages']) > self.max_history_length:
            self.conversations[session_id]['messages'] = \
                self.conversations[session_id]['messages'][-self.max_history_length:]
    
    def get_conversation_summary(self, session_id: str) -> Dict:
        """获取对话摘要"""
        if session_id not in self.conversations:
            return {}
        
        conversation = self.conversations[session_id]
        messages = conversation['messages']
        
        return {
            'session_id': session_id,
            'duration': (datetime.now() - conversation['created_at']).total_seconds(),
            'message_count': len(messages),
            'intents': list(set(msg['intent'] for msg in messages)),
            'last_activity': messages[-1]['timestamp'] if messages else None
        }

class IntentClassifier:
    def __init__(self):
        self.intent_patterns = {
            'complaint': ['投诉', '抱怨', '不满', '问题', '错误', '故障', '糟糕'],
            'technical_support': ['技术', '系统', '软件', '硬件', '网络', '连接', '设置'],
            'product_inquiry': ['产品', '价格', '功能', '规格', '购买', '订单'],
            'general_inquiry': ['咨询', '询问', '了解', '信息', '帮助'],
            'billing': ['账单', '费用', '付款', '退款', '发票'],
            'shipping': ['配送', '快递', '物流', '发货', '收货']
        }
    
    def classify(self, message: str, history: List[Dict] = None) -> str:
        """分类意图"""
        message_lower = message.lower()
        
        # 基于关键词分类
        for intent, keywords in self.intent_patterns.items():
            for keyword in keywords:
                if keyword in message_lower:
                    return intent
        
        # 基于历史对话上下文
        if history and len(history) > 0:
            last_intent = history[-1]['intent']
            if last_intent in ['complaint', 'technical_support']:
                # 延续之前的意图
                return last_intent
        
        # 默认意图
        return 'general_inquiry'

七、大模型开发实战

7.1 Deepseek模型集成

Deepseek是一个优秀的国产大语言模型,具有强大的中文理解能力和合理的价格优势。本节介绍如何将Deepseek模型集成到RAG系统中,包括API调用封装、提示管理、流式处理等完整功能,以及成本控制和性能优化策略:

Deepseek API封装:

import openai
import time
from typing import List, Dict, Optional, Generator
import json

class DeepSeekLLMService:
    def __init__(self, api_key: str, base_url: str = "https://api.deepseek.com"):
        self.client = openai.OpenAI(
            api_key=api_key,
            base_url=base_url
        )
        self.model = "deepseek-chat"
        self.request_count = 0
        self.total_tokens = 0
        self.total_cost = 0.0
    
    def generate_response(self, prompt: str, 
                         temperature: float = 0.7,
                         max_tokens: int = 1000,
                         top_p: float = 0.9) -> str:
        """生成回答"""
        try:
            response = self.client.chat.completions.create(
                model=self.model,
                messages=[{"role": "user", "content": prompt}],
                temperature=temperature,
                max_tokens=max_tokens,
                top_p=top_p
            )
            
            # 更新统计信息
            self.request_count += 1
            self.total_tokens += response.usage.total_tokens
            self.total_cost += self.calculate_cost(response.usage)
            
            return response.choices[0].message.content
            
        except Exception as e:
            raise Exception(f"Deepseek API调用失败: {e}")
    
    def chat_with_history(self, messages: List[Dict], 
                         temperature: float = 0.7,
                         max_tokens: int = 1000) -> str:
        """带历史记录的对话"""
        try:
            response = self.client.chat.completions.create(
                model=self.model,
                messages=messages,
                temperature=temperature,
                max_tokens=max_tokens
            )
            
            self.request_count += 1
            self.total_tokens += response.usage.total_tokens
            self.total_cost += self.calculate_cost(response.usage)
            
            return response.choices[0].message.content
            
        except Exception as e:
            raise Exception(f"Deepseek API调用失败: {e}")
    
    def stream_response(self, prompt: str,
                       temperature: float = 0.7,
                       max_tokens: int = 1000) -> Generator[str, None, None]:
        """流式生成回答"""
        try:
            stream = self.client.chat.completions.create(
                model=self.model,
                messages=[{"role": "user", "content": prompt}],
                temperature=temperature,
                max_tokens=max_tokens,
                stream=True
            )
            
            full_content = ""
            for chunk in stream:
                if chunk.choices[0].delta.content:
                    content = chunk.choices[0].delta.content
                    full_content += content
                    yield content
            
            # 更新统计(估算)
            self.request_count += 1
            self.total_tokens += len(full_content.split()) * 1.3  # 估算token数
            
        except Exception as e:
            raise Exception(f"Deepseek流式调用失败: {e}")
    
    def calculate_cost(self, usage) -> float:
        """计算费用(基于Deepseek定价)"""
        # Deepseek定价示例(实际价格请查看官方文档)
        input_price = 0.14 / 1000000  # 输入价格 $0.14/1M tokens
        output_price = 0.28 / 1000000  # 输出价格 $0.28/1M tokens
        
        return (usage.prompt_tokens * input_price + 
                usage.completion_tokens * output_price)
    
    def get_stats(self) -> Dict:
        """获取使用统计"""
        return {
            'request_count': self.request_count,
            'total_tokens': self.total_tokens,
            'total_cost': self.total_cost,
            'model': self.model,
            'avg_tokens_per_request': self.total_tokens / max(self.request_count, 1)
        }
    
    def reset_stats(self):
        """重置统计"""
        self.request_count = 0
        self.total_tokens = 0
        self.total_cost = 0.0

class DeepSeekPromptManager:
    def __init__(self):
        self.templates = self.load_templates()
    
    def load_templates(self) -> Dict:
        """加载提示模板"""
        return {
            'rag_qa': """你是一个专业的问答助手。请基于提供的上下文信息准确回答用户问题。

上下文信息:
{context}

用户问题:{question}

回答要求:
1. 紧密基于上下文信息
2. 如果上下文中没有相关信息,请明确说明
3. 回答简洁明了,条理清晰
4. 引用具体的信息来源

回答:""",
            
            'technical_support': """你是一个专业的技术支持助手。请基于以下技术文档和故障信息帮助用户解决问题。

技术文档:
{context}

用户问题描述:{question}

请提供:
1. 问题的可能原因分析
2. 具体的解决步骤
3. 预防措施
4. 如果问题复杂,建议的联系支持方式

技术支持回答:""",
            
            'customer_service': """你是一个友好的客服助手。请基于产品信息和用户查询提供专业、友好的回复。

产品信息:
{context}

用户咨询:{question}

回复要求:
1. 语气友好、专业
2. 直接回答用户问题
3. 提供有用的额外信息
4. 保持简洁易懂

客服回复:""",
            
            'summarization': """请对以下内容进行总结,提炼关键信息。

内容:
{content}

总结要求:
1. 提取主要观点和关键信息
2. 保持逻辑清晰
3. 控制在200字以内

总结:"""
        }
    
    def build_prompt(self, template_name: str, **kwargs) -> str:
        """构建提示"""
        if template_name not in self.templates:
            raise ValueError(f"未知的模板名称: {template_name}")
        
        template = self.templates[template_name]
        return template.format(**kwargs)
    
    def add_template(self, name: str, template: str):
        """添加自定义模板"""
        self.templates[name] = template
    
    def list_templates(self) -> List[str]:
        """列出所有模板"""
        return list(self.templates.keys())

7.2 模型性能优化

在实际生产环境中,模型性能优化是确保系统稳定运行的关键。通过智能缓存、速率限制、重试机制等技术,可以显著提升系统的响应速度和稳定性,同时有效控制API调用成本:

智能缓存系统:

import hashlib
import pickle
import time
from typing import Dict, Any, Optional
from dataclasses import dataclass

@dataclass
class CacheConfig:
    max_size: int = 1000
    ttl: int = 3600  # 缓存时间(秒)
    persist_to_file: bool = True
    cache_file: str = "llm_cache.pkl"

class SmartCache:
    def __init__(self, config: CacheConfig):
        self.config = config
        self.cache = {}
        self.access_times = {}
        self.hit_count = 0
        self.miss_count = 0
        
        # 加载持久化缓存
        if config.persist_to_file:
            self.load_cache()
    
    def get(self, prompt: str, model_params: Dict = None) -> Optional[str]:
        """获取缓存结果"""
        cache_key = self.generate_cache_key(prompt, model_params)
        
        if cache_key in self.cache:
            cached_item = self.cache[cache_key]
            
            # 检查是否过期
            if time.time() - cached_item['timestamp'] < self.config.ttl:
                self.access_times[cache_key] = time.time()
                self.hit_count += 1
                return cached_item['response']
            else:
                # 删除过期项
                del self.cache[cache_key]
                if cache_key in self.access_times:
                    del self.access_times[cache_key]
        
        self.miss_count += 1
        return None
    
    def put(self, prompt: str, response: str, model_params: Dict = None):
        """存储到缓存"""
        cache_key = self.generate_cache_key(prompt, model_params)
        
        # 如果缓存已满,删除最少使用的项
        if len(self.cache) >= self.config.max_size:
            self.evict_lru()
        
        self.cache[cache_key] = {
            'response': response,
            'timestamp': time.time()
        }
        self.access_times[cache_key] = time.time()
        
        # 持久化
        if self.config.persist_to_file:
            self.save_cache()
    
    def generate_cache_key(self, prompt: str, model_params: Dict = None) -> str:
        """生成缓存键"""
        key_data = {
            'prompt': prompt,
            'params': model_params or {}
        }
        key_str = json.dumps(key_data, sort_keys=True)
        return hashlib.md5(key_str.encode()).hexdigest()
    
    def evict_lru(self):
        """删除最少使用的缓存项"""
        if not self.access_times:
            return
        
        lru_key = min(self.access_times.keys(), key=lambda k: self.access_times[k])
        del self.cache[lru_key]
        del self.access_times[lru_key]
    
    def save_cache(self):
        """保存缓存到文件"""
        try:
            with open(self.config.cache_file, 'wb') as f:
                pickle.dump(self.cache, f)
        except Exception as e:
            print(f"保存缓存失败: {e}")
    
    def load_cache(self):
        """从文件加载缓存"""
        try:
            with open(self.config.cache_file, 'rb') as f:
                loaded_cache = pickle.load(f)
                
            # 过滤过期项
            current_time = time.time()
            for key, item in loaded_cache.items():
                if current_time - item['timestamp'] < self.config.ttl:
                    self.cache[key] = item
                    self.access_times[key] = item['timestamp']
            
            print(f"从文件加载了 {len(self.cache)} 个有效缓存项")
        except FileNotFoundError:
            print("缓存文件不存在,将创建新缓存")
        except Exception as e:
            print(f"加载缓存失败: {e}")
    
    def get_stats(self) -> Dict:
        """获取缓存统计"""
        total_requests = self.hit_count + self.miss_count
        hit_rate = self.hit_count / total_requests if total_requests > 0 else 0
        
        return {
            'cache_size': len(self.cache),
            'hit_count': self.hit_count,
            'miss_count': self.miss_count,
            'hit_rate': hit_rate,
            'max_size': self.config.max_size
        }
    
    def clear(self):
        """清空缓存"""
        self.cache.clear()
        self.access_times.clear()
        self.hit_count = 0
        self.miss_count = 0
        
        if self.config.persist_to_file:
            self.save_cache()

class OptimizedLLMService:
    def __init__(self, llm_service: DeepSeekLLMService, cache_config: CacheConfig = None):
        self.llm_service = llm_service
        self.cache = SmartCache(cache_config or CacheConfig())
        self.rate_limiter = RateLimiter()
    
    def generate_with_cache(self, prompt: str, **kwargs) -> str:
        """带缓存的生成"""
        # 检查缓存
        cached_response = self.cache.get(prompt, kwargs)
        if cached_response:
            return cached_response
        
        # 速率限制检查
        self.rate_limiter.wait_if_needed()
        
        # 调用实际API
        response = self.llm_service.generate_response(prompt, **kwargs)
        
        # 存储到缓存
        self.cache.put(prompt, response, kwargs)
        
        return response
    
    def generate_with_fallback(self, prompt: str, 
                             fallback_prompts: List[str] = None,
                             **kwargs) -> str:
        """带回退机制的生成"""
        try:
            # 尝试原始提示
            return self.generate_with_cache(prompt, **kwargs)
        except Exception as e:
            print(f"主提示失败: {e}")
            
            # 尝试备用提示
            if fallback_prompts:
                for fallback_prompt in fallback_prompts:
                    try:
                        return self.generate_with_cache(fallback_prompt, **kwargs)
                    except Exception as fallback_error:
                        print(f"备用提示失败: {fallback_error}")
                        continue
            
            # 所有提示都失败,返回错误信息
            return "抱歉,服务暂时不可用,请稍后再试。"
    
    def get_combined_stats(self) -> Dict:
        """获取综合统计"""
        return {
            'llm_stats': self.llm_service.get_stats(),
            'cache_stats': self.cache.get_stats(),
            'rate_limiter_stats': self.rate_limiter.get_stats()
        }

class RateLimiter:
    def __init__(self, max_requests_per_minute: int = 60):
        self.max_requests = max_requests_per_minute
        self.requests = []
    
    def wait_if_needed(self):
        """如果需要,等待以符合速率限制"""
        current_time = time.time()
        
        # 清理超过1分钟的请求记录
        self.requests = [req_time for req_time in self.requests 
                         if current_time - req_time < 60]
        
        # 如果达到限制,等待
        if len(self.requests) >= self.max_requests:
            oldest_request = min(self.requests)
            wait_time = 60 - (current_time - oldest_request)
            if wait_time > 0:
                print(f"达到速率限制,等待 {wait_time:.1f} 秒")
                time.sleep(wait_time)
        
        # 记录当前请求
        self.requests.append(current_time)
    
    def get_stats(self) -> Dict:
        """获取速率限制统计"""
        current_time = time.time()
        recent_requests = [req_time for req_time in self.requests 
                          if current_time - req_time < 60]
        
        return {
            'max_requests_per_minute': self.max_requests,
            'recent_requests_count': len(recent_requests),
            'remaining_requests': self.max_requests - len(recent_requests)
        }

八、完整项目案例

8.1 企业级RAG问答系统

系统架构和实现:

from fastapi import FastAPI, HTTPException, BackgroundTasks
from pydantic import BaseModel
from typing import List, Optional, Dict, Any
import uvicorn
import asyncio
from contextlib import asynccontextmanager

# 请求和响应模型
class QueryRequest(BaseModel):
    question: str
    use_hybrid: bool = True
    return_sources: bool = True
    top_k: Optional[int] = None

class QueryResponse(BaseModel):
    answer: str
    question: str
    retrieved_docs_count: int
    used_docs_count: int
    sources: Optional[List[Dict]] = None
    processing_time: float
    error: Optional[str] = None

class BuildKnowledgeBaseRequest(BaseModel):
    data_path: str

class BuildKnowledgeBaseResponse(BaseModel):
    message: str
    document_count: int
    processing_time: float

# 全局变量
rag_system = None

@asynccontextmanager
async def lifespan(app: FastAPI):
    # 启动时初始化
    global rag_system
    
    # 初始化RAG系统
    config = RAGConfig(
        embedding_model="shibing624/text2vec-base-chinese",
        chunk_size=1000,
        top_k=5,
        use_hybrid_search=True
    )
    
    rag_system = BasicRAGSystem(config)
    
    # 这里应该初始化各个组件
    # embedding_service = EmbeddingService(...)
    # vector_store = ChromaVectorStore(...)
    # llm_service = DeepSeekLLMService(...)
    # document_chunker = DocumentChunker(...)
    # rag_system.initialize(embedding_service, vector_store, llm_service, document_chunker)
    
    yield
    
    # 关闭时清理
    if rag_system:
        rag_system.cleanup()

app = FastAPI(
    title="RAG问答系统",
    description="基于SpringAI和Deepseek的检索增强生成系统",
    version="1.0.0",
    lifespan=lifespan
)

@app.get("/")
async def root():
    return {
        "message": "RAG问答系统API",
        "version": "1.0.0",
        "status": "running",
        "rag_system": rag_system is not None
    }

@app.get("/health")
async def health_check():
    return {
        "status": "healthy",
        "timestamp": time.time(),
        "rag_system": rag_system is not None
    }

@app.post("/query", response_model=QueryResponse)
async def query_rag(request: QueryRequest):
    """查询RAG系统"""
    if not rag_system:
        raise HTTPException(status_code=503, detail="RAG系统未初始化")
    
    try:
        start_time = time.time()
        
        # 临时调整top_k
        original_top_k = rag_system.config.top_k
        if request.top_k:
            rag_system.config.top_k = request.top_k
        
        # 执行查询
        result = rag_system.query(
            request.question,
            use_hybrid=request.use_hybrid,
            return_sources=request.return_sources
        )
        
        # 恢复配置
        if request.top_k:
            rag_system.config.top_k = original_top_k
        
        processing_time = time.time() - start_time
        
        return QueryResponse(
            answer=result.get('answer', ''),
            question=result.get('question', request.question),
            retrieved_docs_count=result.get('retrieved_docs_count', 0),
            used_docs_count=result.get('used_docs_count', 0),
            sources=result.get('sources'),
            processing_time=processing_time,
            error=result.get('error')
        )
        
    except Exception as e:
        raise HTTPException(status_code=500, detail=f"查询处理失败: {str(e)}")

@app.post("/build-knowledge-base", response_model=BuildKnowledgeBaseResponse)
async def build_knowledge_base(request: BuildKnowledgeBaseRequest,
                               background_tasks: BackgroundTasks):
    """构建知识库"""
    if not rag_system:
        raise HTTPException(status_code=503, detail="RAG系统未初始化")
    
    data_path = Path(request.data_path)
    if not data_path.exists():
        raise HTTPException(status_code=400, detail=f"数据路径不存在: {request.data_path}")
    
    try:
        start_time = time.time()
        
        # 在后台任务中执行构建
        background_tasks.add_task(rag_system.build_knowledge_base, request.data_path)
        
        processing_time = time.time() - start_time
        
        return BuildKnowledgeBaseResponse(
            message="知识库构建任务已启动,请稍后查看结果",
            document_count=0,  # 后台任务完成后才能知道具体数量
            processing_time=processing_time
        )
        
    except Exception as e:
        raise HTTPException(status_code=500, detail=f"构建任务启动失败: {str(e)}")

@app.get("/stats")
async def get_system_stats():
    """获取系统统计信息"""
    if not rag_system:
        raise HTTPException(status_code=503, detail="RAG系统未初始化")
    
    try:
        stats = rag_system.get_system_stats()
        return {
            "status": "success",
            "stats": stats,
            "timestamp": time.time()
        }
    except Exception as e:
        raise HTTPException(status_code=500, detail=f"获取统计信息失败: {str(e)}")

@app.get("/llm-stats")
async def get_llm_stats():
    """获取LLM使用统计"""
    if not rag_system:
        raise HTTPException(status_code=503, detail="RAG系统未初始化")
    
    try:
        llm_stats = rag_system.llm_service.get_stats()
        return {
            "status": "success",
            "llm_stats": llm_stats,
            "timestamp": time.time()
        }
    except Exception as e:
        raise HTTPException(status_code=500, detail=f"获取LLM统计失败: {str(e)}")

if __name__ == "__main__":
    uvicorn.run(app, host="0.0.0.0", port=8000)

8.2 项目配置文件

# config.yaml
rag:
  embedding:
    model_name: "shibing624/text2vec-base-chinese"
    cache_enabled: true
    batch_size: 32
  
  vector_store:
    type: "chroma"
    persist_directory: "./data/vector_db"
    collection_name: "documents"
  
  llm:
    provider: "deepseek"
    model_name: "deepseek-chat"
    api_key: "${DEEPSEEK_API_KEY}"
    base_url: "https://api.deepseek.com"
    temperature: 0.7
    max_tokens: 1000
  
  chunking:
    chunk_size: 1000
    chunk_overlap: 200
    use_semantic_chunking: true
  
  retrieval:
    top_k: 5
    similarity_threshold: 0.7
    use_hybrid_search: true
    enable_reranking: false
  
  caching:
    enabled: true
    max_size: 1000
    ttl: 3600
  
  rate_limiting:
    max_requests_per_minute: 60

logging:
  level: "INFO"
  format: "%(asctime)s - %(name)s - %(levelname)s - %(message)s"
  file: "./logs/rag_system.log"

server:
  host: "0.0.0.0"
  port: 8000
  workers: 1

security:
  enable_auth: false
  secret_key: "your-secret-key-here"

8.3 部署脚本

#!/bin/bash
# deploy.sh

echo "开始部署RAG系统..."

# 创建必要目录
mkdir -p data/vector_db
mkdir -p logs
mkdir -p cache

# 安装依赖
pip install -r requirements.txt

# 下载中文模型
python -c "import spacy; spacy.cli.download('zh_core_web_sm')" || echo "Spacy模型下载失败,请手动下载"

# 设置环境变量
export DEEPSEEK_API_KEY="your-deepseek-api-key"

# 启动服务
python -m uvicorn main:app --host 0.0.0.0 --port 8000 --reload

echo "RAG系统部署完成!"
echo "访问 http://localhost:8000/docs 查看API文档"

8.4 依赖文件(requirements.txt)

fastapi==0.104.1
uvicorn==0.24.0
pydantic==2.5.0
sentence-transformers==2.2.2
chromadb==0.4.18
openai==1.3.7
python-multipart==0.0.6
jieba==0.42.1
nltk==3.8.1
spacy==3.7.2
PyPDF2==3.0.1
python-docx==0.8.11
PyYAML==6.0.1
requests==2.31.0
numpy==1.24.3
torch==2.1.1
transformers==4.36.0
Logo

有“AI”的1024 = 2048,欢迎大家加入2048 AI社区

更多推荐