DiT
DiT: Scalable Diffusion Models with Transformers
核心创新
DiT (Peebles & Xie, ICCV 2023) 的核心思想非常简单但强力:用标准 Transformer 替换 UNet 作为扩散模型的骨干网络 (Backbone)。 这一改动使得 Diffusion Model 能够享受到 Transformer 强大的可扩展性 (Scalability),为后来的 Sora 等视频生成模型奠定了基础。
1. 为什么用 Transformer? (Motivation)
在 DiT 之前,大多数 Diffusion Model (如 SD 1.4/1.5) 都使用 UNet 及其变体。
- UNet 的局限: 包含大量针对特定分辨率设计的归纳偏置 (Inductive Bias),如下采样、上采样、由卷积构成的残差块等。这使得模型难以轻松扩展参数量或适应多模态输入。
- Transformer 的优势:
- Scalability: 遵循 Scaling Laws,参数量越大,训练得到的 Loss 越低,生成质量越高。
- 通用性: 处理 Token 序列,对输入形式(图片、视频、音频)不敏感。
2. DiT 架构详解 (Architecture)
DiT 并不是直接在像素空间操作,而是像 Stable Diffusion 一样,工作在 Latent Space (LDM)。
2.1 Patchify (图像块化)
输入是从 VAE 编码器提取的 Latent 特征 $z \in \mathbb{R}^{H \times W \times C}$(对于 $256 \times 256$ 的图像,经过 SD VAE 后通常为 $32 \times 32 \times 4$)。
空间切块 (Spatial Patching):
- 设定 Patch 大小为 $p \times p$(DiT 论文中常用 $p=2, 4, 8$)。
- 将特征图划分为 $N = (H/p) \times (W/p)$ 个不重叠的图像块。
- 每个图像块的原始维度为 $p \times p \times C$。
线性映射 (Linear Projection):
- 数学本质:将每个 Patch 展平为一维向量 $v \in \mathbb{R}^{p^2C}$,通过一个可学习的权重矩阵 $\mathbf{W} \in \mathbb{R}^{p^2C \times D}$ 映射到 Transformer 的隐藏维度 $D$。
- 卷积实现:在代码中,这通常通过一个
stride=p且kernel_size=p的二维卷积高效完成: $$\text{Tokens} = \text{Conv2d}(z, \text{out\_channels}=D, \text{kernel}=p, \text{stride}=p)$$ - 输出维度:卷积后的特征图空间维度被展平,得到形状为 $\mathbb{R}^{N \times D}$ 的 Token 序列。
2D 正余弦位置编码 (2D Sine-Cosine Positional Embeddings): 由于 Transformer 具有置换不变性,必须引入位置信息。DiT 采用了 2D 绝对正余弦编码(非学习型):
- 坐标网格:首先为 $N$ 个 Token 生成一个 $(x, y)$ 坐标网格,坐标范围通常设为 $[0, \sqrt{N}-1]$。
- 分量计算:将总维度 $D$ 分成两部分(各 $D/2$),分别用于编码 $x$ 坐标和 $y$ 坐标。
- 对 $x$ 和 $y$ 分别应用标准的一维正余弦公式,生成对应的特征向量。
- 公式参考:$PE(pos, 2i) = \sin(pos / 10000^{2i/d_{model}})$。
- 拼接与注入:将 $x$ 和 $y$ 的特征向量拼接(Concatenate)成完整的 $D$ 维向量 $\mathbf{P}_i$。
- 注入方式:在送入第一个 DiT Block 之前,将该位置编码直接逐元素相加到 Token 序列上: $$z_{input} = \text{Tokens} + \mathbf{P}$$
- 核心作用:让模型知道序列中每个 Token 在原始二维网格中的相对位置,这对于图像生成中的空间结构保持至关重要。
2.2 DiT Blocks: 条件注入与 adaLN-Zero
DiT Block 的核心在于如何将扩散过程的时间步 $t$ 和类别信息 $c$ 高效地融入 Transformer 结构中。
1. 条件 Embedding 的构造
在进入 DiT Blocks 之前,条件信息首先被转化为一个统一的条件向量 $y$:
- 时间步 $t$:使用标准的正弦位置编码(Frequency Embedding),捕捉时间步的频率特征,再通过一个两层 MLP 映射。
- 类别标签 $c$ (原始 DiT):在原版论文中,DiT 主要在 ImageNet 上训练,因此 $c$ 是一个 0-999 的整数索引。
- 通过一个可学习的
LabelEmbedder(本质是 Embedding Lookup Table)将其映射为向量。 - CFG 支持:为了支持 Classifier-Free Guidance,通常会预留一个额外的
null类别(如第 1001 个),在训练时以一定概率随机替换真实标签。
- 通过一个可学习的
- 融合方式:将 $t$ 和 $c$ 的 Embedding 向量相加,得到最终的条件表征 $y \in \mathbb{R}^D$。
2. adaLN-Zero (Adaptive Layer Norm Zero)
这是 DiT 最关键的设计,它通过调节 LayerNorm 的参数来实现条件的精准注入。
控制参数的产生: 每个 DiT Block 内部都有一个由两层 SiLU + Linear 组成的 MLP,它接收条件向量 $y$,并为当前 Block 生成 6 个专门的控制参数:
$$(\gamma_1, \beta_1, \alpha_1, \gamma_2, \beta_2, \alpha_2) = \text{MLP}(y)$$具体注入流程:
- 对 Self-Attention 层进行调制:
- 输入先经过
LayerNorm,然后利用 $(\gamma_1, \beta_1)$ 进行缩放和偏移: $$x_{mod} = \text{LayerNorm}(x) \cdot (1 + \gamma_1) + \beta_1$$ - 计算 Attention 结果后,乘以门控因子 $\alpha_1$ 再进行残差连接: $$x = x + \alpha_1 \cdot \text{MultiHeadAttention}(x_{mod})$$
- 输入先经过
- 对 Feedforward (FFN) 层进行调制:
- 使用 $(\gamma_2, \beta_2)$ 对 FFN 前的输入进行类似的调制。
- 计算 FFN 结果后,乘以门控因子 $\alpha_2$ 进行残差连接: $$x = x + \alpha_2 \cdot \text{FFN}(x_{mod2})$$
“Zero” 初始化的奥秘:
- 初始化策略:DiT 将回归出这 6 个参数的 MLP 的最后一层线性层初始化为 全 0。
- 效果:这意味着在训练刚开始时,所有的 $\gamma, \beta, \alpha$ 均为 0。此时每一个 DiT Block 表现为恒等映射 (Identity Function),即 $x_{out} = x_{in}$。
- 意义:这极大地缓解了深层网络训练初期的不稳定性,允许模型从一个简单的起点平滑地学习复杂的扩散分布,被证明比传统的交叉注意力 (Cross-Attention) 更高效。
2.3 整体流程
Latent $\to$ Patchify $\to$ Tokens $\to$ [DiT Blocks with adaLN-Zero] $\to$ Unpatchify $\to$ Predicted Noise/Velocities.
2.4 从类别引导到文本引导的演进 (Evolution to T2I)
随着 Diffusion 模型从单一类别生成转向复杂的文本生成图像 (Text-to-Image),条件注入方式发生了显著变化:
输入形式的转变:
- DiT (Original): 输入是离散的 Class ID(如 248 代表“哈士奇”)。
- SD3 / Sora: 输入是连续的 Text Embeddings(由 CLIP 或 T5 编码器提取的稠密特征向量向量)。
注入机制的增强 (MM-DiT): 在 Stable Diffusion 3 (SD3) 中,DiT 架构进一步演进为 MM-DiT (Multi-Modal DiT):
- 双流处理:文本和图像各有各的 Embedding 和位置编码。
- 联合注意力 (Joint Attention):文本 Token 和图像 Token 被拼接在一起送入 Transformer。这意味着文本不再仅仅是通过 adaLN 注入的全局参数,而是参与到全局的 Self-Attention 中,与图像像素进行细粒度的交互。
- 效果:极大地提升了模型对复杂 Prompt(如文字排版、多物体关系)的理解和还原能力。
Sora 的时空统一: Sora 将视频帧切分为“时空块 (Spacetime Patches)”,通过类似的文本 Embedding 引导,在更高维度的序列上应用 DiT,证明了该架构在处理多模态数据上的极致扩展性。
3. 实验结论 (Key Findings)
- Scaling Laws: 增加模型大小 (N-Layers, Hidden Size) 或减少 Patch Size,FID (生成质量指标) 会显著下降。
- Gflops 决定性能: 模型性能与训练时的计算量 (Gflops) 呈强相关,而与具体的参数结构关系较小。
4. 影响与后续 (Impact)
- Sora: OpenAI 的 Sora 本质上就是一个视频版本的 DiT (Video DiT)。它将视频看作时空块 (Spacetime Patches) 的序列,用 Transformer 统一处理。
- Stable Diffusion 3 (SD3): 采用了基于 DiT 的架构 (MMDiT),分别处理文本和图像 Token,实现了极高的文本遵循度。