← 返回首页

Attention Is All You Need 重读笔记:从困惑到顿悟的72小时

2024-03-12 | 论文解读 Transformer 深度学习 从零实现
2023年12月,我决定从头实现一遍Transformer。本以为看了那么多博客早就理解了,结果花了整整三天才真正跑通。这篇文章记录了我重读论文、手写代码、调试踩坑的全过程。包含完整可运行代码和我个人的一些思考。

背景:为什么要重新读这篇论文

说来惭愧,虽然天天用Transformer,但我之前对它的理解一直停留在"知道有Q、K、V"的水平。去年底在做一个序列生成项目时,我发现模型在长序列上表现很差,想优化却无从下手。问题出在我根本不理解底层原理。

于是我决定花几天时间,把"Attention Is All You Need"从头读一遍,并且亲手实现每一个组件。目标是:

实际花了72小时,比预想的久得多。但这72小时让我对Transformer的理解提升了一个量级。

第一步:理解Self-Attention的本质

从信息检索的角度看Attention

论文里直接给出了公式,但没解释为什么这样设计。我想了很久,发现最好的理解方式是把Attention看成一个软性的信息检索系统:

传统的检索是硬性的:找到最匹配的那一个。Self-Attention是软性的:所有位置都参与,只是权重不同。

Attention(Q, K, V) = softmax(QKT / √dk) · V

从零实现Scaled Dot-Product Attention

import torch
import torch.nn as nn
import torch.nn.functional as F
import math

def scaled_dot_product_attention(
    query: torch.Tensor,
    key: torch.Tensor, 
    value: torch.Tensor,
    mask: torch.Tensor = None,
    dropout: nn.Dropout = None
) -> tuple[torch.Tensor, torch.Tensor]:
    """
    计算缩放点积注意力
    
    Args:
        query: [batch_size, n_heads, seq_len, d_k]
        key: [batch_size, n_heads, seq_len, d_k]  
        value: [batch_size, n_heads, seq_len, d_v]
        mask: [batch_size, 1, 1, seq_len] or [batch_size, 1, seq_len, seq_len]
        dropout: dropout层
        
    Returns:
        output: [batch_size, n_heads, seq_len, d_v]
        attention_weights: [batch_size, n_heads, seq_len, seq_len]
    """
    d_k = query.size(-1)
    
    # Step 1: Q和K的点积,得到原始注意力分数
    # [batch, n_heads, seq_len, d_k] @ [batch, n_heads, d_k, seq_len]
    # -> [batch, n_heads, seq_len, seq_len]
    scores = torch.matmul(query, key.transpose(-2, -1))
    
    # Step 2: 缩放 - 这一步很关键,后面会详细解释
    scores = scores / math.sqrt(d_k)
    
    # Step 3: 应用mask (causal mask或padding mask)
    if mask is not None:
        # mask为0的位置填充一个很大的负数,softmax后趋近于0
        scores = scores.masked_fill(mask == 0, float('-inf'))
    
    # Step 4: Softmax归一化,得到注意力权重
    attention_weights = F.softmax(scores, dim=-1)
    
    # Step 5: Dropout (训练时)
    if dropout is not None:
        attention_weights = dropout(attention_weights)
    
    # Step 6: 用注意力权重加权Value
    output = torch.matmul(attention_weights, value)
    
    return output, attention_weights

为什么要除以√d_k?这是我卡了最久的地方

论文原文说"点积的方差会随d_k增大而增大"。但我一开始不理解这会造成什么问题。

我做了个实验来验证:

import numpy as np

def test_variance(d_k, n_samples=10000):
    """测试不同d_k下点积的分布"""
    q = np.random.randn(n_samples, d_k)  # 标准正态分布
    k = np.random.randn(n_samples, d_k)
    
    # 计算点积
    dot_products = np.sum(q * k, axis=1)
    
    print(f"d_k={d_k}: mean={dot_products.mean():.3f}, "
          f"var={dot_products.var():.3f}, "
          f"max={dot_products.max():.3f}")
    
    # 计算softmax后的分布
    # 模拟一个sequence length=10的情况
    scores = np.random.randn(n_samples, 10) * np.sqrt(d_k)
    softmax_scores = np.exp(scores) / np.exp(scores).sum(axis=1, keepdims=True)
    entropy = -np.sum(softmax_scores * np.log(softmax_scores + 1e-10), axis=1)
    print(f"  softmax entropy: {entropy.mean():.3f}")

# 测试不同维度
for d_k in [64, 256, 512, 1024]:
    test_variance(d_k)

输出结果:

d_k=64: mean=0.012, var=63.847, max=21.234
  softmax entropy: 0.847
d_k=256: mean=-0.003, var=256.102, max=45.123  
  softmax entropy: 0.234
d_k=512: mean=0.008, var=511.876, max=62.567
  softmax entropy: 0.089
d_k=1024: mean=-0.001, var=1023.445, max=89.234
  softmax entropy: 0.023
⚠️ 关键发现: 当d_k变大时,点积的方差近似等于d_k。这导致softmax的输入值很大,输出会趋近于one-hot(entropy接近0)。这意味着注意力会集中在极少数位置,梯度几乎为0,模型学不动。除以√d_k后,方差稳定在1左右,softmax输出更平滑,训练更稳定。

第二步:Multi-Head Attention的工程实现

为什么需要多头?

单头注意力只能学习一种"关注模式"。多头允许模型同时学习多种不同的关注方式。比如在翻译任务中:

完整实现

class MultiHeadAttention(nn.Module):
    """
    多头注意力机制
    
    论文中的默认配置:
    - d_model = 512
    - n_heads = 8
    - d_k = d_v = d_model / n_heads = 64
    """
    
    def __init__(
        self, 
        d_model: int = 512, 
        n_heads: int = 8, 
        dropout: float = 0.1
    ):
        super().__init__()
        
        assert d_model % n_heads == 0, \
            f"d_model({d_model})必须能被n_heads({n_heads})整除"
        
        self.d_model = d_model
        self.n_heads = n_heads
        self.d_k = d_model // n_heads  # 每个头的维度
        
        # Q, K, V的线性变换
        # 技巧:用一个大矩阵同时计算Q、K、V,然后split,比分开计算快
        self.W_qkv = nn.Linear(d_model, 3 * d_model, bias=False)
        
        # 输出投影
        self.W_o = nn.Linear(d_model, d_model, bias=False)
        
        self.dropout = nn.Dropout(dropout)
        self.attn_weights = None  # 保存用于可视化
        
    def forward(
        self, 
        query: torch.Tensor,
        key: torch.Tensor,
        value: torch.Tensor,
        mask: torch.Tensor = None
    ) -> torch.Tensor:
        """
        Args:
            query, key, value: [batch_size, seq_len, d_model]
            mask: [batch_size, 1, seq_len] or [batch_size, seq_len, seq_len]
            
        Returns:
            output: [batch_size, seq_len, d_model]
        """
        batch_size, seq_len, _ = query.size()
        
        # 对于self-attention, query=key=value
        # 对于cross-attention, query来自decoder, key/value来自encoder
        
        if query.data_ptr() == key.data_ptr() == value.data_ptr():
            # Self-attention: 一次计算Q、K、V
            qkv = self.W_qkv(query)
            qkv = qkv.view(batch_size, seq_len, 3, self.n_heads, self.d_k)
            qkv = qkv.permute(2, 0, 3, 1, 4)  # [3, batch, n_heads, seq_len, d_k]
            Q, K, V = qkv[0], qkv[1], qkv[2]
        else:
            # Cross-attention: 分开计算
            # 这里简化处理,实际应该用单独的投影矩阵
            Q = self.W_qkv(query)[:, :, :self.d_model]
            K = self.W_qkv(key)[:, :, self.d_model:2*self.d_model]
            V = self.W_qkv(value)[:, :, 2*self.d_model:]
            
            Q = Q.view(batch_size, -1, self.n_heads, self.d_k).transpose(1, 2)
            K = K.view(batch_size, -1, self.n_heads, self.d_k).transpose(1, 2)
            V = V.view(batch_size, -1, self.n_heads, self.d_k).transpose(1, 2)
        
        # 处理mask维度
        if mask is not None:
            # [batch, seq_len, seq_len] -> [batch, 1, seq_len, seq_len]
            if mask.dim() == 3:
                mask = mask.unsqueeze(1)
        
        # 计算attention
        attn_output, self.attn_weights = scaled_dot_product_attention(
            Q, K, V, mask, self.dropout if self.training else None
        )
        
        # 合并多头: [batch, n_heads, seq_len, d_k] -> [batch, seq_len, d_model]
        attn_output = attn_output.transpose(1, 2).contiguous()
        attn_output = attn_output.view(batch_size, seq_len, self.d_model)
        
        # 输出投影
        output = self.W_o(attn_output)
        
        return output
⚠️ 踩坑记录 #1: 我一开始用了三个独立的Linear层分别计算Q、K、V。后来发现这样慢很多。合并成一个大Linear层,速度提升了约30%。这是因为GPU更擅长处理大矩阵运算。

第三步:位置编码 - 最容易被忽视的关键设计

问题:Self-Attention是位置不变的

Self-Attention的输出只依赖于输入元素的内容,不依赖于它们的位置。也就是说,打乱输入序列的顺序,输出只是相应地打乱,不会有本质变化。

但语言是有顺序的,"我吃饭"和"饭吃我"意思完全不同。所以必须想办法把位置信息注入模型。

论文的方案:正弦位置编码

class SinusoidalPositionalEncoding(nn.Module):
    """
    正弦位置编码
    
    设计思想:
    1. 每个位置有一个唯一的编码向量
    2. 不同维度使用不同频率的sin/cos函数
    3. 相对位置可以通过线性变换表示
    """
    
    def __init__(self, d_model: int, max_len: int = 5000, dropout: float = 0.1):
        super().__init__()
        self.dropout = nn.Dropout(dropout)
        
        # 创建位置编码矩阵 [max_len, d_model]
        pe = torch.zeros(max_len, d_model)
        
        # position: [max_len, 1]
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        
        # div_term: [d_model/2]
        # 10000^(2i/d_model) = exp(2i * log(10000) / d_model)
        div_term = torch.exp(
            torch.arange(0, d_model, 2).float() * 
            (-math.log(10000.0) / d_model)
        )
        
        # 偶数维度用sin,奇数维度用cos
        pe[:, 0::2] = torch.sin(position * div_term)  # [max_len, d_model/2]
        pe[:, 1::2] = torch.cos(position * div_term)  # [max_len, d_model/2]
        
        # [1, max_len, d_model] - 加一个batch维度方便广播
        pe = pe.unsqueeze(0)
        
        # 注册为buffer,不参与梯度计算,但会被保存和加载
        self.register_buffer('pe', pe)
        
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """
        Args:
            x: [batch_size, seq_len, d_model]
        Returns:
            x + positional_encoding
        """
        seq_len = x.size(1)
        # 直接加上对应位置的编码
        x = x + self.pe[:, :seq_len, :]
        return self.dropout(x)

可视化位置编码

为了直观理解,我画了位置编码的热力图:

import matplotlib.pyplot as plt
import numpy as np

def visualize_positional_encoding(d_model=128, max_len=100):
    """可视化位置编码"""
    pe = SinusoidalPositionalEncoding(d_model, max_len)
    encoding = pe.pe.squeeze(0).numpy()  # [max_len, d_model]
    
    fig, axes = plt.subplots(1, 2, figsize=(14, 6))
    
    # 热力图
    im = axes[0].imshow(encoding, aspect='auto', cmap='RdBu')
    axes[0].set_xlabel('Dimension')
    axes[0].set_ylabel('Position')
    axes[0].set_title('Positional Encoding Heatmap')
    plt.colorbar(im, ax=axes[0])
    
    # 选几个维度画曲线
    for dim in [0, 1, 4, 5, 20, 21]:
        axes[1].plot(encoding[:, dim], label=f'dim {dim}')
    axes[1].set_xlabel('Position')
    axes[1].set_ylabel('Value')
    axes[1].set_title('Encoding values across positions')
    axes[1].legend()
    
    plt.tight_layout()
    plt.savefig('positional_encoding_viz.png', dpi=150)
    print("图片已保存")

# 观察结论:
# 1. 低维度(dim 0,1)变化很快,高维度变化很慢
# 2. 相邻位置的编码相似,远离位置的编码差异大
# 3. 每个位置的编码都是唯一的
💡 为什么用sin/cos而不是学习的位置编码?
1. 外推能力:sin/cos是确定性的,可以处理比训练时更长的序列
2. 相对位置:PE(pos+k)可以表示为PE(pos)的线性函数,便于学习相对位置
3. 参数效率:不需要额外参数

但现代模型(如GPT、BERT)多用learned positional embedding,因为训练数据足够时效果更好。RoPE、ALiBi等方案又回归了某种形式的函数编码。

第四步:完整的Transformer Block

Feed-Forward Network

class PositionwiseFeedForward(nn.Module):
    """
    逐位置前馈网络
    
    FFN(x) = max(0, xW1 + b1)W2 + b2
    
    论文配置: d_ff = 4 * d_model = 2048
    """
    
    def __init__(self, d_model: int, d_ff: int = 2048, dropout: float = 0.1):
        super().__init__()
        self.linear1 = nn.Linear(d_model, d_ff)
        self.linear2 = nn.Linear(d_ff, d_model)
        self.dropout = nn.Dropout(dropout)
        
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        # x: [batch, seq_len, d_model]
        # 两层MLP,中间有ReLU和dropout
        return self.linear2(self.dropout(F.relu(self.linear1(x))))

关于激活函数的选择

原论文用ReLU,但后来的研究发现其他激活函数可能更好:

激活函数 公式 使用场景 备注
ReLU max(0, x) 原版Transformer 简单快速
GELU x·Φ(x) BERT, GPT-2 更平滑
SwiGLU Swish(xW)⊙xV LLaMA, PaLM 需要更多参数
GeGLU GELU(xW)⊙xV T5 1.1 GLU变体

Layer Normalization

class LayerNorm(nn.Module):
    """
    层归一化
    
    对每个样本的特征维度做归一化,而不是batch维度。
    这使得LN不依赖batch size,适合变长序列。
    """
    
    def __init__(self, d_model: int, eps: float = 1e-6):
        super().__init__()
        self.gamma = nn.Parameter(torch.ones(d_model))
        self.beta = nn.Parameter(torch.zeros(d_model))
        self.eps = eps
        
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        # x: [batch, seq_len, d_model]
        mean = x.mean(dim=-1, keepdim=True)
        std = x.std(dim=-1, keepdim=True)
        return self.gamma * (x - mean) / (std + self.eps) + self.beta

Pre-LN vs Post-LN:一个实际影响训练的细节

class TransformerEncoderLayer(nn.Module):
    """
    Transformer编码器层
    
    Post-LN (原论文):
        x = x + Attention(x)
        x = LayerNorm(x)
        x = x + FFN(x)
        x = LayerNorm(x)
        
    Pre-LN (现代常用):
        x = x + Attention(LayerNorm(x))
        x = x + FFN(LayerNorm(x))
    """
    
    def __init__(
        self, 
        d_model: int = 512,
        n_heads: int = 8,
        d_ff: int = 2048,
        dropout: float = 0.1,
        pre_norm: bool = True  # 推荐使用Pre-LN
    ):
        super().__init__()
        self.pre_norm = pre_norm
        
        self.self_attn = MultiHeadAttention(d_model, n_heads, dropout)
        self.ffn = PositionwiseFeedForward(d_model, d_ff, dropout)
        
        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)
        
        self.dropout1 = nn.Dropout(dropout)
        self.dropout2 = nn.Dropout(dropout)
        
    def forward(self, x: torch.Tensor, mask: torch.Tensor = None) -> torch.Tensor:
        if self.pre_norm:
            # Pre-LN: 先norm再attention
            attn_out = self.self_attn(
                self.norm1(x), self.norm1(x), self.norm1(x), mask
            )
            x = x + self.dropout1(attn_out)
            
            ffn_out = self.ffn(self.norm2(x))
            x = x + self.dropout2(ffn_out)
        else:
            # Post-LN: 先attention再norm
            attn_out = self.self_attn(x, x, x, mask)
            x = self.norm1(x + self.dropout1(attn_out))
            
            ffn_out = self.ffn(x)
            x = self.norm2(x + self.dropout2(ffn_out))
            
        return x
⚠️ 踩坑记录 #2: 我最初用Post-LN,训练不稳定,loss经常爆炸。换成Pre-LN后,即使不用warmup也能正常训练。后来查资料发现,Post-LN的梯度在靠近输入的层会变得很大,导致不稳定。Pre-LN通过在残差前做norm,缓解了这个问题。

第五步:完整的Encoder和Decoder

Encoder

class TransformerEncoder(nn.Module):
    """
    Transformer编码器
    
    由N个相同的encoder layer堆叠而成
    """
    
    def __init__(
        self,
        vocab_size: int,
        d_model: int = 512,
        n_heads: int = 8,
        n_layers: int = 6,
        d_ff: int = 2048,
        max_len: int = 5000,
        dropout: float = 0.1,
        pre_norm: bool = True
    ):
        super().__init__()
        
        # Token embedding
        self.token_embedding = nn.Embedding(vocab_size, d_model)
        
        # 位置编码
        self.pos_encoding = SinusoidalPositionalEncoding(d_model, max_len, dropout)
        
        # N层encoder
        self.layers = nn.ModuleList([
            TransformerEncoderLayer(d_model, n_heads, d_ff, dropout, pre_norm)
            for _ in range(n_layers)
        ])
        
        # Pre-LN需要最后再加一个norm
        self.final_norm = nn.LayerNorm(d_model) if pre_norm else nn.Identity()
        
        # 初始化
        self._init_weights()
        
    def _init_weights(self):
        """Xavier初始化"""
        for p in self.parameters():
            if p.dim() > 1:
                nn.init.xavier_uniform_(p)
                
    def forward(
        self, 
        src: torch.Tensor, 
        src_mask: torch.Tensor = None
    ) -> torch.Tensor:
        """
        Args:
            src: [batch_size, src_len] - token ids
            src_mask: [batch_size, 1, src_len] - padding mask
        Returns:
            encoder_output: [batch_size, src_len, d_model]
        """
        # Embedding + 位置编码
        x = self.token_embedding(src) * math.sqrt(self.token_embedding.embedding_dim)
        x = self.pos_encoding(x)
        
        # 通过每一层encoder
        for layer in self.layers:
            x = layer(x, src_mask)
            
        return self.final_norm(x)

Decoder (带Causal Mask)

class TransformerDecoderLayer(nn.Module):
    """
    Transformer解码器层
    
    与encoder的区别:
    1. Self-attention需要causal mask,防止看到未来
    2. 多了一层cross-attention,关注encoder的输出
    """
    
    def __init__(
        self,
        d_model: int = 512,
        n_heads: int = 8,
        d_ff: int = 2048,
        dropout: float = 0.1,
        pre_norm: bool = True
    ):
        super().__init__()
        self.pre_norm = pre_norm
        
        # Masked self-attention
        self.self_attn = MultiHeadAttention(d_model, n_heads, dropout)
        
        # Cross-attention
        self.cross_attn = MultiHeadAttention(d_model, n_heads, dropout)
        
        # FFN
        self.ffn = PositionwiseFeedForward(d_model, d_ff, dropout)
        
        # Layer norms
        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)
        self.norm3 = nn.LayerNorm(d_model)
        
        # Dropouts
        self.dropout1 = nn.Dropout(dropout)
        self.dropout2 = nn.Dropout(dropout)
        self.dropout3 = nn.Dropout(dropout)
        
    def forward(
        self,
        x: torch.Tensor,
        encoder_output: torch.Tensor,
        tgt_mask: torch.Tensor = None,
        memory_mask: torch.Tensor = None
    ) -> torch.Tensor:
        """
        Args:
            x: [batch, tgt_len, d_model] - decoder输入
            encoder_output: [batch, src_len, d_model] - encoder输出
            tgt_mask: [batch, tgt_len, tgt_len] - causal mask + padding
            memory_mask: [batch, tgt_len, src_len] - cross attention mask
        """
        if self.pre_norm:
            # Masked self-attention
            normed_x = self.norm1(x)
            attn_out = self.self_attn(normed_x, normed_x, normed_x, tgt_mask)
            x = x + self.dropout1(attn_out)
            
            # Cross-attention
            normed_x = self.norm2(x)
            cross_out = self.cross_attn(
                normed_x, encoder_output, encoder_output, memory_mask
            )
            x = x + self.dropout2(cross_out)
            
            # FFN
            ffn_out = self.ffn(self.norm3(x))
            x = x + self.dropout3(ffn_out)
        else:
            # Post-LN版本 (略)
            pass
            
        return x


def generate_causal_mask(seq_len: int, device: torch.device) -> torch.Tensor:
    """
    生成因果mask,防止decoder看到未来的token
    
    Returns:
        mask: [seq_len, seq_len], 上三角为False
    """
    # 下三角矩阵(包含对角线)为True,上三角为False
    mask = torch.tril(torch.ones(seq_len, seq_len, device=device)).bool()
    return mask

第六步:训练一个小模型验证

任务:数字排序

为了验证实现是否正确,我设计了一个简单任务:给模型一个乱序的数字序列,让它输出排序后的序列。比如输入[3,1,4,1,5],输出[1,1,3,4,5]。

import torch
from torch.utils.data import Dataset, DataLoader

class SortingDataset(Dataset):
    """数字排序数据集"""
    
    def __init__(self, num_samples: int = 10000, seq_len: int = 10, vocab_size: int = 10):
        self.num_samples = num_samples
        self.seq_len = seq_len
        self.vocab_size = vocab_size
        
        # 特殊token
        self.pad_token = 0
        self.bos_token = vocab_size + 1  # Begin of sequence
        self.eos_token = vocab_size + 2  # End of sequence
        
    def __len__(self):
        return self.num_samples
    
    def __getitem__(self, idx):
        # 生成随机序列 (1到vocab_size的数字)
        src = torch.randint(1, self.vocab_size + 1, (self.seq_len,))
        
        # 排序作为target
        tgt_values = src.sort().values
        
        # 加上BOS和EOS
        tgt_input = torch.cat([
            torch.tensor([self.bos_token]), 
            tgt_values
        ])
        tgt_output = torch.cat([
            tgt_values,
            torch.tensor([self.eos_token])
        ])
        
        return {
            'src': src,              # [seq_len]
            'tgt_input': tgt_input,  # [seq_len + 1], 以BOS开头
            'tgt_output': tgt_output # [seq_len + 1], 以EOS结尾
        }


# 完整的训练代码
def train_sorting_model():
    # 超参数
    config = {
        'vocab_size': 15,  # 0-9数字 + pad + bos + eos
        'd_model': 128,
        'n_heads': 4,
        'n_layers': 3,
        'd_ff': 256,
        'dropout': 0.1,
        'batch_size': 64,
        'lr': 1e-3,
        'epochs': 50,
        'seq_len': 8
    }
    
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print(f"Using device: {device}")
    
    # 数据
    train_dataset = SortingDataset(
        num_samples=10000, 
        seq_len=config['seq_len'],
        vocab_size=10
    )
    train_loader = DataLoader(
        train_dataset, 
        batch_size=config['batch_size'], 
        shuffle=True
    )
    
    # 模型 (简化版:只用encoder做seq2seq)
    model = Transformer(
        src_vocab_size=config['vocab_size'],
        tgt_vocab_size=config['vocab_size'],
        d_model=config['d_model'],
        n_heads=config['n_heads'],
        n_encoder_layers=config['n_layers'],
        n_decoder_layers=config['n_layers'],
        d_ff=config['d_ff'],
        dropout=config['dropout']
    ).to(device)
    
    # 优化器
    optimizer = torch.optim.Adam(model.parameters(), lr=config['lr'], betas=(0.9, 0.98))
    
    # 学习率warmup
    def lr_lambda(step):
        warmup_steps = 400
        if step < warmup_steps:
            return step / warmup_steps
        return 1.0
    
    scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda)
    
    # 训练循环
    criterion = nn.CrossEntropyLoss(ignore_index=0)
    
    for epoch in range(config['epochs']):
        model.train()
        total_loss = 0
        correct = 0
        total = 0
        
        for batch in train_loader:
            src = batch['src'].to(device)
            tgt_input = batch['tgt_input'].to(device)
            tgt_output = batch['tgt_output'].to(device)
            
            # Forward
            logits = model(src, tgt_input)  # [batch, seq_len+1, vocab_size]
            
            # Loss
            loss = criterion(
                logits.view(-1, config['vocab_size']),
                tgt_output.view(-1)
            )
            
            # Backward
            optimizer.zero_grad()
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
            optimizer.step()
            scheduler.step()
            
            total_loss += loss.item()
            
            # 计算准确率
            preds = logits.argmax(dim=-1)
            correct += (preds == tgt_output).sum().item()
            total += tgt_output.numel()
        
        if (epoch + 1) % 5 == 0:
            acc = correct / total * 100
            print(f"Epoch {epoch+1}: loss={total_loss/len(train_loader):.4f}, acc={acc:.2f}%")
    
    return model

训练结果

Using device: cuda
Epoch 5: loss=1.2341, acc=62.35%
Epoch 10: loss=0.4523, acc=85.67%
Epoch 15: loss=0.1876, acc=94.12%
Epoch 20: loss=0.0723, acc=97.89%
Epoch 25: loss=0.0312, acc=99.21%
Epoch 30: loss=0.0156, acc=99.67%
Epoch 35: loss=0.0089, acc=99.85%
Epoch 40: loss=0.0052, acc=99.93%
Epoch 45: loss=0.0031, acc=99.97%
Epoch 50: loss=0.0019, acc=99.99%
✅ 验证成功! 50个epoch后,模型在排序任务上达到99.99%的准确率。这说明实现是正确的。

训练中的踩坑与优化

坑1:梯度爆炸

刚开始训练时,loss会突然变成nan。原因是梯度太大了。

解决方案:

# 梯度裁剪,限制梯度的最大范数
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)

坑2:学习率太大导致不收敛

用1e-3的学习率直接训练,前几个batch loss不降反升。

解决方案: 使用warmup,先用小学习率"热身",再逐渐增大。

# 论文原版的学习率schedule
# lr = d_model^(-0.5) * min(step^(-0.5), step * warmup_steps^(-1.5))
def get_lr(step, d_model=512, warmup_steps=4000):
    return d_model ** (-0.5) * min(
        step ** (-0.5),
        step * warmup_steps ** (-1.5)
    )

坑3:Embedding权重没有scale

论文里提到embedding要乘以√d_model,我一开始漏掉了,导致训练很慢。

# ❌ 错误
x = self.token_embedding(src)

# ✅ 正确
x = self.token_embedding(src) * math.sqrt(self.d_model)

原因:embedding的初始化通常是均值0、方差1/d_model的分布。乘以√d_model后,方差变成1,和位置编码的scale匹配。

性能数据对比

我在不同配置下测试了排序任务的收敛速度:

配置 参数量 达到99%准确率的epoch 训练时间(V100)
d=64, h=2, L=2 48K 45 2.3min
d=128, h=4, L=3 380K 22 4.1min
d=256, h=8, L=4 2.4M 15 8.7min
d=512, h=8, L=6 (原论文) 25M 8 23.5min

可以看到,参数量增加带来的收益是递减的。对于简单任务,小模型就够了。

一些个人思考与未解之谜

1. 为什么FFN这么大?

Transformer里,FFN的参数量是attention的2倍(d_ff = 4 * d_model)。但从信息流的角度看,FFN只是一个逐位置的MLP,没有位置间的交互。它到底在做什么?

我的理解:FFN可能在做"知识存储"。attention负责信息的路由和聚合,FFN负责对聚合后的信息做非线性变换。最近的研究也表明,FFN的中间层激活确实编码了很多factual knowledge。

2. 为什么8个头?

论文用8个头,但没解释为什么是8。我尝试了不同的头数:

我的猜测:头数太少,学到的pattern不够多样;头数太多,每个头的维度太小,表达能力下降。8是个平衡点。

3. Layer Norm vs RMS Norm

LLaMA等新模型用RMS Norm替代Layer Norm,去掉了减均值的操作:

class RMSNorm(nn.Module):
    def __init__(self, d_model, eps=1e-6):
        super().__init__()
        self.weight = nn.Parameter(torch.ones(d_model))
        self.eps = eps
        
    def forward(self, x):
        rms = torch.sqrt(torch.mean(x ** 2, dim=-1, keepdim=True) + self.eps)
        return self.weight * x / rms

我测试了一下,RMS Norm确实快一点(约15%),效果几乎一样。

参考资料

总结

花72小时重读论文、从零实现Transformer,收获比看10篇博客都大。几个关键体会:

  1. Attention本质是软性检索,Q-K-V的设计非常优雅
  2. √d_k的缩放是必须的,否则训练不稳定
  3. Pre-LN比Post-LN好训练,现代模型基本都用Pre-LN
  4. Warmup很重要,尤其是用大学习率时
  5. 位置编码的选择影响外推能力,这是个ongoing的研究方向

代码已经放在GitHub,欢迎star和issue。

更新记录:
2024-03-12: 初版发布
2024-03-18: 补充了Pre-LN vs Post-LN的对比
2024-04-02: 增加了性能对比数据
2024-05-15: 补充RMS Norm的讨论