Qwen3-RAG轻量级系统

项目结构

qwen-rag-system/
├── config.py              # 配置文件
├── main.py               # 主应用程序
├── test_rag.py           # 系统测试脚本
├── api_server.py         # FastAPI服务
├── requirements.txt      # 依赖包列表
├── .env                  # 环境配置
├── README.md             # 项目说明
├── documents/            # 存放PDF/TXT/MD文档
├── chroma_db/            # 向量数据库
├── document_processor.py # 文档处理模块
├── vector_retriever.py   # 向量检索模块
├── reranker.py           # 重排序模块
└── answer_generator.py   # 答案生成模块

1. 配置模块 (config.py)

import os
from dotenv import load_dotenv

load_dotenv()

class Config:
    """系统配置类"""
    # 模型配置
    EMBEDDING_MODEL = os.getenv("EMBEDDING_MODEL", "dengcao/Qwen3-Embedding-0.6B:Q4_K_M")
    RERANKER_MODEL = os.getenv("RERANKER_MODEL", "dengcao/Qwen3-Reranker-0.6B:Q4_K_M")
    LLM_MODEL = os.getenv("LLM_MODEL", "qwen2.5:3b")
    
    # 向量数据库配置
    CHROMA_PERSIST_DIR = os.getenv("CHROMA_PERSIST_DIR", "./chroma_db")
    COLLECTION_NAME = os.getenv("COLLECTION_NAME", "knowledge_base")
    
    # 文本处理配置
    CHUNK_SIZE = int(os.getenv("CHUNK_SIZE", "500"))
    CHUNK_OVERLAP = int(os.getenv("CHUNK_OVERLAP", "50"))
    TOP_K_RETRIEVAL = int(os.getenv("TOP_K_RETRIEVAL", "10"))
    TOP_N_RERANK = int(os.getenv("TOP_N_RERANK", "3"))
    
    # Ollama API配置
    OLLAMA_BASE_URL = os.getenv("OLLAMA_BASE_URL", "http://localhost:11434")
    
    # 文件路径
    DOCUMENT_PATH = os.getenv("DOCUMENT_PATH", "./documents")
    
    @classmethod
    def validate(cls):
        """验证配置"""
        if not os.path.exists(cls.DOCUMENT_PATH):
            os.makedirs(cls.DOCUMENT_PATH)
            print(f"创建文档目录: {cls.DOCUMENT_PATH}")
        
        if not os.path.exists(cls.CHROMA_PERSIST_DIR):
            os.makedirs(cls.CHROMA_PERSIST_DIR)
            print(f"创建向量数据库目录: {cls.CHROMA_PERSIST_DIR}")
    
    @classmethod
    def print_config(cls):
        """打印当前配置"""
        print("=" * 50)
        print("系统配置信息:")
        print("=" * 50)
        for key, value in cls.__dict__.items():
            if not key.startswith("__") and not callable(value):
                print(f"{key}: {value}")
        print("=" * 50)

def main():
    """配置模块独立测试"""
    print("测试配置模块...")
    Config.validate()
    Config.print_config()
    
    # 测试环境变量读取
    print("\n环境变量测试:")
    print(f"Ollama URL: {Config.OLLAMA_BASE_URL}")
    print(f"文档路径: {Config.DOCUMENT_PATH}")
    
    return True

if __name__ == "__main__":
    main()

2. 文档处理模块 (document_processor.py)

import os
import json
from typing import List, Dict, Any
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain.document_loaders import (
    PyPDFLoader,
    TextLoader,
    UnstructuredMarkdownLoader
)

class DocumentProcessor:
    """文档处理器"""
    def __init__(self, chunk_size=500, chunk_overlap=50):
        self.chunk_size = chunk_size
        self.chunk_overlap = chunk_overlap
        self.text_splitter = RecursiveCharacterTextSplitter(
            chunk_size=chunk_size,
            chunk_overlap=chunk_overlap,
            length_function=len,
            separators=["\n\n", "\n", "。", "!", "?", ";", ",", " ", ""]
        )
    
    def load_documents(self, directory_path: str) -> List[Dict[str, Any]]:
        """加载目录中的所有文档"""
        documents = []
        
        if not os.path.exists(directory_path):
            print(f"目录不存在: {directory_path}")
            return documents
        
        for filename in os.listdir(directory_path):
            file_path = os.path.join(directory_path, filename)
            
            try:
                if filename.endswith('.pdf'):
                    loader = PyPDFLoader(file_path)
                elif filename.endswith('.txt'):
                    loader = TextLoader(file_path, encoding='utf-8')
                elif filename.endswith('.md'):
                    loader = UnstructuredMarkdownLoader(file_path)
                else:
                    print(f"跳过不支持的文件类型: {filename}")
                    continue
                
                loaded_docs = loader.load()
                for doc in loaded_docs:
                    documents.append({
                        'content': doc.page_content,
                        'metadata': {
                            'source': filename,
                            'page': doc.metadata.get('page', 0)
                        }
                    })
                print(f"✓ 已加载: {filename}")
                
            except Exception as e:
                print(f"✗ 加载文件 {filename} 时出错: {str(e)}")
        
        print(f"总计加载 {len(documents)} 个文档片段")
        return documents
    
    def chunk_documents(self, documents: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
        """将文档分块"""
        chunked_docs = []
        
        for doc in documents:
            chunks = self.text_splitter.split_text(doc['content'])
            
            for i, chunk in enumerate(chunks):
                chunked_docs.append({
                    'id': f"{doc['metadata']['source']}_chunk_{i}",
                    'content': chunk,
                    'metadata': {
                        **doc['metadata'],
                        'chunk_index': i,
                        'total_chunks': len(chunks)
                    }
                })
        
        print(f"✓ 文档分块完成,共 {len(chunked_docs)} 个块")
        return chunked_docs
    
    def save_chunks_to_json(self, chunks: List[Dict[str, Any]], output_path: str):
        """保存分块结果到JSON文件"""
        with open(output_path, 'w', encoding='utf-8') as f:
            json.dump(chunks, f, ensure_ascii=False, indent=2)
        print(f"✓ 分块结果已保存到: {output_path}")
    
    def process_directory(self, input_dir: str, output_json: str = None) -> List[Dict[str, Any]]:
        """完整处理目录中的文档"""
        print(f"开始处理目录: {input_dir}")
        
        # 加载文档
        documents = self.load_documents(input_dir)
        if not documents:
            print("没有找到可处理的文档")
            return []
        
        # 分块处理
        chunks = self.chunk_documents(documents)
        
        # 保存结果
        if output_json:
            self.save_chunks_to_json(chunks, output_json)
        
        return chunks

def main():
    """文档处理模块独立测试"""
    print("测试文档处理模块...")
    
    # 创建测试文档
    test_dir = "./test_documents"
    if not os.path.exists(test_dir):
        os.makedirs(test_dir)
        
        # 创建测试文本文件
        test_content = """这是一个测试文档。
用于测试文档处理模块的功能。
文档处理包括加载、分块等步骤。
RAG系统需要将文档分块以便后续处理。"""
        
        with open(os.path.join(test_dir, "test.txt"), "w", encoding="utf-8") as f:
            f.write(test_content)
    
    # 初始化处理器
    processor = DocumentProcessor(chunk_size=100, chunk_overlap=20)
    
    # 测试处理功能
    chunks = processor.process_directory(test_dir, "./test_chunks.json")
    
    if chunks:
        print("\n处理结果示例:")
        for i, chunk in enumerate(chunks[:3]):  # 显示前3个块
            print(f"块 {i+1}: {chunk['content'][:50]}...")
        print(f"总共处理了 {len(chunks)} 个文档块")
    
    # 清理测试文件
    import shutil
    if os.path.exists("./test_chunks.json"):
        os.remove("./test_chunks.json")
    if os.path.exists(test_dir):
        shutil.rmtree(test_dir)
    
    print("\n文档处理模块测试完成!")
    return len(chunks) > 0

if __name__ == "__main__":
    main()

3. 向量检索模块 (vector_retriever.py)

import requests
import numpy as np
from typing import List, Dict, Any
import chromadb
from chromadb.config import Settings
import hashlib
import time

class QwenEmbeddingClient:
    """Qwen Embedding客户端"""
    def __init__(self, model_name: str, base_url: str = "http://localhost:11434"):
        self.model_name = model_name
        self.base_url = base_url
        self.embedding_url = f"{base_url}/api/embeddings"
    
    def get_embedding(self, text: str, instruction: str = None) -> List[float]:
        """获取单个文本的嵌入向量"""
        payload = {
            "model": self.model_name,
            "input": text
        }
        
        if instruction:
            payload["input"] = f"指令:{instruction}\n文本:{text}"
        
        try:
            response = requests.post(self.embedding_url, json=payload, timeout=60)
            response.raise_for_status()
            result = response.json()
            return result.get("embedding", [])
        except requests.exceptions.RequestException as e:
            print(f"获取嵌入向量时网络错误: {str(e)}")
            return []
        except Exception as e:
            print(f"获取嵌入向量时出错: {str(e)}")
            return []
    
    def get_embeddings_batch(self, texts: List[str], batch_size: int = 5) -> List[List[float]]:
        """批量获取嵌入向量"""
        embeddings = []
        
        print(f"开始批量处理 {len(texts)} 个文本...")
        for i in range(0, len(texts), batch_size):
            batch = texts[i:i+batch_size]
            batch_embeddings = []
            
            for text in batch:
                embedding = self.get_embedding(text)
                if embedding:
                    batch_embeddings.append(embedding)
                else:
                    # 使用零向量占位
                    batch_embeddings.append([0.0] * 1024)
            
            embeddings.extend(batch_embeddings)
            processed = min(i+batch_size, len(texts))
            print(f"进度: {processed}/{len(texts)}")
            time.sleep(0.1)  # 防止请求过快
        
        return embeddings

class VectorRetriever:
    """向量检索器"""
    def __init__(self, config):
        self.config = config
        self.embedding_client = QwenEmbeddingClient(
            model_name=config.EMBEDDING_MODEL,
            base_url=config.OLLAMA_BASE_URL
        )
        
        # 初始化ChromaDB
        self.client = chromadb.PersistentClient(
            path=config.CHROMA_PERSIST_DIR,
            settings=Settings(anonymized_telemetry=False)
        )
        
        # 获取或创建集合
        self.collection = self.client.get_or_create_collection(
            name=config.COLLECTION_NAME,
            metadata={"hnsw:space": "cosine"}
        )
    
    def generate_doc_id(self, content: str, metadata: Dict) -> str:
        """生成文档ID"""
        content_hash = hashlib.md5(content.encode()).hexdigest()
        source = metadata.get('source', 'unknown')
        chunk_idx = metadata.get('chunk_index', 0)
        return f"{source}_chunk{chunk_idx}_{content_hash[:8]}"
    
    def add_documents(self, documents: List[Dict[str, Any]]):
        """将文档添加到向量数据库"""
        if not documents:
            print("没有文档可添加")
            return False
        
        print(f"开始添加 {len(documents)} 个文档到向量数据库...")
        
        # 准备数据
        ids = []
        contents = []
        metadatas = []
        
        for doc in documents:
            doc_id = self.generate_doc_id(doc['content'], doc['metadata'])
            ids.append(doc_id)
            contents.append(doc['content'])
            metadatas.append(doc['metadata'])
        
        # 获取嵌入向量
        print("生成嵌入向量...")
        embeddings = self.embedding_client.get_embeddings_batch(contents)
        
        if not embeddings:
            print("无法生成嵌入向量")
            return False
        
        # 添加到集合
        try:
            self.collection.add(
                embeddings=embeddings,
                documents=contents,
                metadatas=metadatas,
                ids=ids
            )
            print(f"✓ 已成功添加 {len(documents)} 个文档到向量数据库")
            return True
        except Exception as e:
            print(f"✗ 添加文档时出错: {str(e)}")
            return False
    
    def search(self, query: str, top_k: int = 10) -> List[Dict[str, Any]]:
        """检索相关文档"""
        print(f"检索查询: {query}")
        
        # 获取查询的嵌入向量
        query_embedding = self.embedding_client.get_embedding(query)
        
        if not query_embedding:
            print("无法获取查询的嵌入向量")
            return []
        
        # 执行搜索
        try:
            results = self.collection.query(
                query_embeddings=[query_embedding],
                n_results=top_k,
                include=["documents", "metadatas", "distances"]
            )
            
            # 整理结果
            retrieved_docs = []
            if results['documents']:
                for i in range(len(results['documents'][0])):
                    retrieved_docs.append({
                        'content': results['documents'][0][i],
                        'metadata': results['metadatas'][0][i],
                        'score': 1 - results['distances'][0][i]
                    })
            
            print(f"✓ 检索到 {len(retrieved_docs)} 个相关文档")
            return retrieved_docs
            
        except Exception as e:
            print(f"✗ 检索时出错: {str(e)}")
            return []
    
    def get_collection_stats(self) -> Dict[str, Any]:
        """获取集合统计信息"""
        try:
            count = self.collection.count()
            return {
                "total_documents": count,
                "collection_name": self.config.COLLECTION_NAME
            }
        except Exception as e:
            print(f"获取集合统计时出错: {str(e)}")
            return {}

def main():
    """向量检索模块独立测试"""
    print("测试向量检索模块...")
    
    # 导入配置
    from config import Config
    config = Config()
    config.validate()
    
    # 初始化检索器
    retriever = VectorRetriever(config)
    
    # 测试集合状态
    stats = retriever.get_collection_stats()
    print(f"集合状态: {stats}")
    
    # 测试检索功能(需要先有数据)
    test_query = "人工智能"
    results = retriever.search(test_query, top_k=3)
    
    if results:
        print(f"\n查询 '{test_query}' 的检索结果:")
        for i, doc in enumerate(results):
            print(f"\n结果 {i+1}:")
            print(f"  分数: {doc['score']:.4f}")
            print(f"  来源: {doc['metadata'].get('source', '未知')}")
            print(f"  内容: {doc['content'][:100]}...")
    else:
        print("没有检索到结果,可能需要先添加文档")
    
    # 测试嵌入生成
    print("\n测试嵌入生成...")
    test_text = "这是一个测试文本"
    embedding = retriever.embedding_client.get_embedding(test_text)
    if embedding:
        print(f"嵌入向量维度: {len(embedding)}")
        print(f"前10个值: {embedding[:10]}")
    else:
        print("无法生成嵌入向量,请检查Ollama服务")
    
    print("\n向量检索模块测试完成!")
    return len(results) > 0 if results else False

if __name__ == "__main__":
    main()

4. 重排序模块 (reranker.py)

import requests
import json
from typing import List, Dict, Any
import re

class QwenRerankerClient:
    """Qwen Reranker客户端"""
    def __init__(self, model_name: str, base_url: str = "http://localhost:11434"):
        self.model_name = model_name
        self.base_url = base_url
        self.generate_url = f"{base_url}/api/generate"
    
    def rerank(self, query: str, documents: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
        """对检索结果进行重排序"""
        if not documents:
            print("没有文档需要重排序")
            return []
        
        print(f"开始重排序 {len(documents)} 个文档...")
        reranked_docs = []
        
        for idx, doc in enumerate(documents):
            # 构建重排序提示
            prompt = self._build_rerank_prompt(query, doc['content'])
            
            # 调用模型获取分数
            score = self._get_relevance_score(prompt)
            
            reranked_docs.append({
                **doc,
                'rerank_score': score
            })
            
            # 显示进度
            print(f"  进度: {idx+1}/{len(documents)} - 分数: {score:.2f}")
        
        # 按重排序分数降序排列
        reranked_docs.sort(key=lambda x: x['rerank_score'], reverse=True)
        
        print(f"✓ 重排序完成,最佳分数: {reranked_docs[0]['rerank_score']:.2f}")
        return reranked_docs
    
    def _build_rerank_prompt(self, query: str, document: str) -> str:
        """构建重排序提示"""
        # 限制文档长度
        doc_preview = document[:800] + "..." if len(document) > 800 else document
        
        return f"""请评估以下查询与文档的相关性,只输出一个0-10的分数,不要有任何其他文本。

查询:{query}

文档:{doc_preview}

相关性分数:"""
    
    def _get_relevance_score(self, prompt: str) -> float:
        """获取相关性分数"""
        payload = {
            "model": self.model_name,
            "prompt": prompt,
            "stream": False,
            "options": {
                "temperature": 0.1,
                "num_predict": 10
            }
        }
        
        try:
            response = requests.post(self.generate_url, json=payload, timeout=30)
            response.raise_for_status()
            result = response.json()
            
            # 提取数字分数
            response_text = result.get("response", "").strip()
            
            # 尝试从响应中提取数字
            numbers = re.findall(r"\d+\.?\d*", response_text)
            if numbers:
                score = float(numbers[0])
                # 确保分数在0-10范围内
                return min(max(score, 0), 10)
            else:
                print(f"警告: 无法从响应中提取分数: {response_text}")
                return 5.0
                
        except requests.exceptions.Timeout:
            print("重排序请求超时")
            return 0.0
        except Exception as e:
            print(f"重排序时出错: {str(e)}")
            return 0.0

class Reranker:
    """重排序器"""
    def __init__(self, config):
        self.config = config
        self.reranker_client = QwenRerankerClient(
            model_name=config.RERANKER_MODEL,
            base_url=config.OLLAMA_BASE_URL
        )
    
    def process(self, query: str, retrieved_docs: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
        """执行重排序"""
        if not retrieved_docs:
            return []
        
        print(f"开始重排序,处理 {len(retrieved_docs)} 个文档...")
        
        # 执行重排序
        reranked_docs = self.reranker_client.rerank(query, retrieved_docs)
        
        # 只保留top_n个结果
        top_n = min(self.config.TOP_N_RERANK, len(reranked_docs))
        final_docs = reranked_docs[:top_n]
        
        print(f"✓ 重排序完成,保留 {len(final_docs)} 个最相关文档")
        
        # 打印重排序结果
        print("\n重排序结果:")
        for i, doc in enumerate(final_docs):
            print(f"{i+1}. 分数: {doc.get('rerank_score', 0):.2f}, "
                  f"源: {doc['metadata'].get('source', '未知')}")
        
        return final_docs

def main():
    """重排序模块独立测试"""
    print("测试重排序模块...")
    
    # 导入配置
    from config import Config
    config = Config()
    config.validate()
    
    # 初始化重排序器
    reranker = Reranker(config)
    
    # 创建测试数据
    test_query = "人工智能的应用领域"
    test_documents = [
        {
            'content': '人工智能在医疗领域有广泛应用,如疾病诊断和治疗方案推荐。',
            'metadata': {'source': 'test1.txt', 'score': 0.85}
        },
        {
            'content': '机器学习是人工智能的核心技术,包括监督学习和无监督学习。',
            'metadata': {'source': 'test2.txt', 'score': 0.78}
        },
        {
            'content': '深度学习在计算机视觉和自然语言处理中取得显著成果。',
            'metadata': {'source': 'test3.txt', 'score': 0.92}
        },
        {
            'content': '自动驾驶技术结合了计算机视觉、传感器融合和路径规划。',
            'metadata': {'source': 'test4.txt', 'score': 0.65}
        }
    ]
    
    print(f"测试查询: {test_query}")
    print(f"测试文档数量: {len(test_documents)}")
    
    # 测试重排序
    reranked_docs = reranker.process(test_query, test_documents)
    
    if reranked_docs:
        print("\n重排序结果对比:")
        print("原始顺序 vs 重排序后:")
        for i, (orig, reranked) in enumerate(zip(test_documents, reranked_docs)):
            orig_source = orig['metadata']['source']
            reranked_source = reranked['metadata']['source']
            reranked_score = reranked.get('rerank_score', 0)
            print(f"  {i+1}. {orig_source} (原始) -> {reranked_source} (分数: {reranked_score:.2f})")
    
    # 测试单个文档评分
    print("\n测试单个文档评分...")
    test_prompt = reranker.reranker_client._build_rerank_prompt(
        "机器学习", 
        "机器学习是人工智能的重要分支,让计算机从数据中学习规律。"
    )
    score = reranker.reranker_client._get_relevance_score(test_prompt)
    print(f"单个文档评分: {score:.2f}")
    
    print("\n重排序模块测试完成!")
    return len(reranked_docs) > 0

if __name__ == "__main__":
    main()

5. 答案生成模块 (answer_generator.py)

import requests
from typing import List, Dict, Any

class AnswerGenerator:
    """答案生成器"""
    def __init__(self, config):
        self.config = config
        self.generate_url = f"{config.OLLAMA_BASE_URL}/api/generate"
    
    def build_prompt(self, query: str, context_docs: List[Dict[str, Any]]) -> str:
        """构建生成提示"""
        if not context_docs:
            return f"问题:{query}\n\n请回答这个问题:"
        
        context_text = "以下是相关的上下文信息:\n\n"
        for i, doc in enumerate(context_docs):
            source = doc['metadata'].get('source', '未知来源')
            context_text += f"[信息{i+1} 来自 {source}]:\n{doc['content']}\n\n"
        
        prompt = f"""基于以下上下文信息,回答问题。如果上下文信息不足,请说明无法从提供的信息中获取完整答案。

{context_text}

问题:{query}

请提供准确、简洁的答案,并注明答案的来源信息:"""
        
        return prompt
    
    def generate_answer(self, query: str, context_docs: List[Dict[str, Any]]) -> Dict[str, Any]:
        """生成答案"""
        print(f"生成答案,使用 {len(context_docs)} 个上下文文档...")
        
        if not context_docs:
            print("警告:没有上下文文档")
            return {
                "answer": "未找到相关文档信息,无法回答问题。",
                "sources": [],
                "context_used": False,
                "confidence": 0.0
            }
        
        # 构建提示
        prompt = self.build_prompt(query, context_docs)
        
        # 调用LLM生成答案
        payload = {
            "model": self.config.LLM_MODEL,
            "prompt": prompt,
            "stream": False,
            "options": {
                "temperature": 0.3,
                "num_predict": 500,
                "top_p": 0.9
            }
        }
        
        try:
            response = requests.post(self.generate_url, json=payload, timeout=120)
            response.raise_for_status()
            result = response.json()
            
            answer = result.get("response", "").strip()
            
            # 提取参考来源
            sources = list(set([
                doc['metadata'].get('source', '未知来源') 
                for doc in context_docs
            ]))
            
            # 计算置信度(基于上下文数量和重排序分数)
            total_rerank_score = sum(doc.get('rerank_score', 5) for doc in context_docs)
            avg_rerank_score = total_rerank_score / len(context_docs) if context_docs else 0
            confidence = min(avg_rerank_score / 10, 1.0)  # 归一化到0-1
            
            print(f"✓ 答案生成完成,置信度: {confidence:.2f}")
            
            return {
                "answer": answer,
                "sources": sources,
                "context_used": True,
                "context_count": len(context_docs),
                "confidence": confidence,
                "prompt_length": len(prompt),
                "answer_length": len(answer)
            }
            
        except requests.exceptions.Timeout:
            print("生成答案超时")
            return {
                "answer": "生成答案超时,请稍后重试。",
                "sources": [],
                "context_used": False,
                "confidence": 0.0
            }
        except Exception as e:
            print(f"生成答案时出错: {str(e)}")
            return {
                "answer": f"生成答案时出错: {str(e)}",
                "sources": [],
                "context_used": False,
                "confidence": 0.0
            }
    
    def test_generation(self, test_query: str, test_contexts: List[str] = None) -> Dict[str, Any]:
        """测试生成功能"""
        if test_contexts is None:
            test_contexts = [
                "人工智能是计算机科学的一个分支,致力于创建能够执行通常需要人类智能的任务的系统。",
                "机器学习是人工智能的一种方法,使计算机能够从数据中学习而无需明确编程。",
                "深度学习是机器学习的一个子集,使用多层神经网络来模拟人脑的决策过程。"
            ]
        
        # 构建测试文档
        test_docs = []
        for i, context in enumerate(test_contexts):
            test_docs.append({
                'content': context,
                'metadata': {'source': f'test_source_{i+1}'},
                'rerank_score': 8.5 - i  # 模拟重排序分数
            })
        
        print(f"测试查询: {test_query}")
        print(f"测试上下文数量: {len(test_docs)}")
        
        # 生成答案
        result = self.generate_answer(test_query, test_docs)
        
        return result

def main():
    """答案生成模块独立测试"""
    print("测试答案生成模块...")
    
    # 导入配置
    from config import Config
    config = Config()
    config.validate()
    
    # 初始化生成器
    generator = AnswerGenerator(config)
    
    # 测试生成功能
    test_query = "什么是人工智能?"
    result = generator.test_generation(test_query)
    
    # 显示结果
    print("\n" + "="*50)
    print("测试结果:")
    print("="*50)
    print(f"查询: {test_query}")
    print(f"答案: {result['answer']}")
    print(f"使用上下文: {result['context_used']}")
    print(f"上下文数量: {result.get('context_count', 0)}")
    print(f"置信度: {result.get('confidence', 0):.2f}")
    print(f"答案长度: {result.get('answer_length', 0)} 字符")
    print(f"来源: {', '.join(result['sources']) if result['sources'] else '无'}")
    print("="*50)
    
    # 测试提示词构建
    print("\n测试提示词构建...")
    test_docs = [{'content': '测试内容', 'metadata': {'source': 'test.txt'}}]
    prompt = generator.build_prompt("测试问题", test_docs)
    print(f"提示词长度: {len(prompt)} 字符")
    print(f"提示词预览:\n{prompt[:200]}...")
    
    print("\n答案生成模块测试完成!")
    return result['context_used']

if __name__ == "__main__":
    main()

6. 主应用程序 (main.py)

import os
import time
from config import Config
from document_processor import DocumentProcessor
from vector_retriever import VectorRetriever
from reranker import Reranker
from answer_generator import AnswerGenerator

class RAGSystem:
    """RAG系统主类"""
    def __init__(self):
        self.config = Config()
        self.config.validate()
        
        self.doc_processor = DocumentProcessor(
            chunk_size=self.config.CHUNK_SIZE,
            chunk_overlap=self.config.CHUNK_OVERLAP
        )
        
        self.retriever = VectorRetriever(self.config)
        self.reranker = Reranker(self.config)
        self.generator = AnswerGenerator(self.config)
        
        self.is_initialized = False
    
    def initialize_knowledge_base(self, force_reload: bool = False):
        """初始化知识库"""
        print("正在初始化知识库...")
        
        # 检查是否已存在向量数据库
        if not force_reload and os.path.exists(self.config.CHROMA_PERSIST_DIR):
            # 检查集合中是否有数据
            stats = self.retriever.get_collection_stats()
            if stats.get('total_documents', 0) > 0:
                print(f"✓ 检测到已有的向量数据库 ({stats['total_documents']} 个文档),跳过重新构建...")
                self.is_initialized = True
                return
        
        # 加载和分块文档
        documents = self.doc_processor.load_documents(self.config.DOCUMENT_PATH)
        
        if not documents:
            print("✗ 未找到文档,请将文档放入 documents/ 目录")
            print(f"  当前文档路径: {self.config.DOCUMENT_PATH}")
            return
        
        chunked_docs = self.doc_processor.chunk_documents(documents)
        
        # 保存分块结果(可选)
        self.doc_processor.save_chunks_to_json(
            chunked_docs, 
            "./document_chunks.json"
        )
        
        # 添加到向量数据库
        success = self.retriever.add_documents(chunked_docs)
        
        if success:
            self.is_initialized = True
            print("✓ 知识库初始化完成!")
        else:
            print("✗ 知识库初始化失败")
    
    def query(self, question: str) -> Dict[str, Any]:
        """处理用户查询"""
        if not self.is_initialized:
            return {"error": "知识库未初始化"}
        
        print(f"\n{'='*60}")
        print(f"处理查询: {question}")
        print(f"{'='*60}")
        
        # 步骤1: 检索相关文档
        start_time = time.time()
        retrieved_docs = self.retriever.search(
            question, 
            top_k=self.config.TOP_K_RETRIEVAL
        )
        retrieval_time = time.time() - start_time
        
        if not retrieved_docs:
            return {
                "answer": "未找到相关文档。",
                "sources": [],
                "context_used": False,
                "timing": {
                    "retrieval": retrieval_time,
                    "reranking": 0,
                    "generation": 0,
                    "total": retrieval_time
                }
            }
        
        print(f"✓ 初步检索到 {len(retrieved_docs)} 个文档,耗时: {retrieval_time:.2f}秒")
        
        # 步骤2: 重排序
        start_rerank = time.time()
        reranked_docs = self.reranker.process(question, retrieved_docs)
        rerank_time = time.time() - start_rerank
        
        # 步骤3: 生成答案
        start_gen = time.time()
        result = self.generator.generate_answer(question, reranked_docs)
        gen_time = time.time() - start_gen
        
        # 添加时间信息
        result["timing"] = {
            "retrieval": retrieval_time,
            "reranking": rerank_time,
            "generation": gen_time,
            "total": retrieval_time + rerank_time + gen_time
        }
        
        print(f"\n✓ 答案生成完成!总耗时: {result['timing']['total']:.2f}秒")
        print(f"{'='*60}")
        
        return result
    
    def interactive_mode(self):
        """交互式查询模式"""
        if not self.is_initialized:
            print("正在初始化知识库...")
            self.initialize_knowledge_base()
            if not self.is_initialized:
                print("无法初始化知识库,退出交互模式")
                return
        
        print("\n" + "="*60)
        print("RAG 系统已启动!")
        print("="*60)
        print("命令说明:")
        print("  'quit', 'exit', 'q' - 退出系统")
        print("  'reload' - 重新加载知识库")
        print("  'stats' - 显示系统统计")
        print("  'config' - 显示当前配置")
        print("="*60 + "\n")
        
        while True:
            try:
                question = input("\n请输入问题: ").strip()
                
                if question.lower() in ['quit', 'exit', 'q']:
                    print("再见!")
                    break
                
                if question.lower() == 'reload':
                    print("重新加载知识库...")
                    self.initialize_knowledge_base(force_reload=True)
                    continue
                
                if question.lower() == 'stats':
                    stats = self.retriever.get_collection_stats()
                    print(f"系统统计:")
                    print(f"  文档数量: {stats.get('total_documents', 0)}")
                    print(f"  集合名称: {stats.get('collection_name', '未知')}")
                    continue
                
                if question.lower() == 'config':
                    self.config.print_config()
                    continue
                
                if not question:
                    continue
                
                # 处理查询
                result = self.query(question)
                
                # 显示结果
                if "error" in result:
                    print(f"错误: {result['error']}")
                else:
                    print(f"\n📝 答案:")
                    print(f"{result['answer']}")
                    
                    if result.get('sources'):
                        print(f"\n📚 参考来源: {', '.join(result['sources'])}")
                    
                    if result.get('confidence', 0) > 0:
                        print(f"\n📊 置信度: {result['confidence']:.2%}")
                    
                    print(f"\n⏱️  时间统计:")
                    for stage, time_taken in result['timing'].items():
                        if stage != 'total':
                            print(f"  {stage}: {time_taken:.2f}秒")
                    print(f"  总计: {result['timing']['total']:.2f}秒")
                
            except KeyboardInterrupt:
                print("\n\n程序被中断")
                break
            except Exception as e:
                print(f"处理查询时出错: {str(e)}")

def test_system():
    """测试整个RAG系统"""
    print("测试RAG系统...")
    
    rag = RAGSystem()
    
    # 初始化知识库
    print("\n1. 初始化知识库...")
    rag.initialize_knowledge_base()
    
    if not rag.is_initialized:
        print("知识库初始化失败,无法继续测试")
        return False
    
    # 测试查询
    print("\n2. 测试查询功能...")
    test_queries = [
        "人工智能是什么?",
        "机器学习和深度学习有什么区别?",
        "RAG系统的工作原理是什么?"
    ]
    
    results = []
    for query in test_queries:
        print(f"\n测试查询: {query}")
        result = rag.query(query)
        results.append(result)
        
        if "error" not in result:
            print(f"  成功生成答案 (长度: {len(result.get('answer', ''))} 字符)")
    
    # 总结测试结果
    print("\n" + "="*60)
    print("测试总结:")
    print("="*60)
    success_count = sum(1 for r in results if "error" not in r and r.get('context_used', False))
    print(f"成功查询: {success_count}/{len(test_queries)}")
    
    if results:
        avg_time = sum(r.get('timing', {}).get('total', 0) for r in results) / len(results)
        print(f"平均查询时间: {avg_time:.2f}秒")
    
    return success_count > 0

def main():
    """主函数"""
    print("Qwen3-RAG 轻量级系统")
    print("="*60)
    
    import sys
    
    if len(sys.argv) > 1:
        command = sys.argv[1].lower()
        
        if command == "test":
            # 运行测试模式
            success = test_system()
            sys.exit(0 if success else 1)
        
        elif command == "init":
            # 只初始化知识库
            rag = RAGSystem()
            rag.initialize_knowledge_base(force_reload=True)
            sys.exit(0)
        
        elif command == "serve":
            # 启动API服务
            from api_server import app
            import uvicorn
            uvicorn.run(app, host="0.0.0.0", port=8000)
            sys.exit(0)
        
        else:
            print(f"未知命令: {command}")
            print("可用命令: test, init, serve")
            sys.exit(1)
    
    # 默认启动交互模式
    rag_system = RAGSystem()
    rag_system.interactive_mode()

if __name__ == "__main__":
    main()

7. 测试脚本 (test_rag.py)

#!/usr/bin/env python3
"""
RAG系统测试脚本
"""

import sys
import time
import json
from typing import List, Dict, Any

# 添加当前目录到路径
sys.path.append('.')

def test_config_module():
    """测试配置模块"""
    print("=" * 60)
    print("测试配置模块")
    print("=" * 60)
    
    try:
        from config import Config, main as config_main
        success = config_main()
        return success
    except Exception as e:
        print(f"配置模块测试失败: {str(e)}")
        return False

def test_document_processor():
    """测试文档处理模块"""
    print("\n" + "=" * 60)
    print("测试文档处理模块")
    print("=" * 60)
    
    try:
        from document_processor import DocumentProcessor, main as processor_main
        success = processor_main()
        return success
    except Exception as e:
        print(f"文档处理模块测试失败: {str(e)}")
        return False

def test_vector_retriever():
    """测试向量检索模块"""
    print("\n" + "=" * 60)
    print("测试向量检索模块")
    print("=" * 60)
    
    try:
        from vector_retriever import VectorRetriever, main as retriever_main
        
        # 先确保Ollama服务运行
        import requests
        try:
            response = requests.get("http://localhost:11434/api/tags", timeout=5)
            if response.status_code == 200:
                print("✓ Ollama服务运行正常")
            else:
                print("✗ Ollama服务异常")
                return False
        except:
            print("✗ 无法连接到Ollama服务")
            print("请先启动Ollama服务: ollama serve")
            return False
        
        success = retriever_main()
        return success
    except Exception as e:
        print(f"向量检索模块测试失败: {str(e)}")
        return False

def test_reranker():
    """测试重排序模块"""
    print("\n" + "=" * 60)
    print("测试重排序模块")
    print("=" * 60)
    
    try:
        from reranker import Reranker, main as reranker_main
        success = reranker_main()
        return success
    except Exception as e:
        print(f"重排序模块测试失败: {str(e)}")
        return False

def test_answer_generator():
    """测试答案生成模块"""
    print("\n" + "=" * 60)
    print("测试答案生成模块")
    print("=" * 60)
    
    try:
        from answer_generator import AnswerGenerator, main as generator_main
        
        # 检查Ollama服务
        import requests
        try:
            response = requests.get("http://localhost:11434/api/tags", timeout=5)
            if response.status_code != 200:
                print("✗ Ollama服务异常")
                return False
        except:
            print("✗ 无法连接到Ollama服务")
            return False
        
        success = generator_main()
        return success
    except Exception as e:
        print(f"答案生成模块测试失败: {str(e)}")
        return False

def test_full_system():
    """测试完整系统"""
    print("\n" + "=" * 60)
    print("测试完整RAG系统")
    print("=" * 60)
    
    try:
        from main import test_system
        success = test_system()
        return success
    except Exception as e:
        print(f"完整系统测试失败: {str(e)}")
        return False

def run_performance_test():
    """运行性能测试"""
    print("\n" + "=" * 60)
    print("性能测试")
    print("=" * 60)
    
    try:
        from main import RAGSystem
        
        rag = RAGSystem()
        rag.initialize_knowledge_base()
        
        if not rag.is_initialized:
            print("知识库未初始化,跳过性能测试")
            return False
        
        # 测试查询
        test_query = "人工智能的主要应用领域"
        
        times = []
        answers = []
        
        print(f"运行性能测试 (查询: '{test_query}')")
        for i in range(3):  # 运行3次取平均
            print(f"\n第 {i+1} 次运行...")
            start_time = time.time()
            result = rag.query(test_query)
            end_time = time.time()
            
            if "error" not in result:
                total_time = result.get('timing', {}).get('total', end_time - start_time)
                times.append(total_time)
                answers.append(result.get('answer', ''))
                
                print(f"  时间: {total_time:.2f}秒")
                print(f"  答案长度: {len(answers[-1])} 字符")
                print(f"  置信度: {result.get('confidence', 0):.2%}")
            else:
                print(f"  失败: {result.get('error', '未知错误')}")
        
        if times:
            print("\n性能测试结果:")
            print(f"  平均时间: {sum(times)/len(times):.2f}秒")
            print(f"  最短时间: {min(times):.2f}秒")
            print(f"  最长时间: {max(times):.2f}秒")
            
            # 检查答案一致性
            if len(answers) >= 2:
                similar = all(len(a) > 50 for a in answers)  # 简单检查
                print(f"  答案一致性: {'良好' if similar else '不稳定'}")
            
            return True
        else:
            print("性能测试失败: 没有成功的查询")
            return False
            
    except Exception as e:
        print(f"性能测试失败: {str(e)}")
        return False

def main():
    """主测试函数"""
    print("Qwen3-RAG 系统测试套件")
    print("=" * 60)
    
    # 运行各模块测试
    tests = [
        ("配置模块", test_config_module),
        ("文档处理模块", test_document_processor),
        ("向量检索模块", test_vector_retriever),
        ("重排序模块", test_reranker),
        ("答案生成模块", test_answer_generator),
        ("完整系统", test_full_system),
        ("性能测试", run_performance_test),
    ]
    
    results = []
    for test_name, test_func in tests:
        print(f"\n▶ 开始测试: {test_name}")
        try:
            success = test_func()
            results.append((test_name, success))
            status = "✓ 通过" if success else "✗ 失败"
            print(f"{status}: {test_name}")
        except Exception as e:
            print(f"✗ 异常: {test_name} - {str(e)}")
            results.append((test_name, False))
    
    # 输出测试报告
    print("\n" + "=" * 60)
    print("测试报告")
    print("=" * 60)
    
    passed = sum(1 for _, success in results if success)
    total = len(results)
    
    print(f"通过率: {passed}/{total} ({passed/total*100:.1f}%)")
    print("\n详细结果:")
    for test_name, success in results:
        status = "✓ 通过" if success else "✗ 失败"
        print(f"  {status}: {test_name}")
    
    # 保存测试结果
    with open("test_results.json", "w", encoding="utf-8") as f:
        json.dump({
            "timestamp": time.strftime("%Y-%m-%d %H:%M:%S"),
            "total_tests": total,
            "passed_tests": passed,
            "results": [
                {"test": name, "passed": success}
                for name, success in results
            ]
        }, f, ensure_ascii=False, indent=2)
    
    print(f"\n详细结果已保存到: test_results.json")
    
    return passed == total

if __name__ == "__main__":
    success = main()
    sys.exit(0 if success else 1)

8. API服务 (api_server.py)

from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
from typing import Optional, List
import uvicorn
import sys
import os

# 添加当前目录到路径
sys.path.append('.')

from main import RAGSystem

app = FastAPI(
    title="Qwen3-RAG API",
    description="基于Qwen3-Embedding和Qwen3-Reranker的RAG系统API",
    version="1.0.0"
)

# 初始化RAG系统
rag_system = RAGSystem()
rag_system.initialize_knowledge_base()

class QueryRequest(BaseModel):
    """查询请求模型"""
    question: str
    top_k: Optional[int] = 10
    top_n: Optional[int] = 3
    include_context: Optional[bool] = False

class QueryResponse(BaseModel):
    """查询响应模型"""
    answer: str
    sources: List[str]
    context_used: bool
    confidence: Optional[float] = 0.0
    timing: dict
    error: Optional[str] = None

class SystemStatus(BaseModel):
    """系统状态模型"""
    initialized: bool
    document_count: int
    collection_name: str
    config: dict

@app.get("/")
async def root():
    """根端点"""
    return {
        "service": "Qwen3-RAG API",
        "version": "1.0.0",
        "endpoints": {
            "/query": "POST - 处理查询",
            "/status": "GET - 系统状态",
            "/health": "GET - 健康检查"
        }
    }

@app.post("/query", response_model=QueryResponse)
async def process_query(request: QueryRequest):
    """处理查询请求"""
    try:
        # 临时更新配置
        if request.top_k:
            rag_system.config.TOP_K_RETRIEVAL = request.top_k
        if request.top_n:
            rag_system.config.TOP_N_RERANK = request.top_n
        
        # 处理查询
        result = rag_system.query(request.question)
        
        if "error" in result:
            return QueryResponse(
                answer="",
                sources=[],
                context_used=False,
                confidence=0.0,
                timing={"total": 0},
                error=result["error"]
            )
        
        # 构建响应
        response = QueryResponse(
            answer=result.get("answer", ""),
            sources=result.get("sources", []),
            context_used=result.get("context_used", False),
            confidence=result.get("confidence", 0.0),
            timing=result.get("timing", {})
        )
        
        return response
        
    except Exception as e:
        raise HTTPException(status_code=500, detail=str(e))

@app.get("/status", response_model=SystemStatus)
async def get_status():
    """获取系统状态"""
    try:
        from vector_retriever import VectorRetriever
        from config import Config
        
        config = Config()
        retriever = VectorRetriever(config)
        stats = retriever.get_collection_stats()
        
        return SystemStatus(
            initialized=rag_system.is_initialized,
            document_count=stats.get("total_documents", 0),
            collection_name=stats.get("collection_name", ""),
            config={
                "embedding_model": config.EMBEDDING_MODEL,
                "reranker_model": config.RERANKER_MODEL,
                "llm_model": config.LLM_MODEL,
                "chunk_size": config.CHUNK_SIZE,
                "top_k": config.TOP_K_RETRIEVAL,
                "top_n": config.TOP_N_RERANK
            }
        )
    except Exception as e:
        raise HTTPException(status_code=500, detail=str(e))

@app.get("/health")
async def health_check():
    """健康检查"""
    try:
        # 检查Ollama服务
        import requests
        ollama_response = requests.get(
            f"{rag_system.config.OLLAMA_BASE_URL}/api/tags", 
            timeout=5
        )
        ollama_ok = ollama_response.status_code == 200
        
        return {
            "status": "healthy" if rag_system.is_initialized and ollama_ok else "degraded",
            "rag_system_initialized": rag_system.is_initialized,
            "ollama_service": "running" if ollama_ok else "unavailable",
            "timestamp": time.strftime("%Y-%m-%d %H:%M:%S")
        }
    except:
        return {
            "status": "unhealthy",
            "rag_system_initialized": rag_system.is_initialized,
            "ollama_service": "unavailable",
            "timestamp": time.strftime("%Y-%m-%d %H:%M:%S")
        }

@app.get("/config")
async def get_config():
    """获取当前配置"""
    from config import Config
    config = Config()
    
    return {
        "embedding_model": config.EMBEDDING_MODEL,
        "reranker_model": config.RERANKER_MODEL,
        "llm_model": config.LLM_MODEL,
        "chunk_size": config.CHUNK_SIZE,
        "chunk_overlap": config.CHUNK_OVERLAP,
        "top_k_retrieval": config.TOP_K_RETRIEVAL,
        "top_n_rerank": config.TOP_N_RERANK,
        "document_path": config.DOCUMENT_PATH,
        "chroma_persist_dir": config.CHROMA_PERSIST_DIR
    }

def main():
    """启动API服务"""
    print("启动 Qwen3-RAG API 服务...")
    print(f"服务地址: http://0.0.0.0:8000")
    print(f"API文档: http://0.0.0.0:8000/docs")
    print("\n按 Ctrl+C 停止服务\n")
    
    uvicorn.run(app, host="0.0.0.0", port=8000)

if __name__ == "__main__":
    import time
    main()

9. 配置文件

requirements.txt

ollama>=0.1.0
chromadb>=0.4.0
langchain>=0.1.0
sentence-transformers>=2.2.0
pypdf>=3.0.0
python-dotenv>=1.0.0
fastapi>=0.104.0
uvicorn>=0.24.0
requests>=2.31.0
numpy>=1.24.0

.env 文件

# 模型配置
EMBEDDING_MODEL=dengcao/Qwen3-Embedding-0.6B:Q4_K_M
RERANKER_MODEL=dengcao/Qwen3-Reranker-0.6B:Q4_K_M
LLM_MODEL=qwen2.5:3b

# 向量数据库配置
CHROMA_PERSIST_DIR=./chroma_db
COLLECTION_NAME=knowledge_base

# 文本处理配置
CHUNK_SIZE=500
CHUNK_OVERLAP=50
TOP_K_RETRIEVAL=10
TOP_N_RERANK=3

# Ollama API配置
OLLAMA_BASE_URL=http://localhost:11434

# 文件路径
DOCUMENT_PATH=./documents

10. 使用说明

10.1 安装与配置

# 1. 克隆或创建项目
mkdir qwen-rag-system && cd qwen-rag-system

# 2. 创建虚拟环境
python -m venv venv
source venv/bin/activate  # Linux/Mac
# 或
venv\Scripts\activate     # Windows

# 3. 安装依赖
pip install -r requirements.txt

# 4. 启动Ollama服务
ollama serve &

# 5. 拉取模型(可选,运行时会自动拉取)
ollama pull dengcao/Qwen3-Embedding-0.6B:Q4_K_M
ollama pull dengcao/Qwen3-Reranker-0.6B:Q4_K_M
ollama pull qwen2.5:3b

10.2 准备文档

# 将文档放入documents目录
mkdir documents
# 复制PDF、TXT或MD文件到documents/目录

10.3 运行方式

方式1:交互模式(默认)
python main.py
方式2:测试模式
# 运行所有测试
python test_rag.py

# 或使用主程序的测试命令
python main.py test
方式3:API服务
python main.py serve
# 或直接运行API服务
python api_server.py
方式4:仅初始化知识库
python main.py init

10.4 独立测试各模块

# 测试配置模块
python -c "from config import main; main()"

# 测试文档处理模块
python -c "from document_processor import main; main()"

# 测试向量检索模块
python -c "from vector_retriever import main; main()"

# 测试重排序模块
python -c "from reranker import main; main()"

# 测试答案生成模块
python -c "from answer_generator import main; main()"

11. 项目特点

模块化设计

  • 每个模块都有独立的main()方法,便于单独测试和调试
  • 清晰的依赖关系,易于理解和维护
  • 配置集中管理,便于调整参数

易于扩展

  • 支持多种文档格式(PDF、TXT、MD)
  • 可替换不同规模的Qwen3模型
  • 可扩展其他向量数据库或LLM模型

完整测试

  • 每个模块都有独立的测试功能
  • 提供完整的系统测试套件
  • 包含性能测试和健康检查

多种部署方式

  • 交互式命令行界面
  • RESTful API服务
  • 可集成到其他应用

12. 故障排除

常见问题

  1. Ollama连接失败

    错误:无法连接到Ollama服务
    解决:确保Ollama服务正在运行
          ollama serve &
    
  2. 模型未找到

    错误:模型不存在
    解决:先拉取模型
          ollama pull dengcao/Qwen3-Embedding-0.6B:Q4_K_M
    
  3. 内存不足

    错误:内存分配失败
    解决:使用更小的量化版本
          修改.env文件中的模型配置
    
  4. 文档加载失败

    错误:无法加载PDF文档
    解决:安装正确的依赖
          pip install pypdf
    

调试建议

  • 先运行python test_rag.py检查各模块功能
  • 查看日志输出了解处理进度
  • 使用python main.py init重新初始化知识库

总结

本项目提供了一个完整的、可独立测试的Qwen3-RAG轻量级系统,具有以下优势:

  1. 完整的RAG流程:涵盖文档处理、向量检索、重排序和答案生成
  2. 易于使用:提供交互式界面和API服务
  3. 模块化设计:每个模块可独立测试和调试
  4. 灵活配置:支持不同规模的模型和参数调整
  5. 良好的可扩展性:易于添加新功能或替换组件

系统特别适合:

  • 个人知识库管理
  • 中小型企业文档检索
  • RAG技术学习和研究
  • 原型验证和概念测试

通过这个系统,您可以快速体验Qwen3-Embedding和Qwen3-Reranker的强大功能,并根据实际需求进行调整和优化。

Logo

有“AI”的1024 = 2048,欢迎大家加入2048 AI社区

更多推荐