深入解析Sora背后的关键技术之一:《Scalable Diffusion Models with Transformers》

代码链接:GitHub - facebookresearch/DiT: Official PyTorch Implementation of "Scalable Diffusion Models with Transformers"

在AI生成领域,大家一定对OpenAI的Sora模型震撼不已。但大家是否好奇,是什么让Sora能生成如此连贯、高质量的视频?其背后的关键技术之一,便是我们今天要讨论的主角——Diffusion Transformer (DiT)。而这一切,都源于一篇名为《Scalable Diffusion Models with Transformers》的开山论文。

这篇由William Peebles(现属OpenAI)和Saining Xie(纽约大学)在2023年发表的论文,做了一件在事后看来“理所当然”,但在当时极具颠覆性的事情:它用标准的Transformer架构,完全替换了扩散模型中赖以成名的U-Net主干网络。

这不仅是一个简单的模块替换,更是一次思想上的解放,为扩散模型打开了规模化(Scaling) 的新纪元。下面,就让我们一起深入这篇论文,看懂它的核心创新与伟大之处。

一、 背景:扩散模型的“旧王” U-Net 与其瓶颈

DiT出现之前,几乎所有高性能的扩散模型(如Stable Diffusion、DALL-E 2)都建立在U-Net架构之上。

  • U-Net为何成功?

    其编码器-解码器结构带有跳跃连接,非常适合捕捉图像的多尺度特征,既能关注高级语义,又能保留细节纹理,这与扩散模型“从噪声逐步构建图像”的去噪过程完美契合。

  • U-Net的瓶颈:

    尽管有效,但U-Net更像一个为特定任务精心设计的“专家”,而非一个通用的基础模型。其架构相对固化,可扩展性(Scalability) 较差。人们发现,单纯地增加U-Net的深度或宽度,其性能(如生成图片的FID分数)并不会像Transformer在NLP领域那样稳定地随之提升。

与此同时,Transformer架构在几乎所有AI领域都展现了其无与伦比的规模化能力:模型越大、数据越多,性能几乎总能单调提升。

一个自然而然的问题被提了出来:Transformer的成功,能否复制到扩散模型上?

二、 DiT的核心创新:用Transformer重构扩散模型

答案是肯定的,其核心思想可以概括为:“Patchify and Transform.” —— 将图像分块,然后用Transformer处理。

需要注意的是:

  1. 高分辨率像素空间训练问题

         1.在高分辨率图像上直接训练扩散模型计算代价过高。

              2.潜在扩散模型LDM避免了这一问题,减少了计算成本和复杂性,并提高了效率。

  1. LDMs 的两阶段方法

    • 学习一个自编码器,用学习到的编码器 E 将高维图像 x 压缩成低维潜表示 z (z = E(x))

    • 训练一个扩散模型处理表示 z 而不是图像 x编码器 E 是冻结的)。然后通过从扩散模型中采样生成新的图像,使用解码器 D 解码为图像 x (x = D(z))

基于上面试验,所以 DiT 应用了潜在空间!

创新点一:Patchify —— 将图像转换为Transformer的“语言”本质是借鉴了ViT) Transformer的输入是序列(Sequence)。如何让处理图像?DiT借鉴了Vision Transformer (ViT) 的思想:

  1. 输入图像(或在潜在扩散模型中的潜在表示)首先被切分成一个个Patch(图像块)。

  2. 每个Patch被拉平成一个向量,并通过一个线性投影层(Linear Embedding)转换为Token。

  3. 同时,像ViT一样,加入可学习的位置编码(Position Embedding)以保留空间信息。

就这样,图像被转换成了一系列Token,完美地送入了Transformer的“口中”。

创新点二:替换主干网络 —— 告别U-Net,拥抱Transformer 这是最直接的架构变革。DiT用一个标准的、堆叠的Transformer编码器,完全取代了原来扩散模型中的U-Net。去噪过程的所有计算,现在都在这个纯Transformer块内进行。

创新点三:自适应条件注入 —— 优雅地告诉模型“何时”做“什么” 扩散模型是条件生成模型,需要将时间步信息(t,即“现在在第几步去噪”)类别标签(c,即“要生成一只猫”) 等信息注入到网络中。DiT设计了一种极其优雅且高效的方式——自适应层归一化(Adaptive Layer Normalization, adaLN)

  • 具体来说,模型将时间步t和类别c编码成一个向量。

  • 这个向量通过一个小型网络(如MLP)来预测一组scaleshift参数(γ 和 β)。

  • 这组参数被用于动态调整Transformer块中层归一化(Layer Norm) 的操作。

这意味着,条件信息不是通过加法或Concatenation硬塞进去的,而是像调节旋钮一样,精细化地控制着每个特征层的分布,告诉模型:“现在是去噪中期,请生成狗的特征”。

先别着急,这里,怎么理解这scaleshift呢?

为了理解 scaleshift,我们必须先回顾标准的层归一化。对于一个输入向量 x,LN 的操作如下:

  1. 计算均值和方差:在一个特征维度上计算 x 的均值 μ 和方差 σ²。

  2. 标准化:用均值和方差将 x 转换为均值为 0、方差为 1 的分布。 x_normalized = (x - μ) / sqrt(σ² + ε)ε 是一个极小值,防止除以零)

  3. 仿射变换 (Affine Transform):这是关键一步!对标准化后的数据施加一个可学习scale (缩放,记为 γ) 和 shift (平移,记为 β)。 y = x_normalized * (1 + γ) + β

这里的 γ 和 β 就是最初级的 scaleshift。它们是模型需要学习的参数,作用是为网络提供灵活性标准化过程可能会丢失一些重要的信息,而 γ 和 β 允许网络“重新校准”地学习如何缩放和平移数据,以保留对任务有用的特征。

创新点四:adaLN-ZERO —— 保证稳定训练的“神来之笔” ,这是论文中最巧妙的一个细节。作者发现,如果将adaLN预测出的scaleshift参数在训练初期初始化为零,会极大提升训练稳定性。

  • 为什么? 零初始化意味着在训练开始时,条件注入机制是“关闭”的。模型会先专注于学习最核心的“去噪”任务,而不必一开始就费力地理解复杂的时间步和条件信息。

  • 随着训练进行,网络再逐步学习如何利用这些条件来微调输出。这是一种“分阶段学习”的策略,被证明非常有效。

代码实现:(原始论文代码实现)

调节函数:

 ## 调节函数,很简单
 def modulate(x, shift, scale):
     return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)

Dit的前向传播:

     def forward(self, x, t, y):
         """
         Forward pass of DiT.
         x: (N, C, H, W) tensor of spatial inputs (images or latent representations of images)
         t: (N,) tensor of diffusion timesteps
         y: (N,) tensor of class labels
         """
         ## 输入的噪声
         x = self.x_embedder(x) + self.pos_embed  # (N, T, D), where T = H * W / patch_size ** 2
         ## 时间步
         t = self.t_embedder(t)                   # (N, D)
         ## 类别
         y = self.y_embedder(y, self.training)    # (N, D)
         ## 时间步 + 类别
         c = t + y                                # (N, D)
         ## 搭建DiT block
         for block in self.blocks:
             x = block(x, c)                      # (N, T, D)
         x = self.final_layer(x, c)               # (N, T, patch_size ** 2 * out_channels)
         x = self.unpatchify(x)                   # (N, out_channels, H, W)
         return x

DiT block:

 class DiTBlock(nn.Module):
     """
     A DiT block with adaptive layer norm zero (adaLN-Zero) conditioning.
     """
     def __init__(self, hidden_size, num_heads, mlp_ratio=4.0, **block_kwargs):
         super().__init__()
         self.norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
         self.attn = Attention(hidden_size, num_heads=num_heads, qkv_bias=True, **block_kwargs)
         self.norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
         mlp_hidden_dim = int(hidden_size * mlp_ratio)
         approx_gelu = lambda: nn.GELU(approximate="tanh")
         self.mlp = Mlp(in_features=hidden_size, hidden_features=mlp_hidden_dim, act_layer=approx_gelu, drop=0)
         self.adaLN_modulation = nn.Sequential(
             nn.SiLU(),
             nn.Linear(hidden_size, 6 * hidden_size, bias=True)
         )
 ​
     def forward(self, x, c):
         shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(c).chunk(6, dim=1)
         x = x + gate_msa.unsqueeze(1) * self.attn(modulate(self.norm1(x), shift_msa, scale_msa))
         x = x + gate_mlp.unsqueeze(1) * self.mlp(modulate(self.norm2(x), shift_mlp, scale_mlp))
         return x

Logo

有“AI”的1024 = 2048,欢迎大家加入2048 AI社区

更多推荐