【实战】基于DiT的MNIST扩散模型:从原理到代码全解析
本文完整实现了基于DiT的MNIST扩散模型,从原理到代码拆解了模型的核心设计与实现细节。该模型兼顾了Transformer的全局注意力优势和扩散模型的生成能力,是入门生成式AI的优质案例。
【实战】基于DiT的MNIST扩散模型:从原理到代码全解析
摘要
扩散模型作为当前生成式AI的核心技术,在图像生成领域展现出卓越性能,但传统U-Net架构的扩散模型存在结构复杂、部署成本高的问题。本文以经典的MNIST手写数字数据集为载体,详细讲解基于Transformer的DiT(Diffusion Model with Transformer)扩散模型的实现原理与代码细节,从模型架构设计、扩散过程建模到完整训练与采样流程,全方位拆解轻量化DiT模型的构建方法。实验表明,该模型仅需约2260万参数即可实现高质量的MNIST数字生成,兼顾性能与易用性,是入门扩散模型与Transformer结合的绝佳案例。
关键词:扩散模型;DiT;Transformer;MNIST;图像生成;PyTorch
一、研究背景与动机
扩散模型(Diffusion Model)通过模拟“加噪-去噪”的马尔可夫过程,能够生成高质量的图像数据,已成为GAN、VAE之外的主流生成模型。传统扩散模型多基于U-Net架构,依赖卷积操作提取空间特征,而近年来基于Transformer的DiT模型凭借全局注意力机制,在图像生成任务中展现出更优的特征建模能力。
MNIST数据集作为计算机视觉的“Hello World”,具有数据量小、任务简单的特点,非常适合作为扩散模型入门的实验载体。本文基于DiT架构实现面向MNIST的轻量化扩散模型,核心目标包括:
- 理解DiT模型的核心设计(adaLN-Zero条件化、Patch嵌入、时间/类别嵌入);
- 掌握DDPM扩散框架的“加噪-去噪”核心逻辑;
- 实现从模型搭建、训练到采样生成的全流程代码落地。
二、核心原理剖析
2.1 DDPM扩散框架基础
DDPM(Denoising Diffusion Probabilistic Models)是扩散模型的经典框架,核心分为两个阶段:
- 前向扩散(加噪):从真实图像x0x_0x0出发,逐步加入高斯噪声,得到含噪图像xtx_txt,满足公式:
xt=αˉtx0+1−αˉtϵx_t = \sqrt{\bar{\alpha}_t}x_0 + \sqrt{1-\bar{\alpha}_t}\epsilonxt=αˉtx0+1−αˉtϵ
其中αˉt\bar{\alpha}_tαˉt为累乘的噪声系数,ϵ\epsilonϵ为高斯噪声。 - 反向扩散(去噪):训练模型学习从含噪图像xtx_txt中预测噪声ϵ\epsilonϵ,并逐步去噪还原出真实图像。
2.2 DiT模型核心设计
DiT(Diffusion Transformer)将Transformer架构与扩散模型结合,核心改进包括:
- Patch嵌入:将图像切分为固定大小的Patch,映射到高维特征空间,替代卷积操作;
- 条件化嵌入:融合时间步嵌入(Timestep Embedder)和类别嵌入(Label Embedder),实现条件化生成;
- adaLN-Zero调制:通过自适应层归一化,将时间/类别条件融入Transformer Block,实现条件化特征调制;
- Classifier-Free引导(CFG):通过随机丢弃类别标签,提升生成模型的可控性。
三、代码全解析
3.1 环境依赖
# 核心依赖
pip install torch torchvision tqdm matplotlib timm numpy
3.2 核心模块拆解
3.2.1 位置编码模块
位置编码是Transformer的核心组件,本文采用2D正弦余弦位置编码,为每个Patch赋予空间位置信息:
def get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False, extra_tokens=0):
grid_h = np.arange(grid_size, dtype=np.float32)
grid_w = np.arange(grid_size, dtype=np.float32)
grid = np.meshgrid(grid_w, grid_h) # 生成2D网格
grid = np.stack(grid, axis=0)
grid = grid.reshape([2, 1, grid_size, grid_size])
pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid)
if cls_token and extra_tokens > 0:
pos_embed = np.concatenate([np.zeros([extra_tokens, embed_dim]), pos_embed], axis=0)
return pos_embed
核心作用:为7×7的MNIST Patch(28×28图像切分为4×4 Patch)赋予唯一的位置特征,弥补Transformer的位置无关性。
3.2.2 DiTBlock核心模块
DiTBlock是模型的核心计算单元,集成了多头注意力、MLP和adaLN-Zero调制:
class DiTBlock(nn.Module):
def __init__(self, hidden_size, num_heads, mlp_ratio=4.0, **block_kwargs):
super().__init__()
self.norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
self.attn = Attention(hidden_size, num_heads=num_heads, qkv_bias=True, qk_norm=False)
self.norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
mlp_hidden_dim = int(hidden_size * mlp_ratio)
approx_gelu = lambda: nn.GELU(approximate="tanh")
self.mlp = Mlp(in_features=hidden_size, hidden_features=mlp_hidden_dim, act_layer=approx_gelu, drop=0)
# adaLN-Zero调制层:融合时间/类别条件
self.adaLN_modulation = nn.Sequential(
nn.SiLU(),
nn.Linear(hidden_size, 6 * hidden_size, bias=True)
)
def forward(self, x, c):
# 将条件向量c分解为6个调制参数
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(c).chunk(6, dim=1)
# 注意力层:条件化调制
x = x + gate_msa.unsqueeze(1) * self.attn(modulate(self.norm1(x), shift_msa, scale_msa))
# MLP层:条件化调制
x = x + gate_mlp.unsqueeze(1) * self.mlp(modulate(self.norm2(x), shift_mlp, scale_mlp))
return x
核心亮点:通过adaLN_modulation将时间/类别条件(c)转换为调制参数,对注意力层和MLP层进行自适应调整,实现条件化生成。
3.2.3 DDPM扩散框架封装
将DiT模型与DDPM扩散流程封装,实现训练和采样的统一管理:
class DDPM(nn.Module):
def __init__(self, nn_model, betas=(1e-4, 0.02), n_T=400, device="cpu", drop_prob=0.1):
super().__init__()
self.nn_model = nn_model.to(device)
# 预计算扩散系数并注册为缓冲区(不参与训练)
for k, v in ddpm_schedules(betas[0], betas[1], n_T).items():
self.register_buffer(k, v)
self.n_T = n_T
self.device = device
self.drop_prob = drop_prob
self.loss_mse = nn.MSELoss()
def forward(self, x, c):
"""训练阶段:随机加噪+噪声预测"""
_ts = torch.randint(1, self.n_T+1, (x.shape[0],)).to(self.device) # 随机采样时间步
noise = torch.randn_like(x) # 生成高斯噪声
# 前向扩散:生成含噪图像x_t
x_t = self.sqrtab[_ts, None, None, None] * x + self.sqrtmab[_ts, None, None, None] * noise
# CFG引导:随机丢弃类别标签
context_mask = torch.bernoulli(torch.zeros_like(c)+self.drop_prob).to(self.device)
# DiT模型预测噪声
eps_pred = self.nn_model(x_t, _ts, c, context_mask)
# MSE损失:预测噪声 vs 真实噪声
return self.loss_mse(noise, eps_pred)
def sample(self, n_sample, size=(1,28,28), guide_w=0.0):
"""采样阶段:从噪声逐步去噪生成图像"""
x_i = torch.randn(n_sample, *size).to(self.device) # 初始纯噪声
c_i = torch.arange(0,10).to(self.device).repeat(int(n_sample/10)) # 生成10类标签
# CFG引导:双倍batch实现有/无标签引导
context_mask = torch.zeros_like(c_i).to(self.device)
c_i = c_i.repeat(2)
context_mask = context_mask.repeat(2)
context_mask[n_sample:] = 1.
x_i_store = [] # 保存采样过程(用于生成GIF)
# 反向扩散:从T到1逐步去噪
for i in range(self.n_T, 0, -1):
t_is = torch.tensor([i]*n_sample).to(self.device).repeat(2)
z = torch.randn(n_sample, *size).to(self.device) if i > 1 else 0. # 最后一步不加噪声
x_i = x_i.repeat(2,1,1,1)
eps = self.nn_model(x_i, t_is, c_i, context_mask)
# CFG融合:eps = 无引导 + w*(有引导 - 无引导)
eps_cond, eps_uncond = eps[:n_sample], eps[n_sample:]
eps = eps_uncond + guide_w * (eps_cond - eps_uncond)
# 反向扩散更新公式
x_i = self.oneover_sqrta[i] * (x_i[:n_sample] - eps * self.mab_over_sqrtmab[i]) + self.sqrt_beta_t[i] * z
if i%20==0 or i==self.n_T or i<8:
x_i_store.append(x_i.detach().cpu().numpy())
return x_i, np.array(x_i_store)
核心逻辑:
- 训练阶段:随机采样时间步,生成含噪图像,训练DiT模型预测噪声;
- 采样阶段:从纯噪声出发,利用训练好的模型逐步预测并去除噪声,最终生成图像;
- CFG引导:通过双倍batch实现“有标签”和“无标签”的预测结果融合,提升生成可控性。
3.3 训练与采样全流程
def train_mnist_dit(total_epoch=10, batch_size=64, base_save_dir='./data/diffusion_dit_mnist/'):
# 1. 超参数配置
n_T = 400
device = "cuda:0" if torch.cuda.is_available() else "cpu"
n_classes = 10
lrate = 1e-4
ws_test = [0.0, 0.5, 2.0] # CFG引导权重
# 2. 初始化DiT模型
dit_model = DiT(
input_size=28, patch_size=4, in_channels=1,
hidden_size=384, depth=12, num_heads=6,
class_dropout_prob=0.1, num_classes=10, learn_sigma=False
)
# 3. 初始化DDPM框架
ddpm = DDPM(nn_model=dit_model, n_T=n_T, device=device, drop_prob=0.1)
print(f"模型总参数量:{sum(p.numel() for p in ddpm.parameters()):,}") # 约2260万
# 4. 加载MNIST数据集
tf = transforms.Compose([transforms.ToTensor()])
dataset = MNIST("./data", train=True, download=True, transform=tf)
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=0)
# 5. 优化器配置
optim = torch.optim.Adam(ddpm.parameters(), lr=lrate, betas=(0.9, 0.999))
# 6. 训练循环
os.makedirs(base_save_dir, exist_ok=True)
for ep in range(total_epoch):
ddpm.train()
optim.param_groups[0]['lr'] = lrate * (1 - ep / total_epoch) # 学习率衰减
pbar = tqdm(dataloader, desc=f"Epoch {ep} Loss: ---")
loss_ema = None
for x, c in pbar:
optim.zero_grad()
x, c = x.to(device), c.to(device)
loss = ddpm(x, c)
loss.backward()
optim.step()
loss_ema = loss.item() if loss_ema is None else 0.95*loss_ema + 0.05*loss.item()
pbar.set_description(f"Epoch {ep} Loss: {loss_ema:.4f}")
# 7. 每轮训练后采样验证
ddpm.eval()
with torch.no_grad():
n_sample = 40
for w in ws_test:
x_gen, x_gen_store = ddpm.sample(n_sample, size=(1,28,28), guide_w=w)
# 保存生成图像
x_all = torch.cat([x_gen, x_real]) # 生成图像+真实图像对比
grid = make_grid(x_all*-1 + 1, nrow=10)
save_image(grid, os.path.join(base_save_dir, f"image_ep{ep}_w{w}.png"))
# 保存采样过程GIF
if ep % 2 == 0 or ep == total_epoch-1:
fig, axs = plt.subplots(nrows=4, ncols=10, figsize=(8,3))
ani = FuncAnimation(fig, animate_diff, fargs=[x_gen_store], interval=200, frames=x_gen_store.shape[0])
ani.save(os.path.join(base_save_dir, f"gif_ep{ep}_w{w}.gif"), dpi=100, writer=PillowWriter(fps=5))
# 保存模型权重
torch.save(ddpm.state_dict(), os.path.join(base_save_dir, f"model_{ep}.pth"))
四、实验结果与分析
4.1 参数量分析
模型总参数量约2260万,核心分布如下:
| 模块 | 参数量 | 占比 |
|---|---|---|
| 12层DiTBlock | 21.9M | 97.2% |
| TimestepEmbedder | 0.237M | 1.05% |
| SimpleHead/FinalLayer | 0.775M | 3.43% |
| PatchEmbed/LabelEmbedder | 0.01M | 0.04% |
4.2 生成效果
- 训练轮次:建议至少训练10轮,8轮后模型开始收敛,生成数字清晰可辨;
- CFG引导权重:
- w=0.0(无引导):生成数字随机性强,部分类别混淆;
- w=2.0(强引导):生成数字类别精准,边缘清晰;
- 采样效率:400步采样在单GPU上约10秒/次,CPU上约1分钟/次。
4.3 关键调优技巧
- 学习率衰减:采用线性衰减策略,训练后期降低学习率,提升收敛稳定性;
- 零初始化调制层:adaLN调制层权重初始化为0,保证扩散模型初始训练稳定;
- Patch尺寸选择:28×28图像选择4×4 Patch,切分为7×7 Patch,无冗余,计算效率最高。
五、总结与拓展
本文完整实现了基于DiT的MNIST扩散模型,从原理到代码拆解了模型的核心设计与实现细节。该模型兼顾了Transformer的全局注意力优势和扩散模型的生成能力,是入门生成式AI的优质案例。
5.1 拓展方向
- 模型轻量化:降低hidden_size(如64)和depth(如2),参数量可降至100万级,适合CPU部署;
- 加速采样:采用DDIM采样策略,将采样步数从400步降至50步,提升采样效率;
- 迁移到其他数据集:修改Patch尺寸、输入通道数和类别数,可适配CIFAR-10、FashionMNIST等数据集;
- 多条件生成:融合更多条件(如手写风格、笔画粗细),实现更精细的可控生成。
5.2 核心收获
- 理解DiT模型中“Patch嵌入+条件化调制”的核心设计;
- 掌握DDPM框架“加噪-去噪”的数学原理与代码实现;
- 熟悉Classifier-Free引导在扩散模型中的应用;
- 具备从0到1搭建Transformer类扩散模型的能力。
六、完整代码获取
本文完整代码可以私信获取,包含模型训练、采样、结果可视化全流程,可直接运行。
附:常见问题解答
Q1:为什么MNIST简单任务需要2260万参数?
A1:生成模型需拟合像素级分布,且Transformer架构本身参数密集(12层Block占97%参数量),若追求轻量化可降低hidden_size和depth。
Q2:训练过程中损失不下降怎么办?
A2:检查学习率(建议1e-4)、Batch Size(建议64)、初始化策略(adaLN调制层零初始化),并确保数据归一化到[0,1]。
Q3:采样结果全是噪声?
A3:模型未收敛(需增加训练轮次),或CFG引导权重设置过高,建议先验证w=0.0的采样结果。
作者简介:查无此人,专注于生成式AI、Transformer、扩散模型研究,擅长将复杂算法落地为可运行的实战代码。
声明:本文仅用于学术交流,代码可自由修改和复用,引用请注明出处。
更多推荐



所有评论(0)