mirror of
https://github.com/datawhalechina/llms-from-scratch-cn.git
synced 2026-04-25 08:58:17 +08:00
update ch02-2.7
This commit is contained in:
parent
d8770c9680
commit
421dd1048a
214
Translated_Book/ch02/2.7 构建词符嵌入.ipynb
Normal file
214
Translated_Book/ch02/2.7 构建词符嵌入.ipynb
Normal file
@ -0,0 +1,214 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "2cd2fcda-2fda-4aa8-8bc8-de1e496f9db1",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## 2.7 构建词符嵌入"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "1a301068-6ab2-44ff-a915-1ba11688274f",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"- 我们准备用于大语言模型(LLM)的数据已经差不多就绪了\n",
|
||||
"- 接下来,我们要做的最后一步是使用嵌入层将 token 嵌入到连续的向量表示中。token本身不可计算,需要将其映射到一个连续向量空间,才可以进行后续运算,这个映射的结果就是该token对应的embedding\n",
|
||||
"- 通常,这些用来转换词符的嵌入层是大语言模型(LLM)的一部分,并且在模型训练的过程中会不断调整和优化。"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "e85089aa-8671-4e5f-a2b3-ef252004ee4c",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"<img src=\"https://github.com/datawhalechina/llms-from-scratch-cn/blob/main/Translated_Book/img/fig-2-15.jpg?raw=true\" width=\"400px\">"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "44e014ca-1fc5-4b90-b6fa-c2097bb92c0b",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"- 假设我们在分词后有以下四个输入示例,对应的输入ID分别是5、1、3和2:"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 1,
|
||||
"id": "15a6304c-9474-4470-b85d-3991a49fa653",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import torch\n",
|
||||
"input_ids = torch.tensor([2, 3, 5, 1])"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "14da6344-2c71-4837-858d-dd120005ba05",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"- 为了简化问题,假设我们有一个只包含6个单词的小型词汇表,我们想要创建大小为3的嵌入。"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 2,
|
||||
"id": "93cb2cee-9aa6-4bb8-8977-c65661d16eda",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"vocab_size = 6\n",
|
||||
"output_dim = 3\n",
|
||||
"\n",
|
||||
"torch.manual_seed(123)\n",
|
||||
"embedding_layer = torch.nn.Embedding(vocab_size, output_dim)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "4ff241f6-78eb-4e4a-a55f-5b2b6196d5b0",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"- 这将会生成一个6x3的权重矩阵:"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 3,
|
||||
"id": "a686eb61-e737-4351-8f1c-222913d47468",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Parameter containing:\n",
|
||||
"tensor([[ 0.3374, -0.1778, -0.1690],\n",
|
||||
" [ 0.9178, 1.5810, 1.3010],\n",
|
||||
" [ 1.2753, -0.2010, -0.1606],\n",
|
||||
" [-0.4015, 0.9666, -1.1481],\n",
|
||||
" [-1.1589, 0.3255, -0.6315],\n",
|
||||
" [-2.8400, -0.7849, -1.4096]], requires_grad=True)\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"print(embedding_layer.weight)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "5e54d5f1",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"- 由于嵌入层只是独热编码和矩阵乘法方法的一种更高效的实现,它可以被视为一个可以通过反向传播进行优化的神经网络层。\n",
|
||||
"- 对于那些熟悉独热编码的人来说,上述嵌入层的方法本质上只是实现独热编码后进行矩阵乘法的一种更高效的手段,这种方法在全连接层中使用,其详细说明可以在补充代码[./embedding_vs_matmul](https://github.com/datawhalechina/llms-from-scratch-cn/tree/main/ch02/03_bonus_embedding-vs-matmul)中找到。\n",
|
||||
"- 因为嵌入层只是独热编码和矩阵乘法方法的一种更高效的实现,所以它可以被视为一个可以通过反向传播算法进行优化的神经网络层。"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "4b0d58c3-83c0-4205-aca2-9c48b19fd4a7",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"- 要将ID为3的词符转换为一个3维向量,我们执行以下步骤:"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 4,
|
||||
"id": "e43600ba-f287-4746-8ddf-d0f71a9023ca",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"tensor([[-0.4015, 0.9666, -1.1481]], grad_fn=<EmbeddingBackward0>)\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"print(embedding_layer(torch.tensor([3])))"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "a7bbf625-4f36-491d-87b4-3969efb784b0",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"- 注意,上述内容是`embedding_layer`权重矩阵中的第4行。\n",
|
||||
"- 为了嵌入上面所有的四个`input_ids`值,我们执行以下操作:"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 5,
|
||||
"id": "50280ead-0363-44c8-8c35-bb885d92c8b7",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"tensor([[ 1.2753, -0.2010, -0.1606],\n",
|
||||
" [-0.4015, 0.9666, -1.1481],\n",
|
||||
" [-2.8400, -0.7849, -1.4096],\n",
|
||||
" [ 0.9178, 1.5810, 1.3010]], grad_fn=<EmbeddingBackward0>)\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"print(embedding_layer(input_ids))"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "be97ced4-bd13-42b7-866a-4d699a17e155",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"- 嵌入层本质上是一种查找操作:"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "f33c2741-bf1b-4c60-b7fd-61409d556646",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"<img src=\"https://github.com/datawhalechina/llms-from-scratch-cn/blob/main/Translated_Book/img/fig-2-16.jpg?raw=true\" width=\"500px\">"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "08218d9f-aa1a-4afb-a105-72ff96a54e73",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"- **您可能对比较嵌入层与常规线性层的附加内容感兴趣:[../03_bonus_embedding-vs-matmul](https://github.com/datawhalechina/llms-from-scratch-cn/tree/main/ch02/03_bonus_embedding-vs-matmul)**"
|
||||
]
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "Python 3 (ipykernel)",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.9.13"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 5
|
||||
}
|
||||
Loading…
Reference in New Issue
Block a user