问题回溯

记一次DDP 分布式训练调试过程

1. 问题现象回溯

在进行 DDP (Distributed Data Parallel) 多卡分布式训练时,模型在相同的 Checkpoint 起点下,训练出的结果在推理时出现人脸颜色异常(发绿或色彩偏差)。而相同的参数和 Checkpoint 在单机单卡训练下表现完全正常。


2. 核心原因分析 (Root Causes)

经过多轮排查和隔离实验(包括“单卡 DDP”模式复现),锁定了以下三个核心原因:

A. 数据流绝对重复:DistributedSampler 缺少 set_epoch

  • 深层原因:在非分布式训练中,RandomSampler 会在每个 Epoch 自动生成全新随机序列。但在 DDP 中,DistributedSampler 的 shuffling 依赖于 epoch 种子。如果代码中没有显式调用 sampler.set_epoch(epoch),Sampler 默认 epoch=0
  • 后果:模型在整个训练生命周期内,每个 Epoch 看到的数据顺序完全一模一样。这种绝对的统计重复会导致 Discriminator 极速过拟合特定顺序组合,进而引发 GAN 的模式崩溃 (Mode Collapse),表现为生成图像色彩失真。
  • 修复:在循环开头添加 dataloader.sampler.set_epoch(idx)(使用累加的迭代序号作为伪 Epoch)。

B. 冻结权重的状态破坏:误用 convert_sync_batchnorm

  • 深层原因:主模型(Generator/Discriminator)使用的是 InstanceNorm,无需同步。但模型中包含冻结的 SyncNet (用于 SyncLoss),其内部含有 BatchNorm
  • 后果:Trainer 强行调用 convert_sync_batchnorm 会把这些已冻结 (eval()) 的 BN 替换为分布式同步 BN。在分布式环境下,这会破坏 BN 原有的 running_meanrunning_var 统计值,导致计算出的梯度方向错误。
  • 修复:禁止对包含冻结 BN 的异构网络进行全局 SyncBN 转换。

C. VAE 随机性相关:RNG 种子同步冲突

  • 深层原因:所有进程初始种子完全相同。在执行 VAE 的 reparameterize (重参数化) 时,不同进程处理不同数据,却使用了完全相同的随机噪声 eps
  • 后果:引入了不该存在的随机性相关性,干扰了 Latent Space 的解耦和稳定性。
  • 修复:为不同 Rank 的进程设置 base_seed + rank 的独立种子。

3. DDP 开发经验与避坑指南

写 DDP 相关的代码时,务必关注以下几点:

1. 采样器的生命周期

  • 规则:只要用了 DistributedSampler,就必须在每个 Epoch 开始前调用 set_epoch
  • 注意:如果是基于 Iteration 的训练(无明确 Epoch 循环),需要根据 total_steps // len(dataloader) 或 enumerate 序号来构造一个变化的 seed 给 set_epoch

2. 梯度同步设置 (find_unused_parameters)

  • 建议:如果模型内部有条件分支(例如 Warmup 期间不用 GAN Loss,或者部分分支根据输入跳过),必须在 DDP 包装时设置 find_unused_parameters=True。否则会导致进程挂起或梯度更新异常。

3. BN 的处理

  • 规则:先 convert_sync_batchnormDDP 包装。
  • 避坑:如果 Generator 使用 InstanceNorm 而只有 Discriminator 使用 BatchNorm,或者模型中混有冻结的 Pretrained 模型,不要盲目使用全局转换,应该针对性地对特定模块进行转换。

4. 损失值的 Reduction

  • 规则:DDP 会自动在 backward() 时同步梯度。
  • 注意:在打印 Log 或记录 Loss 时,如果是自己在各 Rank 算的 Loss,记得求平均;如果是 DDP 内部计算的,通常已经处理好。

5. 学习率 (LR) 缩放

  • 规则:有效 Batch Size = batch_size_per_gpu * world_size
  • 建议:理论上 LR 应该随有效 Batch 线性增加或按平方根缩放。但在 GAN 训练中,如果发现不稳定,优先尝试保持单卡 LR 或仅微调,因为 Discriminator 的收敛平衡对 Batch 极其敏感。