从旋转位置编码RoPE到YaRN的原理与实现
旋转位置编码(RoPE)是当前大模型的主流位置编码方法,相比传统正余弦位置编码具有更好的外推能力。RoPE通过旋转矩阵将位置信息嵌入到查询和键向量中,能显式保留相对位置关系。其核心思想是将位置编码转化为旋转操作,使注意力分数仅依赖相对位置差而非绝对位置。实现时,先计算分层频率并生成旋转角度矩阵,然后对输入向量进行旋转变换。RoPE支持序列长度的平滑扩展,解决了长上下文建模的难题,已成为LLM的标准
旋转位置编码
旋转位置编码公式以及与正余弦位置编码对比
正余弦位置编码是原始 Transformer 中使用的标准方法,而旋转位置编码(RoPE)是当前大模型的主流选择,能够通过旋转角度编码位置信息,从而支持序列长度的平滑外推,解决了正余弦位置编码在超长上下文中难以外推的缺点。
既能适应序列长度的不同,在序列长度不同时相对距离要一致
正余弦位置编码与旋转位置编码:
两者都利用了旋转矩阵能得到的平移性,即乘上一个系数等得到增加一个变换量后的结果,只是正余弦位置编码是加上了位置项,使得QKT运算后多出了一个加上的相对位置项,而旋转位置编码是乘上旋转矩阵,使得一个是多出了QKT运算后多出了一个乘上的相对位置系数
共同特征是随着嵌入维度增加,频率逐渐减小,周期变长,从而关注更远的距离。
- 序列维度增加 → 向前旋转相同角度
- 嵌入维度增加 → 旋转角度更小,周期更长
正弦余弦编码是通过加法引入绝对位置,并依靠模型自身从复杂的注意力计算中隐式地学习出相对位置关系。它利用了正弦余弦函数的和角公式。
旋转位置编码是通过乘法(旋转)引入绝对位置,并利用旋转矩阵的数学性质,显式地、天然地在注意力分数中仅体现出相对位置关系。它利用了旋转矩阵的复合性(RmTRn=Rn−mR_m^T R_n = R_{n-m}RmTRn=Rn−m)。
公式:
对于位置 ttt 与维度 2i2i2i、2i+12i+12i+1:
PE(t,2i)=sin(t100002i/d),PE(t,2i+1)=cos(t100002i/d) PE_{(t,2i)} = \sin\left(\frac{t}{10000^{2i/d}}\right), \quad PE_{(t,2i+1)} = \cos\left(\frac{t}{10000^{2i/d}}\right) PE(t,2i)=sin(100002i/dt),PE(t,2i+1)=cos(100002i/dt)
平移性质:
[sin(Δt+t)cos(Δt+t)]=[cos(Δt)−sin(Δt)sin(Δt)cos(Δt)][sin(t)cos(t)] \begin{bmatrix} \sin(\Delta t + t) \\ \cos(\Delta t + t) \end{bmatrix} = \begin{bmatrix} \cos(\Delta t) & -\sin(\Delta t) \\ \sin(\Delta t) & \cos(\Delta t) \end{bmatrix} \begin{bmatrix} \sin(t) \\ \cos(t) \end{bmatrix} [sin(Δt+t)cos(Δt+t)]=[cos(Δt)sin(Δt)−sin(Δt)cos(Δt)][sin(t)cos(t)]
公式:
θi=10000−2i/d\theta_i = 10000^{-2i/d}θi=10000−2i/d
旋转位置编码(RoPE)
RΘ,mdx=(x0x1x2x3⋮xd−2xd−1)⊗(cosmθ0cosmθ0cosmθ1cosmθ1⋮cosmθd/2−1cosmθd/2−1)+(−x1x0−x3x2⋮−xd−1xd−2)⊗(sinmθ0sinmθ0sinmθ1sinmθ1⋮sinmθd/2−1sinmθd/2−1)R^d_{\Theta,m} \mathbf{x} = \begin{pmatrix} x_0 \\ x_1 \\ x_2 \\ x_3 \\ \vdots \\ x_{d-2} \\ x_{d-1} \end{pmatrix} \otimes \begin{pmatrix} \cos m\theta_0 \\ \cos m\theta_0 \\ \cos m\theta_1 \\ \cos m\theta_1 \\ \vdots \\ \cos m\theta_{d/2-1} \\ \cos m\theta_{d/2-1} \end{pmatrix} + \begin{pmatrix} -x_1 \\ x_0 \\ -x_3 \\ x_2 \\ \vdots \\ -x_{d-1} \\ x_{d-2} \end{pmatrix} \otimes \begin{pmatrix} \sin m\theta_0 \\ \sin m\theta_0 \\ \sin m\theta_1 \\ \sin m\theta_1 \\ \vdots \\ \sin m\theta_{d/2-1} \\ \sin m\theta_{d/2-1} \end{pmatrix}RΘ,mdx=
x0x1x2x3⋮xd−2xd−1
⊗
cosmθ0cosmθ0cosmθ1cosmθ1⋮cosmθd/2−1cosmθd/2−1
+
−x1x0−x3x2⋮−xd−1xd−2
⊗
sinmθ0sinmθ0sinmθ1sinmθ1⋮sinmθd/2−1sinmθd/2−1
正余弦位置编码(Sinusoidal PE)
xembedded=(x0x1x2x3⋮xd−2xd−1)+(sin(mθ0)cos(mθ0)sin(mθ1)cos(mθ1)⋮sin(mθd/2−1)cos(mθd/2−1))\mathbf{x}_{embedded} = \begin{pmatrix} x_0 \\ x_1 \\ x_2 \\ x_3 \\ \vdots \\ x_{d-2} \\ x_{d-1} \end{pmatrix} + \begin{pmatrix} \sin(m\theta_0) \\ \cos(m\theta_0) \\ \sin(m\theta_1) \\ \cos(m\theta_1) \\ \vdots \\ \sin(m\theta_{d/2-1}) \\ \cos(m\theta_{d/2-1}) \end{pmatrix}xembedded= x0x1x2x3⋮xd−2xd−1 + sin(mθ0)cos(mθ0)sin(mθ1)cos(mθ1)⋮sin(mθd/2−1)cos(mθd/2−1)
相对位置信息推导
RoPE 中的相对位置
设查询向量在位置 mmm,键向量在位置 nnn:
qm=RΘ,mdxm,kn=RΘ,ndxn\mathbf{q}_m = R^d_{\Theta,m} \mathbf{x}_m, \quad \mathbf{k}_n = R^d_{\Theta,n} \mathbf{x}_nqm=RΘ,mdxm,kn=RΘ,ndxn
注意力权重:
qmTkn=xmT(RΘ,md)TRΘ,ndxn\mathbf{q}_m^T \mathbf{k}_n = \mathbf{x}_m^T (R^d_{\Theta,m})^T R^d_{\Theta,n} \mathbf{x}_nqmTkn=xmT(RΘ,md)TRΘ,ndxn
由于旋转矩阵性质:(RΘ,m)TRΘ,n=RΘ,(n−m)(R_{\Theta,m})^T R_{\Theta,n} = R_{\Theta,(n-m)}(RΘ,m)TRΘ,n=RΘ,(n−m)
qmTkn=xmTRΘ,(n−m)dxn\mathbf{q}_m^T \mathbf{k}_n = \mathbf{x}_m^T R^d_{\Theta,(n-m)} \mathbf{x}_nqmTkn=xmTRΘ,(n−m)dxn
结果:直接得到相对位置 (n−m)(n-m)(n−m) 的旋转变换
Sinusoidal PE 中的相对位置
设嵌入后的向量:xm′=xm+PE(m)\mathbf{x}'_m = \mathbf{x}_m + \mathbf{PE}(m)xm′=xm+PE(m),xn′=xn+PE(n)\mathbf{x}'_n = \mathbf{x}_n + \mathbf{PE}(n)xn′=xn+PE(n)
注意力权重:
(xm′)Txn′=(xm+PE(m))T(xn+PE(n))(\mathbf{x}'_m)^T \mathbf{x}'_n = (\mathbf{x}_m + \mathbf{PE}(m))^T (\mathbf{x}_n + \mathbf{PE}(n))(xm′)Txn′=(xm+PE(m))T(xn+PE(n))
=xmTxn+xmTPE(n)+PE(m)Txn+PE(m)TPE(n)= \mathbf{x}_m^T \mathbf{x}_n + \mathbf{x}_m^T \mathbf{PE}(n) + \mathbf{PE}(m)^T \mathbf{x}_n + \mathbf{PE}(m)^T \mathbf{PE}(n)=xmTxn+xmTPE(n)+PE(m)Txn+PE(m)TPE(n)
其中关键项:
PE(m)TPE(n)=∑i=0d/2−1[sin(mθi)sin(nθi)+cos(mθi)cos(nθi)]\mathbf{PE}(m)^T \mathbf{PE}(n) = \sum_{i=0}^{d/2-1} [\sin(m\theta_i)\sin(n\theta_i) + \cos(m\theta_i)\cos(n\theta_i)]PE(m)TPE(n)=∑i=0d/2−1[sin(mθi)sin(nθi)+cos(mθi)cos(nθi)]
=∑i=0d/2−1cos((m−n)θi)= \sum_{i=0}^{d/2-1} \cos((m-n)\theta_i)=∑i=0d/2−1cos((m−n)θi)
结果:通过三角恒等式间接得到相对位置 (m−n)(m-n)(m−n) 信息
旋转位置编码实现流程说明与代码
输出:(batch_size, seq_len, head_dim) 的旋转位置编码,用于 Query/Key 向量旋转。
1. 计算频率倒数
inv_freq[k]=1base2k/dim,k=0,1,…,dim/2−1 \text{inv\_freq}[k] = \frac{1}{\text{base}^{2k/dim}}, \quad k=0,1,\dots,dim/2-1 inv_freq[k]=base2k/dim1,k=0,1,…,dim/2−1
2. 获取位置信息
- 默认连续
[0,1,2,...,seq_len-1] - 或自定义
position_ids(支持非连续、批量等)
3. 外积计算角度矩阵
angles=position_ids⊗inv_freq⇒(batch_size,seq_len,dim/2) \text{angles} = \text{position\_ids} \otimes \text{inv\_freq} \quad \Rightarrow \quad (batch\_size, seq\_len, dim/2) angles=position_ids⊗inv_freq⇒(batch_size,seq_len,dim/2)
4. 复制角度并计算 sin/cos
- 方式 A(连接)(通常实现方式):
- 将前半维和后半维分别组成两部分:
x=[x1,x2,x3,x4]⇒前半维=[x1,x2], 后半维=[x3,x4] x = [x_1, x_2, x_3, x_4] \quad \Rightarrow \quad \text{前半维}=[x_1,x_2],\ \text{后半维}=[x_3,x_4] x=[x1,x2,x3,x4]⇒前半维=[x1,x2], 后半维=[x3,x4]
- 旋转操作:
x′=[x1cosθ1−x3sinθ1, x2cosθ2−x4sinθ2, x3cosθ1+x1sinθ1, x4cosθ2+x2sinθ2] x' = [x_1 \cos\theta_1 - x_3 \sin\theta_1, \ x_2 \cos\theta_2 - x_4 \sin\theta_2, \ x_3 \cos\theta_1 + x_1 \sin\theta_1, \ x_4 \cos\theta_2 + x_2 \sin\theta_2] x′=[x1cosθ1−x3sinθ1, x2cosθ2−x4sinθ2, x3cosθ1+x1sinθ1, x4cosθ2+x2sinθ2]
- 对应代码中的
rotate_half是把前后半维交换,并加上符号。
angles_full=cat([angles,angles],dim=-1)⇒(batch,seq_len,dim) \text{angles\_full} = \text{cat}([\text{angles}, \text{angles}], \text{dim=-1}) \quad \Rightarrow (batch, seq\_len, dim) angles_full=cat([angles,angles],dim=-1)⇒(batch,seq_len,dim)
- 方式 B(交错)(严格对应数学公式):
- 将每两个维度交错组成复数对:
x=[x1,x2,x3,x4]⇒[(x1,x2),(x3,x4)] x = [x_1, x_2, x_3, x_4] \quad \Rightarrow \quad [(x_1,x_2), (x_3,x_4)] x=[x1,x2,x3,x4]⇒[(x1,x2),(x3,x4)]
- 旋转操作:
x′=[x1cosθ1−x2sinθ1, x2cosθ1+x1sinθ1, x3cosθ2−x4sinθ2, x4cosθ2+x3sinθ2] x' = [x_1 \cos\theta_1 - x_2 \sin\theta_1, \ x_2 \cos\theta_1 + x_1 \sin\theta_1, \ x_3 \cos\theta_2 - x_4 \sin\theta_2, \ x_4 \cos\theta_2 + x_3 \sin\theta_2] x′=[x1cosθ1−x2sinθ1, x2cosθ1+x1sinθ1, x3cosθ2−x4sinθ2, x4cosθ2+x3sinθ2]
angles_full=interleave([angles,angles]) \text{angles\_full} = \text{interleave}([\text{angles}, \text{angles}]) angles_full=interleave([angles,angles])
- 计算:
cos=cos(angles_full),sin=sin(angles_full) \cos = \cos(\text{angles\_full}), \quad \sin = \sin(\text{angles\_full}) cos=cos(angles_full),sin=sin(angles_full)
5. 应用旋转变换
- 对输入向量 xxx(Q 或 K):
x′=x⋅cos+rotate_half(x)⋅sin x' = x \cdot \cos + \text{rotate\_half}(x) \cdot \sin x′=x⋅cos+rotate_half(x)⋅sin
rotate_half根据 Cat/Interleave 不同方式选择维度交换策略
关键点
- 频率分层:不同维度对使用不同旋转频率
- 相对位置编码:相同相对距离的 token 对具有相同相对角度差
import torch
import torch.nn as nn
def demonstrate_rope_complete_process():
"""演示RoPE的完整流程"""
print("=== 旋转位置编码(RoPE) 完整流程 ===\n")
# 参数设置
seq_len = 4
head_dim = 8
batch_size = 1
print(f"参数: seq_len={seq_len}, head_dim={head_dim}, batch_size={batch_size}")
# ============ 步骤1:计算频率倒数 ============
print("\n步骤1:计算频率倒数")
base = 10000.0
# 频率倒数:1 / (base^(2i/dim)),只需要dim/2个
dim_pairs = torch.arange(0, head_dim, 2, dtype=torch.float32) # [0, 2, 4, 6]
inv_freq = 1.0 / (base ** (dim_pairs / head_dim))
print(f"dim_pairs: {dim_pairs}")
print(f"inv_freq形状: {inv_freq.shape} = {inv_freq}")
# ============ 步骤2:获取位置信息 ============
print("\n步骤2:获取位置信息")
# 方式1:使用seq_len
position_ids = torch.arange(seq_len, dtype=torch.float32).unsqueeze(0) # [1, seq_len]
# 方式2:直接传入position_ids(更灵活)
# position_ids = torch.tensor([[0, 1, 5, 8]], dtype=torch.float32) # 非连续位置
print(f"position_ids形状: {position_ids.shape} = {position_ids}")
# ============ 步骤3:外积计算角度矩阵 ============
print("\n步骤3:外积计算角度矩阵")
# position_ids: [batch_size, seq_len] -> [batch_size, seq_len, 1]
# inv_freq: [dim/2] -> [1, 1, dim/2]
pos_expanded = position_ids.unsqueeze(-1) # [1, 4, 1]
freq_expanded = inv_freq.unsqueeze(0).unsqueeze(0) # [1, 1, 4]
# 外积:[1, 4, 1] * [1, 1, 4] = [1, 4, 4] (广播)
angles = pos_expanded * freq_expanded # [batch_size, seq_len, dim/2]
print(f"angles形状: {angles.shape}")
print(f"angles[0]:\n{angles[0].round(decimals=3)}")
# ============ 步骤4:复制并计算sin/cos ============
print("\n步骤4:复制并计算sin/cos")
# 两种复制方式:
# 方式A:连接复制 (Concatenate) - 更常用
angles_cat = torch.cat([angles, angles], dim=-1) # [1, 4, 8]
cos_cat = angles_cat.cos()
sin_cat = angles_cat.sin()
print(f"连接方式 - cos形状: {cos_cat.shape}")
print(f"cos_cat[0, 0]: {cos_cat[0, 0].round(decimals=3)}")
# 方式B:交错复制 (Interleave)
angles_interleave = torch.stack([angles, angles], dim=-1) # [1, 4, 4, 2]
angles_interleave = angles_interleave.flatten(start_dim=-2) # [1, 4, 8]
cos_interleave = angles_interleave.cos()
sin_interleave = angles_interleave.sin()
print(f"交错方式 - cos形状: {cos_interleave.shape}")
print(f"cos_interleave[0, 0]: {cos_interleave[0, 0].round(decimals=3)}")
# ============ 步骤5:应用旋转变换 ============
print("\n步骤5:应用旋转变换")
# 模拟查询向量Q
q = torch.randn(batch_size, seq_len, head_dim) # [1, 4, 8]
print(f"原始查询向量Q形状: {q.shape}")
print(f"Q[0, 0]: {q[0, 0].round(decimals=3)}")
# 方式A对应的旋转函数
def rotate_half_cat(x):
"""连接方式的旋转:前半部分取负号移到后面"""
x1 = x[..., :x.shape[-1] // 2] # 前半部分
x2 = x[..., x.shape[-1] // 2:] # 后半部分
return torch.cat([-x2, x1], dim=-1) # [-x2, x1]
# 方式B对应的旋转函数
def rotate_half_interleave(x):
"""交错方式的旋转:奇偶位置交换并变号"""
x = x.reshape(*x.shape[:-1], -1, 2) # [..., dim/2, 2]
x_rotated = torch.stack([-x[..., 1], x[..., 0]], dim=-1) # 交换并变号
return x_rotated.flatten(start_dim=-2)
# 应用旋转变换
print("\n连接方式旋转结果:")
q_rotated_cat = rotate_half_cat(q)
q_embed_cat = q * cos_cat + q_rotated_cat * sin_cat
print(f"旋转后Q形状: {q_embed_cat.shape}")
print(f"Q_embed_cat[0, 0]: {q_embed_cat[0, 0].round(decimals=3)}")
print("\n交错方式旋转结果:")
q_rotated_interleave = rotate_half_interleave(q)
q_embed_interleave = q * cos_interleave + q_rotated_interleave * sin_interleave
print(f"旋转后Q形状: {q_embed_interleave.shape}")
print(f"Q_embed_interleave[0, 0]: {q_embed_interleave[0, 0].round(decimals=3)}")
if __name__ == "__main__":
demonstrate_rope_complete_process()
RoPE在Transformers中的使用流程简化版
def rotate_half(x):
"""旋转输入的一半维度"""
x1 = x[..., : x.shape[-1] // 2]
x2 = x[..., x.shape[-1] // 2 :]
return torch.cat((-x2, x1), dim=-1)
def apply_rotary_pos_emb(q, k, cos, sin, unsqueeze_dim=1):
"""应用旋转位置编码"""
# 添加head维度
cos = cos.unsqueeze(unsqueeze_dim)
sin = sin.unsqueeze(unsqueeze_dim)
q_embed = (q * cos) + (rotate_half(q) * sin)
k_embed = (k * cos) + (rotate_half(k) * sin)
return q_embed, k_embed
def rope_init_fn(config):
dim = config.hidden_size//config.num_attention_heads
inv_freq = 1.0 / (config.base ** (torch.arange(0, dim, 2) / dim))
return inv_freq
class RotaryEmbedding(nn.Module):
def __init__(self, config: ModelConfig, device=None):
super().__init__()
self.inv_freq = rope_init_fn(config)
@torch.no_grad()
def forward(self, position_ids):
"""前向传播:计算cos和sin"""
# 扩展inv_freq和position_ids以便批量计算
# position_ids:(batch_size,seq_len), inv_freq:(dim//2,)
# position_ids_expanded:(batch_size,1,seq_len),inv_freq_expanded:(batch_size,dim//2,1)
inv_freq_expanded = self.inv_freq[None, :, None].expand(position_ids.shape[0], -1, 1)
position_ids_expanded = position_ids[:, None, :]
# 计算频率 * 位置 = 角度
# freqs:(batch_size,seq_len,dim//2)
freqs = (inv_freq_expanded @ position_ids_expanded).transpose(1, 2)
# 复制频率以匹配完整的头维度
emb = torch.cat((freqs, freqs), dim=-1) # (batch_size,seq_len)
# 计算cos和sin,并应用注意力缩放
cos = emb.cos()
sin = emb.sin()
return cos, sin
class Attention(nn.Module):
def __init__(self, config: ModelConfig):
super().__init__()
pass
def forward(self, x):
b, s, h = x.shape # (batch_size, seq_len, hidden_size)
# 得到(batch_size, num_heads, seq_len, head_dim)
query = self.q_proj(x).view(b,s,self.num_heads,self.head_dim).transpose(1,2)
key = self.k_proj(x).view(b,s,self.num_heads,self.head_dim).transpose(1,2)
value = self.v_proj(x).view(b,s,self.num_heads,self.head_dim).transpose(1,2)
# 添加位置编码
position_ids = torch.arange(s).unsqueeze(0)
cos, sin = self.rotary_emb(position_ids)
query, key = apply_rotary_pos_emb(query, key, cos, sin, unsqueeze_dim=1)
...
YaRN - 长上下文外推技术
针对 RoPE 在中高维度外推时可能不稳定的缺点,YaRN 对中高维度进行缩放,同时保留低维度的原始高频,实现了更长上下文的稳定建模。对于长上下文文本,rope仍能处理其中的短距离关系(学习过),但长距离未学习过。Position Interpolation方法是将位置索引拉伸到长上下文范围,而Yarn通过对频率插值来适应长上下文。
对于YaRN,核心为计算出两种频率(外推频率、插值频率),每个维度根据依据权重混合这两种频率,使用渐变的权重让高频主要使用原始外推频率,低频主要使用插值频率。
极简版实现(修改上面的rope_init_fn)
# config.factor为缩放因子代表外推倍数,用于减少低频旋转角度
def rope_init_fn_yarn(config):
dim = config.hidden_size//config.num_attention_heads # 头维度
# 1. 计算基础频率
freqs = config.base ** (torch.arange(0, dim, 2,) / dim)
inv_freq_extrapolation = 1.0 / pos_freqs
inv_freq_interpolation = 1.0 / (config.factor * pos_freqs) # 减少低频旋转角度
# 2. 简单渐变混合(维度上线性渐变)
extrapolation_factor = 1 - torch.linspace(0, 1, dim // 2) # (1 -> 0)
# 3. 混合最终频率
inv_freq = inv_freq_interpolation * (1 - extrapolation_factor) + inv_freq_extrapolation * extrapolation_factor
return inv_freq
transformers中的实现核心流程:
- 生成一个随factor进行log增长的放大系数,即当增大上下文区间的同时适当放大注意力分数,会使用到参数mscale(控制缩放幅度)、mscale_all_dim(进行维度归一)
- beta_fast、beta_slow为设定的圈数界限,经验值为32和1。根据圈数界限计算出对应的维度区间
- 依据上下界给每个维度生成0~1之间的分段系数,区间左侧完全使用外推频率(即原始频率),区间右侧完全使用插值频率(即分母成了factor后放慢的频率),区间中线性混合
def _compute_yarn_parameters(config)
base = config.rope_theta
partial_rotary_factor = config.partial_rotary_factor if hasattr(config, "partial_rotary_factor") else 1.0
head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads)
dim = int(head_dim * partial_rotary_factor)
factor = config.rope_scaling["factor"]
attention_factor = config.rope_scaling.get("attention_factor")
mscale = config.rope_scaling.get("mscale")
mscale_all_dim = config.rope_scaling.get("mscale_all_dim")
original_max_position_embeddings = (
config.rope_scaling.get("original_max_position_embeddings") or config.max_position_embeddings
)
def get_mscale(scale, mscale=1):
if scale <= 1:
return 1.0
return 0.1 * mscale * math.log(scale) + 1.0
# Sets the attention factor as suggested in the paper
if attention_factor is None:
if mscale and mscale_all_dim:
attention_factor = float(get_mscale(factor, mscale) / get_mscale(factor, mscale_all_dim))
else:
attention_factor = get_mscale(factor)
# Optional config options
# beta_fast/beta_slow: as suggested in the paper, default to 32/1 (correspondingly)
beta_fast = config.rope_scaling.get("beta_fast") or 32
beta_slow = config.rope_scaling.get("beta_slow") or 1
# Compute the inverse frequencies
def find_correction_dim(num_rotations, dim, base, max_position_embeddings):
"""Inverse dimension formula to find the dimension based on the number of rotations"""
return (dim * math.log(max_position_embeddings / (num_rotations * 2 * math.pi))) / (2 * math.log(base))
def find_correction_range(low_rot, high_rot, dim, base, max_position_embeddings, truncate):
"""Find dimension range bounds based on rotations"""
low = find_correction_dim(low_rot, dim, base, max_position_embeddings)
high = find_correction_dim(high_rot, dim, base, max_position_embeddings)
if truncate:
low = math.floor(low)
high = math.ceil(high)
return max(low, 0), min(high, dim - 1)
def linear_ramp_factor(min, max, dim):
if min == max:
max += 0.001 # Prevent singularity
linear_func = (torch.arange(dim, dtype=torch.float32) - min) / (max - min)
ramp_func = torch.clamp(linear_func, 0, 1)
return ramp_func
# Note on variable naming: "interpolation" comes from the original technique, where we interpolate the position IDs
# to expand the possible context length. In other words, interpolation = apply scaling factor.
pos_freqs = base ** (torch.arange(0, dim, 2).to(device=device, dtype=torch.float) / dim)
inv_freq_extrapolation = 1.0 / pos_freqs
inv_freq_interpolation = 1.0 / (factor * pos_freqs)
truncate = config.rope_scaling.get("truncate", True)
low, high = find_correction_range(beta_fast, beta_slow, dim, base, original_max_position_embeddings, truncate)
# Get n-dimensional rotational scaling corrected for extrapolation
inv_freq_extrapolation_factor = 1 - linear_ramp_factor(low, high, dim // 2).to(device=device, dtype=torch.float)
inv_freq = (
inv_freq_interpolation * (1 - inv_freq_extrapolation_factor)
+ inv_freq_extrapolation * inv_freq_extrapolation_factor
)
return inv_freq, attention_factor
更多推荐


所有评论(0)