Transformer 详解:从注意力公式到最小可运行代码
系统讲清 Transformer 的核心公式:缩放点积注意力、多头机制、位置编码、残差与归一化,并给出可运行的 PyTorch 最小实现。
Transformer 是现代大模型的基础。它的关键思想是: 用注意力替代循环结构,并通过多头机制在不同子空间并行建模关系。
1. 输入表示与位置编码
对长度为 $T$ 的序列,令词嵌入矩阵为:
$$ X \in \mathbb{R}^{T \times d_{\text{model}}} $$
由于注意力本身与顺序无关,需要注入位置信息。经典正弦位置编码为:
$$ PE_{(pos,2i)} = \sin\left(\frac{pos}{10000^{2i/d_{\text{model}}}}\right), \quad PE_{(pos,2i+1)} = \cos\left(\frac{pos}{10000^{2i/d_{\text{model}}}}\right) $$
最终输入:
$$ H^{(0)} = X + PE $$
2. 缩放点积注意力(Scaled Dot-Product Attention)
先由输入线性映射出查询、键、值:
$$ Q = XW_Q,\quad K = XW_K,\quad V = XW_V $$
其中:
$$ Q,K \in \mathbb{R}^{T\times d_k}, \quad V \in \mathbb{R}^{T\times d_v} $$
注意力权重:
$$ A = \operatorname{softmax}\left(\frac{QK^\top}{\sqrt{d_k}} + M\right) $$
输出:
$$ \operatorname{Attention}(Q,K,V) = AV $$
这里 $M$ 是 mask(例如 decoder 的因果 mask)。 除以 $\sqrt{d_k}$ 的原因是抑制内积随维度增大而导致的 softmax 饱和。
3. 多头注意力(Multi-Head Attention)
单头只能在一个子空间里建模关系,多头机制让模型并行关注不同模式:
$$ \text{head}_i = \operatorname{Attention}(QW_i^Q, KW_i^K, VW_i^V) $$
拼接并投影:
$$ \operatorname{MHA}(Q,K,V) = \operatorname{Concat}(\text{head}_1,\ldots,\text{head}_h)W^O $$
常见维度关系:
$$ d_k = d_v = \frac{d_{\text{model}}}{h} $$
4. 前馈网络(Position-wise FFN)
每个 token 独立通过同一组两层 MLP:
$$ \operatorname{FFN}(x) = W_2\,\sigma(W_1x + b_1) + b_2 $$
常用激活是 ReLU / GELU。
5. 残差连接与层归一化
一个典型的 encoder 子层写作(Pre-LN 变体):
$$ \tilde{H} = H + \operatorname{MHA}(\operatorname{LN}(H)) $$
$$ H' = \tilde{H} + \operatorname{FFN}(\operatorname{LN}(\tilde{H})) $$
这种结构有助于深层训练稳定。
6. Decoder 与自回归约束
Decoder 比 Encoder 多了两点:
- Masked Self-Attention:只能看见当前位置及之前 token;
- Cross-Attention:Query 来自 decoder,Key/Value 来自 encoder 输出。
训练目标通常是下一个 token 的交叉熵:
$$ \mathcal{L}_{\mathrm{CE}} = -\sum_{t=1}^{T} \log p_\theta(y_t \mid y_{1:t-1}, x) $$
可配合 label smoothing 改善泛化。
7. 复杂度直觉
设序列长度为 $T$、模型维度 $d$:
- 自注意力核心矩阵乘法复杂度约为 $O(T^2 d)$;
- RNN 每步依赖前一步,难并行;
- Transformer 训练时并行度高,但长序列时 $T^2$ 成本明显。
| 模块 | 主要计算 | 复杂度量级 |
|---|---|---|
| $QK^{\top}$ | 注意力打分 | $O(T^2 d)$ |
| $\operatorname{softmax}(QK^{\top})V$ | 加权聚合 | $O(T^2 d)$ |
| $\operatorname{FFN}(x)$ | 逐位置 MLP | $O(T d d_{ff})$ |
8. PyTorch 最小实现(可运行)
下面是一份教学向最小代码(单文件可跑),包含:
- 多头注意力
- Encoder Block
- 简化语言建模训练循环
Python1import math2import torch3import torch.nn as nn4import torch.nn.functional as F5 6 7class MultiHeadSelfAttention(nn.Module):8 def __init__(self, d_model: int, n_heads: int, dropout: float = 0.1):9 super().__init__()10 assert d_model % n_heads == 011 self.d_model = d_model12 self.n_heads = n_heads13 self.d_head = d_model // n_heads14 15 self.w_q = nn.Linear(d_model, d_model)16 self.w_k = nn.Linear(d_model, d_model)17 self.w_v = nn.Linear(d_model, d_model)18 self.w_o = nn.Linear(d_model, d_model)19 self.dropout = nn.Dropout(dropout)20 21 def _split_heads(self, x):22 # x: [B, T, d_model] -> [B, h, T, d_head]23 bsz, seq_len, _ = x.size()24 x = x.view(bsz, seq_len, self.n_heads, self.d_head)25 return x.transpose(1, 2)26 27 def _merge_heads(self, x):28 # x: [B, h, T, d_head] -> [B, T, d_model]29 bsz, _, seq_len, _ = x.size()30 x = x.transpose(1, 2).contiguous()31 return x.view(bsz, seq_len, self.d_model)32 33 def forward(self, x, causal_mask: bool = False):34 # x: [B, T, d_model]35 q = self._split_heads(self.w_q(x))36 k = self._split_heads(self.w_k(x))37 v = self._split_heads(self.w_v(x))38 39 # scores: [B, h, T, T]40 scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.d_head)41 42 if causal_mask:43 t = scores.size(-1)44 mask = torch.triu(torch.ones(t, t, device=scores.device), diagonal=1).bool()45 scores = scores.masked_fill(mask, float('-inf'))46 47 attn = F.softmax(scores, dim=-1)48 attn = self.dropout(attn)49 50 out = torch.matmul(attn, v) # [B, h, T, d_head]51 out = self._merge_heads(out) # [B, T, d_model]52 out = self.w_o(out)53 return out54 55 56class FeedForward(nn.Module):57 def __init__(self, d_model: int, d_ff: int, dropout: float = 0.1):58 super().__init__()59 self.net = nn.Sequential(60 nn.Linear(d_model, d_ff),61 nn.GELU(),62 nn.Dropout(dropout),63 nn.Linear(d_ff, d_model),64 nn.Dropout(dropout),65 )66 67 def forward(self, x):68 return self.net(x)69 70 71class EncoderBlock(nn.Module):72 # Pre-LN73 def __init__(self, d_model: int, n_heads: int, d_ff: int, dropout: float = 0.1):74 super().__init__()75 self.ln1 = nn.LayerNorm(d_model)76 self.attn = MultiHeadSelfAttention(d_model, n_heads, dropout)77 self.ln2 = nn.LayerNorm(d_model)78 self.ffn = FeedForward(d_model, d_ff, dropout)79 80 def forward(self, x):81 x = x + self.attn(self.ln1(x), causal_mask=False)82 x = x + self.ffn(self.ln2(x))83 return x84 85 86class TinyTransformerLM(nn.Module):87 def __init__(self, vocab_size: int, d_model: int = 256, n_heads: int = 8,88 d_ff: int = 1024, n_layers: int = 4, max_len: int = 512):89 super().__init__()90 self.token_emb = nn.Embedding(vocab_size, d_model)91 self.pos_emb = nn.Embedding(max_len, d_model)92 self.blocks = nn.ModuleList([93 EncoderBlock(d_model, n_heads, d_ff) for _ in range(n_layers)94 ])95 self.ln_f = nn.LayerNorm(d_model)96 self.head = nn.Linear(d_model, vocab_size, bias=False)97 98 def forward(self, idx):99 # idx: [B, T]100 bsz, seq_len = idx.size()101 pos = torch.arange(seq_len, device=idx.device).unsqueeze(0)102 x = self.token_emb(idx) + self.pos_emb(pos)103 104 # 语言模型场景:这里用 causal mask 的自注意力更合理105 for block in self.blocks:106 x = x + block.attn(block.ln1(x), causal_mask=True)107 x = x + block.ffn(block.ln2(x))108 109 x = self.ln_f(x)110 logits = self.head(x) # [B, T, V]111 return logits112 113 114def train_step(model, optimizer, batch_x, batch_y):115 """116 batch_x: [B, T] 输入 token117 batch_y: [B, T] 目标 token(右移后的真值)118 """119 model.train()120 logits = model(batch_x) # [B, T, V]121 122 bsz, seq_len, vocab_size = logits.shape123 loss = F.cross_entropy(124 logits.view(bsz * seq_len, vocab_size),125 batch_y.view(bsz * seq_len)126 )127 128 optimizer.zero_grad()129 loss.backward()130 torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)131 optimizer.step()132 133 return loss.item()134 135 136if __name__ == '__main__':137 torch.manual_seed(42)138 139 vocab_size = 5000140 model = TinyTransformerLM(vocab_size=vocab_size)141 optimizer = torch.optim.AdamW(model.parameters(), lr=3e-4, weight_decay=0.01)142 143 # toy batch144 batch_x = torch.randint(0, vocab_size, (8, 64))145 batch_y = torch.randint(0, vocab_size, (8, 64))146 147 for step in range(1, 6):148 loss = train_step(model, optimizer, batch_x, batch_y)149 print(f'step={step}, loss={loss:.4f}')
9. 训练与调参建议
提示先确保 loss 能稳定下降,再谈更复杂技巧(RoPE、FlashAttention、MoE 等)。
注意如果训练初期 loss 波动很大,优先检查:学习率是否过高、warmup 是否缺失、梯度裁剪是否生效。
建议优先排查这 5 项:
- 学习率 + warmup(最常见问题源)
attention mask是否正确LayerNorm放置位置(Pre-LN 更稳)- mixed precision 下的数值稳定性
- tokenization 与数据清洗质量