VAE
Variational Autoencoder (VAE)
1. 简介 (Introduction)
VAE (Variational Autoencoder) 是一种生成模型 (Generative Model),由 Kingma 和 Welling 于 2013 年提出。与传统的自编码器 (Autoencoder) 不同,VAE 的目标不仅仅是重建输入,而是学习输入数据的潜在分布 (Latent Distribution),从而能够生成新的数据样本。
- Autoencoder: 学习一个确定的压缩表示 (Compressed Representation)。主要用于降维、去噪。
- VAE: 学习潜在变量的概率分布 (Probability Distribution)。主要用于生成新样本。
2. 数学基础 (Mathematical Foundation)
2.1 目标
假设我们有一组数据 $X = \{x^{(i)}\}_{i=1}^N$,我们希望对数据生成过程进行建模,即学习联合概率 $P(x, z)$,其中 $z$ 是不可观测的潜在变量 (Latent Variable)。
2.2 变分推断 (Variational Inference)
我们需要计算后验概率 $P(z|x)$,但根据贝叶斯公式 $P(z|x) = \frac{P(x|z)P(z)}{P(x)}$,分母 $P(x) = \int P(x|z)P(z) dz$ 通常是难以计算的 (Intractable)。
因此,我们引入一个近似分布 $Q(z|x)$ (即 Encoder),并试图让它尽可能接近真实的后验分布 $P(z|x)$。衡量两个分布相似度的指标是 KL 散度 (Kullback-Leibler Divergence)。
2.3 ELBO (Evidence Lower Bound)
我们的目标是最大化观测数据的对数似然 $\log P(x)$。通过推导,可以得到:
$$ \log P(x) - D_{KL}[Q(z|x) || P(z|x)] = \mathbb{E}_{z \sim Q}[\log P(x|z)] - D_{KL}[Q(z|x) || P(z)] $$由于 $D_{KL} \ge 0$,所以:
$$ \log P(x) \ge \mathbb{E}_{z \sim Q}[\log P(x|z)] - D_{KL}[Q(z|x) || P(z)] $$右边这一项被称为 ELBO (Evidence Lower Bound)。最大化 $\log P(x)$ 等价于最大化 ELBO。
ELBO 由两部分组成:
- Reconstruction Term $\mathbb{E}_{z \sim Q}[\log P(x|z)]$:希望 Decoder 能从 $z$ 很好地重建 $x$。
- Regularization Term $- D_{KL}[Q(z|x) || P(z)]$:希望 Encoder 输出的分布 $Q(z|x)$ 接近先验分布 $P(z)$ (通常假设为标准正态分布 $\mathcal{N}(0, I)$)。
2.4 重参数化技巧 (Reparameterization Trick)
为了能够对 $z$ 进行采样并进行反向传播,VAE 使用了重参数化技巧。 直接从 $\mathcal{N}(\mu, \sigma^2)$ 采样 $z$ 是不可导的操作。 我们将 $z$ 表示为:
$$ z = \mu + \sigma \odot \epsilon, \quad \epsilon \sim \mathcal{N}(0, I) $$这样,随机性转移到了 $\epsilon$ 上,而 $\mu$ 和 $\sigma$ 则是网络的可学习参数,可以进行梯度下降。
3. 网络架构 (Architecture)
3.1 Encoder (Inference Network)
- 输入:数据 $x$
- 输出:潜在分布的参数,通常是均值 $\mu$ 和对数方差 $\log \sigma^2$ (预测 log var 是为了数值稳定性)。
- $Q(z|x) = \mathcal{N}(z; \mu(x), \sigma^2(x))$
3.2 Decoder (Generative Network)
- 输入:潜在变量 $z$ (通过重参数化采样得到)
- 输出:重建数据 $\hat{x}$ 或其分布参数。
- $P(x|z)$
4. 损失函数 (Loss Function)
总 Loss = Reconstruction Loss + KL Divergence Loss
$$ L = L_{recon} + \beta \cdot L_{KL} $$4.1 Reconstruction Loss
取决于输入数据的假设分布:
- 如果假设 $P(x|z)$ 是高斯分布 (如实数值图像),使用 MSE (Mean Squared Error)。
- 如果假设 $P(x|z)$ 是伯努利分布 (如二值化图像),使用 BCE (Binary Cross Entropy)。
4.2 KL Divergence Loss
假设先验 $P(z) \sim \mathcal{N}(0, I)$,后验 $Q(z|x) \sim \mathcal{N}(\mu, \sigma^2)$。两个高斯分布之间的 KL 散度有解析解:
$$ D_{KL} = -\frac{1}{2} \sum_{j=1}^J (1 + \log((\sigma_j)^2) - (\mu_j)^2 - (\sigma_j)^2) $$5. PyTorch 实现 (Implementation)
import torch
import torch.nn as nn
import torch.nn.functional as F
class VAE(nn.Module):
def __init__(self, input_dim=784, hidden_dim=400, latent_dim=20):
super(VAE, self).__init__()
# Encoder
self.fc1 = nn.Linear(input_dim, hidden_dim)
self.fc_mu = nn.Linear(hidden_dim, latent_dim)
self.fc_logvar = nn.Linear(hidden_dim, latent_dim)
# Decoder
self.fc3 = nn.Linear(latent_dim, hidden_dim)
self.fc4 = nn.Linear(hidden_dim, input_dim)
def encode(self, x):
h1 = F.relu(self.fc1(x))
return self.fc_mu(h1), self.fc_logvar(h1)
def reparameterize(self, mu, logvar):
std = torch.exp(0.5 * logvar)
eps = torch.randn_like(std)
return mu + eps * std
def decode(self, z):
h3 = F.relu(self.fc3(z))
return torch.sigmoid(self.fc4(h3)) # 假设输出在 [0, 1] 之间
def forward(self, x):
mu, logvar = self.encode(x.view(-1, 784))
z = self.reparameterize(mu, logvar)
return self.decode(z), mu, logvar
# Loss Function
def loss_function(recon_x, x, mu, logvar):
# Reconstruction term (BCE for MNIST)
BCE = F.binary_cross_entropy(recon_x, x.view(-1, 784), reduction='sum')
# KL Divergence term
# 0.5 * sum(1 + log(sigma^2) - mu^2 - sigma^2)
KLD = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
return BCE + KLD6. 优缺点 (Pros & Cons)
- 优点:
- 拥有良好的数学理论基础。
- 训练相对 GAN 更稳定,容易收敛。
- 学习到的潜在空间 (Latent Space) 比较连续和平滑,适合做插值 (Interpolation)。
- 缺点:
- 生成的图像通常比较模糊 (Blurry),不如 GAN 清晰。这是因为 MSE/BCE 损失倾向于取平均值,且变分近似引入了噪声。
7. 变体 (Variants)
- CVAE (Conditional VAE):
- 在 Encoder 和 Decoder 中都加入条件信息 $c$ (如类别标签)。
- 允许控制生成指定类别的样本。
- Beta-VAE:
- 在 Loss 中给 KL 项加一个权重系数 $\beta$ ($\beta > 1$)。
- 促进潜在变量的解耦 (Disentanglement),每个维度控制不同的语义特征。
- VQ-VAE (Vector Quantized VAE):
- 使用离散的潜在变量 (Discrete Latent Variables)。
- 通过 Codebook 进行量化,解决了 VAE 生成模糊的问题,生成的图像质量很高。