大模型知识蒸馏技术
进一步思考,“知识蒸馏” 其实是我们最拿手的传统教育模式,在古代,背会唐诗三百首,不会写也会偷,在当代,虽然我们的思维方式不是那么异想天开,但我们的考试成绩秒杀全宇宙,这背后都是对范式的记忆和应用,即,将知识的范式当作知识本身来学习,因此才得以弯道超车,省却了大量时间,这岂不就是大模型蒸馏。这个过程中,我是大模型,女儿是小模型,我由于 30 多年的积累训练,早就有了各种解题范式,而她不需要重复这
最近辅导女儿数学,同时也在看大模型,发现我的辅导方法和大模型蒸馏技术异曲同工,即使教不会心法,也能教个妙招,就着这个话题写篇作文。先介绍蒸馏技术,再谈谈中学数学辅导和拔高。
蒸馏旨在将大型模型(老师模型)的知识迁移到小型模型(学生模型),过程中让学生模型学习老师模型 “软化” 后的输出,而不仅仅获得正确答案。
知识蒸馏的动机在于,传统的训练使用 one-hot 向量硬标签会丢失很多关联信息,例如,一张皮鞋的图片被错误分类为经理,与错误分类为工人,其错误的严重程度是不同的,但硬标签无法体现这一点,老师模型输出的概率分布则包含这种隐形的暗知识。
设老师模型的 logits 输出为 z t \text{z}^t zt,学生模型的 logits 输出为 z s \text{z}^s zs,其中 logits 指 softmax 层之前的原始得分。为了得到平滑的概率分布,引入温度参数 T(通常 T > 1)。带温度 T 的 softmax 函数定义如下:
q i = exp ( z i T ) ∑ j = 1 K exp ( z j T ) q_i=\dfrac{\exp(\dfrac{z_i}{T})}{\sum_{j=1}^{K}\exp(\dfrac{z_j}{T})} qi=∑j=1Kexp(Tzj)exp(Tzi)
其中 K 为类别总数。T 的取值很讲究:
- 当 T < 1,z 中最大值被严重放大,对其他 z i < z m z_i<z_m zi<zm, exp ( z i T ) = exp ( z i − z m T ) ⋅ exp ( z m T ) \text{exp}(\dfrac{z_i}{T})=\text{exp}(\dfrac{z_i-z_m}{T})\cdot \text{exp}(\dfrac{z_m}{T}) exp(Tzi)=exp(Tzi−zm)⋅exp(Tzm),其中 z i − z m < 0 z_i−z_m<0 zi−zm<0,当 T → 0 T\to 0 T→0 时, exp ( z i − z m T ) → 0 \text{exp}(\dfrac{z_i−z_m}{T})\to 0 exp(Tzi−zm)→0,代入公式,整体趋向 one-hot;
- 当 T 非常大,泰勒展开, exp ( z i T ) ≈ 1 + z i T + 小量 \text{exp}(\dfrac{z_i}{T})\approx 1+\dfrac{z_i}{T}+小量 exp(Tzi)≈1+Tzi+小量,代入公式,分子约为 1,分母约为 K,是为均匀分布;
- 当 T > 1 但又不太大,如 T = 3, z i T \dfrac{z_i}{T} Tzi 变小, exp ( z i T ) \text{exp}(\dfrac{z_i}{T}) exp(Tzi) 差异被压缩,更加平滑;
因此,当 T=1 时,即为标准 softmax,只有 T>1 时,概率分布更加平滑,错误类别之间的相对概率大小得以保留,这是知识传递的基础。
在训练过程中,我们使用相同的温度 T 分别处理老师和学生的 logits,得到 softened 分布:
p t = softmax ( z t T ) , p s = softmax ( z s T ) \text{p}^t=\text{softmax}(\dfrac{\text{z}^t}{T}), \quad\text{p}^s=\text{softmax}(\dfrac{\text{z}^s}{T}) pt=softmax(Tzt),ps=softmax(Tzs)
p t p^t pt 相当于老师先做了一遍题目,蒸馏的目标是让学生模型的 softened 分布 p s \text{p}^s ps 逼近老师模型的 softened 分布 p t \text{p}^t pt。
衡量两个概率分布差异的常用指标是 Kullback-Leibler 散度,KL 散度定义为:
D KL ( p t ∥ p s ) = ∑ i = 1 C p i t log p i t p i s D_{\text{KL}}(\text{p}^t\parallel\text{p}^s)=\sum_{i=1}^{C}p^t_i\log\dfrac{p^t_i}{p^s_i} DKL(pt∥ps)=∑i=1Cpitlogpispit
KL 散度衡量当用其中一个分布 Q 近似真实分布 P 时,所损失的信息量。如果两个分布完全一样,KL 散度为 0,差异越大,KL 散度越大,用 KL 散度可等价交叉熵:
- 交叉熵: H ( P , Q ) = − ∑ i P ( i ) log Q ( i ) H(P,Q)=-\sum_i P(i)\log Q(i) H(P,Q)=−∑iP(i)logQ(i);
- KL 散度与交叉熵的关系: D K L ( P ∥ Q ) = H ( P , Q ) − H ( P ) D_{KL}(P\parallel Q)=H(P,Q)-H(P) DKL(P∥Q)=H(P,Q)−H(P);
- 其中 H§ 是分布 P 的熵。在蒸馏中,老是分布 P 固定,因此它的熵 H§ 是常数;
- So,最小化 KL散度 D K L ( P ∥ Q ) D_{KL}(P\parallel Q) DKL(P∥Q) 等价于最小化交叉熵 H(P, Q)$,两者在优化目标上等价;
将 softened 分布代入,并考虑到训练效率,在实现时通常对学生 logits 使用 log_softmax。因此,蒸馏损失 L distill \mathcal{L}_{\text{distill}} Ldistill 定义为:
L distill = D KL ( p t ∥ p s ) ⋅ T 2 \mathcal{L}_{\text{distill}}=D_{\text{KL}}(\text{p}^t\parallel\text{p}^s)\cdot T^2 Ldistill=DKL(pt∥ps)⋅T2
乘以 T 2 T^2 T2 是出于梯度缩放考虑,对损失函数关于学生 logits z i s z^s_i zis 求导,可得到梯度表达式:
∂ L distill ∂ z i s = 1 T ( p i s − p i t ) \dfrac{\partial\mathcal{L}_{\text{distill}}}{\partial z^s_i}=\dfrac{1}{T}(p^s_i-p^t_i) ∂zis∂Ldistill=T1(pis−pit)
可看出,梯度中包含因子 1 T \dfrac{1}{T} T1。乘以 T 2 T^2 T2 后,有效梯度变为 T ( p i s − p i t ) T(p^s_i-p^t_i) T(pis−pit),这使得在温度 T 变化时,梯度的尺度相对稳定,有利于优化。
然而仅使用蒸馏损失可能导致学生模型过度依赖老师而忽略真实答案。因此,总损失由蒸馏损失和标准的交叉熵损失加权。设真实标签的 one-hot 向量为 y,标准交叉熵损失为:
L CE = − ∑ i = 1 C y i log ( softmax ( z i s ) T = 1 ) \mathcal{L}_{\text{CE}} = -\sum_{i=1}^{C} y_i \log(\text{softmax}(z^s_i)_{T=1}) LCE=−∑i=1Cyilog(softmax(zis)T=1)
最终的总损失函数为:
L total = α ⋅ L distill + ( 1 − α ) ⋅ L CE \mathcal{L}_{\text{total}}=\alpha\cdot\mathcal{L}_{\text{distill}}+(1-\alpha)\cdot \mathcal{L}_{\text{CE}} Ltotal=α⋅Ldistill+(1−α)⋅LCE
其中 α ∈ [ 0 , 1 ] ) \alpha\in[0,1]) α∈[0,1]) 是个超参数,用于平衡两种损失的贡献。
训练流程如下:
- 使用训练数据对冻结参数的老师模型进行前向传播,计算 softened 概率分布 p t \text{p}^t pt 并缓存;
- 在训练学生模型的每个批次中,同时计算学生的 softened 分布 p s \text{p}^s ps 和标准 logits;
- 根据上述公式计算总损失,并仅对学生模型的参数进行反向传播和优化。
推理阶段,学生模型使用标准 softmax(T=1)进行预测。
蒸馏的威力在于 softened 分布蕴含相似性信息,而标准答案的硬标签并不具备这些信息。通过最小化与老师分布的 KL 散度,学生模型不仅学习如何分类,还学习了老师模型更丰富的柔性,模型虽小但看起来显得很大。
以上描述的知识蒸馏与我辅导女儿数学特别是几何时的效果一样,都非常好,我一并分享一下。
我采用几个步骤:
- 自己先做一遍题目,因为我一直都很喜欢解题,这对我而言没什么难度,反而难题是一种享受;
- 告诉女儿我的思路,为保证她确实理解,我会让她写篇 600 字作文,就像我中学时记录的一样;
- 让她对照老师讲的标准答案,并比较我的答案和标准答案的区别;
- 让她回答当时为什么既没有想到标准答案,也没有想到我的答案;
- 不断重复上述;
这个过程中,我是大模型,女儿是小模型,我由于 30 多年的积累训练,早就有了各种解题范式,而她不需要重复这 30 的过程,只需要记住我的范式即可。
进一步思考,“知识蒸馏” 其实是我们最拿手的传统教育模式,在古代,背会唐诗三百首,不会写也会偷,在当代,虽然我们的思维方式不是那么异想天开,但我们的考试成绩秒杀全宇宙,这背后都是对范式的记忆和应用,即,将知识的范式当作知识本身来学习,因此才得以弯道超车,省却了大量时间,这岂不就是大模型蒸馏。
老师花了大量资源(时间,金钱,人力等)学习,学生只学习老师的结论,以及他的方法论。
填鸭式知识传承,知识蒸馏的目标本质上就是要提高填鸭的效率和保真度,但这也存在一些关键瓶颈:
- 对老师(教材)的绝对依赖与知识边界的固化;
- 对标准答案的执念以及泛化能力的瓶颈;
- 批判,解构,联想,第一性探索能力瓶颈;
这些同时也是 AI 大模型的瓶颈。我就不信经理们能搞定。
浙江温州皮鞋湿,下雨进水不会胖。
更多推荐
所有评论(0)