大模型pretrain蒸馏训练中(不是sft),如何利用教师模型对学生模型进行蒸馏训练?从蒸馏数据的生产到学生模型的训练,给出详细的步骤

下面用“老师带学生做完形填空”的比喻,把 Pretrain 阶段的蒸馏(不是 SFT)蒸馏数据生产学生模型训练 讲清楚。你可以把它当成一条“流水线”。


0. 先搞清楚:Pretrain 蒸馏到底在学什么?

预训练(pretrain)的大模型核心任务通常是 预测下一个 token(字/词片段)
蒸馏就是:学生不只学“标准答案”(语料里的真实下一个 token),还学 老师的“解题思路”——也就是老师对“可能的下一个 token”的一整套概率分布(软标签)。

  • 硬标签(hard label):语料里真实的下一个 token(就像题目标准答案)
  • 软标签(soft label):老师模型对所有候选 token 的概率(像老师说:A 也行但更偏 B,C 也合理但概率小)

1. 准备阶段:确定老师、学生、词表与数据

1.1 选老师模型(Teacher)

  • 大、强、效果好(比如更大参数、更好训练)
  • 能稳定推理(你要让它跑很多数据)

1.2 选学生模型(Student)

  • 更小、更快、你想部署的那个
  • 一般会用类似的结构(Transformer)

1.3 强烈建议:老师和学生用同一个 tokenizer/词表

原因很简单:蒸馏时老师给的是“每个 token 的概率”,如果词表不一样就很难对齐。

  • 同词表:直接对齐蒸馏概率 ✅
  • 不同词表:要做复杂映射(麻烦、误差大)⚠️

1.4 准备预训练语料(不需要人工标注)

来源可以是:

  • 大规模网页文本、书籍、百科、代码、论文等
  • 你也可以做“混合配方”(例如 70% 通用文本 + 30% 领域文本)

2. 蒸馏数据怎么“生产”?(离线 or 在线)

蒸馏数据生产的核心:把老师对每个位置的预测分布算出来并保存

有两种常见方式:

2.1 在线蒸馏(Online Distillation)

训练学生时,每个 batch 都顺便把同样输入丢给老师算一次概率。

  • 优点:不需要提前存数据,永远“最新”
  • 缺点:极其耗算力(老师一直要跑),大规模 pretrain 通常很贵

2.2 离线蒸馏(Offline Distillation)✅(更常用)

先让老师把海量语料“过一遍”,把 soft label 存成蒸馏数据;学生训练时只读文件。

  • 优点:训练学生时不需要老师,成本低很多
  • 缺点:要花一次性成本做蒸馏数据

下面重点讲 离线蒸馏 的完整步骤(工业界很常见)。


3. 离线蒸馏数据生产:一步步做

Step 3.1 清洗与切片(把文本切成训练样本)

  1. 文本清洗:去乱码、重复、太短/太长、明显广告等
  2. 分词:用 tokenizer 把文本变成 token 序列
  3. 切成固定长度块(例如长度 2048):
  • 输入:前面的 token
  • 目标:每个位置的“下一个 token”

Step 3.2 用老师模型跑推理,拿到 logits / 概率

对每个序列位置 (t),老师会输出一串 logits(对词表里每个 token 一个分数)。
把 logits 变成概率一般用 softmax(可加温度系数 (T)):

pteacher(i)=ezi/T∑jezj/T p_{\text{teacher}}(i)=\frac{e^{z_i/T}}{\sum_j e^{z_j/T}} pteacher(i)=jezj/Tezi/T

  • (z_i):老师对第 (i) 个 token 的 logit
  • (T):温度(常见 (T>1) 让分布更“软”,能传递更多信息)

Step 3.3 只存 top-k(关键省钱技巧)

词表可能 50k~200k,如果每个位置存全分布,数据量爆炸。
所以通常只存 top-k(比如 k=32/64/128):

对每个位置保存:

  • top-k token 的 id 列表
  • 对应的概率或 log-prob
  • (可选)保存温度 (T)、序列 id、位置等

并对 top-k 做重归一化(只在 top-k 内当作分布):

p~(i)=p(i)∑j∈top-kp(j) \tilde{p}(i)=\frac{p(i)}{\sum_{j\in \text{top-}k} p(j)} p~(i)=jtop-kp(j)p(i)

直觉:老师说“最可能的 64 个答案分别多大概率”,剩下的都太小可以忽略。

Step 3.4 蒸馏数据文件格式(简单理解即可)

一条样本通常包含:

  • input_ids:输入 token 序列
  • teacher_topk_ids:每个位置的 top-k token
  • teacher_topk_probs:每个位置的 top-k 概率(或 log 概率)
  • attention_mask(如果需要)

为了训练速度,会做:

  • 分片(shard)成很多文件
  • 压缩、量化(例如把概率存 float16 或更低精度)

4. 学生模型怎么训练?(损失函数怎么写)

训练学生时,你会同时让学生做两件事:

  1. 像正常预训练一样学真实下一个 token(硬标签)
  2. 模仿老师的概率分布(软标签,蒸馏)

4.1 硬标签损失:交叉熵(普通预训练)

对真实下一个 token (y):

LCE=−log⁡pstudent(y) L_{\text{CE}}=-\log p_{\text{student}}(y) LCE=logpstudent(y)

4.2 软标签损失:KL 散度(蒸馏核心)

让学生分布接近老师分布:

LKD=T2⋅KL(pteacher(T),∣,pstudent(T)) L_{\text{KD}}=T^2 \cdot \mathrm{KL}\left(p_{\text{teacher}}^{(T)} ,|, p_{\text{student}}^{(T)}\right) LKD=T2KL(pteacher(T),,pstudent(T))

其中 (p^{(T)}) 表示用温度 (T) 计算出来的分布。
乘 (T^2) 是经典做法,让梯度尺度更合适(常见经验)。

如果你只存了 top-k,就在 top-k 上算近似 KL(并用上面重归一化后的 (\tilde{p}))。

4.3 总损失:两者加权

L=αLCE+(1−α)LKD L=\alpha L_{\text{CE}} + (1-\alpha)L_{\text{KD}} L=αLCE+(1α)LKD

  • (\alpha) 常见 0.2~0.8:

    • (\alpha) 大:更像普通预训练
    • (\alpha) 小:更“听老师的”

高中生直觉:你既做“课后习题”(真实语料答案),也做“老师讲解题”(老师概率分布),按比例安排时间。


5. 训练流程(从数据到跑起来)

Step 5.1 初始化学生模型

  • 随机初始化 ✅
  • 或者用一个小模型 checkpoint 继续训练 ✅(更稳、更快收敛)

Step 5.2 读入数据

每个 batch 读:

  • input_ids(学生输入)
  • labels(真实下一个 token,用于 CE)
  • teacher_topk_ids/probs(用于 KD)

Step 5.3 前向计算

  1. 学生输出 logits
  2. 计算 (L_{\text{CE}})
  3. 用 top-k 抽取学生在这些 token 上的概率,计算 (L_{\text{KD}})
  4. 合成总损失 (L)

Step 5.4 反向传播与更新

  • 常规优化器(AdamW 等)
  • 学习率 warmup + 衰减
  • 混合精度(fp16/bf16)提升速度
  • 梯度累积(显存不够就用)

Step 5.5 评估与监控(很重要)

至少看:

  • perplexity(困惑度):越低一般越好
  • 蒸馏对齐情况:学生分布是否接近老师
  • 泛化:在不同域文本上是否都好用

6. 常见“升级版”蒸馏(可选,但你先知道名字)

  1. 中间层蒸馏(Feature Distillation)
    不仅对齐输出分布,也对齐隐藏层表示(像让学生模仿老师的“草稿纸”)
  2. 多老师蒸馏
    不同老师擅长不同领域,混合其分布
  3. 数据再生成(Synthetic Data)
    让老师生成额外文本补充训练(注意质量控制与去重复)

7. 最小可行方案(你照着就能做的版本)

如果你想要一个“最能落地、最不绕”的配方:

  1. 老师学生同 tokenizer

  2. 拿你的预训练语料切块

  3. 离线跑老师,保存每个位置 top-64 概率(温度 (T=2))

  4. 学生训练用:

    • (\alpha=0.5)
    • 总损失 (L=\alpha L_{\text{CE}}+(1-\alpha)L_{\text{KD}})
  5. 先小规模验证(比如 1B tokens),再扩大


如果你愿意,我也可以给你一份“伪代码级”的训练循环(不绑定具体框架),把 top-k 版 KL 怎么算写得更像程序,让你更容易实现。你更想用 PyTorch 风格还是更抽象的流程图风格?

详细解释上述 teacher_topk_ids、teacher_topk_probs 的数据格式,假设k = 3, seq_len=10, vocab_size=30,给出详细的数据演示,不要有省略号

下面把 teacher_topk_idsteacher_topk_probs数据格式讲清楚,并按你给定的参数 k=3、seq_len=10、vocab_size=30 给出完整、无省略号的数据演示。


1) 这两个张量到底表示什么?

在语言模型预训练里,我们有一段 token 序列:

  • input_ids:长度为 seq_len 的 token id 序列(每个 id 都在 [0, vocab_size-1] 内)
  • labels:训练目标(通常是 input_ids 左移一位)

教师模型(teacher)在每个位置都会输出对“下一个 token”的概率分布(在全词表 30 个 token 上)。

但全分布太大,所以只保存每个位置概率最高的 top-k = 3 个 token:

  • teacher_topk_ids[t]:第 t 个位置,teacher 认为最可能的 3 个 token 的 id
  • teacher_topk_probs[t]:对应这 3 个 token 的 概率

2) 形状(shape)是什么?

单条样本(batch_size=1)最常见的形状是:

  • teacher_topk_ids: [(seq_len), (k)] = [10, 3],整数类型(int16/int32)
  • teacher_topk_probs: [(seq_len), (k)] = [10, 3],浮点类型(float16/float32)

如果带 batch(比如 batch_size=B):

  • teacher_topk_ids: [B, 10, 3]
  • teacher_topk_probs: [B, 10, 3]

3) 概率是否一定加起来等于 1?

两种常见存法:

A. Top-k 内重归一化(更常见,也更方便)

每个位置只在 top-3 里当作一个“小分布”,所以每行 3 个概率之和为 1:

KaTeX parse error: Expected 'EOF', got '_' at position 30: …} \text{teacher_̲topk_probs}[t][…

B. 保存原始 softmax 概率(未重归一化)

top-3 的概率和会小于 1,因为还有很多概率被丢掉了。

下面我给你的演示采用 A(重归一化),因为更直观。


4) 完整数据演示(k=3, seq_len=10, vocab_size=30)

4.1 一条样本的 input_idslabels

设词表大小 vocab_size=30,token id 范围是 0~29

输入:

{
  "vocab_size": 30,
  "seq_len": 10,
  "k": 3,
  "input_ids": [12, 5, 7, 3, 18, 21, 9, 2, 14, 6],
  "labels":    [ 5, 7, 3, 18, 21,  9, 2, 14,  6, -100],
  "attention_mask": [1,1,1,1,1,1,1,1,1,1]
}

解释一下 labels

  • 语言模型训练通常是“预测下一个 token”,所以 labels[t] = input_ids[t+1]
  • 最后一个位置没有“下一个 token”,通常用 -100 这种 ignore_index 表示不参与损失计算

4.2 对应的 teacher_topk_ids(形状 [10,3])

teacher_topk_ids[t] 表示在位置 t,teacher 对 labels[t] 这个“下一个 token”的预测中,最可能的 3 个 token id。

{
  "teacher_topk_ids": [
    [ 5,  8,  1],
    [ 7,  4,  9],
    [ 3, 10,  6],
    [18, 17, 11],
    [21, 22, 19],
    [ 9, 13,  8],
    [ 2, 15,  1],
    [14, 16, 12],
    [ 6, 20,  7],
    [ 4,  6,  5]
  ]
}

注意:

  • 每个数字都在 0~29 内,符合 vocab_size=30
  • 每行正好 3 个 id(因为 k=3)
  • 这里我让前 9 行都包含真实标签(比如第 0 行包含 label=5),便于你理解;真实情况不一定总包含

4.3 对应的 teacher_topk_probs(形状 [10,3])

每一行的 3 个概率与 teacher_topk_ids 一一对应,并且(按我们采用的存法)每行加起来等于 1。

{
  "teacher_topk_probs": [
    [0.62, 0.23, 0.15],
    [0.55, 0.30, 0.15],
    [0.50, 0.27, 0.23],
    [0.58, 0.22, 0.20],
    [0.47, 0.32, 0.21],
    [0.51, 0.28, 0.21],
    [0.66, 0.19, 0.15],
    [0.44, 0.33, 0.23],
    [0.49, 0.26, 0.25],
    [0.40, 0.35, 0.25]
  ]
}

举例解释第 0 个位置(t=0):

  • input_ids[0] = 12(输入看到 token 12)

  • 目标 labels[0] = 5(希望预测下一个 token 是 5)

  • teacher 认为 top-3 最可能的是:

    • token 5:概率 0.62
    • token 8:概率 0.23
    • token 1:概率 0.15

5) 把它们放在一起看(逐位置对齐)

下面是把 labels[t]teacher_topk_ids[t]teacher_topk_probs[t] 放一行里对齐:

t label[t] teacher_topk_ids[t] teacher_topk_probs[t]
0 5 [5, 8, 1] [0.62, 0.23, 0.15]
1 7 [7, 4, 9] [0.55, 0.30, 0.15]
2 3 [3, 10, 6] [0.50, 0.27, 0.23]
3 18 [18, 17, 11] [0.58, 0.22, 0.20]
4 21 [21, 22, 19] [0.47, 0.32, 0.21]
5 9 [9, 13, 8] [0.51, 0.28, 0.21]
6 2 [2, 15, 1] [0.66, 0.19, 0.15]
7 14 [14, 16, 12] [0.44, 0.33, 0.23]
8 6 [6, 20, 7] [0.49, 0.26, 0.25]
9 -100 [4, 6, 5] [0.40, 0.35, 0.25]

最后一行(t=9):

  • label[9] = -100 表示不算损失
  • teacher_topk_*[9] 仍然可以存在(有些实现为了对齐 shape 直接填了一个“无用但合法”的 top-k)

6) 实际训练时怎么用这两项(你理解格式就够用)

学生模型会输出对全词表(30 个 token)的 logits。蒸馏时只取老师给的 top-3 那几个 token,对学生的概率做比较即可(省算力、省显存)。

如果使用温度 (T),学生在 top-k 上的分布(示意):

pstudent(T)(i)=ezi/T∑j∈top-kezj/T p_{\text{student}}^{(T)}(i)=\frac{e^{z_i/T}}{\sum_{j\in \text{top-}k} e^{z_j/T}} pstudent(T)(i)=jtop-kezj/Tezi/T

然后让这个分布尽量接近 teacher_topk_probs[t]


如果你希望我把“这一条样本在训练时如何计算 KD loss(top-k 版 KL)”也用同样的 k=3, seq_len=10, vocab=30 做一个完整数值演示(同样不省略),我也可以继续往下写。

把“这一条样本在训练时如何计算 KD loss(top-k 版 KL)”也用同样的 k=3, seq_len=10, vocab=30 做一个完整数值演示(同样不省略)

下面用你前面那一条样本(k=3, seq_len=10, vocab_size=30)做一次完整的 top-k 版 KD loss(KL)数值演示。我会把每一步需要的数字都写出来,并且不使用省略号


0) 这次演示的约定

  1. 温度取 T=2T=2T=2
  2. teacher_topk_probs 视为:已经用同样温度得到,并且在 top-k 内重归一化后的概率(每行 3 个数相加等于 1)
  3. labels[9] = -100 表示最后一个位置不计入 loss,所以只对 t=0 到 t=8 共 9 个位置计算 KD loss
  4. 采用自然对数 ln⁡\lnln

1) 需要的公式(top-k 版)

1.1 学生在 top-k 上的概率(只对那 3 个 token 做 softmax)

对某个位置 t,teacher 给出 top-k token 的顺序是 teacher_topk_ids[t],学生取出这 3 个 token 的 logits(按同样顺序),得到三个数 z1,z2,z3z_1,z_2,z_3z1,z2,z3。然后:

qj=exp⁡(zj/T)exp⁡(z1/T)+exp⁡(z2/T)+exp⁡(z3/T) q_j=\frac{\exp(z_j/T)}{\exp(z_1/T)+\exp(z_2/T)+\exp(z_3/T)} qj=exp(z1/T)+exp(z2/T)+exp(z3/T)exp(zj/T)

这里 qjq_jqj 就是学生在 top-k 上的概率分布。

1.2 top-k KL(老师分布 p 到学生分布 q)

KL(p∣q)=∑j=13pjln⁡(pjqj) \mathrm{KL}(p|q)=\sum_{j=1}^{3} p_j \ln\left(\frac{p_j}{q_j}\right) KL(pq)=j=13pjln(qjpj)

1.3 KD loss(常见写法带 T2T^2T2

如果对有效位置做平均:
LKD=T2⋅19∑t=08KLt L_{\mathrm{KD}}=T^2 \cdot \frac{1}{9}\sum_{t=0}^{8}\mathrm{KL}_t LKD=T291t=08KLt


2) 先演示一次「从 vocab_size=30 的 logits 里 gather 出 top-k logits」

t=0 为例,假设学生对全词表 30 个 token 的 logits 是下面这个长度为 30 的数组(下标就是 token id):

{
  "student_logits_vocab30_t0": [
    -1.2,  0.7, -0.5,  0.1, -0.3,  2.1, -0.8,  0.4,  1.2, -0.6,
     0.2, -1.0, -0.2,  0.5, -0.7,  0.0, -0.9,  0.3, -0.4,  0.6,
    -1.1, -0.1,  0.8, -0.2,  0.9, -0.5,  0.1, -0.6,  0.2, -0.3
  ]
}

t=0 的 teacher top-k 是:

  • teacher_topk_ids[0] = [5, 8, 1]

所以我们按顺序取出学生 logits:

  • id=5 的 logit 是 2.1
  • id=8 的 logit 是 1.2
  • id=1 的 logit 是 0.7

得到:

  • student_topk_logits[0] = [2.1, 1.2, 0.7]

后面每个位置都同理:先从 30 维 logits 中按 top-k ids 取出 3 个数,再在这 3 个数上做 softmax 与 KL。


3) 本条样本的 teacher top-k(t=0 到 t=8)

我沿用你上一问的 teacher 数据:

  • t=0 ids [5, 8, 1] probs [0.62, 0.23, 0.15]
  • t=1 ids [7, 4, 9] probs [0.55, 0.30, 0.15]
  • t=2 ids [3, 10, 6] probs [0.50, 0.27, 0.23]
  • t=3 ids [18, 17, 11] probs [0.58, 0.22, 0.20]
  • t=4 ids [21, 22, 19] probs [0.47, 0.32, 0.21]
  • t=5 ids [9, 13, 8] probs [0.51, 0.28, 0.21]
  • t=6 ids [2, 15, 1] probs [0.66, 0.19, 0.15]
  • t=7 ids [14, 16, 12] probs [0.44, 0.33, 0.23]
  • t=8 ids [6, 20, 7] probs [0.49, 0.26, 0.25]

并假设学生在每个位置取出的 top-k logits(按 teacher_topk_ids 的顺序)是:

  • t=0 logits [2.1, 1.2, 0.7]
  • t=1 logits [1.8, 1.1, 0.3]
  • t=2 logits [1.4, 0.9, 0.6]
  • t=3 logits [2.0, 1.0, 0.8]
  • t=4 logits [1.5, 1.3, 0.9]
  • t=5 logits [1.6, 1.0, 0.4]
  • t=6 logits [2.2, 0.9, 0.5]
  • t=7 logits [1.2, 1.0, 0.6]
  • t=8 logits [1.7, 1.0, 0.9]

4) 逐位置完整计算(t=0 到 t=8)

下面每个位置都按同样顺序算:
缩放 logits(除以 T) → 指数 → 归一化得 q → 逐项算 KL


t=0

  • teacher_topk_ids[0] = [5, 8, 1]
  • p = [0.620000, 0.230000, 0.150000]
  • 学生 logits = [2.100000, 1.200000, 0.700000]
  • 缩放 logits(除以 T=2) = [1.050000, 0.600000, 0.350000]
  • exp(缩放) = [2.857651, 1.822119, 1.419068]
  • sumexp = 6.098837
  • 学生 q = [0.468557, 0.298765, 0.232678]

逐项 KL:

  • j=1 (id=5): ln(p/q)=ln(0.620000/0.468557)=0.280062,p*ln(p/q)=0.173639

  • j=2 (id=8): ln(p/q)=ln(0.230000/0.298765)=-0.261578,p*ln(p/q)=-0.060163

  • j=3 (id=1): ln(p/q)=ln(0.150000/0.232678)=-0.439022,p*ln(p/q)=-0.065853

  • KL_0 = 0.047623


t=1

  • teacher_topk_ids[1] = [7, 4, 9]
  • p = [0.550000, 0.300000, 0.150000]
  • 学生 logits = [1.800000, 1.100000, 0.300000]
  • 缩放 logits = [0.900000, 0.550000, 0.150000]
  • exp(缩放) = [2.459603, 1.733253, 1.161834]
  • sumexp = 5.354690
  • 学生 q = [0.459336, 0.323689, 0.216975]

逐项 KL:

  • j=1 (id=7): ln(p/q)=ln(0.550000/0.459336)=0.179845,p*ln(p/q)=0.098915

  • j=2 (id=4): ln(p/q)=ln(0.300000/0.323689)=-0.076117,p*ln(p/q)=-0.022835

  • j=3 (id=9): ln(p/q)=ln(0.150000/0.216975)=-0.368314,p*ln(p/q)=-0.055247

  • KL_1 = 0.020903


t=2

  • teacher_topk_ids[2] = [3, 10, 6]
  • p = [0.500000, 0.270000, 0.230000]
  • 学生 logits = [1.400000, 0.900000, 0.600000]
  • 缩放 logits = [0.700000, 0.450000, 0.300000]
  • exp(缩放) = [2.013753, 1.568312, 1.349859]
  • sumexp = 4.931924
  • 学生 q = [0.408336, 0.317965, 0.273699]

逐项 KL:

  • j=1 (id=3): ln(p/q)=ln(0.500000/0.408336)=0.202058,p*ln(p/q)=0.101029

  • j=2 (id=10): ln(p/q)=ln(0.270000/0.317965)=-0.163826,p*ln(p/q)=-0.044233

  • j=3 (id=6): ln(p/q)=ln(0.230000/0.273699)=-0.173504,p*ln(p/q)=-0.039686

  • KL_2 = 0.017110


t=3

  • teacher_topk_ids[3] = [18, 17, 11]
  • p = [0.580000, 0.220000, 0.200000]
  • 学生 logits = [2.000000, 1.000000, 0.800000]
  • 缩放 logits = [1.000000, 0.500000, 0.400000]
  • exp(缩放) = [2.718282, 1.648721, 1.491825]
  • sumexp = 5.858828
  • 学生 q = [0.463974, 0.281408, 0.254618]

逐项 KL:

  • j=1 (id=18): ln(p/q)=ln(0.580000/0.463974)=0.223812,p*ln(p/q)=0.129811

  • j=2 (id=17): ln(p/q)=ln(0.220000/0.281408)=-0.246260,p*ln(p/q)=-0.054177

  • j=3 (id=11): ln(p/q)=ln(0.200000/0.254618)=-0.241588,p*ln(p/q)=-0.048318

  • KL_3 = 0.027012


t=4

  • teacher_topk_ids[4] = [21, 22, 19]
  • p = [0.470000, 0.320000, 0.210000]
  • 学生 logits = [1.500000, 1.300000, 0.900000]
  • 缩放 logits = [0.750000, 0.650000, 0.450000]
  • exp(缩放) = [2.117000, 1.915541, 1.568312]
  • sumexp = 5.600853
  • 学生 q = [0.377793, 0.342016, 0.280191]

逐项 KL:

  • j=1 (id=21): ln(p/q)=ln(0.470000/0.377793)=0.217017,p*ln(p/q)=0.101998

  • j=2 (id=22): ln(p/q)=ln(0.320000/0.342016)=-0.066495,p*ln(p/q)=-0.021278

  • j=3 (id=19): ln(p/q)=ln(0.210000/0.280191)=-0.288806,p*ln(p/q)=-0.060017

  • KL_4 = 0.020703


t=5

  • teacher_topk_ids[5] = [9, 13, 8]
  • p = [0.510000, 0.280000, 0.210000]
  • 学生 logits = [1.600000, 1.000000, 0.400000]
  • 缩放 logits = [0.800000, 0.500000, 0.200000]
  • exp(缩放) = [2.225541, 1.648721, 1.221403]
  • sumexp = 5.095665
  • 学生 q = [0.436751, 0.323544, 0.239705]

逐项 KL:

  • j=1 (id=9): ln(p/q)=ln(0.510000/0.436751)=0.154753,p*ln(p/q)=0.078924

  • j=2 (id=13): ln(p/q)=ln(0.280000/0.323544)=-0.144525,p*ln(p/q)=-0.040467

  • j=3 (id=8): ln(p/q)=ln(0.210000/0.239705)=-0.132181,p*ln(p/q)=-0.027758

  • KL_5 = 0.010818


t=6

  • teacher_topk_ids[6] = [2, 15, 1]
  • p = [0.660000, 0.190000, 0.150000]
  • 学生 logits = [2.200000, 0.900000, 0.500000]
  • 缩放 logits = [1.100000, 0.450000, 0.250000]
  • exp(缩放) = [3.004166, 1.568312, 1.284025]
  • sumexp = 5.856503
  • 学生 q = [0.512979, 0.267796, 0.219225]

逐项 KL:

  • j=1 (id=2): ln(p/q)=ln(0.660000/0.512979)=0.251564,p*ln(p/q)=0.166033

  • j=2 (id=15): ln(p/q)=ln(0.190000/0.267796)=-0.343130,p*ln(p/q)=-0.065195

  • j=3 (id=1): ln(p/q)=ln(0.150000/0.219225)=-0.379624,p*ln(p/q)=-0.056632

  • KL_6 = 0.044206


t=7

  • teacher_topk_ids[7] = [14, 16, 12]
  • p = [0.440000, 0.330000, 0.230000]
  • 学生 logits = [1.200000, 1.000000, 0.600000]
  • 缩放 logits = [0.600000, 0.500000, 0.300000]
  • exp(缩放) = [1.822119, 1.648721, 1.349859]
  • sumexp = 4.820699
  • 学生 q = [0.377963, 0.342000, 0.280037]

逐项 KL:

  • j=1 (id=14): ln(p/q)=ln(0.440000/0.377963)=0.151690,p*ln(p/q)=0.066744

  • j=2 (id=16): ln(p/q)=ln(0.330000/0.342000)=-0.035703,p*ln(p/q)=-0.011782

  • j=3 (id=12): ln(p/q)=ln(0.230000/0.280037)=-0.196330,p*ln(p/q)=-0.045159

  • KL_7 = 0.009803


t=8

  • teacher_topk_ids[8] = [6, 20, 7]
  • p = [0.490000, 0.260000, 0.250000]
  • 学生 logits = [1.700000, 1.000000, 0.900000]
  • 缩放 logits = [0.850000, 0.500000, 0.450000]
  • exp(缩放) = [2.339647, 1.648721, 1.568312]
  • sumexp = 5.556680
  • 学生 q = [0.421114, 0.296696, 0.282190]

逐项 KL:

  • j=1 (id=6): ln(p/q)=ln(0.490000/0.421114)=0.151608,p*ln(p/q)=0.074288

  • j=2 (id=20): ln(p/q)=ln(0.260000/0.296696)=-0.132139,p*ln(p/q)=-0.034356

  • j=3 (id=7): ln(p/q)=ln(0.250000/0.282190)=-0.121143,p*ln(p/q)=-0.030284

  • KL_8 = 0.009647


5) 汇总:平均 KL 与最终 KD loss

9 个有效位置的 KL 值(t=0 到 t=8):

  • KL_0 = 0.047623
  • KL_1 = 0.020903
  • KL_2 = 0.017110
  • KL_3 = 0.027012
  • KL_4 = 0.020703
  • KL_5 = 0.010818
  • KL_6 = 0.044206
  • KL_7 = 0.009803
  • KL_8 = 0.009647

求和:

  • ∑t=08KLt=0.207824\sum_{t=0}^{8} \mathrm{KL}_t = 0.207824t=08KLt=0.207824

求平均(除以 9):

  • 19∑t=08KLt=0.023092\frac{1}{9}\sum_{t=0}^{8}\mathrm{KL}_t = 0.02309291t=08KLt=0.023092

带上 T2=4T^2=4T2=4 的 KD loss(按平均写法):
LKD=4×0.023092=0.092366 L_{\mathrm{KD}}=4 \times 0.023092 = 0.092366 LKD=4×0.023092=0.092366

补充:如果你的实现不是“平均”,而是“直接求和”,那就是:
LKD,sum=4×0.207824=0.831297 L_{\mathrm{KD,sum}}=4 \times 0.207824 = 0.831297 LKD,sum=4×0.207824=0.831297


如果你希望,我可以把这套计算再对应到一段非常接近 PyTorch 的伪代码(包括如何用 gather 取 top-k logits、如何 mask 掉 -100 的位置),并把每一步张量的 shape 都写清楚。

继续分析,基于上述计算KL散度的同时,一般业界还需要计算真实label的loss吗?详细分析原因

一般情况下:业界做 pretrain 蒸馏时,除了算 KD 的 KL 散度,仍然会同时计算真实 label 的 loss(也就是普通预训练的交叉熵 CE),然后把两者加权混合训练。

最常见的总损失长这样(你前面也见过):

L=αLCE+(1−α),T2⋅KL(pteacher(T)∣pstudent(T)) L=\alpha L_{\text{CE}} + (1-\alpha),T^2 \cdot \mathrm{KL}(p_{\text{teacher}}^{(T)}|p_{\text{student}}^{(T)}) L=αLCE+(1α),T2KL(pteacher(T)pstudent(T))

下面说清楚“为什么还要算真实 label 的 CE”,以及“什么时候可能不算”。


为什么通常还要算真实 label 的 CE loss?

1) 老师不是“真理”,CE 是把学生拴在真实数据上的“地面”

KD 让学生模仿老师,但老师也会犯错、也会带偏见(比如对某些罕见词/新词判断差)。
如果只学老师,学生会把老师的错误也学得很牢。CE 用真实语料里的下一个 token 做监督,相当于告诉学生:

“你可以参考老师的思路,但最终要对真实数据负责。”

这在老师并非完美、或蒸馏数据覆盖面不够时尤其重要。


2) 你用的是 top-k 蒸馏:KD 只教了 3 个候选,CE 才能教“剩下的 27 个”

你当前是 top-k 版 KD(每个位置只存 k=3 个 token 的概率)。这会带来一个硬限制:

  • KD 只会对这 3 个 token 产生直接的训练信号
  • 词表是 30(真实情况是 50k+),绝大多数 token 在 KD 里“从未被提及”

即使你在 top-k 上做了 softmax 得到 qqq,并算 KL:

KL(p∣q)=∑j=13pjln⁡pjqj \mathrm{KL}(p|q)=\sum_{j=1}^3 p_j\ln\frac{p_j}{q_j} KL(pq)=j=13pjlnqjpj

它也只约束 top-3 的相对关系。
CE 则会让学生必须把真实 label 的概率抬高,从而给到“全词表”更健康的学习信号(尤其对长尾 token)。


3) top-k 里可能不包含真实 label:只做 KD 会“学不到正确答案”

在你的演示数据里,很多行 top-k 恰好包含了真实 label,这是为了好理解;但真实训练里经常发生:

  • 真实 label 是个低频 token
  • 老师没有把它排进 top-k(比如排到第 80 名)
  • 于是 top-k 蒸馏文件里根本没记录它

这时如果只做 KD,学生会被迫在 top-k 里“选一个最像老师的”,但永远不会被明确要求去预测真实 label
CE 则直接对真实 label 施压:

LCE=−ln⁡pstudent(y) L_{\text{CE}}=-\ln p_{\text{student}}(y) LCE=lnpstudent(y)

所以 CE 是避免“top-k 截断带来的监督缺失”的关键补丁。


4) KD 会继承老师的“概率形状”,但 CE 帮学生更好贴合语料分布(降低 PPL)

Pretrain 最常看的指标之一是 perplexity(困惑度),本质上来自 CE。
只做 KD 时,学生优化的是“像老师”,不是“最大化真实语料似然”。就会出现:

  • 学生很像老师(KL 小)
  • 但对真实语料的 PPL 可能不如“KD+CE”的模型低

所以许多团队会把 CE 当作“主任务”,KD 当作“老师的提示/正则化”。


5) 稳定训练:CE 能防止 KD 带来的“奇怪收敛”或退化

只用 KD 时,训练目标完全由老师分布决定。容易遇到:

  • 老师分布过于尖锐/过于平(温度、校准问题)
  • 蒸馏数据噪声(老师推理出错、截断误差、量化误差)
  • 学生容量小,硬模仿导致训练震荡

加入 CE,相当于给优化加了一个更稳定、与数据一致的“参照系”,常见效果是:

  • loss 曲线更平滑
  • 不容易崩
  • 对超参(学习率、温度、k)没那么敏感

6) 纠偏与泛化:CE 能抵消老师在域外/新数据上的偏差

老师通常训练于某个数据配方,但你蒸馏时的语料可能:

  • 更干净或更脏
  • 更偏某个领域
  • 更新了很多新内容

如果只做 KD,学生会被“拉回老师的旧偏好”。
CE 会推动学生适应你当前训练语料的真实统计规律,从而更好泛化到你的目标分布。


那有没有情况可以“不算”真实 label 的 CE?

有,但相对少见,且一般会配合别的补偿。

情况 A:你保存的是“全词表分布”(不是 top-k)且老师极强

如果蒸馏数据保存的是完整 pteacherp_{\text{teacher}}pteacher(vocab 全量),理论上 KD 已经包含了非常丰富的信息。
一些训练会把 α\alphaα 设得很小甚至 0,但这通常要求:

  • 老师非常强且与目标数据分布一致
  • 蒸馏数据质量极高
  • 不做 top-k 截断(或 k 极大)

情况 B:你在做“纯模仿老师”的压缩部署,PPL 不是核心指标

比如你的目标就是“学生尽可能复刻老师输出行为”,而不是追求对真实语料的最优似然,这时可能更偏向 KD-only。
但在 pretrain 场景,大多数团队仍会保留一定比例的 CE。

情况 C:你用合成数据(teacher 生成)且没有可靠“真实标签”

如果数据是老师生成的序列,那“真实 label”其实也来自老师。此时 CE 可能和 KD 强耦合、信息重复。
不过即便如此,很多实现也会用 CE(对生成 token 的 next-token)来稳定训练,同时 KD 提供更软的分布信息。


业界最常见的做法(实用结论)

  • top-k 蒸馏(尤其 k 小,比如 3/32/64):几乎都会 KD + CE
  • 常见设置:α\alphaα 在 0.2~0.8 之间(取决于老师强度、k 大小、数据质量)
  • 有时还会用“训练日程”:前期偏 CE 稳住,后期加大 KD;或相反(看目标)

如果你愿意,我可以基于你这条样本继续把 “同时算 CE 与 KD 的完整数值例子” 也写出来:给定每个位置真实 label 的学生概率(或 logits),把 LCEL_{\text{CE}}LCELKDL_{\text{KD}}LKD、以及混合后的 LLL 全部算到具体数字。

把 “同时算 CE 与 KD 的完整数值例子” 也写出来

下面给出同一条样本在训练时 同时计算 CE(真实 label)与 KD(top-k 版 KL) 的一套完整数值演示,参数与前面一致:

  • vocab_size = 30
  • seq_len = 10,但 labels[9] = -100 忽略,所以有效位置是 t = 0..8 共 9 个
  • k = 3
  • KD 温度取 T=2T=2T=2
  • 对数用自然对数 ln⁡\lnln

1) 要用的公式

1.1 CE(真实 label 的交叉熵)

对每个位置 (t),学生对全词表 30 维 logits 记为 (z_{t,0…29})。

softmax 概率:
pt,i=ezt,i∑j=029ezt,j p_{t,i}=\frac{e^{z_{t,i}}}{\sum_{j=0}^{29} e^{z_{t,j}}} pt,i=j=029ezt,jezt,i

真实 label 为 (y_t),则
CE∗t=−ln⁡(p∗t,yt) \mathrm{CE}*t=-\ln(p*{t,y_t}) CEt=ln(pt,yt)

平均 CE:
LCE=19∑t=08CEt L_{\mathrm{CE}}=\frac{1}{9}\sum_{t=0}^{8}\mathrm{CE}_t LCE=91t=08CEt


1.2 KD(top-k 版 KL,带温度)

老师给每个位置的 top-k:

  • teacher_topk_ids[t] = [a,b,c]
  • teacher_topk_probs[t] = [p_1,p_2,p_3](top-k 内已归一化,三项和为 1)

学生只在这 3 个 token 上取 logits:

  • student_topk_logits[t] = [z_a, z_b, z_c]

温度 softmax(只在 top-k 上):
qj=exp⁡(zj/T)exp⁡(z1/T)+exp⁡(z2/T)+exp⁡(z3/T) q_j=\frac{\exp(z_j/T)}{\exp(z_1/T)+\exp(z_2/T)+\exp(z_3/T)} qj=exp(z1/T)+exp(z2/T)+exp(z3/T)exp(zj/T)

top-k KL:
KL∗t=∑∗j=13pjln⁡(pjqj) \mathrm{KL}*t=\sum*{j=1}^{3} p_j\ln\left(\frac{p_j}{q_j}\right) KLt=j=13pjln(qjpj)

平均 KD loss(常见乘 (T^2)):
LKD=T2⋅19∑t=08KLt L_{\mathrm{KD}}=T^2\cdot \frac{1}{9}\sum_{t=0}^{8}\mathrm{KL}_t LKD=T291t=08KLt


1.3 混合总损失(举例用 (\alpha=0.5))

L=αLCE+(1−α)LKD L=\alpha L_{\mathrm{CE}}+(1-\alpha)L_{\mathrm{KD}} L=αLCE+(1α)LKD


2) 本条样本数据(labels 与 teacher top-k)

labels = [5, 7, 3, 18, 21, 9, 2, 14, 6, -100],有效是 t=0..8

teacher top-k(t=0…8):

  • t=0 ids [5, 8, 1] probs [0.62, 0.23, 0.15]
  • t=1 ids [7, 4, 9] probs [0.55, 0.30, 0.15]
  • t=2 ids [3, 10, 6] probs [0.50, 0.27, 0.23]
  • t=3 ids [18, 17, 11] probs [0.58, 0.22, 0.20]
  • t=4 ids [21, 22, 19] probs [0.47, 0.32, 0.21]
  • t=5 ids [9, 13, 8] probs [0.51, 0.28, 0.21]
  • t=6 ids [2, 15, 1] probs [0.66, 0.19, 0.15]
  • t=7 ids [14, 16, 12] probs [0.44, 0.33, 0.23]
  • t=8 ids [6, 20, 7] probs [0.49, 0.26, 0.25]

3) 学生对全词表的 logits(用于 CE 的完整 30 维,每个 t 一条)

下面每一行是 student_logits_vocab30[t](长度 30,下标 0…29):

  • t=0
    [-1.2, 0.7, -0.5, 0.1, -0.3, 2.1, -0.8, 0.4, 1.2, -0.6, 0.2, -1.0, -0.2, 0.5, -0.7, 0.0, -0.9, 0.3, -0.4, 0.6, -1.1, -0.1, 0.8, -0.2, 0.9, -0.5, 0.1, -0.6, 0.2, -0.3]

  • t=1
    [-1.9, -2.0, -2.0, -2.0, 1.1, -2.0, -2.0, 1.8, -2.0, 0.3, -2.0, -2.0, -2.0, -2.0, -2.0, -2.0, -2.0, -2.0, -2.0, -2.0, -2.0, -2.0, -2.0, -2.0, -2.0, -2.0, -2.0, -2.0, -2.0, -1.7]

  • t=2
    [-1.9, -2.0, -2.0, 1.4, -2.0, -2.0, 0.6, -2.0, -2.0, -2.0, 0.9, -2.0, -2.0, -2.0, -2.0, -2.0, -2.0, -2.0, -2.0, -2.0, -2.0, -2.0, -2.0, -2.0, -2.0, -2.0, -2.0, -2.0, -2.0, -1.7]

  • t=3
    [-1.9, -2.0, -2.0, -2.0, -2.0, -2.0, -2.0, -2.0, -2.0, -2.0, -2.0, 0.8, -2.0, -2.0, -2.0, -2.0, -2.0, 1.0, 2.0, -2.0, -2.0, -2.0, -2.0, -2.0, -2.0, -2.0, -2.0, -2.0, -2.0, -1.7]

  • t=4
    [-1.9, -2.0, -2.0, -2.0, -2.0, -2.0, -2.0, -2.0, -2.0, -2.0, -2.0, -2.0, -2.0, -2.0, -2.0, -2.0, -2.0, -2.0, -2.0, 0.9, -2.0, 1.5, 1.3, -2.0, -2.0, -2.0, -2.0, -2.0, -2.0, -1.7]

  • t=5
    [-1.9, -2.0, -2.0, -2.0, -2.0, -2.0, -2.0, -2.0, 0.4, 1.6, -2.0, -2.0, -2.0, 1.0, -2.0, -2.0, -2.0, -2.0, -2.0, -2.0, -2.0, -2.0, -2.0, -2.0, -2.0, -2.0, -2.0, -2.0, -2.0, -1.7]

  • t=6
    [-1.9, 0.5, 2.2, -2.0, -2.0, -2.0, -2.0, -2.0, -2.0, -2.0, -2.0, -2.0, -2.0, -2.0, -2.0, 0.9, -2.0, -2.0, -2.0, -2.0, -2.0, -2.0, -2.0, -2.0, -2.0, -2.0, -2.0, -2.0, -2.0, -1.7]

  • t=7
    [-1.9, -2.0, -2.0, -2.0, -2.0, -2.0, -2.0, -2.0, -2.0, -2.0, -2.0, -2.0, 0.6, -2.0, 1.2, -2.0, 1.0, -2.0, -2.0, -2.0, -2.0, -2.0, -2.0, -2.0, -2.0, -2.0, -2.0, -2.0, -2.0, -1.7]

  • t=8
    [-1.9, -2.0, -2.0, -2.0, -2.0, -2.0, 1.7, 0.9, -2.0, -2.0, -2.0, -2.0, -2.0, -2.0, -2.0, -2.0, -2.0, -2.0, -2.0, -2.0, 1.0, -2.0, -2.0, -2.0, -2.0, -2.0, -2.0, -2.0, -2.0, -1.7]


4) 计算 CE(逐位置给出 sumexp、p_label、CE_t)

每个位置 (t):
pt,yt=ezt,yt∑j=029ezt,j,CE∗t=−ln⁡(p∗t,yt) p_{t,y_t}=\frac{e^{z_{t,y_t}}}{\sum_{j=0}^{29}e^{z_{t,j}}},\quad \mathrm{CE}*t=-\ln(p*{t,y_t}) pt,yt=j=029ezt,jezt,yt,CEt=ln(pt,yt)

逐位置结果(t=0…8):

t label_id label_logit exp(label_logit) sumexp (Σ exp(logit)) p_label CE_t = -ln(p_label)
0 5 2.100000 8.166170 39.510223 0.206685 1.576559
1 7 1.800000 6.049647 14.119307 0.428466 0.847543
2 3 1.400000 4.055200 12.052556 0.336460 1.089277
3 18 2.000000 7.389056 16.048513 0.460420 0.775616
4 21 1.500000 4.481689 14.326223 0.312831 1.162092
5 9 1.600000 4.953032 12.878773 0.384589 0.955580
6 2 2.200000 9.025013 16.848972 0.535642 0.624290
7 14 1.200000 3.320117 11.576152 0.286807 1.248947
8 6 1.700000 5.473947 14.367467 0.380996 0.964966

求和与平均:

  • ∑t=08CEt=9.244871\sum_{t=0}^{8}\mathrm{CE}_t = 9.244871t=08CEt=9.244871
  • LCE=9.2448719=1.027208L_{\mathrm{CE}}=\frac{9.244871}{9}=1.027208LCE=99.244871=1.027208

5) 计算 KD(top-k 版 KL,逐位置给出 q 与 KL_t)

每个位置 (t):

  • ids = [a,b,c]
  • 学生取 logits:([z_a,z_b,z_c])
  • 温度缩放:([z_a/T,z_b/T,z_c/T])
  • 得到 (q),再算
    KL∗t=∑∗j=13pjln⁡(pjqj) \mathrm{KL}*t=\sum*{j=1}^{3} p_j\ln\left(\frac{p_j}{q_j}\right) KLt=j=13pjln(qjpj)

逐位置结果(t=0…8,全部给出):

t teacher_topk_ids p_teacher student_topk_logits scaled_logits (÷T) exp(scaled) sumexp q_student KL_t
0 [5, 8, 1] [0.620000, 0.230000, 0.150000] [2.100000, 1.200000, 0.700000] [1.050000, 0.600000, 0.350000] [2.857651, 1.822119, 1.419068] 6.098837 [0.468557, 0.298765, 0.232678] 0.047623
1 [7, 4, 9] [0.550000, 0.300000, 0.150000] [1.800000, 1.100000, 0.300000] [0.900000, 0.550000, 0.150000] [2.459603, 1.733253, 1.161834] 5.354690 [0.459336, 0.323689, 0.216975] 0.020903
2 [3, 10, 6] [0.500000, 0.270000, 0.230000] [1.400000, 0.900000, 0.600000] [0.700000, 0.450000, 0.300000] [2.013753, 1.568312, 1.349859] 4.931924 [0.408336, 0.317965, 0.273699] 0.017110
3 [18, 17, 11] [0.580000, 0.220000, 0.200000] [2.000000, 1.000000, 0.800000] [1.000000, 0.500000, 0.400000] [2.718282, 1.648721, 1.491825] 5.858828 [0.463974, 0.281408, 0.254618] 0.027012
4 [21, 22, 19] [0.470000, 0.320000, 0.210000] [1.500000, 1.300000, 0.900000] [0.750000, 0.650000, 0.450000] [2.117000, 1.915541, 1.568312] 5.600853 [0.377793, 0.342016, 0.280191] 0.020703
5 [9, 13, 8] [0.510000, 0.280000, 0.210000] [1.600000, 1.000000, 0.400000] [0.800000, 0.500000, 0.200000] [2.225541, 1.648721, 1.221403] 5.095665 [0.436751, 0.323544, 0.239705] 0.010818
6 [2, 15, 1] [0.660000, 0.190000, 0.150000] [2.200000, 0.900000, 0.500000] [1.100000, 0.450000, 0.250000] [3.004166, 1.568312, 1.284025] 5.856503 [0.512979, 0.267796, 0.219225] 0.044206
7 [14, 16, 12] [0.440000, 0.330000, 0.230000] [1.200000, 1.000000, 0.600000] [0.600000, 0.500000, 0.300000] [1.822119, 1.648721, 1.349859] 4.820699 [0.377963, 0.342000, 0.280037] 0.009803
8 [6, 20, 7] [0.490000, 0.260000, 0.250000] [1.700000, 1.000000, 0.900000] [0.850000, 0.500000, 0.450000] [2.339647, 1.648721, 1.568312] 5.556680 [0.421114, 0.296696, 0.282190] 0.009647

求和与平均:

  • ∑t=08KLt=0.207824\sum_{t=0}^{8}\mathrm{KL}_t = 0.207824t=08KLt=0.207824
  • 19∑t=08KLt=0.023092\frac{1}{9}\sum_{t=0}^{8}\mathrm{KL}_t = 0.02309291t=08KLt=0.023092

带上 (T^2=4):
LKD=4×0.023092=0.092366 L_{\mathrm{KD}}=4\times 0.023092=0.092366 LKD=4×0.023092=0.092366


6) 合并 CE 与 KD(给出最终数值)

本例:

  • LCE=1.027208L_{\mathrm{CE}}=1.027208LCE=1.027208
  • LKD=0.092366L_{\mathrm{KD}}=0.092366LKD=0.092366

若取 (\alpha=0.5):
L=0.5×1.027208+0.5×0.092366=0.559787 L=0.5\times 1.027208 + 0.5\times 0.092366 = 0.559787 L=0.5×1.027208+0.5×0.092366=0.559787

(可选感受一下 (\alpha) 的影响)

  • (\alpha=0.2):(L=0.279335)
  • (\alpha=0.8):(L=0.840240)

如果你想把这一步落到实现上,我也可以把“同一批数据里同时算 CE 与 top-k KD”的张量形状与伪代码写出来,并明确哪里需要 mask-100

Logo

有“AI”的1024 = 2048,欢迎大家加入2048 AI社区

更多推荐