问:这个流程图是介绍lgb的goss的,请介绍下goss的具体计算流程


答:

这张图非常清晰地展示了LightGBM中GOSS(Gradient-based One-Side Sampling)算法的核心流程。这是一个用于加速梯度提升模型训练的革命性采样方法。

下面我将根据流程图,为您逐步解析GOSS的具体计算流程。

GOSS算法的核心思想

GOSS算法的基本思想非常直观:​在每一轮构建新子树时,不是使用全部数据来计算梯度,而是对训练样本进行智能采样,保留那些梯度绝对值大的样本(即模型目前还“搞不懂”的困难样本),并随机抽样少量梯度绝对值小的样本(即模型已经“学得比较好”的简单样本),同时通过对小梯度样本进行加权来弥补采样带来的数据分布偏差。​

这样做的好处是,训练效率大幅提升,同时由于保住了信息量最大的样本,模型精度损失很小。


GOSS具体计算流程分步解析

以下流程对应伪代码中的每一步:

第一步:初始化参数与模型集合

  • 输入​:
    • I: 训练数据集。
    • d: 迭代次数(即要生成的弱学习器/决策树的数量)。
    • a: 大梯度数据的采样比率。
    • b: 小梯度数据的采样比率。
    • loss: 损失函数。
    • L: 弱学习器(通常是决策树)。
  • 初始化​:
    • models ← {}:创建一个空集合,用于存储后续训练好的所有弱模型。
    • fact ← (1-a)/b:计算一个权重系数。这个系数至关重要,用于后续对小梯度样本进行加权,以补偿采样后数据分布的变化。

第二步:计算采样数量

  • topN ← a × len(I):计算需要保留的大梯度样本的数量。例如,有10000个样本,a=0.1,则topN=1000。
  • randN ← b × len(I):计算需要随机采样的小梯度样本的数量。例如,b=0.5,则randN=500。

第三步:开始迭代训练(for i=1 to d)​
这是算法的核心循环,每一轮迭代训练一棵新的树。

  1. 计算当前预测值与梯度​:

    • preds ← models.predict(I):使用当前已训练的所有模型(models集合)对全部训练数据 I 进行预测,得到当前模型的集成预测值。
    • g ← loss(I, preds):根据损失函数计算每个样本的负梯度(negative gradient)​。在梯度提升中,负梯度指示了当前模型下每个样本需要被修正的方向和强度,其绝对值大小代表了样本的“信息量”或“难度”。绝对值越大,说明模型对这个样本的预测误差越大,越需要关注。
  2. 样本排序与采样​:

    • sorted ← GetSortedIndices(abs(g)):根据梯度绝对值 abs(g) 从大到小对所有样本的索引进行排序。这样,列表开头的样本就是“信息量”最大的样本。
    • topSet ← sorted[1:topN]:直接选取排序后列表前 topN 个样本,组成大梯度样本集
    • randSet ← RandomPick(sorted[topN:len(I)], randN):从剩余的小梯度样本(sorted[topN:len(I)])中,​随机抽取 randN 个样本,组成小梯度样本集
  3. 合并样本集并分配权重​:

    • usedSet ← topSet + randSet:将上述两个样本集合并,形成本轮迭代最终用于训练的数据集
    • w[randSet] ×= fact:这是GOSS算法的关键一步!为所有被随机抽中的小梯度样本的权重乘以之前计算的系数 fact。例如,如果a=0.1, b=0.5,则fact=(1-0.1)/0.5=1.8。这意味着小梯度样本在计算损失时的“影响力”被增大了,以此来弥补我们没有使用全部小梯度样本所造成的数据分布偏差,确保训练过程的准确性。
  4. 训练新弱学习器并更新模型​:

    • newModel ← L(I[usedSet], -g[usedSet], w[usedSet]):使用合并后的样本集 I[usedSet]、这些样本对应的负梯度 -g[usedSet] 以及调整后的样本权重 w[usedSet] 来训练一个新的弱学习器(决策树)。
    • models.append(newModel):将新训练好的弱学习器加入到模型集合中。

第四步:循环结束与输出

  • 当循环执行完 d 次后,算法结束。最终得到的 models 集合就是训练好的Gradient Boosting模型,可用于对新数据进行预测。

总结

GOSS算法通过​“保留全部大梯度样本 + 随机采样部分小梯度样本 + 对小梯度样本加权”​​ 的策略,巧妙地实现了训练速度的大幅提升,同时基本保持了模型的精度。这使得LightGBM在处理大规模数据集时具有显著的速度优势,成为许多数据科学竞赛和工业应用中的首选工具。


相关文章:

1、算法篇(LightGBM理论介绍)

Logo

有“AI”的1024 = 2048,欢迎大家加入2048 AI社区

更多推荐