强化学习的实验往往复杂、参数众多、结果不稳定。而 Stable-Baselines3(简称 SB3) 的出现,正是为了让这些过程标准化、模块化,使强化学习的研究与工程实现都能遵循统一规范、可重复复现。

本文将系统讲解 SB3 的结构设计与使用规范,帮助你从“能跑通代码”到“理解它的逻辑”。文章内容只聚焦于 SB3 的核心规范与基本用法,下篇文章会涉及 ProjectAirSim 环境的封装与控制。

一、环境规范:Gymnasium 的标准接口

SB3 要求所有环境(Env)遵循 Gymnasium 的 API 规范。Gymnasium 是 Gym 的继任者,两者的接口几乎一致,但在返回值和命名上略有更新(Gymnasium 的 reset() 返回 (obs, info),step() 返回 (obs, reward, terminated, truncated, info))。

环境的核心逻辑包括四个部分:

  1. 初始化(__init__):定义 action_space 与 observation_space。这是 SB3 判断输入输出合法性的关键。如果你的动作空间是离散的,用 gym.spaces.Discrete(n);如果是连续的,用 gym.spaces.Box(low, high, shape, dtype)。

  2. 重置(reset()):在每个 episode 开始时调用,返回一个观测向量 obs 和一个字典 info。SB3 会根据它来初始化策略的输入。

  3. 执行(step(action)):环境的核心循环。输入一个动作 action,返回新的观测 obs、奖励 reward、终止标志 terminated、截断标志 truncated 以及 info。

  4. 结束与清理:当 terminated 或 truncated 为真时,一个 episode 结束,SB3 内部会自动重置环境并开始下一轮。

下面是一个最小可运行的自定义环境示例,它符合 SB3 的所有规范:

import gymnasium as gym
import numpy as np

class CustomEnv(gym.Env):
    def __init__(self):
        super().__init__()
        self.action_space = gym.spaces.Discrete(2)
        self.observation_space = gym.spaces.Box(low=0, high=1, shape=(3,), dtype=np.float32)

    def reset(self, *, seed=None, options=None):
        obs = np.array([0.0, 0.0, 0.0], dtype=np.float32)
        info = {}
        return obs, info

    def step(self, action):
        reward = 1 if action == 0 else -1
        terminated = False
        truncated = False
        info = {}
        obs = np.array([0.0, 0.0, 0.0], dtype=np.float32)
        return obs, reward, terminated, truncated, info

这段代码虽然简单,但它体现了 SB3 训练环境最核心的规范:空间定义明确、数据类型统一、返回结构标准化。如果这些规范不被遵守,SB3 将无法自动判断训练状态或记录日志。

Stable-Baselines3 (SB3) 提供的一个非常实用的环境验证工具,我们可以使用check_env(env) 用来验证自定义的Gym环境是否符合 Gym / Gymnasium 的接口规范,确保它能被 Stable-Baselines3 等强化学习算法(PPO、SAC、TD3 等)正确使用。

在你实现完环境类(例如 CustomEnv)之后,加上这几行:

from stable_baselines3.common.env_checker import check_env 
from your_env_file import CustomEnvEnv 

env = CustomEnv() 
check_env(env)

它会检查的内容包括:

检查项目

说明

reset() 方法

是否返回 (obs, info) 二元组

step(action) 方法

是否返回 (obs, reward, done, truncated, info) 五元组

observation_space / action_space

是否为 gym.spaces 对象

obs 是否匹配 observation_space

检查维度、类型、范围

action 是否匹配 action_space

检查取值范围是否有效

done 和 truncated

是否为布尔类型

reward

是否为数值类型

info

是否为字典类型

环境可重复 reset

多次调用 reset() 是否能正常运行

环境可 step

连续调用 step() 是否报错

二、模型定义与训练流程

SB3 中的每个算法(如 PPO、A2C、SAC)本质上是一个强化学习训练器,它会自动管理策略网络、价值网络、优化器、回放缓冲区、梯度更新与日志输出。你只需提供一个符合规范的环境实例,就可以像使用一个高层 API 一样开始训练。例如:

from stable_baselines3 import PPO 

env = CustomEnv() 
model = PPO("MlpPolicy", env, verbose=1) 
model.learn(total_timesteps=10000) 
model.save("ppo_custom_env")

SB3 的策略名(如 "MlpPolicy", "CnnPolicy")决定了模型结构。如果环境观测是向量,就用 "MlpPolicy";如果是图像,则用 "CnnPolicy"。而算法对象(如 PPO、SAC)只控制训练方式,不影响输入结构。

三、保存与加载:复现训练的关键

强化学习实验往往需要多次中断与继续训练。SB3 的 save() 和 load() 方法正是为此设计的。每个模型对象(如 PPO、SAC)都可以直接保存当前权重、超参数与优化状态。例如:

model.save("ppo_checkpoint") 
model = PPO.load("ppo_checkpoint")

四、相关连接

AirSim:https://github.com/microsoft/AirSim.git

Project AirSim:https://github.com/iamaisim/ProjectAirSim.git

为了方便使用,我在 Project AirSim 基础上略作修改:https://github.com/QinCheng0928/UAV_UGV_Navigation_ProjectAirSim.git

本文仅为个人学习与理解笔记,水平有限,如有错误或理解偏差,欢迎评论区指正,轻喷。

希望这篇文章能帮到正在研究 Project AirSim 的同学!

Logo

有“AI”的1024 = 2048,欢迎大家加入2048 AI社区

更多推荐