Project AirSim简介(2):自定义ProjectAirSim环境基础
强化学习的实验往往复杂、参数众多、结果不稳定。而 Stable-Baselines3(简称 SB3) 的出现,正是为了让这些过程标准化、模块化,使强化学习的研究与工程实现都能遵循统一规范、可重复复现。
本文将系统讲解 SB3 的结构设计与使用规范,帮助你从“能跑通代码”到“理解它的逻辑”。文章内容只聚焦于 SB3 的核心规范与基本用法,下篇文章会涉及 ProjectAirSim 环境的封装与控制。
一、环境规范:Gymnasium 的标准接口
SB3 要求所有环境(Env)遵循 Gymnasium 的 API 规范。Gymnasium 是 Gym 的继任者,两者的接口几乎一致,但在返回值和命名上略有更新(Gymnasium 的 reset() 返回 (obs, info),step() 返回 (obs, reward, terminated, truncated, info))。
环境的核心逻辑包括四个部分:
-
初始化(__init__):定义 action_space 与 observation_space。这是 SB3 判断输入输出合法性的关键。如果你的动作空间是离散的,用 gym.spaces.Discrete(n);如果是连续的,用 gym.spaces.Box(low, high, shape, dtype)。
-
重置(reset()):在每个 episode 开始时调用,返回一个观测向量 obs 和一个字典 info。SB3 会根据它来初始化策略的输入。
-
执行(step(action)):环境的核心循环。输入一个动作 action,返回新的观测 obs、奖励 reward、终止标志 terminated、截断标志 truncated 以及 info。
-
结束与清理:当 terminated 或 truncated 为真时,一个 episode 结束,SB3 内部会自动重置环境并开始下一轮。
下面是一个最小可运行的自定义环境示例,它符合 SB3 的所有规范:
import gymnasium as gym
import numpy as np
class CustomEnv(gym.Env):
def __init__(self):
super().__init__()
self.action_space = gym.spaces.Discrete(2)
self.observation_space = gym.spaces.Box(low=0, high=1, shape=(3,), dtype=np.float32)
def reset(self, *, seed=None, options=None):
obs = np.array([0.0, 0.0, 0.0], dtype=np.float32)
info = {}
return obs, info
def step(self, action):
reward = 1 if action == 0 else -1
terminated = False
truncated = False
info = {}
obs = np.array([0.0, 0.0, 0.0], dtype=np.float32)
return obs, reward, terminated, truncated, info
这段代码虽然简单,但它体现了 SB3 训练环境最核心的规范:空间定义明确、数据类型统一、返回结构标准化。如果这些规范不被遵守,SB3 将无法自动判断训练状态或记录日志。
Stable-Baselines3 (SB3) 提供的一个非常实用的环境验证工具,我们可以使用check_env(env) 用来验证自定义的Gym环境是否符合 Gym / Gymnasium 的接口规范,确保它能被 Stable-Baselines3 等强化学习算法(PPO、SAC、TD3 等)正确使用。
在你实现完环境类(例如 CustomEnv)之后,加上这几行:
from stable_baselines3.common.env_checker import check_env
from your_env_file import CustomEnvEnv
env = CustomEnv()
check_env(env)
它会检查的内容包括:
|
检查项目 |
说明 |
|---|---|
|
reset() 方法 |
是否返回 (obs, info) 二元组 |
|
step(action) 方法 |
是否返回 (obs, reward, done, truncated, info) 五元组 |
|
observation_space / action_space |
是否为 gym.spaces 对象 |
|
obs 是否匹配 observation_space |
检查维度、类型、范围 |
|
action 是否匹配 action_space |
检查取值范围是否有效 |
|
done 和 truncated |
是否为布尔类型 |
|
reward |
是否为数值类型 |
|
info |
是否为字典类型 |
|
环境可重复 reset |
多次调用 reset() 是否能正常运行 |
|
环境可 step |
连续调用 step() 是否报错 |
二、模型定义与训练流程
SB3 中的每个算法(如 PPO、A2C、SAC)本质上是一个强化学习训练器,它会自动管理策略网络、价值网络、优化器、回放缓冲区、梯度更新与日志输出。你只需提供一个符合规范的环境实例,就可以像使用一个高层 API 一样开始训练。例如:
from stable_baselines3 import PPO
env = CustomEnv()
model = PPO("MlpPolicy", env, verbose=1)
model.learn(total_timesteps=10000)
model.save("ppo_custom_env")
SB3 的策略名(如 "MlpPolicy", "CnnPolicy")决定了模型结构。如果环境观测是向量,就用 "MlpPolicy";如果是图像,则用 "CnnPolicy"。而算法对象(如 PPO、SAC)只控制训练方式,不影响输入结构。
三、保存与加载:复现训练的关键
强化学习实验往往需要多次中断与继续训练。SB3 的 save() 和 load() 方法正是为此设计的。每个模型对象(如 PPO、SAC)都可以直接保存当前权重、超参数与优化状态。例如:
model.save("ppo_checkpoint")
model = PPO.load("ppo_checkpoint")
四、相关连接
AirSim:https://github.com/microsoft/AirSim.git
Project AirSim:https://github.com/iamaisim/ProjectAirSim.git
为了方便使用,我在 Project AirSim 基础上略作修改:https://github.com/QinCheng0928/UAV_UGV_Navigation_ProjectAirSim.git
本文仅为个人学习与理解笔记,水平有限,如有错误或理解偏差,欢迎评论区指正,轻喷。
希望这篇文章能帮到正在研究 Project AirSim 的同学!
更多推荐


所有评论(0)