Diffusion Transformer (DiT):用Transformer取代U-Net,成为扩散模型新基石
摘要:OpenAI的Sora模型采用Diffusion Transformer(DiT)技术取代传统U-Net架构,实现高质量视频生成。DiT通过将图像分块转换为Token序列,利用Transformer处理扩散过程,并创新性地使用自适应层归一化(adaLN)注入条件信息。其中adaLN-ZERO技术通过零初始化确保训练稳定性,使模型先学习核心去噪任务再逐步适应条件控制。该技术突破为扩散模型带来更
深入解析
Sora
背后的关键技术之一:《Scalable Diffusion Models with Transformers》
在AI生成领域,大家一定对OpenAI的Sora模型震撼不已。但大家是否好奇,是什么让Sora能生成如此连贯、高质量的视频?其背后的关键技术之一,便是我们今天要讨论的主角——Diffusion Transformer (DiT)。而这一切,都源于一篇名为《Scalable Diffusion Models with Transformers
》的开山论文。
这篇由William Peebles(现属OpenAI)和Saining Xie(纽约大学)在2023年发表的论文,做了一件在事后看来“理所当然”,但在当时极具颠覆性的事情:它用标准的Transformer架构,完全替换了扩散模型中赖以成名的U-Net主干网络。
这不仅是一个简单的模块替换,更是一次思想上的解放,为扩散模型打开了规模化(Scaling) 的新纪元。下面,就让我们一起深入这篇论文,看懂它的核心创新与伟大之处。
一、 背景:扩散模型的“旧王” U-Net
与其瓶颈
在DiT
出现之前,几乎所有高性能的扩散模型(如Stable Diffusion、DALL-E 2)都建立在U-Net
架构之上。
-
U-Net为何成功?
其编码器-解码器结构带有跳跃连接,非常适合捕捉图像的多尺度特征,既能关注高级语义,又能保留细节纹理,这与扩散模型“从噪声逐步构建图像”的去噪过程完美契合。
-
U-Net的瓶颈:
尽管有效,但U-Net更像一个为特定任务精心设计的“专家”,而非一个通用的基础模型。其架构相对固化,可扩展性(Scalability) 较差。人们发现,单纯地增加U-Net的深度或宽度,其性能(如生成图片的FID分数)并不会像Transformer在NLP领域那样稳定地随之提升。
与此同时,Transformer架构在几乎所有AI领域都展现了其无与伦比的规模化能力:模型越大、数据越多,性能几乎总能单调提升。
一个自然而然的问题被提了出来:Transformer的成功,能否复制到扩散模型上?
二、 DiT的核心创新:用Transformer重构扩散模型
答案是肯定的,其核心思想可以概括为:“Patchify and Transform.” —— 将图像分块,然后用Transformer处理。
需要注意的是:
-
高分辨率像素空间训练问题:
1.在高分辨率图像上直接训练扩散模型计算代价过高。
2.潜在扩散模型LDM
避免了这一问题,减少了计算成本和复杂性,并提高了效率。
-
LDMs 的两阶段方法:
-
学习一个自编码器,用学习到的编码器 E 将高维图像 x 压缩成低维潜表示
z (z = E(x))
; -
训练一个扩散模型处理表示 z 而不是图像 x (编码器 E 是冻结的)。然后通过从扩散模型中采样生成新的图像,使用解码器 D 解码为图像
x (x = D(z))
。
-
基于上面试验,所以 DiT
应用了潜在空间!
创新点一:Patchify —— 将图像转换为Transformer的“语言” (本质是借鉴了ViT) Transformer的输入是序列(Sequence)。如何让处理图像?DiT借鉴了Vision Transformer (ViT) 的思想:
-
输入图像(或在潜在扩散模型中的潜在表示)首先被切分成一个个Patch(图像块)。
-
每个Patch被拉平成一个向量,并通过一个线性投影层(
Linear Embedding
)转换为Token。 -
同时,像ViT一样,加入可学习的位置编码(Position Embedding)以保留空间信息。
就这样,图像被转换成了一系列Token,完美地送入了Transformer的“口中”。
创新点二:替换主干网络 —— 告别U-Net,拥抱Transformer 这是最直接的架构变革。DiT用一个标准的、堆叠的Transformer编码器,完全取代了原来扩散模型中的U-Net。去噪过程的所有计算,现在都在这个纯Transformer块内进行。
创新点三:自适应条件注入 —— 优雅地告诉模型“何时”做“什么” 扩散模型是条件生成模型,需要将时间步信息(t,即“现在在第几步去噪”) 和类别标签(c,即“要生成一只猫”) 等信息注入到网络中。DiT设计了一种极其优雅且高效的方式——自适应层归一化(Adaptive Layer Normalization, adaLN)。
-
具体来说,模型将时间步
t
和类别c
编码成一个向量。 -
这个向量通过一个小型网络(如MLP)来预测一组
scale
和shift
参数(γ 和 β)。 -
这组参数被用于动态调整Transformer块中层归一化(Layer Norm) 的操作。
这意味着,条件信息不是通过加法或Concatenation硬塞进去的,而是像调节旋钮一样,精细化地控制着每个特征层的分布,告诉模型:“现在是去噪中期,请生成狗的特征”。
先别着急,这里,怎么理解这scale
和 shift
呢?
为了理解 scale
和 shift
,我们必须先回顾标准的层归一化。对于一个输入向量 x
,LN 的操作如下:
-
计算均值和方差:在一个特征维度上计算
x
的均值 μ 和方差 σ²。 -
标准化:用均值和方差将
x
转换为均值为 0、方差为 1 的分布。x_normalized = (x - μ) / sqrt(σ² + ε)
(ε
是一个极小值,防止除以零) -
仿射变换 (Affine Transform):这是关键一步!对标准化后的数据施加一个可学习的
scale
(缩放,记为 γ) 和 shift (平移,记为 β)。y = x_normalized * (1 + γ) + β
这里的 γ 和 β 就是最初级的 scale
和 shift
。它们是模型需要学习的参数,作用是为网络提供灵活性。标准化过程可能会丢失一些重要的信息,而 γ 和 β 允许网络“重新校准”地学习如何缩放和平移数据,以保留对任务有用的特征。
创新点四:adaLN-ZERO —— 保证稳定训练的“神来之笔” ,这是论文中最巧妙的一个细节。作者发现,如果将adaLN预测出的scale
和shift
参数在训练初期初始化为零,会极大提升训练稳定性。
-
为什么? 零初始化意味着在训练开始时,条件注入机制是“关闭”的。模型会先专注于学习最核心的“去噪”任务,而不必一开始就费力地理解复杂的时间步和条件信息。
-
随着训练进行,网络再逐步学习如何利用这些条件来微调输出。这是一种“分阶段学习”的策略,被证明非常有效。
代码实现:(原始论文代码实现)
调节函数:
## 调节函数,很简单
def modulate(x, shift, scale):
return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)
Dit的前向传播:
def forward(self, x, t, y):
"""
Forward pass of DiT.
x: (N, C, H, W) tensor of spatial inputs (images or latent representations of images)
t: (N,) tensor of diffusion timesteps
y: (N,) tensor of class labels
"""
## 输入的噪声
x = self.x_embedder(x) + self.pos_embed # (N, T, D), where T = H * W / patch_size ** 2
## 时间步
t = self.t_embedder(t) # (N, D)
## 类别
y = self.y_embedder(y, self.training) # (N, D)
## 时间步 + 类别
c = t + y # (N, D)
## 搭建DiT block
for block in self.blocks:
x = block(x, c) # (N, T, D)
x = self.final_layer(x, c) # (N, T, patch_size ** 2 * out_channels)
x = self.unpatchify(x) # (N, out_channels, H, W)
return x
DiT block:
class DiTBlock(nn.Module):
"""
A DiT block with adaptive layer norm zero (adaLN-Zero) conditioning.
"""
def __init__(self, hidden_size, num_heads, mlp_ratio=4.0, **block_kwargs):
super().__init__()
self.norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
self.attn = Attention(hidden_size, num_heads=num_heads, qkv_bias=True, **block_kwargs)
self.norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
mlp_hidden_dim = int(hidden_size * mlp_ratio)
approx_gelu = lambda: nn.GELU(approximate="tanh")
self.mlp = Mlp(in_features=hidden_size, hidden_features=mlp_hidden_dim, act_layer=approx_gelu, drop=0)
self.adaLN_modulation = nn.Sequential(
nn.SiLU(),
nn.Linear(hidden_size, 6 * hidden_size, bias=True)
)
def forward(self, x, c):
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(c).chunk(6, dim=1)
x = x + gate_msa.unsqueeze(1) * self.attn(modulate(self.norm1(x), shift_msa, scale_msa))
x = x + gate_mlp.unsqueeze(1) * self.mlp(modulate(self.norm2(x), shift_mlp, scale_mlp))
return x
更多推荐
所有评论(0)