动手实现一个LLaMA2大模型
动手实现一个LLaMA2大模型
1、定义超参数
from transformers import PretrainedConfig
class ModelConfig(PretrainedConfig):
model_type = "Tiny-K"
def __init__(
self,
dim: int = 768, # 模型维度
n_layers: int = 12, # Transformer的层数
n_heads: int = 16, # 注意力机制的头数
n_kv_heads: int = 8, # 键值头的数量
vocab_size: int = 6144, # 词汇表大小
hidden_dim: int = None, # 隐藏层维度
multiple_of: int = 64,
norm_eps: float = 1e-5, # 归一化层的eps
max_seq_len: int = 512, # 最大序列长度
dropout: float = 0.0, # dropout概率
flash_attn: bool = True, # 是否使用Flash Attention
**kwargs,
):
self.dim = dim
self.n_layers = n_layers
self.n_heads = n_heads
self.n_kv_heads = n_kv_heads
self.vocab_size = vocab_size
self.hidden_dim = hidden_dim
self.multiple_of = multiple_of
self.norm_eps = norm_eps
self.max_seq_len = max_seq_len
self.dropout = dropout
self.flash_attn = flash_attn
super().__init__(**kwargs)
2、构建归一化函数 RMSNorm
RMSNorm (Root Mean Square Layer Normalization) 是由 Biao Zhang 和 Rico Sennrich 在 2019 年提出的一种归一化机制。它是常用的 LayerNorm 的变体。
2.1 核心思想
传统的 LayerNorm 会计算神经元输出的均值和方差,并进行平移和缩放。
RMSNorm 做了一个大胆的简化:它认为 LayerNorm 的成功主要源于重缩放(Rescaling),而**平移不变性(均值中心化)**并不是必须的。
因此,RMSNorm 只计算均值的平方根(RMS),而不减去均值。
2.2 与 LayerNorm 的区别
- LayerNorm: 需要计算均值和方差 ,计算量稍大。
- RMSNorm: 只计算平方和的平均值的平方根。它省去了减去均值的步骤,计算更简单,速度更快(通常能提升 10%~40% 的效率),且效果往往与 LayerNorm 相当。
RMS ( x ) = 1 n ∑ i = 1 n x i 2 + ϵ x ˉ = x RMS ( x ) y = γ ⊙ x ˉ \begin{aligned} \text{RMS}(x) &= \sqrt{\frac{1}{n} \sum_{i=1}^{n} x_i^2 + \epsilon} \\ \bar{x} &= \frac{x}{\text{RMS}(x)} \\ y &= \gamma \odot \bar{x} \end{aligned} RMS(x)xˉy=n1i=1∑nxi2+ϵ=RMS(x)x=γ⊙xˉ
2.3 代码段
class RMSNorm(nn.Module):
def __init__(self, dim: int, eps: float):
super().__init__()
# eps是为了防止除以0的情况
self.eps = eps
# weight是一个可学习的参数,全部初始化为1
self.weight = nn.Parameter(torch.ones(dim))
def _norm(self, x):
# 计算RMSNorm的核心部分
# x.pow(2).mean(-1, keepdim=True)计算了输入x的平方的均值
# torch.rsqrt是平方根的倒数,这样就得到了RMSNorm的分母部分,再加上eps防止分母为0
# 最后乘以x,得到RMSNorm的结果
return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
def forward(self, x):
# forward函数是模型的前向传播
# 首先将输入x转为float类型,然后进行RMSNorm,最后再转回原来的数据类型
# 最后乘以weight,这是RMSNorm的一个可学习的缩放因子
output = self._norm(x.float()).type_as(x)
return output * self.weight
3、构建 LLaMA2 Attention
我们采用分组查询注意力机制(Grouped-Query Attention,GQA)
3.1 GQA
**分组查询注意力(Grouped-Query Attention,GQA)**它是为了平衡 全多头注意力(MHA) 的高性能和 多查询注意力(MQA) 的低推理成本而设计的。
3.1.1 核心设计思想
为了理解 GQA,我们将注意力机制中的三个核心组件(Query, Key, Value)的“头数”进行对比:
- Multi-Head Attention (MHA): 每个查询头(Query)都有对应的一组 Key 和 Value。虽然效果好,但推理时 KV 缓存(KV Cache)非常占显存。
- Multi-Query Attention (MQA): 所有的查询头共享同一组 Key 和 Value。虽然显著减少了显存占用和推理延迟,但模型表达能力下降,容易导致性能受损。
- Grouped-Query Attention (GQA): 它采取了“折中方案”。将查询头分成若干组,每一组查询头共享一组 Key 和 Value。
3.1.2 GQA 的优势
- 显存开销更低:相比 MHA,由于 Key 和 Value 的头数减少了,KV 缓存占用的显存大幅下降(通常只有 MHA 的 1/4 或 1/8)。
- 推理速度更快:减少了数据加载(IO)的开销。在自回归生成(解码阶段),由于需要频繁搬运 KV 缓存,KV 头数越少,模型响应越快。
- 效果不打折:GQA 的表现非常接近 MHA,远远优于 MQA,成功在速度和效果之间找到了“甜点”。
3.2 扩展键值维度
在 LLaMA2 模型中,我们需要将键和值的维度扩展到和查询的维度一样,这样才能进行注意力计算
3.2.1原理
为什么需要扩展键值维度?
Attention的硬约束:
Q.shape = [B, T, n_heads, d]
K.shape = [B, T, n_heads, d]
V.shape = [B, T, n_heads, d]
Q、K、V三个头的数量必须一致。
FlashAttention / 标准 Attention 代码假设 Q、K、V 的 head 数一致,而GQA中Q头多,KV头少,实际上repeat_kv是将GQA翻译成MHA的形状。
repeat_kv:为了 GQA,把每个 KV 头“复制”给多个 Q 头使用,使注意力计算形状对齐。
3.2.2 repeat_kv 实现三部曲
3.2.2.1 添加新维度
形状变化:
[B, T, n_kv_heads, d]
→
[B, T, n_kv_heads, 1, d]
这个新维度表达的含义是:一个 KV 头,要被N个 Q 头共享
3.2.2.2 expand
新维度初始值为1我们需要将上述N赋值,通过expand操作实现
[B, T, n_kv_heads, 1, d]
→
[B, T, n_kv_heads, N, d]
3.2.2.3 reshape
将新维度合并到n_kv_heads中去
[B, T, n_kv_heads, N, d]
→
[B, T, n_kv_heads * N, d]
n_heads = n_kv_heads * N
为什么需要合并?
因为后续的attention kernel 只认识**[B, T, n_heads, d]**,因此我们必须对下游算子做接口适配。
3.2.3 代码段
def repeat_kv(x: torch.Tensor, n_rep: int) -> torch.Tensor:
# 获取输入张量的形状:批量大小、序列长度、键/值对头的数量、每个头的维度大小
bs, slen, n_kv_heads, head_dim = x.shape
# 如果重复次数为1,则不需要重复,直接返回原始张量
if n_rep == 1:
return x
# 对张量进行扩展和重塑操作以重复键值对
return (
x[:, :, :, None, :] # 在第四个维度(头的维度前)添加一个新的维度
.expand(bs, slen, n_kv_heads, n_rep, head_dim) # 将新添加的维度扩展到n_rep大小,实现重复的效果
.reshape(bs, slen, n_kv_heads * n_rep, head_dim) # 重新塑形,合并键/值对头的数量和重复次数的维度
)
3.3 RoPE 旋转嵌入位置编码
3.3.1 什么是RoPE
旋转嵌入是 LLaMA2 模型中的一个重要组件,它可以为注意力机制提供更强的上下文信息,从而提高模型的性能。RoPE在功能上等价于 positional encoding
但它不是把“位置向量加到embedding”,而是通过对Q/K的旋转,把"相对位置信息"编码尽注意力分数本身。
在LLaMA/Qwen/DeepSeek/Yi/Beichuan里:RoPE完全取代了传统positional embedding
RoPE的主要思想:位置不是一个向量,而是一个旋转角度
3.3.2 具体实现
3.3.2.1 分组
将head_dim(每一个attention head 自己看到的向量维度) 分组,每组两个元素。
假设 head_dim = 8:
qi=(q0,q1,q2,q3,q4,q5,q6,q7)
RoPE 认为:
- (q0,q1) 是一组
- (q2,q3) 是一组
- …
每一组 2 维 = 平面坐标
3.2.2.2 对每一组做旋转:
对于某一组 (x,y),在位置 i:
( x ′ y ′ ) = ( cos θ i − sin θ i sin θ i cos θ i ) ( x y ) \begin{pmatrix} x' \\ y' \end{pmatrix} = \begin{pmatrix} \cos \theta_i & -\sin \theta_i \\ \sin \theta_i & \cos \theta_i \end{pmatrix} \begin{pmatrix} x \\ y \end{pmatrix} (x′y′)=(cosθisinθi−sinθicosθi)(xy)
这就是二维平面旋转。
其中:
θ i , m = i ⋅ ω m \theta_{i,m} = i \cdot \omega_m θi,m=i⋅ωm
-
i:token位置
-
m:第m组维度
-
ω m = 10000 − 2 m / d \omega_m = 10000^{-2m/d} ωm=10000−2m/d
3.2.2.3 用复数写
把一个二维向量看成复数:
x + i y x + iy x+iy
旋转 = 乘以一个单位复数:
( x + i y ) ⋅ e i θ (x + iy) \cdot e^{i\theta} (x+iy)⋅eiθ
于是:
q ~ i = q i ⊙ e i θ i k ~ j = k j ⊙ e i θ j \begin{aligned} \tilde{q}_i &= q_i \odot e^{i\theta_i} \\ \tilde{k}_j &= k_j \odot e^{i\theta_j} \end{aligned} q~ik~j=qi⊙eiθi=kj⊙eiθj
3.2.2.4 总结
Attention 的核心是内积:
q ~ i ⊤ k ~ j \tilde{q}_i^\top \tilde{k}_j q~i⊤k~j
带入旋转后的形式:
q ~ i ⊤ k ~ j ∝ q i ⊤ k j ⋅ e i ( θ i − θ j ) \tilde{q}_i^\top \tilde{k}_j \propto q_i^\top k_j \cdot e^{i(\theta_i-\theta_j)} q~i⊤k~j∝qi⊤kj⋅ei(θi−θj)
只剩下
i − j i - j i−j
即两个token的相对位置。
3.3.3 代码段
precompute_freqs_cis:负责提前算好每个token位置i、每个维度对,要转多少角度
reshape_for_broadcast:广播freqs_cos到batch和head维度,实现x * freqs_cos
apply_rotary_emb:真正实现旋转操作
# 注意:此处的dim应为 dim//n_head,因为我们是对每个head进行旋转嵌入
def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0):
# torch.arange(0, dim, 2)[: (dim // 2)].float()生成了一个从0开始,步长为2的序列,长度为dim的一半
# 然后每个元素除以dim,再取theta的倒数,得到频率
freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
# 生成一个从0到end的序列,长度为end
t = torch.arange(end, device=freqs.device)
# 计算外积,得到一个二维矩阵,每一行是t的元素乘以freqs的元素
freqs = torch.outer(t, freqs).float()
# 计算频率的余弦值,得到实部
freqs_cos = torch.cos(freqs)
# 计算频率的正弦值,得到虚部
freqs_sin = torch.sin(freqs)
return freqs_cos, freqs_sin
def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor):
# 获取x的维度数
ndim = x.ndim
# 断言,确保1在x的维度范围内
assert 0 <= 1 < ndim
# 断言,确保freqs_cis的形状与x的第二维和最后一维相同
assert freqs_cis.shape == (x.shape[1], x.shape[-1])
# 构造一个新的形状,除了第二维和最后一维,其他维度都为1,这样做是为了能够将freqs_cis与x进行广播操作
shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)]
# 将freqs_cis调整为新的形状,并返回
return freqs_cis.view(shape)
def apply_rotary_emb(
xq: torch.Tensor,
xk: torch.Tensor,
freqs_cos: torch.Tensor,
freqs_sin: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor]:
# 将查询和键张量转换为浮点数,并重塑形状以分离实部和虚部
xq_r, xq_i = xq.float().reshape(xq.shape[:-1] + (-1, 2)).unbind(-1)
xk_r, xk_i = xk.float().reshape(xk.shape[:-1] + (-1, 2)).unbind(-1)
# 重新塑形频率张量以进行广播
freqs_cos = reshape_for_broadcast(freqs_cos, xq_r)
freqs_sin = reshape_for_broadcast(freqs_sin, xq_r)
# 应用旋转,分别计算旋转后的实部和虚部
xq_out_r = xq_r * freqs_cos - xq_i * freqs_sin
xq_out_i = xq_r * freqs_sin + xq_i * freqs_cos
xk_out_r = xk_r * freqs_cos - xk_i * freqs_sin
xk_out_i = xk_r * freqs_sin + xk_i * freqs_cos
# 将最后两个维度合并,并还原为原始张量的形状
xq_out = torch.stack([xq_out_r, xq_out_i], dim=-1).flatten(3)
xk_out = torch.stack([xk_out_r, xk_out_i], dim=-1).flatten(3)
return xq_out.type_as(xq), xk_out.type_as(xk)
3.4 组装 LLaMA2 Attention
3.4.1 具体流程
-
输入
x: [bsz, seqlen, dim] -
线性投影 -> Q/K/V
xq = self.wq(x) xk = self.wk(x) xv = self.wv(x) xq: [bsz, seqlen, n_heads * head_dim] xk: [bsz, seqlen, n_kv_heads * head_dim] xv: [bsz, seqlen, n_kv_heads * head_dim] -
reshape 变成多头形式
xq = xq.view(bsz, seqlen, n_heads, head_dim) xk = xk.view(bsz, seqlen, n_kv_heads, head_dim) xv = xv.view(bsz, seqlen, n_kv_heads, head_dim) xq: [bsz, seqlen, n_heads, head_dim] xk: [bsz, seqlen, n_kv_heads, head_dim] xv: [bsz, seqlen, n_kv_heads, head_dim] -
应用RoPE
xq, xk = apply_rotary_emb(xq, xk, freqs_cos, freqs_sin) -
repeat_kv
xk = repeat_kv(xk, self.n_rep) xv = repeat_kv(xv, self.n_rep) -> xk, xv: [bsz, seqlen, n_heads, head_dim] -
transpose, 把head提到batch维
让每个head的token成为attention的计算单元,可以想象成一个seq里面的各个token对应位置的head作为一条流水线。attention是对同一个head里面的token序列做两两相似度,我们要让每个head像一个“独立的小batch”一样跑attention。
xq = xq.transpose(1, 2) xk = xk.transpose(1, 2) xv = xv.transpose(1, 2) -> xq,xk,xv: [bsz, n_heas, seqlen, head_dim]Linear → view → transpose 这是深度学习里面的一个非常固定的工程套路 -
做Attention计算
Attention ( Q , K , V ) = softmax ( f R o P E ( Q ) f R o P E ( K ) T d k + M ) V \text{Attention}(Q, K, V) = \text{softmax}\left(\frac{f_{RoPE}(Q)f_{RoPE}(K)^T}{\sqrt{d_k}} + M\right)V Attention(Q,K,V)=softmax(dkfRoPE(Q)fRoPE(K)T+M)V -
合并多头
output = output.transpose(1, 2).view(bsz, seqlen, -1) -> [bsz, seqlen, n_heads * head_dim] -
投影回残差流
output = self.wo(output) output = self.resid_dropout(output) -> [bsz, seqlen, dim] -
总结
Linear(Q/K/V) → reshape heads → RoPE(Q, K) → repeat_kv (GQA) → causal attention → concat heads → Linear
3.3.2代码段
class Attention(nn.Module):
def __init__(self, args: ModelConfig):
super().__init__()
# 根据是否指定n_kv_heads,确定用于键(key)和值(value)的头的数量。
self.n_kv_heads = args.n_heads if args.n_kv_heads is None else args.n_kv_heads
# 确保总头数可以被键值头数整除。
assert args.n_heads % self.n_kv_heads == 0
# 模型并行处理大小,默认为1。
model_parallel_size = 1
# 本地计算头数,等于总头数除以模型并行处理大小。
self.n_local_heads = args.n_heads // model_parallel_size
# 本地键值头数,等于键值头数除以模型并行处理大小。
self.n_local_kv_heads = self.n_kv_heads // model_parallel_size
# 重复次数,用于扩展键和值的尺寸。
self.n_rep = self.n_local_heads // self.n_local_kv_heads
# 每个头的维度,等于模型维度除以头的总数。
self.head_dim = args.dim // args.n_heads
# 定义权重矩阵。
self.wq = nn.Linear(args.dim, args.n_heads * self.head_dim, bias=False)
self.wk = nn.Linear(args.dim, self.n_kv_heads * self.head_dim, bias=False)
self.wv = nn.Linear(args.dim, self.n_kv_heads * self.head_dim, bias=False)
# 输出权重矩阵。
self.wo = nn.Linear(args.n_heads * self.head_dim, args.dim, bias=False)
# 定义dropout。
self.attn_dropout = nn.Dropout(args.dropout)
self.resid_dropout = nn.Dropout(args.dropout)
# 保存dropout概率。
self.dropout = args.dropout
# 检查是否使用Flash Attention(需要PyTorch >= 2.0)。
self.flash = hasattr(torch.nn.functional, 'scaled_dot_product_attention')
if not self.flash:
# 若不支持Flash Attention,则使用手动实现的注意力机制,并设置mask。
print("WARNING: using slow attention. Flash Attention requires PyTorch >= 2.0")
# 创建一个上三角矩阵,用于遮蔽未来信息。
mask = torch.full((1, 1, args.max_seq_len, args.max_seq_len), float("-inf"))
mask = torch.triu(mask, diagonal=1)
# 注册为模型的缓冲区
self.register_buffer("mask", mask)
def forward(self, x: torch.Tensor, freqs_cos: torch.Tensor, freqs_sin: torch.Tensor):
# 获取批次大小和序列长度,[batch_size, seq_len, dim]
bsz, seqlen, _ = x.shape
# 计算查询(Q)、键(K)、值(V)。
xq, xk, xv = self.wq(x), self.wk(x), self.wv(x)
# 调整形状以适应头的维度。
xq = xq.view(bsz, seqlen, self.n_local_heads, self.head_dim)
xk = xk.view(bsz, seqlen, self.n_local_kv_heads, self.head_dim)
xv = xv.view(bsz, seqlen, self.n_local_kv_heads, self.head_dim)
# 应用旋转位置嵌入(RoPE)。
xq, xk = apply_rotary_emb(xq, xk, freqs_cos, freqs_sin)
# 对键和值进行扩展以适应重复次数。
xk = repeat_kv(xk, self.n_rep)
xv = repeat_kv(xv, self.n_rep)
# 将头作为批次维度处理。
xq = xq.transpose(1, 2)
xk = xk.transpose(1, 2)
xv = xv.transpose(1, 2)
# 根据是否支持Flash Attention,选择实现方式。
if self.flash:
# 使用Flash Attention。
output = torch.nn.functional.scaled_dot_product_attention(xq, xk, xv, attn_mask=None, dropout_p=self.dropout if self.training else 0.0, is_causal=True)
else:
# 使用手动实现的注意力机制。
scores = torch.matmul(xq, xk.transpose(2, 3)) / math.sqrt(self.head_dim)
assert hasattr(self, 'mask')
scores = scores + self.mask[:, :, :seqlen, :seqlen]
scores = F.softmax(scores.float(), dim=-1).type_as(xq)
scores = self.attn_dropout(scores)
output = torch.matmul(scores, xv)
# 恢复时间维度并合并头。
output = output.transpose(1, 2).contiguous().view(bsz, seqlen, -1)
# 最终投影回残差流。
output = self.wo(output)
output = self.resid_dropout(output)
return output
4、MLP层
LLaMA2采用的不是“普通FFN”,而是一个门控MLP (SwiGLU)
4.1 传统Transformer FNN
对每个token的表示做一次非线性重映射 (逐token的MLP)
FFN(x) = W2(GELU(W1(x)))
x → Linear(dim → hidden)
→ 添加非线性表达
→ Linear(hidden → dim)
4.2 SwiGLU
SwiGLU = 带SiLU 激活的门控线性单元,用门控机制对高维特征进行条件化选择,在更少的参数下实现更强的表达能力。
S w i G L U ( x ) = W 2 ( S i L U ( W 1 x ) ⊙ ( W 3 x ) ) SwiGLU(x)=W2(SiLU(W1x)⊙(W3x)) SwiGLU(x)=W2(SiLU(W1x)⊙(W3x))
大致分为三步运算:
-
内容分支
v = F.silu(self.w1(x))- w1: dim -> hidden_dim
- SiLU 激活
-
门控分支
g = self.w3(x)- w3:dim -> hidden_dim
- 负责控制每一维特征“开多少门”
-
门控融合
h = v * g- 每一维都被单独调制
- “这一维重要就放大,不重要就缩小”
4.2 代码段
class MLP(nn.Module):
def __init__(self, dim: int, hidden_dim: int, multiple_of: int, dropout: float):
super().__init__()
# 如果没有指定隐藏层的维度,我们将其设置为输入维度的4倍
# 然后将其减少到2/3,最后确保它是multiple_of的倍数
if hidden_dim is None:
hidden_dim = 4 * dim
hidden_dim = int(2 * hidden_dim / 3)
hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of)
# 定义第一层线性变换,从输入维度到隐藏维度
self.w1 = nn.Linear(dim, hidden_dim, bias=False)
# 定义第二层线性变换,从隐藏维度到输入维度
self.w2 = nn.Linear(hidden_dim, dim, bias=False)
# 定义第三层线性变换,从输入维度到隐藏维度
self.w3 = nn.Linear(dim, hidden_dim, bias=False)
# 定义dropout层,用于防止过拟合
self.dropout = nn.Dropout(dropout)
def forward(self, x):
# 前向传播函数
# 首先,输入x通过第一层线性变换和SILU激活函数
# 然后,结果乘以输入x通过第三层线性变化的结果
# 最后,通过第二层线性变换和dropout层
return self.dropout(self.w2(F.silu(self.w1(x)) * self.w3(x)))
5、Decoder Layer
- DecoderLayer= Attention + MLP + 两次归一化 + 两次残差连接
- 他的作用是先让token之间“互相看”(Attention),再让每个token自己“深度思考”(MLP)
代码段
class DecoderLayer(nn.Module):
def __init__(self, layer_id: int, args: ModelConfig):
super().__init__()
# 定义多头注意力的头数
self.n_heads = args.n_heads
# 定义输入维度
self.dim = args.dim
# 定义每个头的维度,等于输入维度除以头数
self.head_dim = args.dim // args.n_heads
# 定义LLaMA2Attention对象,用于进行多头注意力计算
self.attention = Attention(args)
# 定义LLaMAMLP对象,用于进行前馈神经网络计算
self.feed_forward = MLP(
dim=args.dim,
hidden_dim=args.hidden_dim,
multiple_of=args.multiple_of,
dropout=args.dropout,
)
# 定义层的ID
self.layer_id = layer_id
# 定义注意力计算的归一化层
self.attention_norm = RMSNorm(args.dim, eps=args.norm_eps)
# 定义前馈神经网络计算的归一化层
self.ffn_norm = RMSNorm(args.dim, eps=args.norm_eps)
def forward(self, x, freqs_cos, freqs_sin):
# 前向传播函数
# 首先,输入x经过注意力归一化层,然后进行注意力计算,结果与输入x相加得到h
# 然后,h经过前馈神经网络归一化层,然后进行前馈神经网络计算,结果与h相加得到输出
h = x + self.attention.forward(self.attention_norm(x), freqs_cos, freqs_sin)
out = h + self.feed_forward.forward(self.ffn_norm(h))
return out
6、 构建LLaMA2模型
6.1 组装流程
- 把token -> 向量 (Embedding)
- 把DecoderLayer 堆N层
- 用RoPE 提供位置信息
- 把隐藏状态 -> 词表概率 (LM Head)
- 同时支持训练 (算loss)和推理 (生成)
6.2 构建细节
6.2.1 Weight Tying
让“理解一个词” 和 “如何预测一个词” 使用同一套参数
减少参数 + 提升泛化 + 语言建模更加合理
self.tok_embeddings.weight = self.output.weight
6.2.2 自回归生成 (autoregressive generation)
从已有的token序列idx出发,每次预测下一个token,再把它拼回去,循环生成样本。本质是循环调用forward的推理分支。
def generate(self, idx, stop_id=None, max_new_tokens=256, temperature=1.0, top_k=None):
6.2.2.1 参数解释
| 参数 | 含义 |
|---|---|
idx |
初始 token 序列,shape = (B, T) |
stop_id |
遇到这个 token 就停止生成(如 EOS) |
max_new_tokens |
最多生成多少个新 token |
temperature |
控制随机性 |
top_k |
只在概率最高的 k 个 token 里采样 |
6.2.2.2 代码解读
- 记录原始长度,方便后面截断,最终返回时,只返回新生成的部分
index = idx.shape[1]
- 生成循环:每次生成一个token
for _ in range(max_new_tokens):
- 上下文裁剪:Transformer不能无限长,超过max_seq_len 时:只保留最后max_seq_len个token,类似于滑动窗口。因为自回归模型只关心最近的上下文。
idx_cond = idx if idx.size(1) <= self.args.max_seq_len else idx[:, -self.args.max_seq_len:]
- 前向传播(计算logits),只取最后一个位置,因为我们只需要下一个位置的token,前面的logits已经没用了。
logits = self(idx_cond).logits # 调用forward中的推理分支
logits = logits[:, -1, :]
logits shape = (B, T, vocab_size)
logits[b, t] 是一个长度为 vocab_size 的向量
在第 b 个样本、第 t 个位置,下一个 token 是“哪一个词”的可能性。
-
两种生成策略
- temperature=0(贪心解码):直接选概率最大的token
if temperature == 0.0: _, idx_next = torch.topk(logits, k=1, dim=-1)- temperature>0(随机采样)
logits = logits / temperature -
Top-K截断,防止胡说,只允许概率最高的k个token参与采样。
if top_k is not None:
v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
logits[logits < v[:, [-1]]] = -float('Inf')
- Softmax采用
probs = F.softmax(logits, dim=-1) # 将logits归一化为概率,原来的logits为对每个词的可能性进行打分可以为任意实数。
idx_next = torch.multinomial(probs, num_samples=1)
- 拼接token(自回归核心)
idx = torch.cat((idx, idx_next), dim=1)
x₁ x₂ x₃ → 预测 x₄
x₁ x₂ x₃ x₄ → 预测 x₅
6.3 代码段
class Transformer(PreTrainedModel):
config_class = ModelConfig # 配置类
last_loss: Optional[torch.Tensor] # 记录最后一次计算的损失
def __init__(self, args: ModelConfig = None):
super().__init__(args)
# 初始化模型参数
self.args = args
# 词汇表大小
self.vocab_size = args.vocab_size
# 层数
self.n_layers = args.n_layers
# 词嵌入层
self.tok_embeddings = nn.Embedding(args.vocab_size, args.dim)
# Dropout层
self.dropout = nn.Dropout(args.dropout)
# Decoder层
self.layers = torch.nn.ModuleList()
for layer_id in range(args.n_layers):
self.layers.append(DecoderLayer(layer_id, args))
# 归一化层
self.norm = RMSNorm(args.dim, eps=args.norm_eps)
# 输出层
self.output = nn.Linear(args.dim, args.vocab_size, bias=False)
# 将词嵌入层的权重与输出层的权重共享
self.tok_embeddings.weight = self.output.weight
# 预计算相对位置嵌入的频率
freqs_cos, freqs_sin = precompute_freqs_cis(self.args.dim // self.args.n_heads, self.args.max_seq_len)
self.register_buffer("freqs_cos", freqs_cos, persistent=False)
self.register_buffer("freqs_sin", freqs_sin, persistent=False)
# 初始化所有权重
self.apply(self._init_weights)
# 对残差投影进行特殊的缩放初始化
for pn, p in self.named_parameters():
if pn.endswith('w3.weight') or pn.endswith('wo.weight'):
torch.nn.init.normal_(p, mean=0.0, std=0.02/math.sqrt(2 * args.n_layers))
# 初始化最后一次前向传播的损失属性
self.last_loss = None
self.OUT = CausalLMOutputWithPast() # 输出容器
self._no_split_modules = [name for name, _ in self.named_modules()] # 不分割的模块列表
def _init_weights(self, module):
# 初始化权重的函数
if isinstance(module, nn.Linear):
torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
if module.bias is not None:
torch.nn.init.zeros_(module.bias)
elif isinstance(module, nn.Embedding):
torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
def forward(self, tokens: torch.Tensor, targets: Optional[torch.Tensor] = None, **kwargs) -> torch.Tensor:
"""
- tokens: Optional[torch.Tensor], 输入 token 张量。
- targets: Optional[torch.Tensor], 目标 token 张量。
- kv_cache: bool, 是否使用键值缓存。
- kwargs: 其他关键字参数。
- self.OUT: CausalLMOutputWithPast, 包含 logits 和损失。
"""
if 'input_ids' in kwargs:
tokens = kwargs['input_ids']
if 'attention_mask' in kwargs:
targets = kwargs['attention_mask']
# 前向传播函数
_bsz, seqlen = tokens.shape
# 通过词嵌入层和Dropout层
h = self.tok_embeddings(tokens)
h = self.dropout(h)
# 获取相对位置嵌入的频率
freqs_cos = self.freqs_cos[:seqlen]
freqs_sin = self.freqs_sin[:seqlen]
# 通过Decoder层
for layer in self.layers:
h = layer(h, freqs_cos, freqs_sin)
# 通过归一化层
h = self.norm(h)
if targets is not None:
# 如果给定了目标,计算损失
logits = self.output(h)
self.last_loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1), ignore_index=0, reduction='none')
else:
# 推理时的小优化:只对最后一个位置的输出进行前向传播
logits = self.output(h[:, [-1], :])
self.last_loss = None
# 设置输出
self.OUT.__setitem__('logits', logits)
self.OUT.__setitem__('last_loss', self.last_loss)
return self.OUT
@torch.inference_mode()
def generate(self, idx, stop_id=None, max_new_tokens=256, temperature=1.0, top_k=None):
"""
给定输入序列 idx(形状为 (bz,seq_len) 的长整型张量),通过多次生成新 token 来完成序列。
在 model.eval() 模式下运行。效率较低的采样版本,没有使用键k/v cache。
"""
index = idx.shape[1]
for _ in range(max_new_tokens):
# 如果序列上下文过长,截断它到最大长度
idx_cond = idx if idx.size(1) <= self.args.max_seq_len else idx[:, -self.args.max_seq_len:]
# 前向传播获取序列中最后一个位置的 logits
logits = self(idx_cond).logits
logits = logits[:, -1, :] # 只保留最后一个时间步的输出
if temperature == 0.0:
# 选择最有可能的索引
_, idx_next = torch.topk(logits, k=1, dim=-1)
else:
# 缩放 logits 并应用 softmax
logits = logits / temperature
if top_k is not None:
v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
logits[logits < v[:, [-1]]] = -float('Inf')
probs = F.softmax(logits, dim=-1)
idx_next = torch.multinomial(probs, num_samples=1)
if idx_next == stop_id:
break
# 将采样的索引添加到序列中并继续
idx = torch.cat((idx, idx_next), dim=1)
return idx[:, index:] # 只返回生成的token
更多推荐


所有评论(0)