llms-from-scratch-cn/ch02/01_main-chapter-code/ch02.ipynb
2024-02-29 10:09:03 +08:00

1648 lines
45 KiB
Plaintext
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

{
"cells": [
{
"cell_type": "markdown",
"id": "25aa40e3-5109-433f-9153-f5770531fe94",
"metadata": {},
"source": [
"# 第2章处理文本"
]
},
{
"cell_type": "markdown",
"id": "76d5d2c0-cba8-404e-9bf3-71a218cae3cf",
"metadata": {},
"source": [
"本笔记本使用的软件包:"
]
},
{
"cell_type": "code",
"execution_count": 1,
"id": "4d1305cf-12d5-46fe-a2c9-36fb71c5b3d3",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"torch version: 2.2.1\n",
"tiktoken version: 0.6.0\n"
]
}
],
"source": [
"from importlib.metadata import version\n",
"\n",
"import tiktoken\n",
"import torch\n",
"\n",
"print(\"torch version:\", version(\"torch\"))\n",
"print(\"tiktoken version:\", version(\"tiktoken\"))"
]
},
{
"cell_type": "markdown",
"id": "2417139b-2357-44d2-bd67-23f5d7f52ae7",
"metadata": {},
"source": [
"## 2.1 理解词嵌入"
]
},
{
"cell_type": "markdown",
"id": "0b6816ae-e927-43a9-b4dd-e47a9b0e1cf6",
"metadata": {},
"source": [
"- 本节无代码"
]
},
{
"cell_type": "markdown",
"id": "eddbb984-8d23-40c5-bbfa-c3c379e7eec3",
"metadata": {},
"source": [
"## 2.2 文本分词"
]
},
{
"cell_type": "markdown",
"id": "8cceaa18-833d-46b6-b211-b20c53902805",
"metadata": {},
"source": [
"- 加载我们要处理的原始文本\n",
"- [伊迪丝·华顿的《判决书》](https://en.wikisource.org/wiki/The_Verdict) 是一个公共领域的短篇小说"
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "8a769e87-470a-48b9-8bdb-12841b416198",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Total number of character: 20479\n",
"I HAD always thought Jack Gisburn rather a cheap genius--though a good fellow enough--so it was no \n"
]
}
],
"source": [
"with open(\"the-verdict.txt\", \"r\", encoding=\"utf-8\") as f:\n",
" raw_text = f.read()\n",
" \n",
"print(\"Total number of character:\", len(raw_text))\n",
"print(raw_text[:99])"
]
},
{
"cell_type": "markdown",
"id": "9b971a46-ac03-4368-88ae-3f20279e8f4e",
"metadata": {},
"source": [
"- 目标是将这段文本进行分词和嵌入以便用于语言模型LLM\n",
"- 让我们基于一些简单的样本文本开发一个简单的分词器,然后我们可以将这个分词器应用到上面的文本中\n",
"- 下面的正则表达式将会根据空格进行分割"
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "737dd5b0-9dbb-4a97-9ae4-3482c8c04be7",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"['Hello,', ' ', 'world.', ' ', 'This,', ' ', 'is', ' ', 'a', ' ', 'test.']\n"
]
}
],
"source": [
"import re\n",
"\n",
"text = \"Hello, world. This, is a test.\"\n",
"result = re.split(r'(\\s)', text)\n",
"\n",
"print(result)"
]
},
{
"cell_type": "markdown",
"id": "a8c40c18-a9d5-4703-bf71-8261dbcc5ee3",
"metadata": {},
"source": [
"- 我们不仅要分割空格,还要分割逗号和句号,所以我们也要修改正则表达式来做到这一点"
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "ea02489d-01f9-4247-b7dd-a0d63f62ef07",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"['Hello', ',', '', ' ', 'world', '.', '', ' ', 'This', ',', '', ' ', 'is', ' ', 'a', ' ', 'test', '.', '']\n"
]
}
],
"source": [
"result = re.split(r'([,.]|\\s)', text)\n",
"\n",
"print(result)"
]
},
{
"cell_type": "markdown",
"id": "461d0c86-e3af-4f87-8fae-594a9ca9b6ad",
"metadata": {},
"source": [
"- 我们可以看到,这会创建空字符串,让我们删除它们"
]
},
{
"cell_type": "code",
"execution_count": 5,
"id": "4d8a6fb7-2e62-4a12-ad06-ccb04f25fed7",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"['Hello', ',', 'world', '.', 'This', ',', 'is', 'a', 'test', '.']\n"
]
}
],
"source": [
"# 从每个元素中删除空白,然后过滤掉所有空字符串。\n",
"result = [item.strip() for item in result if item.strip()]\n",
"print(result)"
]
},
{
"cell_type": "markdown",
"id": "250e8694-181e-496f-895d-7cb7d92b5562",
"metadata": {},
"source": [
"- 这看起来很不错,但我们还要处理其他类型的标点符号,如句号、问号等"
]
},
{
"cell_type": "code",
"execution_count": 6,
"id": "902f0d9c-9828-4c46-ba32-8fe810c3840a",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"['Hello', ',', 'world', '.', 'Is', 'this', '--', 'a', 'test', '?']\n"
]
}
],
"source": [
"text = \"Hello, world. Is this-- a test?\"\n",
"\n",
"result = re.split(r'([,.?_!\"()\\']|--|\\s)', text)\n",
"result = [item.strip() for item in result if item.strip()]\n",
"print(result)"
]
},
{
"cell_type": "markdown",
"id": "5bbea70b-c030-45d9-b09d-4318164c0bb4",
"metadata": {},
"source": [
"- 这很好,现在我们可以将这个分词器用于原始文本了"
]
},
{
"cell_type": "code",
"execution_count": 7,
"id": "8c567caa-8ff5-49a8-a5cc-d365b0a78a99",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"['I', 'HAD', 'always', 'thought', 'Jack', 'Gisburn', 'rather', 'a', 'cheap', 'genius', '--', 'though', 'a', 'good', 'fellow', 'enough', '--', 'so', 'it', 'was', 'no', 'great', 'surprise', 'to', 'me', 'to', 'hear', 'that', ',', 'in']\n"
]
}
],
"source": [
"preprocessed = re.split(r'([,.?_!\"()\\']|--|\\s)', raw_text)\n",
"preprocessed = [item.strip() for item in preprocessed if item.strip()]\n",
"print(preprocessed[:30])"
]
},
{
"cell_type": "markdown",
"id": "e2a19e1a-5105-4ddb-812a-b7d3117eab95",
"metadata": {},
"source": [
"- 让我们来统计词元token的数量"
]
},
{
"cell_type": "code",
"execution_count": 8,
"id": "35db7b5e-510b-4c45-995f-f5ad64a8e19c",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"4649\n"
]
}
],
"source": [
"print(len(preprocessed))"
]
},
{
"cell_type": "markdown",
"id": "0b5ce8fe-3a07-4f2a-90f1-a0321ce3a231",
"metadata": {},
"source": [
"## 2.3 将词元转换为词元IDs"
]
},
{
"cell_type": "markdown",
"id": "b5973794-7002-4202-8b12-0900cd779720",
"metadata": {},
"source": [
"- 从这些词元中,我们现在可以构建一个词汇表,它包含了所有独特的词元。"
]
},
{
"cell_type": "code",
"execution_count": 9,
"id": "7fdf0533-5ab6-42a5-83fa-a3b045de6396",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"1159\n"
]
}
],
"source": [
"all_words = sorted(list(set(preprocessed)))\n",
"vocab_size = len(all_words)\n",
"\n",
"print(vocab_size)"
]
},
{
"cell_type": "code",
"execution_count": 10,
"id": "77d00d96-881f-4691-bb03-84fec2a75a26",
"metadata": {},
"outputs": [],
"source": [
"vocab = {token:integer for integer,token in enumerate(all_words)}"
]
},
{
"cell_type": "markdown",
"id": "75bd1f81-3a8f-4dd9-9dd6-e75f32dacbe3",
"metadata": {},
"source": [
"- 以下是这个词汇表中的前50个条目"
]
},
{
"cell_type": "code",
"execution_count": 11,
"id": "e1c5de4a-aa4e-4aec-b532-10bb364039d6",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"('!', 0)\n",
"('\"', 1)\n",
"(\"'\", 2)\n",
"('(', 3)\n",
"(')', 4)\n",
"(',', 5)\n",
"('--', 6)\n",
"('.', 7)\n",
"(':', 8)\n",
"(';', 9)\n",
"('?', 10)\n",
"('A', 11)\n",
"('Ah', 12)\n",
"('Among', 13)\n",
"('And', 14)\n",
"('Are', 15)\n",
"('Arrt', 16)\n",
"('As', 17)\n",
"('At', 18)\n",
"('Be', 19)\n",
"('Begin', 20)\n",
"('Burlington', 21)\n",
"('But', 22)\n",
"('By', 23)\n",
"('Carlo', 24)\n",
"('Carlo;', 25)\n",
"('Chicago', 26)\n",
"('Claude', 27)\n",
"('Come', 28)\n",
"('Croft', 29)\n",
"('Destroyed', 30)\n",
"('Devonshire', 31)\n",
"('Don', 32)\n",
"('Dubarry', 33)\n",
"('Emperors', 34)\n",
"('Florence', 35)\n",
"('For', 36)\n",
"('Gallery', 37)\n",
"('Gideon', 38)\n",
"('Gisburn', 39)\n",
"('Gisburns', 40)\n",
"('Grafton', 41)\n",
"('Greek', 42)\n",
"('Grindle', 43)\n",
"('Grindle:', 44)\n",
"('Grindles', 45)\n",
"('HAD', 46)\n",
"('Had', 47)\n",
"('Hang', 48)\n",
"('Has', 49)\n",
"('He', 50)\n"
]
}
],
"source": [
"for i, item in enumerate(vocab.items()):\n",
" print(item)\n",
" if i >= 50:\n",
" break"
]
},
{
"cell_type": "markdown",
"id": "4e569647-2589-4c9d-9a5c-aef1c88a0a9a",
"metadata": {},
"source": [
"- 将所有这些整合到一个分词器类tokenizer class中"
]
},
{
"cell_type": "code",
"execution_count": 12,
"id": "f531bf46-7c25-4ef8-bff8-0d27518676d5",
"metadata": {},
"outputs": [],
"source": [
"class SimpleTokenizerV1:\n",
" def __init__(self, vocab):\n",
" self.str_to_int = vocab\n",
" self.int_to_str = {i:s for s,i in vocab.items()}\n",
" \n",
" def encode(self, text):\n",
" preprocessed = re.split(r'([,.?_!\"()\\']|--|\\s)', text)\n",
" preprocessed = [item.strip() for item in preprocessed if item.strip()]\n",
" ids = [self.str_to_int[s] for s in preprocessed]\n",
" return ids\n",
" \n",
" def decode(self, ids):\n",
" text = \" \".join([self.int_to_str[i] for i in ids])\n",
" # Replace spaces before the specified punctuations\n",
" text = re.sub(r'\\s+([,.?!\"()\\'])', r'\\1', text)\n",
" return text"
]
},
{
"cell_type": "markdown",
"id": "c2950a94-6b0d-474e-8ed0-66d0c3c1a95c",
"metadata": {},
"source": [
"- 我们可以使用分词器将本文编码成整数\n",
"- 这些整数随后可以作为语言模型LLM的输入进行嵌入稍后。"
]
},
{
"cell_type": "code",
"execution_count": 13,
"id": "647364ec-7995-4654-9b4a-7607ccf5f1e4",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"[1, 58, 2, 872, 1013, 615, 541, 763, 5, 1155, 608, 5, 1, 69, 7, 39, 873, 1136, 773, 812, 7]\n"
]
}
],
"source": [
"tokenizer = SimpleTokenizerV1(vocab)\n",
"\n",
"text = \"\"\"\"It's the last he painted, you know,\" Mrs. Gisburn said with pardonable pride.\"\"\"\n",
"ids = tokenizer.encode(text)\n",
"print(ids)"
]
},
{
"cell_type": "markdown",
"id": "3201706e-a487-4b60-b99d-5765865f29a0",
"metadata": {},
"source": [
"- 我们可以将数字解码成文本"
]
},
{
"cell_type": "code",
"execution_count": 14,
"id": "01d8c8fb-432d-4a49-b332-99f23b233746",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"'\" It\\' s the last he painted, you know,\" Mrs. Gisburn said with pardonable pride.'"
]
},
"execution_count": 14,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"tokenizer.decode(ids)"
]
},
{
"cell_type": "code",
"execution_count": 15,
"id": "54f6aa8b-9827-412e-9035-e827296ab0fe",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"'\" It\\' s the last he painted, you know,\" Mrs. Gisburn said with pardonable pride.'"
]
},
"execution_count": 15,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"tokenizer.decode(tokenizer.encode(text))"
]
},
{
"cell_type": "markdown",
"id": "4b821ef8-4d53-43b6-a2b2-aef808c343c7",
"metadata": {},
"source": [
"## 2.4 添加特殊上下文词元"
]
},
{
"cell_type": "markdown",
"id": "9d709d57-2486-4152-b7f9-d3e4bd8634cd",
"metadata": {},
"source": [
"- 一些分词器使用特殊词元来帮助语言模型LLM获取额外的上下文信息。\n",
"- 一些特殊的词元是\n",
" - `[BOS]` (beginning of sequence) 标志着文本的开始\n",
" - `[EOS]` (end of sequence) 标记文本结束的位置(通常用于连接多个不相关的文本,如两篇不同的维基百科文章或两本不同的书等)\n",
" - `[PAD]` (padding) 如果我们训练 LLM 的批量大于 1我们可能会包含多个长度不同的文本使用填充标记我们会将较短的文本填充为最长的文本这样所有文本的长度就相等了。\n",
"- `[UNK]` 代表未列入词汇表的单词\n",
"\n",
"- 请注意GPT-2 不需要上述任何标记,而只使用 `<|endoftext|>`来降低复杂性\n",
"- `<|endoftext|>` 类似于上文提到的`[EOS]`\n",
"- GPT 也使用 `<|endoftext|>` 作为填充词元(因为在批量输入训练时我们通常会使用掩码,反正我们不会关注填充的词元,所以这些词元是什么并不重要)。\n",
"- GPT-2 不使用 `<UNK>` 词元来处理词汇表之外的单词相反GPT-2 使用字节对编码BPE分词器它将单词分解为子词单元我们将在后面的部分讨论这一点。"
]
},
{
"cell_type": "markdown",
"id": "c661a397-da06-4a86-ac27-072dbe7cb172",
"metadata": {},
"source": [
"- 让我们来看看将下面的文本进行分词后会发生什么:"
]
},
{
"cell_type": "code",
"execution_count": 17,
"id": "d5767eff-440c-4de1-9289-f789349d6b85",
"metadata": {},
"outputs": [
{
"ename": "KeyError",
"evalue": "'Hello'",
"output_type": "error",
"traceback": [
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
"\u001b[0;31mKeyError\u001b[0m Traceback (most recent call last)",
"Cell \u001b[0;32mIn[17], line 5\u001b[0m\n\u001b[1;32m 1\u001b[0m tokenizer \u001b[38;5;241m=\u001b[39m SimpleTokenizerV1(vocab)\n\u001b[1;32m 3\u001b[0m text \u001b[38;5;241m=\u001b[39m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mHello, do you like tea. Is this-- a test?\u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[0;32m----> 5\u001b[0m \u001b[43mtokenizer\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mencode\u001b[49m\u001b[43m(\u001b[49m\u001b[43mtext\u001b[49m\u001b[43m)\u001b[49m\n",
"Cell \u001b[0;32mIn[12], line 9\u001b[0m, in \u001b[0;36mSimpleTokenizerV1.encode\u001b[0;34m(self, text)\u001b[0m\n\u001b[1;32m 7\u001b[0m preprocessed \u001b[38;5;241m=\u001b[39m re\u001b[38;5;241m.\u001b[39msplit(\u001b[38;5;124mr\u001b[39m\u001b[38;5;124m'\u001b[39m\u001b[38;5;124m([,.?_!\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124m()\u001b[39m\u001b[38;5;130;01m\\'\u001b[39;00m\u001b[38;5;124m]|--|\u001b[39m\u001b[38;5;124m\\\u001b[39m\u001b[38;5;124ms)\u001b[39m\u001b[38;5;124m'\u001b[39m, text)\n\u001b[1;32m 8\u001b[0m preprocessed \u001b[38;5;241m=\u001b[39m [item\u001b[38;5;241m.\u001b[39mstrip() \u001b[38;5;28;01mfor\u001b[39;00m item \u001b[38;5;129;01min\u001b[39;00m preprocessed \u001b[38;5;28;01mif\u001b[39;00m item\u001b[38;5;241m.\u001b[39mstrip()]\n\u001b[0;32m----> 9\u001b[0m ids \u001b[38;5;241m=\u001b[39m [\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mstr_to_int[s] \u001b[38;5;28;01mfor\u001b[39;00m s \u001b[38;5;129;01min\u001b[39;00m preprocessed]\n\u001b[1;32m 10\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m ids\n",
"Cell \u001b[0;32mIn[12], line 9\u001b[0m, in \u001b[0;36m<listcomp>\u001b[0;34m(.0)\u001b[0m\n\u001b[1;32m 7\u001b[0m preprocessed \u001b[38;5;241m=\u001b[39m re\u001b[38;5;241m.\u001b[39msplit(\u001b[38;5;124mr\u001b[39m\u001b[38;5;124m'\u001b[39m\u001b[38;5;124m([,.?_!\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124m()\u001b[39m\u001b[38;5;130;01m\\'\u001b[39;00m\u001b[38;5;124m]|--|\u001b[39m\u001b[38;5;124m\\\u001b[39m\u001b[38;5;124ms)\u001b[39m\u001b[38;5;124m'\u001b[39m, text)\n\u001b[1;32m 8\u001b[0m preprocessed \u001b[38;5;241m=\u001b[39m [item\u001b[38;5;241m.\u001b[39mstrip() \u001b[38;5;28;01mfor\u001b[39;00m item \u001b[38;5;129;01min\u001b[39;00m preprocessed \u001b[38;5;28;01mif\u001b[39;00m item\u001b[38;5;241m.\u001b[39mstrip()]\n\u001b[0;32m----> 9\u001b[0m ids \u001b[38;5;241m=\u001b[39m [\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mstr_to_int\u001b[49m\u001b[43m[\u001b[49m\u001b[43ms\u001b[49m\u001b[43m]\u001b[49m \u001b[38;5;28;01mfor\u001b[39;00m s \u001b[38;5;129;01min\u001b[39;00m preprocessed]\n\u001b[1;32m 10\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m ids\n",
"\u001b[0;31mKeyError\u001b[0m: 'Hello'"
]
}
],
"source": [
"tokenizer = SimpleTokenizerV1(vocab)\n",
"\n",
"text = \"Hello, do you like tea. Is this-- a test?\"\n",
"\n",
"tokenizer.encode(text)"
]
},
{
"cell_type": "markdown",
"id": "dc53ee0c-fe2b-4cd8-a946-5471f7651acf",
"metadata": {},
"source": [
"- 由于词汇表中不包含 \"Hello \"一词,因此上面的示例产生了错误\n",
"- 为了处理这种情况,我们可以在词汇表中添加特殊词元,如`\"<|unk|>\"`,以表示未知词\n",
"- 既然我们已经在扩展词汇表,那就再添加一个名为`\"<|endoftext|>\"`的词元,它在 GPT-2 训练中用于表示文本的结束(它也用于连接文本,比如我们的训练数据集由多篇文章、书籍等组成)。"
]
},
{
"cell_type": "code",
"execution_count": 18,
"id": "ce9df29c-6c5b-43f1-8c1a-c7f7b79db78f",
"metadata": {},
"outputs": [],
"source": [
"preprocessed = re.split(r'([,.?_!\"()\\']|--|\\s)', raw_text)\n",
"preprocessed = [item.strip() for item in preprocessed if item.strip()]\n",
"\n",
"all_tokens = sorted(list(set(preprocessed)))\n",
"all_tokens.extend([\"<|endoftext|>\", \"<|unk|>\"])\n",
"\n",
"vocab = {token:integer for integer,token in enumerate(all_tokens)}"
]
},
{
"cell_type": "code",
"execution_count": 19,
"id": "57c3143b-e860-4d3b-a22a-de22b547a6a9",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"1161"
]
},
"execution_count": 19,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"len(vocab.items())"
]
},
{
"cell_type": "code",
"execution_count": 20,
"id": "50e51bb1-ae05-4aa8-a9ff-455b65ed1959",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"('younger', 1156)\n",
"('your', 1157)\n",
"('yourself', 1158)\n",
"('<|endoftext|>', 1159)\n",
"('<|unk|>', 1160)\n"
]
}
],
"source": [
"for i, item in enumerate(list(vocab.items())[-5:]):\n",
" print(item)"
]
},
{
"cell_type": "markdown",
"id": "a1daa2b0-6e75-412b-ab53-1f6fb7b4d453",
"metadata": {},
"source": [
"- 我们还需要相应调整分词器,以便它知道何时以及如何使用新的 `<unk>` 词元"
]
},
{
"cell_type": "code",
"execution_count": 21,
"id": "948861c5-3f30-4712-a234-725f20d26f68",
"metadata": {},
"outputs": [],
"source": [
"class SimpleTokenizerV2:\n",
" def __init__(self, vocab):\n",
" self.str_to_int = vocab\n",
" self.int_to_str = { i:s for s,i in vocab.items()}\n",
" \n",
" def encode(self, text):\n",
" preprocessed = re.split(r'([,.?_!\"()\\']|--|\\s)', text)\n",
" preprocessed = [item.strip() for item in preprocessed if item.strip()]\n",
" preprocessed = [item if item in self.str_to_int \n",
" else \"<|unk|>\" for item in preprocessed]\n",
"\n",
" ids = [self.str_to_int[s] for s in preprocessed]\n",
" return ids\n",
" \n",
" def decode(self, ids):\n",
" text = \" \".join([self.int_to_str[i] for i in ids])\n",
" # 替换指定标点符号前的空格\n",
" text = re.sub(r'\\s+([,.?!\"()\\'])', r'\\1', text)\n",
" return text"
]
},
{
"cell_type": "markdown",
"id": "aa728dd1-9d35-4ac7-938f-d411d73083f6",
"metadata": {},
"source": [
"让我们尝试使用修改后的分词器对文本进行分词:"
]
},
{
"cell_type": "code",
"execution_count": 22,
"id": "effcef79-e0a5-4f4a-a43a-31dd94b9250a",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Hello, do you like tea? <|endoftext|> In the sunlit terraces of the palace.\n"
]
}
],
"source": [
"tokenizer = SimpleTokenizerV2(vocab)\n",
"\n",
"text1 = \"Hello, do you like tea?\"\n",
"text2 = \"In the sunlit terraces of the palace.\"\n",
"\n",
"text = \" <|endoftext|> \".join((text1, text2))\n",
"\n",
"print(text)"
]
},
{
"cell_type": "code",
"execution_count": 23,
"id": "ddfe7346-398d-4bf8-99f1-5b071244ce95",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"[1160,\n",
" 5,\n",
" 362,\n",
" 1155,\n",
" 642,\n",
" 1000,\n",
" 10,\n",
" 1159,\n",
" 57,\n",
" 1013,\n",
" 981,\n",
" 1009,\n",
" 738,\n",
" 1013,\n",
" 1160,\n",
" 7]"
]
},
"execution_count": 23,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"tokenizer.encode(text)"
]
},
{
"cell_type": "code",
"execution_count": 24,
"id": "0c350ff6-2734-4e84-9ec7-d578baa4ae1b",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"'<|unk|>, do you like tea? <|endoftext|> In the sunlit terraces of the <|unk|>.'"
]
},
"execution_count": 24,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"tokenizer.decode(tokenizer.encode(text))"
]
},
{
"cell_type": "markdown",
"id": "5c4ba34b-170f-4e71-939b-77aabb776f14",
"metadata": {},
"source": [
"## 2.5 BytePair encoding字节对编码"
]
},
{
"cell_type": "markdown",
"id": "2309494c-79cf-4a2d-bc28-a94d602f050e",
"metadata": {},
"source": [
"- GPT-2 使用字节对编码BPE作为分词器\n",
"- 它允许模型将不在其预定义词汇表中的单词分解为更小的子单词单元甚至单个字符,从而使其能够处理词汇表之外的单词\n",
"- 例如,如果 GPT-2 的词汇表中没有 \"unfamiliarword \"这个词,它可能会根据训练有素的 BPE 合并结果,将其标记为[\"unfam\"、\"iliar\"、\"word\"]或其他子词\n",
"- 原始 BPE 分词器可在此处找到:[https://github.com/openai/gpt-2/blob/master/src/encoder.py](https://github.com/openai/gpt-2/blob/master/src/encoder.py)\n",
"- 在本章中,我们使用 OpenAI 的开源 [tiktoken](https://github.com/openai/tiktoken) 库中的 BPE 标记符号生成器,该库用 Rust 实现了其核心算法,以提高计算性能\n",
"- 我在 [./bytepair_encoder](./bytepair_encoder)中创建了一个笔记本并列比较了这两种实现方法tiktoken 在样本文本上的速度大约快 5 倍)。"
]
},
{
"cell_type": "code",
"execution_count": 25,
"id": "ede1d41f-934b-4bf4-8184-54394a257a94",
"metadata": {},
"outputs": [],
"source": [
"# pip install tiktoken"
]
},
{
"cell_type": "code",
"execution_count": 26,
"id": "48967a77-7d17-42bf-9e92-fc619d63a59e",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"tiktoken version: 0.6.0\n"
]
}
],
"source": [
"import importlib\n",
"import tiktoken\n",
"\n",
"print(\"tiktoken version:\", importlib.metadata.version(\"tiktoken\"))"
]
},
{
"cell_type": "code",
"execution_count": 27,
"id": "6ad3312f-a5f7-4efc-9d7d-8ea09d7b5128",
"metadata": {},
"outputs": [],
"source": [
"tokenizer = tiktoken.get_encoding(\"gpt2\")"
]
},
{
"cell_type": "code",
"execution_count": 31,
"id": "5ff2cd85-7cfb-4325-b390-219938589428",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"[15496, 11, 466, 345, 588, 8887, 30, 220, 50256, 554, 262, 4252, 18250, 8812, 2114, 286, 617, 34680, 27271, 13]\n"
]
}
],
"source": [
"text = \"Hello, do you like tea? <|endoftext|> In the sunlit terraces of someunknownPlace.\"\n",
"\n",
"integers = tokenizer.encode(text, allowed_special={\"<|endoftext|>\"})\n",
"\n",
"print(integers)"
]
},
{
"cell_type": "code",
"execution_count": 32,
"id": "d26a48bb-f82e-41a8-a955-a1c9cf9d50ab",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Hello, do you like tea? <|endoftext|> In the sunlit terraces of someunknownPlace.\n"
]
}
],
"source": [
"strings = tokenizer.decode(integers)\n",
"\n",
"print(strings)"
]
},
{
"cell_type": "markdown",
"id": "f63d62ab-4b80-489c-8041-e4052fe29969",
"metadata": {},
"source": [
"- 未知单词实验:"
]
},
{
"cell_type": "code",
"execution_count": 34,
"id": "ce25cf25-a2bb-44d2-bac1-cb566f433f98",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"[33901, 86, 343, 86, 220, 959]\n"
]
}
],
"source": [
"integers = tokenizer.encode(\"Akwirw ier\")\n",
"print(integers)"
]
},
{
"cell_type": "code",
"execution_count": 35,
"id": "3e224f96-41d0-4074-ac6e-f7db2490f806",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"33901 -> Ak\n",
"86 -> w\n",
"343 -> ir\n",
"86 -> w\n",
"220 -> \n",
"959 -> ier\n"
]
}
],
"source": [
"for i in integers:\n",
" print(f\"{i} -> {tokenizer.decode([i])}\")"
]
},
{
"cell_type": "code",
"execution_count": 36,
"id": "766bcf29-64bf-47ca-9b65-4ae8e607d580",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Akwirw ier\n"
]
}
],
"source": [
"strings = tokenizer.decode(integers)\n",
"print(strings)"
]
},
{
"cell_type": "markdown",
"id": "abbd7c0d-70f8-4386-a114-907e96c950b0",
"metadata": {},
"source": [
"## 2.6 使用滑动窗口进行数据采样"
]
},
{
"cell_type": "code",
"execution_count": 37,
"id": "848d5ade-fd1f-46c3-9e31-1426e315c71b",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"5145\n"
]
}
],
"source": [
"with open(\"the-verdict.txt\", \"r\", encoding=\"utf-8\") as f:\n",
" raw_text = f.read()\n",
"\n",
"enc_text = tokenizer.encode(raw_text)\n",
"print(len(enc_text))"
]
},
{
"cell_type": "markdown",
"id": "cebd0657-5543-43ca-8011-2ae6bd0a5810",
"metadata": {},
"source": [
"- 对于每个文本块,我们需要输入和目标\n",
"- 由于我们希望模型预测下一个单词,因此目标是向右移动一个位置的输入值"
]
},
{
"cell_type": "code",
"execution_count": 38,
"id": "e84424a7-646d-45b6-99e3-80d15fb761f2",
"metadata": {},
"outputs": [],
"source": [
"enc_sample = enc_text[50:]"
]
},
{
"cell_type": "code",
"execution_count": 39,
"id": "dfbff852-a92f-48c8-a46d-143a0f109f40",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"x: [290, 4920, 2241, 287]\n",
"y: [4920, 2241, 287, 257]\n"
]
}
],
"source": [
"context_size = 4\n",
"\n",
"x = enc_sample[:context_size]\n",
"y = enc_sample[1:context_size+1]\n",
"\n",
"print(f\"x: {x}\")\n",
"print(f\"y: {y}\")"
]
},
{
"cell_type": "markdown",
"id": "815014ef-62f7-4476-a6ad-66e20e42b7c3",
"metadata": {},
"source": [
"- 逐一预测的结果如下:"
]
},
{
"cell_type": "code",
"execution_count": 40,
"id": "d97b031e-ed55-409d-95f2-aeb38c6fe366",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"[290] ----> 4920\n",
"[290, 4920] ----> 2241\n",
"[290, 4920, 2241] ----> 287\n",
"[290, 4920, 2241, 287] ----> 257\n"
]
}
],
"source": [
"for i in range(1, context_size+1):\n",
" context = enc_sample[:i]\n",
" desired = enc_sample[i]\n",
"\n",
" print(context, \"---->\", desired)"
]
},
{
"cell_type": "code",
"execution_count": 41,
"id": "f57bd746-dcbf-4433-8e24-ee213a8c34a1",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
" and ----> established\n",
" and established ----> himself\n",
" and established himself ----> in\n",
" and established himself in ----> a\n"
]
}
],
"source": [
"for i in range(1, context_size+1):\n",
" context = enc_sample[:i]\n",
" desired = enc_sample[i]\n",
"\n",
" print(tokenizer.decode(context), \"---->\", tokenizer.decode([desired]))"
]
},
{
"cell_type": "markdown",
"id": "210d2dd9-fc20-4927-8d3d-1466cf41aae1",
"metadata": {},
"source": [
"- 我们将在下一章介绍注意力机制后再处理下一个单词的预测\n",
"- 现在,我们实现了一个简单的数据加载器,它可以遍历输入数据集,并返回输入和目标值相移一的结果"
]
},
{
"cell_type": "markdown",
"id": "a1a1b47a-f646-49d1-bc70-fddf2c840796",
"metadata": {},
"source": [
"- 安装并导入 PyTorch安装提示见附录 A"
]
},
{
"cell_type": "code",
"execution_count": 42,
"id": "e1770134-e7f3-4725-a679-e04c3be48cac",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"PyTorch version: 2.2.1+cu121\n"
]
}
],
"source": [
"import torch\n",
"print(\"PyTorch version:\", torch.__version__)"
]
},
{
"cell_type": "markdown",
"id": "92ac652d-7b38-4843-9fbd-494cdc8ec12c",
"metadata": {},
"source": [
"- 创建数据集和数据加载器,从输入的文本数据集中提取数据块"
]
},
{
"cell_type": "code",
"execution_count": 43,
"id": "74b41073-4c9f-46e2-a1bd-d38e4122b375",
"metadata": {},
"outputs": [],
"source": [
"from torch.utils.data import Dataset, DataLoader\n",
"\n",
"\n",
"class GPTDatasetV1(Dataset):\n",
" def __init__(self, txt, tokenizer, max_length, stride):\n",
" self.tokenizer = tokenizer\n",
" self.input_ids = []\n",
" self.target_ids = []\n",
"\n",
" # 对全部文本进行分词\n",
" token_ids = tokenizer.encode(txt, allowed_special={'<|endoftext|>'})\n",
"\n",
" # 使用滑动窗口将图书分块为最大长度的重叠序列\n",
" for i in range(0, len(token_ids) - max_length, stride):\n",
" input_chunk = token_ids[i:i + max_length]\n",
" target_chunk = token_ids[i + 1: i + max_length + 1]\n",
" self.input_ids.append(torch.tensor(input_chunk))\n",
" self.target_ids.append(torch.tensor(target_chunk))\n",
"\n",
" def __len__(self):\n",
" return len(self.input_ids)\n",
"\n",
" def __getitem__(self, idx):\n",
" return self.input_ids[idx], self.target_ids[idx]"
]
},
{
"cell_type": "code",
"execution_count": 44,
"id": "5eb30ebe-97b3-43c5-9ff1-a97d621b3c4e",
"metadata": {},
"outputs": [],
"source": [
"def create_dataloader_v1(txt, batch_size=4, max_length=256, stride=128, shuffle=True, drop_last=True):\n",
"\n",
" # 分词器初始化\n",
" tokenizer = tiktoken.get_encoding(\"gpt2\")\n",
"\n",
" # 创建数据集\n",
" dataset = GPTDatasetV1(txt, tokenizer, max_length, stride)\n",
"\n",
" # 创建加载器\n",
" dataloader = DataLoader(\n",
" dataset, batch_size=batch_size, shuffle=shuffle, drop_last=drop_last)\n",
"\n",
" return dataloader"
]
},
{
"cell_type": "markdown",
"id": "42dd68ef-59f7-45ff-ba44-e311c899ddcd",
"metadata": {},
"source": [
"- 让我们用批量大小为 1 的数据加载器来测试上下文大小为 4 的 LLM"
]
},
{
"cell_type": "code",
"execution_count": 45,
"id": "df31d96c-6bfd-4564-a956-6192242d7579",
"metadata": {},
"outputs": [],
"source": [
"with open(\"the-verdict.txt\", \"r\", encoding=\"utf-8\") as f:\n",
" raw_text = f.read()"
]
},
{
"cell_type": "code",
"execution_count": 47,
"id": "9226d00c-ad9a-4949-a6e4-9afccfc7214f",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"[tensor([[ 40, 367, 2885, 1464]]), tensor([[ 367, 2885, 1464, 1807]])]\n"
]
}
],
"source": [
"dataloader = create_dataloader_v1(raw_text, batch_size=1, max_length=4, stride=1, shuffle=False)\n",
"\n",
"data_iter = iter(dataloader)\n",
"first_batch = next(data_iter)\n",
"print(first_batch)"
]
},
{
"cell_type": "code",
"execution_count": 48,
"id": "10deb4bc-4de1-4d20-921e-4b1c7a0e1a6d",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"[tensor([[ 367, 2885, 1464, 1807]]), tensor([[2885, 1464, 1807, 3619]])]\n"
]
}
],
"source": [
"second_batch = next(data_iter)\n",
"print(second_batch)"
]
},
{
"cell_type": "markdown",
"id": "b1ae6d45-f26e-4b83-9c7b-cff55ffa7d16",
"metadata": {},
"source": [
"- 我们也可以创造批量的输出\n",
"- 请注意,我们在这里增加了步长,以避免批次之间的重叠,因为更多的重叠可能会导致过拟合增加。"
]
},
{
"cell_type": "code",
"execution_count": 49,
"id": "1916e7a6-f03d-4f09-91a6-d0bdbac5a58c",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Inputs:\n",
" tensor([[ 40, 367, 2885, 1464],\n",
" [ 1807, 3619, 402, 271],\n",
" [10899, 2138, 257, 7026],\n",
" [15632, 438, 2016, 257],\n",
" [ 922, 5891, 1576, 438],\n",
" [ 568, 340, 373, 645],\n",
" [ 1049, 5975, 284, 502],\n",
" [ 284, 3285, 326, 11]])\n",
"\n",
"Targets:\n",
" tensor([[ 367, 2885, 1464, 1807],\n",
" [ 3619, 402, 271, 10899],\n",
" [ 2138, 257, 7026, 15632],\n",
" [ 438, 2016, 257, 922],\n",
" [ 5891, 1576, 438, 568],\n",
" [ 340, 373, 645, 1049],\n",
" [ 5975, 284, 502, 284],\n",
" [ 3285, 326, 11, 287]])\n"
]
}
],
"source": [
"dataloader = create_dataloader_v1(raw_text, batch_size=8, max_length=4, stride=4, shuffle=False)\n",
"\n",
"data_iter = iter(dataloader)\n",
"inputs, targets = next(data_iter)\n",
"print(\"Inputs:\\n\", inputs)\n",
"print(\"\\nTargets:\\n\", targets)"
]
},
{
"cell_type": "markdown",
"id": "2cd2fcda-2fda-4aa8-8bc8-de1e496f9db1",
"metadata": {},
"source": [
"## 2.7 创建词元嵌入"
]
},
{
"cell_type": "markdown",
"id": "1a301068-6ab2-44ff-a915-1ba11688274f",
"metadata": {},
"source": [
" - 数据已经几乎准备好用于大型语言模型LLM了\n",
"- 但最后,让我们使用嵌入层将这些词元嵌入到连续的向量表示中\n",
"- 通常,这些嵌入层是大型语言模型本身的一部分,在模型训练期间会进行更新(训练)"
]
},
{
"cell_type": "markdown",
"id": "44e014ca-1fc5-4b90-b6fa-c2097bb92c0b",
"metadata": {},
"source": [
" - 假设我们有以下三个输入示例输入ID分别为5、1、3和2在进行分词之后"
]
},
{
"cell_type": "code",
"execution_count": 50,
"id": "15a6304c-9474-4470-b85d-3991a49fa653",
"metadata": {},
"outputs": [],
"source": [
"input_ids = torch.tensor([5, 1, 3, 2])"
]
},
{
"cell_type": "markdown",
"id": "14da6344-2c71-4837-858d-dd120005ba05",
"metadata": {},
"source": [
" - 为了简化问题假设我们有一个很小的词汇表只包含6个单词我们想要创建大小为3的嵌入向量。"
]
},
{
"cell_type": "code",
"execution_count": 51,
"id": "93cb2cee-9aa6-4bb8-8977-c65661d16eda",
"metadata": {},
"outputs": [],
"source": [
"vocab_size = 6\n",
"output_dim = 3\n",
"\n",
"torch.manual_seed(123)\n",
"embedding_layer = torch.nn.Embedding(vocab_size, output_dim)"
]
},
{
"cell_type": "markdown",
"id": "4ff241f6-78eb-4e4a-a55f-5b2b6196d5b0",
"metadata": {},
"source": [
" - 这将导致一个6行3列的权重矩阵"
]
},
{
"cell_type": "code",
"execution_count": 52,
"id": "a686eb61-e737-4351-8f1c-222913d47468",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Parameter containing:\n",
"tensor([[ 0.3374, -0.1778, -0.1690],\n",
" [ 0.9178, 1.5810, 1.3010],\n",
" [ 1.2753, -0.2010, -0.1606],\n",
" [-0.4015, 0.9666, -1.1481],\n",
" [-1.1589, 0.3255, -0.6315],\n",
" [-2.8400, -0.7849, -1.4096]], requires_grad=True)\n"
]
}
],
"source": [
"print(embedding_layer.weight)"
]
},
{
"cell_type": "markdown",
"id": "26fcf4f5-0801-4eb4-bb90-acce87935ac7",
"metadata": {},
"source": [
" - 对于那些熟悉独热编码one-hot encoding的人来说上述的嵌入层方法本质上只是一种更高效的实现方式它相当于在全连接层中先进行独热编码然后进行矩阵乘法这在补充代码[./embedding_vs_matmul](./embedding_vs_matmul)中有描述。\n",
"- 因为嵌入层只是一种更高效的实现方式,它等同于独热编码和矩阵乘法的方法,所以它可以被视为一个可以通过反向传播进行优化的神经网络层。"
]
},
{
"cell_type": "markdown",
"id": "4b0d58c3-83c0-4205-aca2-9c48b19fd4a7",
"metadata": {},
"source": [
" - 要将ID为3的标记转换为三维向量我们执行以下操作"
]
},
{
"cell_type": "code",
"execution_count": 53,
"id": "e43600ba-f287-4746-8ddf-d0f71a9023ca",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"tensor([[-0.4015, 0.9666, -1.1481]], grad_fn=<EmbeddingBackward0>)\n"
]
}
],
"source": [
"print(embedding_layer(torch.tensor([3])))"
]
},
{
"cell_type": "markdown",
"id": "a7bbf625-4f36-491d-87b4-3969efb784b0",
"metadata": {},
"source": [
" - 请注意,上述内容是`embedding_layer`权重矩阵的第四行。\n",
"- 要嵌入上述所有三个`input_ids`的值,我们执行以下操作:"
]
},
{
"cell_type": "code",
"execution_count": 54,
"id": "50280ead-0363-44c8-8c35-bb885d92c8b7",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"tensor([[-2.8400, -0.7849, -1.4096],\n",
" [ 0.9178, 1.5810, 1.3010],\n",
" [-0.4015, 0.9666, -1.1481],\n",
" [ 1.2753, -0.2010, -0.1606]], grad_fn=<EmbeddingBackward0>)\n"
]
}
],
"source": [
"print(embedding_layer(input_ids))"
]
},
{
"cell_type": "markdown",
"id": "c393d270-b950-4bc8-99ea-97d74f2ea0f6",
"metadata": {},
"source": [
"## 2.8 单词位置编码"
]
},
{
"cell_type": "markdown",
"id": "7f187f87-c1f8-4c2e-8050-350bbb972f55",
"metadata": {},
"source": [
"- BytePair编码器有一个词汇量大小为50,257的词典\n",
"- 假设我们想要将输入的词元编码成256维的向量表示"
]
},
{
"cell_type": "code",
"execution_count": 55,
"id": "0b9e344d-03a6-4f2c-b723-67b6a20c5041",
"metadata": {},
"outputs": [],
"source": [
"vocab_size = 50257\n",
"output_dim = 256\n",
"\n",
"token_embedding_layer = torch.nn.Embedding(vocab_size, output_dim)"
]
},
{
"cell_type": "markdown",
"id": "a2654722-24e4-4b0d-a43c-436a461eb70b",
"metadata": {},
"source": [
" - 如果我们从数据加载器中采样数据我们将每个批次中的词元嵌入到一个256维的向量中。\n",
"- 如果我们的批次大小为8每个批次有4个token这将导致一个8 x 4 x 256的张量。"
]
},
{
"cell_type": "code",
"execution_count": 56,
"id": "ad56a263-3d2e-4d91-98bf-d0b68d3c7fc3",
"metadata": {},
"outputs": [],
"source": [
"max_length = 4\n",
"dataloader = create_dataloader_v1(raw_text, batch_size=8, max_length=max_length, stride=5, shuffle=False)\n",
"data_iter = iter(dataloader)\n",
"inputs, targets = next(data_iter)"
]
},
{
"cell_type": "code",
"execution_count": 57,
"id": "84416b60-3707-4370-bcbc-da0b62f2b64d",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Token IDs:\n",
" tensor([[ 40, 367, 2885, 1464],\n",
" [ 3619, 402, 271, 10899],\n",
" [ 257, 7026, 15632, 438],\n",
" [ 257, 922, 5891, 1576],\n",
" [ 568, 340, 373, 645],\n",
" [ 5975, 284, 502, 284],\n",
" [ 326, 11, 287, 262],\n",
" [ 286, 465, 13476, 11]])\n",
"\n",
"Inputs shape:\n",
" torch.Size([8, 4])\n"
]
}
],
"source": [
"print(\"Token IDs:\\n\", inputs)\n",
"print(\"\\nInputs shape:\\n\", inputs.shape)"
]
},
{
"cell_type": "code",
"execution_count": 58,
"id": "7766ec38-30d0-4128-8c31-f49f063c43d1",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"torch.Size([8, 4, 256])\n"
]
}
],
"source": [
"token_embeddings = token_embedding_layer(inputs)\n",
"print(token_embeddings.shape)"
]
},
{
"cell_type": "markdown",
"id": "fe2ae164-6f19-4e32-b9e5-76950fcf1c9f",
"metadata": {},
"source": [
" - GPT-2 使用绝对位置嵌入,所以我们只需创建另一个嵌入层:"
]
},
{
"cell_type": "code",
"execution_count": 61,
"id": "cc048e20-7ac8-417e-81f5-8fe6f9a4fe07",
"metadata": {},
"outputs": [],
"source": [
"block_size = max_length\n",
"pos_embedding_layer = torch.nn.Embedding(block_size, output_dim)"
]
},
{
"cell_type": "code",
"execution_count": 62,
"id": "c369a1e7-d566-4b53-b398-d6adafb44105",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"torch.Size([4, 256])\n"
]
}
],
"source": [
"pos_embeddings = pos_embedding_layer(torch.arange(max_length))\n",
"print(pos_embeddings.shape)"
]
},
{
"cell_type": "markdown",
"id": "870e9d9f-2935-461a-9518-6d1386b976d6",
"metadata": {},
"source": [
" - 为了创建在大型语言模型LLM中使用的输入嵌入我们只需将词元嵌入和位置嵌入相加"
]
},
{
"cell_type": "code",
"execution_count": 63,
"id": "b22fab89-526e-43c8-9035-5b7018e34288",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"torch.Size([8, 4, 256])\n"
]
}
],
"source": [
"input_embeddings = token_embeddings + pos_embeddings\n",
"print(input_embeddings.shape)"
]
},
{
"cell_type": "markdown",
"id": "63230f2e-258f-4497-9e2e-8deee4530364",
"metadata": {},
"source": [
"# 总结和要点"
]
},
{
"cell_type": "markdown",
"id": "8b3293a6-45a5-47cd-aa00-b23e3ca0a73f",
"metadata": {},
"source": [
"**请查看[./dataloader.ipynb](./dataloader.ipynb)代码笔记本这是我们在本章实现的数据加载器的简洁版本我们将在接下来的章节中训练GPT模型时需要用到它。**\n",
"\n",
"**请查看[./exercise-solutions.ipynb](./exercise-solutions.ipynb)以获取练习题的解答。**"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.13"
}
},
"nbformat": 4,
"nbformat_minor": 5
}