RMSNorm(Root Mean Square Layer Normalization)

RMSNorm 是一种类似于 Layer Normalization 的归一化技术,但它使用的是均方根(Root Mean Square, RMS)而非标准差。RMSNorm 在实际应用中表现出更好的稳定性和计算效率。

  • 均值:描述数据的中心位置。
    μ = 1 n ∑ i = 1 n x i \mu = \frac{1}{n} \sum_{i=1}^{n} x_i μ=n1i=1nxi
  • 方差:描述数据的离散程度,单位是原数据单位的平方。
    σ 2 = 1 n ∑ i = 1 n ( x i − μ ) 2 \sigma^2 = \frac{1}{n} \sum_{i=1}^{n} (x_i - \mu)^2 σ2=n1i=1n(xiμ)2
  • 标准差:描述数据的离散程度,单位与原数据相同。
    σ = 1 n ∑ i = 1 n ( x i − μ ) 2 \sigma = \sqrt{\frac{1}{n} \sum_{i=1}^{n} (x_i - \mu)^2} σ=n1i=1n(xiμ)2
  • 均方根:描述数据的总体水平,常用于信号处理等领域。
    RMS = 1 n ∑ i = 1 n x i 2 \text{RMS} = \sqrt{\frac{1}{n} \sum_{i=1}^{n} x_i^2} RMS=n1i=1nxi2

数学公式

假设输入张量 x \mathbf{x} x 的形状为 ( N , D ) (N, D) (N,D),其中 N N N 是批量大小, D D D是特征维度。RMSNorm 的计算步骤如下:
1. 求均方根 : RMS ( x ) = 1 D ∑ i = 1 D x i 2 2. 归一化输入 : y = x RMS ( x ) + ϵ 3. 缩放 : z = γ ⋅ y ,   γ 是一个可学习的参数 . \begin{aligned} &1.求均方根: \text{RMS}(\mathbf{x}) = \sqrt{\frac{1}{D} \sum_{i=1}^{D} x_i^2} \\ &2. 归一化输入: \mathbf{y} = \frac{\mathbf{x}}{\text{RMS}(\mathbf{x}) + \epsilon} \\ &3. 缩放: \mathbf{z} = \gamma \cdot \mathbf{y},\ \gamma 是一个可学习的参数. \end{aligned} 1.求均方根:RMS(x)=D1i=1Dxi2 2.归一化输入:y=RMS(x)+ϵx3.缩放:z=γy, γ是一个可学习的参数.

代码实现

class RMSNorm(nn.Module):
    def __init__(self, hidden_size, eps=1e-6):
        super(RMSNorm, self).__init__()
        self.weight = nn.Parameter(torch.ones(hidden_size))
        self.variance_epsilon = eps

    def forward(self, x):
        variance = torch.var(x, dim=-1, keepdim=True)
        x = x * torch.rsqrt(variance + self.variance_epsilon)
        return self.weight * x

总结

  1. RMSNorm 的数学公式
    • 计算均方根(RMS)。
    • 归一化输入。
    • 缩放归一化后的输出。

RMSNorm 相对于传统的 Layer Normalization 更加稳定,尤其是在长序列处理中表现更优。

LayerNorm和BatchNorm 参考

Logo

有“AI”的1024 = 2048,欢迎大家加入2048 AI社区

更多推荐