长文本处理实战:从4K到128K上下文的血泪史
2024年Q3,我们接到一个法律文档分析项目,要求模型能处理单份合同最长50万字。当时手上的模型只支持4K上下文。这篇文章记录了我们把上下文窗口从4K扩展到128K的完整过程,包括各种位置编码方案的对比、显存优化技巧,以及一些至今没完全搞懂的问题。
背景:4K上下文完全不够用
项目需求很明确:分析一份完整的商业合同,回答用户关于条款的问题。问题是:
- 一份普通的商业合同约2-3万字
- 复杂的并购协议可达20万字
- 加上用户问题和格式化prompt,最长需要50万token
我们当时用的是基于Llama 2架构的中文模型,原生只支持4096 token。直接截断文档显然不行——用户问的可能恰好是被截掉的部分。
初步尝试:分块检索
第一反应是用RAG:把文档切块,检索相关段落,只把相关部分送给模型。但法律文档有个特点——条款之间高度交叉引用。
# 典型的合同条款交叉引用
"根据本协议第5.2条的规定,乙方应在第3.1条约定的
付款日期前完成第7.4条所列的交付义务,否则按照
第12.3条承担违约责任。"
用户问"违约责任是什么",你检索到了12.3条,但不知道5.2、3.1、7.4条的内容,回答依然不完整。
经过两周的尝试,我们决定:必须让模型能直接处理长文本。
位置编码:长文本的核心问题
为什么位置编码是瓶颈?
Transformer的self-attention本身是位置无关的。位置信息完全依赖position encoding注入。问题在于:
- 绝对位置编码(如原版Transformer的sin/cos):训练时没见过的位置,推理时效果未知
- 学习的位置编码(如BERT、GPT):直接受训练长度限制
我们的模型用的是RoPE(Rotary Position Embedding),这是目前最主流的方案。理解RoPE是解决长文本问题的关键。
RoPE原理:旋转位置编码
RoPE的核心思想:用旋转矩阵编码位置信息,使得两个位置的注意力分数只依赖于它们的相对距离。
import torch
import math
def apply_rotary_pos_emb(q, k, cos, sin):
"""
应用旋转位置编码
Args:
q, k: [batch, n_heads, seq_len, head_dim]
cos, sin: [seq_len, head_dim]
"""
# 将q, k分成两半
q1, q2 = q[..., ::2], q[..., 1::2]
k1, k2 = k[..., ::2], k[..., 1::2]
# 旋转变换
# q_rot = q * cos + rotate_half(q) * sin
q_rot = torch.cat([
q1 * cos - q2 * sin,
q2 * cos + q1 * sin
], dim=-1)
k_rot = torch.cat([
k1 * cos - k2 * sin,
k2 * cos + k1 * sin
], dim=-1)
return q_rot, k_rot
def precompute_rope_cache(dim: int, max_seq_len: int, base: float = 10000.0):
"""
预计算RoPE的cos和sin缓存
关键参数:
- dim: 头的维度
- max_seq_len: 最大序列长度
- base: 频率基数,默认10000
"""
# 计算频率: theta_i = base^(-2i/d)
inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim))
# 位置序列
t = torch.arange(max_seq_len).float()
# 计算角度: freqs[pos, i] = pos * theta_i
freqs = torch.outer(t, inv_freq) # [max_seq_len, dim/2]
# 复制一份,因为要作用在整个head_dim上
freqs = torch.cat([freqs, freqs], dim=-1) # [max_seq_len, dim]
cos_cache = freqs.cos()
sin_cache = freqs.sin()
return cos_cache, sin_cache
RoPE为什么外推能力差?
虽然RoPE编码的是相对位置,但它有一个关键问题:不同频率的分量对长度的敏感度不同。
def analyze_rope_frequencies(dim=64, base=10000.0, max_pos=8192):
"""分析RoPE各频率分量的波长"""
inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim))
# 波长 = 2π / 频率
wavelengths = 2 * math.pi / inv_freq
print("RoPE频率分析:")
print(f" 维度0-1的波长: {wavelengths[0]:.1f} (约{wavelengths[0]:.0f}个位置一个周期)")
print(f" 维度{dim//2-2}-{dim//2-1}的波长: {wavelengths[-1]:.1f}")
print(f" 训练长度4096时:")
print(f" - 低频分量: 完成约{4096/wavelengths[-1]:.2f}个周期")
print(f" - 高频分量: 完成约{4096/wavelengths[0]:.2f}个周期")
# 输出:
# RoPE频率分析:
# 维度0-1的波长: 6.3 (约6个位置一个周期)
# 维度62-63的波长: 62831.9
# 训练长度4096时:
# - 低频分量: 完成约0.07个周期
# - 高频分量: 完成约651.90个周期
方案一:Linear Position Interpolation (PI)
核心思想
最直观的想法:如果4K训练,想支持16K,就把位置压缩4倍。本来位置16000对应角度θ,现在让它对应θ/4,这样所有角度都在训练见过的范围内。
def linear_scaling_rope(dim: int, max_seq_len: int, base: float = 10000.0,
scale: float = 1.0):
"""
线性位置插值
scale > 1: 压缩位置,支持更长序列
例如 scale=4 可以让4K模型支持16K
"""
inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim))
# 关键:位置除以scale
t = torch.arange(max_seq_len).float() / scale
freqs = torch.outer(t, inv_freq)
freqs = torch.cat([freqs, freqs], dim=-1)
return freqs.cos(), freqs.sin()
# 使用示例
# 原模型训练长度4096,想支持32K
scale = 32768 / 4096 # = 8
cos, sin = linear_scaling_rope(dim=64, max_seq_len=32768, scale=scale)
实测效果
| 序列长度 | 原始RoPE PPL | PI (scale=4) PPL | 说明 |
|---|---|---|---|
| 4096 | 5.2 | 5.4 | 略有下降 |
| 8192 | 87.3 | 6.1 | 显著改善 |
| 16384 | 1203.5 | 7.8 | 可用 |
| 32768 | OOM | 12.4 | 质量下降明显 |
PI的问题:高频分量被过度压缩,损失了精细的位置区分能力。位置1和位置2的区别被压缩成位置0.125和位置0.25的区别。
方案二:NTK-aware Interpolation
核心思想
不要统一压缩所有频率,而是有选择地处理:高频分量保持不变(它们本来就周期短,外推问题不大),低频分量做更大的调整。
def ntk_aware_rope(dim: int, max_seq_len: int, base: float = 10000.0,
scale: float = 1.0):
"""
NTK-aware位置编码
关键:调整base而不是调整位置
新的base = base * scale^(dim/(dim-2))
"""
# 调整base,使得低频分量的波长变长
new_base = base * (scale ** (dim / (dim - 2)))
inv_freq = 1.0 / (new_base ** (torch.arange(0, dim, 2).float() / dim))
t = torch.arange(max_seq_len).float()
freqs = torch.outer(t, inv_freq)
freqs = torch.cat([freqs, freqs], dim=-1)
return freqs.cos(), freqs.sin()
# NTK的效果:
# - 高频分量(dim=0,1): 波长从6.3变化不大
# - 低频分量(dim=62,63): 波长从62831变成约250000
# 这样低频分量在32K内也不会外推太远
实测效果
| 序列长度 | PI PPL | NTK PPL |
|---|---|---|
| 4096 | 5.4 | 5.3 |
| 8192 | 6.1 | 5.6 |
| 16384 | 7.8 | 6.2 |
| 32768 | 12.4 | 7.5 |
NTK在长序列上效果更好,但需要微调才能充分发挥效果。
方案三:YaRN (Yet another RoPE extensioN)
我们最终采用的方案
YaRN是目前效果最好的RoPE扩展方法,结合了NTK和注意力缩放。
import torch
import math
class YaRNRotaryEmbedding(torch.nn.Module):
"""
YaRN: Yet another RoPE extensioN
核心改进:
1. 对不同频率分量使用不同的缩放策略
2. 添加注意力logits的温度缩放
"""
def __init__(
self,
dim: int,
max_position_embeddings: int = 4096,
base: float = 10000.0,
scale: float = 1.0,
original_max_position_embeddings: int = 4096,
extrapolation_factor: float = 1.0,
attn_factor: float = 1.0,
beta_fast: int = 32,
beta_slow: int = 1,
):
super().__init__()
self.dim = dim
self.max_position_embeddings = max_position_embeddings
self.base = base
self.scale = scale
self.original_max_position_embeddings = original_max_position_embeddings
self.extrapolation_factor = extrapolation_factor
self.attn_factor = attn_factor
self.beta_fast = beta_fast
self.beta_slow = beta_slow
self._build_cache()
def _build_cache(self):
# 计算每个频率分量应该如何插值
pos_freqs = self.base ** (torch.arange(0, self.dim, 2).float() / self.dim)
inv_freq_extrapolation = 1.0 / pos_freqs
inv_freq_interpolation = 1.0 / (self.scale * pos_freqs)
# 计算高低频分界点
low = max(
math.floor(
self.dim * math.log(self.original_max_position_embeddings /
(self.beta_fast * 2 * math.pi)) /
(2 * math.log(self.base))
), 0
)
high = min(
math.ceil(
self.dim * math.log(self.original_max_position_embeddings /
(self.beta_slow * 2 * math.pi)) /
(2 * math.log(self.base))
), self.dim - 1
)
# 计算平滑的插值权重
inv_freq = torch.zeros(self.dim // 2)
for i in range(self.dim // 2):
if i < low:
# 高频:不插值
inv_freq[i] = inv_freq_extrapolation[i]
elif i > high:
# 低频:完全插值
inv_freq[i] = inv_freq_interpolation[i]
else:
# 中频:线性混合
t = (i - low) / (high - low)
inv_freq[i] = (1 - t) * inv_freq_extrapolation[i] + \
t * inv_freq_interpolation[i]
self.register_buffer("inv_freq", inv_freq)
# 预计算cos/sin
t = torch.arange(self.max_position_embeddings).float()
freqs = torch.outer(t, self.inv_freq)
freqs = torch.cat([freqs, freqs], dim=-1)
# 注意力缩放因子
self.register_buffer("attn_scaling",
torch.tensor(0.1 * math.log(self.scale) + 1.0))
self.register_buffer("cos_cached", freqs.cos())
self.register_buffer("sin_cached", freqs.sin())
def forward(self, x, seq_len=None):
if seq_len is None:
seq_len = x.shape[-2]
return (
self.cos_cached[:seq_len],
self.sin_cached[:seq_len],
self.attn_scaling
)
1. 分频处理:高频分量外推,低频分量插值,中频分量混合
2. 注意力缩放:长序列时降低attention logits的scale,防止分布过于尖锐
3. 参数自动计算:根据目标长度自动确定插值策略
YaRN效果对比
| 序列长度 | 原始RoPE | PI | NTK | YaRN |
|---|---|---|---|---|
| 4096 | 5.2 | 5.4 | 5.3 | 5.2 |
| 16384 | 1203.5 | 7.8 | 6.2 | 5.5 |
| 32768 | - | 12.4 | 7.5 | 5.9 |
| 65536 | - | 24.7 | 11.2 | 6.4 |
| 131072 | - | - | 18.9 | 7.2 |
显存优化:长序列的另一个瓶颈
问题分析
Attention的显存复杂度是O(n²),32K序列的attention矩阵需要:
# 32K x 32K x 4字节(float32) x batch_size x n_heads
# = 32768 * 32768 * 4 * 1 * 32
# = 137GB (单batch!)
# 即使用float16也需要68GB,远超单卡显存
解决方案1:Flash Attention
from flash_attn import flash_attn_func
def flash_attention_forward(q, k, v, causal=True):
"""
使用Flash Attention 2
优势:
1. 显存O(n)而不是O(n²)
2. 速度快2-4倍
"""
# q, k, v: [batch, seq_len, n_heads, head_dim]
# 注意维度顺序和标准attention不同!
output = flash_attn_func(
q, k, v,
dropout_p=0.0,
softmax_scale=None, # 自动计算
causal=causal
)
return output
# 显存对比 (batch=1, heads=32, dim=128)
# 序列长度 | 标准Attention | Flash Attention
# 8192 | 8.2 GB | 1.1 GB
# 32768 | OOM | 4.3 GB
# 131072 | OOM | 17.2 GB
解决方案2:KV Cache优化
推理时,每生成一个token都要存储之前所有token的K和V。对于长序列,这是巨大的显存开销。
class PagedKVCache:
"""
分页KV Cache (类似vLLM的实现)
思想:不预分配全部显存,按需分配page
"""
def __init__(
self,
n_layers: int,
n_heads: int,
head_dim: int,
page_size: int = 16,
max_pages: int = 8192,
dtype=torch.float16,
device='cuda'
):
self.page_size = page_size
self.n_layers = n_layers
self.n_heads = n_heads
self.head_dim = head_dim
# 预分配page池
self.k_cache = torch.zeros(
max_pages, n_layers, n_heads, page_size, head_dim,
dtype=dtype, device=device
)
self.v_cache = torch.zeros(
max_pages, n_layers, n_heads, page_size, head_dim,
dtype=dtype, device=device
)
# 空闲page列表
self.free_pages = list(range(max_pages))
# 每个序列使用的page表
self.page_tables = {} # seq_id -> list of page indices
def allocate(self, seq_id: int, num_tokens: int):
"""为序列分配page"""
num_pages = (num_tokens + self.page_size - 1) // self.page_size
if len(self.free_pages) < num_pages:
raise RuntimeError("KV Cache空间不足")
pages = [self.free_pages.pop() for _ in range(num_pages)]
self.page_tables[seq_id] = pages
return pages
def append(self, seq_id: int, layer: int, k: torch.Tensor, v: torch.Tensor):
"""追加KV到cache"""
pages = self.page_tables[seq_id]
# ... 实际的写入逻辑
def get(self, seq_id: int, layer: int):
"""获取序列的完整KV"""
pages = self.page_tables[seq_id]
# ... 实际的读取逻辑
实际部署配置
最终方案
经过三个月的迭代,我们的配置是:
# 模型配置
model_config:
base_model: "Llama-2-13B-chat"
rope_scaling:
type: "yarn"
factor: 32 # 4K -> 128K
original_max_position_embeddings: 4096
# 推理配置
inference_config:
max_seq_len: 131072
attention_implementation: "flash_attention_2"
kv_cache_type: "paged"
dtype: "bfloat16"
# 硬件配置
hardware:
gpus: 4 x A100-80G
tensor_parallel: 4
性能数据
| 指标 | 数值 |
|---|---|
| 最大上下文长度 | 128K tokens |
| prefill吞吐(128K) | ~12K tokens/s |
| decode吞吐 | ~45 tokens/s |
| 首token延迟(128K) | ~11秒 |
| 显存占用 | ~290GB (4卡) |
Needle in Haystack测试
"大海捞针"是评估长文本能力的标准测试:在长文档中随机位置插入一个事实,看模型能否准确找到。
import random
def needle_in_haystack_test(model, context_lengths, depths):
"""
大海捞针测试
context_lengths: 测试的上下文长度列表
depths: 针插入的深度比例 (0.0-1.0)
"""
results = {}
# 生成干扰文本
haystack = load_paul_graham_essays() # 用Paul Graham的文章作为干扰
# 针
needle = "The special magic number is: 7429185."
question = "What is the special magic number?"
expected = "7429185"
for ctx_len in context_lengths:
results[ctx_len] = {}
for depth in depths:
# 计算插入位置
insert_pos = int(ctx_len * depth)
# 构造测试文本
prefix = haystack[:insert_pos]
suffix = haystack[insert_pos:ctx_len]
test_context = prefix + needle + suffix
# 测试
response = model.generate(
f"Context: {test_context}\n\nQuestion: {question}"
)
# 判断是否正确
success = expected in response
results[ctx_len][depth] = success
print(f"Length={ctx_len}, Depth={depth:.1%}: "
f"{'✓' if success else '✗'}")
return results
# 我们的测试结果
# Length=4K: 所有depth都成功
# Length=16K: depth<0.9成功,末尾有时失败
# Length=64K: depth 0.3-0.7成功率高,首尾较差
# Length=128K: 中间位置80%成功,首尾约50%
还没解决的问题
1. Lost in the Middle
不只是我们,很多长文本模型都有这个问题:中间位置的信息比首尾更容易被"遗忘"。有论文专门研究过这个现象,但目前没有完美解决方案。
2. 推理速度
128K context的首token延迟11秒,用户体验不好。我们尝试过:
- Speculative decoding:小模型起草,大模型验证。效果有限,因为瓶颈在prefill
- Prompt caching:相同前缀的请求复用KV Cache。对我们场景帮助不大
- 并行prefill:把长context分到多卡并行处理。实现复杂,还在研究
3. 质量degradation
虽然PPL看起来不错,但主观评测发现128K时的回答质量明显不如短文本。特别是需要综合多处信息的问题,经常漏掉一些细节。
参考资料
- RoFormer: Enhanced Transformer with Rotary Position Embedding
- Extending Context Window of Large Language Models via Positional Interpolation
- YaRN: Efficient Context Window Extension
- Lost in the Middle: How Language Models Use Long Contexts
- Flash Attention GitHub
总结
把上下文从4K扩展到128K,我们学到的经验:
- YaRN是目前效果最好的RoPE扩展方案,但需要理解原理才能调好参数
- 显存优化和位置编码同等重要,Flash Attention + Paged KV Cache是标配
- 长度≠质量,128K的PPL可以很低,但实际任务效果可能不如RAG
- 评测要全面,不只看PPL,还要做Needle in Haystack等任务测试
对于我们的法律文档场景,最终采用的是混合方案:短文档(<32K)直接处理,长文档先分块再用128K模型处理关键部分。这样在质量和成本之间取得了平衡。
更新记录:
2025-01-15: 初版发布
2025-01-22: 补充了Needle in Haystack测试结果
2025-02-01: 更新了生产环境性能数据