mirror of
https://github.com/datawhalechina/llms-from-scratch-cn.git
synced 2026-06-06 00:04:42 +00:00
[Book] correct code and its output
This commit is contained in:
@@ -36,7 +36,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"execution_count": 4,
|
||||
"metadata": {
|
||||
"trusted": true
|
||||
},
|
||||
@@ -58,7 +58,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"execution_count": 5,
|
||||
"metadata": {
|
||||
"trusted": true
|
||||
},
|
||||
@@ -81,11 +81,19 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"execution_count": 6,
|
||||
"metadata": {
|
||||
"trusted": true
|
||||
},
|
||||
"outputs": [],
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"tensor([0.4306, 1.4551])\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"query_2 = x_2 @ W_query \n",
|
||||
"key_2 = x_2 @ W_key \n",
|
||||
@@ -120,11 +128,20 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"execution_count": 7,
|
||||
"metadata": {
|
||||
"trusted": true
|
||||
},
|
||||
"outputs": [],
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"keys.shape: torch.Size([6, 2])\n",
|
||||
"values.shape: torch.Size([6, 2])\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"keys = inputs @ W_key \n",
|
||||
"values = inputs @ W_value\n",
|
||||
@@ -158,11 +175,19 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"execution_count": 8,
|
||||
"metadata": {
|
||||
"trusted": true
|
||||
},
|
||||
"outputs": [],
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"tensor(1.8524)\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"keys_2 = keys[1] #A\n",
|
||||
"attn_score_22 = query_2.dot(keys_2)\n",
|
||||
@@ -188,11 +213,19 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"execution_count": 9,
|
||||
"metadata": {
|
||||
"trusted": true
|
||||
},
|
||||
"outputs": [],
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"tensor([1.2705, 1.8524, 1.8111, 1.0795, 0.5577, 1.5440])\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"attn_scores_2 = query_2 @ keys.T # All attention scores for given query\n",
|
||||
"print(attn_scores_2)"
|
||||
@@ -265,11 +298,23 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"execution_count": 10,
|
||||
"metadata": {
|
||||
"trusted": true
|
||||
},
|
||||
"outputs": [],
|
||||
"outputs": [
|
||||
{
|
||||
"ename": "NameError",
|
||||
"evalue": "name 'attn_weights_2' is not defined",
|
||||
"output_type": "error",
|
||||
"traceback": [
|
||||
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
|
||||
"\u001b[0;31mNameError\u001b[0m Traceback (most recent call last)",
|
||||
"Input \u001b[0;32mIn [10]\u001b[0m, in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[0;32m----> 1\u001b[0m context_vec_2 \u001b[38;5;241m=\u001b[39m \u001b[43mattn_weights_2\u001b[49m \u001b[38;5;241m@\u001b[39m values\n\u001b[1;32m 2\u001b[0m \u001b[38;5;28mprint\u001b[39m(context_vec_2)\n",
|
||||
"\u001b[0;31mNameError\u001b[0m: name 'attn_weights_2' is not defined"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"context_vec_2 = attn_weights_2 @ values\n",
|
||||
"print(context_vec_2)"
|
||||
@@ -320,23 +365,11 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 3,
|
||||
"execution_count": 11,
|
||||
"metadata": {
|
||||
"trusted": true
|
||||
},
|
||||
"outputs": [
|
||||
{
|
||||
"ename": "<class 'ModuleNotFoundError'>",
|
||||
"evalue": "No module named 'torch'",
|
||||
"output_type": "error",
|
||||
"traceback": [
|
||||
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
|
||||
"\u001b[0;31mModuleNotFoundError\u001b[0m Traceback (most recent call last)",
|
||||
"Cell \u001b[0;32mIn[3], line 1\u001b[0m\n\u001b[0;32m----> 1\u001b[0m \u001b[38;5;28;01mimport\u001b[39;00m \u001b[38;5;21;01mtorch\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01mnn\u001b[39;00m \u001b[38;5;28;01mas\u001b[39;00m \u001b[38;5;21;01mnn\u001b[39;00m\n\u001b[1;32m 2\u001b[0m \u001b[38;5;28;01mclass\u001b[39;00m \u001b[38;5;21;01mSelfAttention_v1\u001b[39;00m(nn\u001b[38;5;241m.\u001b[39mModule):\n\u001b[1;32m 3\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21m__init__\u001b[39m(\u001b[38;5;28mself\u001b[39m, d_in, d_out):\n",
|
||||
"\u001b[0;31mModuleNotFoundError\u001b[0m: No module named 'torch'"
|
||||
]
|
||||
}
|
||||
],
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import torch.nn as nn\n",
|
||||
"class SelfAttention_v1(nn.Module):\n",
|
||||
@@ -373,11 +406,24 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"execution_count": 12,
|
||||
"metadata": {
|
||||
"trusted": true
|
||||
},
|
||||
"outputs": [],
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"tensor([[0.2996, 0.8053],\n",
|
||||
" [0.3061, 0.8210],\n",
|
||||
" [0.3058, 0.8203],\n",
|
||||
" [0.2948, 0.7939],\n",
|
||||
" [0.2927, 0.7891],\n",
|
||||
" [0.2990, 0.8040]], grad_fn=<MmBackward>)\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"torch.manual_seed(123)\n",
|
||||
"sa_v1 = SelfAttention_v1(d_in, d_out)\n",
|
||||
@@ -420,7 +466,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"execution_count": 14,
|
||||
"metadata": {
|
||||
"trusted": true
|
||||
},
|
||||
@@ -439,7 +485,7 @@
|
||||
" queries = self.W_query(x)\n",
|
||||
" values = self.W_value(x)\n",
|
||||
" attn_scores = queries @ keys.T\n",
|
||||
" attn_weights = torch.softmax(attn_scores / keys.shape[-1]**0.5, dim=\n",
|
||||
" attn_weights = torch.softmax(attn_scores / keys.shape[-1]**0.5, dim=-1)\n",
|
||||
" context_vec = attn_weights @ values\n",
|
||||
" return context_vec"
|
||||
]
|
||||
@@ -453,11 +499,24 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"execution_count": 15,
|
||||
"metadata": {
|
||||
"trusted": true
|
||||
},
|
||||
"outputs": [],
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"tensor([[-0.0739, 0.0713],\n",
|
||||
" [-0.0748, 0.0703],\n",
|
||||
" [-0.0749, 0.0702],\n",
|
||||
" [-0.0760, 0.0685],\n",
|
||||
" [-0.0763, 0.0679],\n",
|
||||
" [-0.0754, 0.0693]], grad_fn=<MmBackward>)\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"torch.manual_seed(789)\n",
|
||||
"sa_v2 = SelfAttention_v2(d_in, d_out)\n",
|
||||
@@ -524,13 +583,13 @@
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "Python (Pyodide)",
|
||||
"display_name": "minitorch",
|
||||
"language": "python",
|
||||
"name": "python"
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "python",
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
@@ -538,7 +597,7 @@
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.8"
|
||||
"version": "3.8.12"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
|
||||
+163
-41
@@ -35,11 +35,25 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"execution_count": 5,
|
||||
"metadata": {
|
||||
"trusted": true
|
||||
},
|
||||
"outputs": [],
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"tensor([[0.1921, 0.1646, 0.1652, 0.1550, 0.1721, 0.1510],\n",
|
||||
" [0.2041, 0.1659, 0.1662, 0.1496, 0.1665, 0.1477],\n",
|
||||
" [0.2036, 0.1659, 0.1662, 0.1498, 0.1664, 0.1480],\n",
|
||||
" [0.1869, 0.1667, 0.1668, 0.1571, 0.1661, 0.1564],\n",
|
||||
" [0.1830, 0.1669, 0.1670, 0.1588, 0.1658, 0.1585],\n",
|
||||
" [0.1935, 0.1663, 0.1666, 0.1542, 0.1666, 0.1529]],\n",
|
||||
" grad_fn=<SoftmaxBackward>)\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"queries = sa_v2.W_query(inputs) #A\n",
|
||||
"keys = sa_v2.W_key(inputs) \n",
|
||||
@@ -73,11 +87,24 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"execution_count": 6,
|
||||
"metadata": {
|
||||
"trusted": true
|
||||
},
|
||||
"outputs": [],
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"tensor([[1., 0., 0., 0., 0., 0.],\n",
|
||||
" [1., 1., 0., 0., 0., 0.],\n",
|
||||
" [1., 1., 1., 0., 0., 0.],\n",
|
||||
" [1., 1., 1., 1., 0., 0.],\n",
|
||||
" [1., 1., 1., 1., 1., 0.],\n",
|
||||
" [1., 1., 1., 1., 1., 1.]])\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"context_length = attn_scores.shape[0]\n",
|
||||
"mask_simple = torch.tril(torch.ones(context_length, context_length))\n",
|
||||
@@ -108,11 +135,25 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"execution_count": 7,
|
||||
"metadata": {
|
||||
"trusted": true
|
||||
},
|
||||
"outputs": [],
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"tensor([[0.1921, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],\n",
|
||||
" [0.2041, 0.1659, 0.0000, 0.0000, 0.0000, 0.0000],\n",
|
||||
" [0.2036, 0.1659, 0.1662, 0.0000, 0.0000, 0.0000],\n",
|
||||
" [0.1869, 0.1667, 0.1668, 0.1571, 0.0000, 0.0000],\n",
|
||||
" [0.1830, 0.1669, 0.1670, 0.1588, 0.1658, 0.0000],\n",
|
||||
" [0.1935, 0.1663, 0.1666, 0.1542, 0.1666, 0.1529]],\n",
|
||||
" grad_fn=<MulBackward0>)\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"masked_simple = attn_weights*mask_simple\n",
|
||||
"print(masked_simple)"
|
||||
@@ -143,11 +184,25 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"execution_count": 8,
|
||||
"metadata": {
|
||||
"trusted": true
|
||||
},
|
||||
"outputs": [],
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"tensor([[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],\n",
|
||||
" [0.5517, 0.4483, 0.0000, 0.0000, 0.0000, 0.0000],\n",
|
||||
" [0.3800, 0.3097, 0.3103, 0.0000, 0.0000, 0.0000],\n",
|
||||
" [0.2758, 0.2460, 0.2462, 0.2319, 0.0000, 0.0000],\n",
|
||||
" [0.2175, 0.1983, 0.1984, 0.1888, 0.1971, 0.0000],\n",
|
||||
" [0.1935, 0.1663, 0.1666, 0.1542, 0.1666, 0.1529]],\n",
|
||||
" grad_fn=<DivBackward0>)\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"row_sums = masked_simple.sum(dim=1, keepdim=True)\n",
|
||||
"masked_simple_norm = masked_simple / row_sums\n",
|
||||
@@ -196,11 +251,25 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"execution_count": 19,
|
||||
"metadata": {
|
||||
"trusted": true
|
||||
},
|
||||
"outputs": [],
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"tensor([[0.2899, -inf, -inf, -inf, -inf, -inf],\n",
|
||||
" [0.4656, 0.1723, -inf, -inf, -inf, -inf],\n",
|
||||
" [0.4594, 0.1703, 0.1731, -inf, -inf, -inf],\n",
|
||||
" [0.2642, 0.1024, 0.1036, 0.0186, -inf, -inf],\n",
|
||||
" [0.2183, 0.0874, 0.0882, 0.0177, 0.0786, -inf],\n",
|
||||
" [0.3408, 0.1270, 0.1290, 0.0198, 0.1290, 0.0078]],\n",
|
||||
" grad_fn=<MaskedFillBackward0>)\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"mask = torch.triu(torch.ones(context_length, context_length), diagonal=1)\n",
|
||||
"masked = attn_scores.masked_fill(mask.bool(), -torch.inf)\n",
|
||||
@@ -232,11 +301,25 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"execution_count": 20,
|
||||
"metadata": {
|
||||
"trusted": true
|
||||
},
|
||||
"outputs": [],
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"tensor([[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],\n",
|
||||
" [0.5517, 0.4483, 0.0000, 0.0000, 0.0000, 0.0000],\n",
|
||||
" [0.3800, 0.3097, 0.3103, 0.0000, 0.0000, 0.0000],\n",
|
||||
" [0.2758, 0.2460, 0.2462, 0.2319, 0.0000, 0.0000],\n",
|
||||
" [0.2175, 0.1983, 0.1984, 0.1888, 0.1971, 0.0000],\n",
|
||||
" [0.1935, 0.1663, 0.1666, 0.1542, 0.1666, 0.1529]],\n",
|
||||
" grad_fn=<SoftmaxBackward>)\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"attn_weights = torch.softmax(masked / keys.shape[-1]**0.5, dim=1)\n",
|
||||
"print(attn_weights)"
|
||||
@@ -285,11 +368,24 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"execution_count": 21,
|
||||
"metadata": {
|
||||
"trusted": true
|
||||
},
|
||||
"outputs": [],
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"tensor([[2., 2., 2., 2., 2., 2.],\n",
|
||||
" [0., 2., 0., 0., 0., 0.],\n",
|
||||
" [0., 0., 2., 0., 2., 0.],\n",
|
||||
" [2., 2., 0., 0., 0., 2.],\n",
|
||||
" [2., 0., 0., 0., 0., 2.],\n",
|
||||
" [0., 2., 0., 0., 0., 0.]])\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"torch.manual_seed(123)\n",
|
||||
"dropout = torch.nn.Dropout(0.5) #A\n",
|
||||
@@ -323,11 +419,25 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"execution_count": 22,
|
||||
"metadata": {
|
||||
"trusted": true
|
||||
},
|
||||
"outputs": [],
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"tensor([[2.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],\n",
|
||||
" [0.0000, 0.8966, 0.0000, 0.0000, 0.0000, 0.0000],\n",
|
||||
" [0.0000, 0.0000, 0.6206, 0.0000, 0.0000, 0.0000],\n",
|
||||
" [0.5517, 0.4921, 0.0000, 0.0000, 0.0000, 0.0000],\n",
|
||||
" [0.4350, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],\n",
|
||||
" [0.0000, 0.3327, 0.0000, 0.0000, 0.0000, 0.0000]],\n",
|
||||
" grad_fn=<MulBackward0>)\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"torch.manual_seed(123)\n",
|
||||
"print(dropout(attn_weights))"
|
||||
@@ -369,11 +479,19 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"execution_count": 23,
|
||||
"metadata": {
|
||||
"trusted": true
|
||||
},
|
||||
"outputs": [],
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"torch.Size([2, 6, 3])\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"batch = torch.stack((inputs, inputs), dim=0)\n",
|
||||
"print(batch.shape) #A "
|
||||
@@ -400,39 +518,35 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"execution_count": 29,
|
||||
"metadata": {
|
||||
"trusted": true
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"class CausalAttention(nn.Module):\n",
|
||||
" def __init__(self, d_in, d_out, context_length, dropout, qkv_bias=False)\n",
|
||||
"\n",
|
||||
" def __init__(self, d_in, d_out, context_length, dropout, qkv_bias=False):\n",
|
||||
" super().__init__()\n",
|
||||
" self.d_out = d_out\n",
|
||||
" self.W_query = nn.Linear(d_in, d_out, bias=qkv_bias)\n",
|
||||
" self.W_key = nn.Linear(d_in, d_out, bias=qkv_bias)\n",
|
||||
" self.W_value = nn.Linear(d_in, d_out, bias=qkv_bias)\n",
|
||||
" self.dropout = nn.Dropout(dropout) #A\n",
|
||||
" self.register_buffer(\n",
|
||||
" 'mask',\n",
|
||||
" torch.triu(torch.ones(context_length, context_length),\n",
|
||||
" diagonal=1)\n",
|
||||
" ) #B\n",
|
||||
" \n",
|
||||
" self.dropout = nn.Dropout(dropout) # New\n",
|
||||
" self.register_buffer('mask', torch.triu(torch.ones(context_length, context_length), diagonal=1)) # New\n",
|
||||
"\n",
|
||||
" def forward(self, x):\n",
|
||||
" b, num_tokens, d_in = x.shape #C \n",
|
||||
"New batch dimension b\n",
|
||||
" b, num_tokens, d_in = x.shape # New batch dimension b\n",
|
||||
" keys = self.W_key(x)\n",
|
||||
" queries = self.W_query(x)\n",
|
||||
" values = self.W_value(x)\n",
|
||||
" \n",
|
||||
" attn_scores = queries @ keys.transpose(1, 2) #C\n",
|
||||
" attn_scores.masked_fill_( #D\n",
|
||||
"\n",
|
||||
" attn_scores = queries @ keys.transpose(1, 2) # Changed transpose\n",
|
||||
" attn_scores.masked_fill_( # New, _ ops are in-place\n",
|
||||
" self.mask.bool()[:num_tokens, :num_tokens], -torch.inf) \n",
|
||||
" attn_weights = torch.softmax(attn_scores / keys.shape[-1]**0.5, dim=\n",
|
||||
" attn_weights = self.dropout(attn_weights)\n",
|
||||
" \n",
|
||||
" attn_weights = torch.softmax(attn_scores / keys.shape[-1]**0.5, dim=-1)\n",
|
||||
" attn_weights = self.dropout(attn_weights) # New\n",
|
||||
"\n",
|
||||
" context_vec = attn_weights @ values\n",
|
||||
" return context_vec"
|
||||
]
|
||||
@@ -448,11 +562,19 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"execution_count": 30,
|
||||
"metadata": {
|
||||
"trusted": true
|
||||
},
|
||||
"outputs": [],
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"context_vecs.shape: torch.Size([2, 6, 2])\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"torch.manual_seed(123)\n",
|
||||
"context_length = batch.shape[1]\n",
|
||||
@@ -532,13 +654,13 @@
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "Python (Pyodide)",
|
||||
"display_name": "minitorch",
|
||||
"language": "python",
|
||||
"name": "python"
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "python",
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
@@ -546,7 +668,7 @@
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.8"
|
||||
"version": "3.8.12"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
|
||||
@@ -31,12 +31,13 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"execution_count": 4,
|
||||
"metadata": {
|
||||
"trusted": true
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from torch import nn\n",
|
||||
"class MultiHeadAttentionWrapper(nn.Module):\n",
|
||||
" def __init__(self, d_in, d_out, context_length,\n",
|
||||
" dropout, num_heads, qkv_bias=False):\n",
|
||||
@@ -47,7 +48,7 @@
|
||||
" )\n",
|
||||
" \n",
|
||||
" def forward(self, x):\n",
|
||||
" return torch.cat([head(x) for head in self.heads], dim=-1"
|
||||
" return torch.cat([head(x) for head in self.heads], dim=-1)"
|
||||
]
|
||||
},
|
||||
{
|
||||
@@ -65,16 +66,28 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"execution_count": 6,
|
||||
"metadata": {
|
||||
"trusted": true
|
||||
},
|
||||
"outputs": [],
|
||||
"outputs": [
|
||||
{
|
||||
"ename": "NameError",
|
||||
"evalue": "name 'batch' is not defined",
|
||||
"output_type": "error",
|
||||
"traceback": [
|
||||
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
|
||||
"\u001b[0;31mNameError\u001b[0m Traceback (most recent call last)",
|
||||
"Input \u001b[0;32mIn [6]\u001b[0m, in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[1;32m 1\u001b[0m torch\u001b[38;5;241m.\u001b[39mmanual_seed(\u001b[38;5;241m123\u001b[39m)\n\u001b[0;32m----> 2\u001b[0m context_length \u001b[38;5;241m=\u001b[39m \u001b[43mbatch\u001b[49m\u001b[38;5;241m.\u001b[39mshape[\u001b[38;5;241m1\u001b[39m] \u001b[38;5;66;03m# This is the number of tokens\u001b[39;00m\n\u001b[1;32m 3\u001b[0m d_in, d_out \u001b[38;5;241m=\u001b[39m \u001b[38;5;241m3\u001b[39m, \u001b[38;5;241m2\u001b[39m\n\u001b[1;32m 4\u001b[0m mha \u001b[38;5;241m=\u001b[39m MultiHeadAttentionWrapper(d_in, d_out, context_length, \u001b[38;5;241m0.0\u001b[39m, num_heads\u001b[38;5;241m=\u001b[39m\u001b[38;5;241m2\u001b[39m)\n",
|
||||
"\u001b[0;31mNameError\u001b[0m: name 'batch' is not defined"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"torch.manual_seed(123)\n",
|
||||
"context_length = batch.shape[1] # This is the number of tokens\n",
|
||||
"d_in, d_out = 3, 2\n",
|
||||
"mha = MultiHeadAttentionWrapper(d_in, d_out, context_length, 0.0, num_heads=\n",
|
||||
"mha = MultiHeadAttentionWrapper(d_in, d_out, context_length, 0.0, num_heads=2)\n",
|
||||
"context_vecs = mha(batch)\n",
|
||||
" \n",
|
||||
"print(context_vecs)\n",
|
||||
@@ -132,7 +145,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"execution_count": 12,
|
||||
"metadata": {
|
||||
"trusted": true
|
||||
},
|
||||
@@ -142,7 +155,7 @@
|
||||
" def __init__(self, d_in, d_out, \n",
|
||||
" context_length, dropout, num_heads, qkv_bias=False):\n",
|
||||
" super().__init__()\n",
|
||||
" assert d_out % num_heads == 0, \"d_out must be divisible by num_heads\n",
|
||||
" assert d_out % num_heads == 0, \"d_out must be divisible by num_heads\"\n",
|
||||
" self.d_out = d_out\n",
|
||||
" self.num_heads = num_heads\n",
|
||||
" self.head_dim = d_out // num_heads #A\n",
|
||||
@@ -151,10 +164,7 @@
|
||||
" self.W_value = nn.Linear(d_in, d_out, bias=qkv_bias)\n",
|
||||
" self.out_proj = nn.Linear(d_out, d_out) #B\n",
|
||||
" self.dropout = nn.Dropout(dropout)\n",
|
||||
" self.register_buffer(\n",
|
||||
" 'mask',\n",
|
||||
" torch.triu(torch.ones(context_length, context_length), diagonal\n",
|
||||
" )\n",
|
||||
" self.register_buffer('mask', torch.triu(torch.ones(context_length, context_length), diagonal=1))\n",
|
||||
" \n",
|
||||
" def forward(self, x):\n",
|
||||
" b, num_tokens, d_in = x.shape\n",
|
||||
@@ -181,7 +191,7 @@
|
||||
" \n",
|
||||
" context_vec = (attn_weights @ values).transpose(1, 2) #I\n",
|
||||
" #J\n",
|
||||
" context_vec = context_vec.contiguous().view(b, num_tokens, self.d_ou\n",
|
||||
" context_vec = context_vec.contiguous().view(b, num_tokens, self.d_out)\n",
|
||||
" context_vec = self.out_proj(context_vec) #K\n",
|
||||
" return context_vec"
|
||||
]
|
||||
@@ -209,7 +219,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"execution_count": 13,
|
||||
"metadata": {
|
||||
"trusted": true
|
||||
},
|
||||
@@ -233,11 +243,25 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"execution_count": 14,
|
||||
"metadata": {
|
||||
"trusted": true
|
||||
},
|
||||
"outputs": [],
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"tensor([[[[1.3208, 1.1631, 1.2879],\n",
|
||||
" [1.1631, 2.2150, 1.8424],\n",
|
||||
" [1.2879, 1.8424, 2.0402]],\n",
|
||||
"\n",
|
||||
" [[0.4391, 0.7003, 0.5903],\n",
|
||||
" [0.7003, 1.3737, 1.0620],\n",
|
||||
" [0.5903, 1.0620, 0.9912]]]])\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"print(a @ a.transpose(2, 3))"
|
||||
]
|
||||
@@ -269,11 +293,27 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"execution_count": 15,
|
||||
"metadata": {
|
||||
"trusted": true
|
||||
},
|
||||
"outputs": [],
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"First head:\n",
|
||||
" tensor([[1.3208, 1.1631, 1.2879],\n",
|
||||
" [1.1631, 2.2150, 1.8424],\n",
|
||||
" [1.2879, 1.8424, 2.0402]])\n",
|
||||
"\n",
|
||||
"Second head:\n",
|
||||
" tensor([[0.4391, 0.7003, 0.5903],\n",
|
||||
" [0.7003, 1.3737, 1.0620],\n",
|
||||
" [0.5903, 1.0620, 0.9912]])\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"first_head = a[0, 0, :, :]\n",
|
||||
"first_res = first_head @ first_head.T\n",
|
||||
@@ -317,11 +357,23 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"execution_count": 16,
|
||||
"metadata": {
|
||||
"trusted": true
|
||||
},
|
||||
"outputs": [],
|
||||
"outputs": [
|
||||
{
|
||||
"ename": "NameError",
|
||||
"evalue": "name 'batch' is not defined",
|
||||
"output_type": "error",
|
||||
"traceback": [
|
||||
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
|
||||
"\u001b[0;31mNameError\u001b[0m Traceback (most recent call last)",
|
||||
"Input \u001b[0;32mIn [16]\u001b[0m, in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[1;32m 1\u001b[0m torch\u001b[38;5;241m.\u001b[39mmanual_seed(\u001b[38;5;241m123\u001b[39m)\n\u001b[0;32m----> 2\u001b[0m batch_size, context_length, d_in \u001b[38;5;241m=\u001b[39m \u001b[43mbatch\u001b[49m\u001b[38;5;241m.\u001b[39mshape\n\u001b[1;32m 3\u001b[0m d_out \u001b[38;5;241m=\u001b[39m \u001b[38;5;241m2\u001b[39m\n\u001b[1;32m 4\u001b[0m mha \u001b[38;5;241m=\u001b[39m MultiHeadAttention(d_in, d_out, context_length, \u001b[38;5;241m0.0\u001b[39m, num_heads\u001b[38;5;241m=\u001b[39m\u001b[38;5;241m2\u001b[39m)\n",
|
||||
"\u001b[0;31mNameError\u001b[0m: name 'batch' is not defined"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"torch.manual_seed(123)\n",
|
||||
"batch_size, context_length, d_in = batch.shape\n",
|
||||
@@ -366,53 +418,17 @@
|
||||
"## 练习 3.3 初始化具有 GPT-2 规模的注意力模块\n",
|
||||
"使用 MultiHeadAttention 类,初始化一个具有与最小 GPT-2 模型相同数量的注意力头(12 个注意力头)的 MultiHeadAttention 模块。同时确保你使用与 GPT-2 相似的输入和输出嵌入规模(768 维)。请注意,最小的 GPT-2 模型支持 1024 个 Token 的上下文长度。"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"trusted": true
|
||||
},
|
||||
"outputs": [],
|
||||
"source": []
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"trusted": true
|
||||
},
|
||||
"outputs": [],
|
||||
"source": []
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"trusted": true
|
||||
},
|
||||
"outputs": [],
|
||||
"source": []
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"trusted": true
|
||||
},
|
||||
"outputs": [],
|
||||
"source": []
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "Python (Pyodide)",
|
||||
"display_name": "minitorch",
|
||||
"language": "python",
|
||||
"name": "python"
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "python",
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
@@ -420,7 +436,7 @@
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.8"
|
||||
"version": "3.8.12"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
|
||||
@@ -1,9 +1,49 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"# 3.7 总结\n",
|
||||
"\n",
|
||||
"- 注意力(Attention)机制将输入元素转换为增强的上下文向量表示,这些表示融合了所有输入的信息。\n",
|
||||
"\n",
|
||||
"- 自注意力(Self Attention)机制通过对输入的加权求和来计算上下文向量表示。\n",
|
||||
"\n",
|
||||
"- 在简化的注意力机制中,注意力权重通过点积计算得出。\n",
|
||||
"\n",
|
||||
"- 点积是将两个向量的相应元素相乘然后求和的简洁方式。\n",
|
||||
"\n",
|
||||
"- 虽然不是绝对必要,但矩阵乘法通过替代嵌套的 for 循环,帮助我们更高效、紧凑地实施计算。\n",
|
||||
"\n",
|
||||
"- 用于大语言模型的自注意力机制,也称为缩放点积注意力,其中包含了可训练的权重矩阵来计算输入的中间转换向量:查询、值和键。\n",
|
||||
"\n",
|
||||
"- 在处理从左到右阅读和生成文本的大语言模型时,我们添加因果注意力遮蔽(CausalAttention Mask)以防止大语言模型访问后续的 Token 。\n",
|
||||
"\n",
|
||||
"- 除了使用因果注意力遮蔽将注意力权重归零外,我们还可以添加 Dropout 遮蔽来减少大语言模型中的过拟合问题。\n",
|
||||
"\n",
|
||||
"- 基于 Transformer 的大语言模型中的注意力模块涉及多个因果注意力(CausalAttention)实例,这称为多头注意力(MultiHeadAttention)。\n",
|
||||
"\n",
|
||||
"- 我们可以通过堆叠多个 CausalAttention 模块来创建一个 MultiHeadAttention 模块。\n",
|
||||
"\n",
|
||||
"- 创建 MultiHeadAttention 模块的更有效方式涉及到批量矩阵乘法。"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"trusted": true
|
||||
},
|
||||
"outputs": [],
|
||||
"source": []
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"name": "python",
|
||||
"display_name": "Python (Pyodide)",
|
||||
"language": "python"
|
||||
"language": "python",
|
||||
"name": "python"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
@@ -18,22 +58,6 @@
|
||||
"version": "3.8"
|
||||
}
|
||||
},
|
||||
"nbformat_minor": 4,
|
||||
"nbformat": 4,
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"source": "# 3.7 总结\n\n- 注意力(Attention)机制将输入元素转换为增强的上下文向量表示,这些表示融合了所有输入的信息。\n\n- 自注意力(Self Attention)机制通过对输入的加权求和来计算上下文向量表示。\n\n- 在简化的注意力机制中,注意力权重通过点积计算得出。\n\n- 点积是将两个向量的相应元素相乘然后求和的简洁方式。\n\n- 虽然不是绝对必要,但矩阵乘法通过替代嵌套的 for 循环,帮助我们更高效、紧凑地实施计算。\n\n- 用于大语言模型的自注意力机制,也称为缩放点积注意力,其中包含了可训练的权重矩阵来计算输入的中间转换向量:查询、值和键。\n\n- 在处理从左到右阅读和生成文本的大语言模型时,我们添加因果注意力遮蔽(CausalAttention Mask)以防止大语言模型访问后续的 Token 。\n\n- 除了使用因果注意力遮蔽将注意力权重归零外,我们还可以添加 Dropout 遮蔽来减少大语言模型中的过拟合问题。\n\n- 基于 Transformer 的大语言模型中的注意力模块涉及多个因果注意力(CausalAttention)实例,这称为多头注意力(MultiHeadAttention)。\n\n- 我们可以通过堆叠多个 CausalAttention 模块来创建一个 MultiHeadAttention 模块。\n\n- 创建 MultiHeadAttention 模块的更有效方式涉及到批量矩阵乘法。",
|
||||
"metadata": {}
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"source": "",
|
||||
"metadata": {
|
||||
"trusted": true
|
||||
},
|
||||
"outputs": [],
|
||||
"execution_count": null
|
||||
}
|
||||
]
|
||||
}
|
||||
"nbformat_minor": 4
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user