mirror of
https://github.com/datawhalechina/llms-from-scratch-cn.git
synced 2026-05-03 13:02:35 +00:00
@@ -1,174 +0,0 @@
|
|||||||
"""
|
|
||||||
Byte pair encoding utilities
|
|
||||||
|
|
||||||
Code from https://github.com/openai/gpt-2/blob/master/src/encoder.py
|
|
||||||
|
|
||||||
And modified code (download_vocab) from
|
|
||||||
https://github.com/openai/gpt-2/blob/master/download_model.py
|
|
||||||
|
|
||||||
Modified MIT License
|
|
||||||
|
|
||||||
Software Copyright (c) 2019 OpenAI
|
|
||||||
|
|
||||||
We don’t claim ownership of the content you create with GPT-2, so it is yours to do with as you please.
|
|
||||||
We only ask that you use GPT-2 responsibly and clearly indicate your content was created using GPT-2.
|
|
||||||
|
|
||||||
Permission is hereby granted, free of charge, to any person obtaining a copy of this software and
|
|
||||||
associated documentation files (the "Software"), to deal in the Software without restriction,
|
|
||||||
including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense,
|
|
||||||
and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so,
|
|
||||||
subject to the following conditions:
|
|
||||||
|
|
||||||
The above copyright notice and this permission notice shall be included
|
|
||||||
in all copies or substantial portions of the Software.
|
|
||||||
The above copyright notice and this permission notice need not be included
|
|
||||||
with content created by the Software.
|
|
||||||
|
|
||||||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED,
|
|
||||||
INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
|
||||||
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS
|
|
||||||
BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
|
|
||||||
TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE
|
|
||||||
OR OTHER DEALINGS IN THE SOFTWARE.
|
|
||||||
|
|
||||||
|
|
||||||
"""
|
|
||||||
|
|
||||||
import os
|
|
||||||
import json
|
|
||||||
import regex as re
|
|
||||||
import requests
|
|
||||||
from tqdm import tqdm
|
|
||||||
from functools import lru_cache
|
|
||||||
|
|
||||||
@lru_cache()
|
|
||||||
def bytes_to_unicode():
|
|
||||||
"""
|
|
||||||
Returns list of utf-8 byte and a corresponding list of unicode strings.
|
|
||||||
The reversible bpe codes work on unicode strings.
|
|
||||||
This means you need a large # of unicode characters in your vocab if you want to avoid UNKs.
|
|
||||||
When you're at something like a 10B token dataset you end up needing around 5K for decent coverage.
|
|
||||||
This is a significant percentage of your normal, say, 32K bpe vocab.
|
|
||||||
To avoid that, we want lookup tables between utf-8 bytes and unicode strings.
|
|
||||||
And avoids mapping to whitespace/control characters the bpe code barfs on.
|
|
||||||
"""
|
|
||||||
bs = list(range(ord("!"), ord("~")+1))+list(range(ord("¡"), ord("¬")+1))+list(range(ord("®"), ord("ÿ")+1))
|
|
||||||
cs = bs[:]
|
|
||||||
n = 0
|
|
||||||
for b in range(2**8):
|
|
||||||
if b not in bs:
|
|
||||||
bs.append(b)
|
|
||||||
cs.append(2**8+n)
|
|
||||||
n += 1
|
|
||||||
cs = [chr(n) for n in cs]
|
|
||||||
return dict(zip(bs, cs))
|
|
||||||
|
|
||||||
def get_pairs(word):
|
|
||||||
"""Return set of symbol pairs in a word.
|
|
||||||
|
|
||||||
Word is represented as tuple of symbols (symbols being variable-length strings).
|
|
||||||
"""
|
|
||||||
pairs = set()
|
|
||||||
prev_char = word[0]
|
|
||||||
for char in word[1:]:
|
|
||||||
pairs.add((prev_char, char))
|
|
||||||
prev_char = char
|
|
||||||
return pairs
|
|
||||||
|
|
||||||
class Encoder:
|
|
||||||
def __init__(self, encoder, bpe_merges, errors='replace'):
|
|
||||||
self.encoder = encoder
|
|
||||||
self.decoder = {v:k for k,v in self.encoder.items()}
|
|
||||||
self.errors = errors # how to handle errors in decoding
|
|
||||||
self.byte_encoder = bytes_to_unicode()
|
|
||||||
self.byte_decoder = {v:k for k, v in self.byte_encoder.items()}
|
|
||||||
self.bpe_ranks = dict(zip(bpe_merges, range(len(bpe_merges))))
|
|
||||||
self.cache = {}
|
|
||||||
|
|
||||||
# Should haved added re.IGNORECASE so BPE merges can happen for capitalized versions of contractions
|
|
||||||
self.pat = re.compile(r"""'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+""")
|
|
||||||
|
|
||||||
def bpe(self, token):
|
|
||||||
if token in self.cache:
|
|
||||||
return self.cache[token]
|
|
||||||
word = tuple(token)
|
|
||||||
pairs = get_pairs(word)
|
|
||||||
|
|
||||||
if not pairs:
|
|
||||||
return token
|
|
||||||
|
|
||||||
while True:
|
|
||||||
bigram = min(pairs, key = lambda pair: self.bpe_ranks.get(pair, float('inf')))
|
|
||||||
if bigram not in self.bpe_ranks:
|
|
||||||
break
|
|
||||||
first, second = bigram
|
|
||||||
new_word = []
|
|
||||||
i = 0
|
|
||||||
while i < len(word):
|
|
||||||
try:
|
|
||||||
j = word.index(first, i)
|
|
||||||
new_word.extend(word[i:j])
|
|
||||||
i = j
|
|
||||||
except:
|
|
||||||
new_word.extend(word[i:])
|
|
||||||
break
|
|
||||||
|
|
||||||
if word[i] == first and i < len(word)-1 and word[i+1] == second:
|
|
||||||
new_word.append(first+second)
|
|
||||||
i += 2
|
|
||||||
else:
|
|
||||||
new_word.append(word[i])
|
|
||||||
i += 1
|
|
||||||
new_word = tuple(new_word)
|
|
||||||
word = new_word
|
|
||||||
if len(word) == 1:
|
|
||||||
break
|
|
||||||
else:
|
|
||||||
pairs = get_pairs(word)
|
|
||||||
word = ' '.join(word)
|
|
||||||
self.cache[token] = word
|
|
||||||
return word
|
|
||||||
|
|
||||||
def encode(self, text):
|
|
||||||
bpe_tokens = []
|
|
||||||
for token in re.findall(self.pat, text):
|
|
||||||
token = ''.join(self.byte_encoder[b] for b in token.encode('utf-8'))
|
|
||||||
bpe_tokens.extend(self.encoder[bpe_token] for bpe_token in self.bpe(token).split(' '))
|
|
||||||
return bpe_tokens
|
|
||||||
|
|
||||||
def decode(self, tokens):
|
|
||||||
text = ''.join([self.decoder[token] for token in tokens])
|
|
||||||
text = bytearray([self.byte_decoder[c] for c in text]).decode('utf-8', errors=self.errors)
|
|
||||||
return text
|
|
||||||
|
|
||||||
def get_encoder(model_name, models_dir):
|
|
||||||
with open(os.path.join(models_dir, model_name, 'encoder.json'), 'r') as f:
|
|
||||||
encoder = json.load(f)
|
|
||||||
with open(os.path.join(models_dir, model_name, 'vocab.bpe'), 'r', encoding="utf-8") as f:
|
|
||||||
bpe_data = f.read()
|
|
||||||
bpe_merges = [tuple(merge_str.split()) for merge_str in bpe_data.split('\n')[1:-1]]
|
|
||||||
return Encoder(
|
|
||||||
encoder=encoder,
|
|
||||||
bpe_merges=bpe_merges,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def download_vocab():
|
|
||||||
# Modified code from
|
|
||||||
subdir = 'gpt2_model'
|
|
||||||
if not os.path.exists(subdir):
|
|
||||||
os.makedirs(subdir)
|
|
||||||
subdir = subdir.replace('\\','/') # needed for Windows
|
|
||||||
|
|
||||||
for filename in ['encoder.json', 'vocab.bpe']:
|
|
||||||
|
|
||||||
r = requests.get("https://openaipublic.blob.core.windows.net/gpt-2/models/117M" + "/" + filename, stream=True)
|
|
||||||
|
|
||||||
with open(os.path.join(subdir, filename), 'wb') as f:
|
|
||||||
file_size = int(r.headers["content-length"])
|
|
||||||
chunk_size = 1000
|
|
||||||
with tqdm(ncols=100, desc="Fetching " + filename, total=file_size, unit_scale=True) as pbar:
|
|
||||||
# 1k for chunk_size, since Ethernet packet size is around 1500 bytes
|
|
||||||
for chunk in r.iter_content(chunk_size=chunk_size):
|
|
||||||
f.write(chunk)
|
|
||||||
pbar.update(chunk_size)
|
|
||||||
-442
@@ -1,442 +0,0 @@
|
|||||||
{
|
|
||||||
"cells": [
|
|
||||||
{
|
|
||||||
"cell_type": "markdown",
|
|
||||||
"id": "a9adc3bf-353c-411e-a471-0e92786e7103",
|
|
||||||
"metadata": {},
|
|
||||||
"source": [
|
|
||||||
"# Using BytePair encodding from `tiktoken`"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "code",
|
|
||||||
"execution_count": 1,
|
|
||||||
"id": "4036ffa3-0e5c-433a-a997-4ed7d33de0b2",
|
|
||||||
"metadata": {},
|
|
||||||
"outputs": [],
|
|
||||||
"source": [
|
|
||||||
"# !pip install tiktoken"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "code",
|
|
||||||
"execution_count": 4,
|
|
||||||
"id": "1c490fca-a48a-47fa-a299-322d1a08ad17",
|
|
||||||
"metadata": {},
|
|
||||||
"outputs": [
|
|
||||||
{
|
|
||||||
"name": "stdout",
|
|
||||||
"output_type": "stream",
|
|
||||||
"text": [
|
|
||||||
"tiktoken version: 0.5.2\n"
|
|
||||||
]
|
|
||||||
}
|
|
||||||
],
|
|
||||||
"source": [
|
|
||||||
"import importlib.metadata\n",
|
|
||||||
"\n",
|
|
||||||
"print(\"tiktoken version:\", importlib.metadata.version(\"tiktoken\"))"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "code",
|
|
||||||
"execution_count": 5,
|
|
||||||
"id": "0952667c-ce84-4f21-87db-59f52b44cec4",
|
|
||||||
"metadata": {},
|
|
||||||
"outputs": [],
|
|
||||||
"source": [
|
|
||||||
"import tiktoken\n",
|
|
||||||
"\n",
|
|
||||||
"tik_tokenizer = tiktoken.get_encoding(\"gpt2\")\n",
|
|
||||||
"\n",
|
|
||||||
"text = \"Hello, world. Is this-- a test?\""
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "code",
|
|
||||||
"execution_count": 6,
|
|
||||||
"id": "b039c350-18ad-48fb-8e6a-085702dfc330",
|
|
||||||
"metadata": {},
|
|
||||||
"outputs": [
|
|
||||||
{
|
|
||||||
"name": "stdout",
|
|
||||||
"output_type": "stream",
|
|
||||||
"text": [
|
|
||||||
"[15496, 11, 995, 13, 1148, 428, 438, 257, 1332, 30]\n"
|
|
||||||
]
|
|
||||||
}
|
|
||||||
],
|
|
||||||
"source": [
|
|
||||||
"integers = tik_tokenizer.encode(text, allowed_special={\"<|endoftext|>\"})\n",
|
|
||||||
"\n",
|
|
||||||
"print(integers)"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "code",
|
|
||||||
"execution_count": 7,
|
|
||||||
"id": "7b152ba4-04d3-41cc-849f-adedcfb8cabb",
|
|
||||||
"metadata": {},
|
|
||||||
"outputs": [
|
|
||||||
{
|
|
||||||
"name": "stdout",
|
|
||||||
"output_type": "stream",
|
|
||||||
"text": [
|
|
||||||
"Hello, world. Is this-- a test?\n"
|
|
||||||
]
|
|
||||||
}
|
|
||||||
],
|
|
||||||
"source": [
|
|
||||||
"strings = tik_tokenizer.decode(integers)\n",
|
|
||||||
"\n",
|
|
||||||
"print(strings)"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "code",
|
|
||||||
"execution_count": 8,
|
|
||||||
"id": "cf148a1a-316b-43ec-b7ba-1b6d409ce837",
|
|
||||||
"metadata": {},
|
|
||||||
"outputs": [
|
|
||||||
{
|
|
||||||
"name": "stdout",
|
|
||||||
"output_type": "stream",
|
|
||||||
"text": [
|
|
||||||
"50257\n"
|
|
||||||
]
|
|
||||||
}
|
|
||||||
],
|
|
||||||
"source": [
|
|
||||||
"print(tik_tokenizer.n_vocab)"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "markdown",
|
|
||||||
"id": "6a0b5d4f-2af9-40de-828c-063c4243e771",
|
|
||||||
"metadata": {},
|
|
||||||
"source": [
|
|
||||||
"# Using the original Byte-pair encoding implementation used in GPT-2"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "code",
|
|
||||||
"execution_count": 9,
|
|
||||||
"id": "0903108c-65cb-4ae1-967a-2155e25349c2",
|
|
||||||
"metadata": {},
|
|
||||||
"outputs": [],
|
|
||||||
"source": [
|
|
||||||
"from bpe_openai_gpt2 import get_encoder, download_vocab"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "code",
|
|
||||||
"execution_count": 10,
|
|
||||||
"id": "35dd8d7c-8c12-4b68-941a-0fd05882dd45",
|
|
||||||
"metadata": {},
|
|
||||||
"outputs": [
|
|
||||||
{
|
|
||||||
"name": "stderr",
|
|
||||||
"output_type": "stream",
|
|
||||||
"text": [
|
|
||||||
"Fetching encoder.json: 1.04Mit [00:28, 36.8kit/s] \n",
|
|
||||||
"Fetching vocab.bpe: 457kit [00:00, 458kit/s] \n"
|
|
||||||
]
|
|
||||||
}
|
|
||||||
],
|
|
||||||
"source": [
|
|
||||||
"download_vocab()"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "code",
|
|
||||||
"execution_count": 14,
|
|
||||||
"id": "1888a7a9-9c40-4fe0-99b4-ebd20aa1ffd0",
|
|
||||||
"metadata": {},
|
|
||||||
"outputs": [],
|
|
||||||
"source": [
|
|
||||||
"orig_tokenizer = get_encoder(model_name=\"gpt2_model\", models_dir=\".\")"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "code",
|
|
||||||
"execution_count": 15,
|
|
||||||
"id": "2740510c-a78a-4fba-ae18-2b156ba2dfef",
|
|
||||||
"metadata": {},
|
|
||||||
"outputs": [
|
|
||||||
{
|
|
||||||
"name": "stdout",
|
|
||||||
"output_type": "stream",
|
|
||||||
"text": [
|
|
||||||
"[15496, 11, 995, 13, 1148, 428, 438, 257, 1332, 30]\n"
|
|
||||||
]
|
|
||||||
}
|
|
||||||
],
|
|
||||||
"source": [
|
|
||||||
"integers = orig_tokenizer.encode(text)\n",
|
|
||||||
"\n",
|
|
||||||
"print(integers)"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "code",
|
|
||||||
"execution_count": 16,
|
|
||||||
"id": "434d115e-990d-42ad-88dd-31323a96e10f",
|
|
||||||
"metadata": {},
|
|
||||||
"outputs": [
|
|
||||||
{
|
|
||||||
"name": "stdout",
|
|
||||||
"output_type": "stream",
|
|
||||||
"text": [
|
|
||||||
"Hello, world. Is this-- a test?\n"
|
|
||||||
]
|
|
||||||
}
|
|
||||||
],
|
|
||||||
"source": [
|
|
||||||
"strings = orig_tokenizer.decode(integers)\n",
|
|
||||||
"\n",
|
|
||||||
"print(strings)"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "markdown",
|
|
||||||
"id": "4f63e8c6-707c-4d66-bcf8-dd790647cc86",
|
|
||||||
"metadata": {},
|
|
||||||
"source": [
|
|
||||||
"# Using the BytePair Tokenizer in HuggingFace transformers"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "code",
|
|
||||||
"execution_count": 12,
|
|
||||||
"id": "5bfff386-f725-4137-9c50-e5da0c38bea0",
|
|
||||||
"metadata": {},
|
|
||||||
"outputs": [],
|
|
||||||
"source": [
|
|
||||||
"# pip install transformers"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "code",
|
|
||||||
"execution_count": 13,
|
|
||||||
"id": "e9077bf4-f91f-42ad-ab76-f3d89128510e",
|
|
||||||
"metadata": {},
|
|
||||||
"outputs": [
|
|
||||||
{
|
|
||||||
"data": {
|
|
||||||
"text/plain": [
|
|
||||||
"'4.30.2'"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
"execution_count": 13,
|
|
||||||
"metadata": {},
|
|
||||||
"output_type": "execute_result"
|
|
||||||
}
|
|
||||||
],
|
|
||||||
"source": [
|
|
||||||
"import transformers\n",
|
|
||||||
"\n",
|
|
||||||
"transformers.__version__"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "code",
|
|
||||||
"execution_count": null,
|
|
||||||
"id": "16e06ee5-c4ca-4211-8e24-dbfd84b1d85b",
|
|
||||||
"metadata": {},
|
|
||||||
"outputs": [],
|
|
||||||
"source": [
|
|
||||||
"设置为国内可访问"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "code",
|
|
||||||
"execution_count": 1,
|
|
||||||
"id": "3e07ddc9-187e-4482-a7b5-7e4e9381d805",
|
|
||||||
"metadata": {},
|
|
||||||
"outputs": [
|
|
||||||
{
|
|
||||||
"name": "stdout",
|
|
||||||
"output_type": "stream",
|
|
||||||
"text": [
|
|
||||||
"env: HF_ENDPOINT=https://hf-mirror.com\n"
|
|
||||||
]
|
|
||||||
}
|
|
||||||
],
|
|
||||||
"source": [
|
|
||||||
"%env HF_ENDPOINT=https://hf-mirror.com"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "code",
|
|
||||||
"execution_count": 2,
|
|
||||||
"id": "a9839137-b8ea-4a2c-85fc-9a63064cf8c8",
|
|
||||||
"metadata": {},
|
|
||||||
"outputs": [
|
|
||||||
{
|
|
||||||
"data": {
|
|
||||||
"application/vnd.jupyter.widget-view+json": {
|
|
||||||
"model_id": "afc151b540664287aa60a4cbe90cdfeb",
|
|
||||||
"version_major": 2,
|
|
||||||
"version_minor": 0
|
|
||||||
},
|
|
||||||
"text/plain": [
|
|
||||||
"vocab.json: 0.00B [00:00, ?B/s]"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
"metadata": {},
|
|
||||||
"output_type": "display_data"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"data": {
|
|
||||||
"application/vnd.jupyter.widget-view+json": {
|
|
||||||
"model_id": "9a5d584e4adf40bca215b409b693dc02",
|
|
||||||
"version_major": 2,
|
|
||||||
"version_minor": 0
|
|
||||||
},
|
|
||||||
"text/plain": [
|
|
||||||
"merges.txt: 0.00B [00:00, ?B/s]"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
"metadata": {},
|
|
||||||
"output_type": "display_data"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"data": {
|
|
||||||
"application/vnd.jupyter.widget-view+json": {
|
|
||||||
"model_id": "a126ee77a9f94e58b1dcccd68e6d5bb1",
|
|
||||||
"version_major": 2,
|
|
||||||
"version_minor": 0
|
|
||||||
},
|
|
||||||
"text/plain": [
|
|
||||||
"config.json: 0%| | 0.00/367 [00:00<?, ?B/s]"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
"metadata": {},
|
|
||||||
"output_type": "display_data"
|
|
||||||
}
|
|
||||||
],
|
|
||||||
"source": [
|
|
||||||
"from transformers import GPT2Tokenizer\n",
|
|
||||||
"\n",
|
|
||||||
"hf_tokenizer = GPT2Tokenizer.from_pretrained(\"gpt2\")"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "code",
|
|
||||||
"execution_count": 11,
|
|
||||||
"id": "222cbd69-6a3d-4868-9c1f-421ffc9d5fe1",
|
|
||||||
"metadata": {},
|
|
||||||
"outputs": [
|
|
||||||
{
|
|
||||||
"data": {
|
|
||||||
"text/plain": [
|
|
||||||
"[15496, 11, 995, 13, 1148, 428, 438, 257, 1332, 30]"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
"execution_count": 11,
|
|
||||||
"metadata": {},
|
|
||||||
"output_type": "execute_result"
|
|
||||||
}
|
|
||||||
],
|
|
||||||
"source": [
|
|
||||||
"hf_tokenizer(strings)[\"input_ids\"]"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "markdown",
|
|
||||||
"id": "907a1ade-3401-4f2e-9017-7f58a60cbd98",
|
|
||||||
"metadata": {},
|
|
||||||
"source": [
|
|
||||||
"# A quick performance benchmark"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "code",
|
|
||||||
"execution_count": 17,
|
|
||||||
"id": "a61bb445-b151-4a2f-8180-d4004c503754",
|
|
||||||
"metadata": {},
|
|
||||||
"outputs": [],
|
|
||||||
"source": [
|
|
||||||
"with open('../01_main-chapter-code/the-verdict.txt', 'r', encoding='utf-8') as f:\n",
|
|
||||||
" raw_text = f.read()"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "code",
|
|
||||||
"execution_count": 18,
|
|
||||||
"id": "57f7c0a3-c1fd-4313-af34-68e78eb33653",
|
|
||||||
"metadata": {},
|
|
||||||
"outputs": [
|
|
||||||
{
|
|
||||||
"name": "stdout",
|
|
||||||
"output_type": "stream",
|
|
||||||
"text": [
|
|
||||||
"9.14 ms ± 74.5 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)\n"
|
|
||||||
]
|
|
||||||
}
|
|
||||||
],
|
|
||||||
"source": [
|
|
||||||
"%timeit orig_tokenizer.encode(raw_text)"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "code",
|
|
||||||
"execution_count": null,
|
|
||||||
"id": "036dd628-3591-46c9-a5ce-b20b105a8062",
|
|
||||||
"metadata": {},
|
|
||||||
"outputs": [],
|
|
||||||
"source": [
|
|
||||||
"%timeit tik_tokenizer.encode(raw_text)"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "code",
|
|
||||||
"execution_count": null,
|
|
||||||
"id": "b9c85b58-bfbc-465e-9a7e-477e53d55c90",
|
|
||||||
"metadata": {},
|
|
||||||
"outputs": [],
|
|
||||||
"source": [
|
|
||||||
"%timeit hf_tokenizer(raw_text)[\"input_ids\"]"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "code",
|
|
||||||
"execution_count": null,
|
|
||||||
"id": "7117107f-22a6-46b4-a442-712d50b3ac7a",
|
|
||||||
"metadata": {},
|
|
||||||
"outputs": [],
|
|
||||||
"source": [
|
|
||||||
"%timeit hf_tokenizer(raw_text, max_length=5145, truncation=True)[\"input_ids\"]"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "code",
|
|
||||||
"execution_count": null,
|
|
||||||
"id": "d81eaf6d-554b-44e3-aa19-2c3ae0030762",
|
|
||||||
"metadata": {},
|
|
||||||
"outputs": [],
|
|
||||||
"source": []
|
|
||||||
}
|
|
||||||
],
|
|
||||||
"metadata": {
|
|
||||||
"kernelspec": {
|
|
||||||
"display_name": "Python 3 (ipykernel)",
|
|
||||||
"language": "python",
|
|
||||||
"name": "python3"
|
|
||||||
},
|
|
||||||
"language_info": {
|
|
||||||
"codemirror_mode": {
|
|
||||||
"name": "ipython",
|
|
||||||
"version": 3
|
|
||||||
},
|
|
||||||
"file_extension": ".py",
|
|
||||||
"mimetype": "text/x-python",
|
|
||||||
"name": "python",
|
|
||||||
"nbconvert_exporter": "python",
|
|
||||||
"pygments_lexer": "ipython3",
|
|
||||||
"version": "3.11.5"
|
|
||||||
}
|
|
||||||
},
|
|
||||||
"nbformat": 4,
|
|
||||||
"nbformat_minor": 5
|
|
||||||
}
|
|
||||||
@@ -1,7 +1,7 @@
|
|||||||
# Chapter 2: Working with Text Data
|
# 第2章:使用文本数据
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
- [compare-bpe-tiktoken.ipynb](compare-bpe-tiktoken.ipynb) benchmarks various byte pair encoding implementations
|
- [compare-bpe-tiktoken.ipynb](compare-bpe-tiktoken.ipynb) 对各种字节对编码实现进行基准测试
|
||||||
- [bpe_openai_gpt2.py](bpe_openai_gpt2.py) is the original bytepair encoder code used by OpenAI
|
- [bpe_openai_gpt2.py](bpe_openai_gpt2.py) 是OpenAI使用的原始字节对编码器代码
|
||||||
|
|
||||||
|
|||||||
Binary file not shown.
@@ -41,6 +41,7 @@ import requests
|
|||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
from functools import lru_cache
|
from functools import lru_cache
|
||||||
|
|
||||||
|
# # 定义一个函数将字节转换为Unicode字符
|
||||||
@lru_cache()
|
@lru_cache()
|
||||||
def bytes_to_unicode():
|
def bytes_to_unicode():
|
||||||
"""
|
"""
|
||||||
@@ -52,6 +53,15 @@ def bytes_to_unicode():
|
|||||||
To avoid that, we want lookup tables between utf-8 bytes and unicode strings.
|
To avoid that, we want lookup tables between utf-8 bytes and unicode strings.
|
||||||
And avoids mapping to whitespace/control characters the bpe code barfs on.
|
And avoids mapping to whitespace/control characters the bpe code barfs on.
|
||||||
"""
|
"""
|
||||||
|
'''
|
||||||
|
返回一组UTF-8字节和相应的Unicode字符串列表。
|
||||||
|
可逆的BPE编码适用于Unicode字符串。
|
||||||
|
这意味着如果想要避免UNK(未知标记),则需要在词汇表中包含大量的Unicode字符。
|
||||||
|
当处理大约100亿标记的数据集时,最终需要大约5000个字符以确保良好的覆盖率。
|
||||||
|
这相当于正常情况下使用的32,000个BPE词汇表的显著比例。
|
||||||
|
为了避免这种情况,我们希望在UTF-8字节和Unicode字符串之间建立查找表。
|
||||||
|
并且要避免将BPE代码映射到空格/控制字符上,以免出现问题。
|
||||||
|
'''
|
||||||
bs = list(range(ord("!"), ord("~")+1))+list(range(ord("¡"), ord("¬")+1))+list(range(ord("®"), ord("ÿ")+1))
|
bs = list(range(ord("!"), ord("~")+1))+list(range(ord("¡"), ord("¬")+1))+list(range(ord("®"), ord("ÿ")+1))
|
||||||
cs = bs[:]
|
cs = bs[:]
|
||||||
n = 0
|
n = 0
|
||||||
@@ -63,11 +73,16 @@ def bytes_to_unicode():
|
|||||||
cs = [chr(n) for n in cs]
|
cs = [chr(n) for n in cs]
|
||||||
return dict(zip(bs, cs))
|
return dict(zip(bs, cs))
|
||||||
|
|
||||||
|
# 定义一个函数获取单词中的符号对
|
||||||
def get_pairs(word):
|
def get_pairs(word):
|
||||||
"""Return set of symbol pairs in a word.
|
"""Return set of symbol pairs in a word.
|
||||||
|
|
||||||
Word is represented as tuple of symbols (symbols being variable-length strings).
|
Word is represented as tuple of symbols (symbols being variable-length strings).
|
||||||
"""
|
"""
|
||||||
|
'''
|
||||||
|
返回单词中的符号对集合。
|
||||||
|
单词以符号元组的形式表示(其中符号是可变长度的字符串)。
|
||||||
|
'''
|
||||||
pairs = set()
|
pairs = set()
|
||||||
prev_char = word[0]
|
prev_char = word[0]
|
||||||
for char in word[1:]:
|
for char in word[1:]:
|
||||||
@@ -75,8 +90,10 @@ def get_pairs(word):
|
|||||||
prev_char = char
|
prev_char = char
|
||||||
return pairs
|
return pairs
|
||||||
|
|
||||||
|
# 定义一个使用字节对编码(BPE)进行编码和解码的Encoder类
|
||||||
class Encoder:
|
class Encoder:
|
||||||
def __init__(self, encoder, bpe_merges, errors='replace'):
|
def __init__(self, encoder, bpe_merges, errors='replace'):
|
||||||
|
# # 使用编码器字典、BPE合并和错误处理策略初始化Encoder
|
||||||
self.encoder = encoder
|
self.encoder = encoder
|
||||||
self.decoder = {v:k for k,v in self.encoder.items()}
|
self.decoder = {v:k for k,v in self.encoder.items()}
|
||||||
self.errors = errors # how to handle errors in decoding
|
self.errors = errors # how to handle errors in decoding
|
||||||
@@ -89,6 +106,7 @@ class Encoder:
|
|||||||
self.pat = re.compile(r"""'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+""")
|
self.pat = re.compile(r"""'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+""")
|
||||||
|
|
||||||
def bpe(self, token):
|
def bpe(self, token):
|
||||||
|
# 对给定的标记执行字节对编码
|
||||||
if token in self.cache:
|
if token in self.cache:
|
||||||
return self.cache[token]
|
return self.cache[token]
|
||||||
word = tuple(token)
|
word = tuple(token)
|
||||||
@@ -130,6 +148,7 @@ class Encoder:
|
|||||||
return word
|
return word
|
||||||
|
|
||||||
def encode(self, text):
|
def encode(self, text):
|
||||||
|
# 使用BPE对给定文本进行编码
|
||||||
bpe_tokens = []
|
bpe_tokens = []
|
||||||
for token in re.findall(self.pat, text):
|
for token in re.findall(self.pat, text):
|
||||||
token = ''.join(self.byte_encoder[b] for b in token.encode('utf-8'))
|
token = ''.join(self.byte_encoder[b] for b in token.encode('utf-8'))
|
||||||
@@ -137,10 +156,11 @@ class Encoder:
|
|||||||
return bpe_tokens
|
return bpe_tokens
|
||||||
|
|
||||||
def decode(self, tokens):
|
def decode(self, tokens):
|
||||||
|
# 将一系列标记解码回文本
|
||||||
text = ''.join([self.decoder[token] for token in tokens])
|
text = ''.join([self.decoder[token] for token in tokens])
|
||||||
text = bytearray([self.byte_decoder[c] for c in text]).decode('utf-8', errors=self.errors)
|
text = bytearray([self.byte_decoder[c] for c in text]).decode('utf-8', errors=self.errors)
|
||||||
return text
|
return text
|
||||||
|
# 定义一个函数获取特定模型的编码器
|
||||||
def get_encoder(model_name, models_dir):
|
def get_encoder(model_name, models_dir):
|
||||||
with open(os.path.join(models_dir, model_name, 'encoder.json'), 'r') as f:
|
with open(os.path.join(models_dir, model_name, 'encoder.json'), 'r') as f:
|
||||||
encoder = json.load(f)
|
encoder = json.load(f)
|
||||||
@@ -152,7 +172,7 @@ def get_encoder(model_name, models_dir):
|
|||||||
bpe_merges=bpe_merges,
|
bpe_merges=bpe_merges,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# 定义一个函数下载GPT-2模型的词汇文件
|
||||||
def download_vocab():
|
def download_vocab():
|
||||||
# Modified code from
|
# Modified code from
|
||||||
subdir = 'gpt2_model'
|
subdir = 'gpt2_model'
|
||||||
|
|||||||
@@ -5,7 +5,8 @@
|
|||||||
"id": "a9adc3bf-353c-411e-a471-0e92786e7103",
|
"id": "a9adc3bf-353c-411e-a471-0e92786e7103",
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"source": [
|
"source": [
|
||||||
"# Using BytePair encodding from `tiktoken`"
|
"# 使用来自 `tiktoken` 的字节对编码\n",
|
||||||
|
"tiktoken是一个用于OpenAI模型的快速BPE标记器。(BPE标记器是一种基于字节对编码(Byte Pair Encoding,简称BPE)的文本标记方法。字节对编码是一种数据压缩技术,但在自然语言处理中,它也被用于创建词汇表和对文本进行分词。)"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
@@ -20,7 +21,7 @@
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 4,
|
"execution_count": 1,
|
||||||
"id": "1c490fca-a48a-47fa-a299-322d1a08ad17",
|
"id": "1c490fca-a48a-47fa-a299-322d1a08ad17",
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [
|
"outputs": [
|
||||||
@@ -28,33 +29,33 @@
|
|||||||
"name": "stdout",
|
"name": "stdout",
|
||||||
"output_type": "stream",
|
"output_type": "stream",
|
||||||
"text": [
|
"text": [
|
||||||
"tiktoken version: 0.5.2\n"
|
"tiktoken version: 0.6.0\n"
|
||||||
]
|
]
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
"source": [
|
"source": [
|
||||||
"import importlib.metadata\n",
|
"import importlib.metadata\n",
|
||||||
"\n",
|
"# 打印出当前系统中安装的 tiktoken 库的版本号\n",
|
||||||
"print(\"tiktoken version:\", importlib.metadata.version(\"tiktoken\"))"
|
"print(\"tiktoken version:\", importlib.metadata.version(\"tiktoken\"))"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 5,
|
"execution_count": 2,
|
||||||
"id": "0952667c-ce84-4f21-87db-59f52b44cec4",
|
"id": "0952667c-ce84-4f21-87db-59f52b44cec4",
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
"import tiktoken\n",
|
"import tiktoken\n",
|
||||||
"\n",
|
"# 创建一个使用 GPT-2 模型的编码器对象\n",
|
||||||
"tik_tokenizer = tiktoken.get_encoding(\"gpt2\")\n",
|
"tik_tokenizer = tiktoken.get_encoding(\"gpt2\")\n",
|
||||||
"\n",
|
"# ,定义一个包含文本的字符串变量,使用 tik_tokenizer 对象对文本进行编码\n",
|
||||||
"text = \"Hello, world. Is this-- a test?\""
|
"text = \"Hello, world. Is this-- a test?\""
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 6,
|
"execution_count": 3,
|
||||||
"id": "b039c350-18ad-48fb-8e6a-085702dfc330",
|
"id": "b039c350-18ad-48fb-8e6a-085702dfc330",
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [
|
"outputs": [
|
||||||
@@ -67,6 +68,7 @@
|
|||||||
}
|
}
|
||||||
],
|
],
|
||||||
"source": [
|
"source": [
|
||||||
|
"# 参数 allowed_special,该参数指定哪些特殊字符允许出现在编码结果\n",
|
||||||
"integers = tik_tokenizer.encode(text, allowed_special={\"<|endoftext|>\"})\n",
|
"integers = tik_tokenizer.encode(text, allowed_special={\"<|endoftext|>\"})\n",
|
||||||
"\n",
|
"\n",
|
||||||
"print(integers)"
|
"print(integers)"
|
||||||
@@ -74,7 +76,7 @@
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 7,
|
"execution_count": 4,
|
||||||
"id": "7b152ba4-04d3-41cc-849f-adedcfb8cabb",
|
"id": "7b152ba4-04d3-41cc-849f-adedcfb8cabb",
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [
|
"outputs": [
|
||||||
@@ -87,6 +89,7 @@
|
|||||||
}
|
}
|
||||||
],
|
],
|
||||||
"source": [
|
"source": [
|
||||||
|
"# 进行解码\n",
|
||||||
"strings = tik_tokenizer.decode(integers)\n",
|
"strings = tik_tokenizer.decode(integers)\n",
|
||||||
"\n",
|
"\n",
|
||||||
"print(strings)"
|
"print(strings)"
|
||||||
@@ -94,7 +97,7 @@
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 8,
|
"execution_count": 5,
|
||||||
"id": "cf148a1a-316b-43ec-b7ba-1b6d409ce837",
|
"id": "cf148a1a-316b-43ec-b7ba-1b6d409ce837",
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [
|
"outputs": [
|
||||||
@@ -107,6 +110,7 @@
|
|||||||
}
|
}
|
||||||
],
|
],
|
||||||
"source": [
|
"source": [
|
||||||
|
"# 表示编码器的词汇表大小\n",
|
||||||
"print(tik_tokenizer.n_vocab)"
|
"print(tik_tokenizer.n_vocab)"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
@@ -115,12 +119,12 @@
|
|||||||
"id": "6a0b5d4f-2af9-40de-828c-063c4243e771",
|
"id": "6a0b5d4f-2af9-40de-828c-063c4243e771",
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"source": [
|
"source": [
|
||||||
"# Using the original Byte-pair encoding implementation used in GPT-2"
|
"# 使用在GPT-2中使用的原始字节对编码实现"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 9,
|
"execution_count": 6,
|
||||||
"id": "0903108c-65cb-4ae1-967a-2155e25349c2",
|
"id": "0903108c-65cb-4ae1-967a-2155e25349c2",
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
@@ -130,7 +134,7 @@
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 10,
|
"execution_count": 7,
|
||||||
"id": "35dd8d7c-8c12-4b68-941a-0fd05882dd45",
|
"id": "35dd8d7c-8c12-4b68-941a-0fd05882dd45",
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [
|
"outputs": [
|
||||||
@@ -138,8 +142,8 @@
|
|||||||
"name": "stderr",
|
"name": "stderr",
|
||||||
"output_type": "stream",
|
"output_type": "stream",
|
||||||
"text": [
|
"text": [
|
||||||
"Fetching encoder.json: 1.04Mit [00:28, 36.8kit/s] \n",
|
"Fetching encoder.json: 1.04Mit [00:02, 502kit/s] \n",
|
||||||
"Fetching vocab.bpe: 457kit [00:00, 458kit/s] \n"
|
"Fetching vocab.bpe: 457kit [00:02, 212kit/s] \n"
|
||||||
]
|
]
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
@@ -149,7 +153,7 @@
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 14,
|
"execution_count": 8,
|
||||||
"id": "1888a7a9-9c40-4fe0-99b4-ebd20aa1ffd0",
|
"id": "1888a7a9-9c40-4fe0-99b4-ebd20aa1ffd0",
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
@@ -159,7 +163,7 @@
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 15,
|
"execution_count": 9,
|
||||||
"id": "2740510c-a78a-4fba-ae18-2b156ba2dfef",
|
"id": "2740510c-a78a-4fba-ae18-2b156ba2dfef",
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [
|
"outputs": [
|
||||||
@@ -179,7 +183,7 @@
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 16,
|
"execution_count": 10,
|
||||||
"id": "434d115e-990d-42ad-88dd-31323a96e10f",
|
"id": "434d115e-990d-42ad-88dd-31323a96e10f",
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [
|
"outputs": [
|
||||||
@@ -202,12 +206,14 @@
|
|||||||
"id": "4f63e8c6-707c-4d66-bcf8-dd790647cc86",
|
"id": "4f63e8c6-707c-4d66-bcf8-dd790647cc86",
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"source": [
|
"source": [
|
||||||
"# Using the BytePair Tokenizer in HuggingFace transformers"
|
"# 使用HuggingFace Transformers中的BytePair Tokenizer\n",
|
||||||
|
"\r\n",
|
||||||
|
"\r\n"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 12,
|
"execution_count": 11,
|
||||||
"id": "5bfff386-f725-4137-9c50-e5da0c38bea0",
|
"id": "5bfff386-f725-4137-9c50-e5da0c38bea0",
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
@@ -217,17 +223,17 @@
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 13,
|
"execution_count": 12,
|
||||||
"id": "e9077bf4-f91f-42ad-ab76-f3d89128510e",
|
"id": "e9077bf4-f91f-42ad-ab76-f3d89128510e",
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [
|
"outputs": [
|
||||||
{
|
{
|
||||||
"data": {
|
"data": {
|
||||||
"text/plain": [
|
"text/plain": [
|
||||||
"'4.30.2'"
|
"'4.33.3'"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
"execution_count": 13,
|
"execution_count": 12,
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"output_type": "execute_result"
|
"output_type": "execute_result"
|
||||||
}
|
}
|
||||||
@@ -240,47 +246,19 @@
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": null,
|
"execution_count": 13,
|
||||||
"id": "16e06ee5-c4ca-4211-8e24-dbfd84b1d85b",
|
|
||||||
"metadata": {},
|
|
||||||
"outputs": [],
|
|
||||||
"source": [
|
|
||||||
"设置为国内可访问"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "code",
|
|
||||||
"execution_count": 1,
|
|
||||||
"id": "3e07ddc9-187e-4482-a7b5-7e4e9381d805",
|
|
||||||
"metadata": {},
|
|
||||||
"outputs": [
|
|
||||||
{
|
|
||||||
"name": "stdout",
|
|
||||||
"output_type": "stream",
|
|
||||||
"text": [
|
|
||||||
"env: HF_ENDPOINT=https://hf-mirror.com\n"
|
|
||||||
]
|
|
||||||
}
|
|
||||||
],
|
|
||||||
"source": [
|
|
||||||
"%env HF_ENDPOINT=https://hf-mirror.com"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "code",
|
|
||||||
"execution_count": 2,
|
|
||||||
"id": "a9839137-b8ea-4a2c-85fc-9a63064cf8c8",
|
"id": "a9839137-b8ea-4a2c-85fc-9a63064cf8c8",
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [
|
"outputs": [
|
||||||
{
|
{
|
||||||
"data": {
|
"data": {
|
||||||
"application/vnd.jupyter.widget-view+json": {
|
"application/vnd.jupyter.widget-view+json": {
|
||||||
"model_id": "afc151b540664287aa60a4cbe90cdfeb",
|
"model_id": "ef488b57e4214b76a8913a4704de7e15",
|
||||||
"version_major": 2,
|
"version_major": 2,
|
||||||
"version_minor": 0
|
"version_minor": 0
|
||||||
},
|
},
|
||||||
"text/plain": [
|
"text/plain": [
|
||||||
"vocab.json: 0.00B [00:00, ?B/s]"
|
"vocab.json: 0%| | 0.00/1.04M [00:00<?, ?B/s]"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
@@ -289,12 +267,12 @@
|
|||||||
{
|
{
|
||||||
"data": {
|
"data": {
|
||||||
"application/vnd.jupyter.widget-view+json": {
|
"application/vnd.jupyter.widget-view+json": {
|
||||||
"model_id": "9a5d584e4adf40bca215b409b693dc02",
|
"model_id": "9ab86eb5125640dba6d59a5744f2d927",
|
||||||
"version_major": 2,
|
"version_major": 2,
|
||||||
"version_minor": 0
|
"version_minor": 0
|
||||||
},
|
},
|
||||||
"text/plain": [
|
"text/plain": [
|
||||||
"merges.txt: 0.00B [00:00, ?B/s]"
|
"merges.txt: 0%| | 0.00/456k [00:00<?, ?B/s]"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
@@ -303,12 +281,26 @@
|
|||||||
{
|
{
|
||||||
"data": {
|
"data": {
|
||||||
"application/vnd.jupyter.widget-view+json": {
|
"application/vnd.jupyter.widget-view+json": {
|
||||||
"model_id": "a126ee77a9f94e58b1dcccd68e6d5bb1",
|
"model_id": "073f03eb3ef541e092c8f344f65c34da",
|
||||||
"version_major": 2,
|
"version_major": 2,
|
||||||
"version_minor": 0
|
"version_minor": 0
|
||||||
},
|
},
|
||||||
"text/plain": [
|
"text/plain": [
|
||||||
"config.json: 0%| | 0.00/367 [00:00<?, ?B/s]"
|
"tokenizer_config.json: 0%| | 0.00/26.0 [00:00<?, ?B/s]"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"metadata": {},
|
||||||
|
"output_type": "display_data"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"data": {
|
||||||
|
"application/vnd.jupyter.widget-view+json": {
|
||||||
|
"model_id": "b92abbedc99a4bb9ad8bf5f3b3e8b140",
|
||||||
|
"version_major": 2,
|
||||||
|
"version_minor": 0
|
||||||
|
},
|
||||||
|
"text/plain": [
|
||||||
|
"config.json: 0%| | 0.00/665 [00:00<?, ?B/s]"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
@@ -317,13 +309,13 @@
|
|||||||
],
|
],
|
||||||
"source": [
|
"source": [
|
||||||
"from transformers import GPT2Tokenizer\n",
|
"from transformers import GPT2Tokenizer\n",
|
||||||
"\n",
|
"# 使用 HuggingFace Transformers 提供的 GPT2Tokenizer 类,创建一个预训练的 GPT-2 模型的标记器对象\n",
|
||||||
"hf_tokenizer = GPT2Tokenizer.from_pretrained(\"gpt2\")"
|
"hf_tokenizer = GPT2Tokenizer.from_pretrained(\"gpt2\")"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 11,
|
"execution_count": 14,
|
||||||
"id": "222cbd69-6a3d-4868-9c1f-421ffc9d5fe1",
|
"id": "222cbd69-6a3d-4868-9c1f-421ffc9d5fe1",
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [
|
"outputs": [
|
||||||
@@ -333,7 +325,7 @@
|
|||||||
"[15496, 11, 995, 13, 1148, 428, 438, 257, 1332, 30]"
|
"[15496, 11, 995, 13, 1148, 428, 438, 257, 1332, 30]"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
"execution_count": 11,
|
"execution_count": 14,
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"output_type": "execute_result"
|
"output_type": "execute_result"
|
||||||
}
|
}
|
||||||
@@ -347,12 +339,12 @@
|
|||||||
"id": "907a1ade-3401-4f2e-9017-7f58a60cbd98",
|
"id": "907a1ade-3401-4f2e-9017-7f58a60cbd98",
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"source": [
|
"source": [
|
||||||
"# A quick performance benchmark"
|
"# 快速测试"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 17,
|
"execution_count": 15,
|
||||||
"id": "a61bb445-b151-4a2f-8180-d4004c503754",
|
"id": "a61bb445-b151-4a2f-8180-d4004c503754",
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
@@ -363,7 +355,7 @@
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 18,
|
"execution_count": 16,
|
||||||
"id": "57f7c0a3-c1fd-4313-af34-68e78eb33653",
|
"id": "57f7c0a3-c1fd-4313-af34-68e78eb33653",
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [
|
"outputs": [
|
||||||
@@ -371,17 +363,18 @@
|
|||||||
"name": "stdout",
|
"name": "stdout",
|
||||||
"output_type": "stream",
|
"output_type": "stream",
|
||||||
"text": [
|
"text": [
|
||||||
"9.14 ms ± 74.5 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)\n"
|
"14.6 ms ± 201 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)\n"
|
||||||
]
|
]
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
"source": [
|
"source": [
|
||||||
|
"# 测量其运行时间,从而进行性能评估\n",
|
||||||
"%timeit orig_tokenizer.encode(raw_text)"
|
"%timeit orig_tokenizer.encode(raw_text)"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 19,
|
"execution_count": 17,
|
||||||
"id": "036dd628-3591-46c9-a5ce-b20b105a8062",
|
"id": "036dd628-3591-46c9-a5ce-b20b105a8062",
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [
|
"outputs": [
|
||||||
@@ -389,7 +382,7 @@
|
|||||||
"name": "stdout",
|
"name": "stdout",
|
||||||
"output_type": "stream",
|
"output_type": "stream",
|
||||||
"text": [
|
"text": [
|
||||||
"3.28 ms ± 2.66 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)\n"
|
"2.9 ms ± 42.8 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)\n"
|
||||||
]
|
]
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
@@ -399,7 +392,7 @@
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 20,
|
"execution_count": 18,
|
||||||
"id": "b9c85b58-bfbc-465e-9a7e-477e53d55c90",
|
"id": "b9c85b58-bfbc-465e-9a7e-477e53d55c90",
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [
|
"outputs": [
|
||||||
@@ -414,7 +407,7 @@
|
|||||||
"name": "stdout",
|
"name": "stdout",
|
||||||
"output_type": "stream",
|
"output_type": "stream",
|
||||||
"text": [
|
"text": [
|
||||||
"19.1 ms ± 2.43 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)\n"
|
"28.6 ms ± 643 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)\n"
|
||||||
]
|
]
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
@@ -424,15 +417,15 @@
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 21,
|
"execution_count": 20,
|
||||||
"id": "7117107f-22a6-46b4-a442-712d50b3ac7a",
|
"id": "67159770-e131-411a-b9fe-037b4d931c9d",
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [
|
"outputs": [
|
||||||
{
|
{
|
||||||
"name": "stdout",
|
"name": "stdout",
|
||||||
"output_type": "stream",
|
"output_type": "stream",
|
||||||
"text": [
|
"text": [
|
||||||
"18.8 ms ± 2.41 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)\n"
|
"28.3 ms ± 601 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)\n"
|
||||||
]
|
]
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
@@ -443,7 +436,15 @@
|
|||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": null,
|
"execution_count": null,
|
||||||
"id": "d81eaf6d-554b-44e3-aa19-2c3ae0030762",
|
"id": "e1cf5928-b6c8-4493-9e51-cb2795b2482c",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": []
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"id": "5f70f164-5f73-479e-bfaf-914243016439",
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": []
|
"source": []
|
||||||
@@ -465,7 +466,7 @@
|
|||||||
"name": "python",
|
"name": "python",
|
||||||
"nbconvert_exporter": "python",
|
"nbconvert_exporter": "python",
|
||||||
"pygments_lexer": "ipython3",
|
"pygments_lexer": "ipython3",
|
||||||
"version": "3.11.5"
|
"version": "3.8.17"
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"nbformat": 4,
|
"nbformat": 4,
|
||||||
|
|||||||
-486
@@ -1,486 +0,0 @@
|
|||||||
{
|
|
||||||
"cells": [
|
|
||||||
{
|
|
||||||
"cell_type": "markdown",
|
|
||||||
"id": "063850ab-22b0-4838-b53a-9bb11757d9d0",
|
|
||||||
"metadata": {},
|
|
||||||
"source": [
|
|
||||||
"# Embedding Layers and Linear Layers"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "markdown",
|
|
||||||
"id": "0315c598-701f-46ff-8806-15813cad0e51",
|
|
||||||
"metadata": {},
|
|
||||||
"source": [
|
|
||||||
"- Embedding layers in PyTorch accomplish the same as linear layers that perform matrix multiplications; the reason we use embedding layers is computational efficiency\n",
|
|
||||||
"- We will take a look at this relationship step by step using code examples in PyTorch"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "code",
|
|
||||||
"execution_count": 1,
|
|
||||||
"id": "061720f4-f025-4640-82a0-15098fa94cf9",
|
|
||||||
"metadata": {},
|
|
||||||
"outputs": [
|
|
||||||
{
|
|
||||||
"name": "stdout",
|
|
||||||
"output_type": "stream",
|
|
||||||
"text": [
|
|
||||||
"PyTorch version: 2.1.0.post301\n"
|
|
||||||
]
|
|
||||||
}
|
|
||||||
],
|
|
||||||
"source": [
|
|
||||||
"import torch\n",
|
|
||||||
"\n",
|
|
||||||
"print(\"PyTorch version:\", torch.__version__)"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "markdown",
|
|
||||||
"id": "a7895a66-7f69-4f62-9361-5c9da2eb76ef",
|
|
||||||
"metadata": {},
|
|
||||||
"source": [
|
|
||||||
"## Using nn.Embedding"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "code",
|
|
||||||
"execution_count": 2,
|
|
||||||
"id": "cc489ea5-73db-40b9-959e-0d70cae25f40",
|
|
||||||
"metadata": {},
|
|
||||||
"outputs": [],
|
|
||||||
"source": [
|
|
||||||
"# Suppose we have the following 3 training examples,\n",
|
|
||||||
"# which may represent token IDs in a LLM context\n",
|
|
||||||
"idx = torch.tensor([2, 3, 1])\n",
|
|
||||||
"\n",
|
|
||||||
"# The number of rows in the embedding matrix can be determined\n",
|
|
||||||
"# by obtaining the largest token ID + 1.\n",
|
|
||||||
"# If the highest token ID is 3, then we want 4 rows, for the possible\n",
|
|
||||||
"# token IDs 0, 1, 2, 3\n",
|
|
||||||
"num_idx = max(idx)+1\n",
|
|
||||||
"\n",
|
|
||||||
"# The desired embedding dimension is a hyperparameter\n",
|
|
||||||
"out_dim = 5"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "markdown",
|
|
||||||
"id": "93d83a6e-8543-40af-b253-fe647640bf36",
|
|
||||||
"metadata": {},
|
|
||||||
"source": [
|
|
||||||
"- Let's implement a simple embedding layer:"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "code",
|
|
||||||
"execution_count": 3,
|
|
||||||
"id": "60a7c104-36e1-4b28-bd02-a24a1099dc66",
|
|
||||||
"metadata": {},
|
|
||||||
"outputs": [],
|
|
||||||
"source": [
|
|
||||||
"# We use the random seed for reproducibility since\n",
|
|
||||||
"# weights in the embedding layer are initialized with\n",
|
|
||||||
"# small random values\n",
|
|
||||||
"torch.manual_seed(123)\n",
|
|
||||||
"\n",
|
|
||||||
"embedding = torch.nn.Embedding(num_idx, out_dim)"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "markdown",
|
|
||||||
"id": "dd96c00a-3297-4a50-8bfc-247aaea7e761",
|
|
||||||
"metadata": {},
|
|
||||||
"source": [
|
|
||||||
"We can optionally take a look at the embedding weights:"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "code",
|
|
||||||
"execution_count": 4,
|
|
||||||
"id": "595f603e-8d2a-4171-8f94-eac8106b2e57",
|
|
||||||
"metadata": {},
|
|
||||||
"outputs": [
|
|
||||||
{
|
|
||||||
"data": {
|
|
||||||
"text/plain": [
|
|
||||||
"Parameter containing:\n",
|
|
||||||
"tensor([[ 0.3374, -0.1778, -0.3035, -0.5880, 1.5810],\n",
|
|
||||||
" [ 1.3010, 1.2753, -0.2010, -0.1606, -0.4015],\n",
|
|
||||||
" [ 0.6957, -1.8061, -1.1589, 0.3255, -0.6315],\n",
|
|
||||||
" [-2.8400, -0.7849, -1.4096, -0.4076, 0.7953]], requires_grad=True)"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
"execution_count": 4,
|
|
||||||
"metadata": {},
|
|
||||||
"output_type": "execute_result"
|
|
||||||
}
|
|
||||||
],
|
|
||||||
"source": [
|
|
||||||
"embedding.weight"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "markdown",
|
|
||||||
"id": "c86eb562-61e2-4171-ab6e-b410a1fd5c18",
|
|
||||||
"metadata": {},
|
|
||||||
"source": [
|
|
||||||
"- We can then use the embedding layers to obtain the vector representation of a training example with ID 1:"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "code",
|
|
||||||
"execution_count": 5,
|
|
||||||
"id": "8bbc0255-4805-4be9-9f4c-1d0d967ef9d5",
|
|
||||||
"metadata": {},
|
|
||||||
"outputs": [
|
|
||||||
{
|
|
||||||
"data": {
|
|
||||||
"text/plain": [
|
|
||||||
"tensor([[ 1.3010, 1.2753, -0.2010, -0.1606, -0.4015]],\n",
|
|
||||||
" grad_fn=<EmbeddingBackward0>)"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
"execution_count": 5,
|
|
||||||
"metadata": {},
|
|
||||||
"output_type": "execute_result"
|
|
||||||
}
|
|
||||||
],
|
|
||||||
"source": [
|
|
||||||
"embedding(torch.tensor([1]))"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "markdown",
|
|
||||||
"id": "6a4d47f2-4691-47b8-9855-2528b6c285c9",
|
|
||||||
"metadata": {},
|
|
||||||
"source": [
|
|
||||||
"- Below is a visualization of what happens under the hood:"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "markdown",
|
|
||||||
"id": "12ffd155-7190-44b1-b6b6-45b11d6fe83b",
|
|
||||||
"metadata": {},
|
|
||||||
"source": [
|
|
||||||
"<img src=\"images/1.png\" width=\"400px\">"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "markdown",
|
|
||||||
"id": "87d1311b-cfb2-4afc-9e25-e4ecf35370d9",
|
|
||||||
"metadata": {},
|
|
||||||
"source": [
|
|
||||||
"- Similarly, we can use embedding layers to obtain the vector representation of a training example with ID 2:"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "code",
|
|
||||||
"execution_count": 6,
|
|
||||||
"id": "c309266a-c601-4633-9404-2e10b1cdde8c",
|
|
||||||
"metadata": {},
|
|
||||||
"outputs": [
|
|
||||||
{
|
|
||||||
"data": {
|
|
||||||
"text/plain": [
|
|
||||||
"tensor([[ 0.6957, -1.8061, -1.1589, 0.3255, -0.6315]],\n",
|
|
||||||
" grad_fn=<EmbeddingBackward0>)"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
"execution_count": 6,
|
|
||||||
"metadata": {},
|
|
||||||
"output_type": "execute_result"
|
|
||||||
}
|
|
||||||
],
|
|
||||||
"source": [
|
|
||||||
"embedding(torch.tensor([2]))"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "markdown",
|
|
||||||
"id": "7ad3b601-f68c-41b1-a28d-b624b94ef383",
|
|
||||||
"metadata": {},
|
|
||||||
"source": [
|
|
||||||
"<img src=\"images/2.png\" width=\"400px\">"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "markdown",
|
|
||||||
"id": "27dd54bd-85ae-4887-9c5e-3139da361cf4",
|
|
||||||
"metadata": {},
|
|
||||||
"source": [
|
|
||||||
"- Now, let's convert all the training examples we have defined previously:"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "code",
|
|
||||||
"execution_count": 7,
|
|
||||||
"id": "0191aa4b-f6a8-4b0d-9c36-65e82b81d071",
|
|
||||||
"metadata": {},
|
|
||||||
"outputs": [
|
|
||||||
{
|
|
||||||
"data": {
|
|
||||||
"text/plain": [
|
|
||||||
"tensor([[ 0.6957, -1.8061, -1.1589, 0.3255, -0.6315],\n",
|
|
||||||
" [-2.8400, -0.7849, -1.4096, -0.4076, 0.7953],\n",
|
|
||||||
" [ 1.3010, 1.2753, -0.2010, -0.1606, -0.4015]],\n",
|
|
||||||
" grad_fn=<EmbeddingBackward0>)"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
"execution_count": 7,
|
|
||||||
"metadata": {},
|
|
||||||
"output_type": "execute_result"
|
|
||||||
}
|
|
||||||
],
|
|
||||||
"source": [
|
|
||||||
"idx = torch.tensor([2, 3, 1])\n",
|
|
||||||
"embedding(idx)"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "markdown",
|
|
||||||
"id": "146cf8eb-c517-4cd4-aa91-0e818fed7651",
|
|
||||||
"metadata": {},
|
|
||||||
"source": [
|
|
||||||
"- Under the hood, it's still the same look-up concept:"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "markdown",
|
|
||||||
"id": "b392eb43-0bda-4821-b446-b8dcbee8ae00",
|
|
||||||
"metadata": {},
|
|
||||||
"source": [
|
|
||||||
"<img src=\"images/3.png\" width=\"450px\">"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "markdown",
|
|
||||||
"id": "f0fe863b-d6a3-48f3-ace5-09ecd0eb7b59",
|
|
||||||
"metadata": {},
|
|
||||||
"source": [
|
|
||||||
"## Using nn.Linear"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "markdown",
|
|
||||||
"id": "138de6a4-2689-4c1f-96af-7899b2d82a4e",
|
|
||||||
"metadata": {},
|
|
||||||
"source": [
|
|
||||||
"- Now, we will demonstrate that the embedding layer above accomplishes exactly the same as `nn.Linear` layer on a one-hot encoded representation in PyTorch\n",
|
|
||||||
"- First, let's convert the token IDs into a one-hot representation:"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "code",
|
|
||||||
"execution_count": 8,
|
|
||||||
"id": "b5bb56cf-bc73-41ab-b107-91a43f77bdba",
|
|
||||||
"metadata": {},
|
|
||||||
"outputs": [
|
|
||||||
{
|
|
||||||
"data": {
|
|
||||||
"text/plain": [
|
|
||||||
"tensor([[0, 0, 1, 0],\n",
|
|
||||||
" [0, 0, 0, 1],\n",
|
|
||||||
" [0, 1, 0, 0]])"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
"execution_count": 8,
|
|
||||||
"metadata": {},
|
|
||||||
"output_type": "execute_result"
|
|
||||||
}
|
|
||||||
],
|
|
||||||
"source": [
|
|
||||||
"onehot = torch.nn.functional.one_hot(idx)\n",
|
|
||||||
"onehot"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "markdown",
|
|
||||||
"id": "aa45dfdf-fb26-4514-a176-75224f5f179b",
|
|
||||||
"metadata": {},
|
|
||||||
"source": [
|
|
||||||
"- Next, we initialize a `Linear` layer, which caries out a matrix multiplication $X W^\\top$:"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "code",
|
|
||||||
"execution_count": 9,
|
|
||||||
"id": "ae04c1ed-242e-4dd7-b8f7-4b7e4caae383",
|
|
||||||
"metadata": {},
|
|
||||||
"outputs": [],
|
|
||||||
"source": [
|
|
||||||
"torch.manual_seed(123)\n",
|
|
||||||
"linear = torch.nn.Linear(num_idx, out_dim, bias=False)"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "markdown",
|
|
||||||
"id": "63efb98e-5cc4-4e8d-9fe6-ef0ad29ae2d7",
|
|
||||||
"metadata": {},
|
|
||||||
"source": [
|
|
||||||
"- Note that the linear layer in PyTorch is also initialized with small random weights; to directly compare it to the `Embedding` layer above, we have to use the same small random weights, which is why we reassign them here:"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "code",
|
|
||||||
"execution_count": 10,
|
|
||||||
"id": "a3b90d69-761c-486e-bd19-b38a2988fe62",
|
|
||||||
"metadata": {},
|
|
||||||
"outputs": [],
|
|
||||||
"source": [
|
|
||||||
"linear.weight = torch.nn.Parameter(embedding.weight.T.detach())"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "markdown",
|
|
||||||
"id": "9116482d-f1f9-45e2-9bf3-7ef5e9003898",
|
|
||||||
"metadata": {},
|
|
||||||
"source": [
|
|
||||||
"- Now we can use the linear layer on the one-hot encoded representation of the inputs:"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "code",
|
|
||||||
"execution_count": 11,
|
|
||||||
"id": "90d2b0dd-9f1d-4c0f-bb16-1f6ce6b8ac2c",
|
|
||||||
"metadata": {},
|
|
||||||
"outputs": [
|
|
||||||
{
|
|
||||||
"data": {
|
|
||||||
"text/plain": [
|
|
||||||
"tensor([[ 0.6957, -1.8061, -1.1589, 0.3255, -0.6315],\n",
|
|
||||||
" [-2.8400, -0.7849, -1.4096, -0.4076, 0.7953],\n",
|
|
||||||
" [ 1.3010, 1.2753, -0.2010, -0.1606, -0.4015]], grad_fn=<MmBackward0>)"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
"execution_count": 11,
|
|
||||||
"metadata": {},
|
|
||||||
"output_type": "execute_result"
|
|
||||||
}
|
|
||||||
],
|
|
||||||
"source": [
|
|
||||||
"linear(onehot.float())"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "markdown",
|
|
||||||
"id": "f6204bc8-92e2-4546-9cda-574fe1360fa2",
|
|
||||||
"metadata": {},
|
|
||||||
"source": [
|
|
||||||
"As we can see, this is exactly the same as what we got when we used the embedding layer:"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "code",
|
|
||||||
"execution_count": 12,
|
|
||||||
"id": "2b057649-3176-4a54-b58c-fd8fbf818c61",
|
|
||||||
"metadata": {},
|
|
||||||
"outputs": [
|
|
||||||
{
|
|
||||||
"data": {
|
|
||||||
"text/plain": [
|
|
||||||
"tensor([[ 0.6957, -1.8061, -1.1589, 0.3255, -0.6315],\n",
|
|
||||||
" [-2.8400, -0.7849, -1.4096, -0.4076, 0.7953],\n",
|
|
||||||
" [ 1.3010, 1.2753, -0.2010, -0.1606, -0.4015]],\n",
|
|
||||||
" grad_fn=<EmbeddingBackward0>)"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
"execution_count": 12,
|
|
||||||
"metadata": {},
|
|
||||||
"output_type": "execute_result"
|
|
||||||
}
|
|
||||||
],
|
|
||||||
"source": [
|
|
||||||
"embedding(idx)"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "markdown",
|
|
||||||
"id": "0e447639-8952-460e-8c8f-cf9e23c368c9",
|
|
||||||
"metadata": {},
|
|
||||||
"source": [
|
|
||||||
"- What happens under the hood is the following computation for the first training example's token ID:"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "markdown",
|
|
||||||
"id": "1830eccf-a707-4753-a24a-9b103f55594a",
|
|
||||||
"metadata": {},
|
|
||||||
"source": [
|
|
||||||
"<img src=\"images/4.png\" width=\"450px\">"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "markdown",
|
|
||||||
"id": "9ce5211a-14e6-46aa-a3a8-14609f086e97",
|
|
||||||
"metadata": {},
|
|
||||||
"source": [
|
|
||||||
"- And for the second training example's token ID:"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "markdown",
|
|
||||||
"id": "173f6026-a461-44da-b9c5-f571f8ec8bf3",
|
|
||||||
"metadata": {},
|
|
||||||
"source": [
|
|
||||||
"<img src=\"images/5.png\" width=\"450px\">"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "markdown",
|
|
||||||
"id": "e2608049-f5d1-49a9-a14b-82695fc32e6a",
|
|
||||||
"metadata": {},
|
|
||||||
"source": [
|
|
||||||
"- Since all but one index in each one-hot encoded row are 0 (by design), this matrix multiplication is essentially the same as a look-up of the one-hot elements\n",
|
|
||||||
"- This use of the matrix multiplication on one-hot encodings is equivalent to the embedding layer look-up but can be inefficient if we work with large embedding matrices, because there are a lot of wasteful multiplications by zero"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "code",
|
|
||||||
"execution_count": null,
|
|
||||||
"id": "5eacc005-86fc-490c-8f6a-dc37d8a0df7c",
|
|
||||||
"metadata": {},
|
|
||||||
"outputs": [],
|
|
||||||
"source": []
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "code",
|
|
||||||
"execution_count": null,
|
|
||||||
"id": "a1f63c81-1ee3-40a1-9ef2-14ff18fb4f05",
|
|
||||||
"metadata": {},
|
|
||||||
"outputs": [],
|
|
||||||
"source": []
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "code",
|
|
||||||
"execution_count": null,
|
|
||||||
"id": "c71959bb-facf-44fd-8edb-b67f7752f034",
|
|
||||||
"metadata": {},
|
|
||||||
"outputs": [],
|
|
||||||
"source": []
|
|
||||||
}
|
|
||||||
],
|
|
||||||
"metadata": {
|
|
||||||
"kernelspec": {
|
|
||||||
"display_name": "Python 3 (ipykernel)",
|
|
||||||
"language": "python",
|
|
||||||
"name": "python3"
|
|
||||||
},
|
|
||||||
"language_info": {
|
|
||||||
"codemirror_mode": {
|
|
||||||
"name": "ipython",
|
|
||||||
"version": 3
|
|
||||||
},
|
|
||||||
"file_extension": ".py",
|
|
||||||
"mimetype": "text/x-python",
|
|
||||||
"name": "python",
|
|
||||||
"nbconvert_exporter": "python",
|
|
||||||
"pygments_lexer": "ipython3",
|
|
||||||
"version": "3.11.5"
|
|
||||||
}
|
|
||||||
},
|
|
||||||
"nbformat": 4,
|
|
||||||
"nbformat_minor": 5
|
|
||||||
}
|
|
||||||
@@ -1,3 +1,4 @@
|
|||||||
# Chapter 2: Working with Text Data
|
# 第2章:使用文本数据
|
||||||
|
|
||||||
|
- [embeddings-and-linear-layers.ipynb](embeddings-and-linear-layers.ipynb) 包含可选(奖励)代码,以说明应用于独热编码向量的嵌入层和全连接层是等效的。
|
||||||
|
|
||||||
- [embeddings-and-linear-layers.ipynb](embeddings-and-linear-layers.ipynb) contains optional (bonus) code to explain that embedding layers and fully connected layers applied to one-hot encoded vectors are equivalent.
|
|
||||||
|
|||||||
@@ -5,7 +5,7 @@
|
|||||||
"id": "063850ab-22b0-4838-b53a-9bb11757d9d0",
|
"id": "063850ab-22b0-4838-b53a-9bb11757d9d0",
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"source": [
|
"source": [
|
||||||
"# Embedding Layers and Linear Layers"
|
"# Embedding层和 Linear层"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
@@ -13,8 +13,8 @@
|
|||||||
"id": "0315c598-701f-46ff-8806-15813cad0e51",
|
"id": "0315c598-701f-46ff-8806-15813cad0e51",
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"source": [
|
"source": [
|
||||||
"- Embedding layers in PyTorch accomplish the same as linear layers that perform matrix multiplications; the reason we use embedding layers is computational efficiency\n",
|
"- 在PyTorch中,嵌入层(Embedding layers)实现了执行矩阵乘法的线性层的相同功能;我们使用嵌入层的原因是为了提高计算效率。\n",
|
||||||
"- We will take a look at this relationship step by step using code examples in PyTorch"
|
"- 我们将逐步使用PyTorch中的代码示例来查看这种关系。"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
@@ -27,13 +27,12 @@
|
|||||||
"name": "stdout",
|
"name": "stdout",
|
||||||
"output_type": "stream",
|
"output_type": "stream",
|
||||||
"text": [
|
"text": [
|
||||||
"PyTorch version: 2.1.0.post301\n"
|
"PyTorch version: 1.12.1+cu113\n"
|
||||||
]
|
]
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
"source": [
|
"source": [
|
||||||
"import torch\n",
|
"import torch\n",
|
||||||
"\n",
|
|
||||||
"print(\"PyTorch version:\", torch.__version__)"
|
"print(\"PyTorch version:\", torch.__version__)"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
@@ -47,22 +46,21 @@
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 2,
|
"execution_count": 3,
|
||||||
"id": "cc489ea5-73db-40b9-959e-0d70cae25f40",
|
"id": "cc489ea5-73db-40b9-959e-0d70cae25f40",
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
"# Suppose we have the following 3 training examples,\n",
|
"# 假设我们有以下 3 个训练样本,\n",
|
||||||
"# which may represent token IDs in a LLM context\n",
|
"# 这些样本可能表示语言模型(LM)上下文中的标记ID\n",
|
||||||
"idx = torch.tensor([2, 3, 1])\n",
|
"idx = torch.tensor([2, 3, 1])\n",
|
||||||
"\n",
|
"\n",
|
||||||
"# The number of rows in the embedding matrix can be determined\n",
|
"# 嵌入矩阵的行数可以通过获取最大标记ID + 1 来确定。\n",
|
||||||
"# by obtaining the largest token ID + 1.\n",
|
"# 如果最高的标记ID是3,则我们希望有4行,对应可能的\n",
|
||||||
"# If the highest token ID is 3, then we want 4 rows, for the possible\n",
|
"# 标记ID 0, 1, 2, 3\n",
|
||||||
"# token IDs 0, 1, 2, 3\n",
|
"num_idx = max(idx) + 1\n",
|
||||||
"num_idx = max(idx)+1\n",
|
|
||||||
"\n",
|
"\n",
|
||||||
"# The desired embedding dimension is a hyperparameter\n",
|
"# 所需的嵌入维度是一个超参数\n",
|
||||||
"out_dim = 5"
|
"out_dim = 5"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
@@ -71,21 +69,21 @@
|
|||||||
"id": "93d83a6e-8543-40af-b253-fe647640bf36",
|
"id": "93d83a6e-8543-40af-b253-fe647640bf36",
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"source": [
|
"source": [
|
||||||
"- Let's implement a simple embedding layer:"
|
"- 实现一个简单的嵌入层"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 3,
|
"execution_count": 4,
|
||||||
"id": "60a7c104-36e1-4b28-bd02-a24a1099dc66",
|
"id": "60a7c104-36e1-4b28-bd02-a24a1099dc66",
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
"# We use the random seed for reproducibility since\n",
|
"# 为了可重复性,我们使用随机种子,\n",
|
||||||
"# weights in the embedding layer are initialized with\n",
|
"# 因为嵌入层的权重是用小的随机值初始化的\n",
|
||||||
"# small random values\n",
|
|
||||||
"torch.manual_seed(123)\n",
|
"torch.manual_seed(123)\n",
|
||||||
"\n",
|
"\n",
|
||||||
|
"# 创建一个嵌入层,指定输入维度为 num_idx,输出维度为 out_dim\n",
|
||||||
"embedding = torch.nn.Embedding(num_idx, out_dim)"
|
"embedding = torch.nn.Embedding(num_idx, out_dim)"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
@@ -94,12 +92,12 @@
|
|||||||
"id": "dd96c00a-3297-4a50-8bfc-247aaea7e761",
|
"id": "dd96c00a-3297-4a50-8bfc-247aaea7e761",
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"source": [
|
"source": [
|
||||||
"We can optionally take a look at the embedding weights:"
|
"查看嵌入权重数据情况"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 4,
|
"execution_count": 5,
|
||||||
"id": "595f603e-8d2a-4171-8f94-eac8106b2e57",
|
"id": "595f603e-8d2a-4171-8f94-eac8106b2e57",
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [
|
"outputs": [
|
||||||
@@ -113,7 +111,7 @@
|
|||||||
" [-2.8400, -0.7849, -1.4096, -0.4076, 0.7953]], requires_grad=True)"
|
" [-2.8400, -0.7849, -1.4096, -0.4076, 0.7953]], requires_grad=True)"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
"execution_count": 4,
|
"execution_count": 5,
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"output_type": "execute_result"
|
"output_type": "execute_result"
|
||||||
}
|
}
|
||||||
@@ -127,12 +125,12 @@
|
|||||||
"id": "c86eb562-61e2-4171-ab6e-b410a1fd5c18",
|
"id": "c86eb562-61e2-4171-ab6e-b410a1fd5c18",
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"source": [
|
"source": [
|
||||||
"- We can then use the embedding layers to obtain the vector representation of a training example with ID 1:"
|
"- 使用嵌入层来获取具有ID 1的训练样本的向量表示"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 5,
|
"execution_count": 6,
|
||||||
"id": "8bbc0255-4805-4be9-9f4c-1d0d967ef9d5",
|
"id": "8bbc0255-4805-4be9-9f4c-1d0d967ef9d5",
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [
|
"outputs": [
|
||||||
@@ -143,7 +141,7 @@
|
|||||||
" grad_fn=<EmbeddingBackward0>)"
|
" grad_fn=<EmbeddingBackward0>)"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
"execution_count": 5,
|
"execution_count": 6,
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"output_type": "execute_result"
|
"output_type": "execute_result"
|
||||||
}
|
}
|
||||||
@@ -157,7 +155,7 @@
|
|||||||
"id": "6a4d47f2-4691-47b8-9855-2528b6c285c9",
|
"id": "6a4d47f2-4691-47b8-9855-2528b6c285c9",
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"source": [
|
"source": [
|
||||||
"- Below is a visualization of what happens under the hood:"
|
"- 下面是底层操作的可视化"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
@@ -173,12 +171,12 @@
|
|||||||
"id": "87d1311b-cfb2-4afc-9e25-e4ecf35370d9",
|
"id": "87d1311b-cfb2-4afc-9e25-e4ecf35370d9",
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"source": [
|
"source": [
|
||||||
"- Similarly, we can use embedding layers to obtain the vector representation of a training example with ID 2:"
|
"- 同样,我们可以使用嵌入层来获取具有ID 2的训练样本的向量表示:"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 6,
|
"execution_count": 7,
|
||||||
"id": "c309266a-c601-4633-9404-2e10b1cdde8c",
|
"id": "c309266a-c601-4633-9404-2e10b1cdde8c",
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [
|
"outputs": [
|
||||||
@@ -189,7 +187,7 @@
|
|||||||
" grad_fn=<EmbeddingBackward0>)"
|
" grad_fn=<EmbeddingBackward0>)"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
"execution_count": 6,
|
"execution_count": 7,
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"output_type": "execute_result"
|
"output_type": "execute_result"
|
||||||
}
|
}
|
||||||
@@ -211,12 +209,12 @@
|
|||||||
"id": "27dd54bd-85ae-4887-9c5e-3139da361cf4",
|
"id": "27dd54bd-85ae-4887-9c5e-3139da361cf4",
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"source": [
|
"source": [
|
||||||
"- Now, let's convert all the training examples we have defined previously:"
|
"- 现在,让我们将之前定义的所有训练样本转换:"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 7,
|
"execution_count": 8,
|
||||||
"id": "0191aa4b-f6a8-4b0d-9c36-65e82b81d071",
|
"id": "0191aa4b-f6a8-4b0d-9c36-65e82b81d071",
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [
|
"outputs": [
|
||||||
@@ -229,12 +227,13 @@
|
|||||||
" grad_fn=<EmbeddingBackward0>)"
|
" grad_fn=<EmbeddingBackward0>)"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
"execution_count": 7,
|
"execution_count": 8,
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"output_type": "execute_result"
|
"output_type": "execute_result"
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
"source": [
|
"source": [
|
||||||
|
"# 将原先的第三行变成现在的第一行,第四行变成现在的第二行,第二行变成现在的第三行\n",
|
||||||
"idx = torch.tensor([2, 3, 1])\n",
|
"idx = torch.tensor([2, 3, 1])\n",
|
||||||
"embedding(idx)"
|
"embedding(idx)"
|
||||||
]
|
]
|
||||||
@@ -260,7 +259,7 @@
|
|||||||
"id": "f0fe863b-d6a3-48f3-ace5-09ecd0eb7b59",
|
"id": "f0fe863b-d6a3-48f3-ace5-09ecd0eb7b59",
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"source": [
|
"source": [
|
||||||
"## Using nn.Linear"
|
"## 使用 nn.Linear"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
@@ -268,13 +267,13 @@
|
|||||||
"id": "138de6a4-2689-4c1f-96af-7899b2d82a4e",
|
"id": "138de6a4-2689-4c1f-96af-7899b2d82a4e",
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"source": [
|
"source": [
|
||||||
"- Now, we will demonstrate that the embedding layer above accomplishes exactly the same as `nn.Linear` layer on a one-hot encoded representation in PyTorch\n",
|
"- 接下来,我们将使用One-Hot编码,与embedding 层一样,在 `nn.Linear` 层进行操作\n",
|
||||||
"- First, let's convert the token IDs into a one-hot representation:"
|
"- 首先,我们将标记ID转换为One-Hot表示:"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 8,
|
"execution_count": 12,
|
||||||
"id": "b5bb56cf-bc73-41ab-b107-91a43f77bdba",
|
"id": "b5bb56cf-bc73-41ab-b107-91a43f77bdba",
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [
|
"outputs": [
|
||||||
@@ -286,7 +285,7 @@
|
|||||||
" [0, 1, 0, 0]])"
|
" [0, 1, 0, 0]])"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
"execution_count": 8,
|
"execution_count": 12,
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"output_type": "execute_result"
|
"output_type": "execute_result"
|
||||||
}
|
}
|
||||||
@@ -301,18 +300,33 @@
|
|||||||
"id": "aa45dfdf-fb26-4514-a176-75224f5f179b",
|
"id": "aa45dfdf-fb26-4514-a176-75224f5f179b",
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"source": [
|
"source": [
|
||||||
"- Next, we initialize a `Linear` layer, which caries out a matrix multiplication $X W^\\top$:"
|
"- 接下来,我们使用矩阵乘法$X W^\\top$ 来初始化一个Linear层"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 9,
|
"execution_count": 16,
|
||||||
"id": "ae04c1ed-242e-4dd7-b8f7-4b7e4caae383",
|
"id": "ae04c1ed-242e-4dd7-b8f7-4b7e4caae383",
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [],
|
"outputs": [
|
||||||
|
{
|
||||||
|
"name": "stdout",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
"Parameter containing:\n",
|
||||||
|
"tensor([[-0.2039, 0.0166, -0.2483, 0.1886],\n",
|
||||||
|
" [-0.4260, 0.3665, -0.3634, -0.3975],\n",
|
||||||
|
" [-0.3159, 0.2264, -0.1847, 0.1871],\n",
|
||||||
|
" [-0.4244, -0.3034, -0.1836, -0.0983],\n",
|
||||||
|
" [-0.3814, 0.3274, -0.1179, 0.1605]], requires_grad=True)\n"
|
||||||
|
]
|
||||||
|
}
|
||||||
|
],
|
||||||
"source": [
|
"source": [
|
||||||
"torch.manual_seed(123)\n",
|
"torch.manual_seed(123)\n",
|
||||||
"linear = torch.nn.Linear(num_idx, out_dim, bias=False)"
|
"# 初始化一个Linear层,该层的权重矩阵是由 num_idx(输入维度)到 out_dim(输出维度)的一个线性层,而且没有偏置项\n",
|
||||||
|
"linear = torch.nn.Linear(num_idx, out_dim, bias=False)\n",
|
||||||
|
"print(linear.weight)"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
@@ -320,16 +334,17 @@
|
|||||||
"id": "63efb98e-5cc4-4e8d-9fe6-ef0ad29ae2d7",
|
"id": "63efb98e-5cc4-4e8d-9fe6-ef0ad29ae2d7",
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"source": [
|
"source": [
|
||||||
"- Note that the linear layer in PyTorch is also initialized with small random weights; to directly compare it to the `Embedding` layer above, we have to use the same small random weights, which is why we reassign them here:"
|
"- 请注意,PyTorch中的`linear`层也是用小的随机权重进行初始化的。为了与上面的 `Embedding` 层进行直接比较,我们必须使用相同的小随机权重,这就是我们在这里重新分配它们的原因:"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 10,
|
"execution_count": 17,
|
||||||
"id": "a3b90d69-761c-486e-bd19-b38a2988fe62",
|
"id": "a3b90d69-761c-486e-bd19-b38a2988fe62",
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
|
"# linear 层的权重就被重新赋值为与 embedding 层相同的小随机权重,以确保它们具有相同的初始化。这是为了使它们在后续操作中可以进行直接比较。\n",
|
||||||
"linear.weight = torch.nn.Parameter(embedding.weight.T.detach())"
|
"linear.weight = torch.nn.Parameter(embedding.weight.T.detach())"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
@@ -338,12 +353,12 @@
|
|||||||
"id": "9116482d-f1f9-45e2-9bf3-7ef5e9003898",
|
"id": "9116482d-f1f9-45e2-9bf3-7ef5e9003898",
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"source": [
|
"source": [
|
||||||
"- Now we can use the linear layer on the one-hot encoded representation of the inputs:"
|
"- 现在,我们可以使用线性层处理输入的One-Hot编码表示:"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 11,
|
"execution_count": 18,
|
||||||
"id": "90d2b0dd-9f1d-4c0f-bb16-1f6ce6b8ac2c",
|
"id": "90d2b0dd-9f1d-4c0f-bb16-1f6ce6b8ac2c",
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [
|
"outputs": [
|
||||||
@@ -355,7 +370,7 @@
|
|||||||
" [ 1.3010, 1.2753, -0.2010, -0.1606, -0.4015]], grad_fn=<MmBackward0>)"
|
" [ 1.3010, 1.2753, -0.2010, -0.1606, -0.4015]], grad_fn=<MmBackward0>)"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
"execution_count": 11,
|
"execution_count": 18,
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"output_type": "execute_result"
|
"output_type": "execute_result"
|
||||||
}
|
}
|
||||||
@@ -369,12 +384,12 @@
|
|||||||
"id": "f6204bc8-92e2-4546-9cda-574fe1360fa2",
|
"id": "f6204bc8-92e2-4546-9cda-574fe1360fa2",
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"source": [
|
"source": [
|
||||||
"As we can see, this is exactly the same as what we got when we used the embedding layer:"
|
"正如我们所看到的,这与我们使用嵌入层时得到的结果完全相同:"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 12,
|
"execution_count": 19,
|
||||||
"id": "2b057649-3176-4a54-b58c-fd8fbf818c61",
|
"id": "2b057649-3176-4a54-b58c-fd8fbf818c61",
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [
|
"outputs": [
|
||||||
@@ -387,7 +402,7 @@
|
|||||||
" grad_fn=<EmbeddingBackward0>)"
|
" grad_fn=<EmbeddingBackward0>)"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
"execution_count": 12,
|
"execution_count": 19,
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"output_type": "execute_result"
|
"output_type": "execute_result"
|
||||||
}
|
}
|
||||||
@@ -401,7 +416,7 @@
|
|||||||
"id": "0e447639-8952-460e-8c8f-cf9e23c368c9",
|
"id": "0e447639-8952-460e-8c8f-cf9e23c368c9",
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"source": [
|
"source": [
|
||||||
"- What happens under the hood is the following computation for the first training example's token ID:"
|
"- 底层发生的计算如下,针对第一个训练样本的标记ID:"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
@@ -417,7 +432,7 @@
|
|||||||
"id": "9ce5211a-14e6-46aa-a3a8-14609f086e97",
|
"id": "9ce5211a-14e6-46aa-a3a8-14609f086e97",
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"source": [
|
"source": [
|
||||||
"- And for the second training example's token ID:"
|
"- 以及对于第二个训练样本的标记ID:"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
@@ -433,8 +448,9 @@
|
|||||||
"id": "e2608049-f5d1-49a9-a14b-82695fc32e6a",
|
"id": "e2608049-f5d1-49a9-a14b-82695fc32e6a",
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"source": [
|
"source": [
|
||||||
"- Since all but one index in each one-hot encoded row are 0 (by design), this matrix multiplication is essentially the same as a look-up of the one-hot elements\n",
|
"- \n",
|
||||||
"- This use of the matrix multiplication on one-hot encodings is equivalent to the embedding layer look-up but can be inefficient if we work with large embedding matrices, because there are a lot of wasteful multiplications by zero"
|
"由于每个独热编码行中除了一个索引外都为0(设计如此),这个矩阵乘法本质上就是对独热编码元素的查找\n",
|
||||||
|
"- 。在独热编码上使用矩阵乘法与使用嵌入层查找是等效的,但如果我们使用大型嵌入矩阵,这种方法可能效率较低,因为有很多不必要的零乘法。"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
@@ -444,22 +460,6 @@
|
|||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": []
|
"source": []
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "code",
|
|
||||||
"execution_count": null,
|
|
||||||
"id": "a1f63c81-1ee3-40a1-9ef2-14ff18fb4f05",
|
|
||||||
"metadata": {},
|
|
||||||
"outputs": [],
|
|
||||||
"source": []
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "code",
|
|
||||||
"execution_count": null,
|
|
||||||
"id": "c71959bb-facf-44fd-8edb-b67f7752f034",
|
|
||||||
"metadata": {},
|
|
||||||
"outputs": [],
|
|
||||||
"source": []
|
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
"metadata": {
|
"metadata": {
|
||||||
@@ -478,7 +478,7 @@
|
|||||||
"name": "python",
|
"name": "python",
|
||||||
"nbconvert_exporter": "python",
|
"nbconvert_exporter": "python",
|
||||||
"pygments_lexer": "ipython3",
|
"pygments_lexer": "ipython3",
|
||||||
"version": "3.11.5"
|
"version": "3.8.17"
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"nbformat": 4,
|
"nbformat": 4,
|
||||||
|
|||||||
Reference in New Issue
Block a user