RMSNorm规范化
RMSNorm 的数学公式计算均方根(RMS)。归一化输入。缩放归一化后的输出。RMSNorm 相对于传统的 Layer Normalization 更加稳定,尤其是在长序列处理中表现更优。
RMSNorm(Root Mean Square Layer Normalization)
RMSNorm 是一种类似于 Layer Normalization 的归一化技术,但它使用的是均方根(Root Mean Square, RMS)而非标准差。RMSNorm 在实际应用中表现出更好的稳定性和计算效率。
- 均值:描述数据的中心位置。
μ = 1 n ∑ i = 1 n x i \mu = \frac{1}{n} \sum_{i=1}^{n} x_i μ=n1i=1∑nxi - 方差:描述数据的离散程度,单位是原数据单位的平方。
σ 2 = 1 n ∑ i = 1 n ( x i − μ ) 2 \sigma^2 = \frac{1}{n} \sum_{i=1}^{n} (x_i - \mu)^2 σ2=n1i=1∑n(xi−μ)2 - 标准差:描述数据的离散程度,单位与原数据相同。
σ = 1 n ∑ i = 1 n ( x i − μ ) 2 \sigma = \sqrt{\frac{1}{n} \sum_{i=1}^{n} (x_i - \mu)^2} σ=n1i=1∑n(xi−μ)2 - 均方根:描述数据的总体水平,常用于信号处理等领域。
RMS = 1 n ∑ i = 1 n x i 2 \text{RMS} = \sqrt{\frac{1}{n} \sum_{i=1}^{n} x_i^2} RMS=n1i=1∑nxi2
数学公式
假设输入张量 x \mathbf{x} x 的形状为 ( N , D ) (N, D) (N,D),其中 N N N 是批量大小, D D D是特征维度。RMSNorm 的计算步骤如下:
1. 求均方根 : RMS ( x ) = 1 D ∑ i = 1 D x i 2 2. 归一化输入 : y = x RMS ( x ) + ϵ 3. 缩放 : z = γ ⋅ y , γ 是一个可学习的参数 . \begin{aligned} &1.求均方根: \text{RMS}(\mathbf{x}) = \sqrt{\frac{1}{D} \sum_{i=1}^{D} x_i^2} \\ &2. 归一化输入: \mathbf{y} = \frac{\mathbf{x}}{\text{RMS}(\mathbf{x}) + \epsilon} \\ &3. 缩放: \mathbf{z} = \gamma \cdot \mathbf{y},\ \gamma 是一个可学习的参数. \end{aligned} 1.求均方根:RMS(x)=D1i=1∑Dxi22.归一化输入:y=RMS(x)+ϵx3.缩放:z=γ⋅y, γ是一个可学习的参数.
代码实现
class RMSNorm(nn.Module):
def __init__(self, hidden_size, eps=1e-6):
super(RMSNorm, self).__init__()
self.weight = nn.Parameter(torch.ones(hidden_size))
self.variance_epsilon = eps
def forward(self, x):
variance = torch.var(x, dim=-1, keepdim=True)
x = x * torch.rsqrt(variance + self.variance_epsilon)
return self.weight * x
总结
- RMSNorm 的数学公式:
- 计算均方根(RMS)。
- 归一化输入。
- 缩放归一化后的输出。
RMSNorm 相对于传统的 Layer Normalization 更加稳定,尤其是在长序列处理中表现更优。
更多推荐

所有评论(0)