Kirito's Blog · 文章

Transformer 详解:从注意力公式到最小可运行代码

2026-02-06 · 4 分钟 ·标签:Transformer深度学习NLPPyTorch

系统讲清 Transformer 的核心公式:缩放点积注意力、多头机制、位置编码、残差与归一化,并给出可运行的 PyTorch 最小实现。

Transformer 是现代大模型的基础。它的关键思想是: 用注意力替代循环结构,并通过多头机制在不同子空间并行建模关系。

Transformer 结构总览
Encoder-Decoder Transformer:输入嵌入 + 位置编码,经过多层注意力与前馈网络,再生成输出概率。

1. 输入表示与位置编码

对长度为 $T$ 的序列,令词嵌入矩阵为:

$$ X \in \mathbb{R}^{T \times d_{\text{model}}} $$

由于注意力本身与顺序无关,需要注入位置信息。经典正弦位置编码为:

$$ PE_{(pos,2i)} = \sin\left(\frac{pos}{10000^{2i/d_{\text{model}}}}\right), \quad PE_{(pos,2i+1)} = \cos\left(\frac{pos}{10000^{2i/d_{\text{model}}}}\right) $$

最终输入:

$$ H^{(0)} = X + PE $$

2. 缩放点积注意力(Scaled Dot-Product Attention)

先由输入线性映射出查询、键、值:

$$ Q = XW_Q,\quad K = XW_K,\quad V = XW_V $$

其中:

$$ Q,K \in \mathbb{R}^{T\times d_k}, \quad V \in \mathbb{R}^{T\times d_v} $$

注意力权重:

$$ A = \operatorname{softmax}\left(\frac{QK^\top}{\sqrt{d_k}} + M\right) $$

输出:

$$ \operatorname{Attention}(Q,K,V) = AV $$

这里 $M$ 是 mask(例如 decoder 的因果 mask)。 除以 $\sqrt{d_k}$ 的原因是抑制内积随维度增大而导致的 softmax 饱和。

3. 多头注意力(Multi-Head Attention)

单头只能在一个子空间里建模关系,多头机制让模型并行关注不同模式:

$$ \text{head}_i = \operatorname{Attention}(QW_i^Q, KW_i^K, VW_i^V) $$

拼接并投影:

$$ \operatorname{MHA}(Q,K,V) = \operatorname{Concat}(\text{head}_1,\ldots,\text{head}_h)W^O $$

常见维度关系:

$$ d_k = d_v = \frac{d_{\text{model}}}{h} $$

4. 前馈网络(Position-wise FFN)

每个 token 独立通过同一组两层 MLP:

$$ \operatorname{FFN}(x) = W_2\,\sigma(W_1x + b_1) + b_2 $$

常用激活是 ReLU / GELU。

5. 残差连接与层归一化

一个典型的 encoder 子层写作(Pre-LN 变体):

$$ \tilde{H} = H + \operatorname{MHA}(\operatorname{LN}(H)) $$

$$ H' = \tilde{H} + \operatorname{FFN}(\operatorname{LN}(\tilde{H})) $$

这种结构有助于深层训练稳定。

6. Decoder 与自回归约束

Decoder 比 Encoder 多了两点:

  1. Masked Self-Attention:只能看见当前位置及之前 token;
  2. Cross-Attention:Query 来自 decoder,Key/Value 来自 encoder 输出。

训练目标通常是下一个 token 的交叉熵:

$$ \mathcal{L}_{\mathrm{CE}} = -\sum_{t=1}^{T} \log p_\theta(y_t \mid y_{1:t-1}, x) $$

可配合 label smoothing 改善泛化。

7. 复杂度直觉

设序列长度为 $T$、模型维度 $d$:

  • 自注意力核心矩阵乘法复杂度约为 $O(T^2 d)$;
  • RNN 每步依赖前一步,难并行;
  • Transformer 训练时并行度高,但长序列时 $T^2$ 成本明显。
模块 主要计算 复杂度量级
$QK^{\top}$ 注意力打分 $O(T^2 d)$
$\operatorname{softmax}(QK^{\top})V$ 加权聚合 $O(T^2 d)$
$\operatorname{FFN}(x)$ 逐位置 MLP $O(T d d_{ff})$

8. PyTorch 最小实现(可运行)

下面是一份教学向最小代码(单文件可跑),包含:

  • 多头注意力
  • Encoder Block
  • 简化语言建模训练循环
Python
1import math2import torch3import torch.nn as nn4import torch.nn.functional as F5 6 7class MultiHeadSelfAttention(nn.Module):8 def __init__(self, d_model: int, n_heads: int, dropout: float = 0.1):9 super().__init__()10 assert d_model % n_heads == 011 self.d_model = d_model12 self.n_heads = n_heads13 self.d_head = d_model // n_heads14 15 self.w_q = nn.Linear(d_model, d_model)16 self.w_k = nn.Linear(d_model, d_model)17 self.w_v = nn.Linear(d_model, d_model)18 self.w_o = nn.Linear(d_model, d_model)19 self.dropout = nn.Dropout(dropout)20 21 def _split_heads(self, x):22 # x: [B, T, d_model] -> [B, h, T, d_head]23 bsz, seq_len, _ = x.size()24 x = x.view(bsz, seq_len, self.n_heads, self.d_head)25 return x.transpose(1, 2)26 27 def _merge_heads(self, x):28 # x: [B, h, T, d_head] -> [B, T, d_model]29 bsz, _, seq_len, _ = x.size()30 x = x.transpose(1, 2).contiguous()31 return x.view(bsz, seq_len, self.d_model)32 33 def forward(self, x, causal_mask: bool = False):34 # x: [B, T, d_model]35 q = self._split_heads(self.w_q(x))36 k = self._split_heads(self.w_k(x))37 v = self._split_heads(self.w_v(x))38 39 # scores: [B, h, T, T]40 scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.d_head)41 42 if causal_mask:43 t = scores.size(-1)44 mask = torch.triu(torch.ones(t, t, device=scores.device), diagonal=1).bool()45 scores = scores.masked_fill(mask, float('-inf'))46 47 attn = F.softmax(scores, dim=-1)48 attn = self.dropout(attn)49 50 out = torch.matmul(attn, v) # [B, h, T, d_head]51 out = self._merge_heads(out) # [B, T, d_model]52 out = self.w_o(out)53 return out54 55 56class FeedForward(nn.Module):57 def __init__(self, d_model: int, d_ff: int, dropout: float = 0.1):58 super().__init__()59 self.net = nn.Sequential(60 nn.Linear(d_model, d_ff),61 nn.GELU(),62 nn.Dropout(dropout),63 nn.Linear(d_ff, d_model),64 nn.Dropout(dropout),65 )66 67 def forward(self, x):68 return self.net(x)69 70 71class EncoderBlock(nn.Module):72 # Pre-LN73 def __init__(self, d_model: int, n_heads: int, d_ff: int, dropout: float = 0.1):74 super().__init__()75 self.ln1 = nn.LayerNorm(d_model)76 self.attn = MultiHeadSelfAttention(d_model, n_heads, dropout)77 self.ln2 = nn.LayerNorm(d_model)78 self.ffn = FeedForward(d_model, d_ff, dropout)79 80 def forward(self, x):81 x = x + self.attn(self.ln1(x), causal_mask=False)82 x = x + self.ffn(self.ln2(x))83 return x84 85 86class TinyTransformerLM(nn.Module):87 def __init__(self, vocab_size: int, d_model: int = 256, n_heads: int = 8,88 d_ff: int = 1024, n_layers: int = 4, max_len: int = 512):89 super().__init__()90 self.token_emb = nn.Embedding(vocab_size, d_model)91 self.pos_emb = nn.Embedding(max_len, d_model)92 self.blocks = nn.ModuleList([93 EncoderBlock(d_model, n_heads, d_ff) for _ in range(n_layers)94 ])95 self.ln_f = nn.LayerNorm(d_model)96 self.head = nn.Linear(d_model, vocab_size, bias=False)97 98 def forward(self, idx):99 # idx: [B, T]100 bsz, seq_len = idx.size()101 pos = torch.arange(seq_len, device=idx.device).unsqueeze(0)102 x = self.token_emb(idx) + self.pos_emb(pos)103 104 # 语言模型场景:这里用 causal mask 的自注意力更合理105 for block in self.blocks:106 x = x + block.attn(block.ln1(x), causal_mask=True)107 x = x + block.ffn(block.ln2(x))108 109 x = self.ln_f(x)110 logits = self.head(x) # [B, T, V]111 return logits112 113 114def train_step(model, optimizer, batch_x, batch_y):115 """116 batch_x: [B, T] 输入 token117 batch_y: [B, T] 目标 token(右移后的真值)118 """119 model.train()120 logits = model(batch_x) # [B, T, V]121 122 bsz, seq_len, vocab_size = logits.shape123 loss = F.cross_entropy(124 logits.view(bsz * seq_len, vocab_size),125 batch_y.view(bsz * seq_len)126 )127 128 optimizer.zero_grad()129 loss.backward()130 torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)131 optimizer.step()132 133 return loss.item()134 135 136if __name__ == '__main__':137 torch.manual_seed(42)138 139 vocab_size = 5000140 model = TinyTransformerLM(vocab_size=vocab_size)141 optimizer = torch.optim.AdamW(model.parameters(), lr=3e-4, weight_decay=0.01)142 143 # toy batch144 batch_x = torch.randint(0, vocab_size, (8, 64))145 batch_y = torch.randint(0, vocab_size, (8, 64))146 147 for step in range(1, 6):148 loss = train_step(model, optimizer, batch_x, batch_y)149 print(f'step={step}, loss={loss:.4f}')

9. 训练与调参建议

提示

先确保 loss 能稳定下降,再谈更复杂技巧(RoPE、FlashAttention、MoE 等)。

注意

如果训练初期 loss 波动很大,优先检查:学习率是否过高、warmup 是否缺失、梯度裁剪是否生效。

建议优先排查这 5 项:

  1. 学习率 + warmup(最常见问题源)
  2. attention mask 是否正确
  3. LayerNorm 放置位置(Pre-LN 更稳)
  4. mixed precision 下的数值稳定性
  5. tokenization 与数据清洗质量