Agentic RL实战:打造自主学习自主迭代的高性能 Agent

微软Agent-Lightning快速入门

1. 端到端Agentic RL解决方案:Agent-Lightning

​ Agent Lightning 是一个用于训练和优化 AI 智能体的框架,它采用强化学习、自动提示优化、监督式微调和其他算法。本页面概述了该系统的架构、核心概念和主要工作流程。

image-20251109182801068

​ Agent Lightning 通过检测智能体交互、收集遥测数据并应用学习算法来改进智能体行为,从而以最小的代码更改实现 AI 智能体的优化。该系统支持多种智能体框架(LangChain、OpenAI Agent SDK、AutoGen、CrewAI、Anthropic)和各种优化算法(RL、APO、SFT)。

主要特点:

特征 描述
与框架无关 通过通用工具与任何代理框架配合使用
最小侵入 几乎不需要对现有代理进行任何代码更改。
灵活部署 支持单进程、多进程和分布式执行
算法多样性 包括 RL、APO、SFT 和自定义算法支持
生产就绪 全面的 CI/CD、测试和部署基础设施
image-20251109182814172

Agent Lightning软件包由多个核心模块组成,分别处理训练和优化工作流程的不同方面。

image-20251109183204527

同时,Agent Lightning 采用解耦架构,其中LightningStore充当中央消息队列和数据库,协调算法和运行器之间的通信。

image-20251109183254902

而在实际运行过程中,Agent核心训练循环遵循生产者-消费者(producer-consumer)模式,其中算法生成任务,而跑者则消费这些任务:

image-20251109183347686

详细的 Agent RL 原理及入门实战代码 加入 赋范空间 免费领取

image-20251109193644253

2.基于LangGraph的SQL-Agent强化学习微调流程

2.1 整体架构概述:运行与训练的分离机制

​ 在进行 Agent 的强化学习微调之前,我们首先需要理解这一项目的整体架构。整个系统采用了**“运行与训练分离”**的设计思想,也就是将 Agent 的实际执行逻辑(即如何生成、执行与修正 SQL 语句)与强化学习训练过程(即如何计算奖励、更新模型参数)进行解耦。这种架构的设计灵感来源于工业级 RLHF(Reinforcement Learning from Human Feedback)系统:一部分负责与环境交互,另一部分负责模型优化,从而实现可扩展、可维护的训练流程。

具体而言,系统分为两大核心模块:

  1. 运行模块(Agent 运行脚本)
    该部分由 LangGraph 构建的 SQL-Agent 负责,实现从自然语言问题到 SQL 执行结果的完整推理流程。此模块重点在于如何让 Agent 具备可观测、可追踪的行为,从而为后续强化学习提供高质量的数据轨迹。
  2. 训练模块(Agent 训练脚本)
    这一部分主要依托 Agent-Lightning 框架与 veRL (Volcengine Reinforcement Learning) 训练系统完成。其作用是利用运行模块产生的行为轨迹,对底层基座模型进行基于 GRPO (Group Relative Policy Optimization) 算法的优化,从而实现策略提升与能力迁移。
image-20251109190604832

通过这种结构,系统实现了“前端执行、后端优化、双向联动”的设计。
运行模块专注于生成和记录,训练模块专注于分析与更新,两者之间通过轨迹数据(trajectory)和奖励信号(reward signal)进行衔接。换言之,运行脚本产出数据,训练脚本消化数据

2.2 运行模块:LangGraph Agent 的设计逻辑

​ LangGraph 是 LangChain 1.0 之后推出的新型工作流图框架,用于将复杂的 Agent 推理过程可视化、结构化。在本项目中,LangGraph 承担着整个 SQL Agent 的“行为控制”职责。其核心思想是:将 Agent 的推理过程抽象为一张有向图(Directed Graph),每一个节点代表 Agent 的一个关键步骤,而节点之间的边则代表决策路径或执行顺序。

在 SQL Agent 中,LangGraph 的节点主要包括:

  • write_query(生成 SQL 语句):模型根据输入的问题及数据库模式(schema)生成初步 SQL 查询;
  • execute_query(执行 SQL):调用 SQL 执行工具 (QuerySQLDatabaseTool) 在数据库上运行查询;
  • check_query(检查 SQL 正确性):模型根据执行结果判断 SQL 是否合理;
  • rewrite_query(重写 SQL):若上一步判断为错误,则重新生成更合理的 SQL;
  • END(结束节点):当 Agent 确认 SQL 无误时,停止流程。

​ 这种设计使 Agent 在一次运行中可能经历多个“生成-执行-反馈”循环,从而体现出“自我纠错”的特征。更重要的是,LangGraph 天然支持状态持久化与轨迹记录,这为强化学习提供了重要基础。每一次执行、判断与重写都会被记录为一条“状态-动作-反馈”数据,这正是强化学习算法所需要的经验数据(Experience Trajectory)。

因此,运行模块的核心目标不仅是“让 Agent 能跑起来”,更重要的是“让 Agent 的行为能被捕获、能被度量、能被优化”。

image-20251109190604832
2.3 Agent-Lightning的封装逻辑

​ 在原始的 LangGraph 系统中, Agent 虽然能运行,但无法直接与强化学习训练框架交互。为此, Agent-Lightning 框架在运行脚本上进行了“封装”操作,使得 LangGraph Agent 具备了强化学习所需的接口与记录能力。这种封装的主要目的有三点:

  1. 轨迹采集(Trajectory Collection)
    Agent-Lightning 在每次 Agent 执行过程中自动记录输入问题、生成的 SQL、执行结果、反馈内容及执行日志。这些轨迹数据被打包成 rollout 样本,供后续 RL 算法使用。
  2. 奖励信号传递(Reward Propagation)
    运行脚本在每次 Agent 完成一个任务后,会调用 evaluate_query 函数对结果进行评分。若生成的 SQL 与标准答案一致(或执行结果正确),则 reward = 1;否则 reward = 0。这一奖励信号是 GRPO 算法计算梯度的重要依据。
  3. 可扩展接口(Training Interface)
    Agent-Lightning 在运行层面提供了标准化接口,使得训练模块可以直接通过 rollout 调用 Agent 执行。例如,在 train 脚本中可以统一调用 agent.rollout(task, resources, rollout) 而不必关心 Agent 内部结构。这种解耦式设计使整个系统具有极强的可扩展性,可以轻松替换 Agent 或模型。

​ 总结来说, Agent-Lightning 的封装使 LangGraph Agent 从一个“执行体”转变为一个“可训练体”,从而真正具备了强化学习的可操作性。它相当于为 Agent 加上了一层“可观测外壳”,把原本封闭的推理过程开放为可追踪的训练数据流。

image-20251109183930491
2.4 训练模块:基于 veRL 的 GRPO 强化学习逻辑

​ 训练模块是整个系统的优化核心,其主要任务是根据 Agent 执行产生的轨迹与奖励,更新底层语言模型参数,使其行为策略(policy)趋向于生成更优的 SQL 语句。

​ veRL (Volcengine Reinforcement Learning) 是由 字节跳动 开源的强化学习训练框架,支持多种算法,包括 PPO、DPO、GRPO 等。其中 GRPO(Group Relative Policy Optimization) 是 DeepSeek 提出的一种改进型策略优化算法,能够在无需 critic 网络的情况下实现高效的策略更新,特别适合大语言模型(LLM)的 RL 训练。

在本项目中, train 脚本主要通过 veRL 调用 GRPO 算法来实现训练,其底层逻辑如下:

  1. Rollout 阶段
    训练器(Trainer) 会调度多个并行进程,每个进程加载一个 LangGraph Agent 实例,并分配不同的训练样本(自然语言问题)。每个 Agent 根据当前模型策略生成 SQL 、执行、反馈,形成 rollout 轨迹。
  2. Reward 计算
    训练器调用 evaluate_query 对每个 Agent 的输出进行评分。若执行结果正确,则给出正向奖励,否则为 0。所有样本的 reward 值将与模型生成的 log prob 概率一同送入 GRPO 优化器。
  3. Advantage 估计与策略更新
    GRPO 算法根据组内样本的相对表现计算 advantage(优势函数),不依赖额外 critic 网络。表现较优的样本获得更大权重,劣质样本权重降低,从而实现“优胜劣汰”的参数更新。
  4. 参数同步与保存
    训练器更新模型参数后,会定期保存检查点(checkpoint),供下一轮 rollout 使用。新参数会替换旧参数,使下一轮 Agent 行为更优。

整个过程形成一个典型的 on-policy 强化学习闭环:执行 → 反馈 → 优化 → 再执行,每个循环都会让 Agent 的策略更趋近理想状态。

image-20251109190604832
2.5 运行与训练的闭环:性能提升的路径

​ 理解完运行与训练模块后,我们可以从系统层面总结其运行逻辑。整个强化学习微调流程可以用以下闭环描述:

  1. LangGraph Agent 根据输入问题生成 SQL;
  2. 执行 SQL 并获取执行结果;
  3. Agent-Lightning 封装记录轨迹并计算奖励;
  4. veRL 收集轨迹与奖励,使用 GRPO 算法更新模型策略;
  5. 更新后的模型重新投入下一轮 rollout 执行。

​ 在此循环中,每一轮 Agent 的 SQL 生成能力都会得到提升。初期 Agent 可能频繁生成错误 SQL,经过数轮优化后,模型逐渐学会识别正确的字段、表结构和查询模式,生成更符合语义的 SQL 语句。这就是 Agentic RL 的强大之处——通过实际执行结果指导模型学习,使其具备真实世界任务的自适应能力。

image-20251109190604832

​ 更进一步,该架构的解耦特性意味着它并不局限于 SQL 任务。理论上,只要 LangGraph 定义了 Agent 的执行流程, Agent-Lightning 提供了 rollout 封装, veRL 即可用于强化学习训练。这为其他类型 Agent 的迁移与再利用提供了极高的灵活性。

基于Agent-Lightning的SQL Agent强化学习训练实战

1. 基础环境配置与相关库安装

  • 实验环境说明:本小结实验在Ubuntu 22.04、H800(80G)显卡服务器上运行,推荐使用CUDA 12.8,完整运行需要12个小时,如采用LoRA微调,则可以压缩至2小时完成。

  • 创建基础虚拟环境

conda create --name al python=3.12
conda init
conda activate al
# conda install jupyterlab
# conda install ipykernel
# python -m ipykernel install --user --name al --display-name "Python al"
  • 安装Agent Lightning库
pip install --upgrade agentlightning
image-20251109163142773 image-20251109190953542
  • 安装SQL-Agent强化学习训练基础库
# 注意需要安装openai 2.0以上版本,可以通过pip show openai进行版本查看
pip install openai
pip install agentlightning[apo]

# 注意需要CUDA 12.8版本
pip install torch==2.8.0 torchvision==0.23.0 --index-url https://mirrors.huaweicloud.com/repository/pypi/simple

pip install flash-attn --no-build-isolation

pip install vllm==0.10.2 --index-url https://mirrors.huaweicloud.com/repository/pypi/simple

pip install verl==0.5.0 --index-url https://mirrors.huaweicloud.com/repository/pypi/simple

# 手动完成verl安装后可以输入如下命令进行验证
# pip install agentlightning[verl]
image-20251109165922247 image-20251109165702059
  • 安装SQL Agent基础库
pip install "langgraph<1.0" "langchain[openai]<1.0" "langchain-community" "langchain-text-splitters<1.0" "sqlparse" "nltk" --index-url https://mirrors.huaweicloud.com/repository/pypi/simple
image-20251109172859206
  • Qwen 2.5 coder模型下载:https://www.modelscope.cn/models/Qwen/Qwen2.5-Coder-1.5B-Instruct
image-20251109171227212
pip install modelscope
# cd /root/autodl-tmp

mkdir ./qwen2.5-Coder

modelscope download --model Qwen/Qwen2.5-Coder-1.5B-Instruct --local_dir ./qwen2.5-Coder
image-20251109171508721 image-20251109173055848
  • SQL数据集准备

Spider 数据集是一个大规模跨域文本到SQL数据集,专门用于训练和评估自然语言到SQL查询的转换能力。这个数据集来自耶鲁大学的研究项目,可以从 Yale LILY Spider 官方网站获取.

image-20251109193956082 image-20251109194023818

其中本项目核心要用到的是三个 Parquet 文件:

  • train_spider.parquet: 训练数据集,包含约 8,000 个样本
  • test_dev_500.parquet: 验证数据集的子集,包含 500 个样本
  • test_dev.parquet: 完整的开发/测试数据集

每个 Parquet 文件包含以下字段:

  • question: 自然语言问题(例如:“Show all concert names and their singers”)
  • db_id: 数据库标识符(例如:“concert_singer”)
  • query: 标准答案 SQL 查询(ground truth)
  • db_path: 数据库文件的相对路径
image-20251109170846756

数据集领取: 加 小助理 免费领取

image-20251109193644253

总的来说,数据集包含约 200 个不同的 SQLite 数据库,每个数据库代表一个特定的业务场景。我们需要这些数据库文件存储在 data/database/ 目录下,按照数据库名称组织。例如:

  • data/database/concert_singer/concert_singer.db - 音乐会和歌手数据库
  • data/database/college_2/college_2.db - 大学信息数据库
  • data/database/flight_2/flight_2.db - 航班信息数据库

每个数据库都是一个完整的 SQLite 文件,包含多个表和真实的业务数据。需要注意的是,

  • Spider 数据集是跨域的,意味着它包含多个不同业务领域的数据库,这使得训练出的模型具有更好的泛化能力
  • 数据集中的每个问题都有唯一的标准答案 SQL 查询,但可能存在多种等价的 SQL 写法都能得到相同的结果
  • SQLite 数据库的使用使得整个系统零配置,不需要安装和配置独立的数据库服务器
  • 在 CI 测试中,数据集的准备是自动化的,确保了测试环境的一致性

查看数据集:

# 如果要运行如下代码,建议单独创建环境,以免部分库如numpy版本和agent lightning冲突
pip install -U "pandas==2.2.2" "pyarrow==17.0.0" "fastparquet==2024.5.0" "numpy==1.26.4"

然后在Jupyter中运行如下代码,即可查看数据集基本情况:

import pandas as pd

df = pd.read_parquet("train_spider.parquet")

print("共读取样本数:", len(df))

print("字段:", list(df.columns))

print(df.head(3))
image-20251109194339194
  • wandb安装流程

  在大规模模型训练中,我们往往需要监控和分析大量的训练数据,而WandB可以帮助我们实现这一目标。它提供了以下几个重要的功能:

实时可视化:WandB可以实时展示训练过程中关键指标的变化,如损失函数、学习率、训练时间等。通过这些可视化数据,我们能够直观地了解模型的训练进展,快速发现训练中的异常或瓶颈。

自动记录与日志管理:WandB会自动记录每次实验的参数、代码、输出结果,确保实验结果的可追溯性。无论是超参数的设置,还是模型的架构调整,WandB都能够帮助我们完整保留实验记录,方便后期对比与调优。

支持中断与恢复训练:在长时间的预训练任务中,系统中断或需要暂停是常见的情况。通过WandB的checkpoint功能,我们可以随时恢复训练,从上次中断的地方继续进行,避免数据和时间的浪费。

多实验对比:当我们尝试不同的模型配置或超参数时,WandB允许我们在多个实验之间轻松进行对比分析,帮助我们选择最优的模型配置。

团队协作:WandB还支持团队协作,多个成员可以共同查看实验结果,协同调试模型。这对研究和项目开发中团队的合作非常有帮助。

wandb官网:https://wandb.ai/site

image-20241023171805985
image-20241023171908743
image-20241023172111510
image-20241023172148960
image-20241023172226625

然后即可在令行中输入如下代码安装wandb:

pip install wandb
image-20241023172349194

接下来在unsloth微调前,我们即可设置wandb进行微调记录,并可在对应网站上观察到训练过程如下:

image-20251109194754268

2.项目创建与运行流程

首先创建基本项目结构如下:

SQL-Agent-RL/
├── data/         # 存放数据集
├── model/        # 存放下载的模型权重
└── spider/       # 存放核心脚本

data/主要放置 Spider 数据集相关文件,包括:

train_spider.parquet
test_dev_500.parquet
test_dev.parquet
database/            # 每个数据库一个子目录,内含 .sqlite 和 schema.sql
image-20251109200337650

数据集领取:加 小助理 免费领取

image-20251109200001274

确保路径与脚本中保持一致,比如在 train_sql_agent.py 中的配置:

"train_files": "data/train_spider.parquet",
"val_files": "data/test_dev_500.parquet",

以及 SQLAgent 调用时会读取:

original_db_path = os.path.join(self.spider_dir, "database", task["db_id"], task["db_id"] + ".sqlite")

所以 data/database/... 子目录必须存在,否则训练时会提示数据库路径错误。

model/则是本地下载的模型权重(Qwen2.5-Coder)。 推荐结构如下:

model/
└── Qwen2.5-Coder-1.5B-Instruct/
    ├── config.json
    ├── tokenizer.json
    ├── model.safetensors
    └── ...

然后在训练脚本 train_sql_agent.py 中,把配置改为本地路径:

"actor_rollout_ref": {
    "model": {
        "path": "/root/SQL-Agent-RL/model/Qwen2.5-Coder-1.5B-Instruct",
    },
}
image-20251109200427498 image-20251109200435817

这样就不会再联网从 Hugging Face 下载了。

image-20251109195942923 image-20251109200906059

然后 spider/则用于存放项目核心代码脚本,建议包含:

spider/
├── sql_agent.py          # 基于 LangGraph 的 SQL Agent 封装
├── train_sql_agent.py    # 训练脚本(veRL + Agent-Lightning)
├── spider_eval/          # 官方提供的 SQL 评估函数
│   └── exec_eval.py
│   └── 其他各项py文件
└── __init__.py

这种结构可以让 Python 识别 spider 作为包路径,避免导入错误(如 from spider.sql_agent import LitSQLAgent)。

其中Agent创建代码sql_agent.py解释如下:

# Copyright (c) Microsoft. All rights reserved.

"""Sample code that demonstrates an SQL agent using LangGraph and LangChain,
trainable with Agent-lightning.

Adapted from https://python.langchain.com/docs/tutorials/sql_qa/
as well as https://langchain-ai.github.io/langgraph/tutorials/sql-agent/
"""

from __future__ import annotations

import os
import re
import shutil
import tempfile
import time
from typing import Any, Dict, List, Literal, Optional, cast

import pandas as pd
import termcolor
from langchain.chat_models import init_chat_model
from langchain_community.tools.sql_database.tool import QuerySQLDatabaseTool
from langchain_community.utilities import SQLDatabase
from langchain_core.messages import AnyMessage, BaseMessage, HumanMessage
from langchain_core.prompts import ChatPromptTemplate
from langgraph.graph import END, START, MessagesState, StateGraph
from langgraph.graph.state import CompiledStateGraph
from spider_eval.exec_eval import eval_exec_match

import agentlightning as agl

agl.configure_logger()

logger = agl.configure_logger(name=__name__)


WRITE_QUERY_PROMPT = ChatPromptTemplate(
    [
        (
            "system",
            """
You are an agent designed to interact with a SQL database.
     Given an input question, create a syntactically correct {dialect} query to run to help find the answer.

Pay attention to use only the column names that you can see in the schema description.
Be careful to not query for columns that do not exist.
Also, pay attention to which column is in which table.

## Table Schema ##

Only use the following tables:
{table_info}

## Output Format ##

Respond in the following format:

```{dialect}
GENERATED QUERY
```
""".strip(),
        ),
        ("user", "Question: {input}"),
    ]
)


CHECK_QUERY_PROMPT = ChatPromptTemplate(
    [
        (
            "system",
            """
You are a SQL expert with a strong attention to detail.
Double check the {dialect} query for common mistakes, including:
- Using NOT IN with NULL values
- Using UNION when UNION ALL should have been used
- Using BETWEEN for exclusive ranges
- Data type mismatch in predicates
- Properly quoting identifiers
- Using the correct number of arguments for functions
- Casting to the correct data type
- Using the proper columns for joins
- Explicit query execution failures
- Clearly unreasoable query execution results

## Table Schema ##

{table_info}

## Output Format ##

If any mistakes from the list above are found, list each error clearly.
After listing mistakes (if any), conclude with **ONE** of the following exact phrases in all caps and without surrounding quotes:
- If mistakes are found: `THE QUERY IS INCORRECT.`
- If no mistakes are found: `THE QUERY IS CORRECT.`

DO NOT write the corrected query in the response. You only need to report the mistakes.
""".strip(),
        ),
        (
            "user",
            """Question: {input}

Query:

```{dialect}
{query}
```

Execution result:

```
{execution}
```""",
        ),
    ]
)


REWRITE_QUERY_PROMPT = ChatPromptTemplate(
    [
        (
            "system",
            """
You are an agent designed to interact with a SQL database.
Rewrite the previous {dialect} query to fix errors based on the provided feedback.
The goal is to answer the original question.
Make sure to address all points in the feedback.

Pay attention to use only the column names that you can see in the schema description.
Be careful to not query for columns that do not exist.
Also, pay attention to which column is in which table.

## Table Schema ##

Only use the following tables:
{table_info}

## Output Format ##

Respond in the following format:

```{dialect}
REWRITTEN QUERY
```
""".strip(),
        ),
        (
            "user",
            """Question: {input}

## Previous query ##

```{dialect}
{query}
```

## Previous execution result ##

```
{execution}
```

## Feedback ##

{feedback}

Please rewrite the query to address the feedback.""",
        ),
    ]
)


class State(MessagesState):
    question: str
    query: str
    execution: str
    answer: str
    feedback: str
    num_turns: int
    messages: list[AnyMessage]


class SQLAgent:

    def __init__(
        self,
        db: str,
        max_turns: int = 5,
        debug: bool = False,
        db_schema: str | None = None,
        endpoint: str | None = None,
        verl_replacement: Dict[str, Any] | None = None,
        table_info_truncate: int = 2048,
        execution_truncate: int = 2048,
    ):
        self.db = SQLDatabase.from_uri(db)  # type: ignore
        self.db_schema = db_schema
        self.debug = debug
        self.max_turns = max_turns
        self.table_info_truncate = table_info_truncate
        self.execution_truncate = execution_truncate
        if verl_replacement is not None:
            self.model_name: str = verl_replacement["model"]  # type: ignore
            assert endpoint is not None
            self.llm = init_chat_model(
                self.model_name,
                model_provider="openai",
                openai_api_base=endpoint,
                openai_api_key=os.environ.get("OPENAI_API_KEY", "dummy"),
                temperature=verl_replacement["temperature"],
                max_retries=0,
                max_tokens=2048,
            )
        else:
            self.model_name: str = os.environ.get("MODEL", "gpt-4.1-mini")
            self.llm = init_chat_model(
                self.model_name,
                model_provider="openai",
                openai_api_base=endpoint or os.environ["OPENAI_API_BASE"],
                openai_api_key=os.environ["OPENAI_API_KEY"],
                temperature=0,
                max_retries=1,
                max_tokens=2048,
            )

    def get_table_info(self) -> str:
        """Get the table information in a human-readable format."""
        try:
            table_info = self.db.get_table_info()
            if len(table_info) > self.table_info_truncate:
                table_info = table_info[: self.table_info_truncate] + "\n... (truncated)"
            return table_info
        except Exception as e:
            logger.error(f"Failed to get table info: {e}")
            if self.db_schema:
                if len(self.db_schema) > self.table_info_truncate:
                    return self.db_schema[: self.table_info_truncate] + "\n... (truncated)"
                return self.db_schema
            return "No schema available."

    def invoke_prompt(self, prompt: Any) -> AnyMessage:
        if self.debug:
            for message in prompt.messages:
                termcolor.cprint(message.pretty_repr(), "blue")

        try:
            result = self.llm.invoke(prompt)
        except Exception as e:
            logger.error(f"Failed to invoke prompt: {e}")
            # FIXME: fallback to create a random trajectory
            result = self.llm.invoke([HumanMessage(content="Please create a random SQL query as an example.")])

        if self.debug:
            termcolor.cprint(result.pretty_repr(), "green")

        return result  # type: ignore

    def truncate_execuion(self, execution: str) -> str:
        """Truncate the execution result to a reasonable length."""
        if len(execution) > self.execution_truncate:
            return execution[: self.execution_truncate] + "\n... (truncated)"
        return execution

    def parse_query(self, message: AnyMessage) -> str | None:
        result: str | None = None
        for match in re.finditer(r".*```\w*\n(.*?)\n```.*", message.content, re.DOTALL):  # type: ignore
            result = match.group(1).strip()  # type: ignore
        return result  # type: ignore

    def write_query(self, state: State) -> State:
        """Generate SQL query to fetch information."""
        prompt: Any = WRITE_QUERY_PROMPT.invoke(  # type: ignore
            {
                "dialect": self.db.dialect,
                "input": state["question"],
                "table_info": self.get_table_info(),
            }
        )
        result = self.invoke_prompt(prompt)  # type: ignore

        query = self.parse_query(result) or result.content  # type: ignore

        return {  # type: ignore
            **state,
            "query": query,  # type: ignore
            "num_turns": 1,
            "messages": [*prompt.messages, result],
        }

    def execute_query(self, state: State) -> State:
        """Execute SQL query."""
        execute_query_tool = QuerySQLDatabaseTool(db=self.db)
        execution_result = execute_query_tool.invoke(state["query"])  # type: ignore
        if not isinstance(execution_result, str):
            # Convert to string if it's not already
            execution_result = str(execution_result)
        if self.debug:
            termcolor.cprint(execution_result, "yellow")
        return {**state, "execution": execution_result}

    def check_query(self, state: State) -> State:
        """Check the SQL query for correctness."""
        prompt: Any = CHECK_QUERY_PROMPT.invoke(  # type: ignore
            {
                "dialect": self.db.dialect,
                "input": state["question"],
                "query": state["query"],
                "execution": self.truncate_execuion(state["execution"]),
                "table_info": self.get_table_info(),
            }
        )
        result = self.invoke_prompt(prompt)  # type: ignore

        res = {  # type: ignore
            **state,
            "feedback": result.content,  # type: ignore
            "messages": [*state.get("messages", []), *prompt.messages, result],
        }
        return res  # type: ignore

    def rewrite_query(self, state: State) -> State:
        """Rewrite SQL query if necessary."""
        prompt: Any = REWRITE_QUERY_PROMPT.invoke(  # type: ignore
            {
                "dialect": self.db.dialect,
                "input": state["question"],
                "query": state["query"],
                "execution": self.truncate_execuion(state["execution"]),
                "feedback": state["feedback"],
                "table_info": self.get_table_info(),
            }
        )
        result = self.invoke_prompt(prompt)  # type: ignore

        rewritten_query = self.parse_query(result)  # type: ignore

        return {
            **state,
            "query": rewritten_query or state["query"],
            "num_turns": state.get("num_turns", 0) + 1,
            "messages": [*prompt.messages, result],  # clear previous prompts
        }

    def should_continue(self, state: State) -> Literal[END, "rewrite_query"]:  # type: ignore
        """Determine if the agent should continue based on the result."""
        if state["messages"] and isinstance(state["messages"][-1], BaseMessage):  # type: ignore
            last_message = state["messages"][-1]
            if "THE QUERY IS CORRECT" in last_message.content:  # type: ignore
                if "THE QUERY IS INCORRECT" in last_message.content:  # type: ignore
                    # Both correct and incorrect messages found
                    # See which is the last one
                    correct_index = last_message.content.rfind("THE QUERY IS CORRECT")  # type: ignore
                    incorrect_index = last_message.content.rfind("THE QUERY IS INCORRECT")  # type: ignore
                    if correct_index > incorrect_index:
                        return END
                else:
                    return END

        if state.get("num_turns", 0) >= self.max_turns:
            return END

        return "rewrite_query"

    def graph(self) -> CompiledStateGraph[State]:
        builder = StateGraph(State)
        builder.add_node(self.write_query)  # type: ignore
        builder.add_node(self.execute_query)  # type: ignore
        builder.add_node(self.check_query)  # type: ignore
        builder.add_node(self.rewrite_query)  # type: ignore

        builder.add_edge(START, "write_query")
        builder.add_edge("write_query", "execute_query")
        builder.add_edge("execute_query", "check_query")
        builder.add_conditional_edges(
            "check_query",
            self.should_continue,  # type: ignore
        )
        builder.add_edge("rewrite_query", "execute_query")

        return builder.compile()  # type: ignore


def evaluate_query(query: str, ground_truth: str, database: str, raise_on_error: bool = True) -> float:
    # TODO(yuge): Maybe we can evaluate intermediate queries and assign more precise rewards.

    # included in the original evaluation script
    # query = query.replace("value", "1")

    try:
        database = os.path.abspath(database)
        if not os.path.exists(database):
            raise FileNotFoundError(f"Database file {database} does not exist.")

        # Parameters following the default setting
        exec_score = eval_exec_match(
            db=database,
            p_str=query,
            g_str=ground_truth,
            plug_value=False,
            keep_distinct=False,
            progress_bar_for_each_datapoint=False,
        )
        if exec_score == 1:
            return 1.0
        else:
            return 0.0
    except Exception as e:
        if raise_on_error:
            raise
        else:
            logger.exception(f"Error evaluating query: {e}")
            return 0.0


class LitSQLAgent(agl.LitAgent[Dict[str, Any]]):

    def __init__(
        self,
        trained_agents: Optional[str] = r"write",
        val_temperature: Optional[float] = None,
        max_turns: int = 3,
        table_info_truncate: int = 2048,
        execution_truncate: int = 2048,
    ) -> None:
        super().__init__(trained_agents=trained_agents)
        self.val_temperature = val_temperature
        self.spider_dir = os.environ.get("VERL_SPIDER_DATA_DIR", "data")
        self.max_turns = max_turns
        self.table_info_truncate = table_info_truncate
        self.execution_truncate = execution_truncate

    def rollout(
        self,
        task: Dict[str, Any],
        resources: agl.NamedResources,
        rollout: agl.Rollout,
    ) -> float | None:
        question = task["question"]
        start_time = time.time()
        llm: agl.LLM = cast(agl.LLM, resources["main_llm"])

        if rollout.mode == "train":
            original_db_path = os.path.join(self.spider_dir, "database", task["db_id"], task["db_id"] + ".sqlite")
        else:
            original_db_path = os.path.join(self.spider_dir, "test_database", task["db_id"], task["db_id"] + ".sqlite")
        ground_truth = task["query"]

        if not os.path.exists(original_db_path):
            logger.error(f"Database {original_db_path} does not exist. Skipping.")
            return None

        schema_path = os.path.join(os.path.dirname(original_db_path), "schema.sql")
        if os.path.exists(schema_path):
            with open(schema_path, "r") as f:
                schema = f.read()
        else:
            logger.error("Schema file not found: %s", schema_path)
            schema = "No schema available."

        rollout_id = rollout.rollout_id

        with tempfile.TemporaryDirectory() as temp_dir:
            db_path = os.path.join(temp_dir, os.path.basename(original_db_path))
            shutil.copyfile(original_db_path, db_path)
            logger.info(f"[Rollout {rollout_id}] Question: {question}")
            logger.info(f"[Rollout {rollout_id}] Ground Truth: {ground_truth}")

            # Run the agent
            agent = SQLAgent(
                "sqlite:///" + db_path,
                max_turns=self.max_turns,
                table_info_truncate=self.table_info_truncate,
                execution_truncate=self.execution_truncate,
                debug=False,
                db_schema=schema,
                endpoint=llm.get_base_url(rollout.rollout_id, rollout.attempt.attempt_id),  # type: ignore
                verl_replacement=(
                    {"model": llm.model, **llm.sampling_parameters}
                    if rollout.mode == "train"
                    else {
                        "model": llm.model,
                        "temperature": (
                            self.val_temperature
                            if self.val_temperature is not None
                            else llm.sampling_parameters.get("temperature", 0.0)
                        ),
                    }
                ),
            ).graph()
            try:
                # Required to make the langchain tracing work
                handler = self.tracer.get_langchain_handler()
                result = agent.invoke(  # type: ignore
                    {"question": question},  # type: ignore
                    {"callbacks": [handler] if handler else [], "recursion_limit": 100},
                )
            except Exception as e:
                logger.exception(f"[Rollout {rollout_id}] Error during agent invocation: {e}")
                return

            logger.info(f"[Rollout {rollout_id}] Generated Query: {result['query']}")

        end_time_rollout = time.time()

        with tempfile.TemporaryDirectory() as temp_dir:
            db_path = os.path.join(temp_dir, os.path.basename(original_db_path))
            shutil.copyfile(original_db_path, db_path)

            reward = evaluate_query(result["query"], ground_truth, db_path, raise_on_error=False)
            logger.info("[Rollout %s] Reward: %s", rollout_id, reward)

        end_time_eval = time.time()

        logger.info("[Rollout %s] Time taken for rollout: %.2f seconds", rollout_id, end_time_rollout - start_time)
        logger.info(
            "[Rollout %s] Time taken for evaluation: %.2f seconds", rollout_id, end_time_eval - end_time_rollout
        )

        return reward


def debug_sql_agent():
    spider_dev_data_path = os.path.join(os.environ.get("VERL_SPIDER_DATA_DIR", "data"), "dev.parquet")
    if not os.path.exists(spider_dev_data_path):
        raise FileNotFoundError(f"Spider dev data file {spider_dev_data_path} does not exist.")
    df = pd.read_parquet(spider_dev_data_path).head(10)  # type: ignore
    df = cast(List[Dict[str, Any]], df.to_dict(orient="records"))  # type: ignore
    print("Debug data:", df)

    trainer = agl.Trainer(
        n_workers=1,
        initial_resources={
            "main_llm": agl.LLM(
                endpoint=os.environ["OPENAI_API_BASE"],
                model="gpt-4.1-nano",
                sampling_parameters={"temperature": 0.7},
            )
        },
    )
    trainer.dev(LitSQLAgent(), df)


if __name__ == "__main__":
    debug_sql_agent()

解释如下

class SQLAgent:
    def __init__(self, db: str, max_turns: int = 5, debug: bool = False, ...):
        self.db = SQLDatabase.from_uri(db)
        ...
        self.llm = init_chat_model(...)

说明:
SQLAgent 是整个智能体的主体类。
在初始化时:

  • 建立数据库连接(基于 SQLDatabase)。
  • 初始化大语言模型(LLM),支持直接连接 OpenAI 或通过 Agent-Lightning 提供的虚拟接口。
  • 支持设置最大循环次数 max_turns,以及调试模式、表结构截断长度等。
def get_table_info(self) -> str:
    try:
        table_info = self.db.get_table_info()
        ...

说明:
用于提取数据库的表结构描述。

  • 优先调用 LangChain 的 get_table_info()
  • 若失败,则使用传入的 db_schema
  • 超过长度则进行截断,防止 Prompt 太长。
def invoke_prompt(self, prompt: Any) -> AnyMessage:
    result = self.llm.invoke(prompt)
    return result

说明:
这是所有 Prompt 调用的统一入口。

  • 调用大模型生成回复;
  • 在调试模式下会以彩色打印 Prompt 和回复;
  • 若出现错误,会调用备用提示生成随机 SQL 以防中断。
def parse_query(self, message: AnyMessage) -> str | None:
    for match in re.finditer(r".*```\w*\n(.*?)\n```.*", message.content, re.DOTALL):
        result = match.group(1).strip()
    return result

说明:
从模型回复的文本中提取 SQL。

  • 使用正则表达式查找 ```代码块内的内容。
  • 提取最后一个代码块的内容作为 SQL 返回。
def write_query(self, state: State) -> State:
    prompt = WRITE_QUERY_PROMPT.invoke({...})
    result = self.invoke_prompt(prompt)
    query = self.parse_query(result)
    return {...}

说明:
这是第一个图节点,用于生成 SQL。

  • 根据输入问题和表结构生成 SQL。
  • 解析结果后更新 state,包括生成的 SQL 和消息记录。
def execute_query(self, state: State) -> State:
    execute_query_tool = QuerySQLDatabaseTool(db=self.db)
    execution_result = execute_query_tool.invoke(state["query"])
    return {**state, "execution": execution_result}

说明:
执行生成的 SQL。

  • 使用 LangChain 提供的数据库查询工具。
  • 将查询结果(或错误信息)记录在 execution 字段中。
def check_query(self, state: State) -> State:
    prompt = CHECK_QUERY_PROMPT.invoke({...})
    result = self.invoke_prompt(prompt)
    return {...}

说明:
让模型扮演 SQL 检查专家,验证查询是否合理。

  • 提供问题、SQL、执行结果和表结构作为输入。
  • 模型输出错误说明,并以固定句式结尾。
  • 结果被存入 feedback,供下一步使用。
def rewrite_query(self, state: State) -> State:
    prompt = REWRITE_QUERY_PROMPT.invoke({...})
    result = self.invoke_prompt(prompt)
    rewritten_query = self.parse_query(result)
    return {...}

说明:
当上一步检查出错误时,进入该节点。

  • 模型根据反馈修改 SQL。
  • 解析后更新 state,并将回合数 num_turns + 1
def should_continue(self, state: State) -> Literal[END, "rewrite_query"]:
    if "THE QUERY IS CORRECT" in last_message.content:
        return END
    if state.get("num_turns", 0) >= self.max_turns:
        return END
    return "rewrite_query"

说明:
该函数控制循环逻辑:

  • 如果模型认为 SQL 正确,结束流程;
  • 若达到最大修正次数,也结束;
  • 否则返回 "rewrite_query" 节点,继续修正。
def graph(self) -> CompiledStateGraph[State]:
    builder = StateGraph(State)
    builder.add_node(self.write_query)
    builder.add_node(self.execute_query)
    builder.add_node(self.check_query)
    builder.add_node(self.rewrite_query)
    ...
    return builder.compile()

说明:
在这里定义整个 LangGraph 流程:

  1. 添加各个节点;
  2. 建立节点之间的执行顺序;
  3. 指定条件跳转逻辑(check_queryrewrite_queryEND)。
    最终返回一个可执行的图。
def evaluate_query(query, ground_truth, database, raise_on_error=True) -> float:
    exec_score = eval_exec_match(...)
    return 1.0 if exec_score == 1 else 0.0

说明:
比较生成的 SQL 与标准答案在数据库中的执行结果是否一致。

  • 若完全匹配,则得分 1;否则 0。
  • 用于强化学习中的奖励计算。
class LitSQLAgent(agl.LitAgent[Dict[str, Any]]):
    def rollout(self, task, resources, rollout) -> float | None:
        ...

说明:
LitSQLAgentSQLAgent 封装为可训练单元,用于 Agent-Lightning 框架。
主要职责:

  • 从 Spider 数据集中取出问题与数据库;
  • 创建临时数据库副本;
  • 运行 SQLAgent 的 LangGraph;
  • 最终评测生成的 SQL 并返回奖励分数。
def debug_sql_agent():
    spider_dev_data_path = os.path.join(os.environ.get("VERL_SPIDER_DATA_DIR", "data"), "dev.parquet")
    df = pd.read_parquet(spider_dev_data_path).head(10)
    trainer = agl.Trainer(...)
    trainer.dev(LitSQLAgent(), df)

说明:
这是一个调试入口。

  • 读取 Spider 的样例数据;
  • 构造一个最小的 Agent-Lightning 训练器;
  • 运行 10 条数据,验证 SQLAgent 的功能。

而train_sql_agent.py代码则与我们前面讲解的 sql_agent.py 是“前后联动”的:

  • sql_agent.py 负责定义一个可运行的 SQL Agent 图(LangGraph 流程)
  • train_sql_agent.py 则负责将这个 Agent 放入 强化学习算法(VERL 框架) 中进行训练。

具体解释如下:

# Copyright (c) Microsoft. All rights reserved.

"""Train an SQL agent on the Spider dataset using Agent-lightning.

This module provides a training script for SQL agents using different model configurations.
The script supports three different training configurations:

1. 'fast' - A lightweight configuration optimized for CI testing with reduced epochs
2. 'qwen' - Standard configuration using Qwen-2.5-Coder-1.5B-Instruct model
3. 'llama' - Configuration using LLaMA-3.2-1B-Instruct model with JSON formatting
"""

说明:
这部分文档字符串(Docstring)简要说明了脚本的用途与支持的三种训练模式:

模式名称 说明
fast 快速测试模式,用于 CI 测试,训练步数极少
qwen 使用 Qwen 2.5 Coder 1.5B 模型的标准训练配置
llama 使用 LLaMA 3.2 1B 模型的训练配置(JSON 格式输出)

脚本主要用于在 Spider 文本到 SQL 数据集上,通过 Agent-Lightning 框架执行强化学习训练。

from __future__ import annotations
import argparse, os
from copy import deepcopy
from datetime import datetime
from typing import Any, Dict, Optional
import pandas as pd
from sql_agent import LitSQLAgent
import agentlightning as agl

说明:

  • argparse:用于解析命令行参数(如选择 fast/qwen/llama 模式)。
  • deepcopy:用于深拷贝默认配置,避免修改原始模板。
  • pandas:读取 Spider 数据集(存储为 .parquet 文件)。
  • LitSQLAgent:从前面的 sql_agent.py 导入的类,是被训练的智能体。
  • agentlightning(简称 AGL):微软的强化学习训练框架,封装了算法与分布式训练逻辑。
RL_TRAINING_CONFIG: Dict[str, Any] = {
    "algorithm": {
        "adv_estimator": "grpo",
        "use_kl_in_reward": False,
    },
    ...
}

说明:
这是一个全局字典,定义了强化学习训练的默认参数模板,分为以下几部分👇

3.1 算法配置 (algorithm)

"algorithm": {
    "adv_estimator": "grpo",
    "use_kl_in_reward": False,
},
  • adv_estimator: grpo 表示使用 GRPO(Generalized Recurrent Policy Optimization) 算法。
    这是微软提出的一种类似 PPO 的改进算法,更适合语言模型类任务。
  • use_kl_in_reward=False 表示不在奖励中显式引入 KL 正则项(更自由的策略更新)。

3.2 数据配置 (data)

"data": {
    "train_files": "data/train_spider.parquet",
    "val_files": "data/test_dev_500.parquet",
    "train_batch_size": 32,
    "max_prompt_length": 4096,
    "max_response_length": 2048,
    "truncation": "error",
},

这部分指定了训练与验证数据路径,以及输入输出的最大长度约束。

  • Spider 数据集以 .parquet 格式保存,包含自然语言问题与对应 SQL。
  • max_prompt_lengthmax_response_length 控制上下文长度。
  • truncation="error" 表示当超出长度时直接报错。

3.3 模型与并行配置 (actor_rollout_ref)

这一节定义了三个角色:

  • rollout:生成经验(即执行 agent,产生数据);
  • actor:负责策略优化(即更新模型参数);
  • ref:参考模型(稳定更新、提供 baseline)。

同时,还定义了底层使用的模型路径与推理引擎。

"actor_rollout_ref": {
    "rollout": {
        "tensor_model_parallel_size": 1,
        "n": 4,
        "log_prob_micro_batch_size_per_gpu": 4,
        "multi_turn": {"format": "hermes"},
        "name": "vllm",
        "gpu_memory_utilization": 0.8,
        "engine_kwargs": {
            "vllm": {
                "enable_auto_tool_choice": True,
                "tool_call_parser": "hermes",
            }
        },
    },
    "actor": {
        "ppo_mini_batch_size": 32,
        ...
    },
    "ref": {
        "log_prob_micro_batch_size_per_gpu": 8,
        ...
    },
    "model": {
        "path": "Qwen/Qwen2.5-Coder-1.5B-Instruct",
        "use_remove_padding": True,
        "enable_gradient_checkpointing": True,
    },
},

解释要点:

  • rollout 段指定执行模型时的推理参数,底层使用 vLLM 引擎进行高效推理;
  • multi_turn.format = "hermes" 指多轮对话采用 Hermes 格式(LangChain 支持的标准消息格式);
  • actor 段控制优化过程(批大小、学习率、剪切比例等 PPO 超参);
  • model.path 指定要训练的模型,例如 Qwen 2.5 Coder 1.5B。

3.4 训练控制 (trainer)

"trainer": {
    "n_gpus_per_node": 1,
    "val_before_train": True,
    "logger": ["console", "wandb"],
    "project_name": "AgentLightning",
    "experiment_name": "spider",
    "total_epochs": 2,
}
  • 指定训练资源与日志系统(如 wandb 追踪)。
  • total_epochs=2 表示只训练两个 epoch。
  • 其他参数如 test_freq, critic_warmup, nnodes 等控制分布式行为。
def config_train_fast() -> Dict[str, Any]:
    timestamp = datetime.now().strftime("%Y%m%d%H%M%S")
    EXPERIMENT_NAME = f"spider_{timestamp}"
    PROJECT_NAME = "AgentLightningCI"
    ...
    config["actor_rollout_ref"]["model"]["path"] = "Qwen/Qwen2.5-Coder-0.5B-Instruct"
    config["trainer"]["total_epochs"] = 1
    ...
    return config

说明:
用于快速测试或**持续集成(CI)**场景。
特征包括:

  • 自动生成带时间戳的实验名;
  • 减小 GPU 内存占用(gpu_memory_utilization=0.6);
  • 使用更小的模型(Qwen2.5-Coder-0.5B);
  • 降低训练 epoch 数量与步数;
  • 输出项目名与实验名方便追踪。
def config_train_qwen() -> Dict[str, Any]:
    config = deepcopy(RL_TRAINING_CONFIG)
    return config

说明:
此函数直接返回默认模板,用于常规的 Qwen 1.5B 模型训练。
适合标准 GPU 环境(如单卡 A100 / A800 / L40S 等)。

def train(config: Dict[str, Any], active_agent: Optional[str]) -> None:
    agent = LitSQLAgent()
    algorithm = agl.VERL(config)
    trainer = agl.Trainer(n_runners=10, algorithm=algorithm, adapter={"agent_match": active_agent})
    ...
    trainer.fit(agent, train_dataset=train_data, val_dataset=val_data)

说明:
这是训练主逻辑,步骤如下👇

  1. 实例化智能体
    agent = LitSQLAgent() 即我们之前在 sql_agent.py 中定义的 SQL Agent。
  2. 加载强化学习算法
    agl.VERL(config) 初始化强化学习算法(基于 GRPO)。
  3. 创建训练器 Trainer
    agl.Trainer(...) 管理整个训练流程。
    参数:
    • n_runners=10:开启 10 个并发执行线程;
    • adapter={"agent_match": active_agent}:可指定要替换的代理模块(如某个特定节点)。
  4. 加载数据集
    使用 pandas 读取 Spider 数据集(训练/验证),并转换为字典列表形式。
  5. 启动训练
    trainer.fit(agent, train_dataset, val_dataset) 正式进入训练过程。
    Agent-Lightning 会自动完成:
    • rollout 数据生成(执行 agent);
    • reward 计算;
    • 策略更新;
    • 验证评估与日志记录。
def main() -> None:
    parser = argparse.ArgumentParser(...)
    parser.add_argument("config", choices=["fast", "qwen", "llama"], ...)
    parser.add_argument("--active-agent", type=str, ...)
    args = parser.parse_args()
    ...
    train(config, active_agent)

说明:
main() 负责解析命令行参数并启动训练。
使用方式例如:

python train_sql_agent.py qwen

或:

python train_sql_agent.py fast

可选参数 --active-agent 用于指定要训练的特定 agent 名称(若不提供则自动匹配)。

3. SQL Agent运行与调用流程

  • 创建.env文件
image-20251109201559874 image-20251109201617557
  • 启动vLLM服务
cd "/root/autodl-tmp/SQL Agent强化学习训练/model"

vllm serve ./Qwen2.5-Coder-1.5B-Instruct \
  --host 0.0.0.0 \
  --port 8000 \
  --max-model-len 4096 \
  --dtype bfloat16

image-20251109202138910

image-20251109202231669
  • 运行SQL Agent
cd root/autodl-tmp/SQL Agent强化学习训练/spider
# export $(grep -v '^#' .env | xargs)
python sql_agent.py
image-20251109203015868

而运行如下代码则会开始进行训练:

python train_sql_agent.py qwen
image-20251109203032992

实际运行效果如下所示:

image-20251109203247586
  • 最终运行指标:
image-20251109203316105 image-20251109203326793

4. 实验成果说明

​ 在本节中,我们将对 SQL-Agent 强化学习微调的最终实验结果进行系统性总结与分析。本次实验以 Qwen2.5-Coder 系列模型为核心,通过 Agent-Lightning 框架配合 veRL 的 GRPO 强化学习算法,对基于 LangGraph 构建的 SQL-Agent 进行了全流程的 Agentic RL 训练与评估。以下结论基于多组训练实验的真实数据统计与性能表现。

4.1 强化学习显著提升模型性能

​ 实验结果表明,经过 RL 训练后,所有模型的 SQL 生成准确率均较初始状态显著提升。尤其在 Qwen2.5-Coder-3B 模型上,经过 GRPO 训练后,在三轮推理交互(Three Turns)设置下,最终准确率达到 80.4%;即便在单轮推理(One Turn)下,也能保持 80.2% 的高准确率。这充分说明了强化学习能够有效提升 Agent 的策略能力,使模型不仅“生成得出答案”,更“生成得更正确”

4.2 上下文长度对性能的直接影响

​ 实验对比发现,上下文长度(context length) 的提升对模型性能具有明显正向作用。以 Qwen2.5-Coder-3B 为例,当上下文从 2048 tokens 扩展至 4096 tokens 时,三轮推理下的最终准确率由 76.4% 提升至 80.4%,单轮推理下的准确率则从 73.2% 提升至 80.2%。这一变化反映出更长的上下文窗口能够帮助模型在多轮 SQL 生成与反馈中更好地保持逻辑一致性,减少语义丢失。

4.3 交互轮次的边际收益

​ 从交互次数(turns)维度来看,更多的推理轮次在部分设置下确实能带来性能提升,但提升幅度有限。在 2048 context 下,Qwen2.5-Coder-3B 从单轮到三轮的准确率从 73.2% 提升到 76.4%; 然而当上下文提升到 4096 时,单轮与三轮的准确率几乎持平(80.2% vs 80.4%)。这说明在更强大的模型与更充足的上下文环境下,模型在单轮推理中已经能够完成自我校正与最优生成。

4.4 显式“检查”机制的收益与代价

​ 实验中增加了一个“显式检查(check)”训练步骤,即让 Agent 在执行 SQL 生成后主动验证并修正结果。这一机制确实带来小幅性能提升 —— 以 Qwen2.5-Coder-3B 为例,准确率从 76.4% 提升至 77.6%
​ 但代价是训练时间几乎翻倍,更新周期明显增加。这表明,虽然显式检查能增强模型的自我纠错能力,但其训练成本较高,不适合轻量化场景。

4. 实验成果说明

​ 在本节中,我们将对 SQL-Agent 强化学习微调的最终实验结果进行系统性总结与分析。本次实验以 Qwen2.5-Coder 系列模型为核心,通过 Agent-Lightning 框架配合 veRL 的 GRPO 强化学习算法,对基于 LangGraph 构建的 SQL-Agent 进行了全流程的 Agentic RL 训练与评估。以下结论基于多组训练实验的真实数据统计与性能表现。

4.1 强化学习显著提升模型性能

​ 实验结果表明,经过 RL 训练后,所有模型的 SQL 生成准确率均较初始状态显著提升。尤其在 Qwen2.5-Coder-3B 模型上,经过 GRPO 训练后,在三轮推理交互(Three Turns)设置下,最终准确率达到 80.4%;即便在单轮推理(One Turn)下,也能保持 80.2% 的高准确率。这充分说明了强化学习能够有效提升 Agent 的策略能力,使模型不仅“生成得出答案”,更“生成得更正确”

4.2 上下文长度对性能的直接影响

​ 实验对比发现,上下文长度(context length) 的提升对模型性能具有明显正向作用。以 Qwen2.5-Coder-3B 为例,当上下文从 2048 tokens 扩展至 4096 tokens 时,三轮推理下的最终准确率由 76.4% 提升至 80.4%,单轮推理下的准确率则从 73.2% 提升至 80.2%。这一变化反映出更长的上下文窗口能够帮助模型在多轮 SQL 生成与反馈中更好地保持逻辑一致性,减少语义丢失。

4.3 交互轮次的边际收益

​ 从交互次数(turns)维度来看,更多的推理轮次在部分设置下确实能带来性能提升,但提升幅度有限。在 2048 context 下,Qwen2.5-Coder-3B 从单轮到三轮的准确率从 73.2% 提升到 76.4%; 然而当上下文提升到 4096 时,单轮与三轮的准确率几乎持平(80.2% vs 80.4%)。这说明在更强大的模型与更充足的上下文环境下,模型在单轮推理中已经能够完成自我校正与最优生成。

4.4 显式“检查”机制的收益与代价

​ 实验中增加了一个“显式检查(check)”训练步骤,即让 Agent 在执行 SQL 生成后主动验证并修正结果。这一机制确实带来小幅性能提升 —— 以 Qwen2.5-Coder-3B 为例,准确率从 76.4% 提升至 77.6%
​ 但代价是训练时间几乎翻倍,更新周期明显增加。这表明,虽然显式检查能增强模型的自我纠错能力,但其训练成本较高,不适合轻量化场景。

详细的 Agent RL 原理及入门实战代码 加入 赋范空间 免费领取

image-20251109193644253

还有更多免费的前沿技术解读、Agent开发实战等资源等你来拿~

Logo

有“AI”的1024 = 2048,欢迎大家加入2048 AI社区

更多推荐