GPU显存优化实战:在24GB显卡上训练原本需要80GB的模型
2022年初,我们要训练一个3B参数的视觉模型。理论显存需求约80GB,但手上只有4张RTX 3090(24GB)。经过两周的优化,最终在单卡上跑通了训练。这篇文章记录了从OOM到成功的完整过程。
显存都去哪了?先搞清楚问题
显存占用的四大块
# 以3B参数模型为例,fp32训练
# 1. 模型参数
params = 3e9 * 4 # 3B参数 * 4字节(fp32) = 12GB
# 2. 梯度
gradients = 3e9 * 4 # 与参数同样大小 = 12GB
# 3. 优化器状态 (AdamW)
# momentum: 3B * 4 = 12GB
# variance: 3B * 4 = 12GB
optimizer_states = 12 + 12 # = 24GB
# 4. 激活值 (前向传播的中间结果,用于反向传播)
# 取决于batch size, sequence length, 模型架构
# 假设: batch=8, seq=512, hidden=2048, layers=24
activations = estimate_activations(...) # 可能 20-40GB
# 总计: 12 + 12 + 24 + 30 ≈ 78GB
# 单卡24GB,差了3倍多!
实测显存占用
import torch
def get_gpu_memory():
"""获取当前GPU显存使用情况"""
allocated = torch.cuda.memory_allocated() / 1024**3
reserved = torch.cuda.memory_reserved() / 1024**3
return f"Allocated: {allocated:.2f}GB, Reserved: {reserved:.2f}GB"
# 加载模型后
print(get_gpu_memory()) # Allocated: 11.2GB
# forward后
print(get_gpu_memory()) # Allocated: 28.5GB (激活值)
# backward后
print(get_gpu_memory()) # Allocated: 35.8GB (梯度)
# optimizer.step后
print(get_gpu_memory()) # OOM!
优化一:混合精度训练(AMP)
原理
用fp16代替fp32进行计算,显存减半,速度更快。但要注意数值稳定性。
import torch
from torch.cuda.amp import autocast, GradScaler
# 创建scaler处理梯度缩放
scaler = GradScaler()
model = MyModel().cuda()
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4)
for epoch in range(num_epochs):
for batch in dataloader:
inputs, targets = batch
inputs = inputs.cuda()
targets = targets.cuda()
optimizer.zero_grad()
# 自动混合精度上下文
with autocast():
outputs = model(inputs)
loss = criterion(outputs, targets)
# 缩放loss,防止fp16梯度下溢
scaler.scale(loss).backward()
# unscale梯度,检查inf/nan,更新参数
scaler.step(optimizer)
scaler.update()
效果
| 指标 | FP32 | AMP |
|---|---|---|
| 模型显存 | 12GB | 6GB |
| 激活值显存 | 30GB | 15GB |
| 训练速度 | 1x | 1.5x |
| 模型效果 | 基准 | 基本无损 |
⚠️ 踩坑: 某些操作在fp16下会数值不稳定,比如大矩阵的softmax。PyTorch的autocast会自动把这些操作保持在fp32,但自定义op要注意。
优化二:梯度检查点(Gradient Checkpointing)
原理
正常训练时,前向传播的所有激活值都保存下来用于反向传播。梯度检查点只保存部分"检查点",反向传播时重新计算中间激活值。用时间换空间。
from torch.utils.checkpoint import checkpoint, checkpoint_sequential
class TransformerBlock(nn.Module):
def __init__(self, ...):
self.attn = MultiHeadAttention(...)
self.ffn = FeedForward(...)
self.norm1 = nn.LayerNorm(...)
self.norm2 = nn.LayerNorm(...)
def forward(self, x):
x = x + self.attn(self.norm1(x))
x = x + self.ffn(self.norm2(x))
return x
class Model(nn.Module):
def __init__(self, num_layers=24, use_checkpoint=True):
self.layers = nn.ModuleList([
TransformerBlock(...) for _ in range(num_layers)
])
self.use_checkpoint = use_checkpoint
def forward(self, x):
for layer in self.layers:
if self.use_checkpoint and self.training:
# 使用checkpoint,不保存这一层的激活值
x = checkpoint(layer, x, use_reentrant=False)
else:
x = layer(x)
return x
更细粒度的checkpoint
# 对于特别大的层,可以在层内部也加checkpoint
class LargeFFN(nn.Module):
def __init__(self, d_model, d_ff):
self.w1 = nn.Linear(d_model, d_ff)
self.w2 = nn.Linear(d_ff, d_model)
self.act = nn.GELU()
def forward(self, x):
def ffn_forward(x):
return self.w2(self.act(self.w1(x)))
if self.training:
return checkpoint(ffn_forward, x, use_reentrant=False)
return ffn_forward(x)
效果
| Checkpoint策略 | 激活值显存 | 训练速度 |
|---|---|---|
| 不使用 | 30GB | 1x |
| 每层checkpoint | 8GB | 0.7x |
| 每2层checkpoint | 15GB | 0.85x |
优化三:梯度累积
原理
batch size受显存限制时,通过累积多个小batch的梯度,模拟大batch训练。
accumulation_steps = 8 # 累积8步
effective_batch_size = per_device_batch * accumulation_steps # 2 * 8 = 16
model.train()
optimizer.zero_grad()
for i, batch in enumerate(dataloader):
inputs, targets = batch
with autocast():
outputs = model(inputs.cuda())
loss = criterion(outputs, targets.cuda())
# 除以累积步数,保持loss scale一致
loss = loss / accumulation_steps
scaler.scale(loss).backward()
# 每accumulation_steps步更新一次
if (i + 1) % accumulation_steps == 0:
scaler.step(optimizer)
scaler.update()
optimizer.zero_grad()
注意事项
# BatchNorm的问题
# BN统计量是在单个小batch上计算的,不是effective batch
# 解决方案1: 用LayerNorm或GroupNorm代替
# 解决方案2: 使用SyncBatchNorm(多卡情况)
# learning rate的调整
# 有些说法认为大batch需要调大lr,比如linear scaling rule
# 但我的实践是:保持lr不变,或者根据实际效果微调
优化四:优化器状态卸载到CPU
DeepSpeed ZeRO-Offload
# deepspeed_config.json
{
"zero_optimization": {
"stage": 2,
"offload_optimizer": {
"device": "cpu",
"pin_memory": true
},
"contiguous_gradients": true,
"overlap_comm": true
},
"fp16": {
"enabled": true,
"loss_scale": 0,
"initial_scale_power": 16
}
}
# 使用
import deepspeed
model, optimizer, _, _ = deepspeed.initialize(
model=model,
model_parameters=model.parameters(),
config="deepspeed_config.json"
)
for batch in dataloader:
loss = model(batch)
model.backward(loss)
model.step()
效果
| 配置 | 优化器显存 | 训练速度 |
|---|---|---|
| 全部在GPU | 24GB | 1x |
| Offload到CPU | 0GB | 0.6x |
| Offload + NVMe | 0GB | 0.4x |
💡 什么时候用Offload: 显存实在不够时的最后手段。速度损失明显,但起码能跑起来。
优化五:激活值压缩
GACT:激活值量化
# 保存激活值时用int8,反向传播时还原
# 论文: GACT: Activation Compressed Training
class QuantizedActivation(torch.autograd.Function):
@staticmethod
def forward(ctx, x):
# 量化到int8
scale = x.abs().max() / 127
x_int8 = (x / scale).round().clamp(-128, 127).to(torch.int8)
# 保存量化后的激活值和scale
ctx.save_for_backward(x_int8)
ctx.scale = scale
return x # 前向输出不变
@staticmethod
def backward(ctx, grad_output):
x_int8, = ctx.saved_tensors
# 反量化
x = x_int8.float() * ctx.scale
# ... 计算梯度
return grad_output
# 可以节省约75%的激活值显存
# 但实现复杂,且有精度损失
组合优化:最终方案
我们的配置
# 目标: 3B模型在单卡24GB上训练
# 1. 混合精度: 模型参数和激活值减半
# 2. 梯度检查点: 激活值从15GB降到4GB
# 3. 梯度累积: batch size 2,累积8步
# 4. 优化器: 8-bit Adam (bitsandbytes)
import bitsandbytes as bnb
model = MyModel().cuda()
model.gradient_checkpointing_enable()
# 8-bit Adam,优化器状态减少75%
optimizer = bnb.optim.Adam8bit(
model.parameters(),
lr=1e-4,
betas=(0.9, 0.999)
)
scaler = GradScaler()
accumulation_steps = 8
for epoch in range(num_epochs):
for i, batch in enumerate(dataloader):
with autocast():
loss = model(batch) / accumulation_steps
scaler.scale(loss).backward()
if (i + 1) % accumulation_steps == 0:
scaler.step(optimizer)
scaler.update()
optimizer.zero_grad()
显存breakdown
| 项目 | 优化前 | 优化后 |
|---|---|---|
| 模型参数 | 12GB | 6GB (fp16) |
| 梯度 | 12GB | 6GB (fp16) |
| 优化器状态 | 24GB | 6GB (8-bit) |
| 激活值 | 30GB | 4GB (checkpoint) |
| 总计 | 78GB | 22GB |
从78GB优化到22GB,单卡3090(24GB)可以跑了!
监控和调试工具
实时显存监控
import torch
import threading
import time
class GPUMemoryMonitor:
def __init__(self, interval=1.0):
self.interval = interval
self.running = False
self.peak_memory = 0
def start(self):
self.running = True
self.thread = threading.Thread(target=self._monitor)
self.thread.start()
def stop(self):
self.running = False
self.thread.join()
return self.peak_memory
def _monitor(self):
while self.running:
current = torch.cuda.memory_allocated() / 1024**3
self.peak_memory = max(self.peak_memory, current)
time.sleep(self.interval)
# 使用
monitor = GPUMemoryMonitor()
monitor.start()
# ... 训练代码 ...
peak = monitor.stop()
print(f"Peak GPU memory: {peak:.2f}GB")
显存快照
# PyTorch 2.0+ 的显存快照功能
torch.cuda.memory._record_memory_history()
# ... 训练代码 ...
# 导出快照
snapshot = torch.cuda.memory._snapshot()
with open("memory_snapshot.pickle", "wb") as f:
pickle.dump(snapshot, f)
# 用官方工具可视化分析
# https://pytorch.org/memory_viz
找出显存泄漏
def check_memory_leak():
"""检查是否有tensor没被释放"""
gc.collect()
torch.cuda.empty_cache()
for obj in gc.get_objects():
try:
if torch.is_tensor(obj) and obj.is_cuda:
print(f"Tensor: {obj.shape}, {obj.dtype}, {obj.device}")
except:
pass
# 每个epoch结束后调用
check_memory_leak()
常见OOM问题排查
问题1:显存碎片
# 症状:显存显示还有空间,但分配失败
# 原因:显存碎片化,没有连续的大块空间
# 解决:
torch.cuda.empty_cache() # 清理缓存
# 或者设置环境变量
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "max_split_size_mb:512"
问题2:eval时OOM
# 症状:训练OK,但eval时OOM
# 原因:eval时batch size可能更大,或者没关闭梯度
# 解决:
model.eval()
with torch.no_grad(): # 一定要加!
outputs = model(inputs)
问题3:DataLoader占显存
# 症状:数据加载后显存飙升
# 原因:pin_memory把数据放到了锁页内存
# 解决:
dataloader = DataLoader(
dataset,
pin_memory=False, # 如果显存紧张,关掉
num_workers=4
)
参考资料
- PyTorch AMP Documentation
- Gradient Checkpointing
- DeepSpeed ZeRO-Offload
- GACT: Activation Compressed Training
总结
GPU显存优化的核心思路:
- 先分析:搞清楚显存被谁占了
- 混合精度:必开,几乎无成本
- 梯度检查点:激活值太大时用,用时间换空间
- 梯度累积:batch size受限时用
- 优化器优化:8-bit Adam或Offload
- 监控调试:随时知道显存去哪了
更新记录:
2022-06-30: 初版发布
2022-12-15: 补充8-bit Adam
2023-05-20: 更新PyTorch 2.0显存分析工具
2023-11-10: 补充显存碎片问题