llms-from-scratch-cn/Translated_Book/ch03/3.6.ipynb
2024-05-16 22:22:49 +08:00

445 lines
24 KiB
Plaintext
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# 3.6 将单头注意力扩展到多头注意力\n",
"\n",
"在本章的最后一节中,我们将在多个头的基础上扩展之前实现的 CausalAttention 类。这被称为多头注意力机制Multi-head Attention。\n",
"\n",
"“多头”这一术语指的是将注意力机制分为多个“头”,每个头独立运作。在这种情况下,单个因果注意力模块可以被视为单头注意力,其中只有一组注意力权重顺序处理输入。\n",
"\n",
"在以下小节中,我们将从因果注意力扩展到多头注意力。第一小节将直观地通过堆叠多个 CausalAttention 模块来构建 Multi-head Attention 模块,用于示例说明。第二小节将以更复杂但计算上更高效的方式实现相同的多头注意力模块。\n",
"\n",
"## 3.6.1 堆叠多个 Single-head Attention 层\n",
"\n",
"在实际操作中,实现多头注意力机制需要创建多个自注意力机制的实例(如第 3.4.1 节图 3.18 所示),其中每个实例都有自己的权重,然后合并这些示例的输出。尽管使用多个自注意力机制实例计算量很大,但这对于像 Transformer 基础的大语言模型所需的复杂模式识别至关重要。\n",
"\n",
"图 3.24 展示了 Multi-head Attention 模块的结构,它由(如图 3.18 所示的)多个 Single-head Attention 模块堆叠而成。\n",
"\n",
"**图 3.24 这张图中的多头注意力模块由两个单头注意力模块堆叠在一起。因此,在一个具有两个头的多头注意力模块中,我们不再使用单个矩阵 Wv 来计算值矩阵而是使用两个值权重矩阵Wv1 和 Wv2 。同样地Wq 和 Wk 也各自有两组权重矩阵。我们得到两组上下文向量 Z1 和 Z2 ,然后将它们组合成一个上下文向量矩阵 Z 。**\n",
"\n",
"![3.24](../img/fig-3-24.jpg)\n",
"\n",
"如前所述,多头注意力的主要思想是通过不同的、学习到的线性投影,多次(并行地)运行注意力机制————即将输入数据(如注意力机制中的查询、键和值向量)与权重矩阵相乘。\n",
"\n",
"在代码中,我们可以通过实现一个简单的 MultiHeadAttentionWrapper 类来实现这一点,该类堆叠了我们之前实现的多个 CausalAttention 模块实例:\n",
"\n",
"### 清单 3.4 实现 MultiHeadAttentionWrapper 类\n"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {
"trusted": true
},
"outputs": [],
"source": [
"from torch import nn\n",
"class MultiHeadAttentionWrapper(nn.Module):\n",
" def __init__(self, d_in, d_out, context_length,\n",
" dropout, num_heads, qkv_bias=False):\n",
" super().__init__()\n",
" self.heads = nn.ModuleList(\n",
" [CausalAttention(d_in, d_out, context_length, dropout, qkv_bias)\n",
" for _ in range(num_heads)]\n",
" )\n",
" \n",
" def forward(self, x):\n",
" return torch.cat([head(x) for head in self.heads], dim=-1)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"例如,如果我们使用这个 MultiHeadAttentionWrapper 类并通过 num_heads=2 设置两个注意力头,且将 CausalAttention 的输出维度设置为2d_out=2这将导致一个四维的上下文向量 (d_out*num_heads=4),如图 3.25 所示。\n",
"\n",
"**图 3.25 使用 MultiHeadAttentionWrapper 我们指定了注意力头的数量num_heads。如果我们设置 num_heads=2如图所示我们将得到一个包含两组上下文向量矩阵的张量。在每个上下文向量矩阵中行表示对应于 Token 的上下文向量,列对应于通过 d_out=4 指定的嵌入维度。我们沿列维度连接这些上下文向量矩阵。由于我们有 2 个注意力头和嵌入维度为 2最终的嵌入维度为 2 × 2 = 4。**\n",
"\n",
"![3.25](../img/fig-3-25.jpg)\n",
"\n",
"为了进一步说明图 3.25,我们可以像之前使用 CausalAttention 类那样使用 MultiHeadAttentionWrapper 类:"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {
"trusted": true
},
"outputs": [
{
"ename": "NameError",
"evalue": "name 'batch' is not defined",
"output_type": "error",
"traceback": [
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
"\u001b[0;31mNameError\u001b[0m Traceback (most recent call last)",
"Input \u001b[0;32mIn [6]\u001b[0m, in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[1;32m 1\u001b[0m torch\u001b[38;5;241m.\u001b[39mmanual_seed(\u001b[38;5;241m123\u001b[39m)\n\u001b[0;32m----> 2\u001b[0m context_length \u001b[38;5;241m=\u001b[39m \u001b[43mbatch\u001b[49m\u001b[38;5;241m.\u001b[39mshape[\u001b[38;5;241m1\u001b[39m] \u001b[38;5;66;03m# This is the number of tokens\u001b[39;00m\n\u001b[1;32m 3\u001b[0m d_in, d_out \u001b[38;5;241m=\u001b[39m \u001b[38;5;241m3\u001b[39m, \u001b[38;5;241m2\u001b[39m\n\u001b[1;32m 4\u001b[0m mha \u001b[38;5;241m=\u001b[39m MultiHeadAttentionWrapper(d_in, d_out, context_length, \u001b[38;5;241m0.0\u001b[39m, num_heads\u001b[38;5;241m=\u001b[39m\u001b[38;5;241m2\u001b[39m)\n",
"\u001b[0;31mNameError\u001b[0m: name 'batch' is not defined"
]
}
],
"source": [
"torch.manual_seed(123)\n",
"context_length = batch.shape[1] # This is the number of tokens\n",
"d_in, d_out = 3, 2\n",
"mha = MultiHeadAttentionWrapper(d_in, d_out, context_length, 0.0, num_heads=2)\n",
"context_vecs = mha(batch)\n",
" \n",
"print(context_vecs)\n",
"print(\"context_vecs.shape:\", context_vecs.shape)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"得到以下张量表示上下文向量:\n",
"```python\n",
"tensor([[[-0.4519, 0.2216, 0.4772, 0.1063],\n",
" [-0.5874, 0.0058, 0.5891, 0.3257],\n",
" [-0.6300, -0.0632, 0.6202, 0.3860],\n",
" [-0.5675, -0.0843, 0.5478, 0.3589],\n",
" [-0.5526, -0.0981, 0.5321, 0.3428],\n",
" [-0.5299, -0.1081, 0.5077, 0.3493]],\n",
" \n",
" [[-0.4519, 0.2216, 0.4772, 0.1063],\n",
" [-0.5874, 0.0058, 0.5891, 0.3257],\n",
" [-0.6300, -0.0632, 0.6202, 0.3860],\n",
" [-0.5675, -0.0843, 0.5478, 0.3589],\n",
" [-0.5526, -0.0981, 0.5321, 0.3428],\n",
" [-0.5299, -0.1081, 0.5077, 0.3493]]], grad_fn=<CatBackward0>)\n",
"context_vecs.shape: torch.Size([2, 6, 4])\n",
"```"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"第一个维度的上下文向量张量为 2因为我们有两个输入文本输入文本是复制的这就是为什么这些上下文向量对于它们来说完全相同。第二维度指的是每个输入中的 6 个 Token。第三维度指的是每个 Token 的四维嵌入。\n",
"\n",
"### 练习 3.2 返回二维嵌入向量\n",
"\n",
"更改 MultiHeadAttentionWrapper (..., num_heads=2) 调用的输入参数,使输出上下文向量为二维而不是四维,同时保持 num_heads=2 的设置。提示:你不需要修改类的实现、你只需要更改其他一个输入参数。\n",
"\n",
"在本节中,我们实现了 MultiHeadAttentionWrapper ,它结合了多个 Single-head Attention 模块。请注意,这些要在 forward 方法中 [head(x) for head in self.heads] 顺序处理。我们可以通过并行处理头来改进这个实现。一种实现这点的方法是通过矩阵乘法同时计算所有注意力头的输出,我们将在下一节中探讨这一点。\n",
"\n",
"\n",
"## 3.6.2 通过权重分割实现多头注意力\n",
"\n",
"在前一节中,我们创建了一个 MultiHeadAttentionWrapper 来通过堆叠多个 Single-head Attention 模块实现多头注意力。这是通过实例化并组合几个 CausalAttention 对象完成的。\n",
"\n",
"我们可以将 MultiHeadAttentionWrapper 和 CausalAttention 这两个概念合并成一个单一的 MultiHeadAttentionWrapper 类,而不是同时保有两个单独的类。此外,除了将 MultiHeadAttentionWrapper 与 CausalAttention 的代码合并之外,我们还将进行一些其他修改以更有效地实现多头注意力机制。\n",
"\n",
"在 MultiHeadAttentionWrapper 中,通过创建一系列 CausalAttention 对象self.heads来实现多个头每个头代表一个单独的注意力头。 CausalAttention 类独立执行注意力机制,每个头的结果被连接起来。相比之下,下面的 MultiHeadAttention 类将多头功能集成在一个类中。它通过重塑投影的查询、键和值张量将输入分割成多个头,然后在计算注意力后组合这些头的结果。\n",
"\n",
"让我们在进一步讨论之前先看一下 MultiHeadAttention 类:\n",
"\n",
"### 清单 3.5 一个高效的 MultiHeadAttention 类\n"
]
},
{
"cell_type": "code",
"execution_count": 12,
"metadata": {
"trusted": true
},
"outputs": [],
"source": [
"class MultiHeadAttention(nn.Module):\n",
" def __init__(self, d_in, d_out, \n",
" context_length, dropout, num_heads, qkv_bias=False):\n",
" super().__init__()\n",
" assert d_out % num_heads == 0, \"d_out must be divisible by num_heads\"\n",
" self.d_out = d_out\n",
" self.num_heads = num_heads\n",
" self.head_dim = d_out // num_heads #A\n",
" self.W_query = nn.Linear(d_in, d_out, bias=qkv_bias)\n",
" self.W_key = nn.Linear(d_in, d_out, bias=qkv_bias)\n",
" self.W_value = nn.Linear(d_in, d_out, bias=qkv_bias)\n",
" self.out_proj = nn.Linear(d_out, d_out) #B\n",
" self.dropout = nn.Dropout(dropout)\n",
" self.register_buffer('mask', torch.triu(torch.ones(context_length, context_length), diagonal=1))\n",
" \n",
" def forward(self, x):\n",
" b, num_tokens, d_in = x.shape\n",
" keys = self.W_key(x) #C\n",
" queries = self.W_query(x) #C\n",
" values = self.W_value(x) #C\n",
" \n",
" keys = keys.view(b, num_tokens, self.num_heads, self.head_dim) #D\n",
" values = values.view(b, num_tokens, self.num_heads, self.head_dim) #\n",
" queries = queries.view(b, num_tokens, self.num_heads, self.head_dim)\n",
" \n",
" keys = keys.transpose(1, 2) #E\n",
" queries = queries.transpose(1, 2) #E\n",
" values = values.transpose(1, 2) #E\n",
" \n",
" attn_scores = queries @ keys.transpose(2, 3) #F \n",
" mask_bool = self.mask.bool()[:num_tokens, :num_tokens] #G\n",
" \n",
" attn_scores.masked_fill_(mask_bool, -torch.inf) #H\n",
" \n",
" attn_weights = torch.softmax(\n",
" attn_scores / keys.shape[-1]**0.5, dim=-1)\n",
" attn_weights = self.dropout(attn_weights)\n",
" \n",
" context_vec = (attn_weights @ values).transpose(1, 2) #I\n",
" #J\n",
" context_vec = context_vec.contiguous().view(b, num_tokens, self.d_out)\n",
" context_vec = self.out_proj(context_vec) #K\n",
" return context_vec"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"尽管 MultiHeadAttention 类中的张量重塑 (.view) 和转置 (.transpose) 看起来非常复杂但从数学上讲MultiHeadAttention 类实现的概念与之前的 MultiHeadAttentionWrapper 相同。\n",
"\n",
"从宏观层面上看,在之前的 MultiHeadAttentionWrapper 中,我们堆叠了多个 Single-head Attention 层,然后将它们组合成一个 MultiHeadAttention 层。MultiHeadAttention 类采取了一种集成的方法。它从一个 Multi-head Attention 层开始,然后在内部将这个层分割成单独的注意力头,如图 3.26 所示。\n",
"\n",
"**图 3.26 在带有两个注意力头的 MultiHeadAttentionWrapper 类中,我们初始化了两个权重矩阵 Wq1 和 Wq2并计算了两个查询矩阵 Q1 和 Q2如图顶部所示。在 MultiHeadAttention 类中,我们初始化一个更大的权重矩阵 Wq只执行一次与输入的矩阵乘法以获得查询矩阵 Q然后将查询矩阵分割成 Q1 和 Q2如图底部所示。我们对键和值做同样的处理为了减少视觉混乱这部分处理没有显示出来。**\n",
"\n",
"![3.26](../img/fig-3-26.jpg)\n",
"\n",
"如图 3.26 所示,查询、键和值张量的分割是通过使用 PyTorch 的 .view 和 .transpose 方法进行张量重塑和转置操作来实现的。输入首先通过线性层转换(针对查询、键和值),然后被重塑来表示多个头。\n",
"\n",
"关键操作是将 d_out 维度分割为 num_heads 和 head_dim其中 head_dim = d_out / num_heads。这种分割随后通过 .view 方法实现:将维度为 (b, num_tokens, d_out) 的张量重塑为维度 (b, num_tokens, num_heads, head_dim)。\n",
"\n",
"随后张量被转置使得多头维度num_heads排在序列长度维度num_tokens之前形成 (b, num_heads, num_tokens, head_dim) 的结构。这种转置对于正确匹配不同头的查询、键和值,以及高效进行批量矩阵乘法至关重要。\n",
"\n",
"为了说明这种批量矩阵乘法,假设我们有以下示例张量:\n"
]
},
{
"cell_type": "code",
"execution_count": 13,
"metadata": {
"trusted": true
},
"outputs": [],
"source": [
"a = torch.tensor([[[[0.2745, 0.6584, 0.2775, 0.8573], #A\n",
" [0.8993, 0.0390, 0.9268, 0.7388],\n",
" [0.7179, 0.7058, 0.9156, 0.4340]],\n",
" \n",
" [[0.0772, 0.3565, 0.1479, 0.5331],\n",
" [0.4066, 0.2318, 0.4545, 0.9737],\n",
" [0.4606, 0.5159, 0.4220, 0.5786]]]])"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"现在,我们执行一个批量矩阵乘法,将张量本身和一个转置了最后两个维度的张量视图之间进行矩阵乘法:"
]
},
{
"cell_type": "code",
"execution_count": 14,
"metadata": {
"trusted": true
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"tensor([[[[1.3208, 1.1631, 1.2879],\n",
" [1.1631, 2.2150, 1.8424],\n",
" [1.2879, 1.8424, 2.0402]],\n",
"\n",
" [[0.4391, 0.7003, 0.5903],\n",
" [0.7003, 1.3737, 1.0620],\n",
" [0.5903, 1.0620, 0.9912]]]])\n"
]
}
],
"source": [
"print(a @ a.transpose(2, 3))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"结果如下:\n",
"```python\n",
"tensor([[[[1.3208, 1.1631, 1.2879],\n",
" [1.1631, 2.2150, 1.8424],\n",
" [1.2879, 1.8424, 2.0402]],\n",
" \n",
" [[0.4391, 0.7003, 0.5903],\n",
" [0.7003, 1.3737, 1.0620],\n",
" [0.5903, 1.0620, 0.9912]]]])\n",
"```"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"在该例中PyTorch 中的矩阵乘法实现可以处理四维输入张量使得矩阵乘法在最后两个维度num_tokens, head_dim之间进行然后针对各个头重复执行。\n",
"\n",
"例如,上述方法可以更简洁地分别计算每个头的矩阵乘法:"
]
},
{
"cell_type": "code",
"execution_count": 15,
"metadata": {
"trusted": true
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"First head:\n",
" tensor([[1.3208, 1.1631, 1.2879],\n",
" [1.1631, 2.2150, 1.8424],\n",
" [1.2879, 1.8424, 2.0402]])\n",
"\n",
"Second head:\n",
" tensor([[0.4391, 0.7003, 0.5903],\n",
" [0.7003, 1.3737, 1.0620],\n",
" [0.5903, 1.0620, 0.9912]])\n"
]
}
],
"source": [
"first_head = a[0, 0, :, :]\n",
"first_res = first_head @ first_head.T\n",
"print(\"First head:\\n\", first_res)\n",
"\n",
"second_head = a[0, 1, :, :]\n",
"second_res = second_head @ second_head.T\n",
"print(\"\\nSecond head:\\n\", second_res)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"结果与我们之前使用批量矩阵乘法 print(a @ a.transpose(2, 3)) 获得的结果完全相同:\n",
"```python\n",
"First head:\n",
" tensor([[1.3208, 1.1631, 1.2879],\n",
" [1.1631, 2.2150, 1.8424],\n",
" [1.2879, 1.8424, 2.0402]])\n",
"\n",
"Second head:\n",
" tensor([[0.4391, 0.7003, 0.5903],\n",
" [0.7003, 1.3737, 1.0620],\n",
" [0.5903, 1.0620, 0.9912]])\n",
"```"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"继续进行 MultiHeadAttention 的操作,计算完注意力权重和上下文向量后,所有头部的上下文向量被重新转置为 (b, num_tokens, num_heads, head_dim) 的形状。然后,这些向量被重塑(扁平化)成 (b, num_tokens, d_out) 的形状,有效地合并了所有头部的输出。\n",
"\n",
"此外,在将头部合并后,我们在 MultiHeadAttention 中添加了一个所谓的输出投影层self.out_proj这在因果注意力类中并不存在。这个输出投影层虽然不是绝对必要的更多细节请参见附录 B 的参考文献部分),但它在许多大语言模型架构中常见,这也是为什么我们在这里添加它以保持完整性。\n",
"\n",
"尽管由于额外的重塑和张量转置, MultiHeadAttention 类看起来比 MultiHeadAttentionWrapper 更复杂,但它更高效。原因是我们只需要一次矩阵乘法就可以计算出键,例如 keys = self.W_key(x)(对查询和值同样适用)。在 MultiHeadAttentionWrapper 中,我们需要重复这一矩阵乘法,这在计算上是最昂贵的步骤之一,需要为每一个注意力头重复。\n",
"\n",
" MultiHeadAttention 类可以像我们之前实现的 SelfAttention 和 CausalAttention 类一样使用:\n"
]
},
{
"cell_type": "code",
"execution_count": 16,
"metadata": {
"trusted": true
},
"outputs": [
{
"ename": "NameError",
"evalue": "name 'batch' is not defined",
"output_type": "error",
"traceback": [
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
"\u001b[0;31mNameError\u001b[0m Traceback (most recent call last)",
"Input \u001b[0;32mIn [16]\u001b[0m, in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[1;32m 1\u001b[0m torch\u001b[38;5;241m.\u001b[39mmanual_seed(\u001b[38;5;241m123\u001b[39m)\n\u001b[0;32m----> 2\u001b[0m batch_size, context_length, d_in \u001b[38;5;241m=\u001b[39m \u001b[43mbatch\u001b[49m\u001b[38;5;241m.\u001b[39mshape\n\u001b[1;32m 3\u001b[0m d_out \u001b[38;5;241m=\u001b[39m \u001b[38;5;241m2\u001b[39m\n\u001b[1;32m 4\u001b[0m mha \u001b[38;5;241m=\u001b[39m MultiHeadAttention(d_in, d_out, context_length, \u001b[38;5;241m0.0\u001b[39m, num_heads\u001b[38;5;241m=\u001b[39m\u001b[38;5;241m2\u001b[39m)\n",
"\u001b[0;31mNameError\u001b[0m: name 'batch' is not defined"
]
}
],
"source": [
"torch.manual_seed(123)\n",
"batch_size, context_length, d_in = batch.shape\n",
"d_out = 2\n",
"mha = MultiHeadAttention(d_in, d_out, context_length, 0.0, num_heads=2)\n",
"context_vecs = mha(batch)\n",
"print(context_vecs)\n",
"print(\"context_vecs.shape:\", context_vecs.shape)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"从结果来看,输出维度直接由 d_out 参数控制:\n",
"```python\n",
"tensor([[[0.3190, 0.4858],\n",
" [0.2943, 0.3897],\n",
" [0.2856, 0.3593],\n",
" [0.2693, 0.3873],\n",
" [0.2639, 0.3928],\n",
" [0.2575, 0.4028]],\n",
" \n",
" [[0.3190, 0.4858],\n",
" [0.2943, 0.3897],\n",
" [0.2856, 0.3593],\n",
" [0.2693, 0.3873],\n",
" [0.2639, 0.3928],\n",
" [0.2575, 0.4028]]], grad_fn=<ViewBackward0>)\n",
"context_vecs.shape: torch.Size([2, 6, 2])\n",
"```"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"在本节中,我们实现了在即将到来的章节中用于实现和训练大语言模型的 MultiHeadAttention 类。请注意,虽然代码完全可用,但我们使用了相对较小的嵌入规模和注意力头数量,以保持输出的可读性。\n",
"\n",
"作为比较,最小的 GPT-2 模型1.17亿参数)有 12 个注意力头和 768 的上下文向量嵌入规模。最大的 GPT 2 模型15 亿参数)有 25 个注意力头和 1600 的上下文向量嵌入规模。注意,在 GPT 模型中Token 输入和上下文嵌入的嵌入规模是相同的d_in = d_out。\n",
"\n",
"## 练习 3.3 初始化具有 GPT-2 规模的注意力模块\n",
"使用 MultiHeadAttention 类,初始化一个具有与最小 GPT-2 模型相同数量的注意力头12 个注意力头)的 MultiHeadAttention 模块。同时确保你使用与 GPT-2 相似的输入和输出嵌入规模768 维)。请注意,最小的 GPT-2 模型支持 1024 个 Token 的上下文长度。"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "minitorch",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.12"
}
},
"nbformat": 4,
"nbformat_minor": 4
}