PyTorch分布式训练实战:从单卡到8卡的踩坑全记录
2022年初,团队拿到了一台8卡A100服务器。我负责把原来单卡跑3天的训练任务改成分布式,目标是缩短到1天内。听起来简单,实际踩了无数坑。这篇文章记录完整的迁移过程。
分布式训练的两种模式
数据并行(Data Parallel)
最常用的方式:每张卡都有完整的模型副本,各自处理不同的数据批次,然后同步梯度。
模型并行(Model Parallel)
模型太大单卡放不下时,把模型拆分到多张卡上。实现复杂,这篇文章不涉及。
DDP vs DataParallel
| 特性 | nn.DataParallel | DDP |
|---|---|---|
| 进程模型 | 单进程多线程 | 多进程 |
| GIL影响 | 有 | 无 |
| 通信效率 | 低(所有梯度汇总到GPU0) | 高(Ring AllReduce) |
| 扩展性 | 单机 | 单机/多机 |
| 推荐度 | ❌ 不推荐 | ✅ 推荐 |
⚠️ 不要用DataParallel! 它看起来简单(只需要一行代码),但性能差很多。我们实测8卡A100,DataParallel只有2.8x加速,而DDP有7.2x。
DDP完整代码模板
import os
import torch
import torch.distributed as dist
import torch.nn as nn
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.data import DataLoader, DistributedSampler
def setup(rank, world_size):
"""初始化分布式环境"""
os.environ['MASTER_ADDR'] = 'localhost'
os.environ['MASTER_PORT'] = '12355'
# 初始化进程组
dist.init_process_group(
backend='nccl', # GPU用nccl,CPU用gloo
rank=rank,
world_size=world_size
)
# 设置当前进程使用的GPU
torch.cuda.set_device(rank)
def cleanup():
"""清理分布式环境"""
dist.destroy_process_group()
def train(rank, world_size, epochs=10):
"""训练函数,每个进程都会执行"""
setup(rank, world_size)
# 1. 创建模型并移到对应GPU
model = MyModel().to(rank)
# 2. 用DDP包装模型
model = DDP(model, device_ids=[rank])
# 3. 创建分布式采样器
dataset = MyDataset()
sampler = DistributedSampler(
dataset,
num_replicas=world_size,
rank=rank,
shuffle=True
)
# 4. 创建DataLoader(注意不要在这里shuffle)
dataloader = DataLoader(
dataset,
batch_size=32,
sampler=sampler, # 用sampler而不是shuffle
num_workers=4,
pin_memory=True
)
# 5. 优化器
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4)
criterion = nn.CrossEntropyLoss()
# 6. 训练循环
for epoch in range(epochs):
# 重要:每个epoch要设置sampler的epoch
sampler.set_epoch(epoch)
model.train()
for batch_idx, (data, target) in enumerate(dataloader):
data = data.to(rank)
target = target.to(rank)
optimizer.zero_grad()
output = model(data)
loss = criterion(output, target)
loss.backward()
optimizer.step()
# 只在rank 0打印日志
if rank == 0 and batch_idx % 100 == 0:
print(f"Epoch {epoch}, Batch {batch_idx}, Loss: {loss.item():.4f}")
# 只在rank 0保存checkpoint
if rank == 0:
torch.save(model.module.state_dict(), f"checkpoint_epoch{epoch}.pt")
cleanup()
if __name__ == "__main__":
world_size = torch.cuda.device_count()
# 使用spawn启动多进程
import torch.multiprocessing as mp
mp.spawn(train, args=(world_size,), nprocs=world_size, join=True)
更简单的启动方式:torchrun
# train_ddp.py
import os
import torch
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP
def main():
# torchrun会自动设置这些环境变量
local_rank = int(os.environ["LOCAL_RANK"])
world_size = int(os.environ["WORLD_SIZE"])
rank = int(os.environ["RANK"])
# 初始化
dist.init_process_group(backend="nccl")
torch.cuda.set_device(local_rank)
# 模型
model = MyModel().to(local_rank)
model = DDP(model, device_ids=[local_rank])
# ... 训练代码 ...
dist.destroy_process_group()
if __name__ == "__main__":
main()
启动命令:
# 单机8卡
torchrun --nproc_per_node=8 train_ddp.py
# 多机(2台机器,每台8卡)
# 机器1(master):
torchrun --nproc_per_node=8 --nnodes=2 --node_rank=0 \
--master_addr=192.168.1.1 --master_port=12355 train_ddp.py
# 机器2:
torchrun --nproc_per_node=8 --nnodes=2 --node_rank=1 \
--master_addr=192.168.1.1 --master_port=12355 train_ddp.py
踩坑记录
坑1:DistributedSampler没设置epoch
# ❌ 错误:每个epoch数据顺序相同
for epoch in range(epochs):
for data in dataloader:
...
# ✅ 正确:每个epoch打乱顺序
for epoch in range(epochs):
sampler.set_epoch(epoch) # 关键!
for data in dataloader:
...
坑2:保存模型时用model.module
# ❌ 错误:保存的是DDP wrapper
torch.save(model.state_dict(), "model.pt")
# 加载时会报错:key mismatch
# ✅ 正确:保存内部模型
torch.save(model.module.state_dict(), "model.pt")
坑3:batch size的理解
# 设置 batch_size=32, 8卡
# 每卡处理32个样本
# 全局有效batch size = 32 * 8 = 256
# 如果要保持和单卡相同的有效batch size:
# 单卡batch_size=256
# 8卡应该设置batch_size=32
坑4:学习率要不要调?
# Linear Scaling Rule (来自Facebook):
# 有效batch size变大N倍,学习率也要变大N倍
# 但我的实践经验:
# 1. 先保持lr不变试试
# 2. 如果训练不稳定,适当调小lr
# 3. 如果收敛太慢,适当调大lr
# 4. 用warmup总是好的
坑5:随机种子同步
# 确保所有进程的随机性一致(模型初始化等)
def set_seed(seed, rank):
torch.manual_seed(seed + rank)
torch.cuda.manual_seed(seed + rank)
np.random.seed(seed + rank)
random.seed(seed + rank)
# 但数据shuffle要不同(DistributedSampler自动处理)
坑6:进程hang住不动
# 症状:训练到某处卡住,没有任何输出
# 原因:某个进程出错/提前结束,其他进程等待同步
# 调试方法1:设置超时
dist.init_process_group(backend='nccl', timeout=timedelta(minutes=30))
# 调试方法2:加环境变量看详细日志
# NCCL_DEBUG=INFO torchrun ...
性能优化
梯度通信优化
# 默认:反向传播完成后才开始AllReduce
# 优化:边计算边通信(overlap)
model = DDP(
model,
device_ids=[local_rank],
broadcast_buffers=False, # BN等buffer不广播
gradient_as_bucket_view=True, # 减少内存拷贝
static_graph=True # 如果计算图固定,启用优化
)
混合精度+DDP
from torch.cuda.amp import autocast, GradScaler
scaler = GradScaler()
for data, target in dataloader:
optimizer.zero_grad()
with autocast():
output = model(data)
loss = criterion(output, target)
scaler.scale(loss).backward()
scaler.step(optimizer)
scaler.update()
数据加载优化
dataloader = DataLoader(
dataset,
batch_size=32,
sampler=sampler,
num_workers=8, # 增加worker数
pin_memory=True, # 锁页内存
prefetch_factor=2, # 预加载
persistent_workers=True # worker复用
)
加速比测试
| GPU数量 | 吞吐量(samples/s) | 加速比 | 效率 |
|---|---|---|---|
| 1 | 256 | 1.0x | 100% |
| 2 | 498 | 1.95x | 97% |
| 4 | 972 | 3.80x | 95% |
| 8 | 1843 | 7.20x | 90% |
8卡达到7.2x加速比,效率90%,符合预期。损失主要来自梯度同步通信。
参考资料
总结
- 用DDP不用DataParallel
- 用torchrun启动最方便
- 记得sampler.set_epoch()
- 保存模型用model.module
- batch size和lr要考虑调整
更新记录:
2022-04-12: 初版发布
2022-08-20: 补充torchrun用法
2023-03-15: 更新性能优化技巧