llms-from-scratch-cn/Translated_Book/ch04/4.7 生成文本.ipynb
2024-08-06 23:07:20 +08:00

229 lines
6.5 KiB
Plaintext
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

{
"cells": [
{
"cell_type": "markdown",
"id": "da5d9bc0-95ab-45d4-9378-417628d86e35",
"metadata": {},
"source": [
"## 4.7 生成文本"
]
},
{
"cell_type": "markdown",
"id": "48da5deb-6ee0-4b9b-8dd2-abed7ed65172",
"metadata": {},
"source": [
"- 我们上面实现的GPT模型等LLMs大型语言模型被用来一次生成一个单词"
]
},
{
"cell_type": "markdown",
"id": "caade12a-fe97-480f-939c-87d24044edff",
"metadata": {},
"source": [
"<img src=\"https://sebastianraschka.com/images/LLMs-from-scratch-images/ch04_compressed/16.webp\" width=\"100%\">"
]
},
{
"cell_type": "markdown",
"id": "a7061524-a3bd-4803-ade6-2e3b7b79ac13",
"metadata": {},
"source": [
"\n",
"以下是对给定文本的中英文翻译:\n",
"\n",
"- 以下的`generate_text_simple`函数实现了贪婪解码,这是一种简单且快速的文本生成方法\n",
"- 在贪婪解码中每一步模型都会选择概率最高的词或标记作为下一个输出最高的对数值对应于最高的概率因此我们甚至不必显式计算softmax函数\n",
"- 在下一章中,我们将实现一个更高级的`generate_text`函数\n",
"- 下图展示了GPT模型在给定输入上下文时如何生成下一个词标记\n"
]
},
{
"cell_type": "markdown",
"id": "7ee0f32c-c18c-445e-b294-a879de2aa187",
"metadata": {},
"source": [
"<img src=\"https://sebastianraschka.com/images/LLMs-from-scratch-images/ch04_compressed/17.webp\" width=\"100%\">"
]
},
{
"cell_type": "code",
"execution_count": 29,
"id": "c9b428a9-8764-4b36-80cd-7d4e00595ba6",
"metadata": {},
"outputs": [],
"source": [
"def generate_text_simple(model, idx, max_new_tokens, context_size):\n",
" # idx is (batch, n_tokens) array of indices in the current context\n",
" for _ in range(max_new_tokens):\n",
" \n",
" # Crop current context if it exceeds the supported context size\n",
" # E.g., if LLM supports only 5 tokens, and the context size is 10\n",
" # then only the last 5 tokens are used as context\n",
" idx_cond = idx[:, -context_size:]\n",
" \n",
" # Get the predictions\n",
" with torch.no_grad():\n",
" logits = model(idx_cond)\n",
" \n",
" # Focus only on the last time step\n",
" # (batch, n_tokens, vocab_size) becomes (batch, vocab_size)\n",
" logits = logits[:, -1, :] \n",
"\n",
" # Apply softmax to get probabilities\n",
" probas = torch.softmax(logits, dim=-1) # (batch, vocab_size)\n",
"\n",
" # Get the idx of the vocab entry with the highest probability value\n",
" idx_next = torch.argmax(probas, dim=-1, keepdim=True) # (batch, 1)\n",
"\n",
" # Append sampled index to the running sequence\n",
" idx = torch.cat((idx, idx_next), dim=1) # (batch, n_tokens+1)\n",
"\n",
" return idx"
]
},
{
"cell_type": "markdown",
"id": "6515f2c1-3cc7-421c-8d58-cc2f563b7030",
"metadata": {},
"source": [
"\n",
"上面的`generate_text_simple`实现了一个迭代过程,其中它一次生成一个标记\n",
"\n",
"<img src=\"https://sebastianraschka.com/images/LLMs-from-scratch-images/ch04_compressed/18.webp\" width=\"100%\">"
]
},
{
"cell_type": "markdown",
"id": "f682eac4-f9bd-438b-9dec-6b1cc7bc05ce",
"metadata": {},
"source": [
"- 让我们准备一个输入示例:"
]
},
{
"cell_type": "code",
"execution_count": 30,
"id": "3d7e3e94-df0f-4c0f-a6a1-423f500ac1d3",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"encoded: [15496, 11, 314, 716]\n",
"encoded_tensor.shape: torch.Size([1, 4])\n"
]
}
],
"source": [
"start_context = \"Hello, I am\"\n",
"\n",
"encoded = tokenizer.encode(start_context)\n",
"print(\"encoded:\", encoded)\n",
"\n",
"encoded_tensor = torch.tensor(encoded).unsqueeze(0)\n",
"print(\"encoded_tensor.shape:\", encoded_tensor.shape)"
]
},
{
"cell_type": "code",
"execution_count": 31,
"id": "a72a9b60-de66-44cf-b2f9-1e638934ada4",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Output: tensor([[15496, 11, 314, 716, 27018, 24086, 47843, 30961, 42348, 7267]])\n",
"Output length: 10\n"
]
}
],
"source": [
"model.eval() # disable dropout\n",
"\n",
"out = generate_text_simple(\n",
" model=model,\n",
" idx=encoded_tensor, \n",
" max_new_tokens=6, \n",
" context_size=GPT_CONFIG_124M[\"context_length\"]\n",
")\n",
"\n",
"print(\"Output:\", out)\n",
"print(\"Output length:\", len(out[0]))"
]
},
{
"cell_type": "markdown",
"id": "1d131c00-1787-44ba-bec3-7c145497b2c3",
"metadata": {},
"source": [
"- 移除批次维度并转换回文本:"
]
},
{
"cell_type": "code",
"execution_count": 32,
"id": "053d99f6-5710-4446-8d52-117fb34ea9f6",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Hello, I am Featureiman Byeswickattribute argue\n"
]
}
],
"source": [
"decoded_text = tokenizer.decode(out.squeeze(0).tolist())\n",
"print(decoded_text)"
]
},
{
"cell_type": "markdown",
"id": "9a894003-51f6-4ccc-996f-3b9c7d5a1d70",
"metadata": {},
"source": [
"\n",
"- 注意模型未经过训练;因此上述输出文本是随机的\n",
"- 我们将在下一章中训练模型\n"
]
},
{
"cell_type": "markdown",
"id": "a35278b6-9e5c-480f-83e5-011a1173648f",
"metadata": {},
"source": [
"## 总结和要点\n",
"\n",
"- 请查看./gpt.py脚本这是一个包含我们在本Jupyter笔记本中实现的GPT模型的独立脚本\n",
"- 您可以在./exercise-solutions.ipynb中找到练习题的解答"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.5"
}
},
"nbformat": 4,
"nbformat_minor": 5
}