增加第六章的翻译

This commit is contained in:
kjq_glb 2024-05-22 15:53:12 +08:00
parent d665982073
commit 310cdb21f5
11 changed files with 4440 additions and 0 deletions

View File

@ -0,0 +1,7 @@
# 第 6 章:用于文本分类的微调
- [ch06.ipynb](ch06.ipynb) 包含本章中出现的所有代码
- [previous_chapters.py](previous_chapters.py) 是一个 Python 模块,其中包含我们在前面的章节中编码和训练的 GPT 模型,以及我们在本章中重用的许多实用函数
- [gpt-class-finetune.py](gpt-class-finetune.py) 是一个独立的 Python 脚本文件,其中包含我们在 [ch06.ipynb](ch06.ipynb) 中实现的代码,用于微调 GPT 模型(您可以将其视为章节摘要)
- [gpt_download.py](gpt_download.py) 包含用于下载预训练 GPT 模型权重的实用函数
- [exercise-solutions.ipynb](exercise-solutions.ipynb) 包含本章的练习题

File diff suppressed because one or more lines are too long

View File

@ -0,0 +1,167 @@
{
"cells": [
{
"cell_type": "markdown",
"id": "ba450fb1-8a26-4894-ab7a-5d7bfefe90ce",
"metadata": {},
"source": [
"<font size=\"1\">\n",
"Supplementary code for \"Build a Large Language Model From Scratch\": <a href=\"https://www.manning.com/books/build-a-large-language-model-from-scratch\">https://www.manning.com/books/build-a-large-language-model-from-scratch</a> by <a href=\"https://sebastianraschka.com\">Sebastian Raschka</a><br>\n",
"Code repository: <a href=\"https://github.com/rasbt/LLMs-from-scratch\">https://github.com/rasbt/LLMs-from-scratch</a>\n",
"</font>"
]
},
{
"cell_type": "markdown",
"id": "51c9672d-8d0c-470d-ac2d-1271f8ec3f14",
"metadata": {},
"source": [
"# 第6章 练习题解答"
]
},
{
"cell_type": "markdown",
"id": "5fea8be3-30a1-4623-a6d7-b095c6c1092e",
"metadata": {},
"source": [
"## 练习 6.1:增加上下文长度"
]
},
{
"cell_type": "markdown",
"id": "5860ba9f-2db3-4480-b96b-4be1c68981eb",
"metadata": {},
"source": [
"我们可以通过将最大长度设置为 1024 将输入填充到模型支持的最大标记数:\n",
"\n",
"```python\n",
"max_length = 1024\n",
"\n",
"train_dataset = SpamDataset(base_path / \"train.csv\", max_length=max_length, tokenizer=tokenizer)\n",
"val_dataset = SpamDataset(base_path / \"validation.csv\", max_length=max_length, tokenizer=tokenizer)\n",
"test_dataset = SpamDataset(base_path / \"test.csv\", max_length=max_length, tokenizer=tokenizer)\n",
"\n",
"```\n",
"或者,等效地,我们可以通过以下方式定义`max_length`:\n",
"\n",
"```python\n",
"max_length = model.pos_emb.weight.shape[0]\n",
"```\n",
"\n",
"或者\n",
"\n",
"```python\n",
"max_length = BASE_CONFIG[\"context_length\"]\n",
"```"
]
},
{
"cell_type": "markdown",
"id": "2b0f4d5d-17fd-4265-93d8-ea08a22fdaf8",
"metadata": {},
"source": [
"为了方便起见,您可以通过以下方式运行此实验\n",
"\n",
"```\n",
"python additional-experiments.py --context_length \"model_context_length\"\n",
"```\n",
"\n",
"使用 [../02_bonus_additional-experiments](../02_bonus_additional-experiments) 的代码会导致测试准确率大幅下降,为 78.33%(相对于主章节中的 95.67%)。"
]
},
{
"cell_type": "markdown",
"id": "5a780455-f52a-48d1-ab82-6afd40bcad8b",
"metadata": {},
"source": [
"## 练习 6.2:微调整个模型"
]
},
{
"cell_type": "markdown",
"id": "56aa5208-aa29-4165-a0ec-7480754e2a18",
"metadata": {},
"source": [
"我们可以通过从代码中删除以下行来微调整个模型而不是仅微调最终的transformer块\n",
"\n",
"```python\n",
"for param in model.parameters():\n",
" param.requires_grad = False\n",
"```\n",
"\n",
"为了方便起见,您可以通过以下方式运行此实验\n",
"\n",
"```\n",
"python additional-experiments.py --trainable_layers all\n",
"```\n",
"\n",
"使用 [../02_bonus_additional-experiments](../02_bonus_additional-experiments) 的代码会导致测试准确率提高 1%,达到 96.67%(相对于主章节中的 95.67%)。"
]
},
{
"cell_type": "markdown",
"id": "2269bce3-f2b5-4a76-a692-5977c75a57b6",
"metadata": {},
"source": [
"## 练习 6.3:微调第一个和最后一个标记"
]
},
{
"cell_type": "markdown",
"id": "7418a629-51b6-4aa2-83b7-bc0261bc370f",
"metadata": {},
"source": [
"除了微调最后一个输出标记之外,我们还可以微调第一个输出标记,通过更改代码中的\n",
"\n",
"```python\n",
"model(input_batch)[:, -1, :]\n",
"```\n",
"\n",
"为\n",
"\n",
"```python\n",
"model(input_batch)[:, 0, :]\n",
"```\n",
"\n",
".\n",
"\n",
"为了方便起见,您可以通过以下方式运行此实验\n",
"\n",
"```\n",
"python additional-experiments.py --trainable_token first\n",
"```\n",
"\n",
"使用 [../02_bonus_additional-experiments](../02_bonus_additional-experiments) 文件夹中的代码会导致测试准确率大幅下降,为 75.00%(相对于主章节中的 95.67%)。"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "e5e6188a-f182-4f26-b9e5-ccae3ecadae0",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.6"
}
},
"nbformat": 4,
"nbformat_minor": 5
}

View File

@ -0,0 +1,418 @@
# Copyright (c) Sebastian Raschka under Apache License 2.0 (see LICENSE.txt).
# Source for "Build a Large Language Model From Scratch"
# - https://www.manning.com/books/build-a-large-language-model-from-scratch
# Code: https://github.com/rasbt/LLMs-from-scratch
# This is a summary file containing the main takeaways from chapter 6.
import urllib.request
import zipfile
import os
from pathlib import Path
import time
import matplotlib.pyplot as plt
import pandas as pd
import tiktoken
import torch
from torch.utils.data import Dataset, DataLoader
from gpt_download import download_and_load_gpt2
from previous_chapters import GPTModel, load_weights_into_gpt
def download_and_unzip_spam_data(url, zip_path, extracted_path, data_file_path):
if data_file_path.exists():
print(f"{data_file_path} already exists. Skipping download and extraction.")
return
# Downloading the file
with urllib.request.urlopen(url) as response:
with open(zip_path, "wb") as out_file:
out_file.write(response.read())
# Unzipping the file
with zipfile.ZipFile(zip_path, "r") as zip_ref:
zip_ref.extractall(extracted_path)
# Add .tsv file extension
original_file_path = Path(extracted_path) / "SMSSpamCollection"
os.rename(original_file_path, data_file_path)
print(f"File downloaded and saved as {data_file_path}")
def create_balanced_dataset(df):
# Count the instances of "spam"
num_spam = df[df["Label"] == "spam"].shape[0]
# Randomly sample "ham" instances to match the number of "spam" instances
ham_subset = df[df["Label"] == "ham"].sample(num_spam, random_state=123)
# Combine ham "subset" with "spam"
balanced_df = pd.concat([ham_subset, df[df["Label"] == "spam"]])
return balanced_df
def random_split(df, train_frac, validation_frac):
# Shuffle the entire DataFrame
df = df.sample(frac=1, random_state=123).reset_index(drop=True)
# Calculate split indices
train_end = int(len(df) * train_frac)
validation_end = train_end + int(len(df) * validation_frac)
# Split the DataFrame
train_df = df[:train_end]
validation_df = df[train_end:validation_end]
test_df = df[validation_end:]
return train_df, validation_df, test_df
class SpamDataset(Dataset):
def __init__(self, csv_file, tokenizer, max_length=None, pad_token_id=50256):
self.data = pd.read_csv(csv_file)
# Pre-tokenize texts
self.encoded_texts = [
tokenizer.encode(text) for text in self.data["Text"]
]
if max_length is None:
self.max_length = self._longest_encoded_length()
else:
self.max_length = max_length
# Truncate sequences if they are longer than max_length
self.encoded_texts = [
encoded_text[:self.max_length]
for encoded_text in self.encoded_texts
]
# Pad sequences to the longest sequence
self.encoded_texts = [
encoded_text + [pad_token_id] * (self.max_length - len(encoded_text))
for encoded_text in self.encoded_texts
]
def __getitem__(self, index):
encoded = self.encoded_texts[index]
label = self.data.iloc[index]["Label"]
return (
torch.tensor(encoded, dtype=torch.long),
torch.tensor(label, dtype=torch.long)
)
def __len__(self):
return len(self.data)
def _longest_encoded_length(self):
max_length = 0
for encoded_text in self.encoded_texts:
encoded_length = len(encoded_text)
if encoded_length > max_length:
max_length = encoded_length
return max_length
def calc_accuracy_loader(data_loader, model, device, num_batches=None):
model.eval()
correct_predictions, num_examples = 0, 0
if num_batches is None:
num_batches = len(data_loader)
else:
num_batches = min(num_batches, len(data_loader))
for i, (input_batch, target_batch) in enumerate(data_loader):
if i < num_batches:
input_batch, target_batch = input_batch.to(device), target_batch.to(device)
with torch.no_grad():
logits = model(input_batch)[:, -1, :] # Logits of last output token
predicted_labels = torch.argmax(logits, dim=-1)
num_examples += predicted_labels.shape[0]
correct_predictions += (predicted_labels == target_batch).sum().item()
else:
break
return correct_predictions / num_examples
def calc_loss_batch(input_batch, target_batch, model, device):
input_batch, target_batch = input_batch.to(device), target_batch.to(device)
logits = model(input_batch)[:, -1, :] # Logits of last output token
loss = torch.nn.functional.cross_entropy(logits, target_batch)
return loss
def calc_loss_loader(data_loader, model, device, num_batches=None):
total_loss = 0.
if len(data_loader) == 0:
return float("nan")
elif num_batches is None:
num_batches = len(data_loader)
else:
num_batches = min(num_batches, len(data_loader))
for i, (input_batch, target_batch) in enumerate(data_loader):
if i < num_batches:
loss = calc_loss_batch(input_batch, target_batch, model, device)
total_loss += loss.item()
else:
break
return total_loss / num_batches
def evaluate_model(model, train_loader, val_loader, device, eval_iter):
model.eval()
with torch.no_grad():
train_loss = calc_loss_loader(train_loader, model, device, num_batches=eval_iter)
val_loss = calc_loss_loader(val_loader, model, device, num_batches=eval_iter)
model.train()
return train_loss, val_loss
def train_classifier_simple(model, train_loader, val_loader, optimizer, device, num_epochs,
eval_freq, eval_iter, tokenizer):
# Initialize lists to track losses and tokens seen
train_losses, val_losses, train_accs, val_accs = [], [], [], []
examples_seen, global_step = 0, -1
# Main training loop
for epoch in range(num_epochs):
model.train() # Set model to training mode
for input_batch, target_batch in train_loader:
optimizer.zero_grad() # Reset loss gradients from previous epoch
loss = calc_loss_batch(input_batch, target_batch, model, device)
loss.backward() # Calculate loss gradients
optimizer.step() # Update model weights using loss gradients
examples_seen += input_batch.shape[0] # New: track examples instead of tokens
global_step += 1
# Optional evaluation step
if global_step % eval_freq == 0:
train_loss, val_loss = evaluate_model(
model, train_loader, val_loader, device, eval_iter)
train_losses.append(train_loss)
val_losses.append(val_loss)
print(f"Ep {epoch+1} (Step {global_step:06d}): "
f"Train loss {train_loss:.3f}, Val loss {val_loss:.3f}")
# Calculate accuracy after each epoch
train_accuracy = calc_accuracy_loader(train_loader, model, device, num_batches=eval_iter)
val_accuracy = calc_accuracy_loader(val_loader, model, device, num_batches=eval_iter)
print(f"Training accuracy: {train_accuracy*100:.2f}% | ", end="")
print(f"Validation accuracy: {val_accuracy*100:.2f}%")
train_accs.append(train_accuracy)
val_accs.append(val_accuracy)
return train_losses, val_losses, train_accs, val_accs, examples_seen
def plot_values(epochs_seen, examples_seen, train_values, val_values, label="loss"):
fig, ax1 = plt.subplots(figsize=(5, 3))
# Plot training and validation loss against epochs
ax1.plot(epochs_seen, train_values, label=f"Training {label}")
ax1.plot(epochs_seen, val_values, linestyle="-.", label=f"Validation {label}")
ax1.set_xlabel("Epochs")
ax1.set_ylabel(label.capitalize())
ax1.legend()
# Create a second x-axis for tokens seen
ax2 = ax1.twiny() # Create a second x-axis that shares the same y-axis
ax2.plot(examples_seen, train_values, alpha=0) # Invisible plot for aligning ticks
ax2.set_xlabel("Examples seen")
fig.tight_layout() # Adjust layout to make room
plt.savefig(f"{label}-plot.pdf")
# plt.show()
if __name__ == "__main__":
import argparse
parser = argparse.ArgumentParser(
description="Finetune a GPT model for classification"
)
parser.add_argument(
"--test_mode",
action="store_true",
help=("This flag runs the model in test mode for internal testing purposes. "
"Otherwise, it runs the model as it is used in the chapter (recommended).")
)
args = parser.parse_args()
########################################
# Download and prepare dataset
########################################
url = "https://archive.ics.uci.edu/static/public/228/sms+spam+collection.zip"
zip_path = "sms_spam_collection.zip"
extracted_path = "sms_spam_collection"
data_file_path = Path(extracted_path) / "SMSSpamCollection.tsv"
download_and_unzip_spam_data(url, zip_path, extracted_path, data_file_path)
df = pd.read_csv(data_file_path, sep="\t", header=None, names=["Label", "Text"])
balanced_df = create_balanced_dataset(df)
balanced_df["Label"] = balanced_df["Label"].map({"ham": 0, "spam": 1})
train_df, validation_df, test_df = random_split(balanced_df, 0.7, 0.1)
train_df.to_csv("train.csv", index=None)
validation_df.to_csv("validation.csv", index=None)
test_df.to_csv("test.csv", index=None)
########################################
# Create data loaders
########################################
tokenizer = tiktoken.get_encoding("gpt2")
train_dataset = SpamDataset(
csv_file="train.csv",
max_length=None,
tokenizer=tokenizer
)
val_dataset = SpamDataset(
csv_file="validation.csv",
max_length=train_dataset.max_length,
tokenizer=tokenizer
)
test_dataset = SpamDataset(
csv_file="test.csv",
max_length=train_dataset.max_length,
tokenizer=tokenizer
)
num_workers = 0
batch_size = 8
torch.manual_seed(123)
train_loader = DataLoader(
dataset=train_dataset,
batch_size=batch_size,
shuffle=True,
num_workers=num_workers,
drop_last=True,
)
val_loader = DataLoader(
dataset=val_dataset,
batch_size=batch_size,
num_workers=num_workers,
drop_last=False,
)
test_loader = DataLoader(
dataset=test_dataset,
batch_size=batch_size,
num_workers=num_workers,
drop_last=False,
)
########################################
# Load pretrained model
########################################
# Small GPT model for testing purposes
if args.test_mode:
BASE_CONFIG = {
"vocab_size": 50257,
"context_length": 120,
"drop_rate": 0.0,
"qkv_bias": False,
"emb_dim": 12,
"n_layers": 1,
"n_heads": 2
}
model = GPTModel(BASE_CONFIG)
model.eval()
device = "cpu"
model.to(device)
# Code as it is used in the main chapter
else:
CHOOSE_MODEL = "gpt2-small (124M)"
INPUT_PROMPT = "Every effort moves"
BASE_CONFIG = {
"vocab_size": 50257, # Vocabulary size
"context_length": 1024, # Context length
"drop_rate": 0.0, # Dropout rate
"qkv_bias": True # Query-key-value bias
}
model_configs = {
"gpt2-small (124M)": {"emb_dim": 768, "n_layers": 12, "n_heads": 12},
"gpt2-medium (355M)": {"emb_dim": 1024, "n_layers": 24, "n_heads": 16},
"gpt2-large (774M)": {"emb_dim": 1280, "n_layers": 36, "n_heads": 20},
"gpt2-xl (1558M)": {"emb_dim": 1600, "n_layers": 48, "n_heads": 25},
}
BASE_CONFIG.update(model_configs[CHOOSE_MODEL])
model_size = CHOOSE_MODEL.split(" ")[-1].lstrip("(").rstrip(")")
settings, params = download_and_load_gpt2(model_size=model_size, models_dir="gpt2")
model = GPTModel(BASE_CONFIG)
load_weights_into_gpt(model, params)
model.eval()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
########################################
# Modify and pretrained model
########################################
for param in model.parameters():
param.requires_grad = False
torch.manual_seed(123)
num_classes = 2
model.out_head = torch.nn.Linear(in_features=BASE_CONFIG["emb_dim"], out_features=num_classes)
for param in model.trf_blocks[-1].parameters():
param.requires_grad = True
for param in model.final_norm.parameters():
param.requires_grad = True
########################################
# Finetune modified model
########################################
start_time = time.time()
torch.manual_seed(123)
optimizer = torch.optim.AdamW(model.parameters(), lr=5e-5, weight_decay=0.1)
num_epochs = 5
train_losses, val_losses, train_accs, val_accs, examples_seen = train_classifier_simple(
model, train_loader, val_loader, optimizer, device,
num_epochs=num_epochs, eval_freq=50, eval_iter=5,
tokenizer=tokenizer
)
end_time = time.time()
execution_time_minutes = (end_time - start_time) / 60
print(f"Training completed in {execution_time_minutes:.2f} minutes.")
########################################
# Plot results
########################################
# loss plot
epochs_tensor = torch.linspace(0, num_epochs, len(train_losses))
examples_seen_tensor = torch.linspace(0, examples_seen, len(train_losses))
plot_values(epochs_tensor, examples_seen_tensor, train_losses, val_losses)
# accuracy plot
epochs_tensor = torch.linspace(0, num_epochs, len(train_accs))
examples_seen_tensor = torch.linspace(0, examples_seen, len(train_accs))
plot_values(epochs_tensor, examples_seen_tensor, train_accs, val_accs, label="accuracy")

View File

@ -0,0 +1,99 @@
# Copyright (c) Sebastian Raschka under Apache License 2.0 (see LICENSE.txt).
# Source for "Build a Large Language Model From Scratch"
# - https://www.manning.com/books/build-a-large-language-model-from-scratch
# Code: https://github.com/rasbt/LLMs-from-scratch
import os
import requests
import json
import numpy as np
import tensorflow as tf
from tqdm import tqdm
def download_and_load_gpt2(model_size, models_dir):
# Validate model size
allowed_sizes = ("124M", "355M", "774M", "1558M")
if model_size not in allowed_sizes:
raise ValueError(f"Model size not in {allowed_sizes}")
# Define paths
model_dir = os.path.join(models_dir, model_size)
base_url = "https://openaipublic.blob.core.windows.net/gpt-2/models"
filenames = [
"checkpoint", "encoder.json", "hparams.json",
"model.ckpt.data-00000-of-00001", "model.ckpt.index",
"model.ckpt.meta", "vocab.bpe"
]
# Download files
os.makedirs(model_dir, exist_ok=True)
for filename in filenames:
file_url = os.path.join(base_url, model_size, filename)
file_path = os.path.join(model_dir, filename)
download_file(file_url, file_path)
# Load settings and params
tf_ckpt_path = tf.train.latest_checkpoint(model_dir)
settings = json.load(open(os.path.join(model_dir, "hparams.json")))
params = load_gpt2_params_from_tf_ckpt(tf_ckpt_path, settings)
return settings, params
def download_file(url, destination):
# Send a GET request to download the file in streaming mode
response = requests.get(url, stream=True)
# Get the total file size from headers, defaulting to 0 if not present
file_size = int(response.headers.get("content-length", 0))
# Check if file exists and has the same size
if os.path.exists(destination):
file_size_local = os.path.getsize(destination)
if file_size == file_size_local:
print(f"File already exists and is up-to-date: {destination}")
return
# Define the block size for reading the file
block_size = 1024 # 1 Kilobyte
# Initialize the progress bar with total file size
progress_bar_description = url.split("/")[-1] # Extract filename from URL
with tqdm(total=file_size, unit="iB", unit_scale=True, desc=progress_bar_description) as progress_bar:
# Open the destination file in binary write mode
with open(destination, "wb") as file:
# Iterate over the file data in chunks
for chunk in response.iter_content(block_size):
progress_bar.update(len(chunk)) # Update progress bar
file.write(chunk) # Write the chunk to the file
def load_gpt2_params_from_tf_ckpt(ckpt_path, settings):
# Initialize parameters dictionary with empty blocks for each layer
params = {"blocks": [{} for _ in range(settings["n_layer"])]}
# Iterate over each variable in the checkpoint
for name, _ in tf.train.list_variables(ckpt_path):
# Load the variable and remove singleton dimensions
variable_array = np.squeeze(tf.train.load_variable(ckpt_path, name))
# Process the variable name to extract relevant parts
variable_name_parts = name.split("/")[1:] # Skip the 'model/' prefix
# Identify the target dictionary for the variable
target_dict = params
if variable_name_parts[0].startswith("h"):
layer_number = int(variable_name_parts[0][1:])
target_dict = params["blocks"][layer_number]
# Recursively access or create nested dictionaries
for key in variable_name_parts[1:-1]:
target_dict = target_dict.setdefault(key, {})
# Assign the variable array to the last key
last_key = variable_name_parts[-1]
target_dict[last_key] = variable_array
return params

View File

@ -0,0 +1,321 @@
# Copyright (c) Sebastian Raschka under Apache License 2.0 (see LICENSE.txt).
# Source for "Build a Large Language Model From Scratch"
# - https://www.manning.com/books/build-a-large-language-model-from-scratch
# Code: https://github.com/rasbt/LLMs-from-scratch
#
# This file collects all the relevant code that we covered thus far
# throughout Chapters 2-5.
# This file can be run as a standalone script.
import numpy as np
import tiktoken
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
#####################################
# Chapter 2
#####################################
class GPTDatasetV1(Dataset):
def __init__(self, txt, tokenizer, max_length, stride):
self.tokenizer = tokenizer
self.input_ids = []
self.target_ids = []
# Tokenize the entire text
token_ids = tokenizer.encode(txt)
# Use a sliding window to chunk the book into overlapping sequences of max_length
for i in range(0, len(token_ids) - max_length, stride):
input_chunk = token_ids[i:i + max_length]
target_chunk = token_ids[i + 1: i + max_length + 1]
self.input_ids.append(torch.tensor(input_chunk))
self.target_ids.append(torch.tensor(target_chunk))
def __len__(self):
return len(self.input_ids)
def __getitem__(self, idx):
return self.input_ids[idx], self.target_ids[idx]
def create_dataloader_v1(txt, batch_size=4, max_length=256,
stride=128, shuffle=True, drop_last=True):
# Initialize the tokenizer
tokenizer = tiktoken.get_encoding("gpt2")
# Create dataset
dataset = GPTDatasetV1(txt, tokenizer, max_length, stride)
# Create dataloader
dataloader = DataLoader(
dataset, batch_size=batch_size, shuffle=shuffle, drop_last=drop_last)
return dataloader
#####################################
# Chapter 3
#####################################
class MultiHeadAttention(nn.Module):
def __init__(self, d_in, d_out, context_length, dropout, num_heads, qkv_bias=False):
super().__init__()
assert d_out % num_heads == 0, "d_out must be divisible by n_heads"
self.d_out = d_out
self.num_heads = num_heads
self.head_dim = d_out // num_heads # Reduce the projection dim to match desired output dim
self.W_query = nn.Linear(d_in, d_out, bias=qkv_bias)
self.W_key = nn.Linear(d_in, d_out, bias=qkv_bias)
self.W_value = nn.Linear(d_in, d_out, bias=qkv_bias)
self.out_proj = nn.Linear(d_out, d_out) # Linear layer to combine head outputs
self.dropout = nn.Dropout(dropout)
self.register_buffer('mask', torch.triu(torch.ones(context_length, context_length), diagonal=1))
def forward(self, x):
b, num_tokens, d_in = x.shape
keys = self.W_key(x) # Shape: (b, num_tokens, d_out)
queries = self.W_query(x)
values = self.W_value(x)
# We implicitly split the matrix by adding a `num_heads` dimension
# Unroll last dim: (b, num_tokens, d_out) -> (b, num_tokens, num_heads, head_dim)
keys = keys.view(b, num_tokens, self.num_heads, self.head_dim)
values = values.view(b, num_tokens, self.num_heads, self.head_dim)
queries = queries.view(b, num_tokens, self.num_heads, self.head_dim)
# Transpose: (b, num_tokens, num_heads, head_dim) -> (b, num_heads, num_tokens, head_dim)
keys = keys.transpose(1, 2)
queries = queries.transpose(1, 2)
values = values.transpose(1, 2)
# Compute scaled dot-product attention (aka self-attention) with a causal mask
attn_scores = queries @ keys.transpose(2, 3) # Dot product for each head
# Original mask truncated to the number of tokens and converted to boolean
mask_bool = self.mask.bool()[:num_tokens, :num_tokens]
# Use the mask to fill attention scores
attn_scores.masked_fill_(mask_bool, -torch.inf)
attn_weights = torch.softmax(attn_scores / keys.shape[-1]**0.5, dim=-1)
attn_weights = self.dropout(attn_weights)
# Shape: (b, num_tokens, num_heads, head_dim)
context_vec = (attn_weights @ values).transpose(1, 2)
# Combine heads, where self.d_out = self.num_heads * self.head_dim
context_vec = context_vec.reshape(b, num_tokens, self.d_out)
context_vec = self.out_proj(context_vec) # optional projection
return context_vec
#####################################
# Chapter 4
#####################################
class LayerNorm(nn.Module):
def __init__(self, emb_dim):
super().__init__()
self.eps = 1e-5
self.scale = nn.Parameter(torch.ones(emb_dim))
self.shift = nn.Parameter(torch.zeros(emb_dim))
def forward(self, x):
mean = x.mean(dim=-1, keepdim=True)
var = x.var(dim=-1, keepdim=True, unbiased=False)
norm_x = (x - mean) / torch.sqrt(var + self.eps)
return self.scale * norm_x + self.shift
class GELU(nn.Module):
def __init__(self):
super().__init__()
def forward(self, x):
return 0.5 * x * (1 + torch.tanh(
torch.sqrt(torch.tensor(2.0 / torch.pi)) *
(x + 0.044715 * torch.pow(x, 3))
))
class FeedForward(nn.Module):
def __init__(self, cfg):
super().__init__()
self.layers = nn.Sequential(
nn.Linear(cfg["emb_dim"], 4 * cfg["emb_dim"]),
GELU(),
nn.Linear(4 * cfg["emb_dim"], cfg["emb_dim"]),
)
def forward(self, x):
return self.layers(x)
class TransformerBlock(nn.Module):
def __init__(self, cfg):
super().__init__()
self.att = MultiHeadAttention(
d_in=cfg["emb_dim"],
d_out=cfg["emb_dim"],
context_length=cfg["context_length"],
num_heads=cfg["n_heads"],
dropout=cfg["drop_rate"],
qkv_bias=cfg["qkv_bias"])
self.ff = FeedForward(cfg)
self.norm1 = LayerNorm(cfg["emb_dim"])
self.norm2 = LayerNorm(cfg["emb_dim"])
self.drop_resid = nn.Dropout(cfg["drop_rate"])
def forward(self, x):
# Shortcut connection for attention block
shortcut = x
x = self.norm1(x)
x = self.att(x) # Shape [batch_size, num_tokens, emb_size]
x = self.drop_resid(x)
x = x + shortcut # Add the original input back
# Shortcut connection for feed-forward block
shortcut = x
x = self.norm2(x)
x = self.ff(x)
x = self.drop_resid(x)
x = x + shortcut # Add the original input back
return x
class GPTModel(nn.Module):
def __init__(self, cfg):
super().__init__()
self.tok_emb = nn.Embedding(cfg["vocab_size"], cfg["emb_dim"])
self.pos_emb = nn.Embedding(cfg["context_length"], cfg["emb_dim"])
self.drop_emb = nn.Dropout(cfg["drop_rate"])
self.trf_blocks = nn.Sequential(
*[TransformerBlock(cfg) for _ in range(cfg["n_layers"])])
self.final_norm = LayerNorm(cfg["emb_dim"])
self.out_head = nn.Linear(cfg["emb_dim"], cfg["vocab_size"], bias=False)
def forward(self, in_idx):
batch_size, seq_len = in_idx.shape
tok_embeds = self.tok_emb(in_idx)
pos_embeds = self.pos_emb(torch.arange(seq_len, device=in_idx.device))
x = tok_embeds + pos_embeds # Shape [batch_size, num_tokens, emb_size]
x = self.drop_emb(x)
x = self.trf_blocks(x)
x = self.final_norm(x)
logits = self.out_head(x)
return logits
def generate_text_simple(model, idx, max_new_tokens, context_size):
# idx is (B, T) array of indices in the current context
for _ in range(max_new_tokens):
# Crop current context if it exceeds the supported context size
# E.g., if LLM supports only 5 tokens, and the context size is 10
# then only the last 5 tokens are used as context
idx_cond = idx[:, -context_size:]
# Get the predictions
with torch.no_grad():
logits = model(idx_cond)
# Focus only on the last time step
# (batch, n_token, vocab_size) becomes (batch, vocab_size)
logits = logits[:, -1, :]
# Get the idx of the vocab entry with the highest logits value
idx_next = torch.argmax(logits, dim=-1, keepdim=True) # (batch, 1)
# Append sampled index to the running sequence
idx = torch.cat((idx, idx_next), dim=1) # (batch, n_tokens+1)
return idx
#####################################
# Chapter 5
#####################################
def assign(left, right):
if left.shape != right.shape:
raise ValueError(f"Shape mismatch. Left: {left.shape}, Right: {right.shape}")
return torch.nn.Parameter(torch.tensor(right))
def load_weights_into_gpt(gpt, params):
gpt.pos_emb.weight = assign(gpt.pos_emb.weight, params['wpe'])
gpt.tok_emb.weight = assign(gpt.tok_emb.weight, params['wte'])
for b in range(len(params["blocks"])):
q_w, k_w, v_w = np.split(
(params["blocks"][b]["attn"]["c_attn"])["w"], 3, axis=-1)
gpt.trf_blocks[b].att.W_query.weight = assign(
gpt.trf_blocks[b].att.W_query.weight, q_w.T)
gpt.trf_blocks[b].att.W_key.weight = assign(
gpt.trf_blocks[b].att.W_key.weight, k_w.T)
gpt.trf_blocks[b].att.W_value.weight = assign(
gpt.trf_blocks[b].att.W_value.weight, v_w.T)
q_b, k_b, v_b = np.split(
(params["blocks"][b]["attn"]["c_attn"])["b"], 3, axis=-1)
gpt.trf_blocks[b].att.W_query.bias = assign(
gpt.trf_blocks[b].att.W_query.bias, q_b)
gpt.trf_blocks[b].att.W_key.bias = assign(
gpt.trf_blocks[b].att.W_key.bias, k_b)
gpt.trf_blocks[b].att.W_value.bias = assign(
gpt.trf_blocks[b].att.W_value.bias, v_b)
gpt.trf_blocks[b].att.out_proj.weight = assign(
gpt.trf_blocks[b].att.out_proj.weight,
params["blocks"][b]["attn"]["c_proj"]["w"].T)
gpt.trf_blocks[b].att.out_proj.bias = assign(
gpt.trf_blocks[b].att.out_proj.bias,
params["blocks"][b]["attn"]["c_proj"]["b"])
gpt.trf_blocks[b].ff.layers[0].weight = assign(
gpt.trf_blocks[b].ff.layers[0].weight,
params["blocks"][b]["mlp"]["c_fc"]["w"].T)
gpt.trf_blocks[b].ff.layers[0].bias = assign(
gpt.trf_blocks[b].ff.layers[0].bias,
params["blocks"][b]["mlp"]["c_fc"]["b"])
gpt.trf_blocks[b].ff.layers[2].weight = assign(
gpt.trf_blocks[b].ff.layers[2].weight,
params["blocks"][b]["mlp"]["c_proj"]["w"].T)
gpt.trf_blocks[b].ff.layers[2].bias = assign(
gpt.trf_blocks[b].ff.layers[2].bias,
params["blocks"][b]["mlp"]["c_proj"]["b"])
gpt.trf_blocks[b].norm1.scale = assign(
gpt.trf_blocks[b].norm1.scale,
params["blocks"][b]["ln_1"]["g"])
gpt.trf_blocks[b].norm1.shift = assign(
gpt.trf_blocks[b].norm1.shift,
params["blocks"][b]["ln_1"]["b"])
gpt.trf_blocks[b].norm2.scale = assign(
gpt.trf_blocks[b].norm2.scale,
params["blocks"][b]["ln_2"]["g"])
gpt.trf_blocks[b].norm2.shift = assign(
gpt.trf_blocks[b].norm2.shift,
params["blocks"][b]["ln_2"]["b"])
gpt.final_norm.scale = assign(gpt.final_norm.scale, params["g"])
gpt.final_norm.shift = assign(gpt.final_norm.shift, params["b"])
gpt.out_head.weight = assign(gpt.out_head.weight, params["wte"])
def text_to_token_ids(text, tokenizer):
encoded = tokenizer.encode(text, allowed_special={'<|endoftext|>'})
encoded_tensor = torch.tensor(encoded).unsqueeze(0) # add batch dimension
return encoded_tensor
def token_ids_to_text(token_ids, tokenizer):
flat = token_ids.squeeze(0) # remove batch dimension
return tokenizer.decode(flat.tolist())

View File

@ -0,0 +1,16 @@
# Copyright (c) Sebastian Raschka under Apache License 2.0 (see LICENSE.txt).
# Source for "Build a Large Language Model From Scratch"
# - https://www.manning.com/books/build-a-large-language-model-from-scratch
# Code: https://github.com/rasbt/LLMs-from-scratch
# File for internal use (unit tests)
import subprocess
def test_gpt_class_finetune():
command = ["python", "ch06/01_main-chapter-code/gpt-class-finetune.py", "--test_mode"]
result = subprocess.run(command, capture_output=True, text=True)
assert result.returncode == 0, f"Script exited with errors: {result.stderr}"

View File

@ -0,0 +1,66 @@
# 额外实验
下表添加了一些实验来回答有关各种设计选择的其他问题。 第一行使用与主要章节相同的设置并用作参考。
例如,
- 比较第 1 行和第 2 行回答了以下问题:“当我们训练最后一个或第一个标记时,性能差异是什么?”;
- 比较第 1 行和第 3 行回答了以下问题:“当我们只训练最后一层而不是最后一个块时,性能差异是什么?”;
- 等等。
&nbsp;
| | Model | Weights | Trainable token | Trainable layers | Context length | Training acc | Validation acc | Test acc | Training time | CPU/GPU |
| ---- | ------------------ | ---------- | --------------- | ---------------- | ----------------------- | ------------ | -------------- | -------- | ------------- | ------- |
| 1 | gpt2-small (124M) | pretrained | last | last_block | longest train ex. (120) | 96.63% | 99.33% | 95.00% | 0.28 min | A100 |
| 2 | gpt2-small (124M) | pretrained | first | last_block | longest train ex. (120) | 78.46% | 80.54% | 75.00% | 0.28 min | A100 |
| 3 | gpt2-small (124M) | pretrained | last | last_layer | longest train ex. (120) | 78.65% | 79.87% | 72.00% | 0.25 min | A100 |
| 4 | gpt2-small (124M) | pretrained | last | all | longest train ex. (120) | 99.62% | 96.64% | 96.67% | 0.69 min | A100 |
| 5 | gpt2-medium (355M) | pretrained | last | last_block | longest train ex. (120) | 87.50% | 91.28% | 84.67% | 0.75 min | A100 |
| 6 | gpt2-large (774M) | pretrained | last | last_block | longest train ex. (120) | 99.52% | 98.66% | 96.67% | 1.50 min | A100 |
| 7 | gpt2-xl (1558M) | pretrained | last | last_block | longest train ex. (120) | 99.81% | 99.33% | 98.33% | 2.83 min | A100 |
| 8 | gpt2-small (124M) | random | last | all | longest train ex. (120) | 100% | 96.64% | 93.67% | 0.69 min | A100 |
| 9 | gpt2-small (124M) | pretrained | last | LoRA | longest train ex. (120) | 99.52% | 97.99% | 97.67% | 0.75 min | A100 |
| 10 | gpt2-small (124M) | pretrained | last | last_block | context length (1024) | 83.08% | 87.92% | 78.33% | 2.46 min | A100 |
| 11 | gpt2-small (124M) | pretrained | last | last_block | variable: no padding (batch size 1) | 100.00% | 98.66% | 98.00% | 1.75 min | A100 |
| 11 | gpt2-small (124M) | pretrained | last | last_block | variable: no padding (batch size 8) | 99.33% | 98.66% | 98.33% | 1.70 min | A100 |
&nbsp;
## 使用方法
您可以使用以下代码来重现实验:
- Row 1: `python additional-experiments.py`
- Row 2: `python additional-experiments.py --trainable_token first`
- Row 3: `python additional-experiments.py --trainable_layers last_layer`
- Row 4: `python additional-experiments.py --trainable_layers all`
- Row 5: `python additional-experiments.py --model_size "gpt2-medium (355M)"`
- Row 6: `python additional-experiments.py --model_size "gpt2-large (774M)"`
- Row 7: `python additional-experiments.py --model_size "gpt2-xl (1558M)"`
- Row 8: `python additional-experiments.py --weights random --trainable_layers all`
- Row 9: `python additional-experiments.py --trainable_layers lora --lora_rank 16 --lora_alpha 8`
- Row 10: `python additional-experiments.py --context_length "model_context_length"`
- Row 11: `python additional-experiments.py --no_padding --batch_size 1`
- Row 12: `python additional-experiments.py --no_padding --batch_size 1 --accumulation_steps 8`
我特意将 LLM 和数据集保持得较小,因此,如果您无法使用 GPU您可以在 MacBook Air M3 等普通笔记本电脑上运行大约 15 分钟的训练。
&nbsp;
## 解释
1. **训练最后一个输出标记与第一个输出标记(第 1 行与第 2 行)**:与第一个输出标记相比,训练最后一个输出标记会带来更好的性能。由于因果自注意力掩模,这种改进是可以预期的。
2. **训练最后一个 Transformer 块与最后一层(第 1 行与第 3 行)**:训练整个最后一个 Transformer 块也比仅训练最后一层获得更好的结果。
3. **训练所有层与最后一个 Transformer 块(第 1 行与第 4 行)**:训练所有层比仅训练最后一个 Transformer 块显示出约 2% 的适度改进,但它需要的时间几乎是三倍的训练时间。
4. **使用更大的预训练模型(第 1 行与第 5 行,以及第 1 行与第 6 行和第 7 行)**:采用 3 倍大的预训练模型会导致更差的结果。 然而,正如预期的那样,与初始模型相比,使用大 5 倍的模型可以提高性能。 同样12 倍大的模型进一步提高了预测性能。(中等模型可能没有经过很好的预训练,或者特定的微调配置对该模型效果不佳。)
5. **使用具有随机权重的模型与预训练权重(第 1 行与第 8 行)**:使用具有随机权重的模型产生的结果仅比使用预训练权重稍差 1.3%。
6. **使用 LoRA低阶适应与训练所有层第 9 行与第 4 行)**:保持模型冻结并添加可训练的 LoRA 层是训练所有模型参数的可行替代方案,甚至可以将性能提高 1%(请参阅[附录 E](../../appendix-E/01_main-chapter-code/appendix-E.ipynb)查看更多细节)。 从使用 LoRA 时训练和验证准确率之间的差距降低 1% 可以看出,这可能是由于过度拟合较少。 此外,使用 LoRA 的速度也稍快一些,因为需要更新的参数较少。
7. **将输入填充到完整上下文长度与最长训练示例(第 1 行与第 10 行)**:将输入填充到完整支持的上下文长度结果明显更差。
8. **填充与无填充(第 1 行与第 11 行和第 12 行)**`--no_padding` 选项禁用数据集中的填充,这需要使用批量大小 1 来训练模型,因为输入具有可变长度。 这会带来更好的测试准确率,但需要更长的训练时间。 在第 12 行中,我们另外启用了 8 个步骤的梯度累积,以实现与其他实验相同的批量大小,这有助于减少过度拟合并略微提高测试集的准确性。

View File

@ -0,0 +1,541 @@
# Copyright (c) Sebastian Raschka under Apache License 2.0 (see LICENSE.txt).
# Source for "Build a Large Language Model From Scratch"
# - https://www.manning.com/books/build-a-large-language-model-from-scratch
# Code: https://github.com/rasbt/LLMs-from-scratch
import argparse
import os
from pathlib import Path
import time
import urllib.request
import zipfile
import pandas as pd
import tiktoken
import torch
from torch.utils.data import DataLoader
from torch.utils.data import Dataset
from gpt_download import download_and_load_gpt2
from previous_chapters import GPTModel, load_weights_into_gpt
class LoRALayer(torch.nn.Module):
def __init__(self, in_dim, out_dim, rank, alpha):
super().__init__()
std_dev = 1 / torch.sqrt(torch.tensor(rank).float())
self.A = torch.nn.Parameter(torch.randn(in_dim, rank) * std_dev)
self.B = torch.nn.Parameter(torch.zeros(rank, out_dim))
self.alpha = alpha
def forward(self, x):
x = self.alpha * (x @ self.A @ self.B)
return x
class LinearWithLoRA(torch.nn.Module):
def __init__(self, linear, rank, alpha):
super().__init__()
self.linear = linear
self.lora = LoRALayer(
linear.in_features, linear.out_features, rank, alpha
)
def forward(self, x):
return self.linear(x) + self.lora(x)
class SpamDataset(Dataset):
def __init__(self, csv_file, tokenizer, max_length=None, pad_token_id=50256, no_padding=False):
self.data = pd.read_csv(csv_file)
self.max_length = max_length if max_length is not None else self._longest_encoded_length(tokenizer)
# Pre-tokenize texts
self.encoded_texts = [
tokenizer.encode(text)[:self.max_length]
for text in self.data["Text"]
]
if not no_padding:
# Pad sequences to the longest sequence
self.encoded_texts = [
et + [pad_token_id] * (self.max_length - len(et))
for et in self.encoded_texts
]
def __getitem__(self, index):
encoded = self.encoded_texts[index]
label = self.data.iloc[index]["Label"]
return torch.tensor(encoded, dtype=torch.long), torch.tensor(label, dtype=torch.long)
def __len__(self):
return len(self.data)
def _longest_encoded_length(self, tokenizer):
max_length = 0
for text in self.data["Text"]:
encoded_length = len(tokenizer.encode(text))
if encoded_length > max_length:
max_length = encoded_length
return max_length
def download_and_unzip(url, zip_path, extract_to, new_file_path):
if new_file_path.exists():
print(f"{new_file_path} already exists. Skipping download and extraction.")
return
# Downloading the file
with urllib.request.urlopen(url) as response:
with open(zip_path, "wb") as out_file:
out_file.write(response.read())
# Unzipping the file
with zipfile.ZipFile(zip_path, "r") as zip_ref:
zip_ref.extractall(extract_to)
# Renaming the file to indicate its format
original_file = Path(extract_to) / "SMSSpamCollection"
os.rename(original_file, new_file_path)
print(f"File downloaded and saved as {new_file_path}")
def random_split(df, train_frac, validation_frac):
# Shuffle the entire DataFrame
df = df.sample(frac=1, random_state=123).reset_index(drop=True)
# Calculate split indices
train_end = int(len(df) * train_frac)
validation_end = train_end + int(len(df) * validation_frac)
# Split the DataFrame
train_df = df[:train_end]
validation_df = df[train_end:validation_end]
test_df = df[validation_end:]
return train_df, validation_df, test_df
def create_dataset_csvs(data_file_path):
df = pd.read_csv(new_file_path, sep="\t", header=None, names=["Label", "Text"])
# Create balanced dataset
n_spam = df[df["Label"] == "spam"].shape[0]
ham_sampled = df[df["Label"] == "ham"].sample(n_spam, random_state=123)
balanced_df = pd.concat([ham_sampled, df[df["Label"] == "spam"]])
balanced_df = balanced_df.sample(frac=1, random_state=123).reset_index(drop=True)
balanced_df["Label"] = balanced_df["Label"].map({"ham": 0, "spam": 1})
# Sample and save csv files
train_df, validation_df, test_df = random_split(balanced_df, 0.7, 0.1)
train_df.to_csv("train.csv", index=None)
validation_df.to_csv("validation.csv", index=None)
test_df.to_csv("test.csv", index=None)
def instantiate_model(choose_model, load_weights):
BASE_CONFIG = {
"vocab_size": 50257, # Vocabulary size
"context_length": 1024, # Context length
"drop_rate": 0.0, # Dropout rate
"qkv_bias": True # Query-key-value bias
}
model_configs = {
"gpt2-small (124M)": {"emb_dim": 768, "n_layers": 12, "n_heads": 12},
"gpt2-medium (355M)": {"emb_dim": 1024, "n_layers": 24, "n_heads": 16},
"gpt2-large (774M)": {"emb_dim": 1280, "n_layers": 36, "n_heads": 20},
"gpt2-xl (1558M)": {"emb_dim": 1600, "n_layers": 48, "n_heads": 25},
}
BASE_CONFIG.update(model_configs[choose_model])
if not load_weights:
torch.manual_seed(123)
model = GPTModel(BASE_CONFIG)
if load_weights:
model_size = choose_model.split(" ")[-1].lstrip("(").rstrip(")")
settings, params = download_and_load_gpt2(model_size=model_size, models_dir="gpt2")
load_weights_into_gpt(model, params)
model.eval()
return model
def calc_loss_batch(input_batch, target_batch, model, device, trainable_token=-1):
input_batch, target_batch = input_batch.to(device), target_batch.to(device)
logits = model(input_batch)[:, trainable_token, :] # Logits of last output token
loss = torch.nn.functional.cross_entropy(logits, target_batch)
return loss
def calc_loss_loader(data_loader, model, device, num_batches=None, trainable_token=-1):
total_loss = 0.
if len(data_loader) == 0:
return float("nan")
elif num_batches is None:
num_batches = len(data_loader)
else:
# Reduce the number of batches to match the total number of batches in the data loader
# if num_batches exceeds the number of batches in the data loader
num_batches = min(num_batches, len(data_loader))
for i, (input_batch, target_batch) in enumerate(data_loader):
if i < num_batches:
loss = calc_loss_batch(input_batch, target_batch, model, device, trainable_token=trainable_token)
total_loss += loss.item()
else:
break
return total_loss / num_batches
@torch.no_grad() # Disable gradient tracking for efficiency
def calc_accuracy_loader(data_loader, model, device, num_batches=None, trainable_token=-1):
model.eval()
correct_predictions, num_examples = 0, 0
if num_batches is None:
num_batches = len(data_loader)
else:
num_batches = min(num_batches, len(data_loader))
for i, (input_batch, target_batch) in enumerate(data_loader):
if i < num_batches:
input_batch, target_batch = input_batch.to(device), target_batch.to(device)
logits = model(input_batch)[:, trainable_token, :] # Logits of last output token
predicted_labels = torch.argmax(logits, dim=-1)
num_examples += predicted_labels.shape[0]
correct_predictions += (predicted_labels == target_batch).sum().item()
else:
break
return correct_predictions / num_examples
def evaluate_model(model, train_loader, val_loader, device, eval_iter, trainable_token=-1):
model.eval()
with torch.no_grad():
train_loss = calc_loss_loader(train_loader, model, device, num_batches=eval_iter, trainable_token=trainable_token)
val_loss = calc_loss_loader(val_loader, model, device, num_batches=eval_iter, trainable_token=trainable_token)
model.train()
return train_loss, val_loss
def train_classifier_simple(model, train_loader, val_loader, optimizer, device, num_epochs,
eval_freq, eval_iter, tokenizer, max_steps=None, trainable_token=-1,
accumulation_steps=1):
# Initialize lists to track losses and tokens seen
train_losses, val_losses, train_accs, val_accs = [], [], [], []
examples_seen, global_step = 0, -1
# Main training loop
for epoch in range(num_epochs):
model.train() # Set model to training mode
for batch_idx, (input_batch, target_batch) in enumerate(train_loader):
loss = calc_loss_batch(input_batch, target_batch, model, device, trainable_token=trainable_token)
# Use gradient accumulation if accumulation_steps > 1
# See https://sebastianraschka.com/blog/2023/llm-grad-accumulation.html
# for an explanation
loss /= accumulation_steps
loss.backward() # Calculate loss gradients
# Use gradient accumulation if accumulation_steps > 1
if batch_idx % accumulation_steps == 0:
optimizer.step() # Update model weights using loss gradients
optimizer.zero_grad() # Reset loss gradients from previous epoch
examples_seen += input_batch.shape[0] # New: track examples instead of tokens
global_step += 1
# Optional evaluation step
if global_step % eval_freq == 0:
train_loss, val_loss = evaluate_model(
model, train_loader, val_loader, device, eval_iter, trainable_token=trainable_token)
train_losses.append(train_loss)
val_losses.append(val_loss)
print(f"Ep {epoch+1} (Step {global_step:06d}): "
f"Train loss {train_loss:.3f}, Val loss {val_loss:.3f}")
if max_steps is not None and global_step > max_steps:
break
# New: Calculate accuracy after each epoch
train_accuracy = calc_accuracy_loader(train_loader, model, device, num_batches=eval_iter, trainable_token=trainable_token)
val_accuracy = calc_accuracy_loader(val_loader, model, device, num_batches=eval_iter, trainable_token=trainable_token)
print(f"Training accuracy: {train_accuracy*100:.2f}% | ", end="")
print(f"Validation accuracy: {val_accuracy*100:.2f}%")
train_accs.append(train_accuracy)
val_accs.append(val_accuracy)
if max_steps is not None and global_step > max_steps:
break
return train_losses, val_losses, train_accs, val_accs, examples_seen
def replace_linear_with_lora(model, rank, alpha):
for name, module in model.named_children():
if isinstance(module, torch.nn.Linear):
# Replace the Linear layer with LinearWithLoRA
setattr(model, name, LinearWithLoRA(module, rank, alpha))
else:
# Recursively apply the same function to child modules
replace_linear_with_lora(module, rank, alpha)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"--model_size",
type=str,
default="gpt2-small (124M)",
help=(
"Which GPT model to use. Options: 'gpt2-small (124M)', 'gpt2-medium (355M)',"
" 'gpt2-large (774M)', 'gpt2-xl (1558M)'."
)
)
parser.add_argument(
"--weights",
type=str,
default="pretrained",
help=(
"Whether to use 'pretrained' or 'random' weights."
)
)
parser.add_argument(
"--trainable_layers",
type=str,
default="last_block",
help=(
"Which layers to train. Options: 'all', 'last_block', 'last_layer', 'lora'."
)
)
parser.add_argument(
"--trainable_token",
type=str,
default="last",
help=(
"Which token to train. Options: 'first', 'last'."
)
)
parser.add_argument(
"--context_length",
type=str,
default="longest_training_example",
help=(
"The context length of the data inputs."
"Options: 'longest_training_example', 'model_context_length' or integer value."
)
)
parser.add_argument(
"--lora_rank",
type=int,
default=8,
help=(
"The LoRA rank when choosing `--trainable_layers lora`"
)
)
parser.add_argument(
"--lora_alpha",
type=int,
default=8,
help=(
"The LoRA alpha value when choosing `--trainable_layers lora`"
)
)
parser.add_argument(
"--no_padding",
action='store_true',
default=False,
help=(
"Disable padding, which means each example may have a different lenght."
" This requires setting `--batch_size 1`."
)
)
parser.add_argument(
"--num_epochs",
type=int,
default=5,
help=(
"Number of training epochs."
)
)
parser.add_argument(
"--batch_size",
type=int,
default=8,
help=(
"The batch size used for training."
)
)
parser.add_argument(
"--accumulation_steps",
type=int,
default=1,
help=(
"Accumulation steps to allow for gradient accumulation."
" See https://sebastianraschka.com/blog/2023/llm-grad-accumulation.html for explanation."
" For example, setting `batch_size=8` and `accumulation_steps=1` compute the exact same"
" loss and weight updates as setting `batch_size=1` and `accumulation_steps=8`, however,"
" the latter setting uses more iterations."
)
)
args = parser.parse_args()
if args.trainable_token == "first":
args.trainable_token = 0
elif args.trainable_token == "last":
args.trainable_token = -1
else:
raise ValueError("Invalid --trainable_token argument")
###############################
# Load model
###############################
if args.weights == "pretrained":
load_weights = True
elif args.weights == "random":
load_weights = False
else:
raise ValueError("Invalid --weights argument.")
model = instantiate_model(args.model_size, load_weights)
for param in model.parameters():
param.requires_grad = False
if args.model_size == "gpt2-small (124M)":
in_features = 768
elif args.model_size == "gpt2-medium (355M)":
in_features = 1024
elif args.model_size == "gpt2-large (774M)":
in_features = 1280
elif args.model_size == "gpt2-xl (1558M)":
in_features = 1600
else:
raise ValueError("Invalid --model_size argument")
torch.manual_seed(123)
model.out_head = torch.nn.Linear(in_features=in_features, out_features=2)
if args.trainable_layers == "last_layer":
pass
elif args.trainable_layers == "last_block":
for param in model.trf_blocks[-1].parameters():
param.requires_grad = True
for param in model.final_norm.parameters():
param.requires_grad = True
elif args.trainable_layers == "all":
for param in model.parameters():
param.requires_grad = True
elif args.trainable_layers == "lora":
replace_linear_with_lora(model, rank=args.lora_rank, alpha=args.lora_alpha)
else:
raise ValueError("Invalid --trainable_layers argument.")
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
###############################
# Instantiate dataloaders
###############################
url = "https://archive.ics.uci.edu/static/public/228/sms+spam+collection.zip"
zip_path = "sms_spam_collection.zip"
extract_to = "sms_spam_collection"
new_file_path = Path(extract_to) / "SMSSpamCollection.tsv"
base_path = Path(".")
file_names = ["train.csv", "validation.csv", "test.csv"]
all_exist = all((base_path / file_name).exists() for file_name in file_names)
if not all_exist:
download_and_unzip(url, zip_path, extract_to, new_file_path)
create_dataset_csvs(new_file_path)
tokenizer = tiktoken.get_encoding("gpt2")
train_dataset = None
if args.no_padding:
max_length = None
else:
if args.context_length == "model_context_length":
max_length = model.pos_emb.weight.shape[0]
elif args.context_length == "longest_training_example":
train_dataset = SpamDataset(base_path / "train.csv", max_length=None, tokenizer=tokenizer, no_padding=args.no_padding)
max_length = train_dataset.max_length
else:
try:
max_length = int(args.context_length)
except ValueError:
raise ValueError("Invalid --context_length argument")
if train_dataset is None:
train_dataset = SpamDataset(base_path / "train.csv", max_length=max_length, tokenizer=tokenizer, no_padding=args.no_padding)
val_dataset = SpamDataset(base_path / "validation.csv", max_length=max_length, tokenizer=tokenizer, no_padding=args.no_padding)
test_dataset = SpamDataset(base_path / "test.csv", max_length=max_length, tokenizer=tokenizer, no_padding=args.no_padding)
tokenizer = tiktoken.get_encoding("gpt2")
num_workers = 0
train_loader = DataLoader(
dataset=train_dataset,
batch_size=args.batch_size,
shuffle=True,
num_workers=num_workers,
drop_last=True,
)
val_loader = DataLoader(
dataset=val_dataset,
batch_size=args.batch_size,
num_workers=num_workers,
drop_last=False,
)
test_loader = DataLoader(
dataset=test_dataset,
batch_size=args.batch_size,
num_workers=num_workers,
drop_last=False,
)
###############################
# Train model
###############################
start_time = time.time()
torch.manual_seed(123)
optimizer = torch.optim.AdamW(model.parameters(), lr=5e-5, weight_decay=0.1)
train_losses, val_losses, train_accs, val_accs, examples_seen = train_classifier_simple(
model, train_loader, val_loader, optimizer, device,
num_epochs=args.num_epochs, eval_freq=50, eval_iter=5,
tokenizer=tokenizer, max_steps=None, trainable_token=args.trainable_token,
accumulation_steps=args.accumulation_steps
)
end_time = time.time()
execution_time_minutes = (end_time - start_time) / 60
print(f"Training completed in {execution_time_minutes:.2f} minutes.")
###############################
# Evaluate model
###############################
train_accuracy = calc_accuracy_loader(train_loader, model, device, trainable_token=args.trainable_token)
val_accuracy = calc_accuracy_loader(val_loader, model, device, trainable_token=args.trainable_token)
test_accuracy = calc_accuracy_loader(test_loader, model, device, trainable_token=args.trainable_token)
print(f"Training accuracy: {train_accuracy*100:.2f}%")
print(f"Validation accuracy: {val_accuracy*100:.2f}%")
print(f"Test accuracy: {test_accuracy*100:.2f}%")

View File

@ -0,0 +1,99 @@
# Copyright (c) Sebastian Raschka under Apache License 2.0 (see LICENSE.txt).
# Source for "Build a Large Language Model From Scratch"
# - https://www.manning.com/books/build-a-large-language-model-from-scratch
# Code: https://github.com/rasbt/LLMs-from-scratch
import os
import requests
import json
import numpy as np
import tensorflow as tf
from tqdm import tqdm
def download_and_load_gpt2(model_size, models_dir):
# Validate model size
allowed_sizes = ("124M", "355M", "774M", "1558M")
if model_size not in allowed_sizes:
raise ValueError(f"Model size not in {allowed_sizes}")
# Define paths
model_dir = os.path.join(models_dir, model_size)
base_url = "https://openaipublic.blob.core.windows.net/gpt-2/models"
filenames = [
"checkpoint", "encoder.json", "hparams.json",
"model.ckpt.data-00000-of-00001", "model.ckpt.index",
"model.ckpt.meta", "vocab.bpe"
]
# Download files
os.makedirs(model_dir, exist_ok=True)
for filename in filenames:
file_url = os.path.join(base_url, model_size, filename)
file_path = os.path.join(model_dir, filename)
download_file(file_url, file_path)
# Load settings and params
tf_ckpt_path = tf.train.latest_checkpoint(model_dir)
settings = json.load(open(os.path.join(model_dir, "hparams.json")))
params = load_gpt2_params_from_tf_ckpt(tf_ckpt_path, settings)
return settings, params
def download_file(url, destination):
# Send a GET request to download the file in streaming mode
response = requests.get(url, stream=True)
# Get the total file size from headers, defaulting to 0 if not present
file_size = int(response.headers.get("content-length", 0))
# Check if file exists and has the same size
if os.path.exists(destination):
file_size_local = os.path.getsize(destination)
if file_size == file_size_local:
print(f"File already exists and is up-to-date: {destination}")
return
# Define the block size for reading the file
block_size = 1024 # 1 Kilobyte
# Initialize the progress bar with total file size
progress_bar_description = url.split("/")[-1] # Extract filename from URL
with tqdm(total=file_size, unit="iB", unit_scale=True, desc=progress_bar_description) as progress_bar:
# Open the destination file in binary write mode
with open(destination, "wb") as file:
# Iterate over the file data in chunks
for chunk in response.iter_content(block_size):
progress_bar.update(len(chunk)) # Update progress bar
file.write(chunk) # Write the chunk to the file
def load_gpt2_params_from_tf_ckpt(ckpt_path, settings):
# Initialize parameters dictionary with empty blocks for each layer
params = {"blocks": [{} for _ in range(settings["n_layer"])]}
# Iterate over each variable in the checkpoint
for name, _ in tf.train.list_variables(ckpt_path):
# Load the variable and remove singleton dimensions
variable_array = np.squeeze(tf.train.load_variable(ckpt_path, name))
# Process the variable name to extract relevant parts
variable_name_parts = name.split("/")[1:] # Skip the 'model/' prefix
# Identify the target dictionary for the variable
target_dict = params
if variable_name_parts[0].startswith("h"):
layer_number = int(variable_name_parts[0][1:])
target_dict = params["blocks"][layer_number]
# Recursively access or create nested dictionaries
for key in variable_name_parts[1:-1]:
target_dict = target_dict.setdefault(key, {})
# Assign the variable array to the last key
last_key = variable_name_parts[-1]
target_dict[last_key] = variable_array
return params

View File

@ -0,0 +1,345 @@
# Copyright (c) Sebastian Raschka under Apache License 2.0 (see LICENSE.txt).
# Source for "Build a Large Language Model From Scratch"
# - https://www.manning.com/books/build-a-large-language-model-from-scratch
# Code: https://github.com/rasbt/LLMs-from-scratch
#
# This file collects all the relevant code that we covered thus far
# throughout Chapters 2-5.
# This file can be run as a standalone script.
import numpy as np
import tiktoken
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
#####################################
# Chapter 2
#####################################
class GPTDatasetV1(Dataset):
def __init__(self, txt, tokenizer, max_length, stride):
self.tokenizer = tokenizer
self.input_ids = []
self.target_ids = []
# Tokenize the entire text
token_ids = tokenizer.encode(txt)
# Use a sliding window to chunk the book into overlapping sequences of max_length
for i in range(0, len(token_ids) - max_length, stride):
input_chunk = token_ids[i:i + max_length]
target_chunk = token_ids[i + 1: i + max_length + 1]
self.input_ids.append(torch.tensor(input_chunk))
self.target_ids.append(torch.tensor(target_chunk))
def __len__(self):
return len(self.input_ids)
def __getitem__(self, idx):
return self.input_ids[idx], self.target_ids[idx]
def create_dataloader_v1(txt, batch_size=4, max_length=256,
stride=128, shuffle=True, drop_last=True):
# Initialize the tokenizer
tokenizer = tiktoken.get_encoding("gpt2")
# Create dataset
dataset = GPTDatasetV1(txt, tokenizer, max_length, stride)
# Create dataloader
dataloader = DataLoader(
dataset, batch_size=batch_size, shuffle=shuffle, drop_last=drop_last)
return dataloader
#####################################
# Chapter 3
#####################################
class MultiHeadAttention(nn.Module):
def __init__(self, d_in, d_out, context_length, dropout, num_heads, qkv_bias=False):
super().__init__()
assert d_out % num_heads == 0, "d_out must be divisible by n_heads"
self.d_out = d_out
self.num_heads = num_heads
self.head_dim = d_out // num_heads # Reduce the projection dim to match desired output dim
self.W_query = nn.Linear(d_in, d_out, bias=qkv_bias)
self.W_key = nn.Linear(d_in, d_out, bias=qkv_bias)
self.W_value = nn.Linear(d_in, d_out, bias=qkv_bias)
self.out_proj = nn.Linear(d_out, d_out) # Linear layer to combine head outputs
self.dropout = nn.Dropout(dropout)
self.register_buffer('mask', torch.triu(torch.ones(context_length, context_length), diagonal=1))
def forward(self, x):
b, num_tokens, d_in = x.shape
keys = self.W_key(x) # Shape: (b, num_tokens, d_out)
queries = self.W_query(x)
values = self.W_value(x)
# We implicitly split the matrix by adding a `num_heads` dimension
# Unroll last dim: (b, num_tokens, d_out) -> (b, num_tokens, num_heads, head_dim)
keys = keys.view(b, num_tokens, self.num_heads, self.head_dim)
values = values.view(b, num_tokens, self.num_heads, self.head_dim)
queries = queries.view(b, num_tokens, self.num_heads, self.head_dim)
# Transpose: (b, num_tokens, num_heads, head_dim) -> (b, num_heads, num_tokens, head_dim)
keys = keys.transpose(1, 2)
queries = queries.transpose(1, 2)
values = values.transpose(1, 2)
# Compute scaled dot-product attention (aka self-attention) with a causal mask
attn_scores = queries @ keys.transpose(2, 3) # Dot product for each head
# Original mask truncated to the number of tokens and converted to boolean
mask_bool = self.mask.bool()[:num_tokens, :num_tokens]
# Use the mask to fill attention scores
attn_scores.masked_fill_(mask_bool, -torch.inf)
attn_weights = torch.softmax(attn_scores / keys.shape[-1]**0.5, dim=-1)
attn_weights = self.dropout(attn_weights)
# Shape: (b, num_tokens, num_heads, head_dim)
context_vec = (attn_weights @ values).transpose(1, 2)
# Combine heads, where self.d_out = self.num_heads * self.head_dim
context_vec = context_vec.reshape(b, num_tokens, self.d_out)
context_vec = self.out_proj(context_vec) # optional projection
return context_vec
#####################################
# Chapter 4
#####################################
class LayerNorm(nn.Module):
def __init__(self, emb_dim):
super().__init__()
self.eps = 1e-5
self.scale = nn.Parameter(torch.ones(emb_dim))
self.shift = nn.Parameter(torch.zeros(emb_dim))
def forward(self, x):
mean = x.mean(dim=-1, keepdim=True)
var = x.var(dim=-1, keepdim=True, unbiased=False)
norm_x = (x - mean) / torch.sqrt(var + self.eps)
return self.scale * norm_x + self.shift
class GELU(nn.Module):
def __init__(self):
super().__init__()
def forward(self, x):
return 0.5 * x * (1 + torch.tanh(
torch.sqrt(torch.tensor(2.0 / torch.pi)) *
(x + 0.044715 * torch.pow(x, 3))
))
class FeedForward(nn.Module):
def __init__(self, cfg):
super().__init__()
self.layers = nn.Sequential(
nn.Linear(cfg["emb_dim"], 4 * cfg["emb_dim"]),
GELU(),
nn.Linear(4 * cfg["emb_dim"], cfg["emb_dim"]),
)
def forward(self, x):
return self.layers(x)
class TransformerBlock(nn.Module):
def __init__(self, cfg):
super().__init__()
self.att = MultiHeadAttention(
d_in=cfg["emb_dim"],
d_out=cfg["emb_dim"],
context_length=cfg["context_length"],
num_heads=cfg["n_heads"],
dropout=cfg["drop_rate"],
qkv_bias=cfg["qkv_bias"])
self.ff = FeedForward(cfg)
self.norm1 = LayerNorm(cfg["emb_dim"])
self.norm2 = LayerNorm(cfg["emb_dim"])
self.drop_shortcut = nn.Dropout(cfg["drop_rate"])
def forward(self, x):
# Shortcut connection for attention block
shortcut = x
x = self.norm1(x)
x = self.att(x) # Shape [batch_size, num_tokens, emb_size]
x = self.drop_shortcut(x)
x = x + shortcut # Add the original input back
# Shortcut connection for feed-forward block
shortcut = x
x = self.norm2(x)
x = self.ff(x)
x = self.drop_shortcut(x)
x = x + shortcut # Add the original input back
return x
class GPTModel(nn.Module):
def __init__(self, cfg):
super().__init__()
self.tok_emb = nn.Embedding(cfg["vocab_size"], cfg["emb_dim"])
self.pos_emb = nn.Embedding(cfg["context_length"], cfg["emb_dim"])
self.drop_emb = nn.Dropout(cfg["drop_rate"])
self.trf_blocks = nn.Sequential(
*[TransformerBlock(cfg) for _ in range(cfg["n_layers"])])
self.final_norm = LayerNorm(cfg["emb_dim"])
self.out_head = nn.Linear(cfg["emb_dim"], cfg["vocab_size"], bias=False)
def forward(self, in_idx):
batch_size, seq_len = in_idx.shape
tok_embeds = self.tok_emb(in_idx)
pos_embeds = self.pos_emb(torch.arange(seq_len, device=in_idx.device))
x = tok_embeds + pos_embeds # Shape [batch_size, num_tokens, emb_size]
x = self.drop_emb(x)
x = self.trf_blocks(x)
x = self.final_norm(x)
logits = self.out_head(x)
return logits
def generate_text_simple(model, idx, max_new_tokens, context_size):
# idx is (B, T) array of indices in the current context
for _ in range(max_new_tokens):
# Crop current context if it exceeds the supported context size
# E.g., if LLM supports only 5 tokens, and the context size is 10
# then only the last 5 tokens are used as context
idx_cond = idx[:, -context_size:]
# Get the predictions
with torch.no_grad():
logits = model(idx_cond)
# Focus only on the last time step
# (batch, n_token, vocab_size) becomes (batch, vocab_size)
logits = logits[:, -1, :]
# Get the idx of the vocab entry with the highest logits value
idx_next = torch.argmax(logits, dim=-1, keepdim=True) # (batch, 1)
# Append sampled index to the running sequence
idx = torch.cat((idx, idx_next), dim=1) # (batch, n_tokens+1)
return idx
#####################################
# Chapter 5
#####################################
def assign(left, right):
if left.shape != right.shape:
raise ValueError(f"Shape mismatch. Left: {left.shape}, Right: {right.shape}")
return torch.nn.Parameter(torch.tensor(right))
def load_weights_into_gpt(gpt, params):
gpt.pos_emb.weight = assign(gpt.pos_emb.weight, params['wpe'])
gpt.tok_emb.weight = assign(gpt.tok_emb.weight, params['wte'])
for b in range(len(params["blocks"])):
q_w, k_w, v_w = np.split(
(params["blocks"][b]["attn"]["c_attn"])["w"], 3, axis=-1)
gpt.trf_blocks[b].att.W_query.weight = assign(
gpt.trf_blocks[b].att.W_query.weight, q_w.T)
gpt.trf_blocks[b].att.W_key.weight = assign(
gpt.trf_blocks[b].att.W_key.weight, k_w.T)
gpt.trf_blocks[b].att.W_value.weight = assign(
gpt.trf_blocks[b].att.W_value.weight, v_w.T)
q_b, k_b, v_b = np.split(
(params["blocks"][b]["attn"]["c_attn"])["b"], 3, axis=-1)
gpt.trf_blocks[b].att.W_query.bias = assign(
gpt.trf_blocks[b].att.W_query.bias, q_b)
gpt.trf_blocks[b].att.W_key.bias = assign(
gpt.trf_blocks[b].att.W_key.bias, k_b)
gpt.trf_blocks[b].att.W_value.bias = assign(
gpt.trf_blocks[b].att.W_value.bias, v_b)
gpt.trf_blocks[b].att.out_proj.weight = assign(
gpt.trf_blocks[b].att.out_proj.weight,
params["blocks"][b]["attn"]["c_proj"]["w"].T)
gpt.trf_blocks[b].att.out_proj.bias = assign(
gpt.trf_blocks[b].att.out_proj.bias,
params["blocks"][b]["attn"]["c_proj"]["b"])
gpt.trf_blocks[b].ff.layers[0].weight = assign(
gpt.trf_blocks[b].ff.layers[0].weight,
params["blocks"][b]["mlp"]["c_fc"]["w"].T)
gpt.trf_blocks[b].ff.layers[0].bias = assign(
gpt.trf_blocks[b].ff.layers[0].bias,
params["blocks"][b]["mlp"]["c_fc"]["b"])
gpt.trf_blocks[b].ff.layers[2].weight = assign(
gpt.trf_blocks[b].ff.layers[2].weight,
params["blocks"][b]["mlp"]["c_proj"]["w"].T)
gpt.trf_blocks[b].ff.layers[2].bias = assign(
gpt.trf_blocks[b].ff.layers[2].bias,
params["blocks"][b]["mlp"]["c_proj"]["b"])
gpt.trf_blocks[b].norm1.scale = assign(
gpt.trf_blocks[b].norm1.scale,
params["blocks"][b]["ln_1"]["g"])
gpt.trf_blocks[b].norm1.shift = assign(
gpt.trf_blocks[b].norm1.shift,
params["blocks"][b]["ln_1"]["b"])
gpt.trf_blocks[b].norm2.scale = assign(
gpt.trf_blocks[b].norm2.scale,
params["blocks"][b]["ln_2"]["g"])
gpt.trf_blocks[b].norm2.shift = assign(
gpt.trf_blocks[b].norm2.shift,
params["blocks"][b]["ln_2"]["b"])
gpt.final_norm.scale = assign(gpt.final_norm.scale, params["g"])
gpt.final_norm.shift = assign(gpt.final_norm.shift, params["b"])
gpt.out_head.weight = assign(gpt.out_head.weight, params["wte"])
def generate(model, idx, max_new_tokens, context_size, temperature, top_k=None):
# For-loop is the same as before: Get logits, and only focus on last time step
for _ in range(max_new_tokens):
idx_cond = idx[:, -context_size:]
with torch.no_grad():
logits = model(idx_cond)
logits = logits[:, -1, :]
# New: Filter logits with top_k sampling
if top_k is not None:
# Keep only top_k values
top_logits, _ = torch.topk(logits, top_k)
min_val = top_logits[:, -1]
logits = torch.where(logits < min_val, torch.tensor(float('-inf')).to(logits.device), logits)
# New: Apply temperature scaling
if temperature > 0.0:
logits = logits / temperature
# Apply softmax to get probabilities
probs = torch.softmax(logits, dim=-1) # (batch_size, context_len)
# Sample from the distribution
idx_next = torch.multinomial(probs, num_samples=1) # (batch_size, 1)
# Otherwise same as before: get idx of the vocab entry with the highest logits value
else:
idx_next = torch.argmax(logits, dim=-1, keepdim=True) # (batch_size, 1)
# Same as before: append sampled index to the running sequence
idx = torch.cat((idx, idx_next), dim=1) # (batch_size, num_tokens+1)
return idx