为什么在超参数优化(HPO)里要 剪枝(pruning)

现象 如果不剪枝会怎样? 引入剪枝后有什么好处?
时间浪费 每一次 trial 都完整训练到 num_epochs,哪怕它在第 2 个 epoch 就已经明显表现最差。 只让“有希望”的 trial 继续,把 GPU/CPU 留给潜力更大的组合,整体搜索更快终止。
资源浪费 训练早期即已落后的模型仍然会继续前向/反向传播,浪费显存、电费。 释放显存、减少 I/O,能在同样预算里尝试更多 trial。
统计效率低 真正有用的结果被大量无价值、长时间运行的 trial 淹没,拖慢找到最优超参数的速度。 更快积累“好 trial”的反馈,优化器能够及时调整搜索方向(例如在贝叶斯优化中更新后验)。

一句话总结:剪枝 = 提前终止明显无望的试验,用更少成本让搜索收敛得更快更好。


Optuna 中剪枝的工作原理

  1. 中期指标上报

    trial.report(metric_value, step)
    

    这一步告诉 Optuna:“在 step(你这里是 epoch)时,这个 trial 的验证损失(或准确率)是多少”。

  2. 比较规则(由 Pruner 决定)
    例如 MedianPruner

    • 先等待 n_startup_trials 个完整 trial(默认 5 个)收集“基线”中期指标。
    • 之后的 trial 若在同一步骤的指标 劣于 所有已完成 trial 指标的中位数,就会被判定为“不值得继续”。
  3. 触发剪枝

    if trial.should_prune():
        raise optuna.TrialPruned()
    
    • 这句抛出 TrialPruned 异常,让 Optuna 把该 trial 标记为 PRUNED 并立即结束训练循环。
    • 训练脚本立刻返回,GPU 资源被下一次 trial 复用。

为什么“光定义 pruner= 还不够”?

  • study = optuna.create_study(pruner=MedianPruner(...)) 只是告诉 Optuna “我想用剪枝”
  • 真正的触发点 必须在 objective 函数里显式上报指标并检查 should_prune()
  • 如果漏掉这两行,Optuna 看不到任何中期指标 → 认为“没有可剪枝信息” → 全部 trial 都会跑满 num_epochs,等于白定义了 Pruner

与早停(early stopping)有何区别?

剪枝 早停
目标 跨 trial:淘汰整体表现落后的 trial 单个模型 内部:防止过拟合、训练震荡
决策依据 其它 trial 的同阶段表现比较 与自身历史最优指标比较
触发后 整个 trial 结束,换下一个超参数组合 结束当前 trial 的剩余 epoch,但保留模型权重

在实际工程里,两者通常 同时使用

  • 剪枝负责把“注定不会成功”的 trial 砍掉;
  • 早停则帮“有希望的 trial”在最佳 epoch 定格,防止过拟合。

代码落地示例

def objective(trial):
    # 1) 读出/采样超参数
    lr = trial.suggest_float('lr', 1e-5, 1e-2, log=True)
    ...

    # 2) 训练循环
    for epoch in range(num_epochs):
        train_loss = train_one_epoch(model, ...)
        val_loss   = evaluate(model, ...)

        # --- 剪枝逻辑 begin ---
        trial.report(val_loss, epoch)        # a. 上报
        if trial.should_prune():             # b. 判断
            raise optuna.TrialPruned()       # c. 触发
        # --- 剪枝逻辑 end ---

    return val_loss  # 3) 返回最终指标(供排序)
  • 何时选择衡量指标?

    • 常见:验证集损失(越低越好)或准确率 / F1(越高越好)。
    • 指标要 单调,才能与 MedianPruner 直接比较。
  • n_startup_trials 设多大?

    • 数据集、模型噪声越大,建议多收集些完整 trial 做基线(例如 10~20)。
  • 能否自定义更智能的 Pruner?

    • 可以:ThreshholdPrunerPercentilePruner、或实现 BasePruner 子类,按业务规则剪枝。

何时不需要剪枝?

  • 训练极快(每个 trial 几秒内完成)——剪枝节省的时间有限。
  • trial 数很少(例如只搜 5 种超参数)——提前终止也看不出明显收益。
  • 指标波动极大(早期无法预测最终好坏)——可能误砍掉后期反超的 trial。
    • 这时可改用“温和”的 MedianPruner(n_warmup_steps=...) 或调大 n_startup_trials

TL;DR

剪枝 = HPO 的“早停”,用来节省时间与算力
Optuna 只有在 objectivereport + should_prune + TrialPruned() 才会真正执行剪枝;否则定义再多的 Pruner 也只是摆设。

当触发剪枝的时候的代码执行逻辑

问题:

请看下面的代码,我这样写,当触发剪枝之后,所记录的指标还会返回值吗?


    # 5️⃣ 评估
    val_loss,  v_r2, v_mse,  v_rmse,  v_mae,  _, _ = evaluate(model, val_loader,  loss_fn, device)

    # --- 剪枝逻辑 begin ---
    trial.report(val_loss, epoch)        # a. 上报
    if trial.should_prune():             # b. 判断
        raise optuna.TrialPruned()       # c. 触发
    # --- 剪枝逻辑 end ---

    # 记录所有指标
    trial.set_user_attr("train_loss", train_loss)
    trial.set_user_attr("val_loss",   val_loss)
    trial.set_user_attr("val_r2",     v_r2)
    trial.set_user_attr("val_rmse",   v_rmse)
    trial.set_user_attr("val_mae",    v_mae)

分析:

不会。
一旦执行到

if trial.should_prune():
    raise optuna.TrialPruned()

Python 立刻抛出 TrialPruned 异常,函数随即提前结束,后面的 return … 或任何语句都不会再运行。因此:

你关心的对象 剪枝发生后会怎样?
objective() 的返回值 没有返回值(Python 视为 None),Optuna 用异常来区分 PRUNED trial,而不是靠 return
val_loss 在你 trial.report(val_loss, epoch) 的那一刻已经被写进 trial.intermediate_values,Optuna 会把 这一轮(触发剪枝的 epoch)val_loss 作为该 trial 最后一次可用的中期指标保存。
其它指标 (v_r2, v_mse …) 只存在于局部变量;除非你在 raise 之前就调用过 trial.set_user_attr(...)、写到日志文件或 TensorBoard,否则 Optuna 不会自动帮你保存它们。
之后的代码 (trial.set_user_attr(...) 等) 完全不会执行——因为异常已经中断流程。

所以:

  • 不会有“返回”;Optuna 通过捕获 TrialPruned 来结束当前 trial。
  • 能保留下来的只有你已经 report() 过的中期指标(例如 val_loss)以及在抛异常前你主动持久化的任何信息。
  • 如果还想留存 v_r2 等辅助指标,需要在 raise 之前手动 trial.set_user_attr("v_r2", v_r2),或者写到你自己的日志系统里。

时间点示意

epoch = 7
           ↓ 1. 计算指标
val_loss = ...
v_r2  = ...

           ↓ 2. trial.report(val_loss, 7)
intermediate_values[7] = val_loss   ← Optuna 现在知道这一步的指标

           ↓ 3. if trial.should_prune(): True
           ↓ 4. raise TrialPruned()      ← 这里函数立即终止
# ---- 后面的代码不会再执行 ----

因此,“最后一次保存的数据” 就是第 7 个 epoch(也就是触发剪枝的那一轮)的 val_loss;其它变量除非提前保存,否则随函数结束而丢弃。

Logo

有“AI”的1024 = 2048,欢迎大家加入2048 AI社区

更多推荐