问题回溯
记一次DDP 分布式训练调试过程
1. 问题现象回溯
在进行 DDP (Distributed Data Parallel) 多卡分布式训练时,模型在相同的 Checkpoint 起点下,训练出的结果在推理时出现人脸颜色异常(发绿或色彩偏差)。而相同的参数和 Checkpoint 在单机单卡训练下表现完全正常。
2. 核心原因分析 (Root Causes)
经过多轮排查和隔离实验(包括“单卡 DDP”模式复现),锁定了以下三个核心原因:
A. 数据流绝对重复:DistributedSampler 缺少 set_epoch
- 深层原因:在非分布式训练中,
RandomSampler会在每个 Epoch 自动生成全新随机序列。但在 DDP 中,DistributedSampler的 shuffling 依赖于epoch种子。如果代码中没有显式调用sampler.set_epoch(epoch),Sampler 默认epoch=0。 - 后果:模型在整个训练生命周期内,每个 Epoch 看到的数据顺序完全一模一样。这种绝对的统计重复会导致 Discriminator 极速过拟合特定顺序组合,进而引发 GAN 的模式崩溃 (Mode Collapse),表现为生成图像色彩失真。
- 修复:在循环开头添加
dataloader.sampler.set_epoch(idx)(使用累加的迭代序号作为伪 Epoch)。
B. 冻结权重的状态破坏:误用 convert_sync_batchnorm
- 深层原因:主模型(Generator/Discriminator)使用的是
InstanceNorm,无需同步。但模型中包含冻结的SyncNet(用于 SyncLoss),其内部含有BatchNorm。 - 后果:Trainer 强行调用
convert_sync_batchnorm会把这些已冻结 (eval()) 的 BN 替换为分布式同步 BN。在分布式环境下,这会破坏 BN 原有的running_mean和running_var统计值,导致计算出的梯度方向错误。 - 修复:禁止对包含冻结 BN 的异构网络进行全局 SyncBN 转换。
C. VAE 随机性相关:RNG 种子同步冲突
- 深层原因:所有进程初始种子完全相同。在执行 VAE 的
reparameterize(重参数化) 时,不同进程处理不同数据,却使用了完全相同的随机噪声eps。 - 后果:引入了不该存在的随机性相关性,干扰了 Latent Space 的解耦和稳定性。
- 修复:为不同 Rank 的进程设置
base_seed + rank的独立种子。
3. DDP 开发经验与避坑指南
写 DDP 相关的代码时,务必关注以下几点:
1. 采样器的生命周期
- 规则:只要用了
DistributedSampler,就必须在每个 Epoch 开始前调用set_epoch。 - 注意:如果是基于 Iteration 的训练(无明确 Epoch 循环),需要根据
total_steps // len(dataloader)或 enumerate 序号来构造一个变化的 seed 给set_epoch。
2. 梯度同步设置 (find_unused_parameters)
- 建议:如果模型内部有条件分支(例如 Warmup 期间不用 GAN Loss,或者部分分支根据输入跳过),必须在 DDP 包装时设置
find_unused_parameters=True。否则会导致进程挂起或梯度更新异常。
3. BN 的处理
- 规则:先
convert_sync_batchnorm再DDP包装。 - 避坑:如果 Generator 使用
InstanceNorm而只有 Discriminator 使用BatchNorm,或者模型中混有冻结的 Pretrained 模型,不要盲目使用全局转换,应该针对性地对特定模块进行转换。
4. 损失值的 Reduction
- 规则:DDP 会自动在
backward()时同步梯度。 - 注意:在打印 Log 或记录 Loss 时,如果是自己在各 Rank 算的 Loss,记得求平均;如果是
DDP内部计算的,通常已经处理好。
5. 学习率 (LR) 缩放
- 规则:有效 Batch Size =
batch_size_per_gpu * world_size。 - 建议:理论上 LR 应该随有效 Batch 线性增加或按平方根缩放。但在 GAN 训练中,如果发现不稳定,优先尝试保持单卡 LR 或仅微调,因为 Discriminator 的收敛平衡对 Batch 极其敏感。