【Agent开发】第七阶段:RAG 自动化评估体系构建 (RAG Evaluation Framework) —— 从“凭感觉调优”到“数据驱动决策” – pd的AI Agent开发笔记

文章目录


前置环境:当前环境是基于WSL2 + Ubuntu 24.04 + Docker Desktop构建的云原生开发平台,所有服务(MySQL、Redis、Qwen)均以独立容器形式运行并通过Docker Compose统一编排。如何配置请参考我的博客 WSL2 + Ubuntu 24.04 + Docker Desktop 配置双内核环境 并且补充了milvus相关的配置,如何配置请参考我的博客 【Agent开发】第三阶段:RAG 实战 —— 赋予 Agent “外脑”。 并且引入了ES检索,并且配置了ES服务,ES部分的配置请查看我的博客 【Agent开发】第五阶段:RAG 深度优化实战 —— 从“可用”到“卓越”

0. 🐘 使用Postgresql存储测试集

为什么选它?(可能最后项目重构,我也会弃用milvus等只保留pgSQL)

  • JSONB 类型:PostgreSQL 提供 JSONB(二进制 JSON)类型,不仅存储 JSON,还进行了预解析和索引优化。查询速度极快,支持 GIN 索引,可以高效查询 JSON 内部的字段。
  • 混合模型:你可以在同一张表中既使用严格的关系型列(如 user_id, created_at),又使用灵活的 JSONB 列存储动态属性。这避免了为了几个动态字段就引入一个新的 NoSQL 数据库。
  • 生态统一:只需维护一个数据库,降低了运维复杂度(不用同时管 MySQL + MongoDB)。
  • 2026 趋势:随着 AI 应用的普及,PostgreSQL 凭借 pgvector(向量搜索)和强大的 JSON 处理能力,正成为“单一事实来源”的首选。

0.1. 配置前环境说明

0.1.1 基础环境
  • 操作系统/终端环境:Ubuntu(当前目录 /home/pdnbplus/ai-stack
  • 编排方式:Docker Compose(单机开发编排)
  • 统一网络:ai-net

0.1.2 现有服务基线

  • 缓存:Redis(redis-local
  • 关系型数据库:MySQL 8.0(mysql-local,端口 3307
  • 大模型服务:vLLM Qwen(qwen-local,端口 7575
  • 向量检索:Milvus(按需启动)
  • 搜索:Elasticsearch(按需启动)

0.1.3 本次 PostgreSQL 配置目标

  • 新增 PostgreSQL 16 + pgvector(容器:postgres-local
  • 主机访问端口:5433(容器内仍为 5432
  • 默认数据库:agent_dev
  • 默认用户:agent_user
  • 初始化扩展:vectoruuid-ossppgcrypto
  • 初始化示例表:agent_memory(用于智能体记忆/向量字段示例)

0.2. 配置步骤

步骤 1:编排中加入 PostgreSQL 服务

docker-compose.yml 中确认存在 postgres 服务,关键项如下:

  postgres:
    image: pgvector/pgvector:pg16
    container_name: postgres-local
    environment:
      POSTGRES_USER: agent_user
      POSTGRES_PASSWORD: xxxxx
      POSTGRES_DB: agent_dev
      TZ: Asia/Shanghai
      PGTZ: Asia/Shanghai
    ports:
      - "5433:5432"
    volumes:
      - postgres_data:/var/lib/postgresql/data
      - ./postgres/init:/docker-entrypoint-initdb.d
    healthcheck:
      test: ["CMD-SHELL", "pg_isready -U agent_user -d agent_dev"]
      interval: 15s
      timeout: 5s
      retries: 10
    restart: unless-stopped
    networks:
      - ai-net

同时在 volumes 下声明:

volumes:
  ...[之前的mysql]

  postgres_data:
    driver: local

步骤 2:准备初始化 SQL

目录:postgres/init

  • 001_extensions.sql:创建 vector/uuid-ossp/pgcrypto 扩展
CREATE EXTENSION IF NOT EXISTS vector;
CREATE EXTENSION IF NOT EXISTS "uuid-ossp";
CREATE EXTENSION IF NOT EXISTS pgcrypto;
  • 002_agent_schema.sql:创建 agent_memory 示例表与索引
CREATE TABLE IF NOT EXISTS agent_memory (
  id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
  user_id TEXT NOT NULL,
  content TEXT NOT NULL,
  embedding VECTOR(512),
  metadata JSONB NOT NULL DEFAULT '{}'::jsonb,
  created_at TIMESTAMPTZ NOT NULL DEFAULT now()
);

CREATE INDEX IF NOT EXISTS idx_agent_memory_user_id ON agent_memory(user_id);

说明:首次初始化(空数据卷)时会自动执行;若数据卷已存在,不会再次自动执行。

步骤 3:启动并验证 PostgreSQL

docker compose up -d postgres
docker compose ps postgres
docker exec -it postgres-local psql -U agent_user -d agent_dev -c "SELECT extname FROM pg_extension;"

预期:

  • postgres 状态为 healthy(或短暂 starting 后转 healthy
  • 扩展列表中可看到 vector

步骤 4:应用连接配置

  • 容器内连接:
    • postgresql://agent_user:<password>@postgres-local:5432/agent_dev
  • 宿主机连接:
    • postgresql://agent_user:<password>@127.0.0.1:5433/agent_dev

建议应用统一使用 DATABASE_URL 环境变量。

步骤 5:备份配置

已提供脚本:scripts/postgres_backup.sh

#!/usr/bin/env bash
set -Eeuo pipefail

CONTAINER_NAME="${CONTAINER_NAME:-postgres-local}"
BACKUP_DIR="${BACKUP_DIR:-/home/pdnbplus/ai-stack/backups/postgres}"
RETENTION_DAYS="${RETENTION_DAYS:-14}"
PG_USER="${PG_USER:-agent_user}"
PG_PASSWORD="${PG_PASSWORD:-postgresql_2002}"
PG_DATABASE="${PG_DATABASE:-agent_dev}"

timestamp="$(date '+%Y%m%d_%H%M%S')"
mkdir -p "$BACKUP_DIR"

lock_file="$BACKUP_DIR/.backup.lock"
exec 9>"$lock_file"
if ! flock -n 9; then
  echo "[backup] another backup is running, aborting"
  exit 1
fi

if ! docker inspect "$CONTAINER_NAME" >/dev/null 2>&1; then
  echo "[backup] container not found: $CONTAINER_NAME"
  exit 1
fi

if [[ "$(docker inspect -f '{{.State.Running}}' "$CONTAINER_NAME" 2>/dev/null)" != "true" ]]; then
  echo "[backup] container is not running: $CONTAINER_NAME"
  exit 1
fi

tmp_file="$BACKUP_DIR/postgres_${timestamp}.sql.gz.tmp"
final_file="$BACKUP_DIR/postgres_${timestamp}.sql.gz"
cleanup_tmp() {
  rm -f "$tmp_file"
}
trap cleanup_tmp ERR INT TERM

echo "[backup] started at $(date '+%F %T %Z')"
docker exec -e PGPASSWORD="$PG_PASSWORD" "$CONTAINER_NAME" \
  pg_dump -U "$PG_USER" -d "$PG_DATABASE" \
  | gzip -9 >"$tmp_file"

gzip -t "$tmp_file"
mv "$tmp_file" "$final_file"

deleted_count="$(
  find "$BACKUP_DIR" -maxdepth 1 -type f -name 'postgres_*.sql.gz' -mtime +"$RETENTION_DAYS" -print -delete | wc -l
)"

echo "[backup] finished at $(date '+%F %T %Z')"
echo "[backup] file: $final_file"
echo "[backup] size: $(du -h "$final_file" | awk '{print $1}')"
echo "[backup] retention deleted: $deleted_count file(s)"

手动执行:

/home/pdnbplus/ai-stack/scripts/postgres_backup.sh

定时任务示例(每天 22:10):

10 22 * * * /home/pdnbplus/ai-stack/scripts/postgres_backup.sh >> /home/pdnbplus/ai-stack/backups/postgres/backup.log 2>&1

0.3. 注意事项

0.3.1 初始化脚本执行时机

  • docker-entrypoint-initdb.d 只在“数据库目录为空”时执行。
  • 如果修改了初始化 SQL,需要清空并重建 postgres_data 后才会重新执行(会丢数据,谨慎)。

0.3.2 密码与权限

  • 当前编排中 POSTGRES_PASSWORD 已配置,建议后续改为更强密码。
  • 开发环境可用单用户;进入联调/共享环境建议拆分只读账号与写账号。

0.3.3 端口与网络

  • 宿主机访问用 5433;容器互访用服务名 postgres-local:5432
  • 避免应用里混用 localhost 与容器服务名导致连接错误。

0.3.4 向量维度一致性

  • 示例表使用 VECTOR(512),需与实际 embedding 模型维度一致。
  • 若模型维度不是 512,需调整建表语句后再初始化。

0.3.5 备份恢复演练

  • 建议每周至少做一次恢复演练,避免只备份不验证。
  • 恢复命令示例:
gunzip -c /path/to/postgres_xxx.sql.gz | docker exec -i -e PGPASSWORD=<password> postgres-local psql -U agent_user -d agent_dev

核心目标:建立一套可量化、可自动化、可视化的评估体系,精准定位 RAG 系统的瓶颈(是检索不准?还是生成幻觉?),并量化每一次代码变更(如引入 HyDE、Sentence-Child)的真实收益。
核心理念:No Evaluation, No Optimization.

第 1 讲:评估维度与指标体系设计

核心格言:“如果你不能衡量它,你就不能改进它。” (Lord Kelvin)
本讲目标:建立一套分层评估模型,能够精准定位问题是出在检索层(没找到)还是生成层(乱说话)。

1. 为什么传统 NLP 指标失效了?

在传统的机器翻译或摘要任务中,我们常用 BLEU (Bilingual Evaluation Understudy) 或 ROUGE(Recall-Oriented Understudy for Gisting Evaluation) 分数。它们通过计算候选答案标准答案之间的字面重叠率来打分。

简单来说:做翻译看 BLEU,做摘要看 ROUGE。
❌ 在 RAG 中,这完全行不通:

  • 用户问:“奇葩星球公司禁止报销哪种咖啡?”
  • 标准答案:“星巴克。”
  • 模型回答 A:“公司明确规定禁止报销星巴克咖啡。” (完美,但 BLEU 分数低,因为重叠词不多)
  • 模型回答 B:“星巴克咖啡是被禁止报销的。” (完美,BLEU 分数中等)
  • 模型回答 C:“星巴克。” (完美,BLEU 分数高)
  • 模型回答 D:“星巴克星巴克星巴克。” (废话,但 BLEU 分数极高!)
  • 结论:RAG 需要的是语义理解和逻辑推理的评估,而不是简单的字符串匹配。我们需要引入 “LLM-as-a-Judge” (让更强的 LLM 当裁判) 的理念。

补充知识

BLEU (Bilingual Evaluation Understudy) 和 ROUGE (Recall-Oriented Understudy for Gisting Evaluation) 是自然语言处理(NLP)领域中用于自动评估机器生成文本质量的两个最经典的指标。

它们的核心思想都是通过计算机算法将机器生成的文本(候选文本)与人类编写的参考文本(标准答案)进行对比,计算重合度来打分,从而避免昂贵且耗时的人工评估。

以下是两者的详细区别和应用场景:

1. BLEU (Bilingual Evaluation Understudy)

  • 主要用途:主要用于机器翻译(Machine Translation)任务的评估。
  • 核心逻辑:基于精确率 (Precision)。它计算候选文本中有多少 n-gram(连续的 n 个词)出现在了参考文本中。
    • 如果机器翻译出的词语在参考译文中出现过,就得高分。
    • 为了防止机器通过重复单词来刷分(例如参考是“猫在垫子上”,机器输出“猫 猫 猫…”),BLEU 引入了** brevity penalty (简短惩罚)**:如果生成的句子比参考句子短太多,分数会被大幅降低。
  • 计算公式特点
    • 通常计算 1-gram 到 4-gram 的精确率,然后取几何平均。
    • 分数范围是 0 到 1(有时表示为 0 到 100),越接近 1 表示质量越高。
  • 缺点
    • 对同义词不敏感(如果意思对但用词不同,分数会低)。
    • 只关注“机器输出的词是否在参考里”,不关心“参考里的词是否都被机器覆盖了”(即召回率低也没关系)。

2. ROUGE (Recall-Oriented Understudy for Gisting Evaluation)

  • 主要用途:主要用于自动摘要(Text Summarization)任务的评估,也可用于对话生成等。
  • 核心逻辑:基于召回率 (Recall)。它计算参考文本中有多少 n-gram 出现在了候选文本中。
    • 它的视角是:“人类写的摘要里的关键点,机器生成的摘要覆盖了多少?”
    • 这对于摘要任务很重要,因为摘要要求涵盖原文的核心信息,漏掉关键信息比多说几句废话更严重。
  • 常见变体
    • ROUGE-N:基于 n-gram 的重合度(如 ROUGE-1 看单词,ROUGE-2 看双词组合)。
    • ROUGE-L:基于最长公共子序列 (LCS)。它不要求词是连续的,只要顺序一致即可。这能更好地捕捉句子的结构相似性。
    • ROUGE-W:加权的最长公共子序列,给连续的匹配更高的权重。
  • 缺点
    • 同样无法理解语义,只匹配字面。
    • 如果机器生成了很长的文本覆盖了所有参考词,但包含大量无关废话,ROUGE 的召回率可能很高,但精确率很低(虽然 ROUGE 主要看召回,但通常也会结合 F1 值来看)。

总结对比表

特性 BLEU ROUGE
全称 Bilingual Evaluation Understudy Recall-Oriented Understudy for Gisting Evaluation
核心指标 精确率 (Precision) 为主 召回率 (Recall) 为主
典型应用 机器翻译 文本摘要、对话生成
关注点 机器生成的内容有多少是“对”的(在参考里) 参考内容有多少被机器“覆盖”了
惩罚机制 有简短惩罚 (Brevity Penalty),防止生成过短 通常无特定的长度惩罚,主要看覆盖度
常用变体 BLEU-4 (最常用) ROUGE-1, ROUGE-2, ROUGE-L

重要提示

虽然这两个指标被广泛使用,但它们并不完美

  1. 缺乏语义理解:它们只是机械地匹配字词。如果机器说“高兴”,参考说“开心”,BLEU/ROUGE 可能认为完全不匹配,尽管意思一样。
  2. 多参考译文问题:对于同一个输入,人类可能有多种正确的表达方式。如果只提供一个参考文本,分数可能会偏低。
  3. 新趋势:近年来,基于大模型(LLM)的评估方法(如使用 BERTScore 或让 LLM 直接打分)正在逐渐补充甚至替代传统的 BLEU/ROUGE,因为它们能更好地理解语义。

结论:RAG 需要的是语义理解逻辑推理的评估,而不是简单的字符串匹配。我们需要引入 “LLM-as-a-Judge” (让更强的 LLM 当裁判) 的理念。

2. RAG 评估的“黄金三角”模型

我们将 RAG 系统拆解为两个核心阶段,分别评估,最后再看整体。这就是业界标准的 RAG Triad

📐 维度一:检索层评估 (Retrieval Evaluation)

核心问题“我们找到的资料是否相关且完整?”
注意:这一层只评估 Context (检索到的片段),不关心 LLM 生成的答案。

指标名称 英文 含义 计算公式逻辑 优化方向
上下文召回率 Context Recall 查全率。标准答案里的信息,有多少包含在检索到的上下文中? 检索到的相关语句数 标准答案中的总语句数 \frac{\text{检索到的相关语句数}}{\text{标准答案中的总语句数}} 标准答案中的总语句数检索到的相关语句数 分块策略 (Sentence-Child)、混合检索、HyDE
上下文精度 Context Precision 查准率 + 排名。相关信息是否排在前面?(越靠前分数越高) 加权平均排名倒数 重排序 (Rerank)、RRF 融合策略
命中率 Hit Rate @K 前 K 个结果里有没有至少一个相关的? 命中次数 总问题数 \frac{\text{命中次数}}{\text{总问题数}} 总问题数命中次数 基础检索能力

💡 场景解读

  • 如果 Recall 低:说明文档库里根本没有答案,或者分块切碎了导致信息丢失。对策:检查分块大小,尝试 HyDE。
  • 如果 Precision 低:说明答案找到了,但被淹没在一堆垃圾信息里,排得很靠后。对策:加强重排序 (Rerank)。

📐 维度二:生成层评估 (Generation Evaluation)

核心问题“LLM 是否基于资料说了人话?”
注意:这一层评估 Answer,依赖检索到的 Context

指标名称 英文 含义 评判标准 (LLM Judge) 优化方向
忠实度 Faithfulness 防幻觉。答案里的每一句话都能从上下文中找到依据吗? “答案中的所有事实陈述是否都源自提供的上下文?” 提示词工程 (Prompt Engineering)、限制生成长度
答案相关性 Answer Relevance 不啰嗦。答案是否直接回答了用户的问题,没有无关废话? “答案是否解决了用户的疑问,且没有冗余信息?” 生成 Prompt 优化、精简 Context

💡 场景解读

  • 如果 Faithfulness 低:LLM 在“一本正经地胡说八道”。对策:在 Prompt 中强调“仅依据上下文回答,不知道就说不知道”。
  • 如果 Relevance 低:LLM 在“答非所问”或“长篇大论”。对策:优化 System Prompt,要求“简明扼要”。

📐 维度三:端到端评估 (End-to-End Evaluation)

核心问题“最终答案对不对?”
前提:你需要有 Ground Truth (标准答案)

指标名称 英文 含义 适用场景
答案正确性 Answer Correctness 模型答案与标准答案的语义相似度。 有明确标准答案的场景 (如考试、事实查询)
人工满意度 Human Feedback 用户点赞/点踩。 开放域对话,无标准答案

3. 核心诊断矩阵:如何定位瓶颈?

这是本讲最重要的实战工具。通过组合 Context Recall (检索质量) 和 Faithfulness (生成质量),我们可以将问题归类到四个象限:

高 Faithfulness (不幻觉) 低 Faithfulness (爱幻觉)
高 Context Recall (检索好) **✅ 理想区 (Perfect)**资料全且准,回答也靠谱。👉 保持现状 **⚠️ 生成瓶颈 (Generator Issue)**资料都给了,LLM 还是乱编。👉 优化生成 Prompt,换更强模型,或限制温度。
低 Context Recall (检索差) **🤐 保守区 (Conservative)**没找到资料,LLM 很诚实说“不知道”。👉 优化检索策略 (HyDE, 分块),扩大 Top-K。 **🔥 危险区 (Critical Failure)**没找到资料,LLM 还在瞎编。👉 最危险! 需同时优化检索 + 强制 Prompt 约束 (“不知为不知”)。

🧪 案例分析:你的“奇葩星球”系统

假设我们运行了一次评估,得到以下典型 Bad Case:

案例 A

  • 问题:“摸鱼 Transformer 有多少参数?”
  • 检索结果:[Chunk 1: 介绍公司文化], [Chunk 2: 报销流程]… (完全没有提到参数)
  • 模型回答:“摸鱼 Transformer 拥有 42000 个参数,非常高效。” (其实文档里写的是 42M,或者根本没写)
  • 指标:Context Recall = 0.0, Faithfulness = 0.2 (因为它瞎编了)
  • 诊断:🔥 危险区
  • 对策:检索完全失败。检查是否开启了 HyDE?检查分块是否把“42000”这个数字切碎了?

案例 B

  • 问题:“步行出差补贴多少?”
  • 检索结果:[Chunk 1: …规定每公里补贴 50 元…], [Chunk 2: …其他交通方式…] (答案就在 Chunk 1)
  • 模型回答:“根据规定,步行出差是非常光荣的行为。关于补贴,公司鼓励员工多走路,具体金额请参考财务手册。” (没说具体数字)
  • 指标:Context Recall = 1.0 (资料里有), Faithfulness = 1.0 (没瞎编), Answer Relevance = 0.4 (太低了,没回答问题)
  • 诊断:⚠️ 生成瓶颈 (相关性低)
  • 对策:修改生成 Prompt:“请直接提取具体的金额数字,不要说废话”。

4. 无参考评估 (Reference-Free) vs 有参考评估

在实际企业中,90% 的问题是没有标准答案 (Ground Truth) 的

  • 有参考 (With Ground Truth)

    • 需要人工编写 (Question, Answer) 对。
    • 指标:Answer Correctness, Context Recall (基于标准答案计算)。
    • 优点:准确。缺点:成本极高,难以覆盖长尾问题。
  • 无参考 (Reference-Free) 👈 我们将重点采用此模式

    • 不需要标准答案。
    • 指标:Faithfulness, Answer Relevance, Context Precision
    • 原理:利用 LLM 自我反思。
      • Faithfulness 判断逻辑:LLM 读取 ContextAnswer,问自己:“Answer 里的这句话能从 Context 推导出来吗?”
      • Relevance 判断逻辑:LLM 读取 QuestionAnswer,问自己:“Answer 直接回答了 Question 吗?”
    • 优点:成本低,可大规模自动化运行。

5. 📝 本讲总结与行动清单

核心知识点

  1. BLEU/ROUGE 已死,RAG 需要语义级评估。
  2. 分层评估:检索层 (Recall/Precision) vs 生成层 (Faithfulness/Relevance)。
  3. 诊断矩阵:利用 Recall 和 Faithfulness 的四象限定位瓶颈。
  4. 无参考评估:利用 LLM-as-a-Judge 在没有标准答案的情况下也能评估。

第 2 讲:构建黄金测试集 (Golden Dataset)

本讲目标:掌握利用 LLM 从现有文档中自动提取 (Question, Answer, Context) 三元组的技术,快速构建覆盖不同难度、不同场景的黄金测试集。

1. 为什么我们需要“黄金测试集”?

在 RAG 评估中,我们需要三种核心数据:

  1. Query (问题):用户可能会问什么?
  2. Ground Truth Context (标准上下文):回答这个问题必须参考的文档片段(用于计算 Recall)。
  3. Ground Truth Answer (标准答案):一个完美、准确的答案(用于计算 Correctness,可选)。

❌ 传统做法的痛点

  • 人工手写:太慢!写 50 道题要半天,写 500 道题要一周。
  • 覆盖不全:人工容易只想到简单问题,忽略“多跳推理”、“否定查询”等边缘情况。
  • 主观偏差:不同的人写的“标准答案”风格不一,影响评估一致性。

✅ 我们的做法:LLM 自举 (Self-Bootstrapping)

利用 LLM 强大的阅读理解能力,让它阅读你的文档切片,然后自己出题、自己答题、自己标注依据

  • 效率:10 分钟生成 100+ 高质量题目。
  • 多样性:通过 Prompt 控制,强制生成“简单”、“复杂”、“多跳”等不同类型题目。
  • 对齐:生成的答案天然基于文档,保证了 ContextAnswer 的逻辑一致性。

2. 黄金测试集的数据结构

我们将测试集保存为 JSONL 格式(每行一个 JSON 对象),方便流式处理和增量更新。

标准 Schema (golden_dataset.jsonl):

{
  "id": "q_001",
  "category": "policy", 
  "difficulty": "easy",
  "query": "奇葩星球公司对于迟到超过 30 分钟的处罚是什么?",
  "ground_truth_context": [
    "员工手册第三章:考勤管理。迟到 30 分钟以内扣除当日餐补;迟到 30 分钟至 2 小时,扣除半日工资并通报批评;迟到 2 小时以上视为旷工。"
  ],
  "ground_truth_answer": "迟到 30 分钟至 2 小时,扣除半日工资并通报批评;迟到 2 小时以上视为旷工。",
  "source_document": "employee_handbook.pdf",
  "metadata": {
    "generated_by": "gpt-4o-mini",
    "generation_date": "2026-03-06"
  }
}
  • query: 模拟用户提问。
  • ground_truth_context: 最关键字段。这是计算 Context Recall 的基准。必须是文档中的原话或高度浓缩的摘要。
  • ground_truth_answer: 用于计算 Answer Correctness(如果有)。如果是纯无参考评估,此字段可省略,但建议保留用于人工抽检。
  • difficulty: 标记难度 (easy, medium, hard),便于后续分层分析。

src\core\models.py 中新增一个模型RagEvalSample

class RagEvalSample(Base):
    __tablename__ = "rag_eval_samples"

    id = Column(String(64), primary_key=True)
    category = Column(String(64), nullable=False, default="general")
    difficulty = Column(String(16), nullable=False)
    query = Column(Text, nullable=False)
    ground_truth_context = Column(JSON, nullable=False)  # 标准 Schema: string[]
    ground_truth_answer = Column(Text, nullable=False)
    source_document = Column(String(255), nullable=True)
    meta = Column("metadata", JSON, nullable=False, default=dict)  # 标准 Schema: object
    # 示例数据
    # {
    #     "id": "q_001",
    #     "category": "policy", 
    #     "difficulty": "easy",
    #     "query": "奇葩星球公司对于迟到超过 30 分钟的处罚是什么?",
    #     "ground_truth_context": [
    #         "员工手册第三章:考勤管理。迟到 30 分钟以内扣除当日餐补;迟到 30 分钟至 2 小时,扣除半日工资并通报批评;迟到 2 小时以上视为旷工。"
    #     ],
    #     "ground_truth_answer": "迟到 30 分钟至 2 小时,扣除半日工资并通报批评;迟到 2 小时以上视为旷工。",
    #     "source_document": "employee_handbook.pdf",
    #     "metadata": {
    #         "generated_by": "gpt-4o-mini",
    #         "generation_date": "2026-03-06"
    #     }
    # }

    # 运行时追踪字段(用于回溯生成来源)
    source_chunk_index = Column(Integer, nullable=False)
    source_backend = Column(String(32), nullable=False, default="milvus")
    created_at = Column(BigInteger, nullable=False)

    __table_args__ = (
        Index("idx_rag_eval_samples_created_at", "created_at"),
        Index("idx_rag_eval_samples_difficulty", "difficulty"),
        Index("idx_rag_eval_samples_category", "category"),
    )

3. 实战:编写“出题专家”模块

我们将编写一个模块 src/augmented/,用于读取分块后的 Chunk,调用 LLM 生成测试题。

以下是几个每日免费调用LLM的网站:

  1. (Modelscope(魔塔社区))[https://www.modelscope.cn/my/access/token], 每日2000次
  2. (Google Genmini)[https://link.zhihu.com/?target=https%3A//ai.google.dev/gemini-api/docs/rate-limits]
  3. (智谱 AI (BigModel))[https://link.zhihu.com/?target=https%3A//bigmodel.cn/usercenter/proj-mgmt/rate-limits], Flash 系列模型(如 GLM-4.5-Flash)完全免费。主要限制并发数(Concurrency)。

🛠️ 代码实现

1. 🌅 新增PostgreSQL client

安装pgsql依赖

pip install psycopg2

延续之前的mysql方案,在postgresql中依然使用ORM框架,先src\core\postgres_client.py在创建一个PostgreSQL client。

import logging

from sqlalchemy import create_engine
from sqlalchemy.engine import Engine
from sqlalchemy.orm import Session, sessionmaker

from src.core.config import settings

logger = logging.getLogger(__name__)


class PostgresClient:
    def __init__(self, database_url: str | None = None) -> None:
        # 统一走 core 配置:像 MySQL 一样用一行 URL 管理 PostgreSQL 连接
        self.url = database_url or settings.db.postgres_database_url
        # 连接池参数复用全局 db 配置,避免在多处重复维护。
        self.engine: Engine = create_engine(
            self.url,
            pool_pre_ping=settings.db.pool_pre_ping,
            pool_size=settings.db.pool_size,
            max_overflow=settings.db.max_overflow,
            echo=settings.db.echo,
            future=True,
        )
        self.SessionLocal = sessionmaker(bind=self.engine, autoflush=False, autocommit=False, future=True)

    def get_session(self) -> Session:
        # 仅提供会话,不绑定具体业务逻辑;具体建表/写入由上层组件实现。
        return self.SessionLocal()


_postgres_client_instance: PostgresClient | None = None


def get_postgres_client(database_url: str | None = None) -> PostgresClient:
    """
    获取 PostgreSQL 客户端:
    - 未传 database_url: 返回全局单例(默认用于主流程)
    - 传入 database_url: 创建独立实例(用于特定功能隔离)
    """
    global _postgres_client_instance
    if database_url:
        return PostgresClient(database_url=database_url)
    if _postgres_client_instance is None:
        _postgres_client_instance = PostgresClient()
    return _postgres_client_instance

并在 src\core\config.py 中补充设置

    # PostgreSQL (评估数据等场景使用,一行 URL 配置即可) 
    postgres_database_url: str = "postgresql+psycopg2://agent_user:xxxx@127.0.0.1:5433/agent_dev"

src\augmented\sinks.py中编写实际client的功能。

from typing import Any, Dict, List

from sqlalchemy import text
from sqlalchemy.dialects.postgresql import insert

from src.core.postgres_client import get_postgres_client
from src.core.models import RagEvalSample

class PostgresSink:
    def __init__(self) -> None:
        self.client = get_postgres_client()
        self._test_connection()
        self._ensure_table()

    def _test_connection(self) -> None:
        with self.client.engine.connect() as conn:
            conn.execute(text("SELECT 1"))

    def _ensure_table(self) -> None:
        RagEvalSample.__table__.create(bind=self.client.engine, checkfirst=True)

    def save(self, rows: List[Dict[str, Any]]) -> None:
        if not rows:
            return

        payload: List[Dict[str, Any]] = []
        for row in rows:
            payload.append(
                {
                    "id": row["id"],
                    "category": row.get("category", "general"),
                    "difficulty": row["difficulty"],
                    "query": row["query"],
                    "ground_truth_context": row["ground_truth_context"],
                    "ground_truth_answer": row["ground_truth_answer"],
                    "source_document": row.get("source_document"),
                    "metadata": row.get("metadata", {}),
                    "source_chunk_index": row["source_chunk_index"],
                    "source_backend": row.get("source_backend", "milvus"),
                    "created_at": row["created_at"],
                }
            )

        upsert_stmt = insert(RagEvalSample.__table__).values(payload)
        upsert_stmt = upsert_stmt.on_conflict_do_update(
            index_elements=[RagEvalSample.id],
            set_={
                "category": upsert_stmt.excluded.category,
                "difficulty": upsert_stmt.excluded.difficulty,
                "query": upsert_stmt.excluded.query,
                "ground_truth_context": upsert_stmt.excluded.ground_truth_context,
                "ground_truth_answer": upsert_stmt.excluded.ground_truth_answer,
                "source_document": upsert_stmt.excluded.source_document,
                "metadata": upsert_stmt.excluded.metadata,
                "source_chunk_index": upsert_stmt.excluded.source_chunk_index,
                "source_backend": upsert_stmt.excluded.source_backend,
                "created_at": upsert_stmt.excluded.created_at,
            },
        )

        with self.client.get_session() as session:
            session.execute(upsert_stmt)
            session.commit()

src\augmented\sources.py中编写实际milvus的功能

# 数据源模块:当前仅从 Milvus 扫描 chunk 作为评估题生成输入。
from typing import Any, Dict, List

from src.core.milvus_client import get_milvus_client


class MilvusSource:
    def __init__(self) -> None:
        self.client = get_milvus_client()

    def load_chunks(self, limit: int) -> List[Dict[str, Any]]:
        # 使用 offset 分页拉取,直到达到 limit 或无更多数据。
        offset = 0
        batch_size = min(limit, 200)
        chunks: List[Dict[str, Any]] = []
        while len(chunks) < limit:
            rows = self.client.scan_collection(limit=batch_size, offset=offset)
            if not rows:
                break
            for row in rows:
                text = row.get("text", "")
                metadata = row.get("metadata", {}) or {}
                if text:
                    chunks.append({"text": text, "metadata": metadata})
                    if len(chunks) >= limit:
                        break
            if len(rows) < batch_size:
                break
            offset += batch_size
        return chunks
2. 🏭 构建LLM工厂

为了提高我们qwen小模型的能力,我们要使用更好的模型来生成参考答案,所以有必要向友商借用一些免费的LLM额度。

创建 src\augmented\llm_endpoints.json 以保存不同模型的配置信息。

{
  "llms": [
    {
      "url": "https://api-inference.modelscope.cn/v1/",
      "model": "Ling-2.5-1T",
      "api_key": "xxx",
      "temperature": 0.7
    },
    {
      "url": "https://open.bigmodel.cn/api/paas/v4",
      "model": "GLM-4.6",
      "api_key": "xxx",
      "temperature": 0.7
    },
    {
      "url": "http://localhost:7575/v1",
      "model": "qwen-3-4b",
      "api_key": "not-needed",
      "temperature": 0.7
    }
  ]
}

src\augmented\config.py 中定义评测配置。

import os
from dataclasses import dataclass

from src.core.config import settings


@dataclass
class GeneratorConfig:
    prompt_profile: str = "default"
    # chunks_limit 用来控制“这次最多从 Milvus 取多少个 chunk 来出题”。
    # 实际效果:
        # 限制本次评测集生成规模(处理速度和成本可控)
        # 间接控制最终样本量(最终样本数 ≈ 有效chunk数 × num_questions_per_chunk)
        # 开发调试时可设小(如 10),正式批量可调大(如 200+)
    chunks_limit: int = 10
    min_chunk_length: int = 50
    num_questions_per_chunk: int = 2
    llm_json_path: str = "src/augmented/llm_endpoints.json"
    default_model_name: str = "unknown"
    max_retries_per_chunk: int = 2


def build_default_config() -> GeneratorConfig:
    return GeneratorConfig(
        prompt_profile=os.getenv("EVAL_PROMPT_PROFILE", "default"),
        chunks_limit=int(os.getenv("EVAL_CHUNKS_LIMIT", "10")),
        min_chunk_length=int(os.getenv("EVAL_MIN_CHUNK_LENGTH", "50")),
        num_questions_per_chunk=int(os.getenv("EVAL_NUM_QUESTIONS", "2")),
        llm_json_path=os.getenv("EVAL_LLM_JSON_PATH", "src/augmented/llm_endpoints.json"),
        default_model_name=settings.llm.model_name,
    )

定义 LLM 路由 src\augmented\llm_router.py

# LLM 路由模块:
# 从一个 JSON 文件加载多个 endpoint,并按顺序调用(失败则降级到下一个)。
import json
import logging
from dataclasses import dataclass
from typing import Any, Dict, List, Optional, Tuple

from langchain_openai import ChatOpenAI
from langchain_core.prompts import ChatPromptTemplate

from src.augmented.config import GeneratorConfig
from src.utils.xml_parser import remove_think_and_n

logger = logging.getLogger(__name__)


@dataclass
class LLMEndpoint:
    # 对应 JSON 里的单个 LLM 节点配置。
    url: str
    model: str
    api_key: str
    temperature: float = 0.7


class LLMRouter:
    def __init__(self, config: GeneratorConfig) -> None:
        self.config = config
        self.endpoints = self._load_endpoints(config.llm_json_path)
        # 实例级降级游标:
        # 一旦第 k 个失败,后续调用将从第 k+1 个开始,不再尝试前 k 个。
        self._degrade_start_idx = 0

    def _load_endpoints(self, json_path: str) -> List[LLMEndpoint]:
        # 支持两种 JSON 形态:
        # 1) {"llms": [...]} 2) [...]
        with open(json_path, "r", encoding="utf-8") as f:
            payload = json.load(f)

        records = payload.get("llms", payload) if isinstance(payload, dict) else payload
        if not isinstance(records, list) or not records:
            raise ValueError(f"LLM JSON 配置无效或为空: {json_path}")

        endpoints: List[LLMEndpoint] = []
        for item in records:
            endpoints.append(
                LLMEndpoint(
                    url=item["url"],
                    model=item["model"],
                    api_key=item.get("api_key", "not-needed"),
                    temperature=float(item.get("temperature", 0.7)),
                )
            )
        return endpoints

    async def invoke(self, prompt: ChatPromptTemplate, payload: Dict[str, Any]) -> Tuple[str, Optional[str]]:
        # 顺序降级:第一个 endpoint 失败后切到下一个。
        last_error: Optional[Exception] = None
        for idx in range(self._degrade_start_idx, len(self.endpoints)):
            ep = self.endpoints[idx]
            try:
                llm = ChatOpenAI(
                    model=ep.model,
                    base_url=ep.url,
                    api_key=ep.api_key if ep.api_key else "not-needed",
                    temperature=ep.temperature,
                )
                chain = prompt | llm
                res = await chain.ainvoke(payload)
                text = remove_think_and_n(getattr(res, "content", "") or "")
                if text:
                    return text, ep.model
            except Exception as e:
                last_error = e
                # 触发实例级熔断:失败的当前节点及其之前节点都不再尝试。
                self._degrade_start_idx = max(self._degrade_start_idx, idx + 1)
                logger.warning("⚠️ LLM 调用失败,降级到下一个 endpoint。model=%s err=%s", ep.model, e)
                continue

        raise RuntimeError(f"所有 LLM endpoint 调用失败: {last_error}")
3. 🎡 构建提示词工厂

src\augmented\prompts.py 中定义提示词工厂类:

DEFAULT_PROMPT = """
你是一个专业的 RAG 评估数据集构建专家。
请阅读以下【文档片段】,并生成 {num_questions} 个高质量的评估测试题。

【要求】:
1. 题目需要覆盖 easy / medium / hard。
2. `ground_truth_context` 必须是字符串数组,数组元素直接摘录原文,不允许杜撰。
3. 每个问题必须可独立回答。
4. 严格输出 JSON 数组,数组元素字段固定为:
   category, difficulty, query, ground_truth_context, ground_truth_answer, source_document, metadata
5. `category` 可用值示例: policy/product/process/general。
6. `metadata` 至少包含 generated_by 和 generation_date(格式 YYYY-MM-DD)。
7. 不要输出 markdown,不要输出解释,只输出 JSON。

【文档片段】:
{context_chunk}
"""


class PromptRegistry:
    def __init__(self) -> None:
        self._prompts = {
            "default": DEFAULT_PROMPT,
            "strict_negative_first": """
你是 RAG 评估数据集专家。请先生成一个偏难问题,再补齐其余问题。
输出必须是 JSON 数组,每个元素字段为:
category, difficulty, query, ground_truth_context, ground_truth_answer, source_document, metadata。
其中 ground_truth_context 必须是字符串数组,metadata 必须包含 generated_by 和 generation_date。

文档片段:
{context_chunk}
需要生成数量:{num_questions}
""",
        }

    def get(self, profile: str) -> str:
        return self._prompts.get(profile, self._prompts["default"])
4. ⏫ 生成测试集

src\augmented\data_generator.py 中添加以下代码, 用于生成测试集:

import json
import asyncio
import time
import logging
from datetime import date
from typing import Any, Dict, List, Optional, Tuple

from langchain_core.prompts import ChatPromptTemplate
from pydantic import ValidationError

from src.augmented.config import GeneratorConfig, build_default_config
from src.augmented.llm_router import LLMRouter
from src.augmented.models import EvalSample
from src.augmented.prompts import PromptRegistry
from src.augmented.sinks import PostgresSink
from src.augmented.sources import MilvusSource

logger = logging.getLogger(__name__)


class DatasetGenerator:
    def __init__(self, config: Optional[GeneratorConfig] = None):
        self.config = config or build_default_config()
        self.prompt_registry = PromptRegistry()
        self.prompt = ChatPromptTemplate.from_template(self.prompt_registry.get(self.config.prompt_profile))

        self.router = LLMRouter(config=self.config)

        self.source = MilvusSource()
        self.sink = PostgresSink()

    @staticmethod
    def _safe_parse_json(content: str) -> List[Dict[str, Any]]:
        text = content.strip()
        if text.startswith("```json"):
            text = text[7:]
        if text.startswith("```"):
            text = text[3:]
        if text.endswith("```"):
            text = text[:-3]
        text = text.strip()

        data = json.loads(text)
        if isinstance(data, dict):
            data = [data]
        return data if isinstance(data, list) else []

    async def generate_from_chunk(
        self, chunk_text: str, num_questions: Optional[int] = None
    ) -> Tuple[List[Dict[str, Any]], Optional[str]]:
        question_num = num_questions or self.config.num_questions_per_chunk
        for attempt in range(self.config.max_retries_per_chunk + 1):
            try:
                content, model_used = await self.router.invoke(
                    prompt=self.prompt,
                    payload={"context_chunk": chunk_text, "num_questions": question_num},
                )
                raw_samples = self._safe_parse_json(content)
                valid_samples = []
                for raw in raw_samples:
                    try:
                        # 兼容模型偶发返回 string 的情况,统一转为 list[str]
                        if isinstance(raw.get("ground_truth_context"), str):
                            raw["ground_truth_context"] = [raw["ground_truth_context"]]
                        valid_samples.append(EvalSample(**raw).model_dump())
                    except ValidationError:
                        continue
                if valid_samples:
                    return valid_samples, model_used
            except Exception as e:
                logger.error("生成失败 (attempt=%s): %s", attempt + 1, e)
            await asyncio.sleep(0.2)
        return [], None

    async def generate(self) -> List[Dict[str, Any]]:
        chunks = self.source.load_chunks(limit=self.config.chunks_limit)
        if not chunks:
            logger.warning("⚠️ 未加载到任何 chunk,请检查 source 配置。")
            return []

        tasks: List[Tuple[int, str, Dict[str, Any]]] = []
        for idx, item in enumerate(chunks):
            text = item.get("text", "")
            if len(text.strip()) < self.config.min_chunk_length:
                continue
            tasks.append((idx, text, item.get("metadata", {}) or {}))

        coroutines = [self.generate_from_chunk(text) for _, text, _ in tasks]
        results = await asyncio.gather(*coroutines)

        all_samples: List[Dict[str, Any]] = []
        now_ts = int(time.time())
        today = date.today().isoformat()
        for bucket_idx, result_item in enumerate(results):
            source_idx, _, chunk_metadata = tasks[bucket_idx]
            samples, model_used = result_item
            source_document = chunk_metadata.get("source")
            for sample_idx, sample in enumerate(samples):
                sample_meta = sample.get("metadata", {}) or {}
                sample_meta.setdefault("generated_by", model_used or self.config.default_model_name)
                sample_meta.setdefault("generation_date", today)
                all_samples.append(
                    {
                        "id": f"gen_{now_ts}_{source_idx}_{sample_idx}",
                        "category": sample.get("category", chunk_metadata.get("category", "general")),
                        "difficulty": sample["difficulty"],
                        "query": sample["query"],
                        "ground_truth_context": sample["ground_truth_context"],
                        "ground_truth_answer": sample["ground_truth_answer"],
                        "source_document": sample.get("source_document") or source_document,
                        "metadata": sample_meta,
                        "source_chunk_index": source_idx,
                        "source_backend": "milvus",
                        "created_at": now_ts,
                    }
                )

        self.sink.save(all_samples)
        logger.info("✅ 数据集生成完成,样本数=%s,已写入 PostgreSQL", len(all_samples))
        return all_samples


if __name__ == "__main__":
    logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s %(name)s: %(message)s")

    generator = DatasetGenerator()
    asyncio.run(generator.generate())

最后在 __init__.py 中写入以下内容:

# augmented 包对外只暴露 DatasetGenerator,其他模块作为内部实现细节使用。
from .data_generator import DatasetGenerator

__all__ = ["DatasetGenerator"]

运行生成测试集

运行前要做配置的修改,就是设置milvus使用的数据集为之前recursive模式生成的数据集。

python -m src.augmented.data_generator

在这里插入图片描述

📊 4. 提升测试集质量的 3 个技巧

🌟 技巧 1:强制难度分布

在 Prompt 中明确要求:“生成的 3 个问题中,必须包含 1 个 Easy,1 个 Medium,1 个 Hard”。

  • Hard 的定义:需要在 Prompt 里举例,比如“需要结合两段不同文档的信息才能回答”。

🌟 技巧 2:引入“对抗性”问题 (Adversarial Examples)

专门生成一些文档中没有答案的问题。

  • 目的:测试系统的 Faithfulness。如果系统对着空气编出了答案,Faithfulness 分数应大幅降低。
  • Prompt 指令:“请生成一个看似相关但文档中完全未提及的问题,标准答案标记为‘根据提供的上下文无法回答’。”

🌟 技巧 3:人工清洗 (Human-in-the-Loop)

LLM 生成的题目可能有 10%~20% 的瑕疵(如问题歧义、上下文截断)。

  • 流程:生成 -> 存入 JSONL -> 快速人工浏览修正 (可用简单的 Streamlit 工具) -> 正式评估。
  • 不要追求 100% 自动化:花 30 分钟清洗 50 道题,比跑 1000 道垃圾题更有价值。

🎢 重构提示词工厂

现在要实现三个策略

  1. 标准策略(当前模式)
  2. 对抗策略(文档无答案,固定答案“根据提供的上下文无法回答”)
  3. 随机拼接策略(两段拼接,强制 3 题:Easy/Medium/Hard,Hard 需跨片段推理)
# Prompt 注册中心:集中管理不同出题风格模板,按 profile 选择。
STANDARD_PROMPT = """
你是一个专业的 RAG 评估数据集构建专家。
请阅读以下【文档片段】,并生成 {num_questions} 个高质量的评估测试题。

【要求】:
1. 题目需要覆盖 easy / medium / hard。
2. `ground_truth_context` 必须是字符串数组,数组元素直接摘录原文,不允许杜撰。
3. 每个问题必须可独立回答。
4. 严格输出 JSON 数组,数组元素字段固定为:
   category, difficulty, query, ground_truth_context, ground_truth_answer, source_document, metadata
5. `category` 可用值示例: policy/product/process/general。
6. `metadata` 至少包含 generated_by 和 generation_date(格式 YYYY-MM-DD)。
7. 不要输出 markdown,不要输出解释,只输出 JSON。

【文档片段】:
{context_chunk}
"""

ADVERSARIAL_PROMPT = """
你是一个 RAG 鲁棒性评估专家。请基于以下文档片段,生成 {num_questions} 个“对抗性问题”。

【目标】:
1. 问题必须“看似相关”,但文档中完全未提及该问题答案。
2. 标准答案必须固定写为:根据提供的上下文无法回答
3. ground_truth_context 必须为 [](空数组),表示上下文中无可支持证据。
4. difficulty 固定为 hard。
5. 严格输出 JSON 数组,字段固定为:
   category, difficulty, query, ground_truth_context, ground_truth_answer, source_document, metadata
6. 不要输出 markdown,不要解释,只输出 JSON。

【文档片段】:
{context_chunk}
"""

MIXED_PAIR_PROMPT = """
你是一个多跳检索评估专家。下面有两个文档片段(A 与 B),请基于它们生成 {num_questions} 个问题。

【强约束】:
1. 必须且仅生成 3 个问题:1 个 Easy、1 个 Medium、1 个 Hard。
2. Hard 的定义:必须结合片段 A 和片段 B 的信息才能回答;如果只看其中一个片段无法完整回答。
3. 每个问题都要提供 ground_truth_context(字符串数组,摘录原文)与 ground_truth_answer。
4. 严格输出 JSON 数组,字段固定为:
   category, difficulty, query, ground_truth_context, ground_truth_answer, source_document, metadata
5. 不要输出 markdown,不要解释,只输出 JSON。

【片段 A】:
{chunk_a}

【片段 B】:
{chunk_b}
"""


class PromptRegistry:
    def __init__(self) -> None:
        # 可按需扩展更多 profile,调用侧只传 profile 名称。
        self._prompts = {
            "default": STANDARD_PROMPT,
            "standard": STANDARD_PROMPT,
            "adversarial": ADVERSARIAL_PROMPT,
            "mixed_pair": MIXED_PAIR_PROMPT,
            "strict_negative_first": """
你是 RAG 评估数据集专家。请先生成一个偏难问题,再补齐其余问题。
输出必须是 JSON 数组,每个元素字段为:
category, difficulty, query, ground_truth_context, ground_truth_answer, source_document, metadata。
其中 ground_truth_context 必须是字符串数组,metadata 必须包含 generated_by 和 generation_date。

文档片段:
{context_chunk}
需要生成数量:{num_questions}
""",
        }

    def get(self, profile: str) -> str:
        # 未命中时回退到 default,保证流程可运行。
        return self._prompts.get(profile, self._prompts["default"])

新增策略模式(src\augmented\strategies.py)

# 策略模式:将“如何从 chunk 组织出题任务”从主流程剥离,便于扩展新策略。
import random
from abc import ABC, abstractmethod
from dataclasses import dataclass, field
from typing import Any, Dict, List


@dataclass
class StrategyTask:
    strategy_name: str
    prompt_profile: str
    payload: Dict[str, Any]
    source_chunk_indices: List[int]
    source_metadata: Dict[str, Any] = field(default_factory=dict)
    num_questions: int = 1


class BaseGenerationStrategy(ABC):
    name: str

    @abstractmethod
    def build_tasks(self, chunks: List[Dict[str, Any]]) -> List[StrategyTask]:
        raise NotImplementedError

    def postprocess_samples(self, samples: List[Dict[str, Any]], _task: StrategyTask) -> List[Dict[str, Any]]:
        return samples


class StandardChunkStrategy(BaseGenerationStrategy):
    name = "standard"
    default_num_questions = 2

    def __init__(self, num_questions: int | None = None):
        self.num_questions = num_questions if num_questions is not None else self.default_num_questions

    def build_tasks(self, chunks: List[Dict[str, Any]]) -> List[StrategyTask]:
        tasks: List[StrategyTask] = []
        for idx, item in enumerate(chunks):
            tasks.append(
                StrategyTask(
                    strategy_name=self.name,
                    prompt_profile="standard",
                    payload={
                        "context_chunk": item.get("text", ""),
                        "num_questions": self.num_questions,
                    },
                    source_chunk_indices=[idx],
                    source_metadata=item.get("metadata", {}) or {},
                    num_questions=self.num_questions,
                )
            )
        return tasks


class AdversarialStrategy(BaseGenerationStrategy):
    name = "adversarial"
    fixed_answer = "根据提供的上下文无法回答"
    default_num_questions = 1

    def __init__(self, num_questions: int | None = None):
        self.num_questions = num_questions if num_questions is not None else self.default_num_questions

    def build_tasks(self, chunks: List[Dict[str, Any]]) -> List[StrategyTask]:
        tasks: List[StrategyTask] = []
        for idx, item in enumerate(chunks):
            tasks.append(
                StrategyTask(
                    strategy_name=self.name,
                    prompt_profile="adversarial",
                    payload={
                        "context_chunk": item.get("text", ""),
                        "num_questions": self.num_questions,
                    },
                    source_chunk_indices=[idx],
                    source_metadata=item.get("metadata", {}) or {},
                    num_questions=self.num_questions,
                )
            )
        return tasks

    def postprocess_samples(self, samples: List[Dict[str, Any]], _task: StrategyTask) -> List[Dict[str, Any]]:
        # 对抗题目的标准答案由程序强制注入,避免模型偏离指令。
        for s in samples:
            s["ground_truth_answer"] = self.fixed_answer
            s["difficulty"] = "hard"
            s["ground_truth_context"] = []
        return samples


class MixedPairStrategy(BaseGenerationStrategy):
    name = "mixed_pair"
    default_pair_count = 2
    default_num_questions = 3

    def __init__(self, pair_count: int | None = None, num_questions: int | None = None, seed: int | None = None):
        self.pair_count = pair_count if pair_count is not None else self.default_pair_count
        # 业务约束:该策略必须生成 3 题(E/M/H),若传入非法值则强制回到 3。
        n = num_questions if num_questions is not None else self.default_num_questions
        self.num_questions = 3 if n != 3 else n
        self.seed = seed

    def build_tasks(self, chunks: List[Dict[str, Any]]) -> List[StrategyTask]:
        if len(chunks) < 2:
            return []

        indices = list(range(len(chunks)))
        if self.seed is not None:
            random.seed(self.seed)
        random.shuffle(indices)
        max_pairs = max(0, min(self.pair_count, len(indices) // 2))

        tasks: List[StrategyTask] = []
        for i in range(max_pairs):
            a_idx = indices[2 * i]
            b_idx = indices[2 * i + 1]
            a = chunks[a_idx]
            b = chunks[b_idx]
            tasks.append(
                StrategyTask(
                    strategy_name=self.name,
                    prompt_profile="mixed_pair",
                    payload={
                        "chunk_a": a.get("text", ""),
                        "chunk_b": b.get("text", ""),
                        "num_questions": self.num_questions,
                    },
                    source_chunk_indices=[a_idx, b_idx],
                    source_metadata={
                        "chunk_a_metadata": a.get("metadata", {}) or {},
                        "chunk_b_metadata": b.get("metadata", {}) or {},
                    },
                    num_questions=self.num_questions,
                )
            )
        return tasks


def build_strategies(enabled_strategies: str, strategy_params: Dict[str, Dict[str, Any]] | None = None) -> List[BaseGenerationStrategy]:
    enabled = {x.strip().lower() for x in enabled_strategies.split(",") if x.strip()}
    params = strategy_params or {}

    strategy_map = {
        "standard": StandardChunkStrategy(**params.get("standard", {})),
        "adversarial": AdversarialStrategy(**params.get("adversarial", {})),
        "mixed_pair": MixedPairStrategy(**params.get("mixed_pair", {})),
    }
    return [strategy_map[name] for name in ("standard", "adversarial", "mixed_pair") if name in enabled]

📄 修改配置项(src\augmented\config.py)

# 该模块负责评估数据生成器的最小运行配置:
# 只保留任务规模、提示词配置和 LLM JSON 配置文件路径。
import os
from dataclasses import dataclass

from src.core.config import settings


@dataclass
class GeneratorConfig:
    # Prompt 模板档位
    prompt_profile: str = "default"
    # 本次最多处理多少个 chunk
    chunks_limit: int = 10
    # 过滤过短 chunk,避免无效出题
    min_chunk_length: int = 50
    # 每个 chunk 生成的问题数
    num_questions_per_chunk: int = 2
    # LLM 端点配置 JSON 路径
    llm_json_path: str = "src/Augmented/llm_endpoints.json"
    # 当生成失败或无返回模型信息时的兜底名
    default_model_name: str = "unknown"
    # 单个 chunk 的最大重试次数
    max_retries_per_chunk: int = 2
    # 启用策略列表(逗号分隔):standard,adversarial,mixed_pair
    enabled_strategies: str = "standard,adversarial,mixed_pair"
    # 策略参数透传(JSON 字符串),示例:
    # {"standard":{"num_questions":2},"adversarial":{"num_questions":1},"mixed_pair":{"pair_count":2,"num_questions":3}}
    strategy_params_json: str = "{}"


def build_default_config() -> GeneratorConfig:
    # 运行参数主要通过环境变量注入,便于本地和线上统一调参。
    return GeneratorConfig(
        prompt_profile=os.getenv("EVAL_PROMPT_PROFILE", "default"),
        chunks_limit=int(os.getenv("EVAL_CHUNKS_LIMIT", "10")),
        min_chunk_length=int(os.getenv("EVAL_MIN_CHUNK_LENGTH", "50")),
        num_questions_per_chunk=int(os.getenv("EVAL_NUM_QUESTIONS", "2")),
        llm_json_path=os.getenv("EVAL_LLM_JSON_PATH", "src/Augmented/llm_endpoints.json"),
        default_model_name=settings.llm.model_name,
        enabled_strategies=os.getenv("EVAL_ENABLED_STRATEGIES", "standard,adversarial,mixed_pair"),
        strategy_params_json=os.getenv("EVAL_STRATEGY_PARAMS_JSON", "{}"),
    )

🐽 修改主流程(src\augmented\data_generator.py)

# 主编排器(策略模式):
# Milvus 取 chunk -> 按策略构建任务 -> LLM 生成评估样本 -> PostgreSQL 持久化。
import ast
import json
import logging
import time
from datetime import date
from typing import Any, Dict, List, Optional, Tuple

from langchain_core.prompts import ChatPromptTemplate
from pydantic import ValidationError

from src.augmented.config import GeneratorConfig, build_default_config
from src.augmented.llm_router import LLMRouter
from src.augmented.models import EvalSample
from src.augmented.prompts import PromptRegistry
from src.augmented.sinks import PostgresSink
from src.augmented.sources import MilvusSource
from src.augmented.strategies import StrategyTask, build_strategies

logger = logging.getLogger(__name__)


class DatasetGenerator:
    def __init__(self, config: Optional[GeneratorConfig] = None):
        # [初始化-1] 读取运行配置(环境变量 + 默认值)
        self.config = config or build_default_config()
        # [初始化-2] 准备 Prompt 注册器与缓存(不同策略使用不同模板)
        self.prompt_registry = PromptRegistry()
        self.prompt_cache: Dict[str, ChatPromptTemplate] = {}
        # [初始化-3] 解析策略参数(JSON 字符串 -> dict)
        self.strategy_params = self._safe_parse_strategy_params(self.config.strategy_params_json)
        # [初始化-4] 为 standard 策略补默认题目数(向后兼容旧配置)
        self.strategy_params.setdefault("standard", {})
        self.strategy_params["standard"].setdefault("num_questions", self.config.num_questions_per_chunk)

        # [初始化-5] 准备核心组件:LLM 路由、数据源、存储端、策略集合
        self.router = LLMRouter(config=self.config)
        self.source = MilvusSource()
        self.sink = PostgresSink()
        self.strategies = build_strategies(self.config.enabled_strategies, self.strategy_params)

    def _get_prompt(self, profile: str) -> ChatPromptTemplate:
        if profile not in self.prompt_cache:
            self.prompt_cache[profile] = ChatPromptTemplate.from_template(self.prompt_registry.get(profile))
        return self.prompt_cache[profile]

    @staticmethod
    def _safe_parse_strategy_params(raw: str) -> Dict[str, Dict[str, Any]]:
        try:
            parsed = json.loads(raw) if raw else {}
            return parsed if isinstance(parsed, dict) else {}
        except Exception:
            return {}

    @staticmethod
    def _extract_json_candidate(text: str) -> str:
        l_arr = text.find("[")
        r_arr = text.rfind("]")
        if l_arr != -1 and r_arr != -1 and r_arr > l_arr:
            return text[l_arr : r_arr + 1]

        l_obj = text.find("{")
        r_obj = text.rfind("}")
        if l_obj != -1 and r_obj != -1 and r_obj > l_obj:
            return text[l_obj : r_obj + 1]

        return text

    @staticmethod
    def _safe_parse_json(content: str) -> List[Dict[str, Any]]:
        text = content.strip()
        if text.startswith("```json"):
            text = text[7:]
        if text.startswith("```"):
            text = text[3:]
        if text.endswith("```"):
            text = text[:-3]
        text = text.strip()

        try:
            data = json.loads(text)
        except json.JSONDecodeError:
            candidate = DatasetGenerator._extract_json_candidate(text)
            try:
                data = json.loads(candidate)
            except json.JSONDecodeError:
                data = ast.literal_eval(candidate)

        if isinstance(data, dict):
            data = [data]
        return data if isinstance(data, list) else []

    def generate_from_task(self, task: StrategyTask) -> Tuple[List[Dict[str, Any]], Optional[str]]:
        # [任务执行-1] 按策略名选择 Prompt 模板
        last_error_msg = ""
        prompt = self._get_prompt(task.prompt_profile)
        # [任务执行-2] 单任务重试循环(网络抖动/模型输出异常时重试)
        for attempt in range(self.config.max_retries_per_chunk + 1):
            try:
                # [任务执行-3] 调用 LLM(返回文本 + 实际使用模型名)
                content, model_used = self.router.invoke(prompt=prompt, payload=task.payload)
                # [任务执行-4] 解析 JSON(带容错)
                raw_samples = self._safe_parse_json(content)
                valid_samples = []
                validation_failed = 0
                for raw in raw_samples:
                    try:
                        # [任务执行-5] 字段清洗与结构校验(Pydantic)
                        if isinstance(raw.get("ground_truth_context"), str):
                            raw["ground_truth_context"] = [raw["ground_truth_context"]]
                        valid_samples.append(EvalSample(**raw).model_dump())
                    except ValidationError:
                        validation_failed += 1
                        continue
                if valid_samples:
                    return valid_samples, model_used
                last_error_msg = (
                    f"JSON解析成功但无有效样本(raw={len(raw_samples)}, validation_failed={validation_failed})"
                )
            except Exception as e:
                last_error_msg = str(e)
                logger.error(
                    "任务生成失败 (attempt=%s): %s | strategy=%s | content_preview=%s",
                    attempt + 1,
                    e,
                    task.strategy_name,
                    content[:200] if "content" in locals() else "",
                )
            time.sleep(0.2)

        logger.warning(
            "任务生成失败,已放弃。strategy=%s reason=%s | task_preview=%s",
            task.strategy_name,
            last_error_msg,
            str(task.payload)[:160],
        )
        return [], None

    def generate(self) -> List[Dict[str, Any]]:
        # [主流程-1] 从 Milvus 拉取候选 chunk
        chunks = self.source.load_chunks(limit=self.config.chunks_limit)
        logger.info(
            "开始生成评估集: chunks_limit=%s, min_chunk_length=%s, standard_num_questions=%s",
            self.config.chunks_limit,
            self.config.min_chunk_length,
            self.config.num_questions_per_chunk,
        )
        logger.info("启用策略=%s", [s.name for s in self.strategies])
        logger.info("策略参数=%s", self.strategy_params)
        logger.info("Milvus加载到chunk数量=%s", len(chunks))
        if not chunks:
            logger.warning("⚠️ 未加载到任何 chunk,请检查 source 配置。")
            return []

        # [主流程-2] 预过滤过短 chunk,避免低质量样本
        filtered_chunks: List[Dict[str, Any]] = []
        filtered_short = 0
        for item in chunks:
            text = item.get("text", "")
            if len(text.strip()) < self.config.min_chunk_length:
                filtered_short += 1
                continue
            filtered_chunks.append(item)
        logger.info("过滤过短chunk=%s, 可用chunk=%s", filtered_short, len(filtered_chunks))

        # [主流程-3] 各策略构建任务(standard / adversarial / mixed_pair)
        strategy_tasks: List[Tuple[StrategyTask, str]] = []
        strategy_map = {}
        for strategy in self.strategies:
            strategy_map[strategy.name] = strategy
            tasks = strategy.build_tasks(filtered_chunks)
            for t in tasks:
                strategy_tasks.append((t, strategy.name))
        logger.info("策略任务总数=%s", len(strategy_tasks))

        # [主流程-4] 逐任务执行 LLM + 策略后处理(如对抗题答案强制注入)
        raw_results: List[Tuple[StrategyTask, str, List[Dict[str, Any]], Optional[str]]] = []
        for task, strategy_name in strategy_tasks:
            samples, model_used = self.generate_from_task(task)
            samples = strategy_map[strategy_name].postprocess_samples(samples, task)
            raw_results.append((task, strategy_name, samples, model_used))

        success_tasks = sum(1 for _, _, samples, _ in raw_results if samples)
        failed_tasks = len(raw_results) - success_tasks
        total_generated = sum(len(samples) for _, _, samples, _ in raw_results)
        logger.info(
            "任务生成统计: success_tasks=%s, failed_tasks=%s, generated_samples=%s",
            success_tasks,
            failed_tasks,
            total_generated,
        )

        # [主流程-5] 统一组装落库数据(补 model_name、strategy、source 索引等追踪字段)
        all_samples: List[Dict[str, Any]] = []
        now_ts = int(time.time())
        today = date.today().isoformat()

        for task, strategy_name, samples, model_used in raw_results:
            first_idx = task.source_chunk_indices[0] if task.source_chunk_indices else -1
            if strategy_name == "mixed_pair":
                a_src = task.source_metadata.get("chunk_a_metadata", {}).get("source")
                b_src = task.source_metadata.get("chunk_b_metadata", {}).get("source")
                source_document = f"{a_src}|{b_src}"
            else:
                source_document = task.source_metadata.get("source")

            for sample_idx, sample in enumerate(samples):
                sample_meta = sample.get("metadata", {}) or {}
                sample_meta.setdefault("generated_by", model_used or self.config.default_model_name)
                sample_meta.setdefault("generation_date", today)
                sample_meta.setdefault("strategy", strategy_name)
                sample_meta.setdefault("source_chunk_indices", task.source_chunk_indices)

                all_samples.append(
                    {
                        "id": f"gen_{now_ts}_{strategy_name}_{first_idx}_{sample_idx}",
                        "category": sample.get("category", task.source_metadata.get("category", "general")),
                        "difficulty": sample["difficulty"],
                        "query": sample["query"],
                        "ground_truth_context": sample["ground_truth_context"],
                        "ground_truth_answer": sample["ground_truth_answer"],
                        "source_document": sample.get("source_document") or source_document,
                        "model_name": model_used or self.config.default_model_name,
                        "metadata": sample_meta,
                        "source_chunk_index": first_idx,
                        "source_backend": "milvus",
                        "created_at": now_ts,
                    }
                )

        # [主流程-6] 批量写入 PostgreSQL
        self.sink.save(all_samples)
        logger.info("✅ 数据集生成完成,样本数=%s,已写入 PostgreSQL", len(all_samples))
        return all_samples


if __name__ == "__main__":
    # python -m src.augmented.data_generator
    logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s %(name)s: %(message)s")

    generator = DatasetGenerator()
    generator.generate()

运行生成测试集脚本

python -m src.augmented.data_generator

🏗️ Feedback 模块设计

这个 Augmented 模块本质上是在做一条“自动造评测集”的流水线,目标是给 RAG 系统持续产出可评估数据,并直接落库到 PostgreSQL。

整体流程是:

  1. 从 Milvus 抽取 chunk

    • sources.py 里,MilvusSource.load_chunks() 会分页扫描向量库,把文本和 metadata 拉出来,作为出题原料。
  2. 用策略模式把“怎么出题”拆开

    • strategies.py 里,定义了 3 种策略:
      • standard:每个 chunk 出 N 个常规问题
      • adversarial:生成“文档里没有答案”的问题,专门测 Faithfulness
      • mixed_pair:随机拼两段 chunk,强制出 1 Easy + 1 Medium + 1 Hard(Hard 要跨段推理)
    • 每个策略先生成 StrategyTask,主流程再统一执行这些任务。
  3. 按顺序调用多个 LLM 端点(失败降级)

    • 在 llm_router.py 里,读取 llm_endpoints.jsonurl/model/api_key/temperature),按顺序尝试:
    • 第一个可用就用它
    • 当前失败就降级到下一个
    • 并且实例内有“降级游标”:前面失败过的端点后续不再回头尝试
  4. 对 LLM 输出做强容错解析和结构校验

    • 在 data_generator.py:
    • 先清理 markdown 代码块
    • JSON 解析失败时,尝试提取 JSON 片段再解析
    • 还失败就用 ast.literal_eval 兜底
    • 最后用 EvalSample(见 models.py)做字段校验,过滤无效样本
  5. 对策略结果做后处理
    例如 adversarial 策略会程序端强制:

    • ground_truth_answer = “根据提供的上下文无法回答”
    • ground_truth_context = []
    • difficulty = hard
      这样保证“无答案题”数据的一致性,不依赖模型是否听话。
  6. 统一补充追踪信息并入库

    • 主流程会给每条样本补:
      • model_name(实际生成模型)
      • metadata.generated_by / generation_date
      • metadata.strategy / source_chunk_indices
    • 然后交给 sinks.py 写 PostgreSQL(upsert)。
  7. PostgreSQL 存储是 ORM 方案

    • 底层模型在 src/core/models.pyRagEvalSample,连接由 src/core/postgres_client.py 提供,PostgresSink 负责建表检查与批量 upsert
  8. 配置方式

    • Augmented 的运行配置在 config.py,重点是:
      • chunks_limit / min_chunk_length / num_questions_per_chunk
      • enabled_strategies
      • strategy_params_json(策略专属参数透传)
      • llm_json_path

一句话总结:
这个模块就是“从知识库自动生成高质量评测样本(含常规题、对抗题、多跳题)并结构化入库”的自动化评估数据工厂。

📂 当前项目结构解析

./augmented/ 📂
├── __init__.py         📦 # 包导出入口,导出 DatasetGenerator
├── config.py           ⚙️ # 负责模块运行配置(chunk 数量、最小长度、策略开关、策略参数 JSON、LLM 配置文件路径)
├── data_generator.py   🏭 # 主编排入口:Milvus 抽取 -> 策略构任务 -> 调 LLM -> 解析校验 -> 组装样本 -> 写 PostgreSQL
├── llm_endpoints.json  🔌 # LLM 端点配置文件
├── llm_router.py       🚦 # LLM 调用路由层
├── models.py           🗃️ # Pydantic 输出结构定义
├── prompts.py          💬 # Prompt 模板中心
├── sinks.py            🪣 # 落库适配层 (数据汇)
├── sources.py          📥 # 数据源适配层 (数据源)
└── strategies.py       ♟️ # 策略模式核心

🌊 系统数据流向架构图

红色线为数据流向示意:

在这里插入图片描述

Logo

有“AI”的1024 = 2048,欢迎大家加入2048 AI社区

更多推荐