Files
llms-from-scratch-cn/Model_Architecture_Discussions/rwkv-v2/rwkv-v2-guide.ipynb
T
2024-06-10 17:00:23 +08:00

778 lines
33 KiB
Plaintext
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"id": "12e4e650-e036-4ca1-83a7-806911fdf0c1",
"metadata": {},
"outputs": [],
"source": [
"import numpy as np\n",
"import math, json, time, types, copy, sys, os\n",
"import torch\n",
"from torch.nn import functional as F\n",
"import torch.nn as nn\n",
"\n",
"from transformers import PreTrainedTokenizerFast\n",
"\n",
"np.set_printoptions(precision=4, suppress=True, linewidth=200)"
]
},
{
"cell_type": "raw",
"id": "d189546c-4d6e-4643-91fb-aab36cfd1935",
"metadata": {},
"source": [
"模型下载地址:https://hf-mirror.com/BlinkDL/rwkv-2-pile-430m/resolve/main/20220615-10803.pth?download=true\n",
"\n",
"请修改模型路径\n",
"例如我的路径是/data1/ckw/20220615-10803\n",
"如果想使用cuda加速,请参考:https://github.com/BlinkDL/RWKV-v2-RNN-Pile"
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "9c02901b-63f9-41fd-bf0b-0c28a2ee57cb",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n",
"* running on cpu\n"
]
}
],
"source": [
"RUN_DEVICE = 'cpu'\n",
"ctx_len = 768\n",
"n_layer = 24\n",
"n_embd = 1024\n",
"\n",
"MODEL_NAME = '/data1/ckw/20220615-10803' #修改为自己的模型路径\n",
"\n",
"vocab_size = 50277\n",
"VOCAB_NAME = '20B_tokenizer.json'\n",
"\n",
"print(f'\\n* running on {RUN_DEVICE}')"
]
},
{
"cell_type": "markdown",
"id": "65ece3b5-1682-4bf0-a4b0-67192230cd47",
"metadata": {},
"source": [
"### 什么是RWKV\n",
"\n",
"RWKV是Receptance Weighted Key Value的缩写,是一种结合了RNN(循环神经网络)和Transformer优势的新型神经网络架构。它的设计目的是解决Transformer在处理长序列时的内存和计算复杂度问题,同时保留RNN在推理阶段的计算效率。RWKV利用线性注意力机制,可以将其形式化为Transformer或RNN,从而在训练期间实现并行计算,并在推理过程中保持恒定的计算和内存复杂度。"
]
},
{
"cell_type": "markdown",
"id": "f5130fdf-306d-404c-ad0b-35ec4bc07e7f",
"metadata": {},
"source": [
"RWKV 的 ChannelMix 实现方式结合了时间混合和通道混合的操作。下面是对代码和其对应公式的详细解释:\n",
"\n",
"1. **时间混合(Time Mixing**\n",
" 时间混合通过 `time_mix` 参数和 `time_shift` 操作来实现。这一步的目的是结合当前时间步的输入和前一个时间步的输入。\n",
"\n",
" 公式表示:\n",
"\n",
" \n",
" \\begin{align*}\n",
" x' = x \\cdot \\text{time\\_mix} + \\text{time\\_shift}(x) \\cdot (1 - \\text{time\\_mix})\n",
" \\end{align*}\n",
"\n",
" 其中,`time_shift` 操作是一个时间步的移位操作,`time_mix` 是一个可训练的参数。\n",
"\n",
"3. **键(Key)生成**\n",
" 使用一个线性层 `self.key` 将输入 `x'` 转换成键 `k`,然后应用 ReLU 激活函数和平方操作。\n",
"\n",
" 公式表示:\n",
" \\begin{align*}\n",
" k = \\text{ReLU}(\\text{key}(x'))^2\n",
" \\end{align*}\n",
"\n",
"4. **值(Value)生成**\n",
" 将键 `k` 输入到值线性层 `self.value` 以生成值 `kv`。\n",
"\n",
" 公式表示:\n",
" \\begin{align*}\n",
" kv = \\text{value}(k)\n",
" \\end{align*}\n",
"\n",
"5. **接收函数(Receptance Function**\n",
" 使用一个线性层 `self.receptance` 计算接收函数 `r`,然后应用 Sigmoid 激活函数。\n",
"\n",
" 公式表示:\n",
" \\begin{align*}\n",
" r = \\sigma(\\text{receptance}(x'))\n",
" \\end{align*}\n",
"\n",
"6. **最终输出**\n",
" 将接收函数 `r` 与值 `kv` 相乘生成最终输出 `rkv`。\n",
"\n",
" 公式表示:\n",
" \\begin{align*}\n",
" rkv = r \\cdot kv\n",
" \\end{align*}\n",
"\n",
"综合这些步骤,整个 ChannelMix 的计算过程可以用以下公式表示:\n",
"\n",
"\\begin{align*}\n",
"x' & = x \\cdot \\text{time\\_mix} + \\text{time\\_shift}(x) \\cdot (1 - \\text{time\\_mix}) \\\\\n",
"k & = \\text{ReLU}(\\text{key}(x'))^2 \\\\\n",
"kv & = \\text{value}(k) \\\\\n",
"r & = \\sigma(\\text{receptance}(x')) \\\\\n",
"\\text{output} & = r \\cdot kv\n",
"\\end{align*}\n",
"\n",
"\n",
"以上公式解释了 RWKV 的 ChannelMix 实现的细节。"
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "4599dc46-75c5-4e47-af1f-012bb72954e6",
"metadata": {},
"outputs": [],
"source": [
"class RWKV_ChannelMix(nn.Module):\n",
" def __init__(self, layer_id):\n",
" super().__init__()\n",
" self.layer_id = layer_id\n",
"\n",
" self.time_shift = nn.ZeroPad2d((0,0,1,-1))\n",
" self.time_mix = nn.Parameter(torch.ones(1, 1, n_embd))\n",
"\n",
" hidden_sz = 4 * n_embd\n",
" self.key = nn.Linear(n_embd, hidden_sz, bias=False)\n",
" self.receptance = nn.Linear(n_embd, n_embd, bias=False)\n",
" self.value = nn.Linear(hidden_sz, n_embd, bias=False)\n",
"\n",
" def forward(self, x):\n",
" x = x * self.time_mix + self.time_shift(x) * (1 - self.time_mix)\n",
"\n",
" k = self.key(x)\n",
" k = torch.square(torch.relu(k))\n",
" kv = self.value(k)\n",
" \n",
" rkv = torch.sigmoid(self.receptance(x)) * kv\n",
" return rkv"
]
},
{
"cell_type": "markdown",
"id": "b6d3accc-60a8-4118-97ea-b20abe14b6ad",
"metadata": {},
"source": [
"在RWKV的实现中,`RWKV_TimeMix` 通过时间混合来处理输入数据。下面是具体的实现和对应的公式说明:\n",
"\n",
"1. **时间混合(Time Mixing**\n",
" 时间混合是通过 `time_mix` 参数和 `time_shift` 操作来实现的。这一步的目的是结合当前时间步的输入和前一个时间步的输入。\n",
"\n",
" 公式表示:\n",
" \\begin{align*}\n",
" x' = x \\cdot \\text{time\\_mix} + \\text{time\\_shift}(x) \\cdot (1 - \\text{time\\_mix})\n",
" \\end{align*}\n",
"\n",
"2. **键(Key)生成**\n",
" 使用一个线性层 `self.key` 将输入 `x'` 转换成键 `k`,然后对其进行转置。\n",
"\n",
" 公式表示:\n",
" \\begin{align*}\n",
" k = \\text{key}(x')^T\n",
" \\end{align*}\n",
"\n",
"3. **值(Value)生成**\n",
" 使用一个线性层 `self.value` 将输入 `x'` 转换成值 `v`,然后对其进行转置。\n",
"\n",
" 公式表示:\n",
" \\begin{align*}\n",
" v = \\text{value}(x')^T\n",
" \\end{align*}\n",
"\n",
"4. **接收函数(Receptance Function**\n",
" 使用一个线性层 `self.receptance` 计算接收函数 `r`。\n",
"\n",
" 公式表示:\n",
" \\begin{align*}\n",
" r = \\text{receptance}(x')\n",
" \\end{align*}\n",
"\n",
"5. **键值相乘**\n",
" 将键 `k` 和值 `v` 相乘得到 `kv`。\n",
"\n",
" 公式表示:\n",
" \\begin{align*}\n",
" kv = k \\cdot v\n",
" \\end{align*}\n",
"\n",
"6. **时间权重计算**\n",
" 计算时间权重 `w`,其中 `self.time_w` 是通过 `time_decay` 和 `time_curve` 计算得到的。\n",
"\n",
" 公式表示:\n",
" \\begin{align*}\n",
" \\text{self.time\\_w} &= \\exp(\\text{time\\_decay}) \\cdot \\text{time\\_curve} \\\\\n",
" w &= \\exp(\\text{self.time\\_w})\n",
" \\end{align*}\n",
"\n",
"7. **卷积操作**\n",
" 使用一维卷积计算加权键值和加权键。\n",
"\n",
" 公式表示:\n",
" \\begin{align*}\n",
" wkv &= \\text{conv1d}(\\text{ZeroPad2d}(kv), w, \\text{groups}=C) \\\\\n",
" wk &= \\text{conv1d}(\\text{ZeroPad2d}(k), w, \\text{groups}=C) + 1e-9\n",
" \\end{align*}\n",
"\n",
"8. **最终输出**\n",
" 将接收函数 `r` 与加权键值比 `wkv / wk` 相乘,并通过输出线性层得到最终输出 `rwkv`。\n",
"\n",
" 公式表示:\n",
" \\begin{align*}\n",
" rwkv &= \\sigma(r) \\cdot \\left( \\frac{wkv}{wk} \\right)^T \\\\\n",
" rwkv &= \\text{output}(rwkv)\n",
" \\end{align*}\n",
"\n",
"综合这些步骤,`RWKV_TimeMix` 的整个计算过程可以表示为:\n",
"\n",
"\\begin{align*}\n",
"x' &= x \\cdot \\text{time\\_mix} + \\text{time\\_shift}(x) \\cdot (1 - \\text{time\\_mix}) \\\\\n",
"k &= \\text{key}(x')^T \\\\\n",
"v &= \\text{value}(x')^T \\\\\n",
"r &= \\text{receptance}(x') \\\\\n",
"kv &= k \\cdot v \\\\\n",
"\\text{self.time\\_w} &= \\exp(\\text{time\\_decay}) \\cdot \\text{time\\_curve} \\\\\n",
"w &= \\exp(\\text{self.time\\_w}) \\\\\n",
"wkv &= \\text{conv1d}(\\text{ZeroPad2d}(kv), w, \\text{groups}=C) \\\\\n",
"wk &= \\text{conv1d}(\\text{ZeroPad2d}(k), w, \\text{groups}=C) + 1e-9 \\\\\n",
"rwkv &= \\sigma(r) \\cdot \\left( \\frac{wkv}{wk} \\right)^T \\\\\n",
"rwkv &= \\text{output}(rwkv)\n",
"\\end{align*}\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "228bb098-eb18-4365-a69f-349a9d9709f2",
"metadata": {},
"outputs": [],
"source": [
"class RWKV_TimeMix(nn.Module):\n",
" def __init__(self, layer_id):\n",
" super().__init__()\n",
" self.layer_id = layer_id\n",
" self.time_decay = nn.Parameter(torch.ones(n_embd, 1))\n",
" self.time_curve = torch.tensor([-(ctx_len - 2 - i) for i in range(ctx_len-1)]).unsqueeze(0)\n",
" self.time_first = nn.Parameter(torch.ones(n_embd, 1) * math.log(0.3))\n",
" \n",
" self.time_shift = nn.ZeroPad2d((0,0,1,-1))\n",
" self.time_mix = nn.Parameter(torch.ones(1,1,n_embd))\n",
"\n",
" self.key = nn.Linear(n_embd, n_embd, bias=False)\n",
" self.value = nn.Linear(n_embd, n_embd, bias=False)\n",
" self.receptance = nn.Linear(n_embd, n_embd, bias=False)\n",
"\n",
" self.output = nn.Linear(n_embd, n_embd, bias=False)\n",
"\n",
" def forward(self, x):\n",
" B, T, C = x.size()\n",
"\n",
" x = x * self.time_mix + self.time_shift(x) * (1 - self.time_mix)\n",
"\n",
" k = self.key(x).transpose(-1, -2)\n",
" v = self.value(x).transpose(-1, -2)\n",
" r = self.receptance(x)\n",
"\n",
" k = torch.clamp(k, max=60)\n",
" k = torch.exp(k)\n",
"\n",
" kv = k * v\n",
"\n",
" self.time_w = torch.cat([torch.exp(self.time_decay) * self.time_curve.to(self.time_decay.device), self.time_first], dim=-1)\n",
" w = torch.exp(self.time_w)\n",
" \n",
" w = w[:,-T:].unsqueeze(1)\n",
" wkv = F.conv1d(nn.ZeroPad2d((T-1, 0, 0, 0))(kv), w, groups=C)\n",
" wk = F.conv1d(nn.ZeroPad2d((T-1, 0, 0, 0))(k), w, groups=C) + 1e-9\n",
"\n",
" rwkv = torch.sigmoid(r) * (wkv / wk).transpose(-1, -2)\n",
" \n",
" rwkv = self.output(rwkv)\n",
" return rwkv"
]
},
{
"cell_type": "markdown",
"id": "c5b3e83c-1fda-498b-92f9-0e98744fd999",
"metadata": {},
"source": [
"### RWKV的Block\n",
"\n",
"RWKV的Block是一个基本的模块,它结合了时间混合(TimeMix)和通道混合(ChannelMix)操作。Block中的每个模块(时间混合和通道混合)都通过归一化和残差连接来处理输入数据,从而增强模型的稳定性和性能。\n",
"\n",
"### 主要组件和操作\n",
"\n",
"1. **LayerNorm**:用于归一化输入,增强训练的稳定性。\n",
" - `self.ln1` 和 `self.ln2` 分别在时间混合和通道混合之前对输入进行归一化。\n",
" \n",
"2. **时间混合(TimeMix)**:结合当前时间步和前一个时间步的信息,捕获时间依赖性。\n",
" - `self.att = RWKV_TimeMix(layer_id)` 初始化时间混合模块。\n",
" \n",
"3. **通道混合(ChannelMix)**:在不同通道间进行混合,增强模型的表达能力。\n",
" - `self.ffn = RWKV_ChannelMix(layer_id)` 初始化通道混合模块。\n",
" \n",
"4. **残差连接**:通过将混合操作的输出加回到原始输入上,保持信息流动并增强模型的梯度传播能力。\n",
"\n",
"通过这种设计,RWKV的Block能够高效地处理序列数据,结合时间和通道信息,提高模型的表现。"
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "2630a3ff-6a8c-49d9-a4d9-05098b424d02",
"metadata": {},
"outputs": [],
"source": [
"class Block(nn.Module):\n",
" def __init__(self, layer_id):\n",
" super().__init__()\n",
" self.layer_id = layer_id # 存储当前层的ID\n",
"\n",
" # 定义两个LayerNorm层,用于归一化输入\n",
" self.ln1 = nn.LayerNorm(n_embd)\n",
" self.ln2 = nn.LayerNorm(n_embd)\n",
" \n",
" # 定义时间混合和通道混合模块\n",
" self.att = RWKV_TimeMix(layer_id)\n",
" self.ffn = RWKV_ChannelMix(layer_id)\n",
"\n",
" def forward(self, x):\n",
" # 首先对输入进行LayerNorm归一化\n",
" x = self.ln1(x)\n",
" \n",
" # 进行时间混合操作,并通过残差连接将结果加回到输入上\n",
" x = x + self.att(x)\n",
" \n",
" # 再次对输入进行LayerNorm归一化\n",
" x = self.ln2(x)\n",
" \n",
" # 进行通道混合操作,并通过残差连接将结果加回到输入上\n",
" x = x + self.ffn(x)\n",
" \n",
" # 返回最终的输出\n",
" return x\n"
]
},
{
"cell_type": "markdown",
"id": "bed3d877-9847-450b-8c92-5fd5c19a1811",
"metadata": {},
"source": [
"接下来,实现了RWKV模型的主要部分:\n",
"\n",
"1. **模型加载和预处理**:代码中加载模型权重并进行时间相关权重的预处理。\n",
"2. **LayerNorm**:在`LN`方法中实现了层归一化,关于LayerNorm的使用。\n",
"3. **前馈网络(FF)和自注意力(SA)**:`FF`方法实现了前馈网络的计算,`SA`方法实现了自注意力机制的计算。这两部分对应TimeMix和ChannelMix的详细计算。\n",
"4. **运行模型**:`run`方法实现了模型的整体运行逻辑,依次通过每一层,并最终输出结果。即模型的运行和推理过程。"
]
},
{
"cell_type": "code",
"execution_count": 5,
"id": "e1c04132-b0de-4a8e-9769-2b4619e2e70a",
"metadata": {},
"outputs": [],
"source": [
"time_buf = {} # 用于缓存时间相关信息的全局字典\n",
"\n",
"class RWKV_RNN():\n",
" def __init__(self, MODEL_NAME=MODEL_NAME):\n",
" print('\\nloading RWKV-RNN', MODEL_NAME)\n",
" self.ctx_len = ctx_len # 上下文长度\n",
" self.n_layer = n_layer # 网络层数\n",
" self.n_embd = n_embd # 嵌入维度\n",
" self.tokenizer = PreTrainedTokenizerFast(tokenizer_file=VOCAB_NAME) # 初始化分词器\n",
"\n",
" self.w = types.SimpleNamespace() # 用于存储模型权重的命名空间\n",
" \n",
" # 加载模型权重文件\n",
" w = torch.load(MODEL_NAME + '.pth', map_location=torch.device(RUN_DEVICE))\n",
"\n",
" # 处理时间相关的权重\n",
" for x in w.keys():\n",
" if '.time_' in x:\n",
" w[x] = w[x].squeeze() # 压缩维度\n",
" if '.time_decay' in x:\n",
" w[x] = torch.exp(-torch.exp(w[x])) # 对时间衰减进行双重指数运算\n",
" if '.time_first' in x:\n",
" w[x] = torch.exp(w[x]) # 对时间初始值进行指数运算\n",
" \n",
" # 将权重存储在命名空间中\n",
" xx = x.split('.')\n",
" here = self.w\n",
" for i in range(len(xx)):\n",
" if xx[i].isdigit():\n",
" ii = int(xx[i])\n",
" if ii not in here:\n",
" here[ii] = types.SimpleNamespace() # 初始化命名空间\n",
" here = here[ii]\n",
" else:\n",
" if i == len(xx) - 1:\n",
" setattr(here, xx[i], w[x])\n",
" elif not hasattr(here, xx[i]):\n",
" if xx[i+1].isdigit():\n",
" setattr(here, xx[i], {})\n",
" else:\n",
" setattr(here, xx[i], types.SimpleNamespace())\n",
" here = getattr(here, xx[i])\n",
" \n",
" self.clear() # 初始化缓存\n",
" \n",
" def clear(self):\n",
" self.xx = {} # 清空缓存\n",
" self.aa = {}\n",
" self.bb = {}\n",
" \n",
" def save(self, target):\n",
" # 深拷贝当前状态到目标\n",
" target.xx = copy.deepcopy(self.xx)\n",
" target.aa = copy.deepcopy(self.aa)\n",
" target.bb = copy.deepcopy(self.bb)\n",
" \n",
" def load(self, target):\n",
" # 从目标深拷贝状态到当前实例\n",
" self.xx = copy.deepcopy(target.xx)\n",
" self.aa = copy.deepcopy(target.aa)\n",
" self.bb = copy.deepcopy(target.bb)\n",
"\n",
" def LN(self, xx, w):\n",
" # 执行LayerNorm归一化\n",
" return F.layer_norm(xx, (n_embd,), weight=w.weight, bias=w.bias)\n",
"\n",
" def FF(self, xx, w, name):\n",
" # 前馈网络计算\n",
" if name not in self.xx:\n",
" self.xx[name] = torch.zeros(n_embd, device=RUN_DEVICE)\n",
" x = xx * w.time_mix + self.xx[name] * (1 - w.time_mix) # 混合当前输入和缓存\n",
"\n",
" self.xx[name] = xx # 更新缓存\n",
"\n",
" r = torch.sigmoid(w.receptance.weight @ x) # 计算接收向量\n",
" k = torch.square(torch.relu(w.key.weight @ x)) # 计算键向量\n",
" kv = w.value.weight @ k # 计算值向量\n",
"\n",
" return r * kv # 返回前馈网络输出\n",
"\n",
" def SA(self, xx, w, name):\n",
" # 自注意力计算\n",
" if name not in self.xx:\n",
" self.xx[name] = torch.zeros(n_embd, device=RUN_DEVICE)\n",
" self.aa[name] = torch.zeros(n_embd, device=RUN_DEVICE)\n",
" self.bb[name] = torch.zeros(n_embd, device=RUN_DEVICE)\n",
" x = xx * w.time_mix + self.xx[name] * (1 - w.time_mix) # 混合当前输入和缓存\n",
" self.xx[name] = xx # 更新缓存\n",
"\n",
" r = torch.sigmoid(w.receptance.weight @ x) # 计算接收向量\n",
"\n",
" k = torch.exp(torch.clamp(w.key.weight @ x, max=60)) # 计算键向量\n",
" v = w.value.weight @ x # 计算值向量\n",
" kv = k * v # 计算键值对\n",
"\n",
" a = self.aa[name] + w.time_first * kv # 计算新的a值\n",
" b = self.bb[name] + w.time_first * k # 计算新的b值\n",
" self.aa[name] = w.time_decay * self.aa[name] + kv # 更新缓存中的a值\n",
" self.bb[name] = w.time_decay * self.bb[name] + k # 更新缓存中的b值\n",
"\n",
" rwkv = r * a / (b + 1e-9) # 计算自注意力输出\n",
"\n",
" return w.output.weight @ rwkv # 返回自注意力输出\n",
"\n",
" def run(self, ctx):\n",
" # 运行模型\n",
" w = self.w\n",
" x = w.emb.weight[ctx[-1]] # 获取当前token的嵌入\n",
"\n",
" # 依次通过每一层\n",
" for i in range(n_layer):\n",
" x = self.LN(x, w.blocks[i].ln1) # 归一化\n",
" x = x + self.SA(x, w.blocks[i].att, f'att.{i}') # 自注意力计算并残差连接\n",
" x = self.LN(x, w.blocks[i].ln2) # 归一化\n",
" x = x + self.FF(x, w.blocks[i].ffn, f'ffn.{i}') # 前馈网络计算并残差连接\n",
"\n",
" x = self.LN(x, w.ln_out) # 最后一层归一化\n",
"\n",
" x = w.head.weight @ x # 计算输出\n",
" x = x.tolist() # 转换为列表\n",
"\n",
" return x # 返回最终结果"
]
},
{
"cell_type": "code",
"execution_count": 6,
"id": "38decbf1-12ee-4a89-b97a-d5e580c10c74",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n",
"******************************************************************************\n",
"* This is a preview of RWKV-v2-RNN trained on the Pile for only 50B tokens.\n",
"* It is NOT indicative of the final performance (which requires 300B tokens).\n",
"******************************************************************************\n"
]
}
],
"source": [
"print('''\n",
"******************************************************************************\n",
"* This is a preview of RWKV-v2-RNN trained on the Pile for only 50B tokens.\n",
"* It is NOT indicative of the final performance (which requires 300B tokens).\n",
"******************************************************************************''')"
]
},
{
"cell_type": "code",
"execution_count": 7,
"id": "99e34d05-8296-4d2f-9f45-29912ddde8f1",
"metadata": {},
"outputs": [],
"source": [
"# Edit model.py to set CPU / CUDA mode. Runs on CPU by default.\n",
"\n",
"TEMPERATURE = 1.0\n",
"TOP_P = 0.7\n",
"\n",
"DEBUG_DEBUG = False\n",
"LENGTH_OF_EACH = 333\n",
"NUM_TRIALS = 3\n",
"\n",
"context = '\\nDataWhalechina is an organization founded at Shanghai Jiao Tong University that helps learners learn artificial intelligence.'\n",
"\n",
"##############################################################################################################"
]
},
{
"cell_type": "code",
"execution_count": 8,
"id": "4a620d9b-e93e-4c6e-b015-99246c1e036c",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n",
"loading RWKV-RNN /data1/ckw/20220615-10803\n"
]
}
],
"source": [
"model = RWKV_RNN()"
]
},
{
"cell_type": "markdown",
"id": "0546efbe-600b-47de-8873-1174686beaa4",
"metadata": {},
"source": [
"下面我们从给定的输出logits中进行采样,以生成一个新的token。它实现了**温度调节采样**和**核采样(Top-p采样)**,具体步骤如下:\n",
"\n",
"1. **Softmax转换**:将模型输出的logits通过softmax函数转换为概率分布。\n",
"2. **排序和累积概率计算**:对概率从高到低进行排序,并计算累积概率分布。\n",
"3. **核采样**\n",
" - 计算累积概率超过`top_p`的最小值,确定截断值`cutoff`。\n",
" - 将所有低于截断值的概率置为0,从而保留最重要的`top_p`部分概率。\n",
"4. **温度调节**:如果`temperature`不为1,则调整概率分布,使得概率分布更平滑或更尖锐。\n",
"5. **采样**:从调整后的概率分布中采样一个值,返回对应的索引。\n",
"\n",
"这种方法在文本生成任务中尤为常用,通过调节`temperature`和`top_p`参数,可以控制生成文本的多样性和质量。"
]
},
{
"cell_type": "code",
"execution_count": 9,
"id": "46c528e4-ef2f-4136-9022-5c1fc08b185b",
"metadata": {},
"outputs": [],
"source": [
"def sample_logits(out, temperature=1.0, top_p=None):\n",
" # 将输出转化为概率分布(通过softmax函数)\n",
" probs = F.softmax(torch.tensor(out), dim=-1)\n",
" \n",
" # 按概率从高到低排序\n",
" sorted_probs, _ = torch.sort(probs, descending=True)\n",
"\n",
" # 计算累积概率分布\n",
" cumulative_probs = torch.cumsum(sorted_probs, dim=-1).numpy()\n",
" \n",
" # 根据累积概率和top_p计算截断值(cutoff\n",
" cutoff = float(sorted_probs[np.argmax(cumulative_probs > top_p)])\n",
" \n",
" # 将低于截断值的概率置为0\n",
" probs[probs < cutoff] = 0\n",
"\n",
" # 如果temperature不等于1,则对概率进行温度调节\n",
" if temperature != 1.0:\n",
" probs = probs.pow(1.0 / temperature)\n",
"\n",
" # 从调整后的概率分布中采样一个值并返回\n",
" return torch.multinomial(probs, num_samples=1)[0]\n"
]
},
{
"cell_type": "code",
"execution_count": 10,
"id": "14581d92-4590-4303-9405-d56f65a3c675",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n",
"DataWhalechina is an organization founded at Shanghai Jiao Tong University that helps learners learn artificial intelligence. We want to change the way students learn artificial intelligence. We are very committed to the idea of artificial intelligence. We have already taken part in a joint research project with the European Research Council. This will bring us closer to our goal. We hope that this research project will give us the opportunity to develop a framework for a data-driven approach in artificial intelligence.\n",
"\n",
"I.C. Pfeifer\n",
"\n",
"Research Fellow\n",
"\n",
"Our work in data science and machine learning has a strong connection with the Human Brain Project, an initiative of the University of California at Berkeley. The focus of our work is on the research of language and language learning, with an emphasis on language acquisition. Our research is directed at the development of technology to improve the language acquisition skills of students and their parents.\n",
"\n",
"The Joint Data-Science-Technology Partnership between the Joint Data Science-Technology Partnership and the International Research Center for Education and Business in the Humanities and Social Sciences (Institute for Learning and Human-Computer Interaction, LBI, JIU) is a collaboration between the JIU and the University of California at Berkeley, the University of California at Los Angeles, the University of Oxford, and the University of Oxford.\n",
"\n",
"In the Humanities and Social Science-Technology Partnership, we have been working with our partners in the Humanities and Social Science-Technology Partnership, the Institute for Learning and Brain Science, the National Science Foundation, the National Science Foundation, and the United States Department of Education on a series of projects focused on the application of artificial intelligence and data mining for educational research.\n",
"\n",
"My wife and I have always believed that data\n",
"----------------------------------------------------------------------\n",
"DataWhalechina is an organization founded at Shanghai Jiao Tong University that helps learners learn artificial intelligence. DataWhalechinas first project, called DataWidener, was started in 2011 to help undergraduate and graduate students develop their research skills.\n",
"\n",
"DataWhalechina.com\n",
"\n",
"In 2016, DataWidener joined the AI-enabled Project Checkup. It was set up to investigate the data used in AI projects and to help prevent missteps.\n",
"\n",
"The data scientists at DataWideners first project checkup are learning how to use artificial intelligence to analyze natural language data and how to understand the problem-solving ability of artificial intelligence.\n",
"\n",
"DataWideners second project checkup is using artificial intelligence to help students understand the information that they are learning. The AI-powered project checkup helps students understand how to use artificial intelligence to learn.\n",
"\n",
"DataWideners third project checkup was in 2016. The AI-powered project checkup helps learners understand how to analyze natural language data.\n",
"\n",
"DataWideners fourth project checkup was in 2016. The AI-powered project checkup helps learners understand how to use natural language data.\n",
"\n",
"DataWideners fifth project checkup was in 2016. The AI-powered project checkup helps learners understand how to use natural language data.\n",
"\n",
"DataWideners sixth project checkup was in 2017. The AI-powered project checkup helps learners understand how to use natural language data.\n",
"\n",
"The AI-powered project checkup helps learners understand how to use natural language data. The AI-powered project checkup helps learners understand how to use natural language data.\n",
"\n",
"----------------------------------------------------------------------\n",
"DataWhalechina is an organization founded at Shanghai Jiao Tong University that helps learners learn artificial intelligence. The organization focuses on promoting artificial intelligence as a way of learning, as well as on promoting skills that are essential for future job opportunities.\n",
"\n",
"DataWhalechina has more than 1.2 million members and reaches more than 50 million members. The company has been listed on the Shanghai Stock Exchange since 2010 and has a market capitalization of $26 billion.\n",
"\n",
"The company has invested $2.7 billion in more than 100 projects since 2009. The company also works on various areas such as industry standardization, artificial intelligence and data mining.\n",
"\n",
"The companys focus is to create a learning environment that is better for learning and more productive.\n",
"\n",
"DataWhalechinas AI applications and software are developed and used by over 30,000 teachers, students and learners in China. DataWhalechina has been awarded the “Top 10 Under 50 in China” by the Center for Training in China.\n",
"\n",
"DataWhalechina is the only company in China that has created and published AI solutions to solve real-world problems, such as energy efficiency, smart buildings, social media, and the Internet of Things.\n",
"\n",
"DataWhalechina is backed by an initial public offering in April. It is scheduled to be listed on the Shanghai Stock Exchange in early 2020.\n",
"\n",
"This article was originally published by iPro and was first published by iPro on April 10, 2020.\n",
"\n",
"Learn more about iPro at iPro.org.\n",
"\n",
"Learn more about datawhalechina at datawhalechina.org.\n",
"\n",
"© 2020 iPro. All Rights Reserved.\n",
"\n",
"Share\n",
"----------------------------------------------------------------------"
]
}
],
"source": [
"for TRIAL in range(1 if DEBUG_DEBUG else NUM_TRIALS):\n",
" ctx = [model.tokenizer.encode(context)][0]\n",
" src_len = len(ctx)\n",
" print(context, end='')\n",
"\n",
" model.clear()\n",
" if TRIAL == 0:\n",
" init_state = types.SimpleNamespace()\n",
" for i in range(src_len if DEBUG_DEBUG else src_len):\n",
" x = ctx[:i+1]\n",
" if i == src_len - 1:\n",
" init_state.out = model.run(x)\n",
" else:\n",
" model.run(x)\n",
" model.save(init_state)\n",
" else:\n",
" model.load(init_state)\n",
"\n",
" if DEBUG_DEBUG:\n",
" out = init_state.out\n",
" print('\\n', np.array(x), '==>', np.array(\n",
" out), np.max(out), np.min(out))\n",
"\n",
" for i in range(src_len, src_len + (0 if DEBUG_DEBUG else LENGTH_OF_EACH)):\n",
" x = ctx[:i+1]\n",
" x = x[-model.ctx_len:]\n",
"\n",
" if i == src_len:\n",
" out = copy.deepcopy(init_state.out)\n",
" else:\n",
" out = model.run(x)\n",
"\n",
" out[0] = -999999999 # disable <|endoftext|>\n",
"\n",
" char = sample_logits(out, temperature=TEMPERATURE, top_p=TOP_P)\n",
" char = char.item()\n",
" print(model.tokenizer.decode(char), end='', flush=True)\n",
"\n",
" ctx += [char]\n",
" print('\\n' + '-' * 70, end='')"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "bc37dfc5-71c7-44a5-a4d4-b71534574872",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "kewei-ai",
"language": "python",
"name": "kewei-ai"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.5"
}
},
"nbformat": 4,
"nbformat_minor": 5
}