← 返回首页

PyTorch分布式训练实战:从单卡到8卡的踩坑全记录

2022-04-12 | 分布式训练 PyTorch DDP
2022年初,团队拿到了一台8卡A100服务器。我负责把原来单卡跑3天的训练任务改成分布式,目标是缩短到1天内。听起来简单,实际踩了无数坑。这篇文章记录完整的迁移过程。

分布式训练的两种模式

数据并行(Data Parallel)

最常用的方式:每张卡都有完整的模型副本,各自处理不同的数据批次,然后同步梯度。

模型并行(Model Parallel)

模型太大单卡放不下时,把模型拆分到多张卡上。实现复杂,这篇文章不涉及。

DDP vs DataParallel

特性 nn.DataParallel DDP
进程模型 单进程多线程 多进程
GIL影响
通信效率 低(所有梯度汇总到GPU0) 高(Ring AllReduce)
扩展性 单机 单机/多机
推荐度 ❌ 不推荐 ✅ 推荐
⚠️ 不要用DataParallel! 它看起来简单(只需要一行代码),但性能差很多。我们实测8卡A100,DataParallel只有2.8x加速,而DDP有7.2x。

DDP完整代码模板

import os
import torch
import torch.distributed as dist
import torch.nn as nn
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.data import DataLoader, DistributedSampler

def setup(rank, world_size):
    """初始化分布式环境"""
    os.environ['MASTER_ADDR'] = 'localhost'
    os.environ['MASTER_PORT'] = '12355'
    
    # 初始化进程组
    dist.init_process_group(
        backend='nccl',  # GPU用nccl,CPU用gloo
        rank=rank,
        world_size=world_size
    )
    
    # 设置当前进程使用的GPU
    torch.cuda.set_device(rank)

def cleanup():
    """清理分布式环境"""
    dist.destroy_process_group()

def train(rank, world_size, epochs=10):
    """训练函数,每个进程都会执行"""
    setup(rank, world_size)
    
    # 1. 创建模型并移到对应GPU
    model = MyModel().to(rank)
    
    # 2. 用DDP包装模型
    model = DDP(model, device_ids=[rank])
    
    # 3. 创建分布式采样器
    dataset = MyDataset()
    sampler = DistributedSampler(
        dataset, 
        num_replicas=world_size,
        rank=rank,
        shuffle=True
    )
    
    # 4. 创建DataLoader(注意不要在这里shuffle)
    dataloader = DataLoader(
        dataset,
        batch_size=32,
        sampler=sampler,  # 用sampler而不是shuffle
        num_workers=4,
        pin_memory=True
    )
    
    # 5. 优化器
    optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4)
    criterion = nn.CrossEntropyLoss()
    
    # 6. 训练循环
    for epoch in range(epochs):
        # 重要:每个epoch要设置sampler的epoch
        sampler.set_epoch(epoch)
        
        model.train()
        for batch_idx, (data, target) in enumerate(dataloader):
            data = data.to(rank)
            target = target.to(rank)
            
            optimizer.zero_grad()
            output = model(data)
            loss = criterion(output, target)
            loss.backward()
            optimizer.step()
            
            # 只在rank 0打印日志
            if rank == 0 and batch_idx % 100 == 0:
                print(f"Epoch {epoch}, Batch {batch_idx}, Loss: {loss.item():.4f}")
        
        # 只在rank 0保存checkpoint
        if rank == 0:
            torch.save(model.module.state_dict(), f"checkpoint_epoch{epoch}.pt")
    
    cleanup()

if __name__ == "__main__":
    world_size = torch.cuda.device_count()
    
    # 使用spawn启动多进程
    import torch.multiprocessing as mp
    mp.spawn(train, args=(world_size,), nprocs=world_size, join=True)

更简单的启动方式:torchrun

# train_ddp.py
import os
import torch
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP

def main():
    # torchrun会自动设置这些环境变量
    local_rank = int(os.environ["LOCAL_RANK"])
    world_size = int(os.environ["WORLD_SIZE"])
    rank = int(os.environ["RANK"])
    
    # 初始化
    dist.init_process_group(backend="nccl")
    torch.cuda.set_device(local_rank)
    
    # 模型
    model = MyModel().to(local_rank)
    model = DDP(model, device_ids=[local_rank])
    
    # ... 训练代码 ...
    
    dist.destroy_process_group()

if __name__ == "__main__":
    main()

启动命令:

# 单机8卡
torchrun --nproc_per_node=8 train_ddp.py

# 多机(2台机器,每台8卡)
# 机器1(master):
torchrun --nproc_per_node=8 --nnodes=2 --node_rank=0 \
    --master_addr=192.168.1.1 --master_port=12355 train_ddp.py

# 机器2:
torchrun --nproc_per_node=8 --nnodes=2 --node_rank=1 \
    --master_addr=192.168.1.1 --master_port=12355 train_ddp.py

踩坑记录

坑1:DistributedSampler没设置epoch

# ❌ 错误:每个epoch数据顺序相同
for epoch in range(epochs):
    for data in dataloader:
        ...

# ✅ 正确:每个epoch打乱顺序
for epoch in range(epochs):
    sampler.set_epoch(epoch)  # 关键!
    for data in dataloader:
        ...

坑2:保存模型时用model.module

# ❌ 错误:保存的是DDP wrapper
torch.save(model.state_dict(), "model.pt")
# 加载时会报错:key mismatch

# ✅ 正确:保存内部模型
torch.save(model.module.state_dict(), "model.pt")

坑3:batch size的理解

# 设置 batch_size=32, 8卡
# 每卡处理32个样本
# 全局有效batch size = 32 * 8 = 256

# 如果要保持和单卡相同的有效batch size:
# 单卡batch_size=256
# 8卡应该设置batch_size=32

坑4:学习率要不要调?

# Linear Scaling Rule (来自Facebook):
# 有效batch size变大N倍,学习率也要变大N倍

# 但我的实践经验:
# 1. 先保持lr不变试试
# 2. 如果训练不稳定,适当调小lr
# 3. 如果收敛太慢,适当调大lr
# 4. 用warmup总是好的

坑5:随机种子同步

# 确保所有进程的随机性一致(模型初始化等)
def set_seed(seed, rank):
    torch.manual_seed(seed + rank)
    torch.cuda.manual_seed(seed + rank)
    np.random.seed(seed + rank)
    random.seed(seed + rank)

# 但数据shuffle要不同(DistributedSampler自动处理)

坑6:进程hang住不动

# 症状:训练到某处卡住,没有任何输出
# 原因:某个进程出错/提前结束,其他进程等待同步

# 调试方法1:设置超时
dist.init_process_group(backend='nccl', timeout=timedelta(minutes=30))

# 调试方法2:加环境变量看详细日志
# NCCL_DEBUG=INFO torchrun ...

性能优化

梯度通信优化

# 默认:反向传播完成后才开始AllReduce
# 优化:边计算边通信(overlap)

model = DDP(
    model, 
    device_ids=[local_rank],
    broadcast_buffers=False,  # BN等buffer不广播
    gradient_as_bucket_view=True,  # 减少内存拷贝
    static_graph=True  # 如果计算图固定,启用优化
)

混合精度+DDP

from torch.cuda.amp import autocast, GradScaler

scaler = GradScaler()

for data, target in dataloader:
    optimizer.zero_grad()
    
    with autocast():
        output = model(data)
        loss = criterion(output, target)
    
    scaler.scale(loss).backward()
    scaler.step(optimizer)
    scaler.update()

数据加载优化

dataloader = DataLoader(
    dataset,
    batch_size=32,
    sampler=sampler,
    num_workers=8,        # 增加worker数
    pin_memory=True,      # 锁页内存
    prefetch_factor=2,    # 预加载
    persistent_workers=True  # worker复用
)

加速比测试

GPU数量 吞吐量(samples/s) 加速比 效率
1 256 1.0x 100%
2 498 1.95x 97%
4 972 3.80x 95%
8 1843 7.20x 90%

8卡达到7.2x加速比,效率90%,符合预期。损失主要来自梯度同步通信。

参考资料

总结

  1. 用DDP不用DataParallel
  2. 用torchrun启动最方便
  3. 记得sampler.set_epoch()
  4. 保存模型用model.module
  5. batch size和lr要考虑调整

更新记录:
2022-04-12: 初版发布
2022-08-20: 补充torchrun用法
2023-03-15: 更新性能优化技巧