mirror of
https://github.com/datawhalechina/llms-from-scratch-cn.git
synced 2026-05-01 11:58:17 +08:00
229 lines
6.5 KiB
Plaintext
229 lines
6.5 KiB
Plaintext
{
|
||
"cells": [
|
||
{
|
||
"cell_type": "markdown",
|
||
"id": "da5d9bc0-95ab-45d4-9378-417628d86e35",
|
||
"metadata": {},
|
||
"source": [
|
||
"## 4.7 生成文本"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"id": "48da5deb-6ee0-4b9b-8dd2-abed7ed65172",
|
||
"metadata": {},
|
||
"source": [
|
||
"- 我们上面实现的GPT模型等LLMs(大型语言模型)被用来一次生成一个单词"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"id": "caade12a-fe97-480f-939c-87d24044edff",
|
||
"metadata": {},
|
||
"source": [
|
||
"<img src=\"https://sebastianraschka.com/images/LLMs-from-scratch-images/ch04_compressed/16.webp\" width=\"100%\">"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"id": "a7061524-a3bd-4803-ade6-2e3b7b79ac13",
|
||
"metadata": {},
|
||
"source": [
|
||
"\n",
|
||
"以下是对给定文本的中英文翻译:\n",
|
||
"\n",
|
||
"- 以下的`generate_text_simple`函数实现了贪婪解码,这是一种简单且快速的文本生成方法\n",
|
||
"- 在贪婪解码中,每一步,模型都会选择概率最高的词(或标记)作为下一个输出(最高的对数值对应于最高的概率,因此我们甚至不必显式计算softmax函数)\n",
|
||
"- 在下一章中,我们将实现一个更高级的`generate_text`函数\n",
|
||
"- 下图展示了GPT模型在给定输入上下文时如何生成下一个词标记\n"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"id": "7ee0f32c-c18c-445e-b294-a879de2aa187",
|
||
"metadata": {},
|
||
"source": [
|
||
"<img src=\"https://sebastianraschka.com/images/LLMs-from-scratch-images/ch04_compressed/17.webp\" width=\"100%\">"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 29,
|
||
"id": "c9b428a9-8764-4b36-80cd-7d4e00595ba6",
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"def generate_text_simple(model, idx, max_new_tokens, context_size):\n",
|
||
" # idx is (batch, n_tokens) array of indices in the current context\n",
|
||
" for _ in range(max_new_tokens):\n",
|
||
" \n",
|
||
" # Crop current context if it exceeds the supported context size\n",
|
||
" # E.g., if LLM supports only 5 tokens, and the context size is 10\n",
|
||
" # then only the last 5 tokens are used as context\n",
|
||
" idx_cond = idx[:, -context_size:]\n",
|
||
" \n",
|
||
" # Get the predictions\n",
|
||
" with torch.no_grad():\n",
|
||
" logits = model(idx_cond)\n",
|
||
" \n",
|
||
" # Focus only on the last time step\n",
|
||
" # (batch, n_tokens, vocab_size) becomes (batch, vocab_size)\n",
|
||
" logits = logits[:, -1, :] \n",
|
||
"\n",
|
||
" # Apply softmax to get probabilities\n",
|
||
" probas = torch.softmax(logits, dim=-1) # (batch, vocab_size)\n",
|
||
"\n",
|
||
" # Get the idx of the vocab entry with the highest probability value\n",
|
||
" idx_next = torch.argmax(probas, dim=-1, keepdim=True) # (batch, 1)\n",
|
||
"\n",
|
||
" # Append sampled index to the running sequence\n",
|
||
" idx = torch.cat((idx, idx_next), dim=1) # (batch, n_tokens+1)\n",
|
||
"\n",
|
||
" return idx"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"id": "6515f2c1-3cc7-421c-8d58-cc2f563b7030",
|
||
"metadata": {},
|
||
"source": [
|
||
"\n",
|
||
"上面的`generate_text_simple`实现了一个迭代过程,其中它一次生成一个标记\n",
|
||
"\n",
|
||
"<img src=\"https://sebastianraschka.com/images/LLMs-from-scratch-images/ch04_compressed/18.webp\" width=\"100%\">"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"id": "f682eac4-f9bd-438b-9dec-6b1cc7bc05ce",
|
||
"metadata": {},
|
||
"source": [
|
||
"- 让我们准备一个输入示例:"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 30,
|
||
"id": "3d7e3e94-df0f-4c0f-a6a1-423f500ac1d3",
|
||
"metadata": {},
|
||
"outputs": [
|
||
{
|
||
"name": "stdout",
|
||
"output_type": "stream",
|
||
"text": [
|
||
"encoded: [15496, 11, 314, 716]\n",
|
||
"encoded_tensor.shape: torch.Size([1, 4])\n"
|
||
]
|
||
}
|
||
],
|
||
"source": [
|
||
"start_context = \"Hello, I am\"\n",
|
||
"\n",
|
||
"encoded = tokenizer.encode(start_context)\n",
|
||
"print(\"encoded:\", encoded)\n",
|
||
"\n",
|
||
"encoded_tensor = torch.tensor(encoded).unsqueeze(0)\n",
|
||
"print(\"encoded_tensor.shape:\", encoded_tensor.shape)"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 31,
|
||
"id": "a72a9b60-de66-44cf-b2f9-1e638934ada4",
|
||
"metadata": {},
|
||
"outputs": [
|
||
{
|
||
"name": "stdout",
|
||
"output_type": "stream",
|
||
"text": [
|
||
"Output: tensor([[15496, 11, 314, 716, 27018, 24086, 47843, 30961, 42348, 7267]])\n",
|
||
"Output length: 10\n"
|
||
]
|
||
}
|
||
],
|
||
"source": [
|
||
"model.eval() # disable dropout\n",
|
||
"\n",
|
||
"out = generate_text_simple(\n",
|
||
" model=model,\n",
|
||
" idx=encoded_tensor, \n",
|
||
" max_new_tokens=6, \n",
|
||
" context_size=GPT_CONFIG_124M[\"context_length\"]\n",
|
||
")\n",
|
||
"\n",
|
||
"print(\"Output:\", out)\n",
|
||
"print(\"Output length:\", len(out[0]))"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"id": "1d131c00-1787-44ba-bec3-7c145497b2c3",
|
||
"metadata": {},
|
||
"source": [
|
||
"- 移除批次维度并转换回文本:"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 32,
|
||
"id": "053d99f6-5710-4446-8d52-117fb34ea9f6",
|
||
"metadata": {},
|
||
"outputs": [
|
||
{
|
||
"name": "stdout",
|
||
"output_type": "stream",
|
||
"text": [
|
||
"Hello, I am Featureiman Byeswickattribute argue\n"
|
||
]
|
||
}
|
||
],
|
||
"source": [
|
||
"decoded_text = tokenizer.decode(out.squeeze(0).tolist())\n",
|
||
"print(decoded_text)"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"id": "9a894003-51f6-4ccc-996f-3b9c7d5a1d70",
|
||
"metadata": {},
|
||
"source": [
|
||
"\n",
|
||
"- 注意模型未经过训练;因此上述输出文本是随机的\n",
|
||
"- 我们将在下一章中训练模型\n"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"id": "a35278b6-9e5c-480f-83e5-011a1173648f",
|
||
"metadata": {},
|
||
"source": [
|
||
"## 总结和要点\n",
|
||
"\n",
|
||
"- 请查看./gpt.py脚本,这是一个包含我们在本Jupyter笔记本中实现的GPT模型的独立脚本\n",
|
||
"- 您可以在./exercise-solutions.ipynb中找到练习题的解答"
|
||
]
|
||
}
|
||
],
|
||
"metadata": {
|
||
"kernelspec": {
|
||
"display_name": "Python 3 (ipykernel)",
|
||
"language": "python",
|
||
"name": "python3"
|
||
},
|
||
"language_info": {
|
||
"codemirror_mode": {
|
||
"name": "ipython",
|
||
"version": 3
|
||
},
|
||
"file_extension": ".py",
|
||
"mimetype": "text/x-python",
|
||
"name": "python",
|
||
"nbconvert_exporter": "python",
|
||
"pygments_lexer": "ipython3",
|
||
"version": "3.11.5"
|
||
}
|
||
},
|
||
"nbformat": 4,
|
||
"nbformat_minor": 5
|
||
}
|