VAE

Variational Autoencoder (VAE)

1. 简介 (Introduction)

VAE (Variational Autoencoder) 是一种生成模型 (Generative Model),由 Kingma 和 Welling 于 2013 年提出。与传统的自编码器 (Autoencoder) 不同,VAE 的目标不仅仅是重建输入,而是学习输入数据的潜在分布 (Latent Distribution),从而能够生成新的数据样本。

  • Autoencoder: 学习一个确定的压缩表示 (Compressed Representation)。主要用于降维、去噪。
  • VAE: 学习潜在变量的概率分布 (Probability Distribution)。主要用于生成新样本。

2. 数学基础 (Mathematical Foundation)

2.1 目标

假设我们有一组数据 $X = \{x^{(i)}\}_{i=1}^N$,我们希望对数据生成过程进行建模,即学习联合概率 $P(x, z)$,其中 $z$ 是不可观测的潜在变量 (Latent Variable)。

2.2 变分推断 (Variational Inference)

我们需要计算后验概率 $P(z|x)$,但根据贝叶斯公式 $P(z|x) = \frac{P(x|z)P(z)}{P(x)}$,分母 $P(x) = \int P(x|z)P(z) dz$ 通常是难以计算的 (Intractable)。

因此,我们引入一个近似分布 $Q(z|x)$ (即 Encoder),并试图让它尽可能接近真实的后验分布 $P(z|x)$。衡量两个分布相似度的指标是 KL 散度 (Kullback-Leibler Divergence)。

2.3 ELBO (Evidence Lower Bound)

我们的目标是最大化观测数据的对数似然 $\log P(x)$。通过推导,可以得到:

$$ \log P(x) - D_{KL}[Q(z|x) || P(z|x)] = \mathbb{E}_{z \sim Q}[\log P(x|z)] - D_{KL}[Q(z|x) || P(z)] $$

由于 $D_{KL} \ge 0$,所以:

$$ \log P(x) \ge \mathbb{E}_{z \sim Q}[\log P(x|z)] - D_{KL}[Q(z|x) || P(z)] $$

右边这一项被称为 ELBO (Evidence Lower Bound)。最大化 $\log P(x)$ 等价于最大化 ELBO。

ELBO 由两部分组成:

  1. Reconstruction Term $\mathbb{E}_{z \sim Q}[\log P(x|z)]$:希望 Decoder 能从 $z$ 很好地重建 $x$。
  2. Regularization Term $- D_{KL}[Q(z|x) || P(z)]$:希望 Encoder 输出的分布 $Q(z|x)$ 接近先验分布 $P(z)$ (通常假设为标准正态分布 $\mathcal{N}(0, I)$)。

2.4 重参数化技巧 (Reparameterization Trick)

为了能够对 $z$ 进行采样并进行反向传播,VAE 使用了重参数化技巧。 直接从 $\mathcal{N}(\mu, \sigma^2)$ 采样 $z$ 是不可导的操作。 我们将 $z$ 表示为:

$$ z = \mu + \sigma \odot \epsilon, \quad \epsilon \sim \mathcal{N}(0, I) $$

这样,随机性转移到了 $\epsilon$ 上,而 $\mu$ 和 $\sigma$ 则是网络的可学习参数,可以进行梯度下降。

3. 网络架构 (Architecture)

3.1 Encoder (Inference Network)

  • 输入:数据 $x$
  • 输出:潜在分布的参数,通常是均值 $\mu$ 和对数方差 $\log \sigma^2$ (预测 log var 是为了数值稳定性)。
  • $Q(z|x) = \mathcal{N}(z; \mu(x), \sigma^2(x))$

3.2 Decoder (Generative Network)

  • 输入:潜在变量 $z$ (通过重参数化采样得到)
  • 输出:重建数据 $\hat{x}$ 或其分布参数。
  • $P(x|z)$

4. 损失函数 (Loss Function)

总 Loss = Reconstruction Loss + KL Divergence Loss

$$ L = L_{recon} + \beta \cdot L_{KL} $$

4.1 Reconstruction Loss

取决于输入数据的假设分布:

  • 如果假设 $P(x|z)$ 是高斯分布 (如实数值图像),使用 MSE (Mean Squared Error)
  • 如果假设 $P(x|z)$ 是伯努利分布 (如二值化图像),使用 BCE (Binary Cross Entropy)

4.2 KL Divergence Loss

假设先验 $P(z) \sim \mathcal{N}(0, I)$,后验 $Q(z|x) \sim \mathcal{N}(\mu, \sigma^2)$。两个高斯分布之间的 KL 散度有解析解:

$$ D_{KL} = -\frac{1}{2} \sum_{j=1}^J (1 + \log((\sigma_j)^2) - (\mu_j)^2 - (\sigma_j)^2) $$

5. PyTorch 实现 (Implementation)

import torch
import torch.nn as nn
import torch.nn.functional as F

class VAE(nn.Module):
    def __init__(self, input_dim=784, hidden_dim=400, latent_dim=20):
        super(VAE, self).__init__()
        
        # Encoder
        self.fc1 = nn.Linear(input_dim, hidden_dim)
        self.fc_mu = nn.Linear(hidden_dim, latent_dim)
        self.fc_logvar = nn.Linear(hidden_dim, latent_dim)
        
        # Decoder
        self.fc3 = nn.Linear(latent_dim, hidden_dim)
        self.fc4 = nn.Linear(hidden_dim, input_dim)
        
    def encode(self, x):
        h1 = F.relu(self.fc1(x))
        return self.fc_mu(h1), self.fc_logvar(h1)
    
    def reparameterize(self, mu, logvar):
        std = torch.exp(0.5 * logvar)
        eps = torch.randn_like(std)
        return mu + eps * std
    
    def decode(self, z):
        h3 = F.relu(self.fc3(z))
        return torch.sigmoid(self.fc4(h3)) # 假设输出在 [0, 1] 之间
    
    def forward(self, x):
        mu, logvar = self.encode(x.view(-1, 784))
        z = self.reparameterize(mu, logvar)
        return self.decode(z), mu, logvar

# Loss Function
def loss_function(recon_x, x, mu, logvar):
    # Reconstruction term (BCE for MNIST)
    BCE = F.binary_cross_entropy(recon_x, x.view(-1, 784), reduction='sum')
    
    # KL Divergence term
    # 0.5 * sum(1 + log(sigma^2) - mu^2 - sigma^2)
    KLD = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
    
    return BCE + KLD

6. 优缺点 (Pros & Cons)

  • 优点:
    • 拥有良好的数学理论基础。
    • 训练相对 GAN 更稳定,容易收敛。
    • 学习到的潜在空间 (Latent Space) 比较连续和平滑,适合做插值 (Interpolation)。
  • 缺点:
    • 生成的图像通常比较模糊 (Blurry),不如 GAN 清晰。这是因为 MSE/BCE 损失倾向于取平均值,且变分近似引入了噪声。

7. 变体 (Variants)

  • CVAE (Conditional VAE):
    • 在 Encoder 和 Decoder 中都加入条件信息 $c$ (如类别标签)。
    • 允许控制生成指定类别的样本。
  • Beta-VAE:
    • 在 Loss 中给 KL 项加一个权重系数 $\beta$ ($\beta > 1$)。
    • 促进潜在变量的解耦 (Disentanglement),每个维度控制不同的语义特征。
  • VQ-VAE (Vector Quantized VAE):
    • 使用离散的潜在变量 (Discrete Latent Variables)。
    • 通过 Codebook 进行量化,解决了 VAE 生成模糊的问题,生成的图像质量很高。