在深度学习中,很多 LLM 的训练都使用 Transformer 架构,而在 Transformer 架构中计算的过程涉及到的最关键的就是注意力,它是整个过程中重要的基础。注意力抽象出了 3 个重要的概念,在计算过程中对应着 3 个矩阵,如下所示:
- Query:在自主提示下,自主提示的内容,对应着矩阵 Q
- Keys:在非自主提示下,进入视觉系统的线索,对应着矩阵 K
- Values:使用 Query 从 Keys 中匹配得到的线索,基于这些线索得到的进入视觉系统中焦点内容,对应着矩阵 V
我们要训练的模型,输入的句子有 n 个 token,而通过选择并使用某个 Embedding 模型获取到每个 token 的 Word Embedding,每个 Word Embedding 是一个 d 维向量。本文我们详细说明自注意力(Self-Attention)的计算过程,在进行解释说明之前,先定义一些标识符号以方便后面阐述使用:
- X:输入训练数据的 Embedding 是一个 n x d 矩阵
- Q:查询矩阵,矩阵形状 n x dq
- K:键矩阵,矩阵形状 n x dk,其中 dk=dq
- V:值矩阵,矩阵形状 n x dv
计算自注意力(Self-Attention)的基本流程,如下图所示:
计算过程及其示例演示,描述如下:
1. 计算生成 Embedding 矩阵
X 是基于输入数据得到的 n x d 的 Embedding 矩阵:通过选择并使用某个 Embedding 模型,输入数据经过分词得到 n 个 token,再获取到每个 token 的 Embedding 向量,然后将这 n 个 Embedding 向量拼接成一个 n x d 的 Embedding 矩阵。
假设我们输入的是一个句子:
Life is short, less is more.
比如,要训练模型以便下游执行机器翻译任务,上面句子对应的 Embedding 是一个 6 x 16 的 Tensor,即 n=6,d=16,如下所示:
tensor([[ 0.3374, -0.1778, -0.3035, -0.5880, 0.3486, 0.6603, -0.2196, -0.3792, 0.7671, -1.1925, 0.6984, -1.4097, 0.1794, 1.8951, 0.4954, 0.2692], [-1.3250, 0.1784, -2.1338, 1.0524, -0.3885, -0.9343, -0.4991, -1.0867, 0.8805, 1.5542, 0.6266, -0.1755, 0.0983, -0.0935, 0.2662, -0.5850], [ 0.2553, -0.5496, 1.0042, 0.8272, -0.3948, 0.4892, -0.2168, -1.7472, -1.6025, -1.0764, 0.9031, -0.7218, -0.5951, -0.7112, 0.6230, -1.3729], [ 0.8768, 1.6221, -1.4779, 1.1331, -1.2203, 1.3139, 1.0533, 0.1388, 2.2473, -0.8036, -0.2808, 0.7697, -0.6596, -0.7979, 0.1838, 0.2293], [-1.3250, 0.1784, -2.1338, 1.0524, -0.3885, -0.9343, -0.4991, -1.0867, 0.8805, 1.5542, 0.6266, -0.1755, 0.0983, -0.0935, 0.2662, -0.5850], [ 0.5146, 0.9938, -0.2587, -1.0826, -0.0444, 1.6236, -2.3229, 1.0878, 0.6716, 0.6933, -0.9487, -0.0765, -0.1526, 0.1167, 0.4403, -1.4465]]) torch.Size([6, 16])
2. 计算生成 Q、K、V 矩阵
定义 3 个矩阵 Wq、Wk、Wv 它们是权重矩阵,输入 X 与这 3 个矩阵的转置分别相乘,得到 Q、K、V。其中,Q 是 n x dq 矩阵,K 是 n x dk 矩阵,V 是 n x dv 矩阵。计算的公式如下所示(().T 表示对矩阵的转置操作):
- Q = X · (Wq).T
- K = X · (Wk).T
- V = X · (Wv).T
Wq、Wk、Wv 是三个参数矩阵,通过在模型进行训练的过程中,经过学习会得到这 3 个矩阵对应的参数值。
在模型训练过程中,参数矩阵可以采用随机策略生成,这里我们随机生成 Wq(8 x 16 矩阵)、Wk(8 x 16 矩阵)、Wv(12 x 16 矩阵),使用 PyTorch 生成:
Wq = torch.rand(8, 16) print(Wq.T) print(Wq.T.shape) Wk = torch.rand(8, 16) print(Wk.T) print(Wk.T.shape) Wv = torch.rand(12, 16) print(Wv.T) print(Wv.T.shape)
从而可以分别得到这 3 个矩阵的转置的结果,如下所示:
tensor([[0.9882, 0.7577, 0.6397, 0.5757, 0.6042, 0.0206, 0.6886, 0.4270], [0.8363, 0.4536, 0.8954, 0.2785, 0.9836, 0.3247, 0.9024, 0.8210], [0.9010, 0.4130, 0.2979, 0.1946, 0.1444, 0.9355, 0.1123, 0.3605], [0.3950, 0.5585, 0.6314, 0.5382, 0.9010, 0.5855, 0.2685, 0.4516], [0.8809, 0.1170, 0.5028, 0.1291, 0.9221, 0.4695, 0.6591, 0.7056], [0.1084, 0.5578, 0.1239, 0.1242, 0.9043, 0.5201, 0.1735, 0.1853], [0.5432, 0.6681, 0.3786, 0.1746, 0.5713, 0.8118, 0.9247, 0.6339], [0.2185, 0.9275, 0.1661, 0.3302, 0.9546, 0.0585, 0.6166, 0.3894], [0.3834, 0.3443, 0.7211, 0.5370, 0.8339, 0.1142, 0.3608, 0.7398], [0.3720, 0.6800, 0.5449, 0.8443, 0.8730, 0.3338, 0.5325, 0.2288], [0.5374, 0.9998, 0.5490, 0.6937, 0.4675, 0.2122, 0.6559, 0.5185], [0.9551, 0.2855, 0.3483, 0.8831, 0.1163, 0.7579, 0.3232, 0.5489], [0.7475, 0.9753, 0.5024, 0.1861, 0.4938, 0.8533, 0.1126, 0.0977], [0.4979, 0.2518, 0.3445, 0.5422, 0.5938, 0.0149, 0.5034, 0.1364], [0.8549, 0.7204, 0.6437, 0.0556, 0.1594, 0.0757, 0.5091, 0.6918], [0.2438, 0.6959, 0.9856, 0.7868, 0.2132, 0.0131, 0.5101, 0.3545]]) torch.Size([16, 8]) tensor([[0.7969, 0.7674, 0.5840, 0.6569, 0.0975, 0.3534, 0.1829, 0.1836], [0.0061, 0.4058, 0.1227, 0.3704, 0.2956, 0.6638, 0.2956, 0.2010], [0.2528, 0.1548, 0.9587, 0.3630, 0.9027, 0.4563, 0.8646, 0.9603], [0.0882, 0.5201, 0.9914, 0.0578, 0.3112, 0.1091, 0.8010, 0.6861], [0.6997, 0.8773, 0.1547, 0.3629, 0.9167, 0.3069, 0.8044, 0.4209], [0.4855, 0.9577, 0.5185, 0.2974, 0.4139, 0.7274, 0.0733, 0.8046], [0.4067, 0.1226, 0.2337, 0.2275, 0.4362, 0.5164, 0.7355, 0.2621], [0.4168, 0.2742, 0.9794, 0.0484, 0.6996, 0.6845, 0.6248, 0.0638], [0.1092, 0.8893, 0.7788, 0.8916, 0.4265, 0.2073, 0.1638, 0.0036], [0.6418, 0.7444, 0.7945, 0.0532, 0.4958, 0.9727, 0.5158, 0.7032], [0.5125, 0.8095, 0.6613, 0.9964, 0.8463, 0.2913, 0.6000, 0.3051], [0.1549, 0.2511, 0.4502, 0.2377, 0.6671, 0.6066, 0.2299, 0.8070], [0.6881, 0.9308, 0.7815, 0.4616, 0.4801, 0.2557, 0.2890, 0.9271], [0.4900, 0.0890, 0.5085, 0.9079, 0.6904, 0.2588, 0.9078, 0.6647], [0.0164, 0.4759, 0.3176, 0.6650, 0.9355, 0.7239, 0.4596, 0.9296], [0.7690, 0.5104, 0.7582, 0.3573, 0.6260, 0.3604, 0.4947, 0.3848]]) torch.Size([16, 8]) tensor([[0.9357, 0.8048, 0.0204, 0.9148, 0.2590, 0.4970, 0.1098, 0.9920, 0.4196, 0.5242, 0.5921, 0.8188], [0.2616, 0.0649, 0.8290, 0.1705, 0.7162, 0.3552, 0.6353, 0.4791, 0.0050, 0.5153, 0.0056, 0.5773], [0.4344, 0.8322, 0.1063, 0.0943, 0.5689, 0.2576, 0.3719, 0.7945, 0.1368, 0.5047, 0.5577, 0.7870], [0.8323, 0.3672, 0.2062, 0.8800, 0.8181, 0.7346, 0.0574, 0.9323, 0.8588, 0.7175, 0.3350, 0.8855], [0.2410, 0.9012, 0.5058, 0.2614, 0.8286, 0.4564, 0.6951, 0.1144, 0.0121, 0.3116, 0.4620, 0.9941], [0.8815, 0.8146, 0.6522, 0.5325, 0.5292, 0.4009, 0.6766, 0.8039, 0.2541, 0.5315, 0.7872, 0.3705], [0.6226, 0.2077, 0.7905, 0.9981, 0.7914, 0.8474, 0.5674, 0.0651, 0.0475, 0.5021, 0.3279, 0.5148], [0.4902, 0.4474, 0.4298, 0.3005, 0.1387, 0.1203, 0.8267, 0.3650, 0.7690, 0.7111, 0.1213, 0.2103], [0.9279, 0.5746, 0.2427, 0.9657, 0.0221, 0.8265, 0.2993, 0.2984, 0.8418, 0.1939, 0.5302, 0.9562], [0.8751, 0.6429, 0.4570, 0.8973, 0.0927, 0.9441, 0.9564, 0.0324, 0.5438, 0.1091, 0.3608, 0.6591], [0.2943, 0.0369, 0.6638, 0.8862, 0.7759, 0.1928, 0.1189, 0.0290, 0.2486, 0.0931, 0.2668, 0.4172], [0.5485, 0.5224, 0.2187, 0.6483, 0.9598, 0.0263, 0.9508, 0.0179, 0.3788, 0.7101, 0.3473, 0.6253], [0.5583, 0.7605, 0.0657, 0.2746, 0.3617, 0.5696, 0.8715, 0.1132, 0.5291, 0.8978, 0.2165, 0.9961], [0.9096, 0.7823, 0.7387, 0.8148, 0.7766, 0.1197, 0.0552, 0.2206, 0.7095, 0.9959, 0.9389, 0.7036], [0.7810, 0.7459, 0.1691, 0.1575, 0.1427, 0.7091, 0.4556, 0.3352, 0.5086, 0.6785, 0.5344, 0.7429], [0.9049, 0.5791, 0.2186, 0.2087, 0.4906, 0.1012, 0.2310, 0.7797, 0.2669, 0.3981, 0.2346, 0.9616]]) torch.Size([16, 12])
接着,分别计算 Q、K、V 这三个矩阵,代码如下所示:
Q = X.matmul(Wq.T) print(Q) print(Q.shape) K = X.matmul(Wk.T) print(K) print(K.shape) V = X.matmul(Wv.T) print(V) print(V.shape)
计算得到 Q、K、V 三个矩阵的结果,如下所示:
tensor([[ 0.3018, 0.5771, 0.8287, -0.3224, 0.9979, -1.3807, 0.7953, 0.3018], [-2.4223, -1.3309, 0.0371, 0.4128, -0.2362, -1.7722, -0.9576, -0.6908], [-1.5899, -2.5706, -3.0113, -3.2927, -4.0568, -0.3453, -3.0388, -2.1831], [ 1.4786, 2.4595, 3.2942, 2.1628, 4.1394, 0.7536, 2.8714, 3.5802], [-2.4223, -1.3309, 0.0371, 0.4128, -0.2362, -1.7722, -0.9576, -0.6908], [-0.3082, -0.5900, -0.9257, -0.7688, 1.8828, -1.6065, -0.8011, -0.4114]]) torch.Size([6, 8]) tensor([[ 1.1813, 1.3559, 0.2174, 3.4764, 1.1683, -0.6475, 0.9448, 0.1573], [-1.9191, -0.0077, -1.5536, -0.7241, -1.9709, -1.2558, -1.5606, -1.4531], [-2.6464, -2.5228, -3.2055, -1.5096, -2.4377, -2.7335, -1.7701, -0.5160], [-0.2554, 2.9326, 1.8757, 1.9398, -0.1509, 1.9660, -0.5092, -0.2307], [-1.9191, -0.0077, -1.5536, -0.7241, -1.9709, -1.2558, -1.5606, -1.4531], [-0.6116, 1.3902, -0.1460, 0.0244, -0.5577, 1.5972, -2.2190, -0.0214]]) torch.Size([6, 8]) tensor([[ 1.5465, 1.5362, 1.3274, 0.9452, 0.5531, 0.1000, -1.5909, 0.8779, 0.8645, 1.1643, 2.2148, 1.3088], [-1.0912, -2.8470, -0.4005, 0.6766, -1.7351, 1.0082, -1.1248, -3.2161, 0.5959, -2.3485, -1.7592, -1.2618], [-3.7293, -2.7705, -2.1744, -2.7340, -1.0410, -1.8867, -4.0902, -0.3303, -3.1343, -2.4864, -1.1285, -3.5427], [ 4.1125, 0.5593, 2.1795, 4.3551, 1.7887, 3.0898, 1.5155, 3.2049, 2.5387, 2.0061, 1.2701, 2.4616], [-1.0912, -2.8470, -0.4005, 0.6766, -1.7351, 1.0082, -1.1248, -3.2161, 0.5959, -2.3485, -1.7592, -1.2618], [ 0.2002, 1.3752, -0.0809, -1.2746, -2.3948, -0.3425, 1.5967, 0.5399, 0.9113, 0.0962, 0.7300, -1.0553]]) torch.Size([6, 12])
3. 计算注意力分数矩阵
上一步已经得到 Q、K,其中 Q 是 n x dq 矩阵,K 是 n x dk 矩阵,这里 dq = dk。
注意力分数矩阵(Attention Score Matrix)的计算,直接使用 Q 与 K 的转置相乘,得到的结果是一个 n x n 矩阵 Q(K).T。
使用 PyTorch 计算 Q·(K).T,代码如下:
QKT = Q.matmul(K.T) print(QKT) print(QKT.shape)
计算结果,如下所示:
tensor([[ 3.0568, -3.5501, -4.6460, -0.7954, -3.5501, -4.0441], [ -3.3648, 9.4917, 16.4977, -5.2158, 9.4917, -0.9232], [-25.1955, 26.4776, 42.6549, -17.1836, 26.4776, 6.2588], [ 20.9407, -28.3288, -43.0230, 15.7782, -28.3288, -5.4665], [ -3.3648, 9.4917, 16.4977, -5.2158, 9.4917, -0.9232], [ -1.6198, 2.7452, 7.8636, -7.8188, 2.7452, -2.3448]]) torch.Size([6, 6])
4. 标准化注意力分数矩阵,并与 V 相乘得到自注意力结果
针对矩阵 QKT 进行标准化 Softmax 操作,得到归一化的结果矩阵 A,A 的大小为 n x n。
最后,n x dv 的矩阵 V,乘以 n x n 的矩阵 A,得到了最后的自注意力结果矩阵 Z。
计算 Z = AT(Q, K, V) 的计算公式如下图所示:
前面计算已经得到 QKT 矩阵,n=6,dk=8,则 A 的大小也是 6 x 6。
dk = 16 A = F.softmax(QKT/(dk**0.50), dim=0) print(A)
然后再乘以矩阵 V:
Z = A.matmul(V) print(Z) print(Z.shape)
结果示例如下所示:
tensor([[-2.9351e-02, 6.1129e-02, 5.4583e-02, 9.4816e-02, 4.9514e-02, 1.2487e-01, 2.9251e-02, 6.3045e-03, -2.7838e-02, 2.7134e-02, 1.2357e-01, 4.0422e-02], [-2.4056e-01, 3.3915e-02, 1.9047e-03, 1.1973e-01, -1.5679e-02, 1.4004e-01, -9.8062e-02, -8.4042e-02, -1.5060e-01, -1.1663e-02, 1.6010e-01, -2.3983e-02], [-7.0469e+00, -5.9786e+00, -6.3531e+00, -3.7648e+00, -1.4788e+00, -5.4479e+00, -6.4546e+00, -6.2927e+00, -2.4734e+00, -3.2765e+00, -5.1672e+00, -6.0859e+00], [ 3.8240e+00, 1.7311e+00, 2.8731e+00, 2.2339e+00, 4.9503e+00, 2.7620e+00, 3.9712e+00, 1.2559e+00, 3.1700e+00, 1.4015e+00, 1.8649e+00, 2.7100e+00], [-2.4056e-01, 3.3915e-02, 1.9047e-03, 1.1973e-01, -1.5679e-02, 1.4004e-01, -9.8062e-02, -8.4042e-02, -1.5060e-01, -1.1663e-02, 1.6010e-01, -2.3983e-02], [-1.2892e-01, 4.9015e-02, 2.3402e-02, 9.5212e-02, -1.8994e-02, 1.2503e-01, -3.3910e-02, -1.9795e-02, -1.0338e-01, 1.0985e-02, 1.5254e-01, 6.4172e-03]]) torch.Size([6, 12])
或者,我们可以直接使用 PyTorch 的 torch.nn.functional.scaled_dot_product_attention() 函数,从 Q、K、V 一步就可以计算得到,如下所示:
import torch.nn.functional as F Z = F.scaled_dot_product_attention(Q, K, V) print(Z) print(Z.shape)
参考资源
- Understanding and Coding the Self-Attention Mechanism of Large Language Models From Scratch
- https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html
本文基于署名-非商业性使用-相同方式共享 4.0许可协议发布,欢迎转载、使用、重新发布,但务必保留文章署名时延军(包含链接:http://shiyanjun.cn),不得用于商业目的,基于本文修改后的作品务必以相同的许可发布。如有任何疑问,请与我联系。
你好,这一行“前面计算已经得到 QKT 矩阵,n=6,dk=8,则 A 的大小也是 6 x 6。”请问在下面的代码中为什么dk = 16?