场景

源句(英文): "The cat sat"
目标(德文): "Die Katze saß"   ← 模型不知道,要自己生成

编码器已经跑完,Z 已经算好:

Z    形状 (3, 512)    ← 3 个源词,固定不变,整个推理过程不再更新

第 0 步:准备开始

解码器输入只有一个开始符:

解码器输入: [<BOS>]

<BOS> = Begin of Sentence,告诉解码器"开始生成"。


第 1 步:生成第 1 个词

解码器拿 [<BOS>] 作输入,走完三个子层:

Masked Self-attention: [<BOS>] 只有一个词,只看自己
Cross-attention:       Q 来自 <BOS> 的向量,K/V 来自 Z,查询源句
FFN:                   非线性加工

最后进入 Linear + Softmax,得到词表上的概率分布:

Die      0.41   ← 最高
Der      0.28
Das      0.18
...

取概率最高的词:Die


第 2 步:生成第 2 个词

把 "Die" 加入输入序列:

解码器输入: [<BOS>, Die]

走完三个子层:

Masked Self-attention: 
  <BOS> 只看 <BOS>
  Die   只看 <BOS>, Die

Cross-attention:
  Q 来自 d[Die](代表"生成了 Die,下一个是什么")
  K/V 来自 Z

Linear + Softmax:

Katze    0.52   ← 最高
Hund     0.21
Maus     0.09
...

取最高:Katze


第 3 步:生成第 3 个词

解码器输入: [<BOS>, Die, Katze]
Cross-attention:
  Q 来自 d[Katze](代表"生成了 Die Katze,下一个是什么")
  K/V 来自 Z

Linear + Softmax:

saß      0.61   ← 最高
sitzt    0.18
lag      0.08
...

取最高:saß


第 4 步:生成结束符

解码器输入: [<BOS>, Die, Katze, saß]

Linear + Softmax:

<EOS>    0.74   ← 最高
auf      0.12
...

取最高: (End of Sentence)

看到 <EOS>,停止生成。


完整结果

生成序列: Die  Katze  saß

整个推理过程的规律

步骤 1: 输入 [<BOS>]              → 生成 Die
步骤 2: 输入 [<BOS>, Die]         → 生成 Katze
步骤 3: 输入 [<BOS>, Die, Katze]  → 生成 saß
步骤 4: 输入 [<BOS>, Die, Katze, saß] → 生成 <EOS> → 停止

每步输入序列比上一步多一个词,这就是自回归(auto-regressive) ——用自己已经生成的结果作为下一步的输入。


三个关键点

1. 编码器只跑一次

Z 在推理开始前算好,整个推理过程固定不变。解码器每一步都查询同一个 Z。

2. 解码器每步重新跑完整的三个子层

不是只算新加的那个词,而是把整个输入序列重新过一遍。但因为有 Masked Self-attention,每个位置只看前面的词,结果和上一步一致,不会重复计算错误。

3. 取最高概率不是唯一的策略

这里用的是 Greedy Search(贪心搜索)——每步取概率最高的词。实际中还有 Beam Search(束搜索)——每步保留概率最高的 k 个候选,最后取整体概率最高的序列,效果更好但更慢。


<BOS> 经过 Embedding + 位置编码之后就是 (1, 512)

我们把第 1 步的每个矩阵维度完整走一遍。


起点

解码器输入: [<BOS>]

Embedding + 位置编码:
<BOS> → (1, 512)

子层 1:Masked Self-attention

只有一个词,掩码没有实际效果(只有自己,看自己)。

输入 X_dec    形状 (1, 512)

投影:
Q = X_dec · W_Q    (1,512) · (512,64) = (1, 64)
K = X_dec · W_K    (1,512) · (512,64) = (1, 64)
V = X_dec · W_V    (1,512) · (512,64) = (1, 64)

点积打分:
Q · Kᵀ    (1,64) · (64,1) = (1, 1)    ← 1个词对自己的分数,就一个数

÷ √64,softmax:
权重      (1, 1)    ← 值为 1.0(只有自己,100%注意自己)

加权求和:
权重 · V    (1,1) · (1,64) = (1, 64)    ← 单头输出

8个头拼接:
(1, 64) × 8 → (1, 512)

· W_O:
(1,512) · (512,512) = (1, 512)

残差 + LayerNorm:
(1,512) + (1,512) → LayerNorm → (1, 512)

输出 d(1, 512)


子层 2:Cross-attention

Q 来自解码器,K 和 V 来自编码器的 Z。

d(解码器)    形状 (1, 512)
Z(编码器)    形状 (3, 512)    ← 源句 3 个词

投影:
Q = d · W_Q    (1,512) · (512,64) = (1, 64)    ← 解码器提问
K = Z · W_K    (3,512) · (512,64) = (3, 64)    ← 源句 3 个 Key
V = Z · W_V    (3,512) · (512,64) = (3, 64)    ← 源句 3 个 Value

点积打分:
Q · Kᵀ    (1,64) · (64,3) = (1, 3)
               ↑
               1个解码器位置 对 3个源词 各打一个分

÷ √64,softmax:
权重    (1, 3)    ← 3个数,加起来=1,代表对源句3个词的注意力分配

例:
[0.15, 0.65, 0.20]
  The   cat   sat

加权求和:
权重 · V    (1,3) · (3,64) = (1, 64)    ← 单头输出

8个头拼接:
(1, 64) × 8 → (1, 512)

· W_O:
(1,512) · (512,512) = (1, 512)

残差 + LayerNorm:
(1,512) + (1,512) → LayerNorm → (1, 512)

输出:(1, 512)


子层 3:FFN

输入    (1, 512)

· W₁:
(1,512) · (512,2048) = (1, 2048)    ← 升维

ReLU:
(1, 2048)                            ← 负值截断

· W₂:
(1,2048) · (2048,512) = (1, 512)    ← 降回

残差 + LayerNorm:
(1,512) + (1,512) → LayerNorm → (1, 512)

输出:(1, 512)


第 1 层解码器结束,传入第 2 层,重复 6 次。

第 6 层结束后进入输出层:

(1, 512) · Linear(512, 37000) = (1, 37000)

Softmax → (1, 37000)    ← 词表每个词的概率

取最大 → "Die"

整条维度变化总结

<BOS> embedding        (1, 512)
↓ Masked Self-attn
  Q,K,V 各            (1, 64)   × 8头
  打分矩阵            (1, 1)
  单头输出            (1, 64)
  拼接+W_O            (1, 512)
  残差+LN             (1, 512)
↓ Cross-attention
  Q                   (1, 64)   ← 来自解码器
  K,V                 (3, 64)   ← 来自编码器Z
  打分矩阵            (1, 3)    ← 1个位置对3个源词打分
  单头输出            (1, 64)
  拼接+W_O            (1, 512)
  残差+LN             (1, 512)
↓ FFN
  升维                (1, 2048)
  降维                (1, 512)
  残差+LN             (1, 512)
↓ × 6层
↓ Linear              (1, 37000)
↓ Softmax             (1, 37000)
↓ 取最大              "Die"

Cross-attention 里 (1, 3) 这个打分矩阵是最关键的地方——1 行代表解码器当前位置,3 列代表源句 3 个词。这一行 softmax 之后的 3 个数,就是"生成这个词时,应该参考源句哪里"的权重分配。