Kirito's Blog · 文章

从 Transformer 到 LLM:RoPE、KV Cache、FlashAttention 的工程化落地

2026-02-06 · 4 分钟 ·标签:LLMTransformer推理优化PyTorch系统工程

面向工程实战系统讲清 RoPE、KV Cache、FlashAttention:核心公式、复杂度分析、显存估算与可落地代码实现。

训练一个 Transformer 和让它在真实线上稳定跑起来,中间隔着一整套“系统工程”。 对自回归 LLM 而言,最关键的三件事通常是:

  1. 位置编码从绝对位置转向 RoPE(旋转位置编码)
  2. 推理阶段用 KV Cache 避免重复计算
  3. 注意力计算用 FlashAttention / SDP 内核 降低显存与 IO 瓶颈
LLM 推理工程总览
Prefill + Decode 流程:RoPE 处理 Q/K,KV Cache 按层追加,注意力内核优先使用 FlashAttention。

1. RoPE:把位置信息写进旋转变换

1.1 基本公式

设注意力头维度为 $d_h$,把向量按偶奇维两两分组。 对第 $i$ 对维度($2i,2i+1$),定义角频率:

$$ \theta_i = 10000^{-2i/d_h} $$

位置 $m$ 的旋转矩阵:

$$ R(m\theta_i)= \begin{bmatrix} \cos(m\theta_i) & -\sin(m\theta_i) \\ \sin(m\theta_i) & \cos(m\theta_i) \end{bmatrix} $$

对查询/键向量的该二维分量做旋转:

$$ \begin{bmatrix} q'_{2i} \\ q'_{2i+1} \end{bmatrix} = R(m\theta_i) \begin{bmatrix} q_{2i} \\ q_{2i+1} \end{bmatrix}, \quad \begin{bmatrix} k'_{2i} \\ k'_{2i+1} \end{bmatrix} = R(n\theta_i) \begin{bmatrix} k_{2i} \\ k_{2i+1} \end{bmatrix} $$

RoPE 的关键性质是注意力分数主要依赖相对位移:

$$ \langle R(m)q, R(n)k \rangle = \langle q, R(m-n)k \rangle $$

这使模型对相对位置关系更自然。

1.2 工程实现(PyTorch)

Python
1import torch2 3 4def build_rope_cache(seq_len: int, head_dim: int, device=None, dtype=torch.float32):5 # head_dim 必须是偶数6 half = head_dim // 27 freq = torch.arange(half, device=device, dtype=dtype)8 inv_freq = 1.0 / (10000 ** (2 * freq / head_dim)) # [half]9 10 pos = torch.arange(seq_len, device=device, dtype=dtype) # [T]11 angle = torch.outer(pos, inv_freq) # [T, half]12 13 cos = torch.cos(angle)14 sin = torch.sin(angle)15 return cos, sin # [T, half], [T, half]16 17 18def apply_rope(x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor, pos_ids: torch.Tensor):19 """20 x: [B, H, T, Dh]21 cos/sin: [max_T, Dh/2]22 pos_ids: [B, T]23 """24 bsz, n_heads, seq_len, head_dim = x.shape25 half = head_dim // 226 27 x_even = x[..., 0::2] # [B,H,T,half]28 x_odd = x[..., 1::2] # [B,H,T,half]29 30 cos_t = cos[pos_ids].unsqueeze(1) # [B,1,T,half]31 sin_t = sin[pos_ids].unsqueeze(1) # [B,1,T,half]32 33 out_even = x_even * cos_t - x_odd * sin_t34 out_odd = x_even * sin_t + x_odd * cos_t35 36 out = torch.empty_like(x)37 out[..., 0::2] = out_even38 out[..., 1::2] = out_odd39 return out

2. KV Cache:把 decode 复杂度从“重复前缀”降下来

2.1 为什么必须用 Cache

自回归解码第 $t$ 步只新增一个 token。 若每步都重算整个前缀,单层复杂度可近似看作:

$$ \sum_{t=1}^{T} O(t^2 d) = O(T^3 d) $$

有了 KV Cache 后,第 $t$ 步只算新 token 的 $Q_t,K_t,V_t$,并与历史 $K_{1:t},V_{1:t}$ 计算:

$$ \sum_{t=1}^{T} O(t d) = O(T^2 d) $$

在长上下文推理中差距极大。

2.2 缓存显存估算

设:

  • batch 为 $B$
  • 层数为 $L$
  • 头数为 $H$
  • 每头维度为 $d_h$
  • 上下文长度为 $T$
  • 元素字节数为 $s$(fp16/bf16 取 2)

KV 总内存约为:

$$ \text{Mem}_{KV} \approx 2 \cdot B \cdot L \cdot H \cdot T \cdot d_h \cdot s $$

前面的 2 表示同时缓存 K 和 V。

2.3 最小缓存实现

Python
1import torch2 3 4class LayerKVCache:5 def __init__(self):6 self.k = None # [B,H,T,Dh]7 self.v = None # [B,H,T,Dh]8 9 def append(self, k_new: torch.Tensor, v_new: torch.Tensor):10 # k_new/v_new: [B,H,T_new,Dh]11 if self.k is None:12 self.k = k_new13 self.v = v_new14 else:15 self.k = torch.cat([self.k, k_new], dim=2)16 self.v = torch.cat([self.v, v_new], dim=2)17 return self.k, self.v18 19 20def decode_step_attn(q_t, k_t, v_t, cache: LayerKVCache):21 """22 q_t/k_t/v_t: [B,H,1,Dh]23 """24 k_all, v_all = cache.append(k_t, v_t) # [B,H,T,Dh]25 26 scale = q_t.size(-1) ** -0.527 score = torch.matmul(q_t, k_all.transpose(-2, -1)) * scale # [B,H,1,T]28 prob = torch.softmax(score, dim=-1)29 out = torch.matmul(prob, v_all) # [B,H,1,Dh]30 return out

3. FlashAttention:同复杂度,不同 IO 命运

3.1 标准注意力的瓶颈

经典实现会显式物化 $QK^\top$(大小约 $T \times T$),再做 softmax 和乘 V。 时间复杂度仍是 $O(T^2 d)$,但显存/带宽开销很大。

FlashAttention 的核心是 分块 + 在线 softmax 归一化,避免存整张分数矩阵。

对每个查询行维护三个量:

  • $m_i$:当前最大值
  • $\ell_i$:归一化分母
  • $o_i$:输出累积

块更新(示意):

$$ m_i' = \max\left(m_i, \max_j s_{ij}\right) $$

$$ \ell_i' = e^{m_i-m_i'}\ell_i + \sum_j e^{s_{ij}-m_i'} $$

$$ o_i' = \frac{e^{m_i-m_i'}\ell_i o_i + \sum_j e^{s_{ij}-m_i'} v_j}{\ell_i'} $$

最终效果:在保持数值稳定的同时显著减少 HBM 读写。

3.2 工程落地:优先用 PyTorch SDP

在 PyTorch 2.x 中,scaled_dot_product_attention 会根据环境自动选用最优内核 (含 FlashAttention / memory-efficient / math fallback)。

Python
1import torch2import torch.nn.functional as F3 4 5def sdp_attention(q, k, v, is_causal: bool, dropout_p: float = 0.0):6 """7 q/k/v: [B,H,T,Dh]8 """9 # 推理时建议 dropout_p=010 with torch.backends.cuda.sdp_kernel(enable_flash=True,11 enable_mem_efficient=True,12 enable_math=True):13 out = F.scaled_dot_product_attention(14 q, k, v,15 attn_mask=None,16 dropout_p=dropout_p,17 is_causal=is_causal,18 )19 return out

4. 把三者接进同一个 Decoder Attention

Python
1import torch2import torch.nn as nn3 4 5class LLMDecoderAttention(nn.Module):6 def __init__(self, d_model: int, n_heads: int):7 super().__init__()8 assert d_model % n_heads == 09 self.n_heads = n_heads10 self.d_head = d_model // n_heads11 self.qkv = nn.Linear(d_model, 3 * d_model, bias=False)12 self.o_proj = nn.Linear(d_model, d_model, bias=False)13 14 def _shape(self, x):15 # [B,T,D] -> [B,H,T,Dh]16 bsz, seq_len, dim = x.shape17 x = x.view(bsz, seq_len, self.n_heads, self.d_head)18 return x.transpose(1, 2)19 20 def _merge(self, x):21 # [B,H,T,Dh] -> [B,T,D]22 bsz, _, seq_len, _ = x.shape23 return x.transpose(1, 2).contiguous().view(bsz, seq_len, self.n_heads * self.d_head)24 25 def forward(self, x, cos, sin, pos_ids, cache: LayerKVCache = None):26 qkv = self.qkv(x)27 q, k, v = qkv.chunk(3, dim=-1)28 29 q = self._shape(q)30 k = self._shape(k)31 v = self._shape(v)32 33 # 1) RoPE34 q = apply_rope(q, cos, sin, pos_ids)35 k = apply_rope(k, cos, sin, pos_ids)36 37 # 2) KV Cache(decode 时传 cache)38 if cache is not None:39 k, v = cache.append(k, v)40 is_causal = False # query 长度通常是 1,不需要再额外 causal mask41 else:42 is_causal = True # prefill / 训练43 44 # 3) FlashAttention / SDP45 out = sdp_attention(q, k, v, is_causal=is_causal, dropout_p=0.0)46 47 out = self._merge(out)48 out = self.o_proj(out)49 return out

5. Prefill vs Decode 的系统视角

阶段 主要特点 关键优化
Prefill(首轮) 一次性处理长 prompt,算力密集 FlashAttention、张量并行、算子融合
Decode(逐 token) 每步计算小但步数多,延迟敏感 KV Cache、调度优化、连续批处理

在实际服务中,常见 KPI:

  • TTFT(Time To First Token)
  • TPOT(Time Per Output Token)
  • 吞吐(tokens/s)

6. 常见坑位与排查

提示

先保证数值一致性(和 reference 实现对齐),再追求极限性能。

注意

大多数线上“突然变慢”问题,根因不是单个算子,而是 batch 调度、cache 碎片和内存带宽竞争。

建议的排查顺序:

  1. 正确性:RoPE 位置索引、cache 追加维度、mask 语义。
  2. 显存:KV Cache 占用是否超预算,是否需要 GQA / Paged KV。
  3. 内核路径:是否真的命中 flash kernel(dtype、shape、硬件条件)。
  4. 调度策略:prefill/decode 是否混跑导致尾延迟放大。
  5. 监控指标:TTFT、TPOT、GPU 利用率、HBM 带宽占用联动看。