← 返回首页

GPU显存优化实战:在24GB显卡上训练原本需要80GB的模型

2022-06-30 | GPU优化 显存 训练技巧
2022年初,我们要训练一个3B参数的视觉模型。理论显存需求约80GB,但手上只有4张RTX 3090(24GB)。经过两周的优化,最终在单卡上跑通了训练。这篇文章记录了从OOM到成功的完整过程。

显存都去哪了?先搞清楚问题

显存占用的四大块

# 以3B参数模型为例,fp32训练

# 1. 模型参数
params = 3e9 * 4  # 3B参数 * 4字节(fp32) = 12GB

# 2. 梯度
gradients = 3e9 * 4  # 与参数同样大小 = 12GB

# 3. 优化器状态 (AdamW)
# momentum: 3B * 4 = 12GB
# variance: 3B * 4 = 12GB
optimizer_states = 12 + 12  # = 24GB

# 4. 激活值 (前向传播的中间结果,用于反向传播)
# 取决于batch size, sequence length, 模型架构
# 假设: batch=8, seq=512, hidden=2048, layers=24
activations = estimate_activations(...)  # 可能 20-40GB

# 总计: 12 + 12 + 24 + 30 ≈ 78GB
# 单卡24GB,差了3倍多!

实测显存占用

import torch

def get_gpu_memory():
    """获取当前GPU显存使用情况"""
    allocated = torch.cuda.memory_allocated() / 1024**3
    reserved = torch.cuda.memory_reserved() / 1024**3
    return f"Allocated: {allocated:.2f}GB, Reserved: {reserved:.2f}GB"

# 加载模型后
print(get_gpu_memory())  # Allocated: 11.2GB

# forward后
print(get_gpu_memory())  # Allocated: 28.5GB (激活值)

# backward后
print(get_gpu_memory())  # Allocated: 35.8GB (梯度)

# optimizer.step后
print(get_gpu_memory())  # OOM!

优化一:混合精度训练(AMP)

原理

用fp16代替fp32进行计算,显存减半,速度更快。但要注意数值稳定性。

import torch
from torch.cuda.amp import autocast, GradScaler

# 创建scaler处理梯度缩放
scaler = GradScaler()

model = MyModel().cuda()
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4)

for epoch in range(num_epochs):
    for batch in dataloader:
        inputs, targets = batch
        inputs = inputs.cuda()
        targets = targets.cuda()
        
        optimizer.zero_grad()
        
        # 自动混合精度上下文
        with autocast():
            outputs = model(inputs)
            loss = criterion(outputs, targets)
        
        # 缩放loss,防止fp16梯度下溢
        scaler.scale(loss).backward()
        
        # unscale梯度,检查inf/nan,更新参数
        scaler.step(optimizer)
        scaler.update()

效果

指标 FP32 AMP
模型显存 12GB 6GB
激活值显存 30GB 15GB
训练速度 1x 1.5x
模型效果 基准 基本无损
⚠️ 踩坑: 某些操作在fp16下会数值不稳定,比如大矩阵的softmax。PyTorch的autocast会自动把这些操作保持在fp32,但自定义op要注意。

优化二:梯度检查点(Gradient Checkpointing)

原理

正常训练时,前向传播的所有激活值都保存下来用于反向传播。梯度检查点只保存部分"检查点",反向传播时重新计算中间激活值。用时间换空间。

from torch.utils.checkpoint import checkpoint, checkpoint_sequential

class TransformerBlock(nn.Module):
    def __init__(self, ...):
        self.attn = MultiHeadAttention(...)
        self.ffn = FeedForward(...)
        self.norm1 = nn.LayerNorm(...)
        self.norm2 = nn.LayerNorm(...)
    
    def forward(self, x):
        x = x + self.attn(self.norm1(x))
        x = x + self.ffn(self.norm2(x))
        return x

class Model(nn.Module):
    def __init__(self, num_layers=24, use_checkpoint=True):
        self.layers = nn.ModuleList([
            TransformerBlock(...) for _ in range(num_layers)
        ])
        self.use_checkpoint = use_checkpoint
    
    def forward(self, x):
        for layer in self.layers:
            if self.use_checkpoint and self.training:
                # 使用checkpoint,不保存这一层的激活值
                x = checkpoint(layer, x, use_reentrant=False)
            else:
                x = layer(x)
        return x

更细粒度的checkpoint

# 对于特别大的层,可以在层内部也加checkpoint
class LargeFFN(nn.Module):
    def __init__(self, d_model, d_ff):
        self.w1 = nn.Linear(d_model, d_ff)
        self.w2 = nn.Linear(d_ff, d_model)
        self.act = nn.GELU()
    
    def forward(self, x):
        def ffn_forward(x):
            return self.w2(self.act(self.w1(x)))
        
        if self.training:
            return checkpoint(ffn_forward, x, use_reentrant=False)
        return ffn_forward(x)

效果

Checkpoint策略 激活值显存 训练速度
不使用 30GB 1x
每层checkpoint 8GB 0.7x
每2层checkpoint 15GB 0.85x

优化三:梯度累积

原理

batch size受显存限制时,通过累积多个小batch的梯度,模拟大batch训练。

accumulation_steps = 8  # 累积8步
effective_batch_size = per_device_batch * accumulation_steps  # 2 * 8 = 16

model.train()
optimizer.zero_grad()

for i, batch in enumerate(dataloader):
    inputs, targets = batch
    
    with autocast():
        outputs = model(inputs.cuda())
        loss = criterion(outputs, targets.cuda())
        # 除以累积步数,保持loss scale一致
        loss = loss / accumulation_steps
    
    scaler.scale(loss).backward()
    
    # 每accumulation_steps步更新一次
    if (i + 1) % accumulation_steps == 0:
        scaler.step(optimizer)
        scaler.update()
        optimizer.zero_grad()

注意事项

# BatchNorm的问题
# BN统计量是在单个小batch上计算的,不是effective batch
# 解决方案1: 用LayerNorm或GroupNorm代替
# 解决方案2: 使用SyncBatchNorm(多卡情况)

# learning rate的调整
# 有些说法认为大batch需要调大lr,比如linear scaling rule
# 但我的实践是:保持lr不变,或者根据实际效果微调

优化四:优化器状态卸载到CPU

DeepSpeed ZeRO-Offload

# deepspeed_config.json
{
    "zero_optimization": {
        "stage": 2,
        "offload_optimizer": {
            "device": "cpu",
            "pin_memory": true
        },
        "contiguous_gradients": true,
        "overlap_comm": true
    },
    "fp16": {
        "enabled": true,
        "loss_scale": 0,
        "initial_scale_power": 16
    }
}

# 使用
import deepspeed

model, optimizer, _, _ = deepspeed.initialize(
    model=model,
    model_parameters=model.parameters(),
    config="deepspeed_config.json"
)

for batch in dataloader:
    loss = model(batch)
    model.backward(loss)
    model.step()

效果

配置 优化器显存 训练速度
全部在GPU 24GB 1x
Offload到CPU 0GB 0.6x
Offload + NVMe 0GB 0.4x
💡 什么时候用Offload: 显存实在不够时的最后手段。速度损失明显,但起码能跑起来。

优化五:激活值压缩

GACT:激活值量化

# 保存激活值时用int8,反向传播时还原
# 论文: GACT: Activation Compressed Training

class QuantizedActivation(torch.autograd.Function):
    @staticmethod
    def forward(ctx, x):
        # 量化到int8
        scale = x.abs().max() / 127
        x_int8 = (x / scale).round().clamp(-128, 127).to(torch.int8)
        
        # 保存量化后的激活值和scale
        ctx.save_for_backward(x_int8)
        ctx.scale = scale
        
        return x  # 前向输出不变
    
    @staticmethod
    def backward(ctx, grad_output):
        x_int8, = ctx.saved_tensors
        # 反量化
        x = x_int8.float() * ctx.scale
        # ... 计算梯度
        return grad_output

# 可以节省约75%的激活值显存
# 但实现复杂,且有精度损失

组合优化:最终方案

我们的配置

# 目标: 3B模型在单卡24GB上训练

# 1. 混合精度: 模型参数和激活值减半
# 2. 梯度检查点: 激活值从15GB降到4GB  
# 3. 梯度累积: batch size 2,累积8步
# 4. 优化器: 8-bit Adam (bitsandbytes)

import bitsandbytes as bnb

model = MyModel().cuda()
model.gradient_checkpointing_enable()

# 8-bit Adam,优化器状态减少75%
optimizer = bnb.optim.Adam8bit(
    model.parameters(),
    lr=1e-4,
    betas=(0.9, 0.999)
)

scaler = GradScaler()
accumulation_steps = 8

for epoch in range(num_epochs):
    for i, batch in enumerate(dataloader):
        with autocast():
            loss = model(batch) / accumulation_steps
        
        scaler.scale(loss).backward()
        
        if (i + 1) % accumulation_steps == 0:
            scaler.step(optimizer)
            scaler.update()
            optimizer.zero_grad()

显存breakdown

项目 优化前 优化后
模型参数 12GB 6GB (fp16)
梯度 12GB 6GB (fp16)
优化器状态 24GB 6GB (8-bit)
激活值 30GB 4GB (checkpoint)
总计 78GB 22GB

从78GB优化到22GB,单卡3090(24GB)可以跑了!

监控和调试工具

实时显存监控

import torch
import threading
import time

class GPUMemoryMonitor:
    def __init__(self, interval=1.0):
        self.interval = interval
        self.running = False
        self.peak_memory = 0
        
    def start(self):
        self.running = True
        self.thread = threading.Thread(target=self._monitor)
        self.thread.start()
        
    def stop(self):
        self.running = False
        self.thread.join()
        return self.peak_memory
        
    def _monitor(self):
        while self.running:
            current = torch.cuda.memory_allocated() / 1024**3
            self.peak_memory = max(self.peak_memory, current)
            time.sleep(self.interval)

# 使用
monitor = GPUMemoryMonitor()
monitor.start()

# ... 训练代码 ...

peak = monitor.stop()
print(f"Peak GPU memory: {peak:.2f}GB")

显存快照

# PyTorch 2.0+ 的显存快照功能
torch.cuda.memory._record_memory_history()

# ... 训练代码 ...

# 导出快照
snapshot = torch.cuda.memory._snapshot()
with open("memory_snapshot.pickle", "wb") as f:
    pickle.dump(snapshot, f)

# 用官方工具可视化分析
# https://pytorch.org/memory_viz

找出显存泄漏

def check_memory_leak():
    """检查是否有tensor没被释放"""
    gc.collect()
    torch.cuda.empty_cache()
    
    for obj in gc.get_objects():
        try:
            if torch.is_tensor(obj) and obj.is_cuda:
                print(f"Tensor: {obj.shape}, {obj.dtype}, {obj.device}")
        except:
            pass

# 每个epoch结束后调用
check_memory_leak()

常见OOM问题排查

问题1:显存碎片

# 症状:显存显示还有空间,但分配失败
# 原因:显存碎片化,没有连续的大块空间

# 解决:
torch.cuda.empty_cache()  # 清理缓存
# 或者设置环境变量
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "max_split_size_mb:512"

问题2:eval时OOM

# 症状:训练OK,但eval时OOM
# 原因:eval时batch size可能更大,或者没关闭梯度

# 解决:
model.eval()
with torch.no_grad():  # 一定要加!
    outputs = model(inputs)

问题3:DataLoader占显存

# 症状:数据加载后显存飙升
# 原因:pin_memory把数据放到了锁页内存

# 解决:
dataloader = DataLoader(
    dataset,
    pin_memory=False,  # 如果显存紧张,关掉
    num_workers=4
)

参考资料

总结

GPU显存优化的核心思路:

  1. 先分析:搞清楚显存被谁占了
  2. 混合精度:必开,几乎无成本
  3. 梯度检查点:激活值太大时用,用时间换空间
  4. 梯度累积:batch size受限时用
  5. 优化器优化:8-bit Adam或Offload
  6. 监控调试:随时知道显存去哪了

更新记录:
2022-06-30: 初版发布
2022-12-15: 补充8-bit Adam
2023-05-20: 更新PyTorch 2.0显存分析工具
2023-11-10: 补充显存碎片问题