mirror of
https://github.com/datawhalechina/llms-from-scratch-cn.git
synced 2026-01-13 16:57:18 +08:00
311 lines
7.9 KiB
Plaintext
311 lines
7.9 KiB
Plaintext
{
|
||
"cells": [
|
||
{
|
||
"cell_type": "markdown",
|
||
"id": "51c9672d-8d0c-470d-ac2d-1271f8ec3f14",
|
||
"metadata": {},
|
||
"source": [
|
||
"# Chapter 3 习题解答"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"id": "33dfa199-9aee-41d4-a64b-7e3811b9a616",
|
||
"metadata": {},
|
||
"source": [
|
||
"# 3.1"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 1,
|
||
"id": "5fee2cf5-61c3-4167-81b5-44ea155bbaf2",
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"import torch\n",
|
||
"\n",
|
||
"inputs = torch.tensor(\n",
|
||
" [[0.43, 0.15, 0.89], # Your (x^1)\n",
|
||
" [0.55, 0.87, 0.66], # journey (x^2)\n",
|
||
" [0.57, 0.85, 0.64], # starts (x^3)\n",
|
||
" [0.22, 0.58, 0.33], # with (x^4)\n",
|
||
" [0.77, 0.25, 0.10], # one (x^5)\n",
|
||
" [0.05, 0.80, 0.55]] # step (x^6)\n",
|
||
")\n",
|
||
"\n",
|
||
"d_in, d_out = 3, 2"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 2,
|
||
"id": "62ea289c-41cd-4416-89dd-dde6383a6f70",
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"import torch.nn as nn\n",
|
||
"\n",
|
||
"class SelfAttention_v1(nn.Module):\n",
|
||
"\n",
|
||
" def __init__(self, d_in, d_out):\n",
|
||
" super().__init__()\n",
|
||
" self.d_out = d_out\n",
|
||
" self.W_query = nn.Parameter(torch.rand(d_in, d_out))\n",
|
||
" self.W_key = nn.Parameter(torch.rand(d_in, d_out))\n",
|
||
" self.W_value = nn.Parameter(torch.rand(d_in, d_out))\n",
|
||
"\n",
|
||
" def forward(self, x):\n",
|
||
" keys = x @ self.W_key\n",
|
||
" queries = x @ self.W_query\n",
|
||
" values = x @ self.W_value\n",
|
||
" \n",
|
||
" attn_scores = queries @ keys.T # omega\n",
|
||
" attn_weights = torch.softmax(attn_scores / keys.shape[-1]**0.5, dim=-1)\n",
|
||
"\n",
|
||
" context_vec = attn_weights @ values\n",
|
||
" return context_vec\n",
|
||
"\n",
|
||
"torch.manual_seed(123)\n",
|
||
"sa_v1 = SelfAttention_v1(d_in, d_out)"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 3,
|
||
"id": "7b035143-f4e8-45fb-b398-dec1bd5153d4",
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"class SelfAttention_v2(nn.Module):\n",
|
||
"\n",
|
||
" def __init__(self, d_in, d_out):\n",
|
||
" super().__init__()\n",
|
||
" self.d_out = d_out\n",
|
||
" self.W_query = nn.Linear(d_in, d_out, bias=False)\n",
|
||
" self.W_key = nn.Linear(d_in, d_out, bias=False)\n",
|
||
" self.W_value = nn.Linear(d_in, d_out, bias=False)\n",
|
||
"\n",
|
||
" def forward(self, x):\n",
|
||
" keys = self.W_key(x)\n",
|
||
" queries = self.W_query(x)\n",
|
||
" values = self.W_value(x)\n",
|
||
" \n",
|
||
" attn_scores = queries @ keys.T\n",
|
||
" attn_weights = torch.softmax(attn_scores / keys.shape[-1]**0.5, dim=1)\n",
|
||
"\n",
|
||
" context_vec = attn_weights @ values\n",
|
||
" return context_vec\n",
|
||
"\n",
|
||
"torch.manual_seed(123)\n",
|
||
"sa_v2 = SelfAttention_v2(d_in, d_out)"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 4,
|
||
"id": "7591d79c-c30e-406d-adfd-20c12eb448f6",
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"sa_v1.W_query = torch.nn.Parameter(sa_v2.W_query.weight.T)\n",
|
||
"sa_v1.W_key = torch.nn.Parameter(sa_v2.W_key.weight.T)\n",
|
||
"sa_v1.W_value = torch.nn.Parameter(sa_v2.W_value.weight.T)"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 5,
|
||
"id": "ddd0f54f-6bce-46cc-a428-17c2a56557d0",
|
||
"metadata": {},
|
||
"outputs": [
|
||
{
|
||
"data": {
|
||
"text/plain": [
|
||
"tensor([[-0.5337, -0.1051],\n",
|
||
" [-0.5323, -0.1080],\n",
|
||
" [-0.5323, -0.1079],\n",
|
||
" [-0.5297, -0.1076],\n",
|
||
" [-0.5311, -0.1066],\n",
|
||
" [-0.5299, -0.1081]], grad_fn=<MmBackward0>)"
|
||
]
|
||
},
|
||
"execution_count": 5,
|
||
"metadata": {},
|
||
"output_type": "execute_result"
|
||
}
|
||
],
|
||
"source": [
|
||
"sa_v1(inputs)"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 6,
|
||
"id": "340908f8-1144-4ddd-a9e1-a1c5c3d592f5",
|
||
"metadata": {},
|
||
"outputs": [
|
||
{
|
||
"data": {
|
||
"text/plain": [
|
||
"tensor([[-0.5337, -0.1051],\n",
|
||
" [-0.5323, -0.1080],\n",
|
||
" [-0.5323, -0.1079],\n",
|
||
" [-0.5297, -0.1076],\n",
|
||
" [-0.5311, -0.1066],\n",
|
||
" [-0.5299, -0.1081]], grad_fn=<MmBackward0>)"
|
||
]
|
||
},
|
||
"execution_count": 6,
|
||
"metadata": {},
|
||
"output_type": "execute_result"
|
||
}
|
||
],
|
||
"source": [
|
||
"sa_v2(inputs)"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"id": "33543edb-46b5-4b01-8704-f7f101230544",
|
||
"metadata": {},
|
||
"source": [
|
||
"# 3.2"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"id": "1fc1a301",
|
||
"metadata": {},
|
||
"source": [
|
||
"如果我们想要多头注意力机制的输出和之前单头注意力机制一样为 2,我们可以将输出维度 `d_out` 设置为 1:"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"id": "18e748ef-3106-4e11-a781-b230b74a0cef",
|
||
"metadata": {},
|
||
"source": [
|
||
"```python\n",
|
||
"torch.manual_seed(123)\n",
|
||
"\n",
|
||
"d_out = 1\n",
|
||
"mha = MultiHeadAttentionWrapper(d_in, d_out, block_size, 0.0, num_heads=2)\n",
|
||
"\n",
|
||
"context_vecs = mha(batch)\n",
|
||
"\n",
|
||
"print(context_vecs)\n",
|
||
"print(\"context_vecs.shape:\", context_vecs.shape)\n",
|
||
"```"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"id": "78234544-d989-4f71-ac28-85a7ec1e6b7b",
|
||
"metadata": {},
|
||
"source": [
|
||
"```\n",
|
||
"tensor([[[-9.1476e-02, 3.4164e-02],\n",
|
||
" [-2.6796e-01, -1.3427e-03],\n",
|
||
" [-4.8421e-01, -4.8909e-02],\n",
|
||
" [-6.4808e-01, -1.0625e-01],\n",
|
||
" [-8.8380e-01, -1.7140e-01],\n",
|
||
" [-1.4744e+00, -3.4327e-01]],\n",
|
||
"\n",
|
||
" [[-9.1476e-02, 3.4164e-02],\n",
|
||
" [-2.6796e-01, -1.3427e-03],\n",
|
||
" [-4.8421e-01, -4.8909e-02],\n",
|
||
" [-6.4808e-01, -1.0625e-01],\n",
|
||
" [-8.8380e-01, -1.7140e-01],\n",
|
||
" [-1.4744e+00, -3.4327e-01]]], grad_fn=<CatBackward0>)\n",
|
||
"context_vecs.shape: torch.Size([2, 6, 2])\n",
|
||
"```"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"id": "92bdabcb-06cf-4576-b810-d883bbd313ba",
|
||
"metadata": {},
|
||
"source": [
|
||
"# 3.3"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"id": "84c9b963-d01f-46e6-96bf-8eb2a54c5e42",
|
||
"metadata": {},
|
||
"source": [
|
||
"```python\n",
|
||
"block_size = 1024\n",
|
||
"d_in, d_out = 768, 768\n",
|
||
"num_heads = 12\n",
|
||
"\n",
|
||
"mha = MultiHeadAttention(d_in, d_out, block_size, 0.0, num_heads)\n",
|
||
"```"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"id": "375d5290-8e8b-4149-958e-1efb58a69191",
|
||
"metadata": {},
|
||
"source": [
|
||
"上述实现的参数量为:"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"id": "6d7e603c-1658-4da9-9c0b-ef4bc72832b4",
|
||
"metadata": {},
|
||
"source": [
|
||
"```python\n",
|
||
"def count_parameters(model):\n",
|
||
" return sum(p.numel() for p in model.parameters() if p.requires_grad)\n",
|
||
"\n",
|
||
"count_parameters(mha)\n",
|
||
"```"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"id": "51ba00bd-feb0-4424-84cb-7c2b1f908779",
|
||
"metadata": {},
|
||
"source": [
|
||
"```\n",
|
||
"2360064 # (2.36 M)\n",
|
||
"```"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"id": "a56c1d47-9b95-4bd1-a517-580a6f779c52",
|
||
"metadata": {},
|
||
"source": [
|
||
"\n",
|
||
"\n",
|
||
"GPT-2 模型有 117M 的参数,但正如我们所见,绝大部分参数其实都不是来源于多头注意力机制(而是线性层)。"
|
||
]
|
||
}
|
||
],
|
||
"metadata": {
|
||
"kernelspec": {
|
||
"display_name": "Python 3 (ipykernel)",
|
||
"language": "python",
|
||
"name": "python3"
|
||
},
|
||
"language_info": {
|
||
"codemirror_mode": {
|
||
"name": "ipython",
|
||
"version": 3
|
||
},
|
||
"file_extension": ".py",
|
||
"mimetype": "text/x-python",
|
||
"name": "python",
|
||
"nbconvert_exporter": "python",
|
||
"pygments_lexer": "ipython3",
|
||
"version": "3.9.18"
|
||
}
|
||
},
|
||
"nbformat": 4,
|
||
"nbformat_minor": 5
|
||
}
|