About ML, DL and RL - 为什么在超参数优化(HPO)里要 剪枝(pruning)?超参数优化过程中的剪枝什么意思?适用场景是什么?
pruner=
·
为什么在超参数优化(HPO)里要 剪枝(pruning)?
现象 | 如果不剪枝会怎样? | 引入剪枝后有什么好处? |
---|---|---|
时间浪费 | 每一次 trial 都完整训练到 num_epochs ,哪怕它在第 2 个 epoch 就已经明显表现最差。 |
只让“有希望”的 trial 继续,把 GPU/CPU 留给潜力更大的组合,整体搜索更快终止。 |
资源浪费 | 训练早期即已落后的模型仍然会继续前向/反向传播,浪费显存、电费。 | 释放显存、减少 I/O,能在同样预算里尝试更多 trial。 |
统计效率低 | 真正有用的结果被大量无价值、长时间运行的 trial 淹没,拖慢找到最优超参数的速度。 | 更快积累“好 trial”的反馈,优化器能够及时调整搜索方向(例如在贝叶斯优化中更新后验)。 |
一句话总结:剪枝 = 提前终止明显无望的试验,用更少成本让搜索收敛得更快更好。
Optuna 中剪枝的工作原理
-
中期指标上报
trial.report(metric_value, step)
这一步告诉 Optuna:“在
step
(你这里是 epoch)时,这个 trial 的验证损失(或准确率)是多少”。 -
比较规则(由
Pruner
决定)
例如MedianPruner
:- 先等待
n_startup_trials
个完整 trial(默认 5 个)收集“基线”中期指标。 - 之后的 trial 若在同一步骤的指标 劣于 所有已完成 trial 指标的中位数,就会被判定为“不值得继续”。
- 先等待
-
触发剪枝
if trial.should_prune(): raise optuna.TrialPruned()
- 这句抛出
TrialPruned
异常,让 Optuna 把该 trial 标记为 PRUNED 并立即结束训练循环。 - 训练脚本立刻返回,GPU 资源被下一次 trial 复用。
- 这句抛出
为什么“光定义 pruner=
还不够”?
study = optuna.create_study(pruner=MedianPruner(...))
只是告诉 Optuna “我想用剪枝”。- 真正的触发点 必须在 objective 函数里显式上报指标并检查
should_prune()
。 - 如果漏掉这两行,Optuna 看不到任何中期指标 → 认为“没有可剪枝信息” → 全部 trial 都会跑满
num_epochs
,等于白定义了Pruner
。
与早停(early stopping)有何区别?
剪枝 | 早停 | |
---|---|---|
目标 | 跨 trial:淘汰整体表现落后的 trial | 单个模型 内部:防止过拟合、训练震荡 |
决策依据 | 与 其它 trial 的同阶段表现比较 | 与自身历史最优指标比较 |
触发后 | 整个 trial 结束,换下一个超参数组合 | 结束当前 trial 的剩余 epoch,但保留模型权重 |
在实际工程里,两者通常 同时使用:
- 剪枝负责把“注定不会成功”的 trial 砍掉;
- 早停则帮“有希望的 trial”在最佳 epoch 定格,防止过拟合。
代码落地示例
def objective(trial):
# 1) 读出/采样超参数
lr = trial.suggest_float('lr', 1e-5, 1e-2, log=True)
...
# 2) 训练循环
for epoch in range(num_epochs):
train_loss = train_one_epoch(model, ...)
val_loss = evaluate(model, ...)
# --- 剪枝逻辑 begin ---
trial.report(val_loss, epoch) # a. 上报
if trial.should_prune(): # b. 判断
raise optuna.TrialPruned() # c. 触发
# --- 剪枝逻辑 end ---
return val_loss # 3) 返回最终指标(供排序)
-
何时选择衡量指标?
- 常见:验证集损失(越低越好)或准确率 / F1(越高越好)。
- 指标要 单调,才能与
MedianPruner
直接比较。
-
n_startup_trials
设多大?- 数据集、模型噪声越大,建议多收集些完整 trial 做基线(例如 10~20)。
-
能否自定义更智能的 Pruner?
- 可以:
ThreshholdPruner
、PercentilePruner
、或实现BasePruner
子类,按业务规则剪枝。
- 可以:
何时不需要剪枝?
- 训练极快(每个 trial 几秒内完成)——剪枝节省的时间有限。
- trial 数很少(例如只搜 5 种超参数)——提前终止也看不出明显收益。
- 指标波动极大(早期无法预测最终好坏)——可能误砍掉后期反超的 trial。
- 这时可改用“温和”的
MedianPruner(n_warmup_steps=...)
或调大n_startup_trials
。
- 这时可改用“温和”的
TL;DR
剪枝 = HPO 的“早停”,用来节省时间与算力。
Optuna 只有在 objective 里report
+should_prune
+TrialPruned()
才会真正执行剪枝;否则定义再多的Pruner
也只是摆设。
当触发剪枝的时候的代码执行逻辑
问题:
请看下面的代码,我这样写,当触发剪枝之后,所记录的指标还会返回值吗?
# 5️⃣ 评估
val_loss, v_r2, v_mse, v_rmse, v_mae, _, _ = evaluate(model, val_loader, loss_fn, device)
# --- 剪枝逻辑 begin ---
trial.report(val_loss, epoch) # a. 上报
if trial.should_prune(): # b. 判断
raise optuna.TrialPruned() # c. 触发
# --- 剪枝逻辑 end ---
# 记录所有指标
trial.set_user_attr("train_loss", train_loss)
trial.set_user_attr("val_loss", val_loss)
trial.set_user_attr("val_r2", v_r2)
trial.set_user_attr("val_rmse", v_rmse)
trial.set_user_attr("val_mae", v_mae)
分析:
不会。
一旦执行到
if trial.should_prune():
raise optuna.TrialPruned()
Python 立刻抛出 TrialPruned
异常,函数随即提前结束,后面的 return …
或任何语句都不会再运行。因此:
你关心的对象 | 剪枝发生后会怎样? |
---|---|
objective() 的返回值 |
没有返回值(Python 视为 None ),Optuna 用异常来区分 PRUNED trial,而不是靠 return 。 |
val_loss |
在你 trial.report(val_loss, epoch) 的那一刻已经被写进 trial.intermediate_values ,Optuna 会把 这一轮(触发剪枝的 epoch) 的 val_loss 作为该 trial 最后一次可用的中期指标保存。 |
其它指标 (v_r2 , v_mse …) |
只存在于局部变量;除非你在 raise 之前就调用过 trial.set_user_attr(...) 、写到日志文件或 TensorBoard,否则 Optuna 不会自动帮你保存它们。 |
之后的代码 (trial.set_user_attr(...) 等) |
完全不会执行——因为异常已经中断流程。 |
所以:
- 不会有“返回”;Optuna 通过捕获
TrialPruned
来结束当前 trial。- 能保留下来的只有你已经
report()
过的中期指标(例如val_loss
)以及在抛异常前你主动持久化的任何信息。- 如果还想留存
v_r2
等辅助指标,需要在raise
之前手动trial.set_user_attr("v_r2", v_r2)
,或者写到你自己的日志系统里。
时间点示意
epoch = 7
↓ 1. 计算指标
val_loss = ...
v_r2 = ...
↓ 2. trial.report(val_loss, 7)
intermediate_values[7] = val_loss ← Optuna 现在知道这一步的指标
↓ 3. if trial.should_prune(): True
↓ 4. raise TrialPruned() ← 这里函数立即终止
# ---- 后面的代码不会再执行 ----
因此,“最后一次保存的数据” 就是第 7 个 epoch(也就是触发剪枝的那一轮)的 val_loss
;其它变量除非提前保存,否则随函数结束而丢弃。
更多推荐
所有评论(0)