llms-from-scratch-cn/ch03/01_main-chapter-code/multihead-attention.ipynb
2024-03-01 15:18:29 +08:00

456 lines
16 KiB
Plaintext
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

{
"cells": [
{
"cell_type": "markdown",
"source": [
"# 多头注意力机制结合数据加载"
],
"metadata": {
"collapsed": false
},
"id": "6c73d8caae2916ec"
},
{
"cell_type": "markdown",
"source": [
"完整的章节代码位于 [ch03.ipynb](./ch03.ipynb)。\n",
"\n",
"这个 notebook 包含了主要的学习点,即多头注意力的实现(以及第二章中的数据加载流程)。"
],
"metadata": {
"collapsed": false
},
"id": "660c1de13c8952cf"
},
{
"cell_type": "markdown",
"source": [
"## 第二章中的数据加载器"
],
"metadata": {
"collapsed": false
},
"id": "62c346f743383e17"
},
{
"cell_type": "code",
"execution_count": 1,
"outputs": [],
"source": [
"# 导入必要的库\n",
"import tiktoken # 假设这是一个自定义的分词库\n",
"import torch\n",
"import torch.nn as nn\n",
"from torch.utils.data import Dataset, DataLoader\n",
"\n",
"# 定义 GPT 数据集类\n",
"class GPTDatasetV1(Dataset):\n",
" def __init__(self, txt, tokenizer, max_length, stride):\n",
" # 初始化分词器\n",
" self.tokenizer = tokenizer\n",
" # 初始化输入和目标 ID 列表\n",
" self.input_ids = []\n",
" self.target_ids = []\n",
"\n",
" # 对整个文本进行分词\n",
" token_ids = tokenizer.encode(txt, allowed_special={'<|endoftext|>'})\n",
"\n",
" # 使用滑动窗口将文本分割成重叠的 max_length 长度的序列\n",
" for i in range(0, len(token_ids) - max_length, stride):\n",
" input_chunk = token_ids[i:i + max_length]\n",
" target_chunk = token_ids[i + 1: i + max_length + 1]\n",
" # 将分词 ID 转换为 PyTorch 张量并添加到列表\n",
" self.input_ids.append(torch.tensor(input_chunk))\n",
" self.target_ids.append(torch.tensor(target_chunk))\n",
"\n",
" def __len__(self):\n",
" # 返回数据集中的样本数量\n",
" return len(self.input_ids)\n",
"\n",
" def __getitem__(self, idx):\n",
" # 根据索引返回对应的输入和目标张量\n",
" return self.input_ids[idx], self.target_ids[idx]\n",
"\n",
"# 创建数据加载器的函数\n",
"def create_dataloader(txt, batch_size=4, max_length=256, stride=128, shuffle=True):\n",
" # 初始化分词器\n",
" tokenizer = tiktoken.get_encoding(\"gpt2\")\n",
"\n",
" # 创建数据集\n",
" dataset = GPTDatasetV1(txt, tokenizer, max_length, stride)\n",
"\n",
" # 创建数据加载器\n",
" dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=shuffle)\n",
"\n",
" return dataloader\n",
"\n",
"# 读取文本文件\n",
"with open(\"small-text-sample.txt\", \"r\", encoding=\"utf-8\") as f:\n",
" raw_text = f.read()\n",
"\n",
"# 初始化分词器\n",
"tokenizer = tiktoken.get_encoding(\"gpt2\")\n",
"# 对原始文本进行编码\n",
"encoded_text = tokenizer.encode(raw_text)\n",
"\n",
"# 定义词汇表大小、输出维度、最大长度和块大小\n",
"vocab_size = 50257\n",
"output_dim = 256\n",
"max_len = 1024\n",
"block_size = max_len\n",
"\n",
"# 创建词嵌入层和位置嵌入层\n",
"token_embedding_layer = nn.Embedding(vocab_size, output_dim)\n",
"pos_embedding_layer = torch.nn.Embedding(block_size, output_dim)\n",
"\n",
"# 设置最大长度为 4这可能是一个错误因为通常最大长度会大于 4\n",
"max_length = 4\n",
"# 创建数据加载器\n",
"dataloader = create_dataloader(raw_text, batch_size=8, max_length=max_length, stride=5)"
],
"metadata": {
"collapsed": false,
"ExecuteTime": {
"end_time": "2024-02-28T10:24:30.859707800Z",
"start_time": "2024-02-28T10:24:28.599202800Z"
}
},
"id": "d6020b82c7b6dd6b"
},
{
"cell_type": "code",
"execution_count": 2,
"id": "664397bc-6daa-4b88-90aa-e8fc1fbd5846",
"metadata": {
"ExecuteTime": {
"end_time": "2024-02-28T10:24:30.875711300Z",
"start_time": "2024-02-28T10:24:30.861708300Z"
}
},
"outputs": [],
"source": [
"# 遍历数据加载器中的每个批次\n",
"for batch in dataloader:\n",
" # 从当前批次中解包输入和目标数据\n",
" x, y = batch\n",
"\n",
" # 使用词嵌入层计算输入序列的词嵌入\n",
" token_embeddings = token_embedding_layer(x)\n",
" # 使用位置嵌入层计算位置嵌入,这里使用 torch.arange 创建一个与 max_length 相同长度的序列\n",
" pos_embeddings = pos_embedding_layer(torch.arange(max_length))\n",
"\n",
" # 将词嵌入和位置嵌入相加,得到最终的输入嵌入\n",
" input_embeddings = token_embeddings + pos_embeddings\n",
"\n",
" # 跳出循环,这里可能是为了演示目的,实际训练中通常会继续循环直到遍历完所有数据\n",
" break"
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "d3664332-e6bb-447e-8b96-203aafde8b24",
"metadata": {
"ExecuteTime": {
"end_time": "2024-02-28T10:24:30.921720700Z",
"start_time": "2024-02-28T10:24:30.876711400Z"
}
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"torch.Size([8, 4, 256])\n"
]
}
],
"source": [
"print(input_embeddings.shape)"
]
},
{
"cell_type": "markdown",
"source": [
"# 第三章中的多头注意力"
],
"metadata": {
"collapsed": false
},
"id": "a6589588dfab8a73"
},
{
"cell_type": "markdown",
"source": [
"## 方式A简单实现"
],
"metadata": {
"collapsed": false
},
"id": "3960118f9f988847"
},
{
"cell_type": "code",
"execution_count": 4,
"id": "a44e682d-1c3c-445d-85fa-b142f89f8503",
"metadata": {
"ExecuteTime": {
"end_time": "2024-02-28T10:24:30.925722100Z",
"start_time": "2024-02-28T10:24:30.891714900Z"
}
},
"outputs": [],
"source": [
"# 定义因果自注意力模块\n",
"class CausalSelfAttention(nn.Module):\n",
" def __init__(self, d_in, d_out, block_size, dropout, qkv_bias=False):\n",
" # 调用父类构造函数\n",
" super().__init__()\n",
" # 输出维度\n",
" self.d_out = d_out\n",
" # 查询、键和值的线性变换层\n",
" self.W_query = nn.Linear(d_in, d_out, bias=qkv_bias)\n",
" self.W_key = nn.Linear(d_in, d_out, bias=qkv_bias)\n",
" self.W_value = nn.Linear(d_in, d_out, bias=qkv_bias)\n",
" # Dropout层用于正则化\n",
" self.dropout = nn.Dropout(dropout)\n",
" # 注册一个缓冲区,用于存储上三角掩码,用于因果自注意力\n",
" self.register_buffer('mask', torch.triu(torch.ones(block_size, block_size), diagonal=1))\n",
"\n",
" def forward(self, x):\n",
" # 获取输入张量的批次大小、序列长度和输入维度\n",
" b, n_tokens, d_in = x.shape\n",
" # 分别计算查询、键和值\n",
" keys = self.W_key(x)\n",
" queries = self.W_query(x)\n",
" values = self.W_value(x)\n",
"\n",
" # 计算注意力分数,这里使用了转置操作\n",
" attn_scores = queries @ keys.transpose(1, 2)\n",
" # 使用掩码将未来位置的注意力分数置为负无穷,实现因果自注意力\n",
" attn_scores.masked_fill_(self.mask.bool()[:n_tokens, :n_tokens], -torch.inf)\n",
" # 归一化注意力分数\n",
" attn_weights = torch.softmax(attn_scores / keys.shape[-1]**0.5, dim=1)\n",
" # 应用dropout\n",
" attn_weights = self.dropout(attn_weights)\n",
"\n",
" # 计算上下文向量\n",
" context_vec = attn_weights @ values\n",
" return context_vec\n",
"\n",
"# 定义多头注意力包装器模块\n",
"class MultiHeadAttentionWrapper(nn.Module):\n",
" def __init__(self, d_in, d_out, block_size, dropout, num_heads, qkv_bias=False):\n",
" # 调用父类构造函数\n",
" super().__init__()\n",
" # 创建多个因果自注意力模块\n",
" self.heads = nn.ModuleList([\n",
" CausalSelfAttention(d_in, d_out, block_size, dropout, qkv_bias) \n",
" for _ in range(num_heads)\n",
" ])\n",
" # 输出投影层,用于将多头的输出合并\n",
" self.out_proj = nn.Linear(d_out*num_heads, d_out*num_heads)\n",
"\n",
" def forward(self, x):\n",
" # 将所有头的输出沿着最后一个维度拼接\n",
" context_vec = torch.cat([head(x) for head in self.heads], dim=-1)\n",
" # 通过输出投影层\n",
" return self.out_proj(context_vec)"
]
},
{
"cell_type": "code",
"execution_count": 5,
"id": "7898551e-f582-48ac-9f66-3632abe2a93f",
"metadata": {
"ExecuteTime": {
"end_time": "2024-02-28T10:24:30.927722100Z",
"start_time": "2024-02-28T10:24:30.910719200Z"
}
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"context_vecs.shape: torch.Size([8, 4, 256])\n"
]
}
],
"source": [
"# 设置随机种子以确保结果的可重复性\n",
"torch.manual_seed(123)\n",
"\n",
"# 定义块大小,这里与最大长度相同\n",
"block_size = max_length\n",
"# 输入维度\n",
"d_in = output_dim\n",
"\n",
"# 定义多头注意力中每个头的输出维度\n",
"num_heads = 2\n",
"# 输出维度是输入维度除以头数\n",
"d_out = d_in // num_heads\n",
"\n",
"# 初始化多头注意力包装器\n",
"mha = MultiHeadAttentionWrapper(d_in, d_out, block_size, 0.0, num_heads)\n",
"\n",
"# 假设 input_embeddings 是之前准备好的输入数据\n",
"batch = input_embeddings\n",
"# 使用多头注意力模块处理输入数据\n",
"context_vecs = mha(batch)\n",
"\n",
"# 打印上下文向量的维度\n",
"print(\"context_vecs.shape:\", context_vecs.shape)"
]
},
{
"cell_type": "markdown",
"source": [
"## 方式B替代实现"
],
"metadata": {
"collapsed": false
},
"id": "6d5ad940a95ee6a7"
},
{
"cell_type": "code",
"execution_count": 6,
"id": "2773c09d-c136-4372-a2be-04b58d292842",
"metadata": {
"ExecuteTime": {
"end_time": "2024-02-28T10:24:30.939725500Z",
"start_time": "2024-02-28T10:24:30.923721600Z"
}
},
"outputs": [],
"source": [
"# 定义多头自注意力模块\n",
"class MultiHeadAttention(nn.Module):\n",
" def __init__(self, d_in, d_out, block_size, dropout, num_heads, qkv_bias=False):\n",
" # 调用父类构造函数\n",
" super().__init__()\n",
" # 确保输出维度可以被头数整除\n",
" assert d_out % num_heads == 0, \"d_out must be divisible by n_heads\"\n",
"\n",
" # 初始化模块的属性\n",
" self.d_out = d_out\n",
" self.num_heads = num_heads\n",
" self.head_dim = d_out // num_heads # 计算每个头的维度\n",
" self.W_query = nn.Linear(d_in, d_out, bias=qkv_bias) # 查询线性层\n",
" self.W_key = nn.Linear(d_in, d_out, bias=qkv_bias) # 键线性层\n",
" self.W_value = nn.Linear(d_in, d_out, bias=qkv_bias) # 值线性层\n",
" self.out_proj = nn.Linear(d_out, d_out) # 输出投影层\n",
" self.dropout = nn.Dropout(dropout) # Dropout层\n",
" # 注册一个缓冲区,用于存储上三角掩码,用于因果自注意力\n",
" self.register_buffer('mask', torch.triu(torch.ones(block_size, block_size), diagonal=1))\n",
"\n",
" def forward(self, x):\n",
" # 获取输入张量的批次大小、序列长度和输入维度\n",
" b, num_tokens, d_in = x.shape\n",
"\n",
" # 分别计算查询、键和值\n",
" keys = self.W_key(x)\n",
" queries = self.W_query(x)\n",
" values = self.W_value(x)\n",
"\n",
" # 将矩阵按头数分割,并添加一个维度\n",
" keys = keys.view(b, num_tokens, self.num_heads, self.head_dim)\n",
" values = values.view(b, num_tokens, self.num_heads, self.head_dim)\n",
" queries = queries.view(b, num_tokens, self.num_heads, self.head_dim)\n",
"\n",
" # 转置以匹配多头注意力的维度\n",
" keys = keys.transpose(1, 2)\n",
" queries = queries.transpose(1, 2)\n",
" values = values.transpose(1, 2)\n",
"\n",
" # 计算缩放点积注意力分数,并应用因果掩码\n",
" attn_scores = queries @ keys.transpose(2, 3) # 对每个头进行点积\n",
" # 将掩码截断到与序列长度相匹配,并转换为布尔值\n",
" mask_bool = self.mask.bool()[:num_tokens, :num_tokens]\n",
" # 扩展掩码以匹配维度\n",
" mask_unsqueezed = mask_bool.unsqueeze(0).unsqueeze(0)\n",
" # 使用扩展的掩码填充注意力分数\n",
" attn_scores.masked_fill_(mask_unsqueezed, -torch.inf)\n",
" \n",
" # 归一化注意力分数\n",
" attn_weights = torch.softmax(attn_scores / keys.shape[-1]**0.5, dim=-1)\n",
" attn_weights = self.dropout(attn_weights)\n",
"\n",
" # 计算上下文向量\n",
" context_vec = (attn_weights @ values).transpose(1, 2)\n",
" \n",
" # 合并头的输出\n",
" context_vec = context_vec.contiguous().view(b, num_tokens, self.d_out)\n",
" # 可选的输出投影\n",
" context_vec = self.out_proj(context_vec)\n",
"\n",
" return context_vec"
]
},
{
"cell_type": "code",
"execution_count": 7,
"id": "779fdd04-0152-4308-af08-840800a7f395",
"metadata": {
"ExecuteTime": {
"end_time": "2024-02-28T10:24:30.984735500Z",
"start_time": "2024-02-28T10:24:30.941725600Z"
}
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"context_vecs.shape: torch.Size([8, 4, 256])\n"
]
}
],
"source": [
"# 设置随机种子以确保结果的可重复性\n",
"torch.manual_seed(123)\n",
"\n",
"# 定义块大小,这里与最大长度相同\n",
"block_size = max_length\n",
"# 输入和输出维度\n",
"d_in = output_dim\n",
"# 输出维度设置为与输入维度相同\n",
"d_out = d_in\n",
"\n",
"# 初始化多头自注意力模块\n",
"mha = MultiHeadAttention(d_in, d_out, block_size, dropout=0.0, num_heads=2)\n",
"\n",
"# 假设 input_embeddings 是之前准备好的输入数据\n",
"batch = input_embeddings\n",
"# 使用多头自注意力模块处理输入数据\n",
"context_vecs = mha(batch)\n",
"\n",
"# 打印上下文向量的维度\n",
"print(\"context_vecs.shape:\", context_vecs.shape)"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.12"
}
},
"nbformat": 4,
"nbformat_minor": 5
}