GPT-2 模型核心代码实现

以下代码基于 PyTorch 框架,实现 GPT-2 的核心组件,包括多头注意力机制、前馈网络和 Transformer 块。代码严格遵循原始论文设计,并添加了关键注释。

import torch
import torch.nn as nn
import torch.nn.functional as F

class MultiHeadAttention(nn.Module):
    def __init__(self, embed_dim, num_heads):
        super().__init__()
        self.embed_dim = embed_dim
        self.num_heads = num_heads
        self.head_dim = embed_dim // num_heads
        
        self.qkv_proj = nn.Linear(embed_dim, 3*embed_dim)
        self.out_proj = nn.Linear(embed_dim, embed_dim)
        
    def forward(self, x, mask=None):
        batch_size, seq_len, _ = x.shape
        qkv = self.qkv_proj(x)
        qkv = qkv.reshape(batch_size, seq_len, self.num_heads, 3*self.head_dim)
        qkv = qkv.permute(0, 2, 1, 3)
        q, k, v = qkv.chunk(3, dim=-1)
        
        attn_scores = torch.matmul(q, k.transpose(-2, -1)) / (self.head_dim ** 0.5)
        if mask is not None:
            attn_scores = attn_scores.masked_fill(mask == 0, float('-inf'))
        attn_probs = F.softmax(attn_scores, dim=-1)
        
        output = torch.matmul(attn_probs, v)
        output = output.permute(0, 2, 1, 3).reshape(batch_size, seq_len, self.embed_dim)
        return self.out_proj(output)

前馈网络实现

GPT-2 使用两层全连接网络,中间通过 GELU 激活函数:

class FeedForward(nn.Module):
    def __init__(self, embed_dim, hidden_dim):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(embed_dim, hidden_dim),
            nn.GELU(),
            nn.Linear(hidden_dim, embed_dim)
        )
        
    def forward(self, x):
        return self.net(x)

Transformer 块集成

将多头注意力和前馈网络组合成完整的 Transformer 块:

class TransformerBlock(nn.Module):
    def __init__(self, embed_dim, num_heads, hidden_dim):
        super().__init__()
        self.attention = MultiHeadAttention(embed_dim, num_heads)
        self.ffn = FeedForward(embed_dim, hidden_dim)
        self.ln1 = nn.LayerNorm(embed_dim)
        self.ln2 = nn.LayerNorm(embed_dim)
        
    def forward(self, x, mask=None):
        attn_out = self.attention(self.ln1(x), mask)
        x = x + attn_out
        ffn_out = self.ffn(self.ln2(x))
        x = x + ffn_out
        return x

关键性能优化技术

  1. 缩放点积注意力:通过 $\sqrt{d_k}$ 缩放因子防止梯度消失: $$ \text{Attention}(Q,K,V) = \text{softmax}\left(\frac{QK^T}{\sqrt{d_k}}\right)V $$

  2. 残差连接:每个子层后保留原始输入路径,缓解梯度消失问题

  3. 层归一化:对每个 Transformer 块的输入进行归一化处理,加速训练收敛

  4. GELU 激活:相比 ReLU 提供更平滑的梯度特性: $$ \text{GELU}(x) = x\Phi(x) $$ 其中 $\Phi$ 是标准正态分布累积函数

完整模型调用示例

model = TransformerBlock(
    embed_dim=768,  # GPT-2 small 的隐藏层维度
    num_heads=12,   # 注意力头数量
    hidden_dim=3072 # 前馈网络中间层维度
)
input_tensor = torch.randn(1, 256, 768)  # (batch, seq_len, embed_dim)
output = model(input_tensor)

GPT-2的核心架构与设计原理

GPT-2基于Transformer的解码器结构,采用自回归机制生成文本。模型的核心是多层堆叠的Transformer块,每个块包含多头自注意力机制和前馈神经网络。自注意力机制通过计算词与词之间的相关性权重,动态捕捉上下文依赖关系。GPT-2的参数量从1.5亿(小型版本)到15亿(完整版本)不等,巨大的容量使其能够学习复杂的语言模式。

关键设计包括:

  • 单向注意力掩码:仅允许当前词关注左侧历史词,确保生成方向性。
  • 位置编码:通过正弦函数注入位置信息,替代传统RNN的时序处理。
  • 层归一化前置:将归一化置于残差连接前,提升训练稳定性。

训练数据与预训练策略

GPT-2的性能优势源于其数据规模和多样性。训练数据来自Reddit出站链接的800万网页(WebText数据集),覆盖技术文档、新闻、小说等多领域文本。数据清洗时保留原始格式(如代码缩进),避免过度规范化导致的语义损失。

预训练采用任务无关的极大似然目标: [ \mathcal{L} = -\sum_{t=1}^T \log P(x_t | x_{<t}) ] 通过预测下一个词的任务,模型隐式学习语法、事实知识和推理能力。训练使用256个TPU v3核心,批处理大小达512,采用Adam优化器(学习率2.5e-4,余弦衰减)。

上下文长度与零样本迁移能力

GPT-2将上下文窗口扩展至1024词,远超GPT的512词。长上下文支持更复杂的连贯性生成,例如保持多段落故事的情节一致性。零样本能力的关键在于:

  • 多任务隐式学习:预训练数据包含问答、翻译等任务的自然文本描述,使模型在未显式训练的情况下泛化。
  • 概率链式分解:自回归生成天然适配开放域任务,如通过条件概率( P(\text{答案}|\text{问题}) )完成问答。

生成策略与可控性技术

GPT-2的生成质量依赖以下技术:

  • Top-k采样:从概率最高的k个候选词中随机选择,平衡多样性与合理性(默认k=40)。
  • 温度系数τ:调整softmax输出分布,高温增加多样性,低温提高确定性: [ P(x) = \frac{\exp(z_i/\tau)}{\sum_j \exp(z_j/\tau)} ]
  • 重复惩罚:抑制已生成词的概率,避免循环输出。

局限性与后续改进方向

GPT-2的局限性包括事实性错误和长程依赖丢失。后续模型(如GPT-3)通过以下方向改进:

  • 规模扩展至千亿参数,增强记忆容量。
  • 引入Few-shot提示工程,显式利用上下文示例。
  • 结合强化学习微调(如RLHF),对齐人类偏好。

代码示例(PyTorch风格伪代码):

# GPT-2生成过程简化示例
def generate(input_ids, max_length):
    for _ in range(max_length):
        logits = model(input_ids)  # 前向传播
        logits = logits[:, -1, :] / temperature
        probs = top_k_filtering(logits, k=40)
        next_token = torch.multinomial(probs, 1)
        input_ids = torch.cat([input_ids, next_token], dim=-1)
    return input_ids

Logo

有“AI”的1024 = 2048,欢迎大家加入2048 AI社区

更多推荐