不知道你有没有遇到这样的情况,AI客服每天要回答几千个问题,其中至少有三分之一是重复的——什么"年假怎么算"“差旅费怎么报销”“公积金比例是多少”……这些问题的答案其实都写在公司制度里,几个月都不会变一次。

但问题来了:每次有人问,AI都要重新去文档库里翻一遍。

就像你明明已经把家里钥匙放哪儿记得清清楚楚,但每次出门还是要把整个房间翻一遍才能找到。这不是浪费时间吗?

今天这篇文章,我会用实际代码带你完整实现从传统RAG到CAG的演进过程。每一步都有可运行的代码,让你真正理解这个技术是怎么work的。

在这里插入图片描述

在这里插入图片描述

一、RAG很好,但它有个"健忘症"

说到这里,得先聊聊现在最流行的RAG技术。

RAG全称是"检索增强生成",听起来挺学术的,但原理很直白:让AI在回答问题之前,先去知识库里查一查相关资料,然后基于这些资料来生成答案。

这个方法确实解决了AI"瞎编"的问题。但它有个天生的缺陷——没记性。

1.1 什么是RAG?

RAG(检索增强生成)的工作流程很简单:

  1. 用户提问
  2. 系统去知识库检索相关文档
  3. 把检索结果和问题一起给AI
  4. AI基于检索内容生成答案

听起来很美好,但问题在于:每次都要检索。

1.2 传统RAG的完整实现

让我们先实现一个标准的RAG系统,用企业HR知识库作为例子:

import numpy as np
from typing import List, Dict
import time
from datetime import datetime
# 模拟向量数据库
class SimpleVectorDB:
    """简单的向量数据库实现"""
    
    def __init__(self):
        self.documents = []
        self.embeddings = []
        self.metadata = []
        
    def add_document(self, text: str, metadata: Dict = None):
        """添加文档到数据库"""
        # 这里用简单的词频向量模拟embedding
        embedding = self._text_to_vector(text)
        self.documents.append(text)
        self.embeddings.append(embedding)
        self.metadata.append(metadata or {})
        
    def _text_to_vector(self, text: str) -> np.ndarray:
        """将文本转换为向量(简化版)"""
        # 实际应该用OpenAI/HuggingFace的embedding模型
        # 这里简化处理:基于字符出现频率
        vector = np.zeros(100)
        for i, char in enumerate(text[:100]):
            vector[i] = ord(char) / 1000
        return vector
    
    def search(self, query: str, top_k: int = 3) -> List[Dict]:
        """检索最相关的文档"""
        query_vector = self._text_to_vector(query)
        
        # 计算余弦相似度
        similarities = []
        for i, doc_vector in enumerate(self.embeddings):
            similarity = np.dot(query_vector, doc_vector) / (
                np.linalg.norm(query_vector) * np.linalg.norm(doc_vector) + 1e-10
            )
            similarities.append({
                'index': i,
                'score': similarity,
                'text': self.documents[i],
                'metadata': self.metadata[i]
            })
        
        # 返回top_k结果
        similarities.sort(key=lambda x: x['score'], reverse=True)
        return similarities[:top_k]
class TraditionalRAG:
    """传统RAG系统"""
    
    def __init__(self):
        self.vector_db = SimpleVectorDB()
        self.search_count = 0  # 统计检索次数
        self.search_times = []  # 记录每次检索耗时
        
    def add_knowledge(self, text: str, metadata: Dict = None):
        """添加知识到系统"""
        self.vector_db.add_document(text, metadata)
        
    def query(self, question: str) -> Dict:
        """处理查询"""
        start_time = time.time()
        
        # 每次都要检索
        search_results = self.vector_db.search(question, top_k=2)
        
        search_time = time.time() - start_time
        self.search_count += 1
        self.search_times.append(search_time)
        
        # 组装上下文
        context = "\n\n".join([r['text'] for r in search_results])
        
        # 模拟LLM生成答案(实际应调用GPT/Claude API)
        answer = self._generate_answer(question, context)
        
        return {
            'question': question,
            'answer': answer,
            'context': context,
            'search_time': search_time,
            'total_searches': self.search_count
        }
    
    def _generate_answer(self, question: str, context: str) -> str:
        """模拟LLM生成答案"""
        # 实际应该调用OpenAI API或其他LLM
        return f"基于知识库:{context[:100]}... 回答:[模拟答案]"
    
    def get_statistics(self) -> Dict:
        """获取性能统计"""
        return {
            'total_searches': self.search_count,
            'avg_search_time': np.mean(self.search_times) if self.search_times else 0,
            'total_time': sum(self.search_times)
        }
# 使用示例
def demo_traditional_rag():
    """演示传统RAG的问题"""
    print("=" * 60)
    print("传统RAG系统演示")
    print("=" * 60)
    
    # 创建RAG系统
    rag = TraditionalRAG()
    
    # 添加企业知识(这些都是稳定的制度文档)
    knowledge_base = [
        {
            "text": "公司年假政策:入职满1年员工享有5天年假,满3年享有10天,满5年享有15天。年假必须在当年使用,不可跨年累积。",
            "metadata": {"category": "HR政策", "update_date": "2024-01-01"}
        },
        {
            "text": "差旅费报销标准:国内出差每天补贴200元,住宿费实报实销上限500元/天。需提供发票和出差申请单。",
            "metadata": {"category": "财务制度", "update_date": "2024-01-01"}
        },
        {
            "text": "公积金缴纳比例:公司和个人各缴纳12%,基数为上年度月平均工资。每年7月调整一次。",
            "metadata": {"category": "薪酬福利", "update_date": "2024-01-01"}
        },
        {
            "text": "病假规定:员工因病需请假,需提供医院证明。病假工资按基本工资的80%发放,每年累计不超过30天。",
            "metadata": {"category": "HR政策", "update_date": "2024-01-01"}
        }
    ]
    
    for kb in knowledge_base:
        rag.add_knowledge(kb['text'], kb['metadata'])
    
    print(f"\n已加载 {len(knowledge_base)} 条企业知识\n")
    
    # 模拟重复查询(这是关键问题所在)
    repeated_questions = [
        "年假怎么算?",
        "年假政策是什么?",
        "我能休几天年假?",
        "差旅费怎么报销?",
        "出差补贴标准是多少?",
        "年假能累积吗?",  # 又问年假
        "公积金比例是多少?",
        "年假政策详细说明",  # 再问年假
    ]
    
    print("开始处理查询...\n")
    for i, question in enumerate(repeated_questions, 1):
        result = rag.query(question)
        print(f"查询 {i}: {question}")
        print(f"  检索耗时: {result['search_time']*1000:.2f}ms")
        print(f"  累计检索次数: {result['total_searches']}")
        print()
    
    # 显示统计信息
    stats = rag.get_statistics()
    print("=" * 60)
    print("性能统计")
    print("=" * 60)
    print(f"总检索次数: {stats['total_searches']}")
    print(f"平均检索耗时: {stats['avg_search_time']*1000:.2f}ms")
    print(f"总耗时: {stats['total_time']*1000:.2f}ms")
    print()
    print("⚠️  问题分析:")
    print("  - 关于'年假'的问题被问了4次,但每次都重新检索")
    print("  - 这些制度文档几个月都不会变,却要反复访问数据库")
    print("  - 随着查询量增加,成本和延迟线性上升")
    print()
# 运行演示
demo_traditional_rag()

1.3 问题暴露:成本与延迟

运行上面的代码,你会看到:

  • 关于"年假"的问题问了4次,系统检索了4次
  • 每次检索都要访问向量数据库
  • 累计检索次数随查询量线性增长

实际生产环境的影响:

  • 成本:向量数据库调用费用(如Pinecone按查询次数收费)
  • 延迟:网络往返+相似度计算,通常50-200ms
  • 资源:数据库连接数、CPU占用

通过上面的例子可以很清楚发现,就算是同样的问题问一百遍,AI还是会乖乖地去检索一百遍。访问数据库、匹配文档、提取信息……这一套流程走下来,既耗时又烧钱。

尤其是对于那些几乎不会变的知识,比如公司规章制度、产品说明书、法律条文……每次都重新检索,实在是有点"杀鸡用牛刀"的感觉。

二、CAG:给AI装上"内存条"

节节这个问题,有个新思路,叫做缓存增强生成(CAG)。

简单说,就是给AI装个"内存"——把那些稳定不变的知识,直接存到模型内部的记忆库里。下次再遇到相关问题,就不用去外面翻箱倒柜了,直接从"脑子里"调出来就行。

这就好比你把常用的工具放在手边,而不是每次都跑到仓库去找。

效果立竿见影:

  • 速度更快:不用反复访问数据库,响应时间能缩短一大半
  • 成本更低:检索次数少了,服务器压力小了,钱自然省下来了
  • 回答更稳定:对于固定知识的表述更一致,不会今天说A明天说B

2.1 CAG的核心思想

CAG(缓存增强生成)要做的事情很简单:

  1. 识别哪些知识是"静态的"(长期不变)
  2. 把这些知识直接缓存到内存
  3. 查询时先查缓存,命中就不用检索了

那是不是所有知识都该塞进缓存呢?

当然不是。如果什么都往里装,很快就会把AI的"脑容量"撑爆。

2.2 CAG系统的完整代码实现

import hashlib
from typing import Optional, Tuple
import json
class KnowledgeCache:
    """知识缓存管理器"""
    
    def __init__(self, max_size: int = 100):
        self.cache = {}  # 缓存存储
        self.max_size = max_size
        self.hit_count = 0  # 命中次数
        self.miss_count = 0  # 未命中次数
        self.access_log = []  # 访问日志
        
    def _generate_key(self, query: str) -> str:
        """生成查询的缓存键"""
        # 使用语义哈希(这里简化为文本哈希)
        # 实际应该用embedding的相似度匹配
        normalized = query.lower().strip()
        return hashlib.md5(normalized.encode()).hexdigest()[:16]
    
    def get(self, query: str, similarity_threshold: float = 0.85) -> Optional[Dict]:
        """从缓存获取答案"""
        # 简化版:精确匹配
        # 实际应该用语义相似度匹配
        query_key = query.lower().strip()
        
        # 查找语义相似的缓存项
        for cached_query, cached_data in self.cache.items():
            if self._is_similar(query_key, cached_query):
                self.hit_count += 1
                self.access_log.append({
                    'query': query,
                    'result': 'HIT',
                    'timestamp': datetime.now().isoformat()
                })
                return cached_data
        
        self.miss_count += 1
        self.access_log.append({
            'query': query,
            'result': 'MISS',
            'timestamp': datetime.now().isoformat()
        })
        return None
    
    def _is_similar(self, query1: str, query2: str) -> bool:
        """判断两个查询是否相似"""
        # 简化版:包含关键词就算相似
        # 实际应该用向量相似度
        keywords1 = set(query1.split())
        keywords2 = set(query2.split())
        
        if not keywords1 or not keywords2:
            return False
            
        intersection = keywords1 & keywords2
        union = keywords1 | keywords2
        similarity = len(intersection) / len(union)
        
        return similarity > 0.5
    
    def set(self, query: str, context: str, answer: str, metadata: Dict = None):
        """设置缓存"""
        query_key = query.lower().strip()
        
        # 检查容量限制
        if len(self.cache) >= self.max_size:
            # 简单的LRU:删除最旧的项
            oldest_key = next(iter(self.cache))
            del self.cache[oldest_key]
        
        self.cache[query_key] = {
            'query': query,
            'context': context,
            'answer': answer,
            'metadata': metadata or {},
            'cached_at': datetime.now().isoformat()
        }
    
    def get_statistics(self) -> Dict:
        """获取缓存统计"""
        total_access = self.hit_count + self.miss_count
        hit_rate = self.hit_count / total_access if total_access > 0 else 0
        
        return {
            'hit_count': self.hit_count,
            'miss_count': self.miss_count,
            'total_access': total_access,
            'hit_rate': hit_rate,
            'cache_size': len(self.cache)
        }
class CAGSystem:
    """CAG(缓存增强生成)系统"""
    
    def __init__(self, cache_size: int = 100):
        self.vector_db = SimpleVectorDB()
        self.cache = KnowledgeCache(max_size=cache_size)
        self.search_count = 0
        self.search_times = []
        self.cache_hit_times = []
        
    def add_knowledge(self, text: str, metadata: Dict = None, cacheable: bool = False):
        """添加知识"""
        self.vector_db.add_document(text, metadata)
        
        # 如果标记为可缓存,预先生成常见问题的缓存
        if cacheable and metadata and 'common_questions' in metadata:
            for question in metadata['common_questions']:
                # 预先缓存答案
                answer = f"基于缓存:{text[:100]}..."
                self.cache.set(question, text, answer, metadata)
    
    def query(self, question: str) -> Dict:
        """处理查询(带缓存)"""
        start_time = time.time()
        
        # 先查缓存
        cached_result = self.cache.get(question)
        
        if cached_result:
            # 缓存命中!
            cache_time = time.time() - start_time
            self.cache_hit_times.append(cache_time)
            
            return {
                'question': question,
                'answer': cached_result['answer'],
                'context': cached_result['context'],
                'source': 'CACHE',
                'response_time': cache_time,
                'total_searches': self.search_count
            }
        
        # 缓存未命中,执行检索
        search_results = self.vector_db.search(question, top_k=2)
        search_time = time.time() - start_time
        self.search_count += 1
        self.search_times.append(search_time)
        
        # 组装上下文
        context = "\n\n".join([r['text'] for r in search_results])
        answer = self._generate_answer(question, context)
        
        # 存入缓存(如果是静态知识)
        if search_results and self._is_cacheable(search_results[0]):
            self.cache.set(question, context, answer, 
                          search_results[0].get('metadata', {}))
        
        return {
            'question': question,
            'answer': answer,
            'context': context,
            'source': 'RETRIEVAL',
            'response_time': search_time,
            'total_searches': self.search_count
        }
    
    def _is_cacheable(self, search_result: Dict) -> bool:
        """判断检索结果是否应该缓存"""
        metadata = search_result.get('metadata', {})
        # 如果有更新日期且超过30天未更新,认为是静态知识
        update_date = metadata.get('update_date')
        if update_date:
            # 简化判断:只要有update_date就认为是静态的
            return True
        return False
    
    def _generate_answer(self, question: str, context: str) -> str:
        """模拟LLM生成答案"""
        return f"基于知识库:{context[:100]}... 回答:[模拟答案]"
    
    def get_statistics(self) -> Dict:
        """获取完整统计信息"""
        cache_stats = self.cache.get_statistics()
        
        return {
            'retrieval': {
                'total_searches': self.search_count,
                'avg_search_time': np.mean(self.search_times) if self.search_times else 0,
                'total_time': sum(self.search_times)
            },
            'cache': {
                'hit_count': cache_stats['hit_count'],
                'miss_count': cache_stats['miss_count'],
                'hit_rate': cache_stats['hit_rate'],
                'avg_hit_time': np.mean(self.cache_hit_times) if self.cache_hit_times else 0,
                'cache_size': cache_stats['cache_size']
            },
            'overall': {
                'total_queries': cache_stats['total_access'],
                'searches_saved': cache_stats['hit_count'],
                'cost_reduction': f"{cache_stats['hit_rate']*100:.1f}%"
            }
        }
# 使用示例
def demo_cag_system():
    """演示CAG系统的优势"""
    print("=" * 60)
    print("CAG系统演示(带缓存优化)")
    print("=" * 60)
    
    # 创建CAG系统
    cag = CAGSystem(cache_size=50)
    
    # 添加知识(标记静态知识为可缓存)
    knowledge_base = [
        {
            "text": "公司年假政策:入职满1年员工享有5天年假,满3年享有10天,满5年享有15天。年假必须在当年使用,不可跨年累积。",
            "metadata": {
                "category": "HR政策",
                "update_date": "2024-01-01",
                "common_questions": [
                    "年假怎么算",
                    "年假政策是什么",
                    "我能休几天年假",
                    "年假能累积吗"
                ]
            },
            "cacheable": True
        },
        {
            "text": "差旅费报销标准:国内出差每天补贴200元,住宿费实报实销上限500元/天。需提供发票和出差申请单。",
            "metadata": {
                "category": "财务制度",
                "update_date": "2024-01-01",
                "common_questions": [
                    "差旅费怎么报销",
                    "出差补贴标准是多少",
                    "出差住宿费报销"
                ]
            },
            "cacheable": True
        },
        {
            "text": "公积金缴纳比例:公司和个人各缴纳12%,基数为上年度月平均工资。每年7月调整一次。",
            "metadata": {
                "category": "薪酬福利",
                "update_date": "2024-01-01",
                "common_questions": [
                    "公积金比例是多少",
                    "公积金怎么缴纳"
                ]
            },
            "cacheable": True
        }
    ]
    
    for kb in knowledge_base:
        cag.add_knowledge(kb['text'], kb['metadata'], kb['cacheable'])
    
    print(f"\n已加载 {len(knowledge_base)} 条企业知识(已预缓存常见问题)\n")
    
    # 模拟重复查询
    test_questions = [
        "年假怎么算?",          # 第1次:缓存命中
        "年假政策是什么?",      # 第2次:缓存命中
        "我能休几天年假?",      # 第3次:缓存命中
        "差旅费怎么报销?",      # 第1次:缓存命中
        "出差补贴标准是多少?",  # 第2次:缓存命中
        "年假能累积吗?",        # 第4次:缓存命中
        "公积金比例是多少?",    # 第1次:缓存命中
        "年假政策详细说明",      # 第5次:缓存命中
    ]
    
    print("开始处理查询...\n")
    for i, question in enumerate(test_questions, 1):
        result = cag.query(question)
        
        # 显示结果
        source_icon = "⚡ [缓存]" if result['source'] == 'CACHE' else "🔍 [检索]"
        print(f"查询 {i}: {question}")
        print(f"  数据源: {source_icon}")
        print(f"  响应时间: {result['response_time']*1000:.2f}ms")
        print(f"  累计检索次数: {result['total_searches']}")
        print()
    
    # 显示详细统计
    stats = cag.get_statistics()
    print("=" * 60)
    print("性能统计对比")
    print("=" * 60)
    
    print("\n【检索统计】")
    print(f"  实际检索次数: {stats['retrieval']['total_searches']}")
    print(f"  平均检索耗时: {stats['retrieval']['avg_search_time']*1000:.2f}ms")
    
    print("\n【缓存统计】")
    print(f"  缓存命中次数: {stats['cache']['hit_count']}")
    print(f"  缓存未命中: {stats['cache']['miss_count']}")
    print(f"  缓存命中率: {stats['cache']['hit_rate']*100:.1f}%")
    print(f"  平均缓存响应: {stats['cache']['avg_hit_time']*1000:.2f}ms")
    print(f"  当前缓存大小: {stats['cache']['cache_size']}")
    
    print("\n【整体优化】")
    print(f"  总查询次数: {stats['overall']['total_queries']}")
    print(f"  节省检索次数: {stats['overall']['searches_saved']}")
    print(f"  成本降低: {stats['overall']['cost_reduction']}")
    
    print("\n✅ 优势总结:")
    print("  - 重复问题直接从缓存返回,无需检索")
    print("  - 响应时间从 50-200ms 降低到 <5ms")
    print("  - 数据库访问次数大幅减少,成本显著降低")
    print()
# 运行CAG演示
demo_cag_system()

2.3 CAG与RAG的性能对比

让我们直接对比两个系统:

def compare_rag_vs_cag():
    """直接对比RAG和CAG的性能"""
    print("=" * 60)
    print("RAG vs CAG 性能对比实验")
    print("=" * 60)
    
    # 准备测试数据
    knowledge = {
        "text": "公司年假政策:入职满1年员工享有5天年假,满3年享有10天,满5年享有15天。",
        "metadata": {
            "category": "HR政策",
            "common_questions": ["年假怎么算", "年假政策", "休几天年假"]
        }
    }
    
    # 重复查询100次
    questions = ["年假怎么算?"] * 100
    
    # 测试传统RAG
    print("\n【测试1:传统RAG】")
    rag = TraditionalRAG()
    rag.add_knowledge(knowledge['text'], knowledge['metadata'])
    
    rag_start = time.time()
    for q in questions:
        rag.query(q)
    rag_total_time = time.time() - rag_start
    rag_stats = rag.get_statistics()
    
    print(f"总耗时: {rag_total_time*1000:.2f}ms")
    print(f"检索次数: {rag_stats['total_searches']}")
    print(f"平均延迟: {rag_stats['avg_search_time']*1000:.2f}ms")
    
    # 测试CAG
    print("\n【测试2:CAG系统】")
    cag = CAGSystem()
    cag.add_knowledge(knowledge['text'], knowledge['metadata'], cacheable=True)
    
    cag_start = time.time()
    for q in questions:
        cag.query(q)
    cag_total_time = time.time() - cag_start
    cag_stats = cag.get_statistics()
    
    print(f"总耗时: {cag_total_time*1000:.2f}ms")
    print(f"检索次数: {cag_stats['retrieval']['total_searches']}")
    print(f"缓存命中率: {cag_stats['cache']['hit_rate']*100:.1f}%")
    print(f"平均延迟: {cag_stats['cache']['avg_hit_time']*1000:.2f}ms")
    
    # 性能提升计算
    print("\n" + "=" * 60)
    print("性能提升")
    print("=" * 60)
    speedup = rag_total_time / cag_total_time
    search_reduction = (rag_stats['total_searches'] - cag_stats['retrieval']['total_searches']) / rag_stats['total_searches']
    
    print(f"速度提升: {speedup:.1f}x")
    print(f"检索次数减少: {search_reduction*100:.1f}%")
    print(f"成本节约: ~{search_reduction*100:.1f}%")
    
# 运行对比测试
compare_rag_vs_cag()

三、RAG+CAG融合架构

3.1 为什么需要融合?

就像你的大脑:九九乘法表、家庭住址这些早就记住了,但今天午饭吃什么、明天天气怎么样,还是得现查。

这种"内存+外脑"的双引擎模式,才是未来知识型AI的标配。

class HybridRAGCAG:
    """混合RAG+CAG系统"""
    
    def __init__(self, cache_size: int = 100):
        # 静态知识缓存
        self.static_cache = KnowledgeCache(max_size=cache_size)
        
        # 动态知识向量库
        self.dynamic_db = SimpleVectorDB()
        
        # 静态知识向量库(用于缓存未命中时的后备)
        self.static_db = SimpleVectorDB()
        
        # 统计信息
        self.stats = {
            'static_cache_hits': 0,
            'static_db_queries': 0,
            'dynamic_db_queries': 0,
            'total_queries': 0
        }
        
    def add_static_knowledge(self, text: str, metadata: Dict = None, 
                            common_questions: List[str] = None):
        """添加静态知识(长期不变)"""
        # 添加到静态数据库
        self.static_db.add_document(text, metadata)
        
        # 预缓存常见问题
        if common_questions:
            for question in common_questions:
                answer = f"[静态知识] {text}"
                self.static_cache.set(question, text, answer, metadata)
    
    def add_dynamic_knowledge(self, text: str, metadata: Dict = None):
        """添加动态知识(经常更新)"""
        # 只添加到动态数据库,不缓存
        self.dynamic_db.add_document(text, metadata)
    
    def query(self, question: str, require_realtime: bool = False) -> Dict:
        """智能查询:自动判断用缓存还是检索
        
        Args:
            question: 用户问题
            require_realtime: 是否强制要求实时数据
        """
        self.stats['total_queries'] += 1
        start_time = time.time()
        
        # 如果不要求实时数据,先查静态缓存
        if not require_realtime:
            cached_result = self.static_cache.get(question)
            if cached_result:
                self.stats['static_cache_hits'] += 1
                return {
                    'question': question,
                    'answer': cached_result['answer'],
                    'source': 'STATIC_CACHE',
                    'response_time': time.time() - start_time,
                    'confidence': 'high'
                }
        
        # 判断问题类型:需要动态数据还是静态数据?
        question_type = self._classify_question(question)
        
        if question_type == 'dynamic' or require_realtime:
            # 查询动态数据库
            results = self.dynamic_db.search(question, top_k=2)
            self.stats['dynamic_db_queries'] += 1
            source = 'DYNAMIC_RETRIEVAL'
        else:
            # 查询静态数据库
            results = self.static_db.search(question, top_k=2)
            self.stats['static_db_queries'] += 1
            source = 'STATIC_RETRIEVAL'
            
            # 将结果缓存起来,下次直接用
            if results:
                context = results[0]['text']
                answer = f"[静态知识] {context}"
                self.static_cache.set(question, context, answer, 
                                     results[0].get('metadata', {}))
        
        # 生成答案
        context = "\n".join([r['text'] for r in results]) if results else ""
        answer = self._generate_answer(question, context, source)
        
        return {
            'question': question,
            'answer': answer,
            'source': source,
            'response_time': time.time() - start_time,
            'confidence': 'high' if results else 'low'
        }
    
    def _classify_question(self, question: str) -> str:
        """判断问题需要动态数据还是静态数据"""
        # 简化版:通过关键词判断
        dynamic_keywords = ['今天', '最新', '现在', '当前', '实时', '昨天', '最近']
        static_keywords = ['政策', '制度', '规定', '标准', '流程', '怎么', '如何']
        
        question_lower = question.lower()
        
        # 包含动态关键词,返回dynamic
        for keyword in dynamic_keywords:
            if keyword in question_lower:
                return 'dynamic'
        
        # 包含静态关键词,返回static
        for keyword in static_keywords:
            if keyword in question_lower:
                return 'static'
        
        # 默认当作静态
        return 'static'
    
    def _generate_answer(self, question: str, context: str, source: str) -> str:
        """生成答案"""
        if not context:
            return "抱歉,没有找到相关信息。"
        return f"基于{source}:{context[:150]}..."
    
    def get_statistics(self) -> Dict:
        """获取详细统计"""
        cache_stats = self.static_cache.get_statistics()
        total_queries = self.stats['total_queries']
        
        return {
            'total_queries': total_queries,
            'static_cache_hits': self.stats['static_cache_hits'],
            'static_db_queries': self.stats['static_db_queries'],
            'dynamic_db_queries': self.stats['dynamic_db_queries'],
            'cache_hit_rate': cache_stats['hit_rate'],
            'db_access_rate': (self.stats['static_db_queries'] + self.stats['dynamic_db_queries']) / total_queries if total_queries > 0 else 0
        }
def demo_hybrid_system():
    """演示混合系统"""
    print("=" * 60)
    print("混合RAG+CAG系统演示")
    print("=" * 60)
    
    # 创建混合系统
    hybrid = HybridRAGCAG(cache_size=50)
    
    # 添加静态知识(制度文档)
    print("\n【加载静态知识】")
    static_knowledge = [
        {
            "text": "公司年假政策:入职满1年员工享有5天年假,满3年享有10天,满5年享有15天。年假必须在当年使用,不可跨年累积。",
            "common_questions": ["年假怎么算", "年假政策", "休几天年假", "年假能累积吗"]
        },
        {
            "text": "报销流程:提交申请单→部门主管审批→财务审核→财务打款。处理时间约3-5个工作日。",
            "common_questions": ["怎么报销", "报销流程", "报销要多久"]
        }
    ]
    
    for kb in static_knowledge:
        hybrid.add_static_knowledge(
            kb['text'], 
            {'type': 'static', 'category': 'policy'},
            kb['common_questions']
        )
    print(f"已加载 {len(static_knowledge)} 条静态知识(已预缓存)")
    
    # 添加动态知识(实时数据)
    print("\n【加载动态知识】")
    dynamic_knowledge = [
        {
            "text": "今天公司食堂菜单:午餐有红烧肉、清蒸鱼、麻婆豆腐。晚餐有宫保鸡丁、酸菜鱼、素炒时蔬。",
            "metadata": {'type': 'dynamic', 'date': '2025-11-05'}
        },
        {
            "text": "本周会议通知:周三下午3点全体会议,周五上午10点部门例会。请提前准备材料。",
            "metadata": {'type': 'dynamic', 'date': '2025-11-05'}
        }
    ]
    
    for kb in dynamic_knowledge:
        hybrid.add_dynamic_knowledge(kb['text'], kb['metadata'])
    print(f"已加载 {len(dynamic_knowledge)} 条动态知识")
    
    # 测试不同类型的查询
    print("\n" + "=" * 60)
    print("开始测试查询")
    print("=" * 60)
    
    test_cases = [
        {"q": "年假怎么算?", "type": "静态问题(应命中缓存)"},
        {"q": "年假政策是什么?", "type": "静态问题(应命中缓存)"},
        {"q": "报销流程是什么?", "type": "静态问题(应命中缓存)"},
        {"q": "今天食堂吃什么?", "type": "动态问题(应查询动态库)"},
        {"q": "本周有什么会议?", "type": "动态问题(应查询动态库)"},
        {"q": "年假能累积吗?", "type": "静态问题(应命中缓存)"},
        {"q": "今天食堂有什么菜?", "type": "动态问题(应查询动态库)"},
    ]
    
    for i, test in enumerate(test_cases, 1):
        result = hybrid.query(test['q'])
        
        # 根据source显示不同图标
        if result['source'] == 'STATIC_CACHE':
            icon = "⚡"
            color = "缓存"
        elif result['source'] == 'STATIC_RETRIEVAL':
            icon = "📚"
            color = "静态库"
        else:
            icon = "🔄"
            color = "动态库"
        
        print(f"\n查询 {i}: {test['q']}")
        print(f"  类型: {test['type']}")
        print(f"  数据源: {icon} {color}")
        print(f"  响应时间: {result['response_time']*1000:.2f}ms")
    
    # 显示统计
    print("\n" + "=" * 60)
    print("系统统计")
    print("=" * 60)
    stats = hybrid.get_statistics()
    
    print(f"\n总查询次数: {stats['total_queries']}")
    print(f"  ├─ 静态缓存命中: {stats['static_cache_hits']} ({stats['cache_hit_rate']*100:.1f}%)")
    print(f"  ├─ 静态库查询: {stats['static_db_queries']}")
    print(f"  └─ 动态库查询: {stats['dynamic_db_queries']}")
    print(f"\n数据库访问率: {stats['db_access_rate']*100:.1f}%")
    print(f"成本节约: ~{(1-stats['db_access_rate'])*100:.1f}%")
    
    print("\n✅ 混合架构优势:")
    print("  - 静态知识走缓存,响应极快")
    print("  - 动态知识走检索,保证实时性")
    print("  - 自动判断问题类型,智能路由")
    print("  - 兼顾速度、成本和准确性")
# 运行混合系统演示
demo_hybrid_system()

四、怎么判断该缓存什么?

4.1 为什么需要选择性缓存?

不是所有知识都该缓存。如果乱缓存,会遇到两个问题:

  1. 内存爆炸:缓存太多,占用大量内存
  2. 命中率低:缓存了不常用的内容,浪费空间

所以需要一套智能缓存策略。

4.2 基于访问频率的智能缓存

class SmartCache:
    """智能缓存系统(基于LFU+LRU)"""
    
    def __init__(self, max_size: int = 100, min_access_count: int = 3):
        self.cache = {}
        self.access_count = {}  # 访问计数
        self.last_access = {}   # 最后访问时间
        self.max_size = max_size
        self.min_access_count = min_access_count  # 最小访问次数才缓存
        
        # 候选池:访问次数不够的暂存这里
        self.candidate_pool = {}
        self.candidate_access = {}
        
    def should_cache(self, key: str) -> bool:
        """判断是否应该缓存"""
        # 如果已经在候选池
        if key in self.candidate_access:
            self.candidate_access[key] += 1
            
            # 访问次数达到阈值,提升到正式缓存
            if self.candidate_access[key] >= self.min_access_count:
                return True
        else:
            # 首次访问,加入候选池
            self.candidate_access[key] = 1
        
        return False
    
    def set(self, key: str, value: Dict):
        """设置缓存(只缓存热数据)"""
        if not self.should_cache(key):
            # 暂存到候选池
            self.candidate_pool[key] = value
            return False
        
        # 达到缓存条件,正式缓存
        if len(self.cache) >= self.max_size:
            # 淘汰策略:LFU + LRU
            self._evict()
        
        self.cache[key] = value
        self.access_count[key] = self.candidate_access.get(key, 1)
        self.last_access[key] = time.time()
        
        # 从候选池移除
        if key in self.candidate_pool:
            del self.candidate_pool[key]
        
        return True
    
    def get(self, key: str) -> Optional[Dict]:
        """获取缓存"""
        if key in self.cache:
            # 更新访问统计
            self.access_count[key] += 1
            self.last_access[key] = time.time()
            return self.cache[key]
        
        # 检查候选池
        if key in self.candidate_pool:
            self.candidate_access[key] += 1
            # 如果访问够多了,提升到正式缓存
            if self.candidate_access[key] >= self.min_access_count:
                self.set(key, self.candidate_pool[key])
                return self.cache[key]
            return self.candidate_pool[key]
        
        return None
    
    def _evict(self):
        """淘汰缓存项(LFU + LRU组合)"""
        if not self.cache:
            return
        
        # 找出访问次数最少的项
        min_count = min(self.access_count.values())
        candidates = [k for k, v in self.access_count.items() if v == min_count]
        
        # 如果有多个,选最久未访问的
        if len(candidates) > 1:
            evict_key = min(candidates, key=lambda k: self.last_access[k])
        else:
            evict_key = candidates[0]
        
        # 删除
        del self.cache[evict_key]
        del self.access_count[evict_key]
        del self.last_access[evict_key]
    
    def get_statistics(self) -> Dict:
        """获取统计信息"""
        return {
            'cache_size': len(self.cache),
            'candidate_size': len(self.candidate_pool),
            'total_size': len(self.cache) + len(self.candidate_pool),
            'avg_access_count': np.mean(list(self.access_count.values())) if self.access_count else 0,
            'hot_items': sorted(
                [(k, v) for k, v in self.access_count.items()],
                key=lambda x: x[1],
                reverse=True
            )[:5]  # 前5个热门项
        }
class SmartCachingSystem:
    """带智能缓存的完整系统"""
    
    def __init__(self, cache_size: int = 50, min_access: int = 3):
        self.vector_db = SimpleVectorDB()
        self.smart_cache = SmartCache(max_size=cache_size, min_access_count=min_access)
        
        self.stats = {
            'total_queries': 0,
            'cache_hits': 0,
            'db_queries': 0,
            'promoted_to_cache': 0  # 从候选池提升到正式缓存的次数
        }
    
    def add_knowledge(self, text: str, metadata: Dict = None):
        """添加知识"""
        self.vector_db.add_document(text, metadata)
    
    def query(self, question: str) -> Dict:
        """查询"""
        self.stats['total_queries'] += 1
        start_time = time.time()
        
        # 查缓存
        cached = self.smart_cache.get(question)
        if cached and question in self.smart_cache.cache:  # 正式缓存命中
            self.stats['cache_hits'] += 1
            return {
                'question': question,
                'answer': cached['answer'],
                'source': 'CACHE',
                'response_time': time.time() - start_time
            }
        
        # 检索
        results = self.vector_db.search(question, top_k=2)
        self.stats['db_queries'] += 1
        
        context = "\n".join([r['text'] for r in results]) if results else ""
        answer = f"基于检索: {context[:100]}..."
        
        # 尝试缓存(智能判断)
        cached_result = self.smart_cache.set(question, {
            'answer': answer,
            'context': context,
            'metadata': results[0].get('metadata', {}) if results else {}
        })
        
        if cached_result:
            self.stats['promoted_to_cache'] += 1
        
        return {
            'question': question,
            'answer': answer,
            'source': 'RETRIEVAL',
            'response_time': time.time() - start_time,
            'will_cache': cached_result
        }
    
    def get_statistics(self) -> Dict:
        """获取统计"""
        cache_stats = self.smart_cache.get_statistics()
        
        return {
            'queries': self.stats,
            'cache': cache_stats,
            'cache_hit_rate': self.stats['cache_hits'] / self.stats['total_queries'] if self.stats['total_queries'] > 0 else 0
        }
def demo_smart_caching():
    """演示智能缓存"""
    print("=" * 60)
    print("智能缓存系统演示")
    print("=" * 60)
    
    # 创建系统
    system = SmartCachingSystem(cache_size=10, min_access=3)
    
    # 添加知识
    knowledge = [
        "年假政策:入职满1年5天,满3年10天,满5年15天",
        "报销流程:提交申请→审批→财务审核→打款",
        "公积金比例:公司和个人各12%",
        "加班政策:工作日1.5倍,周末2倍,节假日3倍",
        "社保缴纳:养老8%医疗2%失业0.5%"
    ]
    
    for kb in knowledge:
        system.add_knowledge(kb)
    
    print(f"\n已加载 {len(knowledge)} 条知识\n")
    
    # 模拟真实查询分布(符合二八定律)
    print("模拟真实查询场景(80%查询集中在20%的问题)\n")
    
    # 热门问题(会被频繁查询)
    hot_questions = [
        "年假怎么算",
        "怎么报销",
        "公积金比例"
    ]
    
    # 冷门问题(偶尔查一次)
    cold_questions = [
        "加班怎么算",
        "社保比例",
        "病假政策",
        "迟到扣款",
        "离职流程"
    ]
    
    # 生成查询序列(80/20分布)
    query_sequence = []
    for _ in range(50):
        if np.random.random() < 0.8:  # 80%概率查热门问题
            query_sequence.append(np.random.choice(hot_questions))
        else:  # 20%概率查冷门问题
            query_sequence.append(np.random.choice(cold_questions))
    
    # 执行查询
    print("开始处理50次查询...\n")
    cache_hits_timeline = []
    
    for i, question in enumerate(query_sequence, 1):
        result = system.query(question)
        
        if i <= 10 or i % 10 == 0:  # 只显示部分结果
            source_icon = "⚡" if result['source'] == 'CACHE' else "🔍"
            cached_tag = " [已提升到缓存]" if result.get('will_cache') else ""
            print(f"查询{i:2d}: {question:15s} {source_icon} {result['source']}{cached_tag}")
        
        # 记录命中率变化
        stats = system.get_statistics()
        cache_hits_timeline.append(stats['cache_hit_rate'])
    
    # 最终统计
    print("\n" + "=" * 60)
    print("最终统计")
    print("=" * 60)
    
    final_stats = system.get_statistics()
    
    print(f"\n【查询统计】")
    print(f"  总查询次数: {final_stats['queries']['total_queries']}")
    print(f"  缓存命中: {final_stats['queries']['cache_hits']}")
    print(f"  数据库查询: {final_stats['queries']['db_queries']}")
    print(f"  提升到缓存: {final_stats['queries']['promoted_to_cache']}")
    
    print(f"\n【缓存统计】")
    print(f"  正式缓存: {final_stats['cache']['cache_size']} 项")
    print(f"  候选池: {final_stats['cache']['candidate_size']} 项")
    print(f"  总存储: {final_stats['cache']['total_size']} 项")
    print(f"  缓存命中率: {final_stats['cache_hit_rate']*100:.1f}%")
    print(f"  平均访问次数: {final_stats['cache']['avg_access_count']:.1f}")
    
    print(f"\n【热门问题Top5】")
    for i, (question, count) in enumerate(final_stats['cache']['hot_items'], 1):
        print(f"  {i}. {question} - 访问{count}次")
    
    print("\n✅ 智能缓存特点:")
    print("  - 只缓存被多次访问的热门问题(访问≥3次)")
    print("  - 冷门问题不占用宝贵的缓存空间")
    print("  - 自动淘汰不常用的缓存项")
    print("  - 符合真实业务场景的访问分布")
    
    # 显示命中率趋势
    print(f"\n【命中率趋势】前20次查询:")
    for i in range(0, min(20, len(cache_hits_timeline)), 5):
        rate = cache_hits_timeline[i]
        bar = "█" * int(rate * 50)
        print(f"  查询{i+1:2d}: {bar} {rate*100:.1f}%")
# 运行智能缓存演示
demo_smart_caching()

五、缓存更新策略

5.1 如何处理知识更新?

静态知识也会更新,比如:

  • 公司政策调整
  • 产品信息变更
  • 法律法规修订

这时需要缓存失效机制。

5.2 完整的缓存更新实现

from datetime import datetime, timedelta
class CacheWithTTL:
    """带过期时间的缓存"""
    
    def __init__(self, max_size: int = 100, default_ttl: int = 86400):
        """
        Args:
            max_size: 最大缓存数量
            default_ttl: 默认过期时间(秒),默认24小时
        """
        self.cache = {}
        self.max_size = max_size
        self.default_ttl = default_ttl
        
        self.stats = {
            'hits': 0,
            'misses': 0,
            'expires': 0,
            'invalidations': 0
        }
    
    def set(self, key: str, value: Dict, ttl: Optional[int] = None):
        """设置缓存项
        
        Args:
            key: 缓存键
            value: 缓存值
            ttl: 过期时间(秒),None则使用默认值
        """
        if len(self.cache) >= self.max_size:
            self._evict_oldest()
        
        expire_at = time.time() + (ttl if ttl is not None else self.default_ttl)
        
        self.cache[key] = {
            'value': value,
            'expire_at': expire_at,
            'created_at': time.time(),
            'version': value.get('metadata', {}).get('version', 1)
        }
    
    def get(self, key: str) -> Optional[Dict]:
        """获取缓存项"""
        if key not in self.cache:
            self.stats['misses'] += 1
            return None
        
        item = self.cache[key]
        
        # 检查是否过期
        if time.time() > item['expire_at']:
            self.stats['expires'] += 1
            del self.cache[key]
            return None
        
        self.stats['hits'] += 1
        return item['value']
    
    def invalidate(self, key: str):
        """主动失效某个缓存"""
        if key in self.cache:
            del self.cache[key]
            self.stats['invalidations'] += 1
    
    def invalidate_by_pattern(self, pattern: str):
        """按模式批量失效"""
        keys_to_delete = [k for k in self.cache.keys() if pattern in k]
        for key in keys_to_delete:
            self.invalidate(key)
    
    def update_version(self, key: str, new_version: int):
        """更新版本号(触发重新缓存)"""
        if key in self.cache:
            current_version = self.cache[key]['version']
            if new_version > current_version:
                # 版本更新,失效旧缓存
                self.invalidate(key)
    
    def _evict_oldest(self):
        """淘汰最旧的项"""
        if not self.cache:
            return
        oldest_key = min(self.cache.keys(), 
                        key=lambda k: self.cache[k]['created_at'])
        del self.cache[oldest_key]
    
    def get_statistics(self) -> Dict:
        """获取统计"""
        total = self.stats['hits'] + self.stats['misses']
        return {
            **self.stats,
            'hit_rate': self.stats['hits'] / total if total > 0 else 0,
            'cache_size': len(self.cache)
        }
class VersionedKnowledgeBase:
    """带版本控制的知识库"""
    
    def __init__(self):
        self.documents = {}  # key: doc_id, value: {content, version, metadata}
        self.cache = CacheWithTTL(max_size=50, default_ttl=3600)  # 1小时TTL
        self.vector_db = SimpleVectorDB()
        
    def add_or_update_document(self, doc_id: str, content: str, 
                                metadata: Dict = None, version: int = 1):
        """添加或更新文档"""
        # 检查是否是更新
        is_update = doc_id in self.documents
        old_version = self.documents[doc_id]['version'] if is_update else 0
        
        # 保存文档
        self.documents[doc_id] = {
            'content': content,
            'version': version,
            'metadata': metadata or {},
            'updated_at': datetime.now().isoformat()
        }
        
        # 更新向量数据库
        self.vector_db.add_document(content, {
            'doc_id': doc_id,
            'version': version,
            **(metadata or {})
        })
        
        # 如果是更新,失效相关缓存
        if is_update and version > old_version:
            print(f"📝 文档 {doc_id} 更新: v{old_version} -> v{version}")
            self.cache.invalidate_by_pattern(doc_id)
            return True
        
        return False
    
    def query(self, question: str, doc_id: Optional[str] = None) -> Dict:
        """查询(支持版本检查)"""
        # 构建缓存键
        cache_key = f"{doc_id}:{question}" if doc_id else question
        
        # 查缓存
        cached = self.cache.get(cache_key)
        if cached:
            return {
                'question': question,
                'answer': cached['answer'],
                'source': 'CACHE',
                'version': cached.get('version', 'unknown')
            }
        
        # 检索
        results = self.vector_db.search(question, top_k=2)
        if not results:
            return {'question': question, 'answer': '未找到相关信息', 'source': 'NONE'}
        
        # 生成答案
        context = results[0]['text']
        result_doc_id = results[0]['metadata'].get('doc_id', 'unknown')
        result_version = results[0]['metadata'].get('version', 1)
        answer = f"[v{result_version}] {context}"
        
        # 缓存结果
        self.cache.set(cache_key, {
            'answer': answer,
            'context': context,
            'metadata': {
                'doc_id': result_doc_id,
                'version': result_version
            }
        }, ttl=3600)  # 1小时过期
        
        return {
            'question': question,
            'answer': answer,
            'source': 'RETRIEVAL',
            'version': result_version,
            'doc_id': result_doc_id
        }
def demo_cache_update():
    """演示缓存更新机制"""
    print("=" * 60)
    print("缓存更新与版本控制演示")
    print("=" * 60)
    
    kb = VersionedKnowledgeBase()
    
    # 场景1:初始知识
    print("\n【场景1:初始加载知识】")
    kb.add_or_update_document(
        doc_id="policy_annual_leave",
        content="年假政策v1:入职满1年5天,满3年10天,满5年15天",
        metadata={'category': 'HR'},
        version=1
    )
    print("✅ 已添加:年假政策 v1")
    
    # 第一次查询
    print("\n第1次查询:年假怎么算?")
    result1 = kb.query("年假怎么算")
    print(f"  来源: {result1['source']}")
    print(f"  版本: {result1['version']}")
    print(f"  答案: {result1['answer'][:50]}...")
    
    # 第二次查询(应该命中缓存)
    print("\n第2次查询:年假怎么算?")
    result2 = kb.query("年假怎么算")
    print(f"  来源: {result2['source']} ⚡")
    print(f"  版本: {result2['version']}")
    
    # 场景2:政策更新
    print("\n" + "=" * 60)
    print("【场景2:政策更新】")
    print("公司调整年假政策...")
    
    kb.add_or_update_document(
        doc_id="policy_annual_leave",
        content="年假政策v2:入职满1年7天,满3年12天,满5年20天。新增:满10年25天",
        metadata={'category': 'HR'},
        version=2
    )
    
    # 再次查询(缓存已失效,应该返回新版本)
    print("\n第3次查询:年假怎么算?")
    result3 = kb.query("年假怎么算")
    print(f"  来源: {result3['source']}")
    print(f"  版本: {result3['version']}")
    print(f"  答案: {result3['answer'][:60]}...")
    
    # 场景3:缓存重建
    print("\n第4次查询:年假怎么算?")
    result4 = kb.query("年假怎么算")
    print(f"  来源: {result4['source']} ⚡ (新版本已缓存)")
    print(f"  版本: {result4['version']}")
    
    # 统计
    print("\n" + "=" * 60)
    print("缓存统计")
    print("=" * 60)
    stats = kb.cache.get_statistics()
    print(f"缓存命中: {stats['hits']}")
    print(f"缓存未命中: {stats['misses']}")
    print(f"缓存失效: {stats['invalidations']}")
    print(f"命中率: {stats['hit_rate']*100:.1f}%")
    
    print("\n✅ 更新机制总结:")
    print("  - 文档更新时自动失效相关缓存")
    print("  - 版本号控制确保数据一致性")
    print("  - 支持TTL自动过期")
    print("  - 下次查询会获取最新版本并重新缓存")
# 运行缓存更新演示
demo_cache_update()

5.3 三种缓存失效策略对比

def compare_invalidation_strategies():
    """对比不同的缓存失效策略"""
    print("=" * 60)
    print("三种缓存失效策略对比")
    print("=" * 60)
    
    print("\n【策略1:固定TTL(Time To Live)】")
    print("特点:设置固定过期时间")
    print("优点:实现简单,自动清理")
    print("缺点:可能返回过期数据")
    print("适用:可以容忍短期延迟的场景")
    print("\n示例代码:")
    print("""
    cache.set('question', answer, ttl=3600)  # 1小时后过期
    """)
    
    print("\n【策略2:版本号控制】")
    print("特点:每次更新增加版本号")
    print("优点:精确控制,不会返回旧数据")
    print("缺点:需要维护版本号系统")
    print("适用:数据一致性要求高的场景")
    print("\n示例代码:")
    print("""
    # 更新文档时
    doc.version += 1
    cache.invalidate_by_version(doc.id, doc.version)
    """)
    
    print("\n【策略3:主动推送失效】")
    print("特点:内容更新时主动通知缓存失效")
    print("优点:实时性最好")
    print("缺点:需要额外的通知机制")
    print("适用:分布式系统、多节点部署")
    print("\n示例代码:")
    print("""
    # 发布更新事件
    event_bus.publish('document_updated', doc_id='policy_123')
    
    # 监听器失效缓存
    @event_bus.subscribe('document_updated')
    def on_document_updated(doc_id):
        cache.invalidate_by_pattern(doc_id)
    """)
    
    # 实际测试对比
    print("\n" + "=" * 60)
    print("实际场景测试")
    print("=" * 60)
    
    # 模拟:文档每小时更新一次,查询每分钟一次
    ttl_configs = [
        {'name': 'TTL=10分钟', 'ttl': 600, 'update_interval': 3600},
        {'name': 'TTL=30分钟', 'ttl': 1800, 'update_interval': 3600},
        {'name': 'TTL=60分钟', 'ttl': 3600, 'update_interval': 3600},
    ]
    
    print("\n假设:文档每小时更新,查询每分钟一次(共120次查询)")
    print("\n不同TTL配置的效果:\n")
    
    for config in ttl_configs:
        ttl = config['ttl']
        update_interval = config['update_interval']
        
        # 计算可能返回过期数据的次数
        stale_responses = max(0, (update_interval - ttl) / 60)  # 分钟
        freshness_rate = (60 - stale_responses) / 60 * 100
        
        print(f"{config['name']}:")
        print(f"  可能过期的响应: ~{int(stale_responses)}次")
        print(f"  数据新鲜度: {freshness_rate:.1f}%")
        print()
    
    print("💡 建议:")
    print("  - 制度类文档:TTL = 24小时 + 版本控制")
    print("  - 产品信息:TTL = 1小时 + 版本控制")
    print("  - 实时数据:不缓存或TTL < 5分钟")
# 运行对比
compare_invalidation_strategies()

六、生产级实现与最佳实践

6.1 完整的生产级CAG系统

import logging
from typing import Callable
from dataclasses import dataclass
from enum import Enum
class CacheStrategy(Enum):
    """缓存策略"""
    ALWAYS = "always"  # 总是缓存
    SMART = "smart"    # 智能判断
    NEVER = "never"    # 从不缓存
@dataclass
class CacheConfig:
    """缓存配置"""
    max_size: int = 100
    default_ttl: int = 3600
    min_access_count: int = 3
    strategy: CacheStrategy = CacheStrategy.SMART
    enable_metrics: bool = True
class ProductionCAGSystem:
    """生产级CAG系统"""
    def __init__(self, config: CacheConfig = None):
        self.config = config or CacheConfig()
        # 核心组件
        self.static_cache = CacheWithTTL(
            max_size=self.config.max_size,
            default_ttl=self.config.default_ttl
        )
        self.smart_cache = SmartCache(
            max_size=self.config.max_size,
            min_access_count=self.config.min_access_count
        )
        self.vector_db = SimpleVectorDB()
        # 监控指标
        self.metrics = {
            'total_queries': 0,
            'cache_hits': 0,
            'db_queries': 0,
            'avg_response_time': [],
            'errors': 0
        }
        # 日志
        self.logger = self._setup_logger()
    def _setup_logger(self):
        """设置日志"""
        logger = logging.getLogger('CAGSystem')
        logger.setLevel(logging.INFO)
        if not logger.handlers:
            handler = logging.StreamHandler()
            formatter = logging.Formatter(
                '%(asctime)s - %(name)s - %(levelname)s - %(message)s'
            )
            handler.setFormatter(formatter)
            logger.addHandler(handler)
        return logger
    def add_knowledge(self, 
                     text: str, 
                     doc_id: str,
                     metadata: Dict = None,
                     cache_strategy: CacheStrategy = None,
                     ttl: int = None):
        """添加知识"""
        try:
            full_metadata = {
                'doc_id': doc_id,
                'cache_strategy': (cache_strategy or self.config.strategy).value,
                'ttl': ttl or self.config.default_ttl,
                **(metadata or {})
            }
            self.vector_db.add_document(text, full_metadata)
            self.logger.info(f"Added document: {doc_id}")
        except Exception as e:
            self.logger.error(f"Error adding document {doc_id}: {str(e)}")
            self.metrics['errors'] += 1
            raise
    def query(self, question: str, force_refresh: bool = False) -> Dict:
        """查询"""
        start_time = time.time()
        self.metrics['total_queries'] += 1
        try:
            # 强制刷新则跳过缓存
            if not force_refresh:
                # 先查静态缓存
                cached = self.static_cache.get(question)
                if cached:
                    self.metrics['cache_hits'] += 1
                    response_time = time.time() - start_time
                    self.metrics['avg_response_time'].append(response_time)
                    self.logger.info(f"Cache hit: {question[:30]}...")
                    return {
                        'question': question,
                        'answer': cached['answer'],
                        'source': 'STATIC_CACHE',
                        'response_time': response_time,
                        'cached': True
                    }
                # 再查智能缓存
                smart_cached = self.smart_cache.get(question)
                if smart_cached and question in self.smart_cache.cache:
                    self.metrics['cache_hits'] += 1
                    response_time = time.time() - start_time
                    self.metrics['avg_response_time'].append(response_time)
                    self.logger.info(f"Smart cache hit: {question[:30]}...")
                    return {
                        'question': question,
                        'answer': smart_cached['answer'],
                        'source': 'SMART_CACHE',
                        'response_time': response_time,
                        'cached': True
                    }
            # 缓存未命中,执行检索
            self.logger.info(f"Cache miss, retrieving: {question[:30]}...")
            results = self.vector_db.search(question, top_k=2)
            self.metrics['db_queries'] += 1
            if not results:
                response_time = time.time() - start_time
                self.metrics['avg_response_time'].append(response_time)
                return {
                    'question': question,
                    'answer': '未找到相关信息',
                    'source': 'NONE',
                    'response_time': response_time,
                    'cached': False
                }
            # 生成答案
            context = results[0]['text']
            metadata = results[0].get('metadata', {})
            answer = f"基于检索: {context[:100]}..."
            # 根据策略决定是否缓存
            strategy = CacheStrategy(metadata.get('cache_strategy', 'smart'))
            ttl = metadata.get('ttl', self.config.default_ttl)
            if strategy == CacheStrategy.ALWAYS:
                # 直接缓存到静态缓存
                self.static_cache.set(question, {
                    'answer': answer,
                    'context': context,
                    'metadata': metadata
                }, ttl=ttl)
                self.logger.info(f"Cached to static (ALWAYS): {question[:30]}...")
            elif strategy == CacheStrategy.SMART:
                # 让智能缓存决定
                cached = self.smart_cache.set(question, {
                    'answer': answer,
                    'context': context,
                    'metadata': metadata
                })
                if cached:
                    self.logger.info(f"Promoted to smart cache: {question[:30]}...")
            response_time = time.time() - start_time
            self.metrics['avg_response_time'].append(response_time)
            return {
                'question': question,
                'answer': answer,
                'source': 'RETRIEVAL',
                'response_time': response_time,
                'cached': False,
                'cache_strategy': strategy.value
            }
        except Exception as e:
            self.logger.error(f"Error processing query '{question}': {str(e)}")
            self.metrics['errors'] += 1
            raise
    def invalidate_cache(self, pattern: str = None, doc_id: str = None):
        """失效缓存"""
        try:
            if pattern:
                self.static_cache.invalidate_by_pattern(pattern)
                self.logger.info(f"Invalidated cache by pattern: {pattern}")
            if doc_id:
                self.static_cache.invalidate_by_pattern(doc_id)
                self.logger.info(f"Invalidated cache for doc: {doc_id}")
        except Exception as e:
            self.logger.error(f"Error invalidating cache: {str(e)}")
            raise
    def get_health_status(self) -> Dict:
        """获取系统健康状态"""
        total_queries = self.metrics['total_queries']
        cache_hit_rate = self.metrics['cache_hits'] / total_queries if total_queries > 0 else 0
        avg_response = np.mean(self.metrics['avg_response_time']) if self.metrics['avg_response_time'] else 0
        # 健康评分
        health_score = 100
        if cache_hit_rate < 0.3:
            health_score -= 20  # 命中率低
        if avg_response > 0.1:
            health_score -= 15  # 响应慢
        if self.metrics['errors'] > 0:
            health_score -= 30  # 有错误
        status = 'healthy' if health_score >= 80 else 'degraded' if health_score >= 50 else 'unhealthy'
        return {
            'status': status,
            'health_score': health_score,
            'metrics': {
                'total_queries': total_queries,
                'cache_hit_rate': f"{cache_hit_rate*100:.1f}%",
                'avg_response_time': f"{avg_response*1000:.2f}ms",
                'db_queries': self.metrics['db_queries'],
                'errors': self.metrics['errors']
            },
            'cache_info': {
                'static_cache_size': self.static_cache.get_statistics()['cache_size'],
                'smart_cache_size': self.smart_cache.get_statistics()['cache_size']
            }
        }
    def export_metrics(self) -> Dict:
        """导出指标(用于监控系统)"""
        static_stats = self.static_cache.get_statistics()
        smart_stats = self.smart_cache.get_statistics()
        return {
            'timestamp': datetime.now().isoformat(),
            'queries': {
                'total': self.metrics['total_queries'],
                'cache_hits': self.metrics['cache_hits'],
                'db_queries': self.metrics['db_queries'],
                'errors': self.metrics['errors']
            },
            'performance': {
                'cache_hit_rate': self.metrics['cache_hits'] / self.metrics['total_queries'] if self.metrics['total_queries'] > 0 else 0,
                'avg_response_time': np.mean(self.metrics['avg_response_time']) if self.metrics['avg_response_time'] else 0,
                'p95_response_time': np.percentile(self.metrics['avg_response_time'], 95) if len(self.metrics['avg_response_time']) > 0 else 0
            },
            'cache': {
                'static': static_stats,
                'smart': smart_stats
            }
        }
def demo_production_system():
    """演示生产级系统"""
    print("=" * 60)
    print("生产级CAG系统演示")
    print("=" * 60)
    # 创建系统(不同配置)
    config = CacheConfig(
        max_size=50,
        default_ttl=3600,
        min_access_count=2,
        strategy=CacheStrategy.SMART,
        enable_metrics=True
    )
    system = ProductionCAGSystem(config)
    print(f"\n系统配置:")
    print(f"  缓存大小: {config.max_size}")
    print(f"  默认TTL: {config.default_ttl}秒")
    print(f"  最小访问次数: {config.min_access_count}")
    print(f"  缓存策略: {config.strategy.value}")
    # 添加不同类型的知识
    print("\n" + "=" * 60)
    print("添加知识")
    print("=" * 60)
    # 1. 静态知识(总是缓存)
    system.add_knowledge(
        text="公司年假政策:入职满1年5天,满3年10天,满5年15天",
        doc_id="policy_001",
        metadata={'category': '制度', 'type': 'static'},
        cache_strategy=CacheStrategy.ALWAYS,
        ttl=86400  # 24小时
    )
    print("✅ 添加静态知识: 年假政策(ALWAYS缓存)")
    # 2. 半静态知识(智能缓存)
    system.add_knowledge(
        text="产品价格表:基础版99元/月,专业版199元/月,企业版499元/月",
        doc_id="product_002",
        metadata={'category': '产品', 'type': 'semi-static'},
        cache_strategy=CacheStrategy.SMART,
        ttl=3600  # 1小时
    )
    print("✅ 添加半静态知识: 产品价格(SMART缓存)")
    # 3. 动态知识(不缓存)
    system.add_knowledge(
        text="今日促销:所有产品8折优惠,仅限今天!",
        doc_id="promo_003",
        metadata={'category': '促销', 'type': 'dynamic'},
        cache_strategy=CacheStrategy.NEVER,
        ttl=300  # 5分钟
    )
    print("✅ 添加动态知识: 促销信息(NEVER缓存)")
    # 模拟真实查询场景
    print("\n" + "=" * 60)
    print("模拟真实查询")
    print("=" * 60)
    queries = [
        # 静态问题(高频)
        ("年假怎么算", 5),
        ("年假政策", 3),
        # 半静态问题(中频)
        ("产品价格", 3),
        ("多少钱", 2),
        # 动态问题(低频)
        ("今天有优惠吗", 2),
        ("促销活动", 1)
    ]
    print("\n执行查询...")
    for question, count in queries:
        for i in range(count):
            result = system.query(question)
            if i == 0:  # 只显示首次查询
                cache_tag = "⚡" if result['cached'] else "🔍"
                print(f"  {cache_tag} {question}: {result['source']} ({result['response_time']*1000:.2f}ms)")
    # 显示健康状态
    print("\n" + "=" * 60)
    print("系统健康状态")
    print("=" * 60)
    health = system.get_health_status()
    status_icon = "✅" if health['status'] == 'healthy' else "⚠️" if health['status'] == 'degraded' else "❌"
    print(f"\n状态: {status_icon} {health['status'].upper()}")
    print(f"健康评分: {health['health_score']}/100")
    print(f"\n指标:")
    for key, value in health['metrics'].items():
        print(f"  {key}: {value}")
    print(f"\n缓存信息:")
    for key, value in health['cache_info'].items():
        print(f"  {key}: {value}")
    # 导出指标
    print("\n" + "=" * 60)
    print("性能指标(可接入Prometheus/Grafana)")
    print("=" * 60)
    metrics = system.export_metrics()
    print(f"\n时间戳: {metrics['timestamp']}")
    print(f"\n查询统计:")
    print(f"  总查询: {metrics['queries']['total']}")
    print(f"  缓存命中: {metrics['queries']['cache_hits']}")
    print(f"  数据库查询: {metrics['queries']['db_queries']}")
    print(f"  错误数: {metrics['queries']['errors']}")
    print(f"\n性能指标:")
    print(f"  缓存命中率: {metrics['performance']['cache_hit_rate']*100:.1f}%")
    print(f"  平均响应时间: {metrics['performance']['avg_response_time']*1000:.2f}ms")
    print(f"  P95响应时间: {metrics['performance']['p95_response_time']*1000:.2f}ms")
    # 最佳实践总结
    print("\n" + "=" * 60)
    print("生产环境最佳实践")
    print("=" * 60)
    print("""
1. 【分层缓存策略】
   - 静态知识:ALWAYS + 长TTL(24小时)
   - 半静态知识:SMART + 中TTL(1小时)
   - 动态知识:NEVER 或 短TTL(5分钟)
2. 【监控指标】
   - 缓存命中率:目标 >50%
   - 平均响应时间:目标 <50ms
   - P95响应时间:目标 <100ms
   - 错误率:目标 <0.1%
3. 【容量规划】
   - 缓存大小 = 日查询量 × 0.2(二八定律)
   - 预留20%扩展空间
   - 设置告警阈值:命中率<30%、响应>100ms
4. 【失效策略】
   - 定时失效:使用TTL
   - 主动失效:文档更新时触发
   - 批量失效:支持按模式匹配
5. 【高可用保障】
   - 缓存失败降级到检索
   - 异常捕获和日志记录
   - 健康检查接口
   - 指标导出到监控系统
    """)
# 运行生产系统演示
demo_production_system()

6.2完整的代码示例

现在让我们把所有代码整合到一起,提供一个完整可运行的demo:

def run_complete_demo():
    """运行完整演示"""
    print("\n\n")
    print("="* 80)
    print(" " * 20 + "CAG完整演示:从RAG到生产级CAG")
    print("=" * 80)
    
    print("\n这个演示将展示:")
    print("  1. 传统RAG的性能问题")
    print("  2. CAG如何解决这些问题")
    print("  3. RAG+CAG混合架构")
    print("  4. 智能缓存策略")
    print("  5. 缓存更新机制")
    print("  6. 生产级系统实现")
    
    print("\n" + "=" * 80)
    input("按回车键开始演示...")
    
    # 依次运行各个演示
    demos = [
        ("传统RAG系统", demo_traditional_rag),
        ("CAG系统", demo_cag_system),
        ("RAG vs CAG性能对比", compare_rag_vs_cag),
        ("混合RAG+CAG系统", demo_hybrid_system),
        ("智能缓存", demo_smart_caching),
        ("缓存更新机制", demo_cache_update),
        ("缓存失效策略对比", compare_invalidation_strategies),
        ("生产级系统", demo_production_system),
    ]
    
    for i, (name, demo_func) in enumerate(demos, 1):
        print(f"\n\n{'='*80}")
        print(f"演示 {i}/{len(demos)}: {name}")
        print("="*80)
        input("按回车继续...")
        demo_func()
        print("\n演示完成!")
        if i < len(demos):
            input("按回车进入下一个演示...")
    
    print("\n\n" + "="*80)
    print(" " * 30 + "所有演示完成!")
    print("="*80)
    print("\n📚 你已经学会了:")
    print("  ✅ RAG的基本原理和问题")
    print("  ✅ CAG如何通过缓存提升性能")
    print("  ✅ 如何设计混合架构")
    print("  ✅ 智能缓存策略的实现")
    print("  ✅ 缓存更新和失效机制")
    print("  ✅ 生产级系统的完整实现")
    print("\n💡 下一步:")
    print("  - 在自己的项目中应用这些技术")
    print("  - 根据实际场景调整缓存策略")
    print("  - 接入监控系统持续优化")
    print("  - 考虑分布式缓存(Redis等)")
# 如果直接运行此文件,执行完整演示
if __name__ == "__main__":
    run_complete_demo()

七、总结

以前的AI是"现学现卖",每次都要临时抱佛脚。

而CAG让AI有了真正的"记忆力",能把核心知识牢牢记住,需要的时候随时调取。

这不是简单的技术升级,而是让AI从"查询工具"向"智能助手"的本质跃迁。

想象一下:未来你的AI助手不仅知道去哪儿查信息,更重要的是,它能记住你的习惯、你的偏好、你们之间的每一次对话……

那时候,它才真正成为了你的"数字分身"。

而这一切的起点,就是从让AI学会"记忆"开始

最后

选择AI大模型就是选择未来!最近两年,大家都可以看到AI的发展有多快,时代在瞬息万变,我们又为何不给自己多一个选择,多一个出路,多一个可能呢?

与其在传统行业里停滞不前,不如尝试一下新兴行业,而AI大模型恰恰是这两年的大风口,人才需求急为紧迫!

由于文章篇幅有限,在这里我就不一一向大家展示了,学习AI大模型是一项系统工程,需要时间和持续的努力。但随着技术的发展和在线资源的丰富,零基础的小白也有很好的机会逐步学习和掌握。

【2025最新】AI大模型全套学习籽料(可无偿送):LLM面试题+AI大模型学习路线+大模型PDF书籍+640套AI大模型报告等等,从入门到进阶再到精通,超全面存下吧!

获取方式:有需要的小伙伴,可以微信扫描下方CSDN官方认证二维码免费领取【保证100%免费】
包括:AI大模型学习路线、LLM面试宝典、0基础教学视频、大模型PDF书籍/笔记、大模型实战案例合集、AI产品经理合集等等

在这里插入图片描述

AI大模型学习之路,道阻且长,但只要你坚持下去,就一定会有收获。

Logo

有“AI”的1024 = 2048,欢迎大家加入2048 AI社区

更多推荐