05. Multi-Head Attention 与 FFN:多头并行与知识存储
05. Multi-Head Attention 与 FFN:多头并行与知识存储
上一章我们学了单个注意力的完整计算。但一个注意力只能学一种”关注模式”——Multi-Head Attention 让模型同时学多种模式。FFN 则是模型存储知识的地方。
一、为什么需要多个”头”
1.1 单头的局限
上一章的注意力只有一组 W_Q、W_K、W_V,只能学到一种关注模式。
但语言中有多种不同类型的关系需要同时捕捉:
"小明昨天在北京的餐厅吃了一顿很好吃的火锅"
需要同时关注:
1. 指代关系: "小明" ← 谁在吃?
2. 时间关系: "昨天" ← 什么时候?
3. 地点关系: "北京的餐厅" ← 在哪里?
4. 修饰关系: "很好吃的" → "火锅"
5. 语法关系: "吃了" → 动宾搭配
一个注意力头只能学其中一种模式。要同时捕捉所有关系,需要多个头。
1.2 Multi-Head Attention 的做法
把嵌入维度 d 拆成 h 份,每份交给一个”头”独立计算:
d = 768, h = 12 头
每个头的维度 d_k = 768 / 12 = 64
每个头有自己独立的 W_Q、W_K、W_V:
对于头 i (i = 1, 2, ..., 12):
W_Q_i 形状: [768, 64]
W_K_i 形状: [768, 64]
W_V_i 形状: [768, 64]
Q_i = X @ W_Q_i → 形状 [seq_len, 64]
K_i = X @ W_K_i → 形状 [seq_len, 64]
V_i = X @ W_V_i → 形状 [seq_len, 64]
head_i = Attention(Q_i, K_i, V_i) → 形状 [seq_len, 64]
1.3 拼接与投影
所有头的输出拼接起来,再做一次线性投影:
concat = [head_1 ; head_2 ; ... ; head_12]
→ 形状 [seq_len, 64×12] = [seq_len, 768]
output = concat @ W_O
→ 形状 [seq_len, 768] (W_O 形状: [768, 768])
W_O 是输出投影矩阵,把 12 个头的信息融合在一起。
1.4 实际中 12 个头学到了什么
研究发现,不同的头确实学到了不同的关注模式。以 GPT-2 为例:
头 1: 主要关注前一个 token(bigram 模式)
头 3: 关注句子开头的 token
头 5: 关注动词(捕捉主谓关系)
头 8: 关注逗号和句号(句法结构)
头 11: 长距离依赖(指代消解)
...
这不是设计出来的——是训练过程中自然分化出来的。
1.5 参数量
实际实现中,12 个头的 W_Q_i 合并成一个大矩阵:
W_Q: [768, 768] = 589,824 参数
W_K: [768, 768] = 589,824 参数
W_V: [768, 768] = 589,824 参数
W_O: [768, 768] = 589,824 参数
一个注意力层总计: 768 × 768 × 4 = 2,359,296 ≈ 2.36M 参数
注意总参数量和单头一样多——Multi-Head 没有增加参数量,只是把同样数量的参数分配给不同的头去学习不同的模式。
二、GQA:分组查询注意力
2.1 标准 MHA 的显存问题
在推理时,每个头都需要缓存 K 和 V(KV Cache)。12 个头就是 12 份。模型越大、头数越多,KV Cache 占用的显存越大。
标准 MHA (Multi-Head Attention):
12 个 Q 头, 12 个 K 头, 12 个 V 头
KV Cache: 12 份 K + 12 份 V = 24 份缓存
2.2 GQA 的思路
GQA(Grouped Query Attention) 让多个 Q 头共享同一组 K 和 V:
GQA (4 个 KV 组, 12 个 Q 头):
Q 头 1,2,3 → 共享 KV 组 1
Q 头 4,5,6 → 共享 KV 组 2
Q 头 7,8,9 → 共享 KV 组 3
Q 头 10,11,12 → 共享 KV 组 4
KV Cache: 4 份 K + 4 份 V = 8 份缓存 (减少 2/3!)
LLaMA 3、GLM-4 等现代模型都使用 GQA。
2.3 极端情况:MQA
MQA(Multi-Query Attention) 是 GQA 的极端——所有 Q 头共享一组 KV:
MQA:
Q 头 1-12 → 全部共享 1 组 KV
KV Cache: 1 份 K + 1 份 V = 2 份缓存
更省显存,但效果略差。GQA 是折中方案。
三、FFN:前馈网络与知识存储
3.1 结构
每个 Transformer 层除了注意力,还有一个前馈网络(Feed-Forward Network)。
结构很简单——两个线性层夹一个激活函数:
FFN(x) = W_2 · activation(W_1 · x + b_1) + b_2
维度变化:
输入: x ∈ R^d (768 维)
第一层: W_1 ∈ R^(d × 4d) (768 × 3072)
→ 升维到 3072 维("展开")
激活函数: GELU
第二层: W_2 ∈ R^(4d × d) (3072 × 768)
→ 降回 768 维("压缩")
3.2 数值走一遍
输入 x = [0.47, 0.50, 0.59, 0.29] (简化为 d=4, 4d=16)
第一层:
z = x @ W_1 + b_1
→ 16 维向量, 例如: [1.2, -0.5, 0.8, 0.0, -1.1, 0.3, ...]
激活:
h = GELU(z)
→ [0.98, -0.15, 0.63, 0.0, -0.13, 0.22, ...]
(正数基本不变,负数被大幅压缩)
第二层:
output = h @ W_2 + b_2
→ 4 维向量, 例如: [0.35, 0.62, 0.41, 0.28]
3.3 “先展开再压缩”的直觉
为什么要先升维到 4d 再降回 d?
768 维 → 3072 维 → 768 维
把它想象成一个检索过程:
-
展开到 3072 维:在更高维度的空间中,输入可以匹配到更多的”知识模式”。W_1 的每一行就是一个”模式检测器”。
-
激活函数:GELU 决定哪些模式被”激活”(正值保留,负值压缩)。
-
压缩回 768 维:W_2 的每一列是被激活模式对应的”输出模板”。把激活的模式组合成最终输出。
3.4 FFN 是”知识库”
2021 年的重要论文(Geva et al., “Transformer Feed-Forward Layers Are Key-Value Memories”)发现:
W_1 的每一行 = 一个 "key"(检测特定输入模式)
W_2 的对应列 = 一个 "value"(对应的输出模式)
输入匹配某个 key → 激活对应的 value → 注入到隐藏状态
举例:
W_1 的第 1729 行可能检测"法国的首都"这个模式
W_2 的第 1729 列可能编码了"巴黎"相关的信息
当输入是"法国的首都是___"时:
→ x @ W_1[1729] 得到高激活值
→ GELU 后保留这个激活
→ 乘以 W_2[:, 1729] 把"巴黎"的信息注入输出
所以模型的知识(事实、关系、规律)主要存储在 FFN 中,注意力更像是路由器(决定关注哪些信息)。
3.5 参数量
一个 FFN 层:
W_1: [768, 3072] = 2,359,296
b_1: [3072] = 3,072
W_2: [3072, 768] = 2,359,296
b_2: [768] = 768
总计: ≈ 4.72M
FFN 的参数量(4.72M)是注意力(2.36M)的两倍! 模型的大部分参数都在 FFN 中。
一个 Transformer 层的参数分布:
Attention: 2.36M (33%)
FFN: 4.72M (67%)
直觉: 1/3 的参数用来"路由"信息,2/3 的参数用来"存储"知识
四、SwiGLU:现代 FFN 的改进
4.1 门控机制
LLaMA、GLM 等现代模型不用标准的 “W_1 + GELU + W_2”,而是用 SwiGLU:
标准 FFN:
output = W_2 · GELU(W_1 · x)
SwiGLU FFN:
output = W_2 · (SiLU(W_gate · x) ⊙ (W_up · x))
其中 ⊙ 是逐元素乘法,SiLU(x) = x · sigmoid(x)。
多了一个”门控”矩阵 W_gate,让模型能更精细地控制哪些信息通过。
4.2 维度调整
因为多了一个矩阵,为了保持总参数量不变,中间维度从 4d 变为 (8/3)d:
标准 FFN: W_1 [d, 4d] + W_2 [4d, d] = 8d² 参数
SwiGLU FFN: W_gate [d, 8d/3] + W_up [d, 8d/3] + W_down [8d/3, d] = 8d² 参数
实验表明 SwiGLU 在相同参数量下效果更好。
五、Attention + FFN 的分工
一个 Transformer 层的信息流:
输入 x
│
├──→ Attention ──→ "从上下文中提取相关信息"
│ (决定关注哪些 token)
│
├──→ FFN ──→ "基于提取的信息,检索/注入知识"
│ (从参数中读取事实和规律)
│
└──→ 输出 x'
类比:
- Attention = 在会议中决定”听谁说话”
- FFN = 根据听到的内容,结合自己的知识库,给出判断
这两个组件交替堆叠(12 层、96 层…),逐步构建出对输入的深层理解。
本章总结
- Multi-Head Attention 把嵌入维度拆成多份,每个头独立学习不同的关注模式
- 头数越多不增加参数量——只是把参数分给不同的头
- GQA 让多个 Q 头共享 KV,减少推理时的显存占用
- FFN 是两层线性变换(升维 → 激活 → 降维),本质是知识存储
- FFN 占模型参数量的 2/3,是模型存储事实知识的主要地方
- 现代模型用 SwiGLU 替代标准 FFN,效果更好
- Attention 做”路由”,FFN 做”知识检索”,两者互补
下一篇:06. 完整 Transformer 架构:残差、LayerNorm 与层堆叠 — 把所有组件组装在一起