← 返回首页

长文本处理实战:从4K到128K上下文的血泪史

2025-01-15 | LLM 长文本 位置编码 RoPE
2024年Q3,我们接到一个法律文档分析项目,要求模型能处理单份合同最长50万字。当时手上的模型只支持4K上下文。这篇文章记录了我们把上下文窗口从4K扩展到128K的完整过程,包括各种位置编码方案的对比、显存优化技巧,以及一些至今没完全搞懂的问题。

背景:4K上下文完全不够用

项目需求很明确:分析一份完整的商业合同,回答用户关于条款的问题。问题是:

我们当时用的是基于Llama 2架构的中文模型,原生只支持4096 token。直接截断文档显然不行——用户问的可能恰好是被截掉的部分。

初步尝试:分块检索

第一反应是用RAG:把文档切块,检索相关段落,只把相关部分送给模型。但法律文档有个特点——条款之间高度交叉引用

# 典型的合同条款交叉引用
"根据本协议第5.2条的规定,乙方应在第3.1条约定的
付款日期前完成第7.4条所列的交付义务,否则按照
第12.3条承担违约责任。"

用户问"违约责任是什么",你检索到了12.3条,但不知道5.2、3.1、7.4条的内容,回答依然不完整。

经过两周的尝试,我们决定:必须让模型能直接处理长文本。

位置编码:长文本的核心问题

为什么位置编码是瓶颈?

Transformer的self-attention本身是位置无关的。位置信息完全依赖position encoding注入。问题在于:

我们的模型用的是RoPE(Rotary Position Embedding),这是目前最主流的方案。理解RoPE是解决长文本问题的关键。

RoPE原理:旋转位置编码

RoPE的核心思想:用旋转矩阵编码位置信息,使得两个位置的注意力分数只依赖于它们的相对距离。

import torch
import math

def apply_rotary_pos_emb(q, k, cos, sin):
    """
    应用旋转位置编码
    
    Args:
        q, k: [batch, n_heads, seq_len, head_dim]
        cos, sin: [seq_len, head_dim]
    """
    # 将q, k分成两半
    q1, q2 = q[..., ::2], q[..., 1::2]
    k1, k2 = k[..., ::2], k[..., 1::2]
    
    # 旋转变换
    # q_rot = q * cos + rotate_half(q) * sin
    q_rot = torch.cat([
        q1 * cos - q2 * sin,
        q2 * cos + q1 * sin
    ], dim=-1)
    
    k_rot = torch.cat([
        k1 * cos - k2 * sin,
        k2 * cos + k1 * sin
    ], dim=-1)
    
    return q_rot, k_rot


def precompute_rope_cache(dim: int, max_seq_len: int, base: float = 10000.0):
    """
    预计算RoPE的cos和sin缓存
    
    关键参数:
    - dim: 头的维度
    - max_seq_len: 最大序列长度
    - base: 频率基数,默认10000
    """
    # 计算频率: theta_i = base^(-2i/d)
    inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim))
    
    # 位置序列
    t = torch.arange(max_seq_len).float()
    
    # 计算角度: freqs[pos, i] = pos * theta_i
    freqs = torch.outer(t, inv_freq)  # [max_seq_len, dim/2]
    
    # 复制一份,因为要作用在整个head_dim上
    freqs = torch.cat([freqs, freqs], dim=-1)  # [max_seq_len, dim]
    
    cos_cache = freqs.cos()
    sin_cache = freqs.sin()
    
    return cos_cache, sin_cache

RoPE为什么外推能力差?

虽然RoPE编码的是相对位置,但它有一个关键问题:不同频率的分量对长度的敏感度不同。

def analyze_rope_frequencies(dim=64, base=10000.0, max_pos=8192):
    """分析RoPE各频率分量的波长"""
    inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim))
    
    # 波长 = 2π / 频率
    wavelengths = 2 * math.pi / inv_freq
    
    print("RoPE频率分析:")
    print(f"  维度0-1的波长: {wavelengths[0]:.1f} (约{wavelengths[0]:.0f}个位置一个周期)")
    print(f"  维度{dim//2-2}-{dim//2-1}的波长: {wavelengths[-1]:.1f}")
    print(f"  训练长度4096时:")
    print(f"    - 低频分量: 完成约{4096/wavelengths[-1]:.2f}个周期")
    print(f"    - 高频分量: 完成约{4096/wavelengths[0]:.2f}个周期")

# 输出:
# RoPE频率分析:
#   维度0-1的波长: 6.3 (约6个位置一个周期)
#   维度62-63的波长: 62831.9
#   训练长度4096时:
#     - 低频分量: 完成约0.07个周期
#     - 高频分量: 完成约651.90个周期
⚠️ 关键发现: 低频分量在训练长度内只完成了不到0.1个周期!当推理长度超过训练长度时,这些分量会进入从未见过的角度范围,导致注意力模式崩溃。

方案一:Linear Position Interpolation (PI)

核心思想

最直观的想法:如果4K训练,想支持16K,就把位置压缩4倍。本来位置16000对应角度θ,现在让它对应θ/4,这样所有角度都在训练见过的范围内。

def linear_scaling_rope(dim: int, max_seq_len: int, base: float = 10000.0, 
                        scale: float = 1.0):
    """
    线性位置插值
    
    scale > 1: 压缩位置,支持更长序列
    例如 scale=4 可以让4K模型支持16K
    """
    inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim))
    
    # 关键:位置除以scale
    t = torch.arange(max_seq_len).float() / scale
    
    freqs = torch.outer(t, inv_freq)
    freqs = torch.cat([freqs, freqs], dim=-1)
    
    return freqs.cos(), freqs.sin()

# 使用示例
# 原模型训练长度4096,想支持32K
scale = 32768 / 4096  # = 8
cos, sin = linear_scaling_rope(dim=64, max_seq_len=32768, scale=scale)

实测效果

序列长度 原始RoPE PPL PI (scale=4) PPL 说明
4096 5.2 5.4 略有下降
8192 87.3 6.1 显著改善
16384 1203.5 7.8 可用
32768 OOM 12.4 质量下降明显

PI的问题:高频分量被过度压缩,损失了精细的位置区分能力。位置1和位置2的区别被压缩成位置0.125和位置0.25的区别。

方案二:NTK-aware Interpolation

核心思想

不要统一压缩所有频率,而是有选择地处理:高频分量保持不变(它们本来就周期短,外推问题不大),低频分量做更大的调整。

def ntk_aware_rope(dim: int, max_seq_len: int, base: float = 10000.0,
                   scale: float = 1.0):
    """
    NTK-aware位置编码
    
    关键:调整base而不是调整位置
    新的base = base * scale^(dim/(dim-2))
    """
    # 调整base,使得低频分量的波长变长
    new_base = base * (scale ** (dim / (dim - 2)))
    
    inv_freq = 1.0 / (new_base ** (torch.arange(0, dim, 2).float() / dim))
    
    t = torch.arange(max_seq_len).float()
    freqs = torch.outer(t, inv_freq)
    freqs = torch.cat([freqs, freqs], dim=-1)
    
    return freqs.cos(), freqs.sin()

# NTK的效果:
# - 高频分量(dim=0,1): 波长从6.3变化不大
# - 低频分量(dim=62,63): 波长从62831变成约250000
# 这样低频分量在32K内也不会外推太远

实测效果

序列长度 PI PPL NTK PPL
4096 5.4 5.3
8192 6.1 5.6
16384 7.8 6.2
32768 12.4 7.5

NTK在长序列上效果更好,但需要微调才能充分发挥效果。

方案三:YaRN (Yet another RoPE extensioN)

我们最终采用的方案

YaRN是目前效果最好的RoPE扩展方法,结合了NTK和注意力缩放。

import torch
import math

class YaRNRotaryEmbedding(torch.nn.Module):
    """
    YaRN: Yet another RoPE extensioN
    
    核心改进:
    1. 对不同频率分量使用不同的缩放策略
    2. 添加注意力logits的温度缩放
    """
    
    def __init__(
        self,
        dim: int,
        max_position_embeddings: int = 4096,
        base: float = 10000.0,
        scale: float = 1.0,
        original_max_position_embeddings: int = 4096,
        extrapolation_factor: float = 1.0,
        attn_factor: float = 1.0,
        beta_fast: int = 32,
        beta_slow: int = 1,
    ):
        super().__init__()
        self.dim = dim
        self.max_position_embeddings = max_position_embeddings
        self.base = base
        self.scale = scale
        self.original_max_position_embeddings = original_max_position_embeddings
        self.extrapolation_factor = extrapolation_factor
        self.attn_factor = attn_factor
        self.beta_fast = beta_fast
        self.beta_slow = beta_slow
        
        self._build_cache()
        
    def _build_cache(self):
        # 计算每个频率分量应该如何插值
        pos_freqs = self.base ** (torch.arange(0, self.dim, 2).float() / self.dim)
        inv_freq_extrapolation = 1.0 / pos_freqs
        inv_freq_interpolation = 1.0 / (self.scale * pos_freqs)
        
        # 计算高低频分界点
        low = max(
            math.floor(
                self.dim * math.log(self.original_max_position_embeddings / 
                                   (self.beta_fast * 2 * math.pi)) /
                (2 * math.log(self.base))
            ), 0
        )
        high = min(
            math.ceil(
                self.dim * math.log(self.original_max_position_embeddings / 
                                   (self.beta_slow * 2 * math.pi)) /
                (2 * math.log(self.base))
            ), self.dim - 1
        )
        
        # 计算平滑的插值权重
        inv_freq = torch.zeros(self.dim // 2)
        for i in range(self.dim // 2):
            if i < low:
                # 高频:不插值
                inv_freq[i] = inv_freq_extrapolation[i]
            elif i > high:
                # 低频:完全插值
                inv_freq[i] = inv_freq_interpolation[i]
            else:
                # 中频:线性混合
                t = (i - low) / (high - low)
                inv_freq[i] = (1 - t) * inv_freq_extrapolation[i] + \
                              t * inv_freq_interpolation[i]
        
        self.register_buffer("inv_freq", inv_freq)
        
        # 预计算cos/sin
        t = torch.arange(self.max_position_embeddings).float()
        freqs = torch.outer(t, self.inv_freq)
        freqs = torch.cat([freqs, freqs], dim=-1)
        
        # 注意力缩放因子
        self.register_buffer("attn_scaling", 
                            torch.tensor(0.1 * math.log(self.scale) + 1.0))
        self.register_buffer("cos_cached", freqs.cos())
        self.register_buffer("sin_cached", freqs.sin())
        
    def forward(self, x, seq_len=None):
        if seq_len is None:
            seq_len = x.shape[-2]
        return (
            self.cos_cached[:seq_len],
            self.sin_cached[:seq_len],
            self.attn_scaling
        )
💡 YaRN的关键洞察:
1. 分频处理:高频分量外推,低频分量插值,中频分量混合
2. 注意力缩放:长序列时降低attention logits的scale,防止分布过于尖锐
3. 参数自动计算:根据目标长度自动确定插值策略

YaRN效果对比

序列长度 原始RoPE PI NTK YaRN
4096 5.2 5.4 5.3 5.2
16384 1203.5 7.8 6.2 5.5
32768 - 12.4 7.5 5.9
65536 - 24.7 11.2 6.4
131072 - - 18.9 7.2

显存优化:长序列的另一个瓶颈

问题分析

Attention的显存复杂度是O(n²),32K序列的attention矩阵需要:

# 32K x 32K x 4字节(float32) x batch_size x n_heads
# = 32768 * 32768 * 4 * 1 * 32
# = 137GB (单batch!)

# 即使用float16也需要68GB,远超单卡显存

解决方案1:Flash Attention

from flash_attn import flash_attn_func

def flash_attention_forward(q, k, v, causal=True):
    """
    使用Flash Attention 2
    
    优势:
    1. 显存O(n)而不是O(n²)
    2. 速度快2-4倍
    """
    # q, k, v: [batch, seq_len, n_heads, head_dim]
    # 注意维度顺序和标准attention不同!
    
    output = flash_attn_func(
        q, k, v,
        dropout_p=0.0,
        softmax_scale=None,  # 自动计算
        causal=causal
    )
    return output

# 显存对比 (batch=1, heads=32, dim=128)
# 序列长度  |  标准Attention  |  Flash Attention
# 8192      |     8.2 GB      |     1.1 GB
# 32768     |     OOM         |     4.3 GB
# 131072    |     OOM         |     17.2 GB

解决方案2:KV Cache优化

推理时,每生成一个token都要存储之前所有token的K和V。对于长序列,这是巨大的显存开销。

class PagedKVCache:
    """
    分页KV Cache (类似vLLM的实现)
    
    思想:不预分配全部显存,按需分配page
    """
    
    def __init__(
        self,
        n_layers: int,
        n_heads: int,
        head_dim: int,
        page_size: int = 16,
        max_pages: int = 8192,
        dtype=torch.float16,
        device='cuda'
    ):
        self.page_size = page_size
        self.n_layers = n_layers
        self.n_heads = n_heads
        self.head_dim = head_dim
        
        # 预分配page池
        self.k_cache = torch.zeros(
            max_pages, n_layers, n_heads, page_size, head_dim,
            dtype=dtype, device=device
        )
        self.v_cache = torch.zeros(
            max_pages, n_layers, n_heads, page_size, head_dim,
            dtype=dtype, device=device
        )
        
        # 空闲page列表
        self.free_pages = list(range(max_pages))
        
        # 每个序列使用的page表
        self.page_tables = {}  # seq_id -> list of page indices
        
    def allocate(self, seq_id: int, num_tokens: int):
        """为序列分配page"""
        num_pages = (num_tokens + self.page_size - 1) // self.page_size
        
        if len(self.free_pages) < num_pages:
            raise RuntimeError("KV Cache空间不足")
        
        pages = [self.free_pages.pop() for _ in range(num_pages)]
        self.page_tables[seq_id] = pages
        return pages
    
    def append(self, seq_id: int, layer: int, k: torch.Tensor, v: torch.Tensor):
        """追加KV到cache"""
        pages = self.page_tables[seq_id]
        # ... 实际的写入逻辑
        
    def get(self, seq_id: int, layer: int):
        """获取序列的完整KV"""
        pages = self.page_tables[seq_id]
        # ... 实际的读取逻辑
⚠️ 踩坑记录: 我们一开始用的是连续KV Cache,结果发现长序列场景下显存碎片严重。换成PagedKV Cache后,同样的显存能多服务40%的并发请求。

实际部署配置

最终方案

经过三个月的迭代,我们的配置是:

# 模型配置
model_config:
  base_model: "Llama-2-13B-chat"
  rope_scaling:
    type: "yarn"
    factor: 32  # 4K -> 128K
    original_max_position_embeddings: 4096
  
# 推理配置
inference_config:
  max_seq_len: 131072
  attention_implementation: "flash_attention_2"
  kv_cache_type: "paged"
  dtype: "bfloat16"
  
# 硬件配置
hardware:
  gpus: 4 x A100-80G
  tensor_parallel: 4

性能数据

指标 数值
最大上下文长度 128K tokens
prefill吞吐(128K) ~12K tokens/s
decode吞吐 ~45 tokens/s
首token延迟(128K) ~11秒
显存占用 ~290GB (4卡)

Needle in Haystack测试

"大海捞针"是评估长文本能力的标准测试:在长文档中随机位置插入一个事实,看模型能否准确找到。

import random

def needle_in_haystack_test(model, context_lengths, depths):
    """
    大海捞针测试
    
    context_lengths: 测试的上下文长度列表
    depths: 针插入的深度比例 (0.0-1.0)
    """
    results = {}
    
    # 生成干扰文本
    haystack = load_paul_graham_essays()  # 用Paul Graham的文章作为干扰
    
    # 针
    needle = "The special magic number is: 7429185."
    question = "What is the special magic number?"
    expected = "7429185"
    
    for ctx_len in context_lengths:
        results[ctx_len] = {}
        
        for depth in depths:
            # 计算插入位置
            insert_pos = int(ctx_len * depth)
            
            # 构造测试文本
            prefix = haystack[:insert_pos]
            suffix = haystack[insert_pos:ctx_len]
            test_context = prefix + needle + suffix
            
            # 测试
            response = model.generate(
                f"Context: {test_context}\n\nQuestion: {question}"
            )
            
            # 判断是否正确
            success = expected in response
            results[ctx_len][depth] = success
            
            print(f"Length={ctx_len}, Depth={depth:.1%}: "
                  f"{'✓' if success else '✗'}")
    
    return results

# 我们的测试结果
# Length=4K:  所有depth都成功
# Length=16K: depth<0.9成功,末尾有时失败
# Length=64K: depth 0.3-0.7成功率高,首尾较差
# Length=128K: 中间位置80%成功,首尾约50%
⚠️ 发现的问题: 即使用了YaRN,128K时首尾位置的召回率明显低于中间。我们猜测是因为attention在极长序列下的分布问题,但还没完全搞清楚。

还没解决的问题

1. Lost in the Middle

不只是我们,很多长文本模型都有这个问题:中间位置的信息比首尾更容易被"遗忘"。有论文专门研究过这个现象,但目前没有完美解决方案。

2. 推理速度

128K context的首token延迟11秒,用户体验不好。我们尝试过:

3. 质量degradation

虽然PPL看起来不错,但主观评测发现128K时的回答质量明显不如短文本。特别是需要综合多处信息的问题,经常漏掉一些细节。

参考资料

总结

把上下文从4K扩展到128K,我们学到的经验:

  1. YaRN是目前效果最好的RoPE扩展方案,但需要理解原理才能调好参数
  2. 显存优化和位置编码同等重要,Flash Attention + Paged KV Cache是标配
  3. 长度≠质量,128K的PPL可以很低,但实际任务效果可能不如RAG
  4. 评测要全面,不只看PPL,还要做Needle in Haystack等任务测试

对于我们的法律文档场景,最终采用的是混合方案:短文档(<32K)直接处理,长文档先分块再用128K模型处理关键部分。这样在质量和成本之间取得了平衡。

更新记录:
2025-01-15: 初版发布
2025-01-22: 补充了Needle in Haystack测试结果
2025-02-01: 更新了生产环境性能数据