AI Research

05. Multi-Head Attention 与 FFN:多头并行与知识存储

05. Multi-Head Attention 与 FFN:多头并行与知识存储

上一章我们学了单个注意力的完整计算。但一个注意力只能学一种”关注模式”——Multi-Head Attention 让模型同时学多种模式。FFN 则是模型存储知识的地方。


一、为什么需要多个”头”

1.1 单头的局限

上一章的注意力只有一组 W_Q、W_K、W_V,只能学到一种关注模式。

但语言中有多种不同类型的关系需要同时捕捉:

"小明昨天在北京的餐厅吃了一顿很好吃的火锅"

需要同时关注:
  1. 指代关系: "小明" ← 谁在吃?
  2. 时间关系: "昨天" ← 什么时候?
  3. 地点关系: "北京的餐厅" ← 在哪里?
  4. 修饰关系: "很好吃的" → "火锅"
  5. 语法关系: "吃了" → 动宾搭配

一个注意力头只能学其中一种模式。要同时捕捉所有关系,需要多个头

1.2 Multi-Head Attention 的做法

把嵌入维度 d 拆成 h 份,每份交给一个”头”独立计算:

d = 768,  h = 12 头
每个头的维度 d_k = 768 / 12 = 64

每个头有自己独立的 W_Q、W_K、W_V:

对于头 i (i = 1, 2, ..., 12):
  W_Q_i 形状: [768, 64]
  W_K_i 形状: [768, 64]
  W_V_i 形状: [768, 64]

  Q_i = X @ W_Q_i    → 形状 [seq_len, 64]
  K_i = X @ W_K_i    → 形状 [seq_len, 64]
  V_i = X @ W_V_i    → 形状 [seq_len, 64]

  head_i = Attention(Q_i, K_i, V_i)    → 形状 [seq_len, 64]

1.3 拼接与投影

所有头的输出拼接起来,再做一次线性投影:

concat = [head_1 ; head_2 ; ... ; head_12]
       → 形状 [seq_len, 64×12] = [seq_len, 768]

output = concat @ W_O
       → 形状 [seq_len, 768]    (W_O 形状: [768, 768])

W_O 是输出投影矩阵,把 12 个头的信息融合在一起。

1.4 实际中 12 个头学到了什么

研究发现,不同的头确实学到了不同的关注模式。以 GPT-2 为例:

头 1: 主要关注前一个 token(bigram 模式)
头 3: 关注句子开头的 token
头 5: 关注动词(捕捉主谓关系)
头 8: 关注逗号和句号(句法结构)
头 11: 长距离依赖(指代消解)
...

这不是设计出来的——是训练过程中自然分化出来的。

1.5 参数量

实际实现中,12 个头的 W_Q_i 合并成一个大矩阵:
  W_Q: [768, 768] = 589,824 参数
  W_K: [768, 768] = 589,824 参数
  W_V: [768, 768] = 589,824 参数
  W_O: [768, 768] = 589,824 参数

一个注意力层总计: 768 × 768 × 4 = 2,359,296 ≈ 2.36M 参数

注意总参数量和单头一样多——Multi-Head 没有增加参数量,只是把同样数量的参数分配给不同的头去学习不同的模式。


二、GQA:分组查询注意力

2.1 标准 MHA 的显存问题

在推理时,每个头都需要缓存 K 和 V(KV Cache)。12 个头就是 12 份。模型越大、头数越多,KV Cache 占用的显存越大。

标准 MHA (Multi-Head Attention):
  12 个 Q 头, 12 个 K 头, 12 个 V 头
  KV Cache: 12 份 K + 12 份 V = 24 份缓存

2.2 GQA 的思路

GQA(Grouped Query Attention) 让多个 Q 头共享同一组 K 和 V:

GQA (4 个 KV 组, 12 个 Q 头):
  Q 头 1,2,3   → 共享 KV 组 1
  Q 头 4,5,6   → 共享 KV 组 2
  Q 头 7,8,9   → 共享 KV 组 3
  Q 头 10,11,12 → 共享 KV 组 4

  KV Cache: 4 份 K + 4 份 V = 8 份缓存 (减少 2/3!)

LLaMA 3、GLM-4 等现代模型都使用 GQA。

2.3 极端情况:MQA

MQA(Multi-Query Attention) 是 GQA 的极端——所有 Q 头共享一组 KV:

MQA:
  Q 头 1-12 → 全部共享 1 组 KV
  KV Cache: 1 份 K + 1 份 V = 2 份缓存

更省显存,但效果略差。GQA 是折中方案。


三、FFN:前馈网络与知识存储

3.1 结构

每个 Transformer 层除了注意力,还有一个前馈网络(Feed-Forward Network)

结构很简单——两个线性层夹一个激活函数:

FFN(x) = W_2 · activation(W_1 · x + b_1) + b_2

维度变化:

输入:    x ∈ R^d           (768 维)
第一层:  W_1 ∈ R^(d × 4d)  (768 × 3072)
         → 升维到 3072 维("展开")

激活函数: GELU

第二层:  W_2 ∈ R^(4d × d)  (3072 × 768)
         → 降回 768 维("压缩")

3.2 数值走一遍

输入 x = [0.47, 0.50, 0.59, 0.29]   (简化为 d=4, 4d=16)

第一层:
  z = x @ W_1 + b_1
  → 16 维向量, 例如: [1.2, -0.5, 0.8, 0.0, -1.1, 0.3, ...]

激活:
  h = GELU(z)
  → [0.98, -0.15, 0.63, 0.0, -0.13, 0.22, ...]
  (正数基本不变,负数被大幅压缩)

第二层:
  output = h @ W_2 + b_2
  → 4 维向量, 例如: [0.35, 0.62, 0.41, 0.28]

3.3 “先展开再压缩”的直觉

为什么要先升维到 4d 再降回 d?

768 维 → 3072 维 → 768 维

把它想象成一个检索过程

  1. 展开到 3072 维:在更高维度的空间中,输入可以匹配到更多的”知识模式”。W_1 的每一行就是一个”模式检测器”。

  2. 激活函数:GELU 决定哪些模式被”激活”(正值保留,负值压缩)。

  3. 压缩回 768 维:W_2 的每一列是被激活模式对应的”输出模板”。把激活的模式组合成最终输出。

3.4 FFN 是”知识库”

2021 年的重要论文(Geva et al., “Transformer Feed-Forward Layers Are Key-Value Memories”)发现:

W_1 的每一行 = 一个 "key"(检测特定输入模式)
W_2 的对应列 = 一个 "value"(对应的输出模式)

输入匹配某个 key → 激活对应的 value → 注入到隐藏状态

举例:

W_1 的第 1729 行可能检测"法国的首都"这个模式
W_2 的第 1729 列可能编码了"巴黎"相关的信息

当输入是"法国的首都是___"时:
  → x @ W_1[1729] 得到高激活值
  → GELU 后保留这个激活
  → 乘以 W_2[:, 1729] 把"巴黎"的信息注入输出

所以模型的知识(事实、关系、规律)主要存储在 FFN 中,注意力更像是路由器(决定关注哪些信息)。

3.5 参数量

一个 FFN 层:
  W_1: [768, 3072]  = 2,359,296
  b_1: [3072]       = 3,072
  W_2: [3072, 768]  = 2,359,296
  b_2: [768]        = 768
  总计: ≈ 4.72M

FFN 的参数量(4.72M)是注意力(2.36M)的两倍! 模型的大部分参数都在 FFN 中。

一个 Transformer 层的参数分布:
  Attention: 2.36M (33%)
  FFN:       4.72M (67%)

直觉: 1/3 的参数用来"路由"信息,2/3 的参数用来"存储"知识

四、SwiGLU:现代 FFN 的改进

4.1 门控机制

LLaMA、GLM 等现代模型不用标准的 “W_1 + GELU + W_2”,而是用 SwiGLU

标准 FFN:
  output = W_2 · GELU(W_1 · x)

SwiGLU FFN:
  output = W_2 · (SiLU(W_gate · x) ⊙ (W_up · x))

其中 是逐元素乘法,SiLU(x) = x · sigmoid(x)。

多了一个”门控”矩阵 W_gate,让模型能更精细地控制哪些信息通过。

4.2 维度调整

因为多了一个矩阵,为了保持总参数量不变,中间维度从 4d 变为 (8/3)d:

标准 FFN:   W_1 [d, 4d] + W_2 [4d, d]       = 8d² 参数
SwiGLU FFN: W_gate [d, 8d/3] + W_up [d, 8d/3] + W_down [8d/3, d] = 8d² 参数

实验表明 SwiGLU 在相同参数量下效果更好。


五、Attention + FFN 的分工

一个 Transformer 层的信息流:

输入 x

  ├──→ Attention ──→ "从上下文中提取相关信息"
  │                   (决定关注哪些 token)

  ├──→ FFN ──→ "基于提取的信息,检索/注入知识"
  │             (从参数中读取事实和规律)

  └──→ 输出 x'

类比:

  • Attention = 在会议中决定”听谁说话”
  • FFN = 根据听到的内容,结合自己的知识库,给出判断

这两个组件交替堆叠(12 层、96 层…),逐步构建出对输入的深层理解。


本章总结

  1. Multi-Head Attention 把嵌入维度拆成多份,每个头独立学习不同的关注模式
  2. 头数越多不增加参数量——只是把参数分给不同的头
  3. GQA 让多个 Q 头共享 KV,减少推理时的显存占用
  4. FFN 是两层线性变换(升维 → 激活 → 降维),本质是知识存储
  5. FFN 占模型参数量的 2/3,是模型存储事实知识的主要地方
  6. 现代模型用 SwiGLU 替代标准 FFN,效果更好
  7. Attention 做”路由”,FFN 做”知识检索”,两者互补

下一篇:06. 完整 Transformer 架构:残差、LayerNorm 与层堆叠 — 把所有组件组装在一起