【大模型】RAG详细讲解+开发实战笔记
RAG(Retrieval-Augmented Generation,检索增强生成)是一种结合信息检索和文本生成的AI技术架构。它通过先从知识库中检索相关信息,再将检索结果作为上下文提供给大语言模型进行回答生成,从而提高回答的准确性和可靠性。2.2 详细步骤RAG系统的核心工作流程包含五个关键步骤,每个步骤都有其特定的功能和技术实现:用户查询接收查询向量化相似度检索上下文构建答案生成以下是一个完整
大模型RAG(检索增强生成)知识详解
目录
一、RAG概述
1.1 什么是RAG?
RAG(Retrieval-Augmented Generation,检索增强生成)是一种结合信息检索和文本生成的AI技术架构。它通过先从知识库中检索相关信息,再将检索结果作为上下文提供给大语言模型进行回答生成,从而提高回答的准确性和可靠性。
1.2 RAG的价值
- 减少幻觉:通过真实数据支撑,降低模型产生错误信息的概率
- 增强时效性:可以快速更新知识库,保持信息最新
- 提高透明度:回答基于可验证的数据源
- 降低成本:相比微调,RAG实现成本更低
- 保护隐私:私有数据无需参与模型训练
二、RAG工作原理
2.1 基本工作流程
用户提问 → 向量化检索 → 相关文档获取 → 上下文构建 → LLM生成 → 答案输出
2.2 详细步骤
RAG系统的核心工作流程包含五个关键步骤,每个步骤都有其特定的功能和技术实现:
-
用户查询接收
- 接收用户自然语言查询
- 对查询进行预处理和清理
-
查询向量化
- 使用embedding模型将查询转换为向量
- 保持与文档向量的相同语义空间
-
相似度检索
- 在向量数据库中进行相似度搜索
- 返回最相关的文档片段
-
上下文构建
- 将检索到的文档与原始查询组合
- 构建完整的prompt模板
-
答案生成
- 将构建的上下文输入LLM
- 生成基于检索结果的回答
以下是一个完整的RAG工作流程实现示例,展示了如何将这些步骤整合到一个可运行的系统中:
import numpy as np
from typing import List, Dict, Any
import time
class RAGWorkflow:
"""RAG工作流程主类,实现从用户查询到答案生成的完整流程"""
def __init__(self, embedding_model, vector_store, llm_service):
"""
初始化RAG工作流程
Args:
embedding_model: 文本向量化模型
vector_store: 向量数据库实例
llm_service: 大语言模型服务
"""
self.embedding_model = embedding_model
self.vector_store = vector_store
self.llm_service = llm_service
def process_query(self, user_query: str, top_k: int = 5) -> Dict[str, Any]:
"""
处理用户查询的主流程方法
Args:
user_query: 用户输入的查询问题
top_k: 检索返回的最相关文档数量
Returns:
包含答案、相关文档和元数据的完整响应
"""
start_time = time.time()
# 步骤1: 查询预处理
processed_query = self._preprocess_query(user_query)
# 步骤2: 查询向量化
query_vector = self._vectorize_query(processed_query)
# 步骤3: 相似度检索
relevant_docs = self._retrieve_documents(query_vector, top_k)
# 步骤4: 上下文构建
context = self._build_context(relevant_docs, processed_query)
# 步骤5: 答案生成
answer = self._generate_answer(user_query, context)
processing_time = time.time() - start_time
return {
'answer': answer,
'original_query': user_query,
'processed_query': processed_query,
'retrieved_documents': relevant_docs,
'context': context,
'processing_time': processing_time,
'top_k': top_k
}
def _preprocess_query(self, query: str) -> str:
"""
步骤1: 查询预处理
清理和标准化用户输入的查询文本
"""
# 去除多余空格和特殊字符
cleaned_query = ' '.join(query.split())
# 转换为小写(对于英文)
cleaned_query = cleaned_query.lower()
# 可以添加更多预处理逻辑,如拼写检查、查询扩展等
return cleaned_query
def _vectorize_query(self, query: str) -> np.ndarray:
"""
步骤2: 查询向量化
将文本查询转换为数值向量表示
"""
# 使用预训练的embedding模型将查询转换为向量
query_vector = self.embedding_model.embed_query(query)
# 确保向量维度正确
if query_vector.ndim == 1:
query_vector = query_vector.reshape(1, -1)
return query_vector
def _retrieve_documents(self, query_vector: np.ndarray, top_k: int) -> List[Dict]:
"""
步骤3: 相似度检索
在向量数据库中查找与查询最相关的文档
"""
# 在向量数据库中进行相似度搜索
search_results = self.vector_store.search(
query_embeddings=query_vector.tolist(),
n_results=top_k
)
# 格式化搜索结果
retrieved_docs = []
for i in range(len(search_results['documents'][0])):
doc = {
'content': search_results['documents'][0][i],
'metadata': search_results['metadatas'][0][i] if search_results['metadatas'] else {},
'similarity_score': 1 - search_results['distances'][0][i], # 转换为相似度
'doc_id': search_results['ids'][0][i] if search_results['ids'] else f"doc_{i}"
}
retrieved_docs.append(doc)
return retrieved_docs
def _build_context(self, documents: List[Dict], query: str, max_length: int = 2000) -> str:
"""
步骤4: 上下文构建
将检索到的文档组织成适合LLM输入的上下文格式
"""
if not documents:
return "未找到相关文档来回答此问题。"
context_parts = []
current_length = 0
# 按相似度排序文档
sorted_docs = sorted(documents, key=lambda x: x['similarity_score'], reverse=True)
for i, doc in enumerate(sorted_docs):
doc_content = f"文档 {i+1} (相似度: {doc['similarity_score']:.3f}):
{doc['content']}"
doc_length = len(doc_content)
# 检查是否超过最大长度限制
if current_length + doc_length > max_length and context_parts:
# 如果添加这个文档会超过限制,尝试截断
remaining_space = max_length - current_length - 50
if remaining_space > 100:
truncated_content = doc['content'][:remaining_space] + "..."
context_parts.append(f"文档 {i+1} (已截断):
{truncated_content}")
break
context_parts.append(doc_content)
current_length += doc_length
return '
'.join(context_parts)
def _generate_answer(self, original_query: str, context: str) -> str:
"""
步骤5: 答案生成
使用LLM基于上下文生成回答
"""
# 构建提示词模板
prompt = f"""基于以下提供的文档内容,请准确回答用户的问题。
上下文信息:
{context}
用户问题:{original_query}
请基于上述上下文信息回答问题。如果上下文中没有足够的信息来回答问题,请明确说明。回答要准确、简洁、有帮助。
回答:"""
# 调用LLM生成回答
answer = self.llm_service.generate_response(prompt)
return answer
def batch_process_queries(self, queries: List[str], top_k: int = 5) -> List[Dict]:
"""
批量处理多个查询
"""
results = []
for query in queries:
try:
result = self.process_query(query, top_k)
results.append(result)
except Exception as e:
# 记录错误但继续处理其他查询
error_result = {
'answer': f"处理查询时出错: {str(e)}",
'original_query': query,
'error': str(e),
'processing_time': 0
}
results.append(error_result)
return results
def get_workflow_stats(self) -> Dict[str, Any]:
"""
获取工作流程统计信息
"""
return {
'embedding_model': type(self.embedding_model).__name__,
'vector_store': type(self.vector_store).__name__,
'llm_service': type(self.llm_service).__name__,
'components_initialized': all([
self.embedding_model is not None,
self.vector_store is not None,
self.llm_service is not None
])
}
class RAGWorkflowManager:
"""RAG工作流程管理器,提供高级管理功能"""
def __init__(self, workflow: RAGWorkflow):
self.workflow = workflow
self.query_history = []
self.performance_metrics = {
'total_queries': 0,
'successful_queries': 0,
'failed_queries': 0,
'average_processing_time': 0,
'total_processing_time': 0
}
def process_and_log_query(self, query: str, top_k: int = 5) -> Dict[str, Any]:
"""
处理查询并记录历史和性能指标
"""
try:
result = self.workflow.process_query(query, top_k)
# 记录查询历史
self.query_history.append({
'timestamp': time.time(),
'query': query,
'success': True,
'processing_time': result['processing_time'],
'retrieved_docs_count': len(result.get('retrieved_documents', []))
})
# 更新性能指标
self._update_performance_metrics(result['processing_time'], success=True)
return result
except Exception as e:
# 记录失败的查询
self.query_history.append({
'timestamp': time.time(),
'query': query,
'success': False,
'error': str(e)
})
# 更新性能指标
self._update_performance_metrics(0, success=False)
raise e
def _update_performance_metrics(self, processing_time: float, success: bool):
"""更新性能指标"""
self.performance_metrics['total_queries'] += 1
self.performance_metrics['total_processing_time'] += processing_time
if success:
self.performance_metrics['successful_queries'] += 1
else:
self.performance_metrics['failed_queries'] += 1
self.performance_metrics['average_processing_time'] = (
self.performance_metrics['total_processing_time'] /
self.performance_metrics['total_queries']
)
def get_performance_report(self) -> Dict[str, Any]:
"""获取性能报告"""
return {
'performance_metrics': self.performance_metrics.copy(),
'success_rate': (
self.performance_metrics['successful_queries'] /
max(self.performance_metrics['total_queries'], 1)
),
'recent_queries': self.query_history[-10:], # 最近10个查询
'query_history_size': len(self.query_history)
}
三、RAG核心组件
3.1 数据源(Data Source)
类型:
- 结构化数据:数据库、表格、API
- 半结构化数据:JSON、XML、Markdown
- 非结构化数据:PDF、Word、网页、图片
数据预处理实现:
import re
import nltk
import spacy
from typing import List, Dict, Optional
from pathlib import Path
import PyPDF2
from docx import Document
import json
import pandas as pd
class DocumentProcessor:
def __init__(self):
# 加载自然语言处理模型
try:
self.nlp = spacy.load("zh_core_web_sm")
except:
print("请安装spacy中文模型: python -m spacy download zh_core_web_sm")
self.nlp = None
# 下载NLTK数据
nltk.download('punkt', quiet=True)
nltk.download('stopwords', quiet=True)
def load_documents(self, data_path: str) -> List[Dict]:
"""加载多种格式的文档"""
documents = []
data_dir = Path(data_path)
if not data_dir.exists():
raise FileNotFoundError(f"数据路径不存在: {data_path}")
# 支持的文件类型
file_patterns = {
'.pdf': self.load_pdf,
'.docx': self.load_docx,
'.md': self.load_markdown,
'.txt': self.load_text,
'.json': self.load_json,
'.csv': self.load_csv,
'.html': self.load_html
}
for file_path in data_dir.rglob('*'):
if file_path.is_file() and file_path.suffix.lower() in file_patterns:
try:
loader_func = file_patterns[file_path.suffix.lower()]
docs = loader_func(file_path)
documents.extend(docs)
print(f"成功加载文件: {file_path}")
except Exception as e:
print(f"加载文件失败 {file_path}: {e}")
return documents
def load_pdf(self, file_path: Path) -> List[Dict]:
"""加载PDF文件"""
documents = []
with open(file_path, 'rb') as file:
pdf_reader = PyPDF2.PdfReader(file)
for page_num, page in enumerate(pdf_reader.pages):
text = page.extract_text()
if text.strip():
documents.append({
'content': text.strip(),
'metadata': {
'file_name': file_path.name,
'file_type': 'pdf',
'page_number': page_num + 1,
'total_pages': len(pdf_reader.pages),
'char_count': len(text.strip())
}
})
return documents
def load_docx(self, file_path: Path) -> List[Dict]:
"""加载Word文档"""
documents = []
doc = Document(file_path)
doc_metadata = {
'file_name': file_path.name,
'file_type': 'docx',
'paragraphs_count': len(doc.paragraphs)
}
for i, paragraph in enumerate(doc.paragraphs):
if paragraph.text.strip():
documents.append({
'content': paragraph.text.strip(),
'metadata': {
**doc_metadata,
'paragraph_index': i,
'char_count': len(paragraph.text)
}
})
return documents
def load_markdown(self, file_path: Path) -> List[Dict]:
"""加载Markdown文件"""
documents = []
with open(file_path, 'r', encoding='utf-8') as f:
content = f.read()
# 按标题分割文档
sections = self.split_markdown_by_headers(content)
for section in sections:
documents.append({
'content': section['content'],
'metadata': {
'file_name': file_path.name,
'file_type': 'markdown',
'section_title': section.get('title', '无标题'),
'section_level': section.get('level', 0),
'char_count': len(section['content'])
}
})
return documents
def split_markdown_by_headers(self, content: str) -> List[Dict]:
"""按标题分割Markdown内容"""
sections = []
lines = content.split('\n')
current_section = []
current_title = "无标题"
current_level = 0
for line in lines:
if line.startswith('#'):
# 保存当前section
if current_section:
sections.append({
'title': current_title,
'level': current_level,
'content': '\n'.join(current_section)
})
# 开始新section
current_title = line.strip()
current_level = len(line) - len(line.lstrip('#'))
current_section = []
else:
current_section.append(line)
# 添加最后一个section
if current_section:
sections.append({
'title': current_title,
'level': current_level,
'content': '\n'.join(current_section)
})
return sections
def clean_text(self, text: str) -> str:
"""清理文本"""
# 移除多余的空白字符
text = re.sub(r'\s+', ' ', text)
# 移除特殊字符(保留中文、英文、数字、基本标点)
text = re.sub(r'[^\u4e00-\u9fa5\w\s.,!?;:()[\]{}"\'-]', '', text)
# 移除行首行尾空白
text = text.strip()
return text
def extract_keywords(self, text: str, max_keywords: int = 10) -> List[str]:
"""提取关键词"""
if self.nlp is None:
# 简单的词频统计作为后备方案
words = re.findall(r'\b\w+\b', text.lower())
word_freq = {}
for word in words:
if len(word) > 2: # 过滤短词
word_freq[word] = word_freq.get(word, 0) + 1
return sorted(word_freq.keys(), key=lambda x: word_freq[x],
reverse=True)[:max_keywords]
# 使用spacy进行关键词提取
doc = self.nlp(text)
keywords = []
# 提取名词和专有名词
for token in doc:
if (token.pos_ in ['NOUN', 'PROPN'] and
not token.is_stop and
len(token.text) > 2):
keywords.append(token.lemma_.lower())
# 去重并限制数量
return list(set(keywords))[:max_keywords]
3.2 向量化模型(Embedding Model)
向量化服务实现:
import numpy as np
from sentence_transformers import SentenceTransformer
from typing import List, Dict, Any, Optional
import pickle
import hashlib
import time
from dataclasses import dataclass
@dataclass
class EmbeddingConfig:
model_name: str = "shibing624/text2vec-base-chinese"
cache_enabled: bool = True
cache_file: str = "embedding_cache.pkl"
batch_size: int = 32
max_length: int = 512
class EmbeddingService:
def __init__(self, config: EmbeddingConfig):
self.config = config
self.model = SentenceTransformer(config.model_name)
self.cache = {}
# 加载缓存
if config.cache_enabled:
self.load_cache()
print(f"已加载embedding模型: {config.model_name}")
def load_cache(self):
"""加载向量缓存"""
try:
with open(self.config.cache_file, 'rb') as f:
self.cache = pickle.load(f)
print(f"已加载缓存,包含 {len(self.cache)} 个向量")
except FileNotFoundError:
self.cache = {}
print("缓存文件不存在,将创建新缓存")
def save_cache(self):
"""保存向量缓存"""
try:
with open(self.config.cache_file, 'wb') as f:
pickle.dump(self.cache, f)
print(f"已保存 {len(self.cache)} 个向量到缓存")
except Exception as e:
print(f"保存缓存失败: {e}")
def embed_texts(self, texts: List[str], use_cache: bool = True) -> np.ndarray:
"""批量文本向量化"""
embeddings = []
uncached_texts = []
uncached_indices = []
# 检查缓存
if use_cache:
for i, text in enumerate(texts):
text_hash = self.get_text_hash(text)
if text_hash in self.cache:
embeddings.append(self.cache[text_hash])
else:
uncached_texts.append(text)
uncached_indices.append(i)
else:
uncached_texts = texts
uncached_indices = list(range(len(texts)))
# 计算未缓存的向量
if uncached_texts:
print(f"计算 {len(uncached_texts)} 个新向量...")
start_time = time.time()
new_embeddings = self.model.encode(
uncached_texts,
batch_size=self.config.batch_size,
show_progress_bar=True,
normalize_embeddings=True
)
# 更新缓存
for i, embedding in enumerate(new_embeddings):
original_index = uncached_indices[i]
text = uncached_texts[i]
text_hash = self.get_text_hash(text)
if use_cache:
self.cache[text_hash] = embedding
# 确保embedding在正确位置
while len(embeddings) <= original_index:
embeddings.append(None)
embeddings[original_index] = embedding
# 保存缓存
if use_cache:
self.save_cache()
elapsed_time = time.time() - start_time
print(f"向量化完成,耗时: {elapsed_time:.2f}秒")
return np.array(embeddings)
def embed_query(self, query: str) -> np.ndarray:
"""单个查询向量化"""
query_hash = self.get_text_hash(query)
if self.config.cache_enabled and query_hash in self.cache:
return self.cache[query_hash]
embedding = self.model.encode([query], normalize_embeddings=True)[0]
if self.config.cache_enabled:
self.cache[query_hash] = embedding
return embedding
def get_text_hash(self, text: str) -> str:
"""获取文本哈希值用于缓存"""
return hashlib.md5(text.encode('utf-8')).hexdigest()
def compute_similarity(self, query_embedding: np.ndarray,
doc_embeddings: np.ndarray) -> np.ndarray:
"""计算相似度"""
# 使用余弦相似度
similarities = np.dot(doc_embeddings, query_embedding)
return similarities
def find_most_similar(self, query: str, documents: List[str],
top_k: int = 5) -> List[Dict]:
"""查找最相似的文档"""
query_embedding = self.embed_query(query)
doc_embeddings = self.embed_texts(documents)
similarities = self.compute_similarity(query_embedding, doc_embeddings)
# 获取top-k结果
top_indices = np.argsort(similarities)[-top_k:][::-1]
results = []
for idx in top_indices:
results.append({
'index': int(idx),
'content': documents[idx],
'similarity_score': float(similarities[idx])
})
return results
def get_embedding_dimension(self) -> int:
"""获取向量维度"""
return self.model.get_sentence_embedding_dimension()
def benchmark_performance(self, test_texts: List[str]) -> Dict:
"""性能基准测试"""
print("开始性能基准测试...")
# 单文本测试
start_time = time.time()
single_embedding = self.embed_query(test_texts[0])
single_time = time.time() - start_time
# 批量测试
batch_sizes = [1, 8, 16, 32]
batch_results = {}
for batch_size in batch_sizes:
batch_texts = test_texts[:batch_size]
start_time = time.time()
batch_embeddings = self.embed_texts(batch_texts, use_cache=False)
batch_time = time.time() - start_time
batch_results[batch_size] = {
'time': batch_time,
'throughput': batch_size / batch_time
}
return {
'embedding_dimension': self.get_embedding_dimension(),
'single_text_time': single_time,
'batch_results': batch_results,
'cache_size': len(self.cache)
}
3.3 文档分块策略
智能分块实现:
import re
import jieba
from typing import List, Dict, Tuple, Optional
import nltk
from dataclasses import dataclass
@dataclass
class ChunkingConfig:
chunk_size: int = 1000
chunk_overlap: int = 200
min_chunk_size: int = 100
max_chunk_size: int = 2000
use_semantic_chunking: bool = True
separators: List[str] = None
class DocumentChunker:
def __init__(self, config: ChunkingConfig):
self.config = config
if config.separators is None:
self.config.separators = ['\n\n', '\n', '。', '!', '?', ';', ';']
# 下载NLTK数据
try:
nltk.data.find('tokenizers/punkt')
except LookupError:
nltk.download('punkt', quiet=True)
def chunk_documents(self, documents: List[Dict]) -> List[Dict]:
"""分块文档列表"""
all_chunks = []
for doc_idx, document in enumerate(documents):
content = document['content']
metadata = document.get('metadata', {})
# 尝试不同的分块策略
chunks = self.chunk_single_document(content, metadata, doc_idx)
all_chunks.extend(chunks)
print(f"共生成 {len(all_chunks)} 个文档块")
return all_chunks
def chunk_single_document(self, content: str, metadata: Dict, doc_idx: int) -> List[Dict]:
"""分块单个文档"""
if self.config.use_semantic_chunking:
# 尝试语义分块
chunks = self.semantic_chunking(content, metadata, doc_idx)
if len(chunks) > 1:
return chunks
# 回退到递归分块
chunks = self.recursive_character_chunking(content, metadata, doc_idx)
if len(chunks) > 1:
return chunks
# 最后使用固定大小分块
return self.fixed_size_chunking(content, metadata, doc_idx)
def semantic_chunking(self, content: str, metadata: Dict, doc_idx: int) -> List[Dict]:
"""语义分块"""
sentences = self.split_sentences(content)
chunks = []
current_chunk = []
current_length = 0
for i, sentence in enumerate(sentences):
sentence_length = len(sentence)
# 如果当前句子会使chunk超过限制,创建新chunk
if (current_length + sentence_length > self.config.chunk_size and
current_chunk and len(current_chunk) > 1):
chunk_text = ' '.join(current_chunk)
chunks.append(self.create_chunk(chunk_text, metadata, doc_idx, len(chunks)))
# 创建重叠
overlap_sentences = self.get_overlap_sentences(current_chunk, i, sentences)
current_chunk = overlap_sentences
current_length = sum(len(s) for s in current_chunk)
current_chunk.append(sentence)
current_length += sentence_length
# 添加最后一个chunk
if current_chunk:
chunk_text = ' '.join(current_chunk)
chunks.append(self.create_chunk(chunk_text, metadata, doc_idx, len(chunks)))
return chunks
def split_sentences(self, text: str) -> List[str]:
"""分割句子"""
# 对于英文
if re.search(r'[a-zA-Z]', text):
sentences = nltk.sent_tokenize(text)
else:
# 对于中文,使用jieba和正则表达式
sentences = re.split(r'[。!?;]+', text)
# 清理和过滤
sentences = [s.strip() for s in sentences if s.strip() and len(s.strip()) > 5]
return sentences
def get_overlap_sentences(self, current_chunk: List[str], current_index: int,
all_sentences: List[str]) -> List[str]:
"""获取重叠句子"""
overlap_sentences = []
overlap_length = 0
# 从后向前添加句子直到达到重叠大小
for sentence in reversed(current_chunk):
if overlap_length + len(sentence) <= self.config.chunk_overlap:
overlap_sentences.insert(0, sentence)
overlap_length += len(sentence)
else:
break
return overlap_sentences
def recursive_character_chunking(self, content: str, metadata: Dict, doc_idx: int) -> List[Dict]:
"""递归字符分块"""
return self._recursive_split(content, metadata, doc_idx, self.config.separators, 0)
def _recursive_split(self, content: str, metadata: Dict, doc_idx: int,
separators: List[str], separator_index: int) -> List[Dict]:
"""递归分割实现"""
if separator_index >= len(separators):
# 最后使用固定大小分块
return self.fixed_size_chunking(content, metadata, doc_idx)
separator = separators[separator_index]
parts = content.split(separator)
chunks = []
current_chunk = ""
for part in parts:
test_chunk = current_chunk + (separator if current_chunk else "")
test_chunk += part
if len(test_chunk) <= self.config.chunk_size:
current_chunk = test_chunk
else:
if current_chunk:
# 递归分割当前块
sub_chunks = self._recursive_split(
current_chunk, metadata, doc_idx, separators, separator_index + 1
)
chunks.extend(sub_chunks)
current_chunk = part
else:
# 单个part太长,继续递归
sub_chunks = self._recursive_split(
part, metadata, doc_idx, separators, separator_index + 1
)
chunks.extend(sub_chunks)
# 处理最后一个chunk
if current_chunk:
sub_chunks = self._recursive_split(
current_chunk, metadata, doc_idx, separators, separator_index + 1
)
chunks.extend(sub_chunks)
return chunks
def fixed_size_chunking(self, content: str, metadata: Dict, doc_idx: int) -> List[Dict]:
"""固定大小分块"""
chunks = []
content_length = len(content)
start = 0
chunk_idx = 0
while start < content_length:
end = min(start + self.config.chunk_size, content_length)
# 如果不是最后一个chunk,尝试在句子边界分割
if end < content_length:
# 向后查找句子边界
sentence_end = end
for i in range(end, min(end + 100, content_length)):
if content[i] in ['。', '!', '?', '\n\n']:
sentence_end = i + 1
break
# 如果找到了句子边界,使用它
if sentence_end > end:
end = sentence_end
chunk_text = content[start:end].strip()
if chunk_text:
chunks.append(self.create_chunk(chunk_text, metadata, doc_idx, chunk_idx))
chunk_idx += 1
# 计算下一个开始位置(考虑重叠)
start = max(start + self.config.chunk_size - self.config.chunk_overlap, end)
return chunks
def create_chunk(self, content: str, metadata: Dict, doc_idx: int, chunk_idx: int) -> Dict:
"""创建文档块"""
chunk_metadata = metadata.copy()
chunk_metadata.update({
'chunk_index': chunk_idx,
'char_count': len(content),
'word_count': len(content.split()),
'created_at': time.time(),
'parent_doc_index': doc_idx
})
return {
'content': content,
'metadata': chunk_metadata
}
def analyze_chunks(self, chunks: List[Dict]) -> Dict:
"""分析分块结果"""
if not chunks:
return {}
sizes = [len(chunk['content']) for chunk in chunks]
return {
'total_chunks': len(chunks),
'avg_chunk_size': np.mean(sizes),
'min_chunk_size': min(sizes),
'max_chunk_size': max(sizes),
'total_characters': sum(sizes),
'size_distribution': {
'small': sum(1 for s in sizes if s < self.config.min_chunk_size),
'medium': sum(1 for s in sizes if self.config.min_chunk_size <= s <= self.config.chunk_size),
'large': sum(1 for s in sizes if s > self.config.chunk_size)
}
}
3.4 向量数据库
Chroma向量数据库实现:
import chromadb
from chromadb.config import Settings
from typing import List, Dict, Any, Optional
import uuid
import time
class ChromaVectorStore:
def __init__(self, persist_directory: str = "./vector_db",
collection_name: str = "documents"):
self.persist_directory = persist_directory
self.collection_name = collection_name
# 初始化Chroma客户端
self.client = chromadb.PersistentClient(path=persist_directory)
# 获取或创建集合
self.collection = self.client.get_or_create_collection(
name=collection_name,
metadata={"hnsw:space": "cosine"}
)
print(f"已连接到向量数据库: {persist_directory}")
print(f"使用集合: {collection_name}")
def add_documents(self, texts: List[str], embeddings: List[List[float]],
metadatas: List[Dict] = None, ids: List[str] = None) -> List[str]:
"""添加文档到向量数据库"""
if len(texts) != len(embeddings):
raise ValueError("文本和向量数量不匹配")
# 生成ID
if ids is None:
ids = [f"doc_{uuid.uuid4().hex[:8]}" for _ in range(len(texts))]
# 生成元数据
if metadatas is None:
metadatas = [{} for _ in range(len(texts))]
# 批量添加
batch_size = 100
added_ids = []
for i in range(0, len(texts), batch_size):
batch_end = min(i + batch_size, len(texts))
self.collection.add(
ids=ids[i:batch_end],
embeddings=embeddings[i:batch_end],
documents=texts[i:batch_end],
metadatas=metadatas[i:batch_end]
)
added_ids.extend(ids[i:batch_end])
print(f"已添加 {len(added_ids)} 个文档到向量数据库")
return added_ids
def search(self, query_embeddings: List[List[float]],
n_results: int = 5, where: Optional[Dict] = None,
where_document: Optional[Dict] = None) -> Dict[str, Any]:
"""搜索向量数据库"""
results = self.collection.query(
query_embeddings=query_embeddings,
n_results=n_results,
where=where,
where_document=where_document
)
# 转换结果格式
return {
'documents': results['documents'][0] if results['documents'] else [],
'metadatas': results['metadatas'][0] if results['metadatas'] else [],
'distances': results['distances'][0] if results['distances'] else [],
'ids': results['ids'][0] if results['ids'] else []
}
def similarity_search(self, query_embedding: List[float],
threshold: float = 0.7, max_results: int = 10) -> List[Dict]:
"""基于阈值的相似度搜索"""
results = self.search([query_embedding], n_results=max_results)
filtered_results = []
for i in range(len(results['documents'])):
similarity_score = 1 - results['distances'][i] # 转换为相似度
if similarity_score >= threshold:
filtered_results.append({
'id': results['ids'][i],
'content': results['documents'][i],
'metadata': results['metadatas'][i],
'similarity_score': similarity_score
})
return filtered_results
def hybrid_search(self, query_embedding: List[float], keyword_query: str,
alpha: float = 0.7, top_k: int = 10) -> List[Dict]:
"""混合搜索(向量 + 关键词)"""
# 向量搜索
vector_results = self.search([query_embedding], n_results=top_k * 2)
# 关键词搜索(Chroma支持基本的文本搜索)
keyword_results = self.collection.query(
query_texts=[keyword_query],
n_results=top_k * 2
)
# 合并和重排序结果
merged_scores = {}
# 向量搜索结果
for i, doc_id in enumerate(vector_results['ids'][0]):
vector_score = 1 - vector_results['distances'][0][i]
merged_scores[doc_id] = alpha * vector_score
# 关键词搜索结果
for i, doc_id in enumerate(keyword_results['ids'][0]):
keyword_score = 1 - keyword_results['distances'][0][i]
if doc_id in merged_scores:
merged_scores[doc_id] += (1 - alpha) * keyword_score
else:
merged_scores[doc_id] = (1 - alpha) * keyword_score
# 获取最终文档
final_results = []
for doc_id, score in sorted(merged_scores.items(), key=lambda x: x[1], reverse=True)[:top_k]:
# 获取文档内容
doc_content = self.get_document_by_id(doc_id)
if doc_content:
final_results.append({
'id': doc_id,
'content': doc_content['content'],
'metadata': doc_content['metadata'],
'hybrid_score': score
})
return final_results
def get_document_by_id(self, doc_id: str) -> Optional[Dict]:
"""根据ID获取文档"""
try:
results = self.collection.get(ids=[doc_id])
if results['documents']:
return {
'id': doc_id,
'content': results['documents'][0],
'metadata': results['metadatas'][0] if results['metadatas'] else {}
}
except Exception as e:
print(f"获取文档失败 {doc_id}: {e}")
return None
def update_document(self, doc_id: str, new_content: str,
new_embedding: List[float], new_metadata: Dict = None):
"""更新文档"""
if new_metadata is None:
new_metadata = {}
self.collection.update(
ids=[doc_id],
documents=[new_content],
embeddings=[new_embedding],
metadatas=[new_metadata]
)
print(f"已更新文档: {doc_id}")
def delete_documents(self, doc_ids: List[str]):
"""删除文档"""
self.collection.delete(ids=doc_ids)
print(f"已删除 {len(doc_ids)} 个文档")
def get_collection_stats(self) -> Dict:
"""获取集合统计信息"""
try:
count = self.collection.count()
return {
'document_count': count,
'collection_name': self.collection_name,
'persist_directory': self.persist_directory
}
except Exception as e:
return {'error': str(e)}
def clear_collection(self):
"""清空集合"""
try:
# 获取所有文档ID
all_docs = self.collection.get()
if all_docs['ids']:
self.delete_documents(all_docs['ids'])
print(f"已清空集合 {self.collection_name}")
except Exception as e:
print(f"清空集合失败: {e}")
四、RAG实现方式
4.1 基础RAG系统
基础RAG系统是整个检索增强生成架构的核心实现。它整合了文档处理、向量化、检索和生成等关键组件,提供了一个统一的接口来处理用户查询并生成基于知识的回答。以下实现包含了完整的项目结构、配置管理、知识库构建和查询处理功能:
核心RAG类实现:
import logging
from typing import List, Dict, Any, Optional
from dataclasses import dataclass
import time
@dataclass
class RAGConfig:
embedding_model: str = "shibing624/text2vec-base-chinese"
llm_model: str = "deepseek-chat"
chunk_size: int = 1000
chunk_overlap: int = 200
top_k: int = 5
similarity_threshold: float = 0.7
max_context_length: int = 4000
use_hybrid_search: bool = False
enable_reranking: bool = False
class BasicRAGSystem:
def __init__(self, config: RAGConfig):
self.config = config
# 初始化日志
logging.basicConfig(level=logging.INFO)
self.logger = logging.getLogger(__name__)
# 初始化组件(稍后实现)
self.embedding_service = None
self.vector_store = None
self.llm_service = None
self.document_chunker = None
self.logger.info("RAG系统初始化完成")
def initialize(self, embedding_service, vector_store, llm_service, document_chunker):
"""初始化RAG组件"""
self.embedding_service = embedding_service
self.vector_store = vector_store
self.llm_service = llm_service
self.document_chunker = document_chunker
self.logger.info("RAG组件初始化完成")
def build_knowledge_base(self, data_path: str) -> Dict[str, Any]:
"""构建知识库"""
self.logger.info(f"开始构建知识库,数据路径: {data_path}")
start_time = time.time()
try:
# 1. 加载文档
doc_processor = DocumentProcessor()
documents = doc_processor.load_documents(data_path)
self.logger.info(f"加载了 {len(documents)} 个文档")
# 2. 文档分块
chunks = self.document_chunker.chunk_documents(documents)
self.logger.info(f"生成了 {len(chunks)} 个文档块")
# 3. 向量化
chunk_texts = [chunk['content'] for chunk in chunks]
chunk_metadatas = [chunk['metadata'] for chunk in chunks]
embeddings = self.embedding_service.embed_texts(chunk_texts)
# 4. 存储到向量数据库
doc_ids = self.vector_store.add_documents(
texts=chunk_texts,
embeddings=embeddings.tolist(),
metadatas=chunk_metadatas
)
processing_time = time.time() - start_time
return {
'success': True,
'document_count': len(documents),
'chunk_count': len(chunks),
'processing_time': processing_time,
'embedding_dimension': self.embedding_service.get_embedding_dimension()
}
except Exception as e:
self.logger.error(f"构建知识库失败: {e}")
return {
'success': False,
'error': str(e),
'processing_time': time.time() - start_time
}
def query(self, user_question: str,
use_hybrid: bool = None,
return_sources: bool = True) -> Dict[str, Any]:
"""查询RAG系统"""
if use_hybrid is None:
use_hybrid = self.config.use_hybrid_search
start_time = time.time()
try:
# 1. 查询预处理
processed_query = self.preprocess_query(user_question)
# 2. 检索相关文档
if use_hybrid:
retrieved_docs = self.hybrid_retrieval(processed_query)
else:
retrieved_docs = self.vector_retrieval(processed_query)
# 3. 构建上下文
context = self.build_context(retrieved_docs)
# 4. 生成回答
answer = self.generate_answer(user_question, context)
# 5. 构建响应
processing_time = time.time() - start_time
response = {
'answer': answer,
'question': user_question,
'retrieved_docs_count': len(retrieved_docs),
'used_docs_count': len(retrieved_docs),
'processing_time': processing_time
}
if return_sources:
response['sources'] = self.format_sources(retrieved_docs)
return response
except Exception as e:
self.logger.error(f"查询处理失败: {e}")
return {
'answer': '抱歉,处理您的问题时遇到了错误。',
'question': user_question,
'error': str(e),
'processing_time': time.time() - start_time,
'sources': []
}
def preprocess_query(self, query: str) -> Dict[str, Any]:
"""查询预处理"""
# 清理查询
cleaned_query = query.strip()
# 可以在这里添加查询扩展、意图识别等
return {
'original': query,
'cleaned': cleaned_query,
'expanded': cleaned_query # 简单实现,可以扩展
}
def vector_retrieval(self, processed_query: Dict[str, Any]) -> List[Dict]:
"""向量检索"""
query_embedding = self.embedding_service.embed_query(processed_query['expanded'])
results = self.vector_store.search(
query_embeddings=[query_embedding],
n_results=self.config.top_k
)
# 转换结果格式
retrieved_docs = []
for i in range(len(results['documents'])):
similarity_score = 1 - results['distances'][i] # 转换为相似度
if similarity_score >= self.config.similarity_threshold:
retrieved_docs.append({
'id': results['ids'][i],
'content': results['documents'][i],
'metadata': results['metadatas'][i],
'similarity_score': similarity_score
})
return retrieved_docs
def hybrid_retrieval(self, processed_query: Dict[str, Any]) -> List[Dict]:
"""混合检索"""
query_embedding = self.embedding_service.embed_query(processed_query['expanded'])
results = self.vector_store.hybrid_search(
query_embedding=query_embedding,
keyword_query=processed_query['expanded'],
alpha=0.7,
top_k=self.config.top_k
)
# 过滤低相关性结果
filtered_results = [
doc for doc in results
if doc.get('hybrid_score', 0) >= self.config.similarity_threshold
]
return filtered_results
def build_context(self, retrieved_docs: List[Dict]) -> str:
"""构建上下文"""
if not retrieved_docs:
return "未找到相关信息。"
context_parts = []
current_length = 0
for i, doc in enumerate(retrieved_docs):
doc_content = doc['content']
# 检查长度限制
if current_length + len(doc_content) > self.config.max_context_length:
# 尝试截取部分内容
remaining_space = self.config.max_context_length - current_length - 100
if remaining_space > 100:
truncated_content = doc_content[:remaining_space] + "..."
context_parts.append(f"文档 {i+1}:{truncated_content}")
break
context_parts.append(f"文档 {i+1}:{doc_content}")
current_length += len(doc_content)
return '\n\n'.join(context_parts)
def generate_answer(self, question: str, context: str) -> str:
"""生成回答"""
prompt = f"""
基于以下上下文信息回答用户问题。请确保回答准确、相关且简洁。
上下文信息:
{context}
用户问题:{question}
请基于上下文信息回答问题。如果上下文中没有足够信息,请说明。
"""
answer = self.llm_service.generate_response(prompt)
return answer
def format_sources(self, retrieved_docs: List[Dict]) -> List[Dict]:
"""格式化来源信息"""
sources = []
for doc in retrieved_docs:
source = {
'id': doc['id'],
'similarity_score': doc.get('similarity_score', 0),
'metadata': doc['metadata']
}
# 添加有用的元数据
if 'file_name' in doc['metadata']:
source['file_name'] = doc['metadata']['file_name']
if 'title' in doc['metadata']:
source['title'] = doc['metadata']['title']
sources.append(source)
return sources
4.2 高级RAG功能
高级RAG功能在基础系统之上添加了智能化的增强能力,包括查询扩展、重排序、多路检索等技术,显著提升了系统的检索精度和回答质量。这些功能通过优化查询理解、文档相关性判断和结果排序来克服基础系统的局限性:
查询扩展器:
from typing import List, Dict, Set
import jieba
import requests
class QueryExpander:
def __init__(self):
self.synonym_dict = self.load_synonym_dict()
self.related_words_dict = self.load_related_words_dict()
def load_synonym_dict(self) -> Dict[str, List[str]]:
"""加载同义词字典"""
return {
'工作': ['职业', '岗位', '职能', '任职'],
'方法': ['方式', '途径', '手段', '策略'],
'问题': ['难题', '挑战', '疑问', '困惑'],
'发展': ['进步', '提升', '成长', '演变'],
'重要': ['关键', '核心', '主要', '首要'],
'Python': ['编程', '代码', '脚本', '开发'],
'数据库': ['MySQL', 'PostgreSQL', 'MongoDB', 'Oracle'],
'前端': ['HTML', 'CSS', 'JavaScript', 'UI', '用户体验']
}
def load_related_words_dict(self) -> Dict[str, List[str]]:
"""加载相关词字典"""
return {
'机器学习': ['深度学习', '神经网络', 'AI', '人工智能'],
'数据分析': ['数据挖掘', '统计', '可视化', '报表'],
'云计算': ['AWS', 'Azure', '阿里云', '服务器'],
'网络安全': ['防火墙', '加密', '认证', '漏洞']
}
def expand_query(self, original_query: str,
expansion_methods: List[str] = None) -> List[str]:
"""查询扩展"""
if expansion_methods is None:
expansion_methods = ['synonyms', 'related_words']
expanded_queries = [original_query] # 包含原始查询
# 分词
words = jieba.lcut(original_query)
for method in expansion_methods:
if method == 'synonyms':
expanded_queries.extend(self.expand_with_synonyms(original_query, words))
elif method == 'related_words':
expanded_queries.extend(self.expand_with_related_words(words))
# 去重
expanded_queries = list(set(expanded_queries))
return expanded_queries
def expand_with_synonyms(self, original_query: str, words: List[str]) -> List[str]:
"""同义词扩展"""
expanded_queries = []
for word in words:
if word in self.synonym_dict:
for synonym in self.synonym_dict[word][:2]: # 最多2个同义词
expanded_query = original_query.replace(word, synonym)
if expanded_query != original_query:
expanded_queries.append(expanded_query)
return expanded_queries
def expand_with_related_words(self, words: List[str]) -> List[str]:
"""相关词扩展"""
related_queries = []
for word in words:
if word in self.related_words_dict:
for related_word in self.related_words_dict[word]:
related_queries.append(f"{word} {related_word}")
return related_queries
重排序器:
import torch
from transformers import AutoTokenizer, AutoModelForSequenceClassification
import numpy as np
class NeuralReranker:
def __init__(self, model_name: str = "cross-encoder/ms-marco-MiniLM-L-6-v2"):
self.model_name = model_name
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
self.model = AutoModelForSequenceClassification.from_pretrained(model_name)
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
self.model.to(self.device)
self.model.eval()
def rerank(self, query: str, documents: List[str], top_k: int = 5) -> List[Dict]:
"""重排序文档"""
if not documents:
return []
# 构建查询-文档对
query_doc_pairs = [(query, doc) for doc in documents]
# 批量处理
scores = self.compute_scores(query_doc_pairs)
# 排序并返回top-k
scored_docs = list(zip(documents, scores))
scored_docs.sort(key=lambda x: x[1], reverse=True)
results = []
for i, (doc, score) in enumerate(scored_docs[:top_k]):
results.append({
'content': doc,
'rerank_score': float(score),
'rank': i + 1
})
return results
def compute_scores(self, query_doc_pairs: List[tuple]) -> np.ndarray:
"""计算相关性分数"""
batch_size = 16
all_scores = []
for i in range(0, len(query_doc_pairs), batch_size):
batch_pairs = query_doc_pairs[i:i + batch_size]
# 编码输入
inputs = self.tokenizer(
[f"{query} [SEP] {doc}" for query, doc in batch_pairs],
padding=True,
truncation=True,
max_length=512,
return_tensors="pt"
)
inputs = {k: v.to(self.device) for k, v in inputs.items()}
# 模型预测
with torch.no_grad():
outputs = self.model(**inputs)
batch_scores = torch.sigmoid(outputs.logits).cpu().numpy()
all_scores.extend(batch_scores.flatten())
return np.array(all_scores)
五、RAG优化策略
5.1 检索优化
检索优化是提升RAG系统性能的关键环节,通过多策略融合、智能排序和上下文自适应等技术来改善检索的准确性和效率。这些优化策略解决了单一检索方法的局限性,提高了系统的整体表现:
多路检索融合:
class MultiModalRetrieval:
def __init__(self, vector_store, keyword_searcher):
self.vector_store = vector_store
self.keyword_searcher = keyword_searcher
self.fusion_methods = ['rrf', 'weighted_sum', 'reciprocal_rank']
def fused_search(self, query, method='rrf', weights=None):
"""融合多种检索结果"""
# 向量检索
vector_results = self.vector_store.search(
self.embed_query(query), n_results=20
)
# 关键词检索
keyword_results = self.keyword_searcher.search(query, n_results=20)
# 融合结果
if method == 'rrf':
return self.reciprocal_rank_fusion(vector_results, keyword_results)
elif method == 'weighted_sum':
return self.weighted_sum_fusion(vector_results, keyword_results, weights)
else:
return self.reciprocal_rank_fusion(vector_results, keyword_results)
def reciprocal_rank_fusion(self, vector_results, keyword_results, k=60):
"""倒数排名融合"""
fused_scores = {}
# 向量检索结果
for rank, doc in enumerate(vector_results):
doc_id = doc['id']
score = 1.0 / (k + rank + 1)
fused_scores[doc_id] = fused_scores.get(doc_id, 0) + score
# 关键词检索结果
for rank, doc in enumerate(keyword_results):
doc_id = doc['id']
score = 1.0 / (k + rank + 1)
fused_scores[doc_id] = fused_scores.get(doc_id, 0) + score
# 按分数排序
sorted_docs = sorted(fused_scores.items(), key=lambda x: x[1], reverse=True)
return [doc for doc, score in sorted_docs[:10]]
5.2 上下文优化
上下文优化专注于如何有效地组织和利用检索到的文档信息,确保生成的回答既准确又相关。通过智能的文档选择、内容截取和重要性评估,系统可以在有限的上下文窗口内最大化信息密度和回答质量:
自适应上下文选择:
class AdaptiveContextManager:
def __init__(self, max_context_length=4000):
self.max_length = max_context_length
self.importance_scorer = ImportanceScorer()
def select_context(self, query, documents, max_docs=5):
"""自适应选择上下文"""
# 计算文档重要性
scored_docs = []
for doc in documents:
importance_score = self.importance_scorer.score(query, doc)
scored_docs.append((doc, importance_score))
# 按重要性排序
scored_docs.sort(key=lambda x: x[1], reverse=True)
# 选择文档直到达到长度限制
selected_docs = []
current_length = 0
for doc, score in scored_docs:
doc_length = len(doc['content'])
if current_length + doc_length <= self.max_length and len(selected_docs) < max_docs:
selected_docs.append(doc)
current_length += doc_length
elif current_length + doc_length > self.max_length and selected_docs:
# 尝试截取文档
remaining_space = self.max_length - current_length
if remaining_space > 100:
truncated_doc = self.truncate_document(doc, remaining_space)
selected_docs.append(truncated_doc)
break
return selected_docs
def truncate_document(self, doc, max_length):
"""智能截取文档"""
content = doc['content']
# 尝试在句子边界截取
sentences = re.split(r'[。!?]', content)
truncated_content = ""
for sentence in sentences:
if len(truncated_content + sentence) <= max_length:
truncated_content += sentence + "。"
else:
break
if not truncated_content:
# 如果没有完整的句子,直接截取
truncated_content = content[:max_length-3] + "..."
return {
**doc,
'content': truncated_content,
'truncated': True
}
class ImportanceScorer:
def __init__(self):
self.weight_similarity = 0.4
self.weight_recency = 0.2
self.weight_length = 0.2
self.weight_structure = 0.2
def score(self, query, document):
"""计算文档重要性分数"""
# 相似度分数
similarity_score = document.get('similarity_score', 0)
# 时效性分数(基于元数据)
recency_score = self.calculate_recency_score(document)
# 长度分数(偏好适中长度)
length_score = self.calculate_length_score(document)
# 结构分数(标题、章节等)
structure_score = self.calculate_structure_score(document)
# 加权求和
total_score = (
self.weight_similarity * similarity_score +
self.weight_recency * recency_score +
self.weight_length * length_score +
self.weight_structure * structure_score
)
return total_score
def calculate_recency_score(self, document):
"""计算时效性分数"""
metadata = document.get('metadata', {})
# 检查是否有日期信息
if 'updated_date' in metadata:
# 这里可以实现更复杂的日期计算
return 0.8
elif 'created_at' in metadata:
return 0.6
else:
return 0.3 # 没有时间信息的文档分数较低
def calculate_length_score(self, document):
"""计算长度分数"""
content_length = len(document['content'])
if content_length < 100:
return 0.3 # 太短
elif content_length < 500:
return 0.7 # 适中
elif content_length < 1500:
return 1.0 # 理想
else:
return 0.6 # 太长
def calculate_structure_score(self, document):
"""计算结构分数"""
metadata = document.get('metadata', {})
score = 0.5 # 基础分数
# 有标题加分
if 'title' in metadata:
score += 0.2
# 是主要章节加分
if metadata.get('section_level', 999) <= 2:
score += 0.2
# 文件类型加分
file_type = metadata.get('file_type', '')
if file_type in ['markdown', 'pdf']:
score += 0.1
return min(score, 1.0)
六、RAG应用场景
6.1 企业知识库问答
企业知识库问答是RAG技术的重要应用场景,它能够帮助企业快速构建智能化的内部知识服务系统。这类系统需要处理权限控制、数据安全、多部门协作等复杂需求,为员工提供准确、及时的知识支持:
企业级RAG系统:
from enum import Enum
from typing import Dict, List, Optional
from datetime import datetime
class AccessLevel(Enum):
PUBLIC = "public"
INTERNAL = "internal"
CONFIDENTIAL = "confidential"
SECRET = "secret"
class EnterpriseRAGSystem:
def __init__(self, config):
self.config = config
self.user_manager = UserManager()
self.doc_indexer = EnterpriseDocumentIndexer()
self.access_controller = AccessController()
self.audit_logger = AuditLogger()
self.base_rag = BaseRAGSystem(config)
def query_with_permission(self, user_id: str, query: str,
access_level: AccessLevel = AccessLevel.INTERNAL) -> Dict:
"""带权限控制的查询"""
# 验证用户权限
user_info = self.user_manager.get_user(user_id)
if not user_info:
return {
'answer': '用户不存在',
'error': 'user_not_found',
'question': query
}
if not self.access_controller.can_access(user_info, access_level):
return {
'answer': '您没有权限访问该级别的信息',
'error': 'permission_denied',
'question': query
}
# 记录查询
self.audit_logger.log_query(user_id, query, access_level)
try:
# 执行检索(带权限过滤)
filtered_docs = self.doc_indexer.search_with_permission(
query, user_info['department'], access_level
)
if not filtered_docs:
return {
'answer': '未找到相关信息或您没有访问权限',
'sources': [],
'success': True
}
# 生成回答
context = self.build_secure_context(filtered_docs)
answer = self.base_rag.generate_answer(query, context)
# 记录回答
self.audit_logger.log_answer(user_id, query, answer, filtered_docs)
return {
'answer': answer,
'sources': self.format_enterprise_sources(filtered_docs),
'success': True
}
except Exception as e:
self.audit_logger.log_error(user_id, query, str(e))
return {
'answer': '处理您的查询时遇到了问题,请联系管理员',
'error': str(e),
'success': False
}
def build_secure_context(self, documents):
"""构建安全的上下文"""
context_parts = []
for i, doc in enumerate(documents):
# 移除敏感信息
safe_content = self.sanitize_content(doc['content'])
context_parts.append(f"文档{i+1}:{safe_content[:500]}")
return '\n\n'.join(context_parts)
def sanitize_content(self, content):
"""清理敏感内容"""
# 这里可以实现更复杂的内容过滤逻辑
sensitive_patterns = [
r'\b\d{4}-\d{4}-\d{4}-\d{4}\b', # 信用卡号
r'\b\d{3}-\d{2}-\d{4}\b', # SSN
r'password\s*[:=]\s*\w+', # 密码
]
for pattern in sensitive_patterns:
content = re.sub(pattern, '[已隐藏]', content, flags=re.IGNORECASE)
return content
class UserManager:
def __init__(self):
self.users = {
'user001': {
'name': '张三',
'department': '技术部',
'role': 'developer',
'access_level': AccessLevel.CONFIDENTIAL,
'roles': ['developer']
},
'user002': {
'name': '李四',
'department': '人事部',
'role': 'hr_manager',
'access_level': AccessLevel.CONFIDENTIAL,
'roles': ['hr_manager']
}
}
def get_user(self, user_id: str) -> Optional[Dict]:
return self.users.get(user_id)
def has_permission(self, user_id: str, required_level: AccessLevel) -> bool:
user = self.get_user(user_id)
if not user:
return False
level_hierarchy = {
AccessLevel.PUBLIC: 0,
AccessLevel.INTERNAL: 1,
AccessLevel.CONFIDENTIAL: 2,
AccessLevel.SECRET: 3
}
user_level = level_hierarchy.get(user['access_level'], 0)
required_level_value = level_hierarchy.get(required_level, 0)
return user_level >= required_level_value
class AccessController:
def can_access(self, user_info: Dict, required_level: AccessLevel) -> bool:
user_level = user_info.get('access_level', AccessLevel.PUBLIC)
level_hierarchy = {
AccessLevel.PUBLIC: 0,
AccessLevel.INTERNAL: 1,
AccessLevel.CONFIDENTIAL: 2,
AccessLevel.SECRET: 3
}
return level_hierarchy.get(user_level, 0) >= level_hierarchy.get(required_level, 0)
class AuditLogger:
def __init__(self):
self.logs = []
def log_query(self, user_id: str, query: str, access_level: AccessLevel):
log_entry = {
'timestamp': datetime.now(),
'type': 'query',
'user_id': user_id,
'query': query,
'access_level': access_level.value
}
self.logs.append(log_entry)
def log_answer(self, user_id: str, query: str, answer: str, sources: List[Dict]):
log_entry = {
'timestamp': datetime.now(),
'type': 'answer',
'user_id': user_id,
'query': query,
'answer': answer[:100], # 只记录前100个字符
'source_count': len(sources)
}
self.logs.append(log_entry)
def log_error(self, user_id: str, query: str, error: str):
log_entry = {
'timestamp': datetime.now(),
'type': 'error',
'user_id': user_id,
'query': query,
'error': error
}
self.logs.append(log_entry)
6.2 智能客服系统
智能客服系统是RAG技术的另一个重要应用领域,它通过结合企业产品知识库和对话管理,为用户提供24/7的智能咨询服务。这类系统需要处理意图识别、情感分析、多轮对话等复杂场景,提供个性化、精准的客服体验:
客服RAG实现:
import uuid
from datetime import datetime, timedelta
from typing import Dict, List, Optional
class CustomerServiceRAG:
def __init__(self, config):
self.config = config
self.conversation_manager = ConversationManager()
self.intent_classifier = IntentClassifier()
self.emotion_analyzer = EmotionAnalyzer()
self.response_generator = ResponseGenerator()
self.knowledge_base = ProductKnowledgeBase()
def handle_customer_query(self, session_id: str, user_message: str,
user_context: Dict = None) -> Dict:
"""处理客户查询"""
# 获取对话历史
conversation_history = self.conversation_manager.get_history(session_id)
# 意图识别
intent = self.intent_classifier.classify(user_message, conversation_history)
# 情感分析
emotion = self.emotion_analyzer.analyze(user_message)
# 获取用户上下文
if user_context is None:
user_context = {}
# 根据意图处理查询
if intent == 'complaint':
return self.handle_complaint(session_id, user_message, emotion, user_context)
elif intent == 'technical_support':
return self.handle_technical_support(session_id, user_message, conversation_history)
elif intent == 'product_inquiry':
return self.handle_product_inquiry(session_id, user_message, user_context)
elif intent == 'general_inquiry':
return self.handle_general_inquiry(session_id, user_message, conversation_history)
else:
return self.handle_unknown_intent(session_id, user_message)
def handle_technical_support(self, session_id: str, query: str,
history: List[Dict]) -> Dict:
"""处理技术支持查询"""
# 检索相关知识
relevant_docs = self.knowledge_base.search_technical_docs(query)
# 构建上下文
context = self.build_technical_context(query, relevant_docs, history)
# 生成技术支持回答
response = self.response_generator.generate_technical_response(
query, context, history
)
# 保存对话
self.conversation_manager.add_message(session_id, query, response, 'technical_support')
# 生成后续建议
suggestions = self.generate_technical_suggestions(query, response, relevant_docs)
return {
'response': response,
'intent': 'technical_support',
'suggestions': suggestions,
'escalation_needed': self.should_escalate(response, relevant_docs),
'sources': self.format_sources(relevant_docs)
}
def handle_complaint(self, session_id: str, complaint: str,
emotion: str, user_context: Dict) -> Dict:
"""处理投诉"""
# 立即道歉和共情
apology_response = self.generate_empathetic_response(complaint, emotion)
# 识别投诉类型
complaint_type = self.classify_complaint_type(complaint)
# 检索相关政策和解决方案
solutions = self.knowledge_base.search_complaint_solutions(complaint_type)
# 生成处理方案
resolution = self.generate_complaint_resolution(complaint, complaint_type, solutions)
# 创建投诉记录
complaint_id = self.create_complaint_record(
session_id, complaint, complaint_type, emotion, user_context
)
# 生成完整回复
full_response = f"{apology_response}\n\n{resolution}"
# 保存对话
self.conversation_manager.add_message(session_id, complaint, full_response, 'complaint')
return {
'response': full_response,
'intent': 'complaint',
'complaint_id': complaint_id,
'complaint_type': complaint_type,
'escalation_needed': emotion in ['angry', 'frustrated'],
'follow_up_required': True,
'estimated_resolution_time': self.get_resolution_time(complaint_type)
}
def handle_product_inquiry(self, session_id: str, query: str,
user_context: Dict) -> Dict:
"""处理产品咨询"""
# 识别产品
products = self.extract_products_from_query(query)
# 检索产品信息
product_info = []
for product in products:
info = self.knowledge_base.get_product_info(product)
if info:
product_info.append(info)
# 构建产品上下文
context = self.build_product_context(query, product_info, user_context)
# 生成产品咨询回答
response = self.response_generator.generate_product_response(
query, context, products
)
# 生成相关产品推荐
recommendations = self.generate_product_recommendations(products, user_context)
# 保存对话
self.conversation_manager.add_message(session_id, query, response, 'product_inquiry')
return {
'response': response,
'intent': 'product_inquiry',
'products_found': products,
'recommendations': recommendations,
'sources': self.format_product_sources(product_info)
}
def generate_empathetic_response(self, complaint: str, emotion: str) -> str:
"""生成共情回应"""
empathetic_templates = {
'angry': "我理解您的愤怒,给您带来这样的体验我深感抱歉。我会认真对待您的问题并尽快帮助您解决。",
'frustrated': "我理解您的沮丧,这种情况确实很令人困扰。让我来帮助您解决这个问题。",
'disappointed': "很抱歉让您感到失望,我们重视每一位客户的体验。请告诉我具体情况,我会尽力协助您。",
'neutral': "感谢您的反馈,我理解您的 concerns。我会认真处理您提到的问题。"
}
base_response = empathetic_templates.get(emotion, empathetic_templates['neutral'])
return base_response
class ConversationManager:
def __init__(self):
self.conversations = {}
self.max_history_length = 20
def get_history(self, session_id: str) -> List[Dict]:
"""获取对话历史"""
if session_id not in self.conversations:
self.conversations[session_id] = {
'session_id': session_id,
'created_at': datetime.now(),
'messages': []
}
return self.conversations[session_id]['messages']
def add_message(self, session_id: str, user_message: str,
bot_response: str, intent: str):
"""添加对话消息"""
if session_id not in self.conversations:
self.get_history(session_id) # 创建新会话
message = {
'timestamp': datetime.now(),
'user_message': user_message,
'bot_response': bot_response,
'intent': intent
}
self.conversations[session_id]['messages'].append(message)
# 限制历史长度
if len(self.conversations[session_id]['messages']) > self.max_history_length:
self.conversations[session_id]['messages'] = \
self.conversations[session_id]['messages'][-self.max_history_length:]
def get_conversation_summary(self, session_id: str) -> Dict:
"""获取对话摘要"""
if session_id not in self.conversations:
return {}
conversation = self.conversations[session_id]
messages = conversation['messages']
return {
'session_id': session_id,
'duration': (datetime.now() - conversation['created_at']).total_seconds(),
'message_count': len(messages),
'intents': list(set(msg['intent'] for msg in messages)),
'last_activity': messages[-1]['timestamp'] if messages else None
}
class IntentClassifier:
def __init__(self):
self.intent_patterns = {
'complaint': ['投诉', '抱怨', '不满', '问题', '错误', '故障', '糟糕'],
'technical_support': ['技术', '系统', '软件', '硬件', '网络', '连接', '设置'],
'product_inquiry': ['产品', '价格', '功能', '规格', '购买', '订单'],
'general_inquiry': ['咨询', '询问', '了解', '信息', '帮助'],
'billing': ['账单', '费用', '付款', '退款', '发票'],
'shipping': ['配送', '快递', '物流', '发货', '收货']
}
def classify(self, message: str, history: List[Dict] = None) -> str:
"""分类意图"""
message_lower = message.lower()
# 基于关键词分类
for intent, keywords in self.intent_patterns.items():
for keyword in keywords:
if keyword in message_lower:
return intent
# 基于历史对话上下文
if history and len(history) > 0:
last_intent = history[-1]['intent']
if last_intent in ['complaint', 'technical_support']:
# 延续之前的意图
return last_intent
# 默认意图
return 'general_inquiry'
七、大模型开发实战
7.1 Deepseek模型集成
Deepseek是一个优秀的国产大语言模型,具有强大的中文理解能力和合理的价格优势。本节介绍如何将Deepseek模型集成到RAG系统中,包括API调用封装、提示管理、流式处理等完整功能,以及成本控制和性能优化策略:
Deepseek API封装:
import openai
import time
from typing import List, Dict, Optional, Generator
import json
class DeepSeekLLMService:
def __init__(self, api_key: str, base_url: str = "https://api.deepseek.com"):
self.client = openai.OpenAI(
api_key=api_key,
base_url=base_url
)
self.model = "deepseek-chat"
self.request_count = 0
self.total_tokens = 0
self.total_cost = 0.0
def generate_response(self, prompt: str,
temperature: float = 0.7,
max_tokens: int = 1000,
top_p: float = 0.9) -> str:
"""生成回答"""
try:
response = self.client.chat.completions.create(
model=self.model,
messages=[{"role": "user", "content": prompt}],
temperature=temperature,
max_tokens=max_tokens,
top_p=top_p
)
# 更新统计信息
self.request_count += 1
self.total_tokens += response.usage.total_tokens
self.total_cost += self.calculate_cost(response.usage)
return response.choices[0].message.content
except Exception as e:
raise Exception(f"Deepseek API调用失败: {e}")
def chat_with_history(self, messages: List[Dict],
temperature: float = 0.7,
max_tokens: int = 1000) -> str:
"""带历史记录的对话"""
try:
response = self.client.chat.completions.create(
model=self.model,
messages=messages,
temperature=temperature,
max_tokens=max_tokens
)
self.request_count += 1
self.total_tokens += response.usage.total_tokens
self.total_cost += self.calculate_cost(response.usage)
return response.choices[0].message.content
except Exception as e:
raise Exception(f"Deepseek API调用失败: {e}")
def stream_response(self, prompt: str,
temperature: float = 0.7,
max_tokens: int = 1000) -> Generator[str, None, None]:
"""流式生成回答"""
try:
stream = self.client.chat.completions.create(
model=self.model,
messages=[{"role": "user", "content": prompt}],
temperature=temperature,
max_tokens=max_tokens,
stream=True
)
full_content = ""
for chunk in stream:
if chunk.choices[0].delta.content:
content = chunk.choices[0].delta.content
full_content += content
yield content
# 更新统计(估算)
self.request_count += 1
self.total_tokens += len(full_content.split()) * 1.3 # 估算token数
except Exception as e:
raise Exception(f"Deepseek流式调用失败: {e}")
def calculate_cost(self, usage) -> float:
"""计算费用(基于Deepseek定价)"""
# Deepseek定价示例(实际价格请查看官方文档)
input_price = 0.14 / 1000000 # 输入价格 $0.14/1M tokens
output_price = 0.28 / 1000000 # 输出价格 $0.28/1M tokens
return (usage.prompt_tokens * input_price +
usage.completion_tokens * output_price)
def get_stats(self) -> Dict:
"""获取使用统计"""
return {
'request_count': self.request_count,
'total_tokens': self.total_tokens,
'total_cost': self.total_cost,
'model': self.model,
'avg_tokens_per_request': self.total_tokens / max(self.request_count, 1)
}
def reset_stats(self):
"""重置统计"""
self.request_count = 0
self.total_tokens = 0
self.total_cost = 0.0
class DeepSeekPromptManager:
def __init__(self):
self.templates = self.load_templates()
def load_templates(self) -> Dict:
"""加载提示模板"""
return {
'rag_qa': """你是一个专业的问答助手。请基于提供的上下文信息准确回答用户问题。
上下文信息:
{context}
用户问题:{question}
回答要求:
1. 紧密基于上下文信息
2. 如果上下文中没有相关信息,请明确说明
3. 回答简洁明了,条理清晰
4. 引用具体的信息来源
回答:""",
'technical_support': """你是一个专业的技术支持助手。请基于以下技术文档和故障信息帮助用户解决问题。
技术文档:
{context}
用户问题描述:{question}
请提供:
1. 问题的可能原因分析
2. 具体的解决步骤
3. 预防措施
4. 如果问题复杂,建议的联系支持方式
技术支持回答:""",
'customer_service': """你是一个友好的客服助手。请基于产品信息和用户查询提供专业、友好的回复。
产品信息:
{context}
用户咨询:{question}
回复要求:
1. 语气友好、专业
2. 直接回答用户问题
3. 提供有用的额外信息
4. 保持简洁易懂
客服回复:""",
'summarization': """请对以下内容进行总结,提炼关键信息。
内容:
{content}
总结要求:
1. 提取主要观点和关键信息
2. 保持逻辑清晰
3. 控制在200字以内
总结:"""
}
def build_prompt(self, template_name: str, **kwargs) -> str:
"""构建提示"""
if template_name not in self.templates:
raise ValueError(f"未知的模板名称: {template_name}")
template = self.templates[template_name]
return template.format(**kwargs)
def add_template(self, name: str, template: str):
"""添加自定义模板"""
self.templates[name] = template
def list_templates(self) -> List[str]:
"""列出所有模板"""
return list(self.templates.keys())
7.2 模型性能优化
在实际生产环境中,模型性能优化是确保系统稳定运行的关键。通过智能缓存、速率限制、重试机制等技术,可以显著提升系统的响应速度和稳定性,同时有效控制API调用成本:
智能缓存系统:
import hashlib
import pickle
import time
from typing import Dict, Any, Optional
from dataclasses import dataclass
@dataclass
class CacheConfig:
max_size: int = 1000
ttl: int = 3600 # 缓存时间(秒)
persist_to_file: bool = True
cache_file: str = "llm_cache.pkl"
class SmartCache:
def __init__(self, config: CacheConfig):
self.config = config
self.cache = {}
self.access_times = {}
self.hit_count = 0
self.miss_count = 0
# 加载持久化缓存
if config.persist_to_file:
self.load_cache()
def get(self, prompt: str, model_params: Dict = None) -> Optional[str]:
"""获取缓存结果"""
cache_key = self.generate_cache_key(prompt, model_params)
if cache_key in self.cache:
cached_item = self.cache[cache_key]
# 检查是否过期
if time.time() - cached_item['timestamp'] < self.config.ttl:
self.access_times[cache_key] = time.time()
self.hit_count += 1
return cached_item['response']
else:
# 删除过期项
del self.cache[cache_key]
if cache_key in self.access_times:
del self.access_times[cache_key]
self.miss_count += 1
return None
def put(self, prompt: str, response: str, model_params: Dict = None):
"""存储到缓存"""
cache_key = self.generate_cache_key(prompt, model_params)
# 如果缓存已满,删除最少使用的项
if len(self.cache) >= self.config.max_size:
self.evict_lru()
self.cache[cache_key] = {
'response': response,
'timestamp': time.time()
}
self.access_times[cache_key] = time.time()
# 持久化
if self.config.persist_to_file:
self.save_cache()
def generate_cache_key(self, prompt: str, model_params: Dict = None) -> str:
"""生成缓存键"""
key_data = {
'prompt': prompt,
'params': model_params or {}
}
key_str = json.dumps(key_data, sort_keys=True)
return hashlib.md5(key_str.encode()).hexdigest()
def evict_lru(self):
"""删除最少使用的缓存项"""
if not self.access_times:
return
lru_key = min(self.access_times.keys(), key=lambda k: self.access_times[k])
del self.cache[lru_key]
del self.access_times[lru_key]
def save_cache(self):
"""保存缓存到文件"""
try:
with open(self.config.cache_file, 'wb') as f:
pickle.dump(self.cache, f)
except Exception as e:
print(f"保存缓存失败: {e}")
def load_cache(self):
"""从文件加载缓存"""
try:
with open(self.config.cache_file, 'rb') as f:
loaded_cache = pickle.load(f)
# 过滤过期项
current_time = time.time()
for key, item in loaded_cache.items():
if current_time - item['timestamp'] < self.config.ttl:
self.cache[key] = item
self.access_times[key] = item['timestamp']
print(f"从文件加载了 {len(self.cache)} 个有效缓存项")
except FileNotFoundError:
print("缓存文件不存在,将创建新缓存")
except Exception as e:
print(f"加载缓存失败: {e}")
def get_stats(self) -> Dict:
"""获取缓存统计"""
total_requests = self.hit_count + self.miss_count
hit_rate = self.hit_count / total_requests if total_requests > 0 else 0
return {
'cache_size': len(self.cache),
'hit_count': self.hit_count,
'miss_count': self.miss_count,
'hit_rate': hit_rate,
'max_size': self.config.max_size
}
def clear(self):
"""清空缓存"""
self.cache.clear()
self.access_times.clear()
self.hit_count = 0
self.miss_count = 0
if self.config.persist_to_file:
self.save_cache()
class OptimizedLLMService:
def __init__(self, llm_service: DeepSeekLLMService, cache_config: CacheConfig = None):
self.llm_service = llm_service
self.cache = SmartCache(cache_config or CacheConfig())
self.rate_limiter = RateLimiter()
def generate_with_cache(self, prompt: str, **kwargs) -> str:
"""带缓存的生成"""
# 检查缓存
cached_response = self.cache.get(prompt, kwargs)
if cached_response:
return cached_response
# 速率限制检查
self.rate_limiter.wait_if_needed()
# 调用实际API
response = self.llm_service.generate_response(prompt, **kwargs)
# 存储到缓存
self.cache.put(prompt, response, kwargs)
return response
def generate_with_fallback(self, prompt: str,
fallback_prompts: List[str] = None,
**kwargs) -> str:
"""带回退机制的生成"""
try:
# 尝试原始提示
return self.generate_with_cache(prompt, **kwargs)
except Exception as e:
print(f"主提示失败: {e}")
# 尝试备用提示
if fallback_prompts:
for fallback_prompt in fallback_prompts:
try:
return self.generate_with_cache(fallback_prompt, **kwargs)
except Exception as fallback_error:
print(f"备用提示失败: {fallback_error}")
continue
# 所有提示都失败,返回错误信息
return "抱歉,服务暂时不可用,请稍后再试。"
def get_combined_stats(self) -> Dict:
"""获取综合统计"""
return {
'llm_stats': self.llm_service.get_stats(),
'cache_stats': self.cache.get_stats(),
'rate_limiter_stats': self.rate_limiter.get_stats()
}
class RateLimiter:
def __init__(self, max_requests_per_minute: int = 60):
self.max_requests = max_requests_per_minute
self.requests = []
def wait_if_needed(self):
"""如果需要,等待以符合速率限制"""
current_time = time.time()
# 清理超过1分钟的请求记录
self.requests = [req_time for req_time in self.requests
if current_time - req_time < 60]
# 如果达到限制,等待
if len(self.requests) >= self.max_requests:
oldest_request = min(self.requests)
wait_time = 60 - (current_time - oldest_request)
if wait_time > 0:
print(f"达到速率限制,等待 {wait_time:.1f} 秒")
time.sleep(wait_time)
# 记录当前请求
self.requests.append(current_time)
def get_stats(self) -> Dict:
"""获取速率限制统计"""
current_time = time.time()
recent_requests = [req_time for req_time in self.requests
if current_time - req_time < 60]
return {
'max_requests_per_minute': self.max_requests,
'recent_requests_count': len(recent_requests),
'remaining_requests': self.max_requests - len(recent_requests)
}
八、完整项目案例
8.1 企业级RAG问答系统
系统架构和实现:
from fastapi import FastAPI, HTTPException, BackgroundTasks
from pydantic import BaseModel
from typing import List, Optional, Dict, Any
import uvicorn
import asyncio
from contextlib import asynccontextmanager
# 请求和响应模型
class QueryRequest(BaseModel):
question: str
use_hybrid: bool = True
return_sources: bool = True
top_k: Optional[int] = None
class QueryResponse(BaseModel):
answer: str
question: str
retrieved_docs_count: int
used_docs_count: int
sources: Optional[List[Dict]] = None
processing_time: float
error: Optional[str] = None
class BuildKnowledgeBaseRequest(BaseModel):
data_path: str
class BuildKnowledgeBaseResponse(BaseModel):
message: str
document_count: int
processing_time: float
# 全局变量
rag_system = None
@asynccontextmanager
async def lifespan(app: FastAPI):
# 启动时初始化
global rag_system
# 初始化RAG系统
config = RAGConfig(
embedding_model="shibing624/text2vec-base-chinese",
chunk_size=1000,
top_k=5,
use_hybrid_search=True
)
rag_system = BasicRAGSystem(config)
# 这里应该初始化各个组件
# embedding_service = EmbeddingService(...)
# vector_store = ChromaVectorStore(...)
# llm_service = DeepSeekLLMService(...)
# document_chunker = DocumentChunker(...)
# rag_system.initialize(embedding_service, vector_store, llm_service, document_chunker)
yield
# 关闭时清理
if rag_system:
rag_system.cleanup()
app = FastAPI(
title="RAG问答系统",
description="基于SpringAI和Deepseek的检索增强生成系统",
version="1.0.0",
lifespan=lifespan
)
@app.get("/")
async def root():
return {
"message": "RAG问答系统API",
"version": "1.0.0",
"status": "running",
"rag_system": rag_system is not None
}
@app.get("/health")
async def health_check():
return {
"status": "healthy",
"timestamp": time.time(),
"rag_system": rag_system is not None
}
@app.post("/query", response_model=QueryResponse)
async def query_rag(request: QueryRequest):
"""查询RAG系统"""
if not rag_system:
raise HTTPException(status_code=503, detail="RAG系统未初始化")
try:
start_time = time.time()
# 临时调整top_k
original_top_k = rag_system.config.top_k
if request.top_k:
rag_system.config.top_k = request.top_k
# 执行查询
result = rag_system.query(
request.question,
use_hybrid=request.use_hybrid,
return_sources=request.return_sources
)
# 恢复配置
if request.top_k:
rag_system.config.top_k = original_top_k
processing_time = time.time() - start_time
return QueryResponse(
answer=result.get('answer', ''),
question=result.get('question', request.question),
retrieved_docs_count=result.get('retrieved_docs_count', 0),
used_docs_count=result.get('used_docs_count', 0),
sources=result.get('sources'),
processing_time=processing_time,
error=result.get('error')
)
except Exception as e:
raise HTTPException(status_code=500, detail=f"查询处理失败: {str(e)}")
@app.post("/build-knowledge-base", response_model=BuildKnowledgeBaseResponse)
async def build_knowledge_base(request: BuildKnowledgeBaseRequest,
background_tasks: BackgroundTasks):
"""构建知识库"""
if not rag_system:
raise HTTPException(status_code=503, detail="RAG系统未初始化")
data_path = Path(request.data_path)
if not data_path.exists():
raise HTTPException(status_code=400, detail=f"数据路径不存在: {request.data_path}")
try:
start_time = time.time()
# 在后台任务中执行构建
background_tasks.add_task(rag_system.build_knowledge_base, request.data_path)
processing_time = time.time() - start_time
return BuildKnowledgeBaseResponse(
message="知识库构建任务已启动,请稍后查看结果",
document_count=0, # 后台任务完成后才能知道具体数量
processing_time=processing_time
)
except Exception as e:
raise HTTPException(status_code=500, detail=f"构建任务启动失败: {str(e)}")
@app.get("/stats")
async def get_system_stats():
"""获取系统统计信息"""
if not rag_system:
raise HTTPException(status_code=503, detail="RAG系统未初始化")
try:
stats = rag_system.get_system_stats()
return {
"status": "success",
"stats": stats,
"timestamp": time.time()
}
except Exception as e:
raise HTTPException(status_code=500, detail=f"获取统计信息失败: {str(e)}")
@app.get("/llm-stats")
async def get_llm_stats():
"""获取LLM使用统计"""
if not rag_system:
raise HTTPException(status_code=503, detail="RAG系统未初始化")
try:
llm_stats = rag_system.llm_service.get_stats()
return {
"status": "success",
"llm_stats": llm_stats,
"timestamp": time.time()
}
except Exception as e:
raise HTTPException(status_code=500, detail=f"获取LLM统计失败: {str(e)}")
if __name__ == "__main__":
uvicorn.run(app, host="0.0.0.0", port=8000)
8.2 项目配置文件
# config.yaml
rag:
embedding:
model_name: "shibing624/text2vec-base-chinese"
cache_enabled: true
batch_size: 32
vector_store:
type: "chroma"
persist_directory: "./data/vector_db"
collection_name: "documents"
llm:
provider: "deepseek"
model_name: "deepseek-chat"
api_key: "${DEEPSEEK_API_KEY}"
base_url: "https://api.deepseek.com"
temperature: 0.7
max_tokens: 1000
chunking:
chunk_size: 1000
chunk_overlap: 200
use_semantic_chunking: true
retrieval:
top_k: 5
similarity_threshold: 0.7
use_hybrid_search: true
enable_reranking: false
caching:
enabled: true
max_size: 1000
ttl: 3600
rate_limiting:
max_requests_per_minute: 60
logging:
level: "INFO"
format: "%(asctime)s - %(name)s - %(levelname)s - %(message)s"
file: "./logs/rag_system.log"
server:
host: "0.0.0.0"
port: 8000
workers: 1
security:
enable_auth: false
secret_key: "your-secret-key-here"
8.3 部署脚本
#!/bin/bash
# deploy.sh
echo "开始部署RAG系统..."
# 创建必要目录
mkdir -p data/vector_db
mkdir -p logs
mkdir -p cache
# 安装依赖
pip install -r requirements.txt
# 下载中文模型
python -c "import spacy; spacy.cli.download('zh_core_web_sm')" || echo "Spacy模型下载失败,请手动下载"
# 设置环境变量
export DEEPSEEK_API_KEY="your-deepseek-api-key"
# 启动服务
python -m uvicorn main:app --host 0.0.0.0 --port 8000 --reload
echo "RAG系统部署完成!"
echo "访问 http://localhost:8000/docs 查看API文档"
8.4 依赖文件(requirements.txt)
fastapi==0.104.1
uvicorn==0.24.0
pydantic==2.5.0
sentence-transformers==2.2.2
chromadb==0.4.18
openai==1.3.7
python-multipart==0.0.6
jieba==0.42.1
nltk==3.8.1
spacy==3.7.2
PyPDF2==3.0.1
python-docx==0.8.11
PyYAML==6.0.1
requests==2.31.0
numpy==1.24.3
torch==2.1.1
transformers==4.36.0
更多推荐


所有评论(0)