自注意力(Self-Attention)的计算过程

在深度学习中,很多 LLM 的训练都使用 Transformer 架构,而在 Transformer 架构中计算的过程涉及到的最关键的就是注意力,它是整个过程中重要的基础。注意力抽象出了 3 个重要的概念,在计算过程中对应着 3 个矩阵,如下所示:

  • Query:在自主提示下,自主提示的内容,对应着矩阵 Q
  • Keys:在非自主提示下,进入视觉系统的线索,对应着矩阵 K
  • Values:使用 Query 从 Keys 中匹配得到的线索,基于这些线索得到的进入视觉系统中焦点内容,对应着矩阵 V

我们要训练的模型,输入的句子有 n 个 token,而通过选择并使用某个 Embedding 模型获取到每个 token 的 Word Embedding,每个 Word Embedding 是一个 d 维向量。本文我们详细说明自注意力(Self-Attention)的计算过程,在进行解释说明之前,先定义一些标识符号以方便后面阐述使用:

  • X:输入训练数据的 Embedding 是一个 n x d 矩阵
  • Q:查询矩阵,矩阵形状 n x dq
  • K:键矩阵,矩阵形状 n x dk,其中 dk=dq
  • V:值矩阵,矩阵形状 n x dv

计算自注意力(Self-Attention)的基本流程,如下图所示:
Computing-Self-Attention
计算过程及其示例演示,描述如下:

1. 计算生成 Embedding 矩阵

X 是基于输入数据得到的 n x d 的 Embedding 矩阵:通过选择并使用某个 Embedding 模型,输入数据经过分词得到 n 个 token,再获取到每个 token 的 Embedding 向量,然后将这 n 个 Embedding 向量拼接成一个 n x d 的 Embedding 矩阵。
假设我们输入的是一个句子:

Life is short, less is more.

比如,要训练模型以便下游执行机器翻译任务,上面句子对应的 Embedding 是一个 6 x 16 的 Tensor,即 n=6,d=16,如下所示:

tensor([[ 0.3374, -0.1778, -0.3035, -0.5880,  0.3486,  0.6603, -0.2196, -0.3792,
          0.7671, -1.1925,  0.6984, -1.4097,  0.1794,  1.8951,  0.4954,  0.2692],
        [-1.3250,  0.1784, -2.1338,  1.0524, -0.3885, -0.9343, -0.4991, -1.0867,
          0.8805,  1.5542,  0.6266, -0.1755,  0.0983, -0.0935,  0.2662, -0.5850],
        [ 0.2553, -0.5496,  1.0042,  0.8272, -0.3948,  0.4892, -0.2168, -1.7472,
         -1.6025, -1.0764,  0.9031, -0.7218, -0.5951, -0.7112,  0.6230, -1.3729],
        [ 0.8768,  1.6221, -1.4779,  1.1331, -1.2203,  1.3139,  1.0533,  0.1388,
          2.2473, -0.8036, -0.2808,  0.7697, -0.6596, -0.7979,  0.1838,  0.2293],
        [-1.3250,  0.1784, -2.1338,  1.0524, -0.3885, -0.9343, -0.4991, -1.0867,
          0.8805,  1.5542,  0.6266, -0.1755,  0.0983, -0.0935,  0.2662, -0.5850],
        [ 0.5146,  0.9938, -0.2587, -1.0826, -0.0444,  1.6236, -2.3229,  1.0878,
          0.6716,  0.6933, -0.9487, -0.0765, -0.1526,  0.1167,  0.4403, -1.4465]])
torch.Size([6, 16])

2. 计算生成 Q、K、V 矩阵

定义 3 个矩阵 Wq、Wk、Wv 它们是权重矩阵,输入 X 与这 3 个矩阵的转置分别相乘,得到 Q、K、V。其中,Q 是 n x dq 矩阵,K 是 n x dk 矩阵,V 是 n x dv 矩阵。计算的公式如下所示(().T 表示对矩阵的转置操作):

  • Q = X · (Wq).T
  • K = X · (Wk).T
  • V = X · (Wv).T

Wq、Wk、Wv 是三个参数矩阵,通过在模型进行训练的过程中,经过学习会得到这 3 个矩阵对应的参数值。
在模型训练过程中,参数矩阵可以采用随机策略生成,这里我们随机生成 Wq(8 x 16 矩阵)、Wk(8 x 16 矩阵)、Wv(12 x 16 矩阵),使用 PyTorch 生成:

Wq = torch.rand(8, 16)
print(Wq.T)
print(Wq.T.shape)

Wk = torch.rand(8, 16)
print(Wk.T)
print(Wk.T.shape)

Wv = torch.rand(12, 16)
print(Wv.T)
print(Wv.T.shape)

从而可以分别得到这 3 个矩阵的转置的结果,如下所示:

tensor([[0.9882, 0.7577, 0.6397, 0.5757, 0.6042, 0.0206, 0.6886, 0.4270],
        [0.8363, 0.4536, 0.8954, 0.2785, 0.9836, 0.3247, 0.9024, 0.8210],
        [0.9010, 0.4130, 0.2979, 0.1946, 0.1444, 0.9355, 0.1123, 0.3605],
        [0.3950, 0.5585, 0.6314, 0.5382, 0.9010, 0.5855, 0.2685, 0.4516],
        [0.8809, 0.1170, 0.5028, 0.1291, 0.9221, 0.4695, 0.6591, 0.7056],
        [0.1084, 0.5578, 0.1239, 0.1242, 0.9043, 0.5201, 0.1735, 0.1853],
        [0.5432, 0.6681, 0.3786, 0.1746, 0.5713, 0.8118, 0.9247, 0.6339],
        [0.2185, 0.9275, 0.1661, 0.3302, 0.9546, 0.0585, 0.6166, 0.3894],
        [0.3834, 0.3443, 0.7211, 0.5370, 0.8339, 0.1142, 0.3608, 0.7398],
        [0.3720, 0.6800, 0.5449, 0.8443, 0.8730, 0.3338, 0.5325, 0.2288],
        [0.5374, 0.9998, 0.5490, 0.6937, 0.4675, 0.2122, 0.6559, 0.5185],
        [0.9551, 0.2855, 0.3483, 0.8831, 0.1163, 0.7579, 0.3232, 0.5489],
        [0.7475, 0.9753, 0.5024, 0.1861, 0.4938, 0.8533, 0.1126, 0.0977],
        [0.4979, 0.2518, 0.3445, 0.5422, 0.5938, 0.0149, 0.5034, 0.1364],
        [0.8549, 0.7204, 0.6437, 0.0556, 0.1594, 0.0757, 0.5091, 0.6918],
        [0.2438, 0.6959, 0.9856, 0.7868, 0.2132, 0.0131, 0.5101, 0.3545]])
torch.Size([16, 8])
tensor([[0.7969, 0.7674, 0.5840, 0.6569, 0.0975, 0.3534, 0.1829, 0.1836],
        [0.0061, 0.4058, 0.1227, 0.3704, 0.2956, 0.6638, 0.2956, 0.2010],
        [0.2528, 0.1548, 0.9587, 0.3630, 0.9027, 0.4563, 0.8646, 0.9603],
        [0.0882, 0.5201, 0.9914, 0.0578, 0.3112, 0.1091, 0.8010, 0.6861],
        [0.6997, 0.8773, 0.1547, 0.3629, 0.9167, 0.3069, 0.8044, 0.4209],
        [0.4855, 0.9577, 0.5185, 0.2974, 0.4139, 0.7274, 0.0733, 0.8046],
        [0.4067, 0.1226, 0.2337, 0.2275, 0.4362, 0.5164, 0.7355, 0.2621],
        [0.4168, 0.2742, 0.9794, 0.0484, 0.6996, 0.6845, 0.6248, 0.0638],
        [0.1092, 0.8893, 0.7788, 0.8916, 0.4265, 0.2073, 0.1638, 0.0036],
        [0.6418, 0.7444, 0.7945, 0.0532, 0.4958, 0.9727, 0.5158, 0.7032],
        [0.5125, 0.8095, 0.6613, 0.9964, 0.8463, 0.2913, 0.6000, 0.3051],
        [0.1549, 0.2511, 0.4502, 0.2377, 0.6671, 0.6066, 0.2299, 0.8070],
        [0.6881, 0.9308, 0.7815, 0.4616, 0.4801, 0.2557, 0.2890, 0.9271],
        [0.4900, 0.0890, 0.5085, 0.9079, 0.6904, 0.2588, 0.9078, 0.6647],
        [0.0164, 0.4759, 0.3176, 0.6650, 0.9355, 0.7239, 0.4596, 0.9296],
        [0.7690, 0.5104, 0.7582, 0.3573, 0.6260, 0.3604, 0.4947, 0.3848]])
torch.Size([16, 8])
tensor([[0.9357, 0.8048, 0.0204, 0.9148, 0.2590, 0.4970, 0.1098, 0.9920, 0.4196,
         0.5242, 0.5921, 0.8188],
        [0.2616, 0.0649, 0.8290, 0.1705, 0.7162, 0.3552, 0.6353, 0.4791, 0.0050,
         0.5153, 0.0056, 0.5773],
        [0.4344, 0.8322, 0.1063, 0.0943, 0.5689, 0.2576, 0.3719, 0.7945, 0.1368,
         0.5047, 0.5577, 0.7870],
        [0.8323, 0.3672, 0.2062, 0.8800, 0.8181, 0.7346, 0.0574, 0.9323, 0.8588,
         0.7175, 0.3350, 0.8855],
        [0.2410, 0.9012, 0.5058, 0.2614, 0.8286, 0.4564, 0.6951, 0.1144, 0.0121,
         0.3116, 0.4620, 0.9941],
        [0.8815, 0.8146, 0.6522, 0.5325, 0.5292, 0.4009, 0.6766, 0.8039, 0.2541,
         0.5315, 0.7872, 0.3705],
        [0.6226, 0.2077, 0.7905, 0.9981, 0.7914, 0.8474, 0.5674, 0.0651, 0.0475,
         0.5021, 0.3279, 0.5148],
        [0.4902, 0.4474, 0.4298, 0.3005, 0.1387, 0.1203, 0.8267, 0.3650, 0.7690,
         0.7111, 0.1213, 0.2103],
        [0.9279, 0.5746, 0.2427, 0.9657, 0.0221, 0.8265, 0.2993, 0.2984, 0.8418,
         0.1939, 0.5302, 0.9562],
        [0.8751, 0.6429, 0.4570, 0.8973, 0.0927, 0.9441, 0.9564, 0.0324, 0.5438,
         0.1091, 0.3608, 0.6591],
        [0.2943, 0.0369, 0.6638, 0.8862, 0.7759, 0.1928, 0.1189, 0.0290, 0.2486,
         0.0931, 0.2668, 0.4172],
        [0.5485, 0.5224, 0.2187, 0.6483, 0.9598, 0.0263, 0.9508, 0.0179, 0.3788,
         0.7101, 0.3473, 0.6253],
        [0.5583, 0.7605, 0.0657, 0.2746, 0.3617, 0.5696, 0.8715, 0.1132, 0.5291,
         0.8978, 0.2165, 0.9961],
        [0.9096, 0.7823, 0.7387, 0.8148, 0.7766, 0.1197, 0.0552, 0.2206, 0.7095,
         0.9959, 0.9389, 0.7036],
        [0.7810, 0.7459, 0.1691, 0.1575, 0.1427, 0.7091, 0.4556, 0.3352, 0.5086,
         0.6785, 0.5344, 0.7429],
        [0.9049, 0.5791, 0.2186, 0.2087, 0.4906, 0.1012, 0.2310, 0.7797, 0.2669,
         0.3981, 0.2346, 0.9616]])
torch.Size([16, 12])

接着,分别计算 Q、K、V 这三个矩阵,代码如下所示:

Q = X.matmul(Wq.T)
print(Q)
print(Q.shape)

K = X.matmul(Wk.T)
print(K)
print(K.shape)

V = X.matmul(Wv.T)
print(V)
print(V.shape)

计算得到 Q、K、V 三个矩阵的结果,如下所示:

tensor([[ 0.3018,  0.5771,  0.8287, -0.3224,  0.9979, -1.3807,  0.7953,  0.3018],
        [-2.4223, -1.3309,  0.0371,  0.4128, -0.2362, -1.7722, -0.9576, -0.6908],
        [-1.5899, -2.5706, -3.0113, -3.2927, -4.0568, -0.3453, -3.0388, -2.1831],
        [ 1.4786,  2.4595,  3.2942,  2.1628,  4.1394,  0.7536,  2.8714,  3.5802],
        [-2.4223, -1.3309,  0.0371,  0.4128, -0.2362, -1.7722, -0.9576, -0.6908],
        [-0.3082, -0.5900, -0.9257, -0.7688,  1.8828, -1.6065, -0.8011, -0.4114]])
torch.Size([6, 8])
tensor([[ 1.1813,  1.3559,  0.2174,  3.4764,  1.1683, -0.6475,  0.9448,  0.1573],
        [-1.9191, -0.0077, -1.5536, -0.7241, -1.9709, -1.2558, -1.5606, -1.4531],
        [-2.6464, -2.5228, -3.2055, -1.5096, -2.4377, -2.7335, -1.7701, -0.5160],
        [-0.2554,  2.9326,  1.8757,  1.9398, -0.1509,  1.9660, -0.5092, -0.2307],
        [-1.9191, -0.0077, -1.5536, -0.7241, -1.9709, -1.2558, -1.5606, -1.4531],
        [-0.6116,  1.3902, -0.1460,  0.0244, -0.5577,  1.5972, -2.2190, -0.0214]])
torch.Size([6, 8])
tensor([[ 1.5465,  1.5362,  1.3274,  0.9452,  0.5531,  0.1000, -1.5909,  0.8779,
          0.8645,  1.1643,  2.2148,  1.3088],
        [-1.0912, -2.8470, -0.4005,  0.6766, -1.7351,  1.0082, -1.1248, -3.2161,
          0.5959, -2.3485, -1.7592, -1.2618],
        [-3.7293, -2.7705, -2.1744, -2.7340, -1.0410, -1.8867, -4.0902, -0.3303,
         -3.1343, -2.4864, -1.1285, -3.5427],
        [ 4.1125,  0.5593,  2.1795,  4.3551,  1.7887,  3.0898,  1.5155,  3.2049,
          2.5387,  2.0061,  1.2701,  2.4616],
        [-1.0912, -2.8470, -0.4005,  0.6766, -1.7351,  1.0082, -1.1248, -3.2161,
          0.5959, -2.3485, -1.7592, -1.2618],
        [ 0.2002,  1.3752, -0.0809, -1.2746, -2.3948, -0.3425,  1.5967,  0.5399,
          0.9113,  0.0962,  0.7300, -1.0553]])
torch.Size([6, 12])

3. 计算注意力分数矩阵

上一步已经得到 Q、K,其中 Q 是 n x dq 矩阵,K 是 n x dk 矩阵,这里 dq = dk。
注意力分数矩阵(Attention Score Matrix)的计算,直接使用 Q 与 K 的转置相乘,得到的结果是一个 n x n 矩阵 Q(K).T。
使用 PyTorch 计算 Q·(K).T,代码如下:

QKT = Q.matmul(K.T)
print(QKT)
print(QKT.shape)

计算结果,如下所示:

tensor([[  3.0568,  -3.5501,  -4.6460,  -0.7954,  -3.5501,  -4.0441],
        [ -3.3648,   9.4917,  16.4977,  -5.2158,   9.4917,  -0.9232],
        [-25.1955,  26.4776,  42.6549, -17.1836,  26.4776,   6.2588],
        [ 20.9407, -28.3288, -43.0230,  15.7782, -28.3288,  -5.4665],
        [ -3.3648,   9.4917,  16.4977,  -5.2158,   9.4917,  -0.9232],
        [ -1.6198,   2.7452,   7.8636,  -7.8188,   2.7452,  -2.3448]])
torch.Size([6, 6])

4. 标准化注意力分数矩阵,并与 V 相乘得到自注意力结果

针对矩阵 QKT 进行标准化 Softmax 操作,得到归一化的结果矩阵 A,A 的大小为 n x n。
最后,n x dv 的矩阵 V,乘以 n x n 的矩阵 A,得到了最后的自注意力结果矩阵 Z。
计算 Z = AT(Q, K, V) 的计算公式如下图所示:
Attention-AT-QKV
前面计算已经得到 QKT 矩阵,n=6,dk=8,则 A 的大小也是 6 x 6。

dk = 16
A = F.softmax(QKT/(dk**0.50), dim=0)
print(A)

然后再乘以矩阵 V:

Z = A.matmul(V)
print(Z)
print(Z.shape)

结果示例如下所示:

tensor([[-2.9351e-02,  6.1129e-02,  5.4583e-02,  9.4816e-02,  4.9514e-02,
          1.2487e-01,  2.9251e-02,  6.3045e-03, -2.7838e-02,  2.7134e-02,
          1.2357e-01,  4.0422e-02],
        [-2.4056e-01,  3.3915e-02,  1.9047e-03,  1.1973e-01, -1.5679e-02,
          1.4004e-01, -9.8062e-02, -8.4042e-02, -1.5060e-01, -1.1663e-02,
          1.6010e-01, -2.3983e-02],
        [-7.0469e+00, -5.9786e+00, -6.3531e+00, -3.7648e+00, -1.4788e+00,
         -5.4479e+00, -6.4546e+00, -6.2927e+00, -2.4734e+00, -3.2765e+00,
         -5.1672e+00, -6.0859e+00],
        [ 3.8240e+00,  1.7311e+00,  2.8731e+00,  2.2339e+00,  4.9503e+00,
          2.7620e+00,  3.9712e+00,  1.2559e+00,  3.1700e+00,  1.4015e+00,
          1.8649e+00,  2.7100e+00],
        [-2.4056e-01,  3.3915e-02,  1.9047e-03,  1.1973e-01, -1.5679e-02,
          1.4004e-01, -9.8062e-02, -8.4042e-02, -1.5060e-01, -1.1663e-02,
          1.6010e-01, -2.3983e-02],
        [-1.2892e-01,  4.9015e-02,  2.3402e-02,  9.5212e-02, -1.8994e-02,
          1.2503e-01, -3.3910e-02, -1.9795e-02, -1.0338e-01,  1.0985e-02,
          1.5254e-01,  6.4172e-03]])
torch.Size([6, 12])

或者,我们可以直接使用 PyTorch 的 torch.nn.functional.scaled_dot_product_attention() 函数,从 Q、K、V 一步就可以计算得到,如下所示:

import torch.nn.functional as F

Z = F.scaled_dot_product_attention(Q, K, V)
print(Z)
print(Z.shape)

参考资源

Creative Commons License

本文基于署名-非商业性使用-相同方式共享 4.0许可协议发布,欢迎转载、使用、重新发布,但务必保留文章署名时延军(包含链接:http://shiyanjun.cn),不得用于商业目的,基于本文修改后的作品务必以相同的许可发布。如有任何疑问,请与我联系

发表评论

电子邮件地址不会被公开。 必填项已用*标注

您可以使用这些HTML标签和属性: <a href="" title=""> <abbr title=""> <acronym title=""> <b> <blockquote cite=""> <cite> <code> <del datetime=""> <em> <i> <q cite=""> <strike> <strong>