阿里开源Qwen3模型实战:Qwen3-RAG轻量级系统
1. 配置模块 (config.py)2. 文档处理模块 (document_processor.py)3. 向量检索模块 (vector_retriever.py)4. 重排序模块 (reranker.py)5. 答案生成模块 (answer_generator.py)6. 主应用程序 (main.py)7. 测试脚本 (test_rag.py)8. API服务 (api_server.py)9
·
Qwen3-RAG轻量级系统
项目结构
qwen-rag-system/
├── config.py # 配置文件
├── main.py # 主应用程序
├── test_rag.py # 系统测试脚本
├── api_server.py # FastAPI服务
├── requirements.txt # 依赖包列表
├── .env # 环境配置
├── README.md # 项目说明
├── documents/ # 存放PDF/TXT/MD文档
├── chroma_db/ # 向量数据库
├── document_processor.py # 文档处理模块
├── vector_retriever.py # 向量检索模块
├── reranker.py # 重排序模块
└── answer_generator.py # 答案生成模块
1. 配置模块 (config.py)
import os
from dotenv import load_dotenv
load_dotenv()
class Config:
"""系统配置类"""
# 模型配置
EMBEDDING_MODEL = os.getenv("EMBEDDING_MODEL", "dengcao/Qwen3-Embedding-0.6B:Q4_K_M")
RERANKER_MODEL = os.getenv("RERANKER_MODEL", "dengcao/Qwen3-Reranker-0.6B:Q4_K_M")
LLM_MODEL = os.getenv("LLM_MODEL", "qwen2.5:3b")
# 向量数据库配置
CHROMA_PERSIST_DIR = os.getenv("CHROMA_PERSIST_DIR", "./chroma_db")
COLLECTION_NAME = os.getenv("COLLECTION_NAME", "knowledge_base")
# 文本处理配置
CHUNK_SIZE = int(os.getenv("CHUNK_SIZE", "500"))
CHUNK_OVERLAP = int(os.getenv("CHUNK_OVERLAP", "50"))
TOP_K_RETRIEVAL = int(os.getenv("TOP_K_RETRIEVAL", "10"))
TOP_N_RERANK = int(os.getenv("TOP_N_RERANK", "3"))
# Ollama API配置
OLLAMA_BASE_URL = os.getenv("OLLAMA_BASE_URL", "http://localhost:11434")
# 文件路径
DOCUMENT_PATH = os.getenv("DOCUMENT_PATH", "./documents")
@classmethod
def validate(cls):
"""验证配置"""
if not os.path.exists(cls.DOCUMENT_PATH):
os.makedirs(cls.DOCUMENT_PATH)
print(f"创建文档目录: {cls.DOCUMENT_PATH}")
if not os.path.exists(cls.CHROMA_PERSIST_DIR):
os.makedirs(cls.CHROMA_PERSIST_DIR)
print(f"创建向量数据库目录: {cls.CHROMA_PERSIST_DIR}")
@classmethod
def print_config(cls):
"""打印当前配置"""
print("=" * 50)
print("系统配置信息:")
print("=" * 50)
for key, value in cls.__dict__.items():
if not key.startswith("__") and not callable(value):
print(f"{key}: {value}")
print("=" * 50)
def main():
"""配置模块独立测试"""
print("测试配置模块...")
Config.validate()
Config.print_config()
# 测试环境变量读取
print("\n环境变量测试:")
print(f"Ollama URL: {Config.OLLAMA_BASE_URL}")
print(f"文档路径: {Config.DOCUMENT_PATH}")
return True
if __name__ == "__main__":
main()
2. 文档处理模块 (document_processor.py)
import os
import json
from typing import List, Dict, Any
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain.document_loaders import (
PyPDFLoader,
TextLoader,
UnstructuredMarkdownLoader
)
class DocumentProcessor:
"""文档处理器"""
def __init__(self, chunk_size=500, chunk_overlap=50):
self.chunk_size = chunk_size
self.chunk_overlap = chunk_overlap
self.text_splitter = RecursiveCharacterTextSplitter(
chunk_size=chunk_size,
chunk_overlap=chunk_overlap,
length_function=len,
separators=["\n\n", "\n", "。", "!", "?", ";", ",", " ", ""]
)
def load_documents(self, directory_path: str) -> List[Dict[str, Any]]:
"""加载目录中的所有文档"""
documents = []
if not os.path.exists(directory_path):
print(f"目录不存在: {directory_path}")
return documents
for filename in os.listdir(directory_path):
file_path = os.path.join(directory_path, filename)
try:
if filename.endswith('.pdf'):
loader = PyPDFLoader(file_path)
elif filename.endswith('.txt'):
loader = TextLoader(file_path, encoding='utf-8')
elif filename.endswith('.md'):
loader = UnstructuredMarkdownLoader(file_path)
else:
print(f"跳过不支持的文件类型: {filename}")
continue
loaded_docs = loader.load()
for doc in loaded_docs:
documents.append({
'content': doc.page_content,
'metadata': {
'source': filename,
'page': doc.metadata.get('page', 0)
}
})
print(f"✓ 已加载: {filename}")
except Exception as e:
print(f"✗ 加载文件 {filename} 时出错: {str(e)}")
print(f"总计加载 {len(documents)} 个文档片段")
return documents
def chunk_documents(self, documents: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
"""将文档分块"""
chunked_docs = []
for doc in documents:
chunks = self.text_splitter.split_text(doc['content'])
for i, chunk in enumerate(chunks):
chunked_docs.append({
'id': f"{doc['metadata']['source']}_chunk_{i}",
'content': chunk,
'metadata': {
**doc['metadata'],
'chunk_index': i,
'total_chunks': len(chunks)
}
})
print(f"✓ 文档分块完成,共 {len(chunked_docs)} 个块")
return chunked_docs
def save_chunks_to_json(self, chunks: List[Dict[str, Any]], output_path: str):
"""保存分块结果到JSON文件"""
with open(output_path, 'w', encoding='utf-8') as f:
json.dump(chunks, f, ensure_ascii=False, indent=2)
print(f"✓ 分块结果已保存到: {output_path}")
def process_directory(self, input_dir: str, output_json: str = None) -> List[Dict[str, Any]]:
"""完整处理目录中的文档"""
print(f"开始处理目录: {input_dir}")
# 加载文档
documents = self.load_documents(input_dir)
if not documents:
print("没有找到可处理的文档")
return []
# 分块处理
chunks = self.chunk_documents(documents)
# 保存结果
if output_json:
self.save_chunks_to_json(chunks, output_json)
return chunks
def main():
"""文档处理模块独立测试"""
print("测试文档处理模块...")
# 创建测试文档
test_dir = "./test_documents"
if not os.path.exists(test_dir):
os.makedirs(test_dir)
# 创建测试文本文件
test_content = """这是一个测试文档。
用于测试文档处理模块的功能。
文档处理包括加载、分块等步骤。
RAG系统需要将文档分块以便后续处理。"""
with open(os.path.join(test_dir, "test.txt"), "w", encoding="utf-8") as f:
f.write(test_content)
# 初始化处理器
processor = DocumentProcessor(chunk_size=100, chunk_overlap=20)
# 测试处理功能
chunks = processor.process_directory(test_dir, "./test_chunks.json")
if chunks:
print("\n处理结果示例:")
for i, chunk in enumerate(chunks[:3]): # 显示前3个块
print(f"块 {i+1}: {chunk['content'][:50]}...")
print(f"总共处理了 {len(chunks)} 个文档块")
# 清理测试文件
import shutil
if os.path.exists("./test_chunks.json"):
os.remove("./test_chunks.json")
if os.path.exists(test_dir):
shutil.rmtree(test_dir)
print("\n文档处理模块测试完成!")
return len(chunks) > 0
if __name__ == "__main__":
main()
3. 向量检索模块 (vector_retriever.py)
import requests
import numpy as np
from typing import List, Dict, Any
import chromadb
from chromadb.config import Settings
import hashlib
import time
class QwenEmbeddingClient:
"""Qwen Embedding客户端"""
def __init__(self, model_name: str, base_url: str = "http://localhost:11434"):
self.model_name = model_name
self.base_url = base_url
self.embedding_url = f"{base_url}/api/embeddings"
def get_embedding(self, text: str, instruction: str = None) -> List[float]:
"""获取单个文本的嵌入向量"""
payload = {
"model": self.model_name,
"input": text
}
if instruction:
payload["input"] = f"指令:{instruction}\n文本:{text}"
try:
response = requests.post(self.embedding_url, json=payload, timeout=60)
response.raise_for_status()
result = response.json()
return result.get("embedding", [])
except requests.exceptions.RequestException as e:
print(f"获取嵌入向量时网络错误: {str(e)}")
return []
except Exception as e:
print(f"获取嵌入向量时出错: {str(e)}")
return []
def get_embeddings_batch(self, texts: List[str], batch_size: int = 5) -> List[List[float]]:
"""批量获取嵌入向量"""
embeddings = []
print(f"开始批量处理 {len(texts)} 个文本...")
for i in range(0, len(texts), batch_size):
batch = texts[i:i+batch_size]
batch_embeddings = []
for text in batch:
embedding = self.get_embedding(text)
if embedding:
batch_embeddings.append(embedding)
else:
# 使用零向量占位
batch_embeddings.append([0.0] * 1024)
embeddings.extend(batch_embeddings)
processed = min(i+batch_size, len(texts))
print(f"进度: {processed}/{len(texts)}")
time.sleep(0.1) # 防止请求过快
return embeddings
class VectorRetriever:
"""向量检索器"""
def __init__(self, config):
self.config = config
self.embedding_client = QwenEmbeddingClient(
model_name=config.EMBEDDING_MODEL,
base_url=config.OLLAMA_BASE_URL
)
# 初始化ChromaDB
self.client = chromadb.PersistentClient(
path=config.CHROMA_PERSIST_DIR,
settings=Settings(anonymized_telemetry=False)
)
# 获取或创建集合
self.collection = self.client.get_or_create_collection(
name=config.COLLECTION_NAME,
metadata={"hnsw:space": "cosine"}
)
def generate_doc_id(self, content: str, metadata: Dict) -> str:
"""生成文档ID"""
content_hash = hashlib.md5(content.encode()).hexdigest()
source = metadata.get('source', 'unknown')
chunk_idx = metadata.get('chunk_index', 0)
return f"{source}_chunk{chunk_idx}_{content_hash[:8]}"
def add_documents(self, documents: List[Dict[str, Any]]):
"""将文档添加到向量数据库"""
if not documents:
print("没有文档可添加")
return False
print(f"开始添加 {len(documents)} 个文档到向量数据库...")
# 准备数据
ids = []
contents = []
metadatas = []
for doc in documents:
doc_id = self.generate_doc_id(doc['content'], doc['metadata'])
ids.append(doc_id)
contents.append(doc['content'])
metadatas.append(doc['metadata'])
# 获取嵌入向量
print("生成嵌入向量...")
embeddings = self.embedding_client.get_embeddings_batch(contents)
if not embeddings:
print("无法生成嵌入向量")
return False
# 添加到集合
try:
self.collection.add(
embeddings=embeddings,
documents=contents,
metadatas=metadatas,
ids=ids
)
print(f"✓ 已成功添加 {len(documents)} 个文档到向量数据库")
return True
except Exception as e:
print(f"✗ 添加文档时出错: {str(e)}")
return False
def search(self, query: str, top_k: int = 10) -> List[Dict[str, Any]]:
"""检索相关文档"""
print(f"检索查询: {query}")
# 获取查询的嵌入向量
query_embedding = self.embedding_client.get_embedding(query)
if not query_embedding:
print("无法获取查询的嵌入向量")
return []
# 执行搜索
try:
results = self.collection.query(
query_embeddings=[query_embedding],
n_results=top_k,
include=["documents", "metadatas", "distances"]
)
# 整理结果
retrieved_docs = []
if results['documents']:
for i in range(len(results['documents'][0])):
retrieved_docs.append({
'content': results['documents'][0][i],
'metadata': results['metadatas'][0][i],
'score': 1 - results['distances'][0][i]
})
print(f"✓ 检索到 {len(retrieved_docs)} 个相关文档")
return retrieved_docs
except Exception as e:
print(f"✗ 检索时出错: {str(e)}")
return []
def get_collection_stats(self) -> Dict[str, Any]:
"""获取集合统计信息"""
try:
count = self.collection.count()
return {
"total_documents": count,
"collection_name": self.config.COLLECTION_NAME
}
except Exception as e:
print(f"获取集合统计时出错: {str(e)}")
return {}
def main():
"""向量检索模块独立测试"""
print("测试向量检索模块...")
# 导入配置
from config import Config
config = Config()
config.validate()
# 初始化检索器
retriever = VectorRetriever(config)
# 测试集合状态
stats = retriever.get_collection_stats()
print(f"集合状态: {stats}")
# 测试检索功能(需要先有数据)
test_query = "人工智能"
results = retriever.search(test_query, top_k=3)
if results:
print(f"\n查询 '{test_query}' 的检索结果:")
for i, doc in enumerate(results):
print(f"\n结果 {i+1}:")
print(f" 分数: {doc['score']:.4f}")
print(f" 来源: {doc['metadata'].get('source', '未知')}")
print(f" 内容: {doc['content'][:100]}...")
else:
print("没有检索到结果,可能需要先添加文档")
# 测试嵌入生成
print("\n测试嵌入生成...")
test_text = "这是一个测试文本"
embedding = retriever.embedding_client.get_embedding(test_text)
if embedding:
print(f"嵌入向量维度: {len(embedding)}")
print(f"前10个值: {embedding[:10]}")
else:
print("无法生成嵌入向量,请检查Ollama服务")
print("\n向量检索模块测试完成!")
return len(results) > 0 if results else False
if __name__ == "__main__":
main()
4. 重排序模块 (reranker.py)
import requests
import json
from typing import List, Dict, Any
import re
class QwenRerankerClient:
"""Qwen Reranker客户端"""
def __init__(self, model_name: str, base_url: str = "http://localhost:11434"):
self.model_name = model_name
self.base_url = base_url
self.generate_url = f"{base_url}/api/generate"
def rerank(self, query: str, documents: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
"""对检索结果进行重排序"""
if not documents:
print("没有文档需要重排序")
return []
print(f"开始重排序 {len(documents)} 个文档...")
reranked_docs = []
for idx, doc in enumerate(documents):
# 构建重排序提示
prompt = self._build_rerank_prompt(query, doc['content'])
# 调用模型获取分数
score = self._get_relevance_score(prompt)
reranked_docs.append({
**doc,
'rerank_score': score
})
# 显示进度
print(f" 进度: {idx+1}/{len(documents)} - 分数: {score:.2f}")
# 按重排序分数降序排列
reranked_docs.sort(key=lambda x: x['rerank_score'], reverse=True)
print(f"✓ 重排序完成,最佳分数: {reranked_docs[0]['rerank_score']:.2f}")
return reranked_docs
def _build_rerank_prompt(self, query: str, document: str) -> str:
"""构建重排序提示"""
# 限制文档长度
doc_preview = document[:800] + "..." if len(document) > 800 else document
return f"""请评估以下查询与文档的相关性,只输出一个0-10的分数,不要有任何其他文本。
查询:{query}
文档:{doc_preview}
相关性分数:"""
def _get_relevance_score(self, prompt: str) -> float:
"""获取相关性分数"""
payload = {
"model": self.model_name,
"prompt": prompt,
"stream": False,
"options": {
"temperature": 0.1,
"num_predict": 10
}
}
try:
response = requests.post(self.generate_url, json=payload, timeout=30)
response.raise_for_status()
result = response.json()
# 提取数字分数
response_text = result.get("response", "").strip()
# 尝试从响应中提取数字
numbers = re.findall(r"\d+\.?\d*", response_text)
if numbers:
score = float(numbers[0])
# 确保分数在0-10范围内
return min(max(score, 0), 10)
else:
print(f"警告: 无法从响应中提取分数: {response_text}")
return 5.0
except requests.exceptions.Timeout:
print("重排序请求超时")
return 0.0
except Exception as e:
print(f"重排序时出错: {str(e)}")
return 0.0
class Reranker:
"""重排序器"""
def __init__(self, config):
self.config = config
self.reranker_client = QwenRerankerClient(
model_name=config.RERANKER_MODEL,
base_url=config.OLLAMA_BASE_URL
)
def process(self, query: str, retrieved_docs: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
"""执行重排序"""
if not retrieved_docs:
return []
print(f"开始重排序,处理 {len(retrieved_docs)} 个文档...")
# 执行重排序
reranked_docs = self.reranker_client.rerank(query, retrieved_docs)
# 只保留top_n个结果
top_n = min(self.config.TOP_N_RERANK, len(reranked_docs))
final_docs = reranked_docs[:top_n]
print(f"✓ 重排序完成,保留 {len(final_docs)} 个最相关文档")
# 打印重排序结果
print("\n重排序结果:")
for i, doc in enumerate(final_docs):
print(f"{i+1}. 分数: {doc.get('rerank_score', 0):.2f}, "
f"源: {doc['metadata'].get('source', '未知')}")
return final_docs
def main():
"""重排序模块独立测试"""
print("测试重排序模块...")
# 导入配置
from config import Config
config = Config()
config.validate()
# 初始化重排序器
reranker = Reranker(config)
# 创建测试数据
test_query = "人工智能的应用领域"
test_documents = [
{
'content': '人工智能在医疗领域有广泛应用,如疾病诊断和治疗方案推荐。',
'metadata': {'source': 'test1.txt', 'score': 0.85}
},
{
'content': '机器学习是人工智能的核心技术,包括监督学习和无监督学习。',
'metadata': {'source': 'test2.txt', 'score': 0.78}
},
{
'content': '深度学习在计算机视觉和自然语言处理中取得显著成果。',
'metadata': {'source': 'test3.txt', 'score': 0.92}
},
{
'content': '自动驾驶技术结合了计算机视觉、传感器融合和路径规划。',
'metadata': {'source': 'test4.txt', 'score': 0.65}
}
]
print(f"测试查询: {test_query}")
print(f"测试文档数量: {len(test_documents)}")
# 测试重排序
reranked_docs = reranker.process(test_query, test_documents)
if reranked_docs:
print("\n重排序结果对比:")
print("原始顺序 vs 重排序后:")
for i, (orig, reranked) in enumerate(zip(test_documents, reranked_docs)):
orig_source = orig['metadata']['source']
reranked_source = reranked['metadata']['source']
reranked_score = reranked.get('rerank_score', 0)
print(f" {i+1}. {orig_source} (原始) -> {reranked_source} (分数: {reranked_score:.2f})")
# 测试单个文档评分
print("\n测试单个文档评分...")
test_prompt = reranker.reranker_client._build_rerank_prompt(
"机器学习",
"机器学习是人工智能的重要分支,让计算机从数据中学习规律。"
)
score = reranker.reranker_client._get_relevance_score(test_prompt)
print(f"单个文档评分: {score:.2f}")
print("\n重排序模块测试完成!")
return len(reranked_docs) > 0
if __name__ == "__main__":
main()
5. 答案生成模块 (answer_generator.py)
import requests
from typing import List, Dict, Any
class AnswerGenerator:
"""答案生成器"""
def __init__(self, config):
self.config = config
self.generate_url = f"{config.OLLAMA_BASE_URL}/api/generate"
def build_prompt(self, query: str, context_docs: List[Dict[str, Any]]) -> str:
"""构建生成提示"""
if not context_docs:
return f"问题:{query}\n\n请回答这个问题:"
context_text = "以下是相关的上下文信息:\n\n"
for i, doc in enumerate(context_docs):
source = doc['metadata'].get('source', '未知来源')
context_text += f"[信息{i+1} 来自 {source}]:\n{doc['content']}\n\n"
prompt = f"""基于以下上下文信息,回答问题。如果上下文信息不足,请说明无法从提供的信息中获取完整答案。
{context_text}
问题:{query}
请提供准确、简洁的答案,并注明答案的来源信息:"""
return prompt
def generate_answer(self, query: str, context_docs: List[Dict[str, Any]]) -> Dict[str, Any]:
"""生成答案"""
print(f"生成答案,使用 {len(context_docs)} 个上下文文档...")
if not context_docs:
print("警告:没有上下文文档")
return {
"answer": "未找到相关文档信息,无法回答问题。",
"sources": [],
"context_used": False,
"confidence": 0.0
}
# 构建提示
prompt = self.build_prompt(query, context_docs)
# 调用LLM生成答案
payload = {
"model": self.config.LLM_MODEL,
"prompt": prompt,
"stream": False,
"options": {
"temperature": 0.3,
"num_predict": 500,
"top_p": 0.9
}
}
try:
response = requests.post(self.generate_url, json=payload, timeout=120)
response.raise_for_status()
result = response.json()
answer = result.get("response", "").strip()
# 提取参考来源
sources = list(set([
doc['metadata'].get('source', '未知来源')
for doc in context_docs
]))
# 计算置信度(基于上下文数量和重排序分数)
total_rerank_score = sum(doc.get('rerank_score', 5) for doc in context_docs)
avg_rerank_score = total_rerank_score / len(context_docs) if context_docs else 0
confidence = min(avg_rerank_score / 10, 1.0) # 归一化到0-1
print(f"✓ 答案生成完成,置信度: {confidence:.2f}")
return {
"answer": answer,
"sources": sources,
"context_used": True,
"context_count": len(context_docs),
"confidence": confidence,
"prompt_length": len(prompt),
"answer_length": len(answer)
}
except requests.exceptions.Timeout:
print("生成答案超时")
return {
"answer": "生成答案超时,请稍后重试。",
"sources": [],
"context_used": False,
"confidence": 0.0
}
except Exception as e:
print(f"生成答案时出错: {str(e)}")
return {
"answer": f"生成答案时出错: {str(e)}",
"sources": [],
"context_used": False,
"confidence": 0.0
}
def test_generation(self, test_query: str, test_contexts: List[str] = None) -> Dict[str, Any]:
"""测试生成功能"""
if test_contexts is None:
test_contexts = [
"人工智能是计算机科学的一个分支,致力于创建能够执行通常需要人类智能的任务的系统。",
"机器学习是人工智能的一种方法,使计算机能够从数据中学习而无需明确编程。",
"深度学习是机器学习的一个子集,使用多层神经网络来模拟人脑的决策过程。"
]
# 构建测试文档
test_docs = []
for i, context in enumerate(test_contexts):
test_docs.append({
'content': context,
'metadata': {'source': f'test_source_{i+1}'},
'rerank_score': 8.5 - i # 模拟重排序分数
})
print(f"测试查询: {test_query}")
print(f"测试上下文数量: {len(test_docs)}")
# 生成答案
result = self.generate_answer(test_query, test_docs)
return result
def main():
"""答案生成模块独立测试"""
print("测试答案生成模块...")
# 导入配置
from config import Config
config = Config()
config.validate()
# 初始化生成器
generator = AnswerGenerator(config)
# 测试生成功能
test_query = "什么是人工智能?"
result = generator.test_generation(test_query)
# 显示结果
print("\n" + "="*50)
print("测试结果:")
print("="*50)
print(f"查询: {test_query}")
print(f"答案: {result['answer']}")
print(f"使用上下文: {result['context_used']}")
print(f"上下文数量: {result.get('context_count', 0)}")
print(f"置信度: {result.get('confidence', 0):.2f}")
print(f"答案长度: {result.get('answer_length', 0)} 字符")
print(f"来源: {', '.join(result['sources']) if result['sources'] else '无'}")
print("="*50)
# 测试提示词构建
print("\n测试提示词构建...")
test_docs = [{'content': '测试内容', 'metadata': {'source': 'test.txt'}}]
prompt = generator.build_prompt("测试问题", test_docs)
print(f"提示词长度: {len(prompt)} 字符")
print(f"提示词预览:\n{prompt[:200]}...")
print("\n答案生成模块测试完成!")
return result['context_used']
if __name__ == "__main__":
main()
6. 主应用程序 (main.py)
import os
import time
from config import Config
from document_processor import DocumentProcessor
from vector_retriever import VectorRetriever
from reranker import Reranker
from answer_generator import AnswerGenerator
class RAGSystem:
"""RAG系统主类"""
def __init__(self):
self.config = Config()
self.config.validate()
self.doc_processor = DocumentProcessor(
chunk_size=self.config.CHUNK_SIZE,
chunk_overlap=self.config.CHUNK_OVERLAP
)
self.retriever = VectorRetriever(self.config)
self.reranker = Reranker(self.config)
self.generator = AnswerGenerator(self.config)
self.is_initialized = False
def initialize_knowledge_base(self, force_reload: bool = False):
"""初始化知识库"""
print("正在初始化知识库...")
# 检查是否已存在向量数据库
if not force_reload and os.path.exists(self.config.CHROMA_PERSIST_DIR):
# 检查集合中是否有数据
stats = self.retriever.get_collection_stats()
if stats.get('total_documents', 0) > 0:
print(f"✓ 检测到已有的向量数据库 ({stats['total_documents']} 个文档),跳过重新构建...")
self.is_initialized = True
return
# 加载和分块文档
documents = self.doc_processor.load_documents(self.config.DOCUMENT_PATH)
if not documents:
print("✗ 未找到文档,请将文档放入 documents/ 目录")
print(f" 当前文档路径: {self.config.DOCUMENT_PATH}")
return
chunked_docs = self.doc_processor.chunk_documents(documents)
# 保存分块结果(可选)
self.doc_processor.save_chunks_to_json(
chunked_docs,
"./document_chunks.json"
)
# 添加到向量数据库
success = self.retriever.add_documents(chunked_docs)
if success:
self.is_initialized = True
print("✓ 知识库初始化完成!")
else:
print("✗ 知识库初始化失败")
def query(self, question: str) -> Dict[str, Any]:
"""处理用户查询"""
if not self.is_initialized:
return {"error": "知识库未初始化"}
print(f"\n{'='*60}")
print(f"处理查询: {question}")
print(f"{'='*60}")
# 步骤1: 检索相关文档
start_time = time.time()
retrieved_docs = self.retriever.search(
question,
top_k=self.config.TOP_K_RETRIEVAL
)
retrieval_time = time.time() - start_time
if not retrieved_docs:
return {
"answer": "未找到相关文档。",
"sources": [],
"context_used": False,
"timing": {
"retrieval": retrieval_time,
"reranking": 0,
"generation": 0,
"total": retrieval_time
}
}
print(f"✓ 初步检索到 {len(retrieved_docs)} 个文档,耗时: {retrieval_time:.2f}秒")
# 步骤2: 重排序
start_rerank = time.time()
reranked_docs = self.reranker.process(question, retrieved_docs)
rerank_time = time.time() - start_rerank
# 步骤3: 生成答案
start_gen = time.time()
result = self.generator.generate_answer(question, reranked_docs)
gen_time = time.time() - start_gen
# 添加时间信息
result["timing"] = {
"retrieval": retrieval_time,
"reranking": rerank_time,
"generation": gen_time,
"total": retrieval_time + rerank_time + gen_time
}
print(f"\n✓ 答案生成完成!总耗时: {result['timing']['total']:.2f}秒")
print(f"{'='*60}")
return result
def interactive_mode(self):
"""交互式查询模式"""
if not self.is_initialized:
print("正在初始化知识库...")
self.initialize_knowledge_base()
if not self.is_initialized:
print("无法初始化知识库,退出交互模式")
return
print("\n" + "="*60)
print("RAG 系统已启动!")
print("="*60)
print("命令说明:")
print(" 'quit', 'exit', 'q' - 退出系统")
print(" 'reload' - 重新加载知识库")
print(" 'stats' - 显示系统统计")
print(" 'config' - 显示当前配置")
print("="*60 + "\n")
while True:
try:
question = input("\n请输入问题: ").strip()
if question.lower() in ['quit', 'exit', 'q']:
print("再见!")
break
if question.lower() == 'reload':
print("重新加载知识库...")
self.initialize_knowledge_base(force_reload=True)
continue
if question.lower() == 'stats':
stats = self.retriever.get_collection_stats()
print(f"系统统计:")
print(f" 文档数量: {stats.get('total_documents', 0)}")
print(f" 集合名称: {stats.get('collection_name', '未知')}")
continue
if question.lower() == 'config':
self.config.print_config()
continue
if not question:
continue
# 处理查询
result = self.query(question)
# 显示结果
if "error" in result:
print(f"错误: {result['error']}")
else:
print(f"\n📝 答案:")
print(f"{result['answer']}")
if result.get('sources'):
print(f"\n📚 参考来源: {', '.join(result['sources'])}")
if result.get('confidence', 0) > 0:
print(f"\n📊 置信度: {result['confidence']:.2%}")
print(f"\n⏱️ 时间统计:")
for stage, time_taken in result['timing'].items():
if stage != 'total':
print(f" {stage}: {time_taken:.2f}秒")
print(f" 总计: {result['timing']['total']:.2f}秒")
except KeyboardInterrupt:
print("\n\n程序被中断")
break
except Exception as e:
print(f"处理查询时出错: {str(e)}")
def test_system():
"""测试整个RAG系统"""
print("测试RAG系统...")
rag = RAGSystem()
# 初始化知识库
print("\n1. 初始化知识库...")
rag.initialize_knowledge_base()
if not rag.is_initialized:
print("知识库初始化失败,无法继续测试")
return False
# 测试查询
print("\n2. 测试查询功能...")
test_queries = [
"人工智能是什么?",
"机器学习和深度学习有什么区别?",
"RAG系统的工作原理是什么?"
]
results = []
for query in test_queries:
print(f"\n测试查询: {query}")
result = rag.query(query)
results.append(result)
if "error" not in result:
print(f" 成功生成答案 (长度: {len(result.get('answer', ''))} 字符)")
# 总结测试结果
print("\n" + "="*60)
print("测试总结:")
print("="*60)
success_count = sum(1 for r in results if "error" not in r and r.get('context_used', False))
print(f"成功查询: {success_count}/{len(test_queries)}")
if results:
avg_time = sum(r.get('timing', {}).get('total', 0) for r in results) / len(results)
print(f"平均查询时间: {avg_time:.2f}秒")
return success_count > 0
def main():
"""主函数"""
print("Qwen3-RAG 轻量级系统")
print("="*60)
import sys
if len(sys.argv) > 1:
command = sys.argv[1].lower()
if command == "test":
# 运行测试模式
success = test_system()
sys.exit(0 if success else 1)
elif command == "init":
# 只初始化知识库
rag = RAGSystem()
rag.initialize_knowledge_base(force_reload=True)
sys.exit(0)
elif command == "serve":
# 启动API服务
from api_server import app
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=8000)
sys.exit(0)
else:
print(f"未知命令: {command}")
print("可用命令: test, init, serve")
sys.exit(1)
# 默认启动交互模式
rag_system = RAGSystem()
rag_system.interactive_mode()
if __name__ == "__main__":
main()
7. 测试脚本 (test_rag.py)
#!/usr/bin/env python3
"""
RAG系统测试脚本
"""
import sys
import time
import json
from typing import List, Dict, Any
# 添加当前目录到路径
sys.path.append('.')
def test_config_module():
"""测试配置模块"""
print("=" * 60)
print("测试配置模块")
print("=" * 60)
try:
from config import Config, main as config_main
success = config_main()
return success
except Exception as e:
print(f"配置模块测试失败: {str(e)}")
return False
def test_document_processor():
"""测试文档处理模块"""
print("\n" + "=" * 60)
print("测试文档处理模块")
print("=" * 60)
try:
from document_processor import DocumentProcessor, main as processor_main
success = processor_main()
return success
except Exception as e:
print(f"文档处理模块测试失败: {str(e)}")
return False
def test_vector_retriever():
"""测试向量检索模块"""
print("\n" + "=" * 60)
print("测试向量检索模块")
print("=" * 60)
try:
from vector_retriever import VectorRetriever, main as retriever_main
# 先确保Ollama服务运行
import requests
try:
response = requests.get("http://localhost:11434/api/tags", timeout=5)
if response.status_code == 200:
print("✓ Ollama服务运行正常")
else:
print("✗ Ollama服务异常")
return False
except:
print("✗ 无法连接到Ollama服务")
print("请先启动Ollama服务: ollama serve")
return False
success = retriever_main()
return success
except Exception as e:
print(f"向量检索模块测试失败: {str(e)}")
return False
def test_reranker():
"""测试重排序模块"""
print("\n" + "=" * 60)
print("测试重排序模块")
print("=" * 60)
try:
from reranker import Reranker, main as reranker_main
success = reranker_main()
return success
except Exception as e:
print(f"重排序模块测试失败: {str(e)}")
return False
def test_answer_generator():
"""测试答案生成模块"""
print("\n" + "=" * 60)
print("测试答案生成模块")
print("=" * 60)
try:
from answer_generator import AnswerGenerator, main as generator_main
# 检查Ollama服务
import requests
try:
response = requests.get("http://localhost:11434/api/tags", timeout=5)
if response.status_code != 200:
print("✗ Ollama服务异常")
return False
except:
print("✗ 无法连接到Ollama服务")
return False
success = generator_main()
return success
except Exception as e:
print(f"答案生成模块测试失败: {str(e)}")
return False
def test_full_system():
"""测试完整系统"""
print("\n" + "=" * 60)
print("测试完整RAG系统")
print("=" * 60)
try:
from main import test_system
success = test_system()
return success
except Exception as e:
print(f"完整系统测试失败: {str(e)}")
return False
def run_performance_test():
"""运行性能测试"""
print("\n" + "=" * 60)
print("性能测试")
print("=" * 60)
try:
from main import RAGSystem
rag = RAGSystem()
rag.initialize_knowledge_base()
if not rag.is_initialized:
print("知识库未初始化,跳过性能测试")
return False
# 测试查询
test_query = "人工智能的主要应用领域"
times = []
answers = []
print(f"运行性能测试 (查询: '{test_query}')")
for i in range(3): # 运行3次取平均
print(f"\n第 {i+1} 次运行...")
start_time = time.time()
result = rag.query(test_query)
end_time = time.time()
if "error" not in result:
total_time = result.get('timing', {}).get('total', end_time - start_time)
times.append(total_time)
answers.append(result.get('answer', ''))
print(f" 时间: {total_time:.2f}秒")
print(f" 答案长度: {len(answers[-1])} 字符")
print(f" 置信度: {result.get('confidence', 0):.2%}")
else:
print(f" 失败: {result.get('error', '未知错误')}")
if times:
print("\n性能测试结果:")
print(f" 平均时间: {sum(times)/len(times):.2f}秒")
print(f" 最短时间: {min(times):.2f}秒")
print(f" 最长时间: {max(times):.2f}秒")
# 检查答案一致性
if len(answers) >= 2:
similar = all(len(a) > 50 for a in answers) # 简单检查
print(f" 答案一致性: {'良好' if similar else '不稳定'}")
return True
else:
print("性能测试失败: 没有成功的查询")
return False
except Exception as e:
print(f"性能测试失败: {str(e)}")
return False
def main():
"""主测试函数"""
print("Qwen3-RAG 系统测试套件")
print("=" * 60)
# 运行各模块测试
tests = [
("配置模块", test_config_module),
("文档处理模块", test_document_processor),
("向量检索模块", test_vector_retriever),
("重排序模块", test_reranker),
("答案生成模块", test_answer_generator),
("完整系统", test_full_system),
("性能测试", run_performance_test),
]
results = []
for test_name, test_func in tests:
print(f"\n▶ 开始测试: {test_name}")
try:
success = test_func()
results.append((test_name, success))
status = "✓ 通过" if success else "✗ 失败"
print(f"{status}: {test_name}")
except Exception as e:
print(f"✗ 异常: {test_name} - {str(e)}")
results.append((test_name, False))
# 输出测试报告
print("\n" + "=" * 60)
print("测试报告")
print("=" * 60)
passed = sum(1 for _, success in results if success)
total = len(results)
print(f"通过率: {passed}/{total} ({passed/total*100:.1f}%)")
print("\n详细结果:")
for test_name, success in results:
status = "✓ 通过" if success else "✗ 失败"
print(f" {status}: {test_name}")
# 保存测试结果
with open("test_results.json", "w", encoding="utf-8") as f:
json.dump({
"timestamp": time.strftime("%Y-%m-%d %H:%M:%S"),
"total_tests": total,
"passed_tests": passed,
"results": [
{"test": name, "passed": success}
for name, success in results
]
}, f, ensure_ascii=False, indent=2)
print(f"\n详细结果已保存到: test_results.json")
return passed == total
if __name__ == "__main__":
success = main()
sys.exit(0 if success else 1)
8. API服务 (api_server.py)
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
from typing import Optional, List
import uvicorn
import sys
import os
# 添加当前目录到路径
sys.path.append('.')
from main import RAGSystem
app = FastAPI(
title="Qwen3-RAG API",
description="基于Qwen3-Embedding和Qwen3-Reranker的RAG系统API",
version="1.0.0"
)
# 初始化RAG系统
rag_system = RAGSystem()
rag_system.initialize_knowledge_base()
class QueryRequest(BaseModel):
"""查询请求模型"""
question: str
top_k: Optional[int] = 10
top_n: Optional[int] = 3
include_context: Optional[bool] = False
class QueryResponse(BaseModel):
"""查询响应模型"""
answer: str
sources: List[str]
context_used: bool
confidence: Optional[float] = 0.0
timing: dict
error: Optional[str] = None
class SystemStatus(BaseModel):
"""系统状态模型"""
initialized: bool
document_count: int
collection_name: str
config: dict
@app.get("/")
async def root():
"""根端点"""
return {
"service": "Qwen3-RAG API",
"version": "1.0.0",
"endpoints": {
"/query": "POST - 处理查询",
"/status": "GET - 系统状态",
"/health": "GET - 健康检查"
}
}
@app.post("/query", response_model=QueryResponse)
async def process_query(request: QueryRequest):
"""处理查询请求"""
try:
# 临时更新配置
if request.top_k:
rag_system.config.TOP_K_RETRIEVAL = request.top_k
if request.top_n:
rag_system.config.TOP_N_RERANK = request.top_n
# 处理查询
result = rag_system.query(request.question)
if "error" in result:
return QueryResponse(
answer="",
sources=[],
context_used=False,
confidence=0.0,
timing={"total": 0},
error=result["error"]
)
# 构建响应
response = QueryResponse(
answer=result.get("answer", ""),
sources=result.get("sources", []),
context_used=result.get("context_used", False),
confidence=result.get("confidence", 0.0),
timing=result.get("timing", {})
)
return response
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@app.get("/status", response_model=SystemStatus)
async def get_status():
"""获取系统状态"""
try:
from vector_retriever import VectorRetriever
from config import Config
config = Config()
retriever = VectorRetriever(config)
stats = retriever.get_collection_stats()
return SystemStatus(
initialized=rag_system.is_initialized,
document_count=stats.get("total_documents", 0),
collection_name=stats.get("collection_name", ""),
config={
"embedding_model": config.EMBEDDING_MODEL,
"reranker_model": config.RERANKER_MODEL,
"llm_model": config.LLM_MODEL,
"chunk_size": config.CHUNK_SIZE,
"top_k": config.TOP_K_RETRIEVAL,
"top_n": config.TOP_N_RERANK
}
)
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@app.get("/health")
async def health_check():
"""健康检查"""
try:
# 检查Ollama服务
import requests
ollama_response = requests.get(
f"{rag_system.config.OLLAMA_BASE_URL}/api/tags",
timeout=5
)
ollama_ok = ollama_response.status_code == 200
return {
"status": "healthy" if rag_system.is_initialized and ollama_ok else "degraded",
"rag_system_initialized": rag_system.is_initialized,
"ollama_service": "running" if ollama_ok else "unavailable",
"timestamp": time.strftime("%Y-%m-%d %H:%M:%S")
}
except:
return {
"status": "unhealthy",
"rag_system_initialized": rag_system.is_initialized,
"ollama_service": "unavailable",
"timestamp": time.strftime("%Y-%m-%d %H:%M:%S")
}
@app.get("/config")
async def get_config():
"""获取当前配置"""
from config import Config
config = Config()
return {
"embedding_model": config.EMBEDDING_MODEL,
"reranker_model": config.RERANKER_MODEL,
"llm_model": config.LLM_MODEL,
"chunk_size": config.CHUNK_SIZE,
"chunk_overlap": config.CHUNK_OVERLAP,
"top_k_retrieval": config.TOP_K_RETRIEVAL,
"top_n_rerank": config.TOP_N_RERANK,
"document_path": config.DOCUMENT_PATH,
"chroma_persist_dir": config.CHROMA_PERSIST_DIR
}
def main():
"""启动API服务"""
print("启动 Qwen3-RAG API 服务...")
print(f"服务地址: http://0.0.0.0:8000")
print(f"API文档: http://0.0.0.0:8000/docs")
print("\n按 Ctrl+C 停止服务\n")
uvicorn.run(app, host="0.0.0.0", port=8000)
if __name__ == "__main__":
import time
main()
9. 配置文件
requirements.txt
ollama>=0.1.0
chromadb>=0.4.0
langchain>=0.1.0
sentence-transformers>=2.2.0
pypdf>=3.0.0
python-dotenv>=1.0.0
fastapi>=0.104.0
uvicorn>=0.24.0
requests>=2.31.0
numpy>=1.24.0
.env 文件
# 模型配置
EMBEDDING_MODEL=dengcao/Qwen3-Embedding-0.6B:Q4_K_M
RERANKER_MODEL=dengcao/Qwen3-Reranker-0.6B:Q4_K_M
LLM_MODEL=qwen2.5:3b
# 向量数据库配置
CHROMA_PERSIST_DIR=./chroma_db
COLLECTION_NAME=knowledge_base
# 文本处理配置
CHUNK_SIZE=500
CHUNK_OVERLAP=50
TOP_K_RETRIEVAL=10
TOP_N_RERANK=3
# Ollama API配置
OLLAMA_BASE_URL=http://localhost:11434
# 文件路径
DOCUMENT_PATH=./documents
10. 使用说明
10.1 安装与配置
# 1. 克隆或创建项目
mkdir qwen-rag-system && cd qwen-rag-system
# 2. 创建虚拟环境
python -m venv venv
source venv/bin/activate # Linux/Mac
# 或
venv\Scripts\activate # Windows
# 3. 安装依赖
pip install -r requirements.txt
# 4. 启动Ollama服务
ollama serve &
# 5. 拉取模型(可选,运行时会自动拉取)
ollama pull dengcao/Qwen3-Embedding-0.6B:Q4_K_M
ollama pull dengcao/Qwen3-Reranker-0.6B:Q4_K_M
ollama pull qwen2.5:3b
10.2 准备文档
# 将文档放入documents目录
mkdir documents
# 复制PDF、TXT或MD文件到documents/目录
10.3 运行方式
方式1:交互模式(默认)
python main.py
方式2:测试模式
# 运行所有测试
python test_rag.py
# 或使用主程序的测试命令
python main.py test
方式3:API服务
python main.py serve
# 或直接运行API服务
python api_server.py
方式4:仅初始化知识库
python main.py init
10.4 独立测试各模块
# 测试配置模块
python -c "from config import main; main()"
# 测试文档处理模块
python -c "from document_processor import main; main()"
# 测试向量检索模块
python -c "from vector_retriever import main; main()"
# 测试重排序模块
python -c "from reranker import main; main()"
# 测试答案生成模块
python -c "from answer_generator import main; main()"
11. 项目特点
模块化设计
- 每个模块都有独立的
main()方法,便于单独测试和调试 - 清晰的依赖关系,易于理解和维护
- 配置集中管理,便于调整参数
易于扩展
- 支持多种文档格式(PDF、TXT、MD)
- 可替换不同规模的Qwen3模型
- 可扩展其他向量数据库或LLM模型
完整测试
- 每个模块都有独立的测试功能
- 提供完整的系统测试套件
- 包含性能测试和健康检查
多种部署方式
- 交互式命令行界面
- RESTful API服务
- 可集成到其他应用
12. 故障排除
常见问题
-
Ollama连接失败
错误:无法连接到Ollama服务 解决:确保Ollama服务正在运行 ollama serve & -
模型未找到
错误:模型不存在 解决:先拉取模型 ollama pull dengcao/Qwen3-Embedding-0.6B:Q4_K_M -
内存不足
错误:内存分配失败 解决:使用更小的量化版本 修改.env文件中的模型配置 -
文档加载失败
错误:无法加载PDF文档 解决:安装正确的依赖 pip install pypdf
调试建议
- 先运行
python test_rag.py检查各模块功能 - 查看日志输出了解处理进度
- 使用
python main.py init重新初始化知识库
总结
本项目提供了一个完整的、可独立测试的Qwen3-RAG轻量级系统,具有以下优势:
- 完整的RAG流程:涵盖文档处理、向量检索、重排序和答案生成
- 易于使用:提供交互式界面和API服务
- 模块化设计:每个模块可独立测试和调试
- 灵活配置:支持不同规模的模型和参数调整
- 良好的可扩展性:易于添加新功能或替换组件
系统特别适合:
- 个人知识库管理
- 中小型企业文档检索
- RAG技术学习和研究
- 原型验证和概念测试
通过这个系统,您可以快速体验Qwen3-Embedding和Qwen3-Reranker的强大功能,并根据实际需求进行调整和优化。
更多推荐



所有评论(0)