斯坦福大学 | CS336 | 从零开始构建语言模型 | Spring 2025 | 笔记 | Assignment 1: Training Loop
斯坦福大学 | CS336 | 从零开始构建语言模型 | Spring 2025 | 笔记 | Assignment 1: Training Loop
目录
前言
本篇文章记录 CS336 作业 Assignment 1: Basics 中的 Training a Transformer LM、Training loop 以及 Generate text 作业要求,仅供自己参考😄
Assignment 1:https://github.com/stanford-cs336/assignment1-basics
reference:https://chatgpt.com/
1. Training a Transformer LM 作业要求
以下内容均翻译自 cs336_spring2025_assignment1_basics.pdf,请大家查看原文档获取更详细的内容
到目前为止,我们已经完成了 数据预处理(通过分词器 tokenizer)以及 模型本身(Transformer)的实现,接下来还需要编写完整的代码来支持模型训练,主要包括以下几个部分:
- 损失函数(Loss):需要定义训练所使用的损失函数,这里采用 交叉熵损失(cross-entropy)
- 优化器(Optimizer):需要定义用于最小化该损失的优化器,例如 AdamW
- 训练循环(Training loop):需要实现所有支撑训练过程的基础设施,包括数据加载、模型检查点保存以及训练过程的整体管理
1.1 Cross-entropy loss
回顾一下,Transformer 语言模型会为长度为 m + 1 m+1 m+1 的每个序列 x x x 定义一个条件概率分布 p θ ( x i + 1 ∣ x 1 : i ) p_\theta(x_{i+1} \mid x_{1:i}) pθ(xi+1∣x1:i),其中 i = 1 , … , m i = 1, \ldots, m i=1,…,m
给定一个由长度为 m m m 的序列组成的训练集 D D D,我们将标准的 交叉熵损失函数(也即 负对数似然)定义为:
ℓ ( θ ; D ) = 1 ∣ D ∣ m ∑ x ∈ D ∑ i = 1 m − log p θ ( x i + 1 ∣ x 1 : i ) . (16) \ell(\theta; D) = \frac{1}{|D| m} \sum_{x \in D} \sum_{i=1}^{m}- \log p_\theta(x_{i+1} \mid x_{1:i}). \tag{16} ℓ(θ;D)=∣D∣m1x∈D∑i=1∑m−logpθ(xi+1∣x1:i).(16)
需要注意的是,在 Transformer 中,一次前向传播就可以得到所有 i = 1 , … , m i = 1, \ldots, m i=1,…,m 的条件概率 p θ ( x i + 1 ∣ x 1 : i ) p_\theta(x_{i+1} \mid x_{1:i}) pθ(xi+1∣x1:i)
具体来说,Transformer 会在每个位置 i i i 计算一个 logits 向量 o i ∈ R vocab_size o_i \in \mathbb{R}^{\text{vocab\_size}} oi∈Rvocab_size,从而得到:
p ( x i + 1 ∣ x 1 : i ) = softmax ( o i ) [ x i + 1 ] = exp ( o i [ x i + 1 ] ) ∑ a = 1 vocab_size exp ( o i [ a ] ) . (17) p(x_{i+1} \mid x_{1:i}) = \text{softmax}(o_i)[x_{i+1}] = \frac{\exp(o_i[x_{i+1}])} {\sum_{a=1}^{\text{vocab\_size}} \exp(o_i[a])}. \tag{17} p(xi+1∣x1:i)=softmax(oi)[xi+1]=∑a=1vocab_sizeexp(oi[a])exp(oi[xi+1]).(17)
Note:这里的 o i [ k ] o_i[k] oi[k] 表示 logits 向量 o i o_i oi 中索引为 k k k 的元素
一般来说,交叉熵损失是针对 logits 向量 o i ∈ R vocab_size o_i \in \mathbb{R}^{\text{vocab\_size}} oi∈Rvocab_size 以及目标词 x i + 1 x_{i+1} xi+1 来定义的
Note:该表达式对应于目标词 x i + 1 x_{i+1} xi+1 的狄拉克 delta 分布与模型预测的 softmax ( o i ) \text{softmax}(o_i) softmax(oi) 分布之间的交叉熵
在实现交叉熵损失时,需要像实现 softmax 一样,特别注意 数值稳定性问题
交叉熵损失已经足以用于训练模型,但在评估模型性能时,我们通常还希望报告 困惑度(perplexity)。对于一个长度为 m m m 的序列,如果我们在各个位置上得到的交叉熵损失为 ℓ 1 , … , ℓ m \ell_1, \ldots, \ell_m ℓ1,…,ℓm,则困惑度定义为:
perplexity = exp ( 1 m ∑ i = 1 m ℓ i ) . (18) \text{perplexity} = \exp\left( \frac{1}{m} \sum_{i=1}^{m} \ell_i \right). \tag{18} perplexity=exp(m1i=1∑mℓi).(18)
Problem (cross_entropy): Implement Cross entropy (2 points)
Deliverable:编写一个函数,用于计算 交叉熵损失,该函数以模型预测的 logits( o i o_i oi)和目标标签( x i + 1 x_{i+1} xi+1)作为输入,并计算交叉熵损失 ℓ i = − log ( softmax ( o i ) [ x i + 1 ] ) \ell_i = - \log \big( \text{softmax}(o_i)[x_{i+1}] \big) ℓi=−log(softmax(oi)[xi+1]),你的函数需要满足以下要求:
- 为数值稳定性,在计算前减去 logits 中的最大值
- 在可能的情况下,消去不必要的 log \log log 和 exp \exp exp 运算
- 支持任意额外的 batch 维度,并在 batch 维度上返回 平均损失。与 §1.3 中的约定一致,我们假设所有 batch 类维度始终位于最前面,而词表大小维度位于最后
完成后,实现测试适配器 [adapters.run_cross_entropy],然后运行:
uv run pytest -k test_cross_entropy
以测试你的实现是否正确
1.2 The SGD Optimizer
现在我们已经定义了损失函数,接下来将开始探索 优化器(optimizers)。最简单的基于梯度的优化器是 随机梯度下降(Stochastic Gradient Descent, SGD),我们从随机初始化的参数 θ 0 \theta_0 θ0 开始,随后,在每一步 t = 0 , … , T − 1 t = 0, \ldots, T-1 t=0,…,T−1 中,我们执行如下更新:
θ t + 1 ← θ t − α t ∇ L ( θ t ; B t ) , (19) \theta_{t+1} \leftarrow \theta_t - \alpha_t \nabla L(\theta_t; B_t), \tag{19} θt+1←θt−αt∇L(θt;Bt),(19)
其中, B t B_t Bt 是从数据集 D D D 中随机采样得到的一个数据批次(batch),而学习率 α t \alpha_t αt 以及批大小 B t B_t Bt 都是 超参数
1.2.1 Implementing SGD in PyTorch
为了实现我们自己的优化器,我们将继承 PyTorch 的 torch.optim.Optimizer 基类,一个优化器子类 必须实现两个方法:
def __init__(self, params, ...)
该方法用于初始化优化器。其中,params 是需要被优化的一组参数,或者参数组,如果用户希望对模型不同部分使用不同的超参数,例如不同的学习率,就可以使用参数组
请务必将 params 传递给基类的 __init__ 方法,基类会保存这些参数,以便在 step 中使用。你可以根据具体优化器的需求接收额外的参数(例如学习率是一个常见参数),并将这些参数以字典的形式传递给基类构造函数,其中字典的键是你为这些参数选择的名称(字符串)
def step(self)
该方法用于 执行一次参数更新。在训练循环中,它会在反向传播(backward pass)之后被调用,因此此时你可以访问最近一个 batch 上计算得到的梯度
在这个方法中,你需要遍历每一个参数张量 p,并 原地(in place) 修改它们,即直接更新 p.data。参数的更新应当基于对应的梯度 p.grad(如果该梯度存在),其中 p.grad 是损失函数相对于该参数的梯度张量
PyTorch 的优化器 API 中存在一些细节问题,用一个具体示例来解释会更加清晰。为了让示例更有代表性,我们将实现一种 带学习率衰减的 SGD 变体:学习率会随着训练过程逐渐减小,从初始学习率 α \alpha α 开始,随着时间推移步长越来越小:
θ t + 1 = θ t − α t + 1 ∇ L ( θ t ; B t ) (20) \theta_{t+1} = \theta_t - \frac{\alpha}{\sqrt{t+1}} \nabla L(\theta_t; B_t) \tag{20} θt+1=θt−t+1α∇L(θt;Bt)(20)
我们来看一下这一版本的 SGD 是如何作为一个 PyTorch 优化器来实现的:
from collections.abc import Callable, Iterable
from typing import Optional
import torch
import math
class SGD(torch.optim.Optimizer):
def __init__(self, params, lr=1e-3):
if lr < 0:
raise ValueError(f"Invalid learning rate: {lr}")
defaults = {"lr": lr}
super().__init__(params, defaults)
def step(self, closure: Optional[Callable] = None):
loss = None if closure is None else closure()
for group in self.param_groups:
lr = group["lr"] # Get the learning rate.
for p in group["params"]:
if p.grad is None:
continue
state = self.state[p] # Get state associated with p.
t = state.get("t", 0) # Get iteration number from the state, or initial value.
grad = p.grad.data # Get the gradient of loss with respect to p.
p.data -= lr / math.sqrt(t + 1) * grad # Update weight tensor in-place.
state["t"] = t + 1 # Increment iteration number.
return loss
在 __init__ 方法中,我们将需要优化的参数以及默认的超参数一起传递给优化器的基类构造函数,这些参数可能会被分成多个参数组,每个组可以拥有不同的超参数。如果传入的参数只是一个 torch.nn.Parameter 的集合,基类构造函数会自动创建一个参数组,并为其分配默认的超参数
在 step 方法中,我们首先遍历每一个参数组,然后遍历该组中的每一个参数,并按照公式 (20) 对参数进行更新。这里我们将 迭代次数 作为与每个参数相关联的状态进行存储:在更新梯度时先读取该值,完成参数更新后再将其递增
PyTorch 的优化器 API 允许用户传入一个可调用的 closure,用于在执行优化器更新之前重新计算损失函数。虽然在我们接下来要使用的优化器中并不需要这个功能,但为了符合 PyTorch 的接口规范,这里仍然将其包含在实现中
为了直观地看到这一过程是如何工作的,我们可以使用下面这个 最小化的训练循环(training loop)示例:
weights = torch.nn.Parameter(5 * torch.randn((10, 10)))
opt = SGD([weights], lr=1)
for t in range(100):
opt.zero_grad() # Reset the gradients for all learnable parameters.
loss = (weights**2).mean() # Compute a scalar loss value.
print(loss.cpu().item())
loss.backward() # Run backward pass, which computes gradients.
opt.step() # Run optimizer step.
这就是训练循环的典型结构:在每一次迭代中,我们都会先计算损失,然后执行一次优化器的更新步骤。在训练语言模型时,可学习的参数将来自模型本身(在 PyTorch 中,可以通过 nn.parameters() 获取这些参数)。损失函数通常是在从数据集中采样得到的一个 batch 上计算的,但 训练循环的基本结构始终是相同的
Problem (learning_rate_tuning): Tuning the learning rate (1 point)
正如我们将会看到的那样,在所有超参数中,对训练过程影响最大的就是 学习率。下面我们通过一个简单示例来直观地观察这一点,请将上面的 SGD 示例分别使用另外三个学习率取值来运行:1e1、1e2 和 1e3,并且只训练 10 次迭代
对于每一种学习率,损失函数的变化情况如何?它是下降得更快、更慢,还是会发生发散(也就是说,在训练过程中反而不断增大)?
Deliverable:用一到两句话描述你观察到的这些学习率对应的训练行为。
1.3 AdamW
现代语言模型通常会使用比 SGD 更加复杂的优化器进行训练,近年来使用的大多数优化器都是 Adam 优化器 [Kingma and Ba 2015] 的各种变体。在本课程中,我们将使用 AdamW [Loshchilov and Hutter 2019],这是一种在近期研究中被广泛采用的方法
AdamW 对 Adam 进行了改进:通过引入 权重衰减(weight decay) 来增强正则化效果。具体来说,在每一次迭代中,都会将参数向 0 拉近,并且这种权重衰减是 与梯度更新解耦(decoupled)的,我们将按照 [Loshchilov and Hutter 2019] 论文中算法 2 的描述来实现 AdamW
AdamW 是一种 有状态(stateful) 的优化器:对于每一个参数,它都会维护该参数一阶矩和二阶矩的运行估计,因此,AdamW 通过额外的内存开销,换取了更好的训练稳定性和收敛性
除了学习率 α \alpha α 之外,AdamW 还包含一组超参数 ( β 1 , β 2 ) (\beta_1, \beta_2) (β1,β2),用于控制一阶与二阶矩估计的更新速度,以及一个权重衰减系数 λ \lambda λ。在典型应用中, ( β 1 , β 2 ) (\beta_1, \beta_2) (β1,β2) 通常设为 ( 0.9 , 0.999 ) (0.9, 0.999) (0.9,0.999),但在大型语言模型中,例如 LLaMA [Touvron+ 2023] 和 GPT-3 [Brown+ 2020],常常使用 ( 0.9 , 0.95 ) (0.9, 0.95) (0.9,0.95) 这一组合
该算法可以形式化地写成如下形式,其中 ϵ \epsilon ϵ 是一个很小的常数(例如 10 − 8 10^{-8} 10−8),用于在 v v v 极小的情况下提升数值稳定性

Note:这里的 t t t 从 1 开始计数,现在你将实现这个优化器
Problem (adamw): Implement AdamW (2 points)
Deliverable:将 AdamW 优化器实现为 torch.optim.Optimizer 的一个子类,你的类应当在 __init__ 中接收学习率 α \alpha α 以及超参数 β \beta β、 ϵ \epsilon ϵ 和 λ \lambda λ。为了帮助你维护状态,基类 Optimizer 为你提供了一个字典 self.state,它将每个 nn.Parameter 对象映射到一个字典,用于存储与该参数相关的任意信息(对于 AdamW 来说,这些信息就是动量估计)
请实现 [adapters.get_adamw_cls],并确保你的实现能够通过下面的测试:
uv run pytest -k test_adamw
Problem (adamwAccounting): Resource accounting for training with AdamW (2 points)
我们来计算在使用 AdamW 进行训练时所需的 内存和计算量,假设所有张量都使用 float32 表示
(a)运行 AdamW 需要多少峰值内存?
请基于以下几部分的内存占用对你的答案进行拆分说明:
- 模型参数(parameters)
- 激活值(activations)
- 梯度(gradients)
- 优化器状态(optimizer state)
请用 batch_size 以及模型的超参数(vocab_size,context_length、num_layers、d_model、num_heads)来表达你的答案,并假设 d f f = 4 × d m o d e l d_{ff}=4\times d_{model} dff=4×dmodel
为简化起见,在计算 激活值内存 时,只考虑以下组件:
- Transformer block
- RMSNorm(s)
- 多头自注意力子层:
- QKV 投影
- Q ⊤ K Q^{\top}K Q⊤K 矩阵乘
- softmax
- 对 value 的加权求和
- 输出投影
- 逐位置前馈网络:
- W 1 W_1 W1 矩阵乘
- SiLU
- W 2 W_2 W2 矩阵乘
- 最终 RMSNorm
- 输出 embedding
- logits 上的交叉熵损失
Deliverable:给出参数、激活值、梯度以及优化器状态各自的代数表达式,并给出总内存表达式。
(b)针对 GPT-2 XL 规模的模型实例化你的结果
将你的结果带入 GPT-2 XL 形状的模型,使表达式只依赖于 batch_size,在 80GB 显存限制下,你最大可以使用多大的 batch_size?
Deliverable:一个形如 a ⋅ batch_size + b a\cdot \text{batch\_size}+b a⋅batch_size+b 的表达式(其中 a , b a,b a,b 为数值常数),以及一个表示最大 batch_size 的具体数值。
(c)AdamW 的单步更新需要多少 FLOPs?
Deliverable:一个代数表达式,并附上简要说明。
(d)训练时间估计(基于 MFU)
模型 FLOPs 利用率(MFU)定义为实际观测到的吞吐量(token/s)与硬件理论峰值 FLOPs 吞吐量的比值 [Chowdhery+ 2022]
已知:
- NVIDIA A100 GPU 的 float32 理论峰值为 19.5 TFLOP/s
- 假设你可以达到 50% MFU
- 在单张 A100 上训练 GPT-2 XL
- 训练 400K steps
- batch size = 1024
- 假设 反向传播的 FLOPs 是前向传播的两倍(参考 [Kaplan+ 2020] [Hoffmann+ 2022])
那么完成训练大约需要多少天?
Deliverable:训练所需的天数,并给出简要说明。
1.4 Learning rate scheduling
在训练过程中,能够最快降低损失的学习率数值往往会随着训练阶段而变化,在训练 Transformer 模型时,通常会使用一种 学习率调度策略(learning rate schedule):在训练初期使用较大的学习率,以更快地进行参数更新;随着模型训练的推进,在逐渐将学习率衰减到较小的数值。在本次作业中,我们将实现用于训练 LLaMA 的 余弦退火(cosine annealing)学习率调度 方法 [Touvron+ 2023]
Note:在实践中,有时也会使用一种 学习率重启(restart) 的调度方式,即在训练过程中让学习率再次升高,以帮助模型跳出局部最优解
一个调度器本质上就是一个函数,它接收当前的训练步数 t t t 以及其他相关参数(例如初始和最终学习率),并返回在第 t t t 步用于梯度更新的学习率,最简单的调度策略是 常数调度,即无论 t t t 取何值,都返回同一个学习率
余弦退火学习率调度 需要以下参数:
1. 当前迭代步数 t t t
2. 最大(初始)学习率 α max \alpha_{\max} αmax
3. 最小(最终)学习率 α min \alpha_{\min} αmin
4. warm-up 迭代步数 T w T_w Tw
5. 余弦退火阶段的迭代步数 T c T_c Tc
在第 t t t 次迭代时,学习率 α t \alpha_t αt 的定义如下:
(Warm-up 阶段) 如果 t < T w t < T_w t<Tw,则:
α t = t T w α max \alpha_t = \frac{t}{T_w} \alpha_{\max} αt=Twtαmax
(余弦退火阶段) 如果 T w ≤ t ≤ T c T_w \le t \le T_c Tw≤t≤Tc,则:
α t = α min + 1 2 ( 1 + cos ( t − T w T c − T w π ) ) ( α max − α min ) \alpha_t = \alpha_{\min} + \frac{1}{2}\left(1 + \cos\left(\frac{t - T_w}{T_c - T_w}\pi\right)\right)(\alpha_{\max} - \alpha_{\min}) αt=αmin+21(1+cos(Tc−Twt−Twπ))(αmax−αmin)
(退火后阶段) 如果 t > T c t > T_c t>Tc,则:
α t = α min \alpha_t = \alpha_{\min} αt=αmin
Problem (learning_rate_schedule): Implement cosine learning rate schedule with warmup (1 point)
Deliverable:编写一个函数,该函数接收参数 t t t、 α max \alpha_{\max} αmax、 α min \alpha_{\min} αmin、 T w T_w Tw 和 T c T_c Tc,并根据上文定义的调度策略返回在第 t t t 步使用的学习率 α t \alpha_t αt
随后,实现测试适配器 [adapters.get_lr_cosine_schedule],并确保它能够通过以下测试命令:
uv run pytest -k test_get_lr_cosine_schedule
1.5 Gradient clipping
在训练过程中,我们有时会遇到一些训练样本,它们会产生非常大的梯度,从而导致训练过程的不稳定,为缓解这一问题,实践中常用的一种技术是 梯度裁剪(gradient clipping)。其核心思想是在每一次反向传播之后、执行优化器更新之前,对梯度的范数施加一个上限约束
具体来说,设所有参数对应的梯度为 g g g,我们首先计算其 ℓ 2 \ell_2 ℓ2 范数 ∥ g ∥ 2 \lVert g \rVert_2 ∥g∥2,如果该范数小于某个最大阈值 M M M,则保持梯度 g g g 不变,否则,将梯度按比例缩放,缩放因子为:
M ∥ g ∥ 2 + ϵ \frac{M}{\lVert g \rVert_2 + \epsilon} ∥g∥2+ϵM
其中 ϵ \epsilon ϵ 是一个很小的常数(例如 10 − 6 10^{-6} 10−6),用于提高数值稳定性。需要注意的是,经过缩放后的梯度范数将略小于 M M M
Problem (gradient_clipping): Implement gradient clipping (1 point)
Deliverable:编写一个函数来实现梯度裁剪,你的函数应当接收一个参数列表以及一个最大的 ℓ 2 \ell_2 ℓ2 范数,并 就地(in place) 修改每一个参数对应的梯度,请使用 ϵ = 10 − 6 \epsilon=10^{-6} ϵ=10−6(即 PyTorch 的默认取值)
随后,实现测试适配器 [adapters.run_gradient_clipping],并确保你的实现能够通过以下测试:
uv run pytest -k test_gradient_clipping
2. Training loop 作业要求
以下内容均翻译自 cs336_spring2025_assignment1_basics.pdf,请大家查看原文档获取更详细的内容
现在,我们终于可以把目前为止构建好的几个主要组件组合在一起了:分词后的数据、模型以及优化器
2.1 Data Loader
分词后的数据(例如你在 tokenizer_experiments 中准备的数据)可以看作是一条单一的 token 序列即 x = ( x 1 , … , x n ) x = (x_1, \ldots, x_n) x=(x1,…,xn),尽管原始数据可能由多个独立文档组成(例如不同的网页或源代码文件),一种常见做法是将它们全部拼接成一条连续的 token 序列,并在文档之间加入分隔符(例如 <|endoftext|> token)
数据加载器(data loader) 会将这条长序列转换成一串批次(batches),每个批次由 B B B 条长度为 m m m 的序列组成,并与对应的 “下一个 token” 序列配对,下一个 token 序列的长度同样为 m m m。例如,当 B = 1 , m = 3 B=1,m=3 B=1,m=3 时, ( [ x 2 , x 3 , x 4 ] , [ x 3 , x 4 , x 5 ] ) ([x_2, x_3, x_4], [x_3, x_4, x_5]) ([x2,x3,x4],[x3,x4,x5]) 就是一个可能的批次
采用这种方式加载数据在训练时有多方面的优势:首先,任何满足 i ≤ i < n − m i\le i<n-m i≤i<n−m 的位置 i i i 都可以构成一个合法的训练序列,因此采样训练样本非常简单;其次,由于所有训练序列的长度都是固定的,不需要对输入序列进行填充(padding),这有助于提高硬件利用率(同时也允许使用更大的 batch size);最后,我们也不需要一次性将整个数据集完全加载到内存中即可进行训练采样,这使得处理那些原本可能无法完全放入内存的大规模数据集变得更加容易
如果数据集大到无法一次性加载进内存,该怎么办呢?我们可以使用一种名为 mmap 的 Unix 系统调用,它可以将磁盘上的文件映射到虚拟内存中,并且只在访问到对应内存位置时才 按需加载 文件内容,这样一来,你就可以 “假装” 整个数据集都已经在内存中了
NumPy 通过 np.memmap(或者在你最初使用 np.save 保持数组时,在 np.load 中设置参数 mmap_mode='r')来实现这一机制,它会返回一个 类似 numpy 数组的对象,在你访问其中元素时才会动态加载对应的数据
在训练过程中从数据集(即一个 numpy 数组)中进行采样时,一定要确保你是以 内存映射模式 加载数据集的(根据保存数组的方式,使用 np.memmap 或在 np.load 中指定 mmap_mode='r'),同时,还要确保你显式指定了与所加载数组相匹配的 dtype
此外,显式检查内存映射后的数据是否正确也是一个好习惯,例如,确认其中不包含超出预期词表大小范围的数值
Problem (data_loading): Implement data loading (2 points)
Deliverable:编写一个函数,该函数接收以下输入:
- 一个 numpy 数组 x x x(包含 token ID 的整数数组)
batch_sizecontext_length- 一个 PyTorch 设备字符串(例如
'cpu'或'cuda:0')
并返回一对张量:采样得到的输入序列以及对应的下一个 token 目标序列,这两个张量的形状都应为 ( batch_size , context_length ) (\text{batch\_size},\text{context\_length}) (batch_size,context_length),并且内容均为 token ID,同时二者都应被放置在所指定的设备上
为了使用我们提供的测试来验证你的实现,你需要首先实现测试适配器 [adapters.run_get_batch],并确保你的实现能够通过以下测试
uv run pytest -k test_get_batch
Low-Resource/Downscaling Tip: Data loading on CPU or Apple Silicon
如果你计划在 CPU 或 Apple Silicon 上训练你的语言模型,你需要将数据移动到正确的设备上(同样地,后续模型本身也应使用相同的设备),如果你使用的是 CPU,可以使用设备字符串 'cpu',如果你使用的是 Apple Silicon(M 系列芯片),可以使用设备字符串 'mps'
关于 MPS 的更多信息,可以参考以下资源:
2.2 Checkpointing
除了加载数据之外,在训练过程中我们还需要 保存模型,在实际运行训练任务时,我们往往希望能够在训练由于某些原因中途停止(例如作业超时、机器故障等)后,继续恢复训练。即便一切顺利,我们之后也可能希望访问训练过程中的 中间模型(例如用于事后分析训练动态,或在训练的不同阶段中从模型中采样等)
一个检查点(checkpoint)应当包含 恢复训练所需的全部状态信息,至少,我们需要能够恢复模型的权重,如果使用的是 有状态的优化器(例如 AdamW),还需要保存优化器的状态(以 AdamW 为例,就是一阶和二阶矩的估计值),此外,为了能够继续学习率调度(learning rate schedule),我们还需要知道训练停止时所处的 迭代步数
PyTorch 让保存这些内容变得非常方便:每一个 nn.Module 都提供了 state_dict() 方法,用于返回一个包含所有可学习权重的字典,之后可以通过配套的 load_state_dict() 方法将这些权重恢复,同样的机制也适用于任何 nn.optim.Optimizer
最后,torch.save(obj, dest) 可以将一个对象(例如一个字典,其中既可以包含张量,也可以包含整数等普通 Python 对象)保存到文件或类文件对象中,而这些内容之后可以通过 torch.load(src) 再次加载回内存
Problem (checkpointing): Implement model checkpointing (1 point)
请实现下面两个函数,用于 保存和加载检查点(checkpoint):
def save_checkpoint(model, optimizer, iteration, out)
该函数应当将前三个参数中的 所有状态 写入到类文件对象 out 中,你可以使用模型和优化器各自的 state_dict 方法来获取它们对应的状态,并使用 torch.save(obj, out) 将对象保存到 out 中(PyTorch 在这里既支持路径,也支持类文件对象)。通常的做法是让 obj 成为一个字典,但只要你之后能够正确加载该检查点,也可以使用任何你喜欢的格式
该函数期望接收以下参数:
model:torch.nn.Moduleoptimizer:torch.optim.Optimizeriteration:intout:str | os.PathLike | typing.BinaryIO | typing.IO[bytes]
def load_checkpoint(src, model, optimizer)
该函数应当从 src(路径或类文件对象)中加载一个检查点,并从中恢复模型和优化器的状态,函数应返回 保存在检查点中的迭代步数。你可以使用 torch.load(src) 来读取在 save_checkpoint 中保存的内容,然后分别调用模型和优化器的 load_state_dict 方法,将它们恢复到之前的状态
该函数期望接收以下参数:
src:str | os.PathLike | typing.BinaryIO | typing.IO[bytes]model:torch.nn.Moduleoptimizer:torch.optim.Optimizer
请实现适配器函数 [adapters.run_save_checkpoint] 和 [adapters.run_load_checkpoint],并确保它们能够通过以下测试:
uv run pytest -k test_checkpointing
2.3 Training loop
现在,终于到了将你已经实现的所有组件整合到 主训练脚本 中的时候了,一个值得投入精力的做法是:让训练过程能够 方便地使用不同的超参数启动(例如通过命令行参数的方式),因为在后续的实现中,你将会多次运行训练,以研究不同选择对训练过程产生的影响
Problem (training_together): Put it together (4 points)
Deliverable:编写一个脚本,用于运行训练循环,在 用户提供的输入数据 上训练你的模型,具体来说,我们建议你的训练脚本至少支持以下功能:
- 能够配置和控制 模型以及优化器中的各类超参数
- 使用 np.memmap 以 内存高效 的方式加载训练集和验证集等大型数据集
- 将 checkpoint(检查点)序列化并保存 到用户指定的路径
- 周期性地记录 训练和验证阶段的性能指标(例如输出到控制台或记录到诸如 Weights & Biases [wandb.ai] 这样的外部服务中)
3. Generating text 作业要求
以下内容均翻译自 cs336_spring2025_assignment1_basics.pdf,请大家查看原文档获取更详细的内容
现在我们已经能够训练模型了,最后还需要实现的一项功能是:让模型生成文本。回顾一下,语言模型接收一个(可能是批量化的)长度为 sequence_length 的整数序列作为输入,并输出一个大小为 ( sequence_length × vocab_size ) (\text{sequence\_length}\times \text{vocab\_size}) (sequence_length×vocab_size) 的矩阵,其中序列中每个位置对应一个概率分布,用来预测该位置之后的下一个词。接下来,我们将编写一些函数,把这一输出形式转化为用于生成新序列的采样方案
Softmax
按照标准约定,语言模型的输出是最后一层线性层的结果(即 logits),因此,我们需要通过 softmax 操作将其转换为归一化的概率分布,这一点我们此前已经在公式 (10) 中见过
softmax ( v ) i = exp ( v i ) ∑ j = 1 n exp ( v j ) . (10) \text{softmax}(v)_i = \frac{\exp(v_i)}{\sum_{j=1}^{n} \exp(v_j)}. \tag{10} softmax(v)i=∑j=1nexp(vj)exp(vi).(10)
Decoding
为了从模型中生成文本(即进行解码),我们会向模型提供一个前缀 token 序列(也就是 “提示词 / prompt”),并让模型输出一个在整个词表上的概率分布,用于预测序列中的下一个词。随后,我们将从该词表概率分布中进行采样,以确定下一个输出的 token
更具体地说,解码过程中会接收一个序列 x 1 … t x_{1\ldots t} x1…t,并通过下式返回下一个 token x t + 1 x_{t+1} xt+1:
P ( x t + 1 = i ∣ x 1 … t ) = exp ( v i ) ∑ j exp ( v j ) P(x_{t+1} = i \mid x_{1\ldots t}) = \frac{\exp(v_i)}{\sum_j \exp(v_j)} P(xt+1=i∣x1…t)=∑jexp(vj)exp(vi)
其中
v = TransformerLM ( x 1 … t ) t ∈ R vocab_size v = \text{TransformerLM}(x_{1\ldots t})_t \in \mathbb{R}^{\text{vocab\_size}} v=TransformerLM(x1…t)t∈Rvocab_size
这里,TransformerLM 是我们的模型:它接收一个长度为 sequence_length 的输入序列,并输出一个大小为 ( sequence_length × vocab_size ) (\text{sequence\_length}\times \text{vocab\_size}) (sequence_length×vocab_size) 的矩阵,我们取该矩阵的最后一行,因为我们关注的是在第 t t t 个位置对下一个词的预测
通过不断地从这些条件概率中进行采样,并将刚刚生成的 token 追加到下一步解码的输入中,我们就得到了一个最基础的解码器,这个过程会一直持续,直到生成序列结束 token <|endoftext|>,或者达到用户指定的最大生成 token 数
Decoder tricks
我们将主要在 小模型 上进行实验,而小模型有时会生成质量很低的文本,下面介绍两种简单但有效的解码技巧,可以缓解这些问题
第一种方法是温度缩放(temperature scaling),我们在 softmax 中引入一个温度参数 τ \tau τ,新的 softmax 定义为:
softmax ( v , τ ) i = exp ( v i / τ ) ∑ ∗ j = 1 ∣ vocab_size ∣ exp ( v j / τ ) . (24) \text{softmax}(v, \tau)_i = \frac{\exp(v_i / \tau)}{\sum*{j=1}^{|\text{vocab\_size}|} \exp(v_j / \tau)}. \tag{24} softmax(v,τ)i=∑∗j=1∣vocab_size∣exp(vj/τ)exp(vi/τ).(24)
注意,当 τ → 0 \tau \to 0 τ→0 时,向量 v v v 中最大的元素会占据主导地位,softmax 的输出将趋近于一个在该最大元素处集中的 one-hot 向量
第二种方法是 nucleus sampling(也称 top-p 采样),该方法通过截断低概率词来修改采样分布,设 q q q 为通过(经过温度缩放的)softmax 得到的、大小为 vocab_size 的概率分布,使用超参数 p p p 的 nucleus sampling 按如下方式生成下一个 token:
P ( x t + 1 = i ∣ q ) = { q i ∑ j ∈ V ( p ) q j if i ∈ V ( p ) 0 otherwise P(x_{t+1} = i \mid q) = \begin{cases} \dfrac{q_i}{\sum_{j \in V(p)} q_j} & \text{if } i \in V(p) \\ 0 & \text{otherwise} \end{cases} P(xt+1=i∣q)=⎩ ⎨ ⎧∑j∈V(p)qjqi0if i∈V(p)otherwise
其中, V ( p ) V(p) V(p) 是满足 ∑ j ∈ V ( p ) q j ≥ p \sum_{j \in V(p)} q_j \ge p ∑j∈V(p)qj≥p 的 最小索引集合,这个集合可以很容易地计算:先按概率大小对分布 q q q 进行排序,然后依次选取概率最大的词表元素,直到累计概率达到目标阈值 p p p
Problem (decoding): Decoding (3 points)
Deliverable:实现一个函数,用于从你的语言模型中进行解码,我们建议你的实现至少支持以下功能:
- 针对用户给定的提示(prompt)生成补全文本:输入一段序列 x 1 … t x_{1\ldots t} x1…t,并不断采样生成后续内容,直到遇到
<|endoftext|>token 为止 - 允许用户控制生成 token 的最大数量:在未生成结束符之前,限制最多可生成的 token 数
- 支持温度参数(temperature)控制:在采样之前,对预测的下一个词分布应用带温度的 softmax(temperature scaling)
- 支持 top-p 采样 [Holtzman+ 2020]:又称为 nucleus sampling,并允许用户指定阈值参数 p p p
结语
这篇文章我们系统性地梳理了 CS336 Assignment 1 中 Training Loop 相关的全部作业要求,从最基础的交叉熵损失函数与优化器接口出发,逐步覆盖了学习率调度、梯度裁剪、资源核算、数据加载、检查点保存以及最终的文本生成流程
更详细的内容大家可以查看官方提供的相关文档
下篇文章我们就来一起看看 Training Loop 具体该如何实现,敬请期待🤗
参考
更多推荐


所有评论(0)