从零到一搭建 AI Agent 记忆系统:九种策略全景实战(含注释代码)

这篇文章展示的是记忆系统,聚焦“可复用、可落地”的记忆架构。你将看到九种常见记忆策略的原理与权衡,并附带带注释的关键实现代码,便于直接复刻。


0. 统一接口与 Agent 框架(所有策略的地基)

在构建任何记忆策略之前,先定义统一接口与 Agent 调度流程,这样策略可以“即插即用”。

import abc

# 记忆策略的统一接口(抽象基类)
class BaseMemoryStrategy(abc.ABC):
    @abc.abstractmethod
    def add_message(self, user_input: str, ai_response: str):
        """将一轮对话写入记忆"""
        pass

    @abc.abstractmethod
    def get_context(self, query: str) -> str:
        """根据当前问题提取上下文"""
        pass

    @abc.abstractmethod
    def clear(self):
        """清空记忆"""
        pass


class AIAgent:
    """统一的 Agent 逻辑:取记忆 -> 构造提示词 -> 调用 LLM -> 更新记忆"""
    def __init__(self, memory_strategy: BaseMemoryStrategy, system_prompt: str = "You are a helpful AI assistant."):
        self.memory = memory_strategy
        self.system_prompt = system_prompt

    def chat(self, user_input: str) -> str:
        # 1) 获取记忆上下文
        context = self.memory.get_context(query=user_input)
        # 2) 拼接提示词
        full_user_prompt = f"### MEMORY CONTEXT\n{context}\n\n### CURRENT REQUEST\n{user_input}"
        # 3) 调用 LLM(generate_text 来自你的工具函数)
        ai_response = generate_text(self.system_prompt, full_user_prompt)
        # 4) 写回记忆
        self.memory.add_message(user_input, ai_response)
        return ai_response

1) 顺序记忆(Sequential / Keep-It-All)

特点:把所有对话完整保存并拼接。优点是“记得全”,缺点是上下文无限增长。

class SequentialMemory(BaseMemoryStrategy):
    def __init__(self):
        # 使用列表存所有对话
        self.history = []

    def add_message(self, user_input: str, ai_response: str):
        # 依次写入用户与助手消息
        self.history.append({"role": "user", "content": user_input})
        self.history.append({"role": "assistant", "content": ai_response})

    def get_context(self, query: str) -> str:
        # 将历史对话拼接成一段文本
        return "\n".join([f"{t['role'].capitalize()}: {t['content']}" for t in self.history])

    def clear(self):
        # 清空历史
        self.history = []

2) 滑动窗口记忆(Sliding Window)

特点:只保留最近 N 轮对话,成本稳定但会遗忘。

from collections import deque

class SlidingWindowMemory(BaseMemoryStrategy):
    def __init__(self, window_size: int = 4):
        # deque 自动维护长度上限
        self.history = deque(maxlen=window_size)

    def add_message(self, user_input: str, ai_response: str):
        # 一轮对话作为一个“turn”写入
        self.history.append([
            {"role": "user", "content": user_input},
            {"role": "assistant", "content": ai_response}
        ])

    def get_context(self, query: str) -> str:
        # 展开 deque 生成上下文
        ctx = []
        for turn in self.history:
            for msg in turn:
                ctx.append(f"{msg['role'].capitalize()}: {msg['content']}")
        return "\n".join(ctx)

    def clear(self):
        self.history.clear()

3) 总结记忆(Summarization)

特点:对话到阈值后,让 LLM 生成摘要并合并。适合长对话。

class SummarizationMemory(BaseMemoryStrategy):
    def __init__(self, summary_threshold: int = 4):
        self.running_summary = ""
        self.buffer = []
        self.summary_threshold = summary_threshold

    def add_message(self, user_input: str, ai_response: str):
        # 先把消息放进缓冲区
        self.buffer.append({"role": "user", "content": user_input})
        self.buffer.append({"role": "assistant", "content": ai_response})
        # 达到阈值则触发总结
        if len(self.buffer) >= self.summary_threshold:
            self._consolidate_memory()

    def _consolidate_memory(self):
        # 将缓冲区文本拼接
        buffer_text = "\n".join([f"{m['role'].capitalize()}: {m['content']}" for m in self.buffer])
        # 构造总结提示词
        prompt = (
            "You are a summarization expert.\n"
            f"### Previous Summary:\n{self.running_summary}\n\n"
            f"### New Conversation:\n{buffer_text}\n\n"
            "### Updated Summary:"
        )
        # 调用 LLM 生成摘要
        self.running_summary = generate_text("You are a summarization engine.", prompt)
        # 清空缓冲区
        self.buffer = []

    def get_context(self, query: str) -> str:
        buffer_text = "\n".join([f"{m['role'].capitalize()}: {m['content']}" for m in self.buffer])
        return f"### Summary:\n{self.running_summary}\n\n### Recent:\n{buffer_text}"

    def clear(self):
        self.running_summary = ""
        self.buffer = []

4) 检索记忆(Retrieval / RAG)

特点:用向量检索长程相关信息,最常用的“长期记忆”方案。

import numpy as np
import faiss

class RetrievalMemory(BaseMemoryStrategy):
    def __init__(self, k: int = 2, embedding_dim: int | None = None):
        self.k = k
        self.embedding_dim = embedding_dim
        self.documents = []
        # embedding_dim 未知时,先不初始化 index
        self.index = faiss.IndexFlatL2(embedding_dim) if embedding_dim else None

    def _ensure_index(self, embedding: list):
        # 首次写入时用向量长度确定维度
        if self.embedding_dim is None:
            self.embedding_dim = len(embedding)
            self.index = faiss.IndexFlatL2(self.embedding_dim)
        # 若维度不匹配,直接抛错
        elif len(embedding) != self.embedding_dim:
            raise ValueError(f"Embedding dim {len(embedding)} != index dim {self.embedding_dim}")

    def add_message(self, user_input: str, ai_response: str):
        docs = [f"User said: {user_input}", f"AI responded: {ai_response}"]
        for doc in docs:
            emb = get_embedding(doc)
            if emb:
                self._ensure_index(emb)
                self.documents.append(doc)
                self.index.add(np.array([emb], dtype="float32"))

    def get_context(self, query: str) -> str:
        if self.index is None or self.index.ntotal == 0:
            return "No information in memory yet."
        q = get_embedding(query)
        if not q:
            return "Could not process query for retrieval."
        if len(q) != self.embedding_dim:
            return "Query embedding dimension mismatch with index."
        D, I = self.index.search(np.array([q], dtype="float32"), self.k)
        retrieved = [self.documents[i] for i in I[0] if i != -1]
        return "### Retrieved:\n" + "\n---\n".join(retrieved)

    def clear(self):
        self.documents = []
        if self.index is not None:
            self.index.reset()

5) 记忆增强(Memory-Augmented Simulation)

特点:让 LLM 识别“关键事实”,生成长期“记忆 token”。

class MemoryAugmentedMemory(BaseMemoryStrategy):
    def __init__(self, window_size: int = 2):
        self.recent_memory = SlidingWindowMemory(window_size=window_size)
        self.memory_tokens = []

    def add_message(self, user_input: str, ai_response: str):
        # 先写入短期记忆
        self.recent_memory.add_message(user_input, ai_response)
        # 让 LLM 抽取关键事实
        prompt = (
            "Analyze the following turn and extract any long-term fact.\n"
            f"User: {user_input}\nAI: {ai_response}\n"
            "If none, reply 'No important fact.'"
        )
        fact = generate_text("You are a fact-extraction expert.", prompt)
        if "no important fact" not in fact.lower():
            self.memory_tokens.append(fact)

    def get_context(self, query: str) -> str:
        recent = self.recent_memory.get_context(query)
        tokens = "\n".join([f"- {t}" for t in self.memory_tokens])
        return f"### Memory Tokens:\n{tokens}\n\n### Recent:\n{recent}"

    def clear(self):
        self.recent_memory.clear()
        self.memory_tokens = []

6) 分层记忆(Hierarchical)

特点:短期用滑窗,长期用检索,触发关键词时晋升。

class HierarchicalMemory(BaseMemoryStrategy):
    def __init__(self, window_size: int = 2, k: int = 2, embedding_dim: int = 4096):
        self.working_memory = SlidingWindowMemory(window_size=window_size)
        self.long_term_memory = RetrievalMemory(k=k, embedding_dim=embedding_dim)
        self.promotion_keywords = ["remember", "rule", "preference", "always", "never", "allergic"]

    def add_message(self, user_input: str, ai_response: str):
        self.working_memory.add_message(user_input, ai_response)
        # 触发关键词则进入长期记忆
        if any(k in user_input.lower() for k in self.promotion_keywords):
            self.long_term_memory.add_message(user_input, ai_response)

    def get_context(self, query: str) -> str:
        working = self.working_memory.get_context(query)
        long_term = self.long_term_memory.get_context(query)
        return f"### Long-Term:\n{long_term}\n\n### Working:\n{working}"

    def clear(self):
        self.working_memory.clear()
        self.long_term_memory.clear()

7) 图谱记忆(Graph Memory)

特点:抽取三元组构建知识图谱,适合关系推理。

import networkx as nx
import re

class GraphMemory(BaseMemoryStrategy):
    def __init__(self):
        self.graph = nx.DiGraph()

    def _extract_triples(self, text: str):
        prompt = (
            "Extract Subject-Relation-Object triples as Python tuples.\n"
            f"Text:\n{text}"
        )
        response = generate_text("You are a KG extractor.", prompt)
        return re.findall(r"\(['\"](.*?)['\"],\s*['\"](.*?)['\"],\s*['\"](.*?)['\"]\)", response)

    def add_message(self, user_input: str, ai_response: str):
        triples = self._extract_triples(f"User: {user_input}\nAI: {ai_response}")
        for s, r, o in triples:
            self.graph.add_edge(s.strip(), o.strip(), relation=r.strip())

    def get_context(self, query: str) -> str:
        if not self.graph.nodes:
            return "The knowledge graph is empty."
        entities = [w.capitalize() for w in query.replace("?", "").split() if w.capitalize() in self.graph.nodes]
        if not entities:
            return "No relevant entities from your query were found in the knowledge graph."
        facts = []
        for e in set(entities):
            for u, v, d in self.graph.out_edges(e, data=True):
                facts.append(f"{u} --[{d['relation']}]--> {v}")
        return "### Facts Retrieved from Knowledge Graph:\n" + "\n".join(sorted(set(facts)))

    def clear(self):
        self.graph.clear()

8) 压缩记忆(Compression)

特点:把每轮对话压缩为“极简事实”,超省 token。

class CompressionMemory(BaseMemoryStrategy):
    def __init__(self):
        self.compressed_facts = []

    def add_message(self, user_input: str, ai_response: str):
        prompt = (
            "Compress the following into its most essential factual statement.\n"
            f"User: {user_input}\nAI: {ai_response}"
        )
        fact = generate_text("You are a data compressor.", prompt)
        self.compressed_facts.append(fact)

    def get_context(self, query: str) -> str:
        if not self.compressed_facts:
            return "No compressed facts in memory."
        return "### Compressed Facts:\n- " + "\n- ".join(self.compressed_facts)

    def clear(self):
        self.compressed_facts = []

9) OS 类记忆(OS-Like)

特点:模拟“内存/硬盘”分页,按需调入旧信息。

class OSMemory(BaseMemoryStrategy):
    def __init__(self, ram_size: int = 2):
        self.ram_size = ram_size
        self.active_memory = deque()
        self.passive_memory = {}
        self.turn_count = 0

    def add_message(self, user_input: str, ai_response: str):
        turn_id = self.turn_count
        turn_data = f"User: {user_input}\nAI: {ai_response}"
        # RAM 满则页面换出
        if len(self.active_memory) >= self.ram_size:
            lru_id, lru_data = self.active_memory.popleft()
            self.passive_memory[lru_id] = lru_data
        # 新页面写入 RAM
        self.active_memory.append((turn_id, turn_data))
        self.turn_count += 1

    def get_context(self, query: str) -> str:
        active = "\n".join([d for _, d in self.active_memory])
        # 简化版“缺页”逻辑:关键词命中则调入
        paged_in = ""
        for tid, data in self.passive_memory.items():
            if any(w in data.lower() for w in query.lower().split() if len(w) > 3):
                paged_in += f"\n(Paged in Turn {tid}): {data}"
        return f"### RAM:\n{active}\n\n### Disk:\n{paged_in}"

    def clear(self):
        self.active_memory.clear()
        self.passive_memory = {}
        self.turn_count = 0

结语:记忆策略不是“选一个”,而是“组合搭配”

  • 短对话:顺序或滑窗即可。
  • 长对话:总结/压缩能有效控制成本。
  • 长期记忆:检索是主流做法。
  • 复杂关系:图谱能做结构化推理。
  • 大规模系统:分层或 OS 化管理更稳。
Logo

有“AI”的1024 = 2048,欢迎大家加入2048 AI社区

更多推荐