上篇文章说明了归一化的作用,我们简要回顾一下:

归一化公式会强行把每一层的数据都变成均值为0、方差为1的分布。然而,对于某些层来说,数据原本的分布(比如均值是5,方差是10)才是最有用的、最能帮助模型做出正确预测的分布。强行把它"拉平",可能会削弱模型的表达能力,甚至让某些激活函数(如Sigmoid)落入饱和区。

  1. 方差大 → 权重值大
    当你使用较大的方差初始化权重时(比如标准差为1或更大),初始权重值会大概率落在远离0的区域(比如 -2, 3, -4 等)。

  2. 权重值大 → 线性组合输出大
    神经元的输出是:
    z=∑wixi+b z = \sum w_i x_i + b z=wixi+b
    如果输入$ x_i $是标准正态分布(均值为0,方差为1),而权重 wiw_iwi 的方差也很大,那么$ z $ 的方差会更大(因为方差是累加的)。

  3. 输出大 → 激活函数进入饱和区
    sigmoid 为例:
    σ(z)=11+e−z \sigma(z) = \frac{1}{1 + e^{-z}} σ(z)=1+ez1

  • 当$ z > 4 ,,\sigma(z) \approx 0.98$
  • 当 $ z < -4 ,,\sigma(z) \approx 0.02$
    这些区域称为饱和区,因为:
  • 输出几乎不再变化(梯度接近0)
  • 反向传播时,梯度几乎为0,导致梯度消失
    在这里插入图片描述

常用的归一化手段

BatchNorm(批归一化)以及 LayerNorm(层归一化) 还有大模型用得比较多的

  • BatchNorm(批归一化)横向对比。把一个批次里所有样本的同一个特征放在一起,求平均和标准差,然后标准化。

    • 比如:收集全班50个同学的语文成绩,一起标准化。
  • LayerNorm(层归一化)纵向自省。把一个样本的所有特征放在一起,求平均和标准差,然后标准化。

    • 比如:只看小明一个人的语数外理化生所有成绩,一起标准化。
归一化的区别

"平均值"和"标准差"是怎么算出来的?
假设你的数据是一个表格,行是样本(一个学生)列是特征(科目成绩)

语文 数学 英语
学生A(样本1) 90 80 70
学生B(样本2) 60 50 40
  • BatchNorm (BN) 的计算方式

    • 竖着看(按列)
    • 计算"语文"这列的平均值:(90 + 60) / 2 = 75
    • 然后用这个平均值和标准差去标准化"语文"这列的所有成绩(学生A和B的语文成绩)。
    • 对"数学"、"英语"列重复同样操作。
    • 总结:按特征(科目)维度处理,依赖同批次的其他样本。
  • LayerNorm (LN) 的计算方式

    • 横着看(按行)
    • 计算"学生A"这行的平均值:(90 + 80 + 70) / 3 = 80
    • 然后用这个平均值和标准差去标准化"学生A"自己的所有成绩(他的语数外成绩)。
    • 对"学生B"行重复同样操作。
    • 总结:按样本处理,不依赖其他样本,自己管自己。

简单记住
BN - 横向比(和别人比)
LN - 纵向比(和自己比)

BatchNorm 方法

假设:

  • 批次数据:X=[x1,x2,…,xm]X = [x_1, x_2, \ldots, x_m]X=[x1,x2,,xm],其中每个样本 xix_ixi 特征维度为 kkk 的向量,可以表示为xikx_{ik}xik
  • 批次大小:mmm
  • 小常数:ϵ\epsilonϵ(通常为 10−510^{-5}105,防止除零)

计算步骤

  1. 在某个特征维度上,比如k1的特征维度, $\mu_{k1} $ 计算均值向量
    μk1=1m∑i=1mxik1(1) \mu_{k1} = \frac{1}{m} \sum_{i=1}^{m} x_{ik1} \tag 1 μk1=m1i=1mxik1(1)
    其计算是每个批次的数据,把这些数据对应的k1特征维度求平均。

  2. 同样计算方差向量,也是计算是每个批次的数据,把这些数据对应的k1特征维度求方差
    σk12=1m∑i=1m(xi−μk1)2(2) \sigma^2_{k1} = \frac{1}{m} \sum_{i=1}^{m} (x_i - \mu_{k1})^2 \tag 2 σk12=m1i=1m(xiμk1)2(2)

    • σk12\sigma^2_{k1}σk12 是每个特征维度上的方差向量
    • 平方运算为逐元素操作
  3. 输入xxxk1k1k1维度的标准化(归一化)
    x^ik1=xik1−μk1σk12+ϵ(3) \hat{x}_{ik1} = \frac{x_{ik1} - \mu_{k1}}{\sqrt{\sigma^2_{k1} + \epsilon}} \tag 3 x^ik1=σk12+ϵ xik1μk1(3)

    • 对每个样本进行标准化处理
  4. 同样如果有k个维度特征就,每个维度都计算他们对应的均值与方差。
    μ=[μk1,μk2,…,μkm] \mu = [\mu_{k1} , \mu_{k2} , \ldots, \mu_{km} ] μ=[μk1,μk2,,μkm]
    σ2=[σk12,σk22,…,σkm2] \sigma^2 = [\sigma^2_{k1} , \sigma^2_{k2} , \ldots, \sigma^2_{km} ] σ2=[σk12,σk22,,σkm2]

由于CV的输入shape规范一致,batchNorm进行归一化非常有效,能加速训练并提供一定的正则化效果。而不适用于序列模型:在 NLP 任务中,序列长度通常是变化的,BatchNorm 在处理变长序列时非常棘手。

LayerNorm 方法

与BatchNorm方法使用的变量保持一致:

  • 批次数据:X=[x1,x2,…,xm]X = [x_1, x_2, \ldots, x_m]X=[x1,x2,,xm],其中每个样本 xix_ixi 特征维度为 kkk 的向量,可以表示为xikx_{ik}xik
  • 批次大小:mmm
  • 小常数:ϵ\epsilonϵ(通常为 10−510^{-5}105,防止除零)

i表示样本维度,k表示一个样本中的特征维度

1. 对每个样本计算均值
对于每个样本 $ x_i $,计算其所有特征维度的均值:
μi=1k∑j=1kxij \mu_i = \frac{1}{k} \sum_{j=1}^{k} x_{ij} μi=k1j=1kxij

⚠️ 注意:这里是每个样本内部求均值,而不是像 BatchNorm 那样跨样本。

2. 对每个样本计算方差
同样,在每个样本内部计算方差:
σi2=1k∑j=1k(xij−μi)2 \sigma^2_i = \frac{1}{k} \sum_{j=1}^{k} (x_{ij} - \mu_i)^2 σi2=k1j=1k(xijμi)2

3. 对每个样本进行标准化
对每个样本的每个特征维度进行归一化:
x^ij=xij−μiσi2+ϵ \hat{x}_{ij} = \frac{x_{ij} - \mu_i}{\sqrt{\sigma^2_i + \epsilon}} x^ij=σi2+ϵ xijμi

对 batch size 不敏感,尤其适合处理变长序列和不同大小的 batch,广泛应用于 RNN 和 Transformer 中。

RMSNorm 方法

说到LayerNorm,这里不得不提RMSNorm,源自于论文"Root Mean Square Layer Normalization",被很多主流大模型所采用。

RMSNorm 不计算传统意义的均值 μ 和方差 σ²,而是只计算一个均方根(Root Mean Square, RMS) 值。具体步骤如下:

同样假设:

  • 批次数据:X=[x1,x2,…,xm]X = [x_1, x_2, \ldots, x_m]X=[x1,x2,,xm],其中每个样本 xix_ixi 特征维度为 kkk 的向量,可以表示为xikx_{ik}xik
  • 批次大小:mmm
  • 小常数:ϵ\epsilonϵ(通常为 10−510^{-5}105,防止除零)
  • 单个样本向量:
    xi=[xk1,xk2,…] x_i = [x_{k1}, x_{k2}, \ldots] xi=[xk1,xk2,]
  1. 计算 RMS(等价于"标准差"但无去均值)
    RMS(xi)=1k∑j=1kxij2+ϵ \mathrm{RMS}(x_{i}) = \sqrt{\frac{1}{k}\sum_{j=1}^{k} x_{ij}^2 + \epsilon} RMS(xi)=k1j=1kxij2+ϵ

注意:这里没有减去均值,直接用原始值平方。

  1. 完整批次数据的 RMS(整个矩阵的均方根)
    如果需要计算整个批次数据 XXXm×km \times km×k 矩阵)的 RMS(即所有元素的均方根):
    RMS(x)=1m⋅k∑i=1m∑j=1kxij2+ϵ \mathrm{RMS}(x) = \sqrt{ \frac{1}{m \cdot k} \sum_{i=1}^{m} \sum_{j=1}^{k} x_{ij}^2 + \epsilon } RMS(x)=mk1i=1mj=1kxij2+ϵ

  2. 单个样本的归一化
    x^i=xiRMS(x) \hat{x}_i = \frac{x_i}{\mathrm{RMS}(x)} x^i=RMS(x)xi

RMSNorm 仅保留均方根,参数量减半,更适合大模型的应用。并且其去掉均值后梯度噪声减小,深层网络梯度消失现象减轻,训练的稳定性更好。

归一化的可学习参数

归一化公式会强行把每一层的数据都变成均值为0、方差为1的分布。然而,对于某些层来说,数据原本的分布(比如均值是5,方差是10)才是最有用的、最能帮助模型做出正确预测的分布。强行把它"拉平",可能会削弱模型的表达能力,甚至让某些激活函数(如Sigmoid)落入饱和区。

我们引入两个可学习的参数 γβ,让网络自己来决定:
BatchNorm 和 LayerNorm 都用同一个公式,但计算范围不同:

输出 = γ * (输入 - 平均值) / 标准差 + β

  • γβ 是可学习的参数,让网络自己决定是否需要恢复一些原始分布。

具体学习如下:

  1. 初始化参数:首先,γβ 被定义为模型的可训练参数。在模型初始化时,γ 通常初始化为全1(这样开始时不做缩放),β 通常初始化为全0(这样开始时不做平移)。
  2. 前向传播:在前向传播过程中,它们会参与计算,影响最终的输出和模型的预测结果。
  3. 反向传播:在反向传播过程中,损失函数的梯度会通过链式法则一直回溯,同样也会计算到损失对 γβ 的梯度
    • d(Loss)/dγ
    • d(Loss)/dβ
  4. 优化器更新:在优化器更新模型所有权重参数的那一步,γβ 也会根据计算出的梯度进行更新
    • γ = γ - learning_rate * d(Loss)/dγ
    • β = β - learning_rate * d(Loss)/dβ

全局移动平均和移动方差

  • 训练阶段:使用当前批次的统计量(μ\muμσ2\sigma^2σ2
  • 测试阶段:通常使用训练过程中计算的移动平均(running mean)和移动方差(running variance)

由于在训练阶段,我们使用当前mini-batch的数据来计算均值(μbatchμ_{batch}μbatch)和方差(σbatch2σ_{batch}^2σbatch2)。这为模型引入了噪声,因为每个批次的统计量都略有不同。这种噪声实际上有助于模型的正则化,提升泛化能力。

然而,在推理(测试)阶段,我们通常一次只处理一个样本(或少量样本),或者需要得到确定性的、一致的结果。如果我们仍然使用单个样本的统计量,会出现两个问题:

  • 方差会非常小,甚至为0,导致分母爆炸。
  • 输出结果会高度依赖于最后一个样本,非常不稳定。

因此,我们需要一种能代表整个训练数据集的、稳定的统计量。这就是移动平均和移动方差的用武之地。

移动平均和移动方差,是在训练过程中逐步计算和更新的,但它们不参与反向传播

更新过程通常采用指数移动平均(Exponential Moving Average, EMA) 算法,公式如下:
在每一步训练(处理一个batch)后,执行以下更新:
runningmean=momentum∗runningmean+(1−momentum)∗batchmean running_{mean} = momentum * running_{mean} + (1 - momentum) * batch_{mean} runningmean=momentumrunningmean+(1momentum)batchmean
runningvar=momentum∗runningvar+(1−momentum)∗batchvar running_{var} = momentum * running_{var} + (1 - momentum) * batch_{var} runningvar=momentumrunningvar+(1momentum)batchvar

这里需要区分全局和当前批次移动:

  • running_mean, running_var: 我们要追踪的全局移动平均和移动方差。初始值通常为 0 和 1。
  • batch_mean, batch_var: 当前训练批次计算出的均值和方差。
  • momentum: 一个超参数(通常设为 0.9、0.99 或 0.999),用于控制历史信息与当前批次信息的权重。
    • momentum 越接近 1,running_meanrunning_var 的变化越缓慢,对最新一个批次的"反应"越不敏感,更依赖于历史值,结果更平滑。
    • momentum 越接近 0,则更依赖于当前批次的值。

根据公式3可得,推理阶段的公式为
x^i=xi−runningmeanrunningvar2+ϵ(4) \hat{x}_i = \frac{x_i - running_{mean} }{\sqrt{running_{var}^2 + \epsilon}} \tag 4 x^i=runningvar2+ϵ xirunningmean(4)

我们可以想象一个温度计在记录全年的每日平均温度:

  • 训练阶段(batch_mean:就像只看今天一天的温度,然后说"今天是夏天/冬天"。这个结论波动很大。
  • 移动平均(running_mean:就像计算过去30天的平均温度。它能更平滑、更稳定地反映季节变化的趋势,而不会被某一天的极端天气所过度影响。
  • 推理阶段:当有人问你"现在大概是什么温度?"时,你会根据那个平滑的"30天平均趋势"来回答,而不是仅根据今天这一天的温度。这个答案更可靠、更具代表性。

前归一化(PreNorm)与后归一化(PostNorm)

如下图所示:
在这里插入图片描述

PreNorm 将输入 xnx_nxn先进行归一化操作,然后再将经过多头注意力层或前馈神经网络层运算后的输出进行残差连接,这一过程可以通过以下公式表示:
xn+1=xn+F(Norm(xn)) x_{n+1} = x_n + F(\text{Norm}(x_n)) xn+1=xn+F(Norm(xn))

PostNorm 将输入 xnx_nxn 经过多头注意力层或前馈神经网络层运算后的输出进行残差连接,然后再进行归一化操作,这一过程可以通过以下公式表示:
xn+1=Norm(xn+F(xn)) x_{n+1} = \text{Norm}(x_n + F(x_n)) xn+1=Norm(xn+F(xn))

这里的FFF 为Muti-head Attention 或者 FFN 网络结构。

残差连接的理解

残差公式如下:H(x) = F(x) + x,+ x 就是跳跃连接,它将输入 x 直接"跳跃"到输出端,与残差映射 F(x) 进行逐元素相加。允许梯度直接反向传播(避免梯度消失),但也可能导致每层的输出幅度逐渐增大(尤其是深层网络),使得训练不稳定。

  • 残差连接:原始残差块的计算为 $x_{n+1} = x_n + f(x_n) $,其中 $ x_n 是输入,是输入,是输入,f(x_n)$ 是某一层(如Transformer中的Attention或FFN)的输出。恒等分支(即 $ x_n $)允许梯度直接反向传播,缓解梯度消失。

  • PostNorm 在残差之后进行归一化,将每层的输出重新缩放为中心化、标准化的分布,从而抑制了幅度累积增长,使训练过程更稳定。

原本残差连接中的恒等分支是xn+1=xnx_{n+1}=x_{n}xn+1=xn,即第n层的输出与输入相同,但是引入PostNorm后,恒等分支的权重被削弱了。同样,对于第n-1层、第n-2层,每靠近模型输入端一层,恒等分支的权重就会被削弱一次。具体推导公式如下:

  • PostNorm:在残差连接之后进行层归一化(LayerNorm),即 $x_{n+1} = \text{LayerNorm}(x_n + f(x_n)) $。
方差假设与归一化的作用
  • 假设 $ x_n $ 的方差为$ \sigma_1^2 = 1,, f(x_n) $ 的方差为$ \sigma_2^2 = 1 $(即输入和输出的方差都被初始化为1)。
  • 那么残差和 $ x_n + f(x_n)$ 的方差为 $ \text{Var}(x_n + f(x_n)) = \text{Var}(x_n) + \text{Var}(f(x_n)) = 1 + 1 = 2 $(假设 $x_n $ 和 f(xn)f(x_n)f(xn) 独立)。
  • PostNorm(LayerNorm)会强制将激活值的方差重新缩放为1。因此,它需要将 $x_n + f(x_n) $ 除以 $ \sqrt{2} $ 来将方差从2降为1(因为方差具有平方尺度:$ \text{Var}(aX) = a^2 \text{Var}(X) $,所以令 $ a = 1/\sqrt{2} $ 时,方差变为 (1/2)×2=1(1/2) \times 2 = 1(1/2)×2=1)。

因此,PostNorm操作后输出为:
xn+1=xn+f(xn)2 x_{n+1} = \frac{x_n + f(x_n)}{\sqrt{2}} xn+1=2 xn+f(xn)

恒等分支被削弱

在原始残差连接中,恒等分支(即 $x_n $)的系数是1。但这里由于除以了 $\sqrt{2} $,恒等分支的权重从1变成了 1/2≈0.7071/\sqrt{2} \approx 0.7071/2 0.707
xn+1=12⋅xn+12⋅f(xn) x_{n+1} = \frac{1}{\sqrt{2}} \cdot x_n + \frac{1}{\sqrt{2}} \cdot f(x_n) xn+1=2 1xn+2 1f(xn)

这意味着:

  • 输入 $ x_n $ 的贡献被削弱了(乘以因子 $1/\sqrt{2} < 1 $)。
  • 同样,f(xn)f(x_n)f(xn)的贡献也被削弱了。

越靠近输入的层,削弱越严重:因为每一层都会进行这样的归一化,所以第 $ n $ 层的输入 $ x_n $ 实际上已经经历了多次除以 $ \sqrt{2} $ 的操作(假设每层方差类似)。例如,第1层的输入在传到第2层时被除以$\sqrt{2} ,再传到第3层时又被除以,再传到第3层时又被除以,再传到第3层时又被除以 \sqrt{2} $(即累计除以 (2)2=2(\sqrt{2})^2 = 2(2 )2=2),以此类推。因此,较早层的输入信号在向前传播时会被持续缩小,导致梯度不稳定(因为梯度反向传播时也会经过同样的缩放)。

PreNorm 推导
  • PreNorm(先归一化):将LayerNorm放在残差连接之前,即$x_{n+1} = x_n + f(\text{LayerNorm}(x_n)) $。
  • PreNorm不会改变恒等分支的权重(始终为1),因此梯度更稳定,但训练可能更慢(因为梯度主要走恒等分支,优化器需要更努力地更新参数)。
  • PostNorm在训练初期更不稳定,但某些研究表明其最终性能可能更好(需要精细调参)。

公式:
xn+1=xn+f(xn)2 x_{n+1} = \frac{x_n + f(x_n)}{\sqrt{2}} xn+1=2 xn+f(xn)
是PostNorm为了控制方差而做的重新缩放操作。它削弱了残差连接中恒等分支(输入)的贡献(乘以 $1/\sqrt{2} $),并且这种削弱会逐层累积,导致模型底层梯度不稳定。这是PostNorm相比PreNorm的主要缺点之一

总结

归一化技术是现代深度学习中的核心组件,通过稳定训练过程、缓解梯度消失/爆炸问题,显著提升了深层网络的训练效果。BatchNorm、LayerNorm和RMSNorm分别适用于不同的场景:BatchNorm在CV领域的卷积网络中表现出色但对batch size敏感;LayerNorm对batch size不敏感,广泛应用于NLP任务;RMSNorm则通过简化计算和减少参数量,在大模型中显示出优势。此外,PreNorm和PostNorm的选择会影响训练稳定性和性能,需要根据具体任务进行调整。理解这些归一化技术的原理、计算方式和适用场景,对于设计和优化深度学习模型至关重要。

Logo

有“AI”的1024 = 2048,欢迎大家加入2048 AI社区

更多推荐