【实战】基于DiT的MNIST扩散模型:从原理到代码全解析

摘要

扩散模型作为当前生成式AI的核心技术,在图像生成领域展现出卓越性能,但传统U-Net架构的扩散模型存在结构复杂、部署成本高的问题。本文以经典的MNIST手写数字数据集为载体,详细讲解基于Transformer的DiT(Diffusion Model with Transformer)扩散模型的实现原理与代码细节,从模型架构设计、扩散过程建模到完整训练与采样流程,全方位拆解轻量化DiT模型的构建方法。实验表明,该模型仅需约2260万参数即可实现高质量的MNIST数字生成,兼顾性能与易用性,是入门扩散模型与Transformer结合的绝佳案例。

关键词:扩散模型;DiT;Transformer;MNIST;图像生成;PyTorch

一、研究背景与动机

扩散模型(Diffusion Model)通过模拟“加噪-去噪”的马尔可夫过程,能够生成高质量的图像数据,已成为GAN、VAE之外的主流生成模型。传统扩散模型多基于U-Net架构,依赖卷积操作提取空间特征,而近年来基于Transformer的DiT模型凭借全局注意力机制,在图像生成任务中展现出更优的特征建模能力。

MNIST数据集作为计算机视觉的“Hello World”,具有数据量小、任务简单的特点,非常适合作为扩散模型入门的实验载体。本文基于DiT架构实现面向MNIST的轻量化扩散模型,核心目标包括:

  1. 理解DiT模型的核心设计(adaLN-Zero条件化、Patch嵌入、时间/类别嵌入);
  2. 掌握DDPM扩散框架的“加噪-去噪”核心逻辑;
  3. 实现从模型搭建、训练到采样生成的全流程代码落地。

二、核心原理剖析

2.1 DDPM扩散框架基础

DDPM(Denoising Diffusion Probabilistic Models)是扩散模型的经典框架,核心分为两个阶段:

  • 前向扩散(加噪):从真实图像x0x_0x0出发,逐步加入高斯噪声,得到含噪图像xtx_txt,满足公式:
    xt=αˉtx0+1−αˉtϵx_t = \sqrt{\bar{\alpha}_t}x_0 + \sqrt{1-\bar{\alpha}_t}\epsilonxt=αˉt x0+1αˉt ϵ
    其中αˉt\bar{\alpha}_tαˉt为累乘的噪声系数,ϵ\epsilonϵ为高斯噪声。
  • 反向扩散(去噪):训练模型学习从含噪图像xtx_txt中预测噪声ϵ\epsilonϵ,并逐步去噪还原出真实图像。

2.2 DiT模型核心设计

DiT(Diffusion Transformer)将Transformer架构与扩散模型结合,核心改进包括:

  1. Patch嵌入:将图像切分为固定大小的Patch,映射到高维特征空间,替代卷积操作;
  2. 条件化嵌入:融合时间步嵌入(Timestep Embedder)和类别嵌入(Label Embedder),实现条件化生成;
  3. adaLN-Zero调制:通过自适应层归一化,将时间/类别条件融入Transformer Block,实现条件化特征调制;
  4. Classifier-Free引导(CFG):通过随机丢弃类别标签,提升生成模型的可控性。

三、代码全解析

3.1 环境依赖

# 核心依赖
pip install torch torchvision tqdm matplotlib timm numpy

3.2 核心模块拆解

3.2.1 位置编码模块

位置编码是Transformer的核心组件,本文采用2D正弦余弦位置编码,为每个Patch赋予空间位置信息:

def get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False, extra_tokens=0):
    grid_h = np.arange(grid_size, dtype=np.float32)
    grid_w = np.arange(grid_size, dtype=np.float32)
    grid = np.meshgrid(grid_w, grid_h)  # 生成2D网格
    grid = np.stack(grid, axis=0)
    grid = grid.reshape([2, 1, grid_size, grid_size])
    pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid)
    if cls_token and extra_tokens > 0:
        pos_embed = np.concatenate([np.zeros([extra_tokens, embed_dim]), pos_embed], axis=0)
    return pos_embed

核心作用:为7×7的MNIST Patch(28×28图像切分为4×4 Patch)赋予唯一的位置特征,弥补Transformer的位置无关性。

3.2.2 DiTBlock核心模块

DiTBlock是模型的核心计算单元,集成了多头注意力、MLP和adaLN-Zero调制:

class DiTBlock(nn.Module):
    def __init__(self, hidden_size, num_heads, mlp_ratio=4.0, **block_kwargs):
        super().__init__()
        self.norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
        self.attn = Attention(hidden_size, num_heads=num_heads, qkv_bias=True, qk_norm=False)
        self.norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
        mlp_hidden_dim = int(hidden_size * mlp_ratio)
        approx_gelu = lambda: nn.GELU(approximate="tanh")
        self.mlp = Mlp(in_features=hidden_size, hidden_features=mlp_hidden_dim, act_layer=approx_gelu, drop=0)
        # adaLN-Zero调制层:融合时间/类别条件
        self.adaLN_modulation = nn.Sequential(
            nn.SiLU(),
            nn.Linear(hidden_size, 6 * hidden_size, bias=True)
        )

    def forward(self, x, c):
        # 将条件向量c分解为6个调制参数
        shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(c).chunk(6, dim=1)
        # 注意力层:条件化调制
        x = x + gate_msa.unsqueeze(1) * self.attn(modulate(self.norm1(x), shift_msa, scale_msa))
        # MLP层:条件化调制
        x = x + gate_mlp.unsqueeze(1) * self.mlp(modulate(self.norm2(x), shift_mlp, scale_mlp))
        return x

核心亮点:通过adaLN_modulation将时间/类别条件(c)转换为调制参数,对注意力层和MLP层进行自适应调整,实现条件化生成。

3.2.3 DDPM扩散框架封装

将DiT模型与DDPM扩散流程封装,实现训练和采样的统一管理:

class DDPM(nn.Module):
    def __init__(self, nn_model, betas=(1e-4, 0.02), n_T=400, device="cpu", drop_prob=0.1):
        super().__init__()
        self.nn_model = nn_model.to(device)
        # 预计算扩散系数并注册为缓冲区(不参与训练)
        for k, v in ddpm_schedules(betas[0], betas[1], n_T).items():
            self.register_buffer(k, v)
        self.n_T = n_T
        self.device = device
        self.drop_prob = drop_prob
        self.loss_mse = nn.MSELoss()

    def forward(self, x, c):
        """训练阶段:随机加噪+噪声预测"""
        _ts = torch.randint(1, self.n_T+1, (x.shape[0],)).to(self.device)  # 随机采样时间步
        noise = torch.randn_like(x)  # 生成高斯噪声
        # 前向扩散:生成含噪图像x_t
        x_t = self.sqrtab[_ts, None, None, None] * x + self.sqrtmab[_ts, None, None, None] * noise
        # CFG引导:随机丢弃类别标签
        context_mask = torch.bernoulli(torch.zeros_like(c)+self.drop_prob).to(self.device)
        # DiT模型预测噪声
        eps_pred = self.nn_model(x_t, _ts, c, context_mask)
        # MSE损失:预测噪声 vs 真实噪声
        return self.loss_mse(noise, eps_pred)

    def sample(self, n_sample, size=(1,28,28), guide_w=0.0):
        """采样阶段:从噪声逐步去噪生成图像"""
        x_i = torch.randn(n_sample, *size).to(self.device)  # 初始纯噪声
        c_i = torch.arange(0,10).to(self.device).repeat(int(n_sample/10))  # 生成10类标签
        # CFG引导:双倍batch实现有/无标签引导
        context_mask = torch.zeros_like(c_i).to(self.device)
        c_i = c_i.repeat(2)
        context_mask = context_mask.repeat(2)
        context_mask[n_sample:] = 1.

        x_i_store = []  # 保存采样过程(用于生成GIF)
        # 反向扩散:从T到1逐步去噪
        for i in range(self.n_T, 0, -1):
            t_is = torch.tensor([i]*n_sample).to(self.device).repeat(2)
            z = torch.randn(n_sample, *size).to(self.device) if i > 1 else 0.  # 最后一步不加噪声
            x_i = x_i.repeat(2,1,1,1)
            eps = self.nn_model(x_i, t_is, c_i, context_mask)
            # CFG融合:eps = 无引导 + w*(有引导 - 无引导)
            eps_cond, eps_uncond = eps[:n_sample], eps[n_sample:]
            eps = eps_uncond + guide_w * (eps_cond - eps_uncond)
            # 反向扩散更新公式
            x_i = self.oneover_sqrta[i] * (x_i[:n_sample] - eps * self.mab_over_sqrtmab[i]) + self.sqrt_beta_t[i] * z
            if i%20==0 or i==self.n_T or i<8:
                x_i_store.append(x_i.detach().cpu().numpy())
        return x_i, np.array(x_i_store)

核心逻辑

  • 训练阶段:随机采样时间步,生成含噪图像,训练DiT模型预测噪声;
  • 采样阶段:从纯噪声出发,利用训练好的模型逐步预测并去除噪声,最终生成图像;
  • CFG引导:通过双倍batch实现“有标签”和“无标签”的预测结果融合,提升生成可控性。

3.3 训练与采样全流程

def train_mnist_dit(total_epoch=10, batch_size=64, base_save_dir='./data/diffusion_dit_mnist/'):
    # 1. 超参数配置
    n_T = 400
    device = "cuda:0" if torch.cuda.is_available() else "cpu"
    n_classes = 10
    lrate = 1e-4
    ws_test = [0.0, 0.5, 2.0]  # CFG引导权重

    # 2. 初始化DiT模型
    dit_model = DiT(
        input_size=28, patch_size=4, in_channels=1,
        hidden_size=384, depth=12, num_heads=6,
        class_dropout_prob=0.1, num_classes=10, learn_sigma=False
    )

    # 3. 初始化DDPM框架
    ddpm = DDPM(nn_model=dit_model, n_T=n_T, device=device, drop_prob=0.1)
    print(f"模型总参数量:{sum(p.numel() for p in ddpm.parameters()):,}")  # 约2260万

    # 4. 加载MNIST数据集
    tf = transforms.Compose([transforms.ToTensor()])
    dataset = MNIST("./data", train=True, download=True, transform=tf)
    dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=0)

    # 5. 优化器配置
    optim = torch.optim.Adam(ddpm.parameters(), lr=lrate, betas=(0.9, 0.999))

    # 6. 训练循环
    os.makedirs(base_save_dir, exist_ok=True)
    for ep in range(total_epoch):
        ddpm.train()
        optim.param_groups[0]['lr'] = lrate * (1 - ep / total_epoch)  # 学习率衰减
        pbar = tqdm(dataloader, desc=f"Epoch {ep} Loss: ---")
        loss_ema = None

        for x, c in pbar:
            optim.zero_grad()
            x, c = x.to(device), c.to(device)
            loss = ddpm(x, c)
            loss.backward()
            optim.step()
            loss_ema = loss.item() if loss_ema is None else 0.95*loss_ema + 0.05*loss.item()
            pbar.set_description(f"Epoch {ep} Loss: {loss_ema:.4f}")

        # 7. 每轮训练后采样验证
        ddpm.eval()
        with torch.no_grad():
            n_sample = 40
            for w in ws_test:
                x_gen, x_gen_store = ddpm.sample(n_sample, size=(1,28,28), guide_w=w)
                # 保存生成图像
                x_all = torch.cat([x_gen, x_real])  # 生成图像+真实图像对比
                grid = make_grid(x_all*-1 + 1, nrow=10)
                save_image(grid, os.path.join(base_save_dir, f"image_ep{ep}_w{w}.png"))
                # 保存采样过程GIF
                if ep % 2 == 0 or ep == total_epoch-1:
                    fig, axs = plt.subplots(nrows=4, ncols=10, figsize=(8,3))
                    ani = FuncAnimation(fig, animate_diff, fargs=[x_gen_store], interval=200, frames=x_gen_store.shape[0])
                    ani.save(os.path.join(base_save_dir, f"gif_ep{ep}_w{w}.gif"), dpi=100, writer=PillowWriter(fps=5))
            # 保存模型权重
            torch.save(ddpm.state_dict(), os.path.join(base_save_dir, f"model_{ep}.pth"))

四、实验结果与分析

4.1 参数量分析

模型总参数量约2260万,核心分布如下:

模块 参数量 占比
12层DiTBlock 21.9M 97.2%
TimestepEmbedder 0.237M 1.05%
SimpleHead/FinalLayer 0.775M 3.43%
PatchEmbed/LabelEmbedder 0.01M 0.04%

4.2 生成效果

  • 训练轮次:建议至少训练10轮,8轮后模型开始收敛,生成数字清晰可辨;
  • CFG引导权重
    • w=0.0(无引导):生成数字随机性强,部分类别混淆;
    • w=2.0(强引导):生成数字类别精准,边缘清晰;
  • 采样效率:400步采样在单GPU上约10秒/次,CPU上约1分钟/次。

4.3 关键调优技巧

  1. 学习率衰减:采用线性衰减策略,训练后期降低学习率,提升收敛稳定性;
  2. 零初始化调制层:adaLN调制层权重初始化为0,保证扩散模型初始训练稳定;
  3. Patch尺寸选择:28×28图像选择4×4 Patch,切分为7×7 Patch,无冗余,计算效率最高。

五、总结与拓展

本文完整实现了基于DiT的MNIST扩散模型,从原理到代码拆解了模型的核心设计与实现细节。该模型兼顾了Transformer的全局注意力优势和扩散模型的生成能力,是入门生成式AI的优质案例。

5.1 拓展方向

  1. 模型轻量化:降低hidden_size(如64)和depth(如2),参数量可降至100万级,适合CPU部署;
  2. 加速采样:采用DDIM采样策略,将采样步数从400步降至50步,提升采样效率;
  3. 迁移到其他数据集:修改Patch尺寸、输入通道数和类别数,可适配CIFAR-10、FashionMNIST等数据集;
  4. 多条件生成:融合更多条件(如手写风格、笔画粗细),实现更精细的可控生成。

5.2 核心收获

  1. 理解DiT模型中“Patch嵌入+条件化调制”的核心设计;
  2. 掌握DDPM框架“加噪-去噪”的数学原理与代码实现;
  3. 熟悉Classifier-Free引导在扩散模型中的应用;
  4. 具备从0到1搭建Transformer类扩散模型的能力。

六、完整代码获取

本文完整代码可以私信获取,包含模型训练、采样、结果可视化全流程,可直接运行。

附:常见问题解答

Q1:为什么MNIST简单任务需要2260万参数?
A1:生成模型需拟合像素级分布,且Transformer架构本身参数密集(12层Block占97%参数量),若追求轻量化可降低hidden_size和depth。

Q2:训练过程中损失不下降怎么办?
A2:检查学习率(建议1e-4)、Batch Size(建议64)、初始化策略(adaLN调制层零初始化),并确保数据归一化到[0,1]。

Q3:采样结果全是噪声?
A3:模型未收敛(需增加训练轮次),或CFG引导权重设置过高,建议先验证w=0.0的采样结果。


作者简介:查无此人,专注于生成式AI、Transformer、扩散模型研究,擅长将复杂算法落地为可运行的实战代码。

声明:本文仅用于学术交流,代码可自由修改和复用,引用请注明出处。

Logo

有“AI”的1024 = 2048,欢迎大家加入2048 AI社区

更多推荐