Kirito's Blog · 文章

旋转位置编码 RoPE 深度解析:从数学原理到工程实现

2026-03-19 · 9 分钟 ·标签:RoPE位置编码TransformerLLM数学推导

从复数域推导到高维旋转矩阵,从相对位置性质证明到长度外推技术,系统拆解 RoPE 的每一个细节。

Transformer 的自注意力机制本身是置换不变的——打乱输入顺序,输出跟着打乱,但数值完全不变。要让模型感知序列中"谁在前、谁在后",必须额外注入位置信息。

位置编码的设计经历了几代演进:

方案 代表工作 核心思路 局限
绝对正弦编码 Vaswani et al. 2017 用固定频率的 sin/cos 信号叠加到 embedding 无法直接表达相对位置
可学习绝对编码 BERT, GPT-2 学一张位置嵌入表 固定长度,外推困难
相对位置偏置 Shaw et al. 2018, T5 在注意力分数中加可学习偏置 需要额外参数,且实现复杂
ALiBi Press et al. 2022 按距离线性衰减注意力分数 不编码位置语义,仅引入衰减
RoPE Su et al. 2021 在复数域对 Q/K 做旋转,内积自动蕴含相对位置 本文详解

RoPE(Rotary Position Embedding)被 LLaMA、Qwen、Mistral、Gemma 等几乎所有主流开源 LLM 采用,已成为事实标准。本文从数学推导出发,彻底讲清 RoPE 的原理、性质、工程实现与长度外推技术。

1. 从复数域出发的推导

1.1 二维情形:核心直觉

RoPE 的灵感来自一个简洁的观察:在复数域中,乘以一个单位复数就是旋转

设 $q, k \in \mathbb{R}^2$,将它们视为复数 $q = q_0 + q_1 i$,$k = k_0 + k_1 i$。若 token 分别在位置 $m$ 和 $n$,定义带位置信息的编码:

$$ \tilde{q}_m = q \cdot e^{im\theta}, \quad \tilde{k}_n = k \cdot e^{in\theta} $$

其中 $\theta$ 是预先确定的角频率。

注意力分数需要计算 $\tilde{q}_m$ 与 $\tilde{k}_n$ 的内积。在复数域中:

$$ \operatorname{Re}\bigl[\tilde{q}_m \cdot \overline{\tilde{k}_n}\bigr] = \operatorname{Re}\bigl[q \cdot e^{im\theta} \cdot \bar{k} \cdot e^{-in\theta}\bigr] = \operatorname{Re}\bigl[q\bar{k} \cdot e^{i(m-n)\theta}\bigr] $$

结果只依赖 相对位置 $m - n$,与绝对位置无关。这是 RoPE 最核心的性质。

1.2 旋转的几何意义

将上述复数乘法展开为实数矩阵运算:

$$ \tilde{q}_m = \begin{bmatrix} \cos(m\theta) & -\sin(m\theta) \\ \sin(m\theta) & \cos(m\theta) \end{bmatrix} \begin{bmatrix} q_0 \\ q_1 \end{bmatrix} $$

这就是一个标准的 二维旋转矩阵 $R_{m\theta}$。乘以 $e^{im\theta}$ 等价于在二维平面上旋转角度 $m\theta$。

提示

RoPE 的名字由此而来:Rotary(旋转)Position Embedding。位置信息不是"加"上去的,而是通过旋转"编织"进向量的方向中。

2. 推广到高维:分块旋转矩阵

2.1 分组策略

实际 Transformer 的注意力头维度 $d_h$ 远大于 2(通常 64 或 128)。RoPE 的做法是:将 $d_h$ 维向量按相邻两维分组,共 $d_h/2$ 组,每组独立做二维旋转。

设 $q \in \mathbb{R}^{d_h}$,分组后对第 $i$ 组(维度 $2i, 2i+1$),旋转角度为 $m\theta_i$:

$$ \begin{bmatrix} \tilde{q}_{2i} \\ \tilde{q}_{2i+1} \end{bmatrix} = \begin{bmatrix} \cos(m\theta_i) & -\sin(m\theta_i) \\ \sin(m\theta_i) & \cos(m\theta_i) \end{bmatrix} \begin{bmatrix} q_{2i} \\ q_{2i+1} \end{bmatrix}, \quad i = 0, 1, \ldots, \frac{d_h}{2} - 1 $$

2.2 完整旋转矩阵的结构

将所有分组组合起来,完整的旋转矩阵 $R_m \in \mathbb{R}^{d_h \times d_h}$ 是一个块对角矩阵

$$ R_m = \operatorname{diag}\!\Bigl( R(m\theta_0),\; R(m\theta_1),\; \ldots,\; R(m\theta_{d_h/2-1}) \Bigr) $$

其中每个 $R(m\theta_i)$ 是 $2\times2$ 旋转矩阵。这个结构有两个重要性质:

  1. 稀疏性:$d_h \times d_h$ 矩阵中只有 $2d_h$ 个非零元素,计算量为 $O(d_h)$ 而非 $O(d_h^2)$
  2. 正交性:$R_m^T R_m = I$,旋转不改变向量的模长,信息无损

2.3 角频率的设计:为什么是 $10000^{-2i/d_h}$

每组的角频率定义为:

$$ \theta_i = b^{-2i/d_h}, \quad b = 10000 $$

  • 低维组($i$ 小):$\theta_i$ 大,旋转快,捕捉短距离依赖
  • 高维组($i$ 大):$\theta_i$ 小,旋转慢,捕捉长距离依赖

这与 Transformer 原始正弦编码的频率设计一脉相承。直觉上,不同频率组成的"频谱"让模型能同时感知局部和全局的位置关系。

底数 $b = 10000$ 的选择保证了在常见训练长度(2K–8K tokens)下,最低频维度的旋转周期足够覆盖整个序列,而最高频维度能区分相邻 token。

提示

底数 $b$ 是长度外推技术的核心调节旋钮。后文会看到,NTK-Aware Scaling 本质上就是增大 $b$。

3. 核心性质的严格证明

3.1 相对位置性质

命题:对任意 $q, k \in \mathbb{R}^{d_h}$,有 $\langle R_m q,\; R_n k \rangle = \langle q,\; R_{m-n} k \rangle$。

证明

$$ \langle R_m q,\; R_n k \rangle = (R_m q)^T (R_n k) = q^T R_m^T R_n k $$

由于每个 $2\times2$ 旋转块满足 $R(\alpha)^T R(\beta) = R(\beta - \alpha)$:

$$ R_m^T R_n = \operatorname{diag}\!\bigl( R((n-m)\theta_0),\; R((n-m)\theta_1),\; \ldots \bigr) = R_{n-m} $$

因此:

$$ q^T R_m^T R_n k = q^T R_{n-m} k = \langle q,\; R_{n-m} k \rangle $$

由于 $R_{n-m} = R_{-(m-n)}$,而旋转方向取反等价于 $R_{m-n}^T$(但内积中只需差值), 最终注意力分数只取决于相对位置 $m - n$。 $\blacksquare$

3.2 远程衰减性质

RoPE 具有一个隐含的远程衰减特性。对随机初始化的 $q, k$(各维度独立同分布),可以分析 $\mathbb{E}[\langle R_m q,\; R_n k \rangle]$ 随 $|m-n|$ 的变化。

设 $\Delta = m - n$,每个二维分组的贡献为:

$$ q_{2i} k_{2i} \cos(\Delta \theta_i) + q_{2i+1} k_{2i+1} \cos(\Delta \theta_i) - q_{2i} k_{2i+1} \sin(\Delta \theta_i) + q_{2i+1} k_{2i} \sin(\Delta \theta_i) $$

对所有组求和后,由于各 $\theta_i$ 不同,余弦项在大 $|\Delta|$ 时互相抵消,产生类似于多频干涉衰减的效果。虽然这种衰减不是单调的(存在振荡),但整体趋势是远处的贡献更弱。

注意

远程衰减是统计性质(期望意义),不是逐样本保证。模型经过训练后,特定注意力头可能学会关注远距离位置。

3.3 信息保持

由于 $R_m$ 是正交矩阵:

  • $\|R_m q\| = \|q\|$:旋转不改变向量长度
  • $R_m$ 可逆($R_m^{-1} = R_m^T$):位置编码不丢失任何信息
  • 不同位置的编码不会坍缩到相同的子空间

这些性质保证了 RoPE 是一种"无损"的位置注入方式,与直接相加的绝对编码不同(加法编码会干扰原始语义向量的方向和模长)。

4. RoPE 的等价计算形式

在实际代码中,我们不会显式构造 $d_h \times d_h$ 的旋转矩阵再做矩阵乘法。有两种高效的等价形式。

4.1 逐元素乘法形式

将旋转展开:

$$ \tilde{q}_{2i} = q_{2i}\cos(m\theta_i) - q_{2i+1}\sin(m\theta_i) $$

$$ \tilde{q}_{2i+1} = q_{2i}\sin(m\theta_i) + q_{2i+1}\cos(m\theta_i) $$

预计算 $\cos$ 和 $\sin$ 缓存后,只需 4 次逐元素乘法 + 2 次加法,计算量为 $O(d_h)$。

4.2 复数乘法形式

将 $(q_{2i}, q_{2i+1})$ 视为复数 $q_{2i} + q_{2i+1}i$,则旋转等价于乘以 $e^{im\theta_i} = \cos(m\theta_i) + i\sin(m\theta_i)$。

在支持复数运算的框架中(如 PyTorch),可以:

Python
1# 将实数张量视为复数2q_complex = torch.view_as_complex(q.reshape(*q.shape[:-1], -1, 2))3# 旋转 = 复数乘法4freqs_complex = torch.polar(torch.ones_like(freqs), freqs) # e^{iθ}5q_rotated = torch.view_as_real(q_complex * freqs_complex).flatten(-2)

两种形式在数学上完全等价,选择哪种取决于框架支持和性能。

5. 工程实现详解(PyTorch)

5.1 频率缓存构建

Python
1import torch2 3def build_rope_cache(4 max_seq_len: int,5 head_dim: int,6 base: float = 10000.0,7 device: torch.device = None,8 dtype: torch.dtype = torch.float32,9):10 """预计算 cos/sin 缓存。11 12 返回:13 cos: [max_seq_len, head_dim // 2]14 sin: [max_seq_len, head_dim // 2]15 """16 half = head_dim // 217 # 角频率: θ_i = base^{-2i/d}18 inv_freq = 1.0 / (base ** (torch.arange(0, head_dim, 2, device=device, dtype=dtype) / head_dim))19 # 位置 × 角频率 -> 角度矩阵20 pos = torch.arange(max_seq_len, device=device, dtype=dtype)21 angles = torch.outer(pos, inv_freq) # [max_seq_len, half]22 return angles.cos(), angles.sin()
提示

缓存只需构建一次,推理时按 pos_ids 索引即可。显存开销为 $2 \times T_{\max} \times d_h/2$ 个浮点数,对 $T_{\max}=8192, d_h=128$ 仅约 4 MB(float32)。

5.2 Interleave 方式 vs Half-Split 方式

两种主流实现在维度分组方式上不同:

Interleave(交错排列)——原始论文方式,LLaMA 采用:

将相邻两维 $(0,1), (2,3), \ldots$ 配对。

Python
1def apply_rope_interleave(x, cos, sin, pos_ids):2 """3 x: [B, H, T, D] 其中 D = head_dim4 cos, sin: [max_T, D//2]5 pos_ids: [B, T]6 """7 cos_t = cos[pos_ids].unsqueeze(1) # [B, 1, T, D//2]8 sin_t = sin[pos_ids].unsqueeze(1)9 10 x_even = x[..., 0::2] # [B, H, T, D//2]11 x_odd = x[..., 1::2]12 13 out = torch.empty_like(x)14 out[..., 0::2] = x_even * cos_t - x_odd * sin_t15 out[..., 1::2] = x_even * sin_t + x_odd * cos_t16 return out

Half-Split(前后半分)——GPT-NeoX / HuggingFace 常用:

将前半维 $(0, \ldots, d_h/2-1)$ 与后半维 $(d_h/2, \ldots, d_h-1)$ 配对。

Python
1def apply_rope_half_split(x, cos, sin, pos_ids):2 """3 x: [B, H, T, D]4 """5 cos_t = cos[pos_ids].unsqueeze(1)6 sin_t = sin[pos_ids].unsqueeze(1)7 8 half = x.shape[-1] // 29 x_first = x[..., :half] # [B, H, T, D//2]10 x_second = x[..., half:]11 12 out_first = x_first * cos_t - x_second * sin_t13 out_second = x_first * sin_t + x_second * cos_t14 return torch.cat([out_first, out_second], dim=-1)
注意

两种方式不兼容。加载预训练权重时必须确认模型使用的是哪种分组,否则位置编码会完全错乱。LLaMA 系列使用 interleave,GPT-NeoX 系列使用 half-split。

5.3 复数视图实现

Python
1def apply_rope_complex(x, cos, sin, pos_ids):2 """利用 PyTorch 的复数运算实现 RoPE。"""3 cos_t = cos[pos_ids].unsqueeze(1)4 sin_t = sin[pos_ids].unsqueeze(1)5 6 # 构造旋转复数 e^{iθ} = cos + i·sin7 freqs = torch.complex(cos_t, sin_t) # [B, 1, T, D//2]8 9 # 将 x 视为复数10 x_complex = torch.view_as_complex(11 x.float().reshape(*x.shape[:-1], -1, 2)12 ) # [B, H, T, D//2]13 14 # 复数乘法 = 旋转15 x_rotated = torch.view_as_real(x_complex * freqs).flatten(-2)16 return x_rotated.type_as(x)

这种方式代码最简洁,但需要注意:

  1. view_as_complex 要求输入为 float32(bfloat16 不支持复数运算)
  2. 需要额外的类型转换开销
  3. 语义等价于 interleave 方式

5.4 与 LLaMA 源码的对照

LLaMA 官方实现(modeling_llama.py)中的核心逻辑:

Python
1# HuggingFace Transformers 中的 LlamaRotaryEmbedding2class LlamaRotaryEmbedding(nn.Module):3 def __init__(self, dim, max_position_embeddings=2048, base=10000):4 super().__init__()5 inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim))6 self.register_buffer("inv_freq", inv_freq)7 self.max_seq_len_cached = max_position_embeddings8 9 def forward(self, x, position_ids):10 # 构建频率矩阵并返回 cos, sin11 freqs = torch.outer(position_ids.float(), self.inv_freq)12 emb = torch.cat((freqs, freqs), dim=-1) # 注意这里复制了一遍13 return emb.cos(), emb.sin()

注意 HuggingFace 版本将 freqs 复制拼接了一遍(torch.cat((freqs, freqs), dim=-1)),因此 cos/sin 的维度是 $[T, d_h]$ 而非 $[T, d_h/2]$。这是为了在 apply_rotary_pos_emb 中直接做逐元素乘法,无需手动拆分偶奇维。

6. 长度外推:超出训练长度后怎么办

6.1 原始 RoPE 为什么会失效

假设模型在 $L_{\text{train}} = 4096$ 长度上训练。当推理时序列长度为 $L > L_{\text{train}}$,会出现训练时从未见过的角度值 $m\theta_i$(当 $m > L_{\text{train}}$)。

具体来说,高频维度($\theta_i$ 大)在训练范围内已经转了很多圈,超出后模式类似;但低频维度($\theta_i$ 小)在训练范围内只转了不到一圈,超出后进入全新的角度区域,模型无法正确泛化。

6.2 位置插值(PI, Position Interpolation)

最直接的方法:将位置索引线性缩放到训练范围内。

$$ m' = m \cdot \frac{L_{\text{train}}}{L} $$

等价于将 RoPE 频率整体降低:

$$ \theta_i' = \theta_i \cdot \frac{L_{\text{train}}}{L} $$

优点:简单有效,所有角度都落在训练范围内。

缺点:所有频率被均匀压缩,高频维度的分辨率下降,短距离位置关系变得模糊。需要少量微调来恢复性能。

6.3 NTK-Aware Scaling

核心思想:不均匀地缩放频率——低频维度多缩放(它们是外推的瓶颈),高频维度少缩放(保持短距离分辨率)。

实现方式是增大底数 $b$:

$$ b' = b \cdot \alpha^{d_h/(d_h-2)}, \quad \alpha = \frac{L}{L_{\text{train}}} $$

由此每个维度的角频率变为:

$$ \theta_i' = (b')^{-2i/d_h} $$

高维组($i$ 大,低频)的 $\theta_i'$ 被大幅压缩,低维组($i$ 小,高频)几乎不变。

Python
1def build_rope_cache_ntk(max_seq_len, head_dim, base=10000.0, train_len=4096):2 """NTK-Aware RoPE 缓存。"""3 if max_seq_len > train_len:4 alpha = max_seq_len / train_len5 base = base * alpha ** (head_dim / (head_dim - 2))6 7 return build_rope_cache(max_seq_len, head_dim, base=base)
提示

NTK-Aware Scaling 的名字来自神经正切核(Neural Tangent Kernel)理论中对高频/低频分量不同缩放需求的分析。

6.4 YaRN(Yet another RoPE extensioN method)

YaRN 进一步优化,将频率维度分为三个区域:

区域 频率特征 处理方式
高频区 $\theta_i$ 大,训练内已转多圈 不缩放(保持原样)
中频区 $\theta_i$ 中等 线性插值过渡
低频区 $\theta_i$ 小,训练内不足一圈 做 PI 缩放

分界由"波长"决定:

$$ \lambda_i = \frac{2\pi}{\theta_i} $$

设两个阈值 $\lambda_{\text{low}}, \lambda_{\text{high}}$,对应的缩放因子:

$$ s_i = \begin{cases} 1 & \text{if } \lambda_i < \lambda_{\text{low}} \\ \frac{1 - r_i}{1 - \alpha^{-1}} & \text{if } \lambda_{\text{low}} \le \lambda_i \le \lambda_{\text{high}} \\ \alpha^{-1} & \text{if } \lambda_i > \lambda_{\text{high}} \end{cases} $$

其中 $r_i$ 是 $\lambda_i$ 在 $[\lambda_{\text{low}}, \lambda_{\text{high}}]$ 中的线性插值比例,$\alpha = L / L_{\text{train}}$。

YaRN 还额外引入了一个注意力缩放因子来补偿长序列中注意力分数的统计变化。

6.5 Dynamic NTK

在推理时根据当前序列的实际长度动态调整底数:

$$ b'(l) = b \cdot \left(\frac{l}{L_{\text{train}}}\right)^{d_h/(d_h-2)}, \quad l > L_{\text{train}} $$

当 $l \le L_{\text{train}}$ 时不缩放。这种方式无需预设目标长度,适合流式场景。

6.6 方法对比

方法 需要微调 短距离保持 长距离外推 实现复杂度
直接外推
位置插值 (PI) 少量 下降
NTK-Aware 少量/无
YaRN 少量/无 最好 最好
Dynamic NTK 较好

7. RoPE 变体与相关工作

7.1 ALiBi:另一条路

ALiBi(Attention with Linear Biases)不编码位置到向量中,而是直接在注意力分数上加一个与距离成正比的负偏置:

$$ a_{ij} = q_i^T k_j - r \cdot |i - j| $$

其中 $r$ 是每个头的固定斜率。

ALiBi 与 RoPE 的对比:

  • ALiBi 更简单(无需 cos/sin 缓存),但只编码"距离"而非"位置"
  • ALiBi 的线性衰减是硬编码的,RoPE 的衰减模式由训练学得
  • 在长上下文任务中,RoPE + 外推技术通常优于 ALiBi

7.2 xPos:带显式衰减的 RoPE

xPos 在 RoPE 基础上引入一个与位置相关的衰减因子:

$$ \tilde{q}_m = R_m q \cdot \gamma^m, \quad \tilde{k}_n = R_n k \cdot \gamma^{-n} $$

注意力分数变为 $\langle q, R_{m-n} k \rangle \cdot \gamma^{m-n}$,显式引入了指数衰减。这使得长距离的注意力信号逐渐减弱,有助于训练稳定性。

7.3 多维 RoPE

在视觉 Transformer(ViT)中,输入的 patch 具有二维空间位置 $(x, y)$。多维 RoPE 将 $d_h$ 维度均分给每个空间维度:

  • 前 $d_h/2$ 维使用 $x$ 坐标的旋转
  • 后 $d_h/2$ 维使用 $y$ 坐标的旋转

$$ R_{x,y} = \operatorname{diag}\bigl(R_x^{(0:d_h/2)},\; R_y^{(d_h/2:d_h)}\bigr) $$

这种思路可以推广到 3D(视频)甚至更高维度的位置编码。

7.4 RoPE 在 MLA 中的应用

DeepSeek-V2 提出的 Multi-head Latent Attention (MLA) 将 KV 投影到低秩空间以压缩 KV Cache。但 RoPE 的位置依赖性与低秩压缩冲突——旋转后的 K 无法被有效压缩。

MLA 的解决方案是将每个头的维度分为两部分:

  • 位置无关部分:参与低秩压缩,不加 RoPE
  • 位置相关部分:独立的小维度(如 $d_{\text{rope}} = 64$),单独加 RoPE

$$ k = [k_{\text{compress}};\; k_{\text{rope}}], \quad \tilde{k} = [k_{\text{compress}};\; R_n k_{\text{rope}}] $$

这种分离设计同时保留了 RoPE 的位置编码能力和 KV Cache 的压缩效率。

8. 数值分析与可视化

8.1 频率谱分析

对于 $d_h = 128$,$b = 10000$,角频率分布为:

$$ \theta_i = 10000^{-2i/128}, \quad i = 0, 1, \ldots, 63 $$

Python
1import torch2 3head_dim = 1284base = 10000.05inv_freq = 1.0 / (base ** (torch.arange(0, head_dim, 2).float() / head_dim))6 7# 最高频率 (i=0): θ_0 = 1.0 rad/pos → 周期 ≈ 6.3 positions8# 最低频率 (i=63): θ_63 ≈ 0.0001 rad/pos → 周期 ≈ 62832 positions9print(f"最高频率: {inv_freq[0]:.4f}, 周期: {2 * 3.14159 / inv_freq[0]:.1f} positions")10print(f"最低频率: {inv_freq[-1]:.6f}, 周期: {2 * 3.14159 / inv_freq[-1]:.0f} positions")

这意味着:

  • $i=0$ 的维度对每约 6 个 token 的间隔即可完成一次完整旋转,对短距离变化极其敏感
  • $i=63$ 的维度需要约 63000 个 token 才能转一圈,对长距离结构缓慢变化

8.2 注意力分数的位置衰减

固定随机 $q, k$,计算不同相对位置 $\Delta$ 下的注意力分数:

Python
1def rope_attention_by_distance(head_dim=128, max_dist=512, base=10000.0):2 """计算 RoPE 注意力分数随相对距离的变化。"""3 cos_cache, sin_cache = build_rope_cache(max_dist + 1, head_dim, base)4 5 torch.manual_seed(42)6 q = torch.randn(head_dim)7 k = torch.randn(head_dim)8 9 scores = []10 for delta in range(max_dist + 1):11 # 对 q 施加位置 delta 的旋转,k 保持位置 012 q_rot = q.clone()13 q_even, q_odd = q[0::2], q[1::2]14 q_rot[0::2] = q_even * cos_cache[delta] - q_odd * sin_cache[delta]15 q_rot[1::2] = q_even * sin_cache[delta] + q_odd * cos_cache[delta]16 scores.append(torch.dot(q_rot, k).item())17 18 return scores

结果呈现明显的多频振荡衰减模式:在距离 0 附近值最大,随距离增加在振荡中逐渐衰减。这正是不同频率 cos 分量叠加干涉的效果。

8.3 底数对衰减的影响

底数 $b$ 最低频周期 4K 长度覆盖率 适用场景
10000 ~63K 充分 标准训练长度 (≤8K)
100000 ~200K 充分 中等长度外推 (~32K)
1000000 ~630K 充分 长上下文 (~128K)

增大底数使所有频率降低,等效拉长了位置编码的"视野",但也降低了短距离分辨率——这正是 NTK-Aware 和 YaRN 方法需要精细调控的原因。

9. 工程最佳实践与常见陷阱

9.1 精度问题:float32 vs bfloat16

RoPE 的 cos/sin 计算对数值精度敏感:

Python
1# 错误做法:在 bfloat16 下计算角频率2inv_freq_bf16 = 1.0 / (10000 ** (torch.arange(0, 128, 2, dtype=torch.bfloat16) / 128))3 4# 正确做法:始终在 float32 下预计算,应用时再转换5inv_freq_fp32 = 1.0 / (10000 ** (torch.arange(0, 128, 2, dtype=torch.float32) / 128))
注意

bfloat16 的尾数只有 7 位(float32 有 23 位)。在大位置索引下($m > 10000$),bfloat16 的 $\cos(m\theta)$ 误差可能导致注意力模式严重偏移。始终在 float32 下计算 cos/sin 缓存,然后将结果转换到目标精度。

9.2 缓存策略

推荐做法

Python
1class RotaryEmbedding(torch.nn.Module):2 def __init__(self, head_dim, max_seq_len, base=10000.0):3 super().__init__()4 # 在初始化时预计算,注册为 buffer(随模型移动设备)5 cos, sin = build_rope_cache(max_seq_len, head_dim, base)6 self.register_buffer("cos_cached", cos, persistent=False)7 self.register_buffer("sin_cached", sin, persistent=False)8 9 def forward(self, x, pos_ids):10 return apply_rope_interleave(x, self.cos_cached, self.sin_cached, pos_ids)
  • 使用 register_buffer 而非普通张量,确保 .to(device).half() 时自动跟随
  • persistent=False 避免序列化到 checkpoint(可随时重建)

9.3 与 KV Cache 的配合

在推理时使用 KV Cache,每个 decode 步骤只处理 1 个新 token:

Python
1# decode 阶段: 只对新 token 的 Q, K 做 RoPE2# pos_ids 是当前步的绝对位置(标量),而非从 0 开始3new_q = apply_rope(q_new, cos, sin, pos_ids=current_pos)4new_k = apply_rope(k_new, cos, sin, pos_ids=current_pos)5 6# K 存入 cache(已经带有 RoPE)7kv_cache.append(new_k, new_v)8 9# 注意力计算使用 cache 中已旋转的 K10attn = scaled_dot_product(new_q, kv_cache.keys, kv_cache.values)

关键点:K 在存入 cache 前就应用了 RoPE,后续使用时无需重新旋转。这保证了 $\langle R_m q, R_n k \rangle$ 中 $m$ 和 $n$ 分别是查询和键各自的绝对位置。

9.4 常见 Bug 清单

Bug 表现 排查方式
Interleave/Half-Split 搞混 加载预训练权重后 PPL 暴涨 对比原始模型代码的拆分方式
cos/sin 缓存维度错误 形状广播报错或静默错误 打印 cos.shape,确认是 [T, D//2]
位置索引从 1 开始 所有位置偏移 1,短序列影响明显 确认 pos_ids 从 0 开始
忘记对 K 也做 RoPE 注意力分数失去位置依赖 检查 Q 和 K 是否都经过旋转
bfloat16 下计算频率 长序列性能退化 检查 inv_freq 的 dtype
多卡推理时 buffer 未同步 不同 GPU 上缓存不一致 使用 register_buffer 而非手动管理