This commit is contained in:
kewei 2024-06-10 22:49:34 +08:00
parent 894b20d3c3
commit e2360bc0a2
10 changed files with 5666 additions and 315 deletions

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,528 @@
{
"cells": [
{
"cell_type": "raw",
"id": "bcd88fb5-6a0f-4c4b-81fd-34be59ea7903",
"metadata": {},
"source": [
"模型下载链接https://hf-mirror.com/BlinkDL/rwkv-4-pile-430m/resolve/main/RWKV-4-Pile-430M-20220808-8066.pth?download=true"
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "5b78b7ef-acc6-46cf-88c2-f90a2835e4b3",
"metadata": {},
"outputs": [],
"source": [
"########################################################################################################\n",
"# The RWKV Language Model - https://github.com/BlinkDL/RWKV-LM\n",
"########################################################################################################\n",
"\n",
"import numpy as np\n",
"np.set_printoptions(precision=4, suppress=True, linewidth=200)\n",
"import types, torch\n",
"from torch.nn import functional as F\n",
"from tokenizers import Tokenizer"
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "deacc22b-2896-4b77-b595-3284b0c13544",
"metadata": {},
"outputs": [],
"source": [
"tokenizer = Tokenizer.from_file(\"20B_tokenizer.json\")\n",
"\n",
"args = types.SimpleNamespace()\n",
"args.MODEL_NAME = '/data1/ckw/RWKV-4-Pile-430M-20220808-8066'\n",
"args.n_layer = 24\n",
"args.n_embd = 1024\n",
"\n",
"context = \"\\nDataWhalechina is an organization founded at Shanghai Jiao Tong University that helps learners learn artificial intelligence.\"\n",
"NUM_TRIALS = 3\n",
"LENGTH_PER_TRIAL = 100\n",
"TEMPERATURE = 1.0\n",
"TOP_P = 0.85\n",
"########################################################################################################"
]
},
{
"cell_type": "markdown",
"id": "3c85bca7-1342-4d8b-870c-baddcf2661d6",
"metadata": {},
"source": [
"### RWKV 的时间混合实现\n",
"\n",
"在 RWKV 模型中时间混合Time Mixing是一个关键步骤用于处理输入序列随时间的变化。以下是 `time_mixing` 函数的详细公式说明和代码注释。\n",
"\n",
"#### 公式说明\n",
"\n",
"时间混合的核心思想是通过时间混合系数将当前输入与先前的状态混合,以生成新的键、值和门控信号。这一过程涉及如下步骤:\n",
"\n",
"1. **混合输入**\n",
" - 对当前输入 \\( x \\) 和前一状态进行加权平均:\n",
" $$ x_k = x \\cdot \\text{time\\_mix\\_k} + \\text{state}[5i+1] \\cdot (1 - \\text{time\\_mix\\_k}) $$\n",
" $$ x_v = x \\cdot \\text{time\\_mix\\_v} + \\text{state}[5i+1] \\cdot (1 - \\text{time\\_mix\\_v}) $$\n",
" $$ x_r = x \\cdot \\text{time\\_mix\\_r} + \\text{state}[5i+1] \\cdot (1 - \\text{time\\_mix\\_r}) $$\n",
"\n",
"2. **状态更新**\n",
" - 更新状态:\n",
" $$ \\text{state}[5i+1] = x $$\n",
"\n",
"3. **计算门控信号**\n",
" - 使用 sigmoid 激活函数计算门控信号 \\( r \\)\n",
" $$ r = \\sigma(\\text{rw} @ x_r) $$\n",
"\n",
"4. **计算键和值**\n",
" - 通过线性变换生成键 \\( k \\) 和值 \\( v \\)\n",
" $$ k = \\text{kw} @ x_k $$\n",
" $$ v = \\text{vw} @ x_v $$\n",
"\n",
"5. **加权和计算**\n",
" - 根据加权和公式计算加权和 \\( wkv \\)\n",
" $$ a = e1 \\cdot aa + e2 \\cdot v $$\n",
" $$ b = e1 \\cdot bb + e2 $$\n",
" $$ \\text{wkv} = a / b $$\n",
"\n",
"代码如下:\n",
"\n",
"```python\n",
"@torch.jit.script_method\n",
"def time_mixing(self, x, state, i:int, time_mix_k, time_mix_v, time_mix_r, time_first, time_decay, kw, vw, rw, ow):\n",
" # 混合当前输入和先前的状态\n",
" xk = x * time_mix_k + state[5*i+1] * (1 - time_mix_k)\n",
" xv = x * time_mix_v + state[5*i+1] * (1 - time_mix_v)\n",
" xr = x * time_mix_r + state[5*i+1] * (1 - time_mix_r)\n",
"\n",
" # 更新状态\n",
" state[5*i+1] = x\n",
"\n",
" # 计算门控信号\n",
" r = torch.sigmoid(rw @ xr)\n",
" \n",
" # 计算键和值\n",
" k = kw @ xk\n",
" v = vw @ xv\n",
"\n",
" # 从状态中读取先前的累积值\n",
" aa = state[5*i+2]\n",
" bb = state[5*i+3]\n",
" pp = state[5*i+4]\n",
"\n",
" # 计算加权和的第一部分\n",
" ww = time_first + k\n",
" qq = torch.maximum(pp, ww)\n",
" e1 = torch.exp(pp - qq)\n",
" e2 = torch.exp(ww - qq)\n",
" a = e1 * aa + e2 * v\n",
" b = e1 * bb + e2\n",
" wkv = a / b\n",
"\n",
" # 计算新的权重和状态\n",
" ww = pp + time_decay\n",
" qq = torch.maximum(ww, k)\n",
" e1 = torch.exp(ww - qq)\n",
" e2 = torch.exp(k - qq)\n",
" state[5*i+2] = e1 * aa + e2 * v\n",
" state[5*i+3] = e1 * bb + e2\n",
" state[5*i+4] = qq\n",
"\n",
" # 计算最终的输出\n",
" return ow @ (r * wkv)\n",
"```\n",
"\n",
"### 详细解释\n",
"\n",
"1. **混合输入**\n",
" - `xk`, `xv`, `xr` 是输入 `x` 与状态 `state` 的加权混合,分别用于计算键、值和门控信号。\n",
"\n",
"2. **状态更新**\n",
" - 将当前输入 `x` 存储在状态数组中,供下一步计算使用。\n",
"\n",
"3. **计算门控信号**\n",
" - 使用 `torch.sigmoid` 计算门控信号 `r`,它决定了多少信息将被传递。\n",
"\n",
"4. **计算键和值**\n",
" - 使用矩阵乘法计算键 `k` 和值 `v`。\n",
"\n",
"5. **加权和计算**\n",
" - 通过指数加权平均计算加权和 `wkv`,这涉及处理数值稳定性问题(通过 `torch.maximum` 和指数运算)。\n",
"\n",
"6. **更新状态**\n",
" - 更新状态数组中的累积值,以便后续时间步使用。\n",
"\n",
"7. **计算最终输出**\n",
" - 使用门控信号 `r` 和加权和 `wkv` 计算最终输出。\n",
"\n",
"这样通过逐步混合当前输入和先前的状态RWKV 模型实现了时间序列数据的有效处理。"
]
},
{
"cell_type": "markdown",
"id": "1f0d47e4-1792-47e4-a506-3071f510526e",
"metadata": {},
"source": [
"### RWKV 的通道混合Channel Mixing实现与代码注释\n",
"\n",
"在 RWKV 模型中通道混合Channel Mixing是另一个关键步骤用于处理不同通道之间的信息交换。以下是 `channel_mixing` 函数的详细公式说明和代码注释。\n",
"\n",
"#### 公式说明\n",
"\n",
"通道混合的核心思想是通过通道混合系数将当前输入与先前的状态混合,以生成新的键和门控信号。这一过程涉及如下步骤:\n",
"\n",
"1. **混合输入**\n",
" - 对当前输入 \\( x \\) 和前一状态进行加权平均:\n",
" $$ x_k = x \\cdot \\text{time\\_mix\\_k} + \\text{state}[5i+0] \\cdot (1 - \\text{time\\_mix\\_k}) $$\n",
" $$ x_r = x \\cdot \\text{time\\_mix\\_r} + \\text{state}[5i+0] \\cdot (1 - \\text{time\\_mix\\_r}) $$\n",
"\n",
"2. **状态更新**\n",
" - 更新状态:\n",
" $$ \\text{state}[5i+0] = x $$\n",
"\n",
"3. **计算门控信号**\n",
" - 使用 sigmoid 激活函数计算门控信号 \\( r \\)\n",
" $$ r = \\sigma(\\text{rw} @ x_r) $$\n",
"\n",
"4. **计算键**\n",
" - 通过 ReLU 和平方变换生成键 \\( k \\)\n",
" $$ k = (\\text{ReLU}(\\text{kw} @ x_k))^2 $$\n",
"\n",
"5. **计算输出**\n",
" - 使用门控信号和键计算最终的输出:\n",
" $$ \\text{output} = r \\cdot (\\text{vw} @ k) $$\n",
"\n",
"代码如下:\n",
"\n",
"```python\n",
"@torch.jit.script_method\n",
"def channel_mixing(self, x, state, i:int, time_mix_k, time_mix_r, kw, vw, rw):\n",
" # 混合当前输入和先前的状态\n",
" xk = x * time_mix_k + state[5*i+0] * (1 - time_mix_k)\n",
" xr = x * time_mix_r + state[5*i+0] * (1 - time_mix_r)\n",
"\n",
" # 更新状态\n",
" state[5*i+0] = x\n",
"\n",
" # 计算门控信号\n",
" r = torch.sigmoid(rw @ xr)\n",
"\n",
" # 计算键并通过ReLU和平方变换\n",
" k = torch.square(torch.relu(kw @ xk)) # square relu, primer paper\n",
"\n",
" # 计算最终的输出\n",
" return r * (vw @ k)\n",
"```\n",
"\n",
"\n",
"1. **混合输入**\n",
" - `xk`, `xr` 是输入 `x` 与状态 `state` 的加权混合,分别用于计算键和门控信号。\n",
"\n",
"2. **状态更新**\n",
" - 将当前输入 `x` 存储在状态数组中,供下一步计算使用。\n",
"\n",
"3. **计算门控信号**\n",
" - 使用 `torch.sigmoid` 计算门控信号 `r`,它决定了多少信息将被传递。\n",
"\n",
"4. **计算键**\n",
" - 使用 `torch.relu` 计算键 `k`,然后进行平方变换以增加非线性特性。\n",
"\n",
"5. **计算最终输出**\n",
" - 使用门控信号 `r` 和键 `k` 计算最终输出。\n",
"\n",
"通过这些步骤RWKV 模型实现了通道间的信息有效交换,增强了模型对输入数据的处理能力。"
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "0f1b2e2b-9f0d-4db3-b9d9-d43e3e2537ee",
"metadata": {},
"outputs": [],
"source": [
"class RWKV_RNN(torch.jit.ScriptModule):\n",
" def __init__(self, args):\n",
" super().__init__()\n",
" self.args = args\n",
" self.eval() # set torch to inference mode\n",
" \n",
" w = torch.load(args.MODEL_NAME + '.pth', map_location='cpu')\n",
" for k in w.keys():\n",
" if '.time_' in k: w[k] = w[k].squeeze()\n",
" if '.time_decay' in k: w[k] = -torch.exp(w[k].float()) # the real time decay is like e^{-e^x}\n",
" else: w[k] = w[k].float() # convert to f32 type\n",
" \n",
" self.w = types.SimpleNamespace() # set self.w from w\n",
" self.w.blocks = {}\n",
" for k in w.keys(): # example: \"blocks.0.att.time_first\" => self.w.blocks[0].att.time_first\n",
" parts = k.split('.')\n",
" last = parts.pop()\n",
" here = self.w\n",
" for p in parts:\n",
" if p.isdigit():\n",
" p = int(p)\n",
" if p not in here: here[p] = types.SimpleNamespace()\n",
" here = here[p]\n",
" else:\n",
" if not hasattr(here, p): setattr(here, p, types.SimpleNamespace())\n",
" here = getattr(here, p)\n",
" setattr(here, last, w[k])\n",
"\n",
" def layer_norm(self, x, w):\n",
" return F.layer_norm(x, (self.args.n_embd,), weight=w.weight, bias=w.bias)\n",
"\n",
" @torch.jit.script_method\n",
" def channel_mixing(self, x, state, i:int, time_mix_k, time_mix_r, kw, vw, rw):\n",
" xk = x * time_mix_k + state[5*i+0] * (1 - time_mix_k)\n",
" xr = x * time_mix_r + state[5*i+0] * (1 - time_mix_r)\n",
" state[5*i+0] = x\n",
" r = torch.sigmoid(rw @ xr)\n",
" k = torch.square(torch.relu(kw @ xk)) # square relu, primer paper\n",
" return r * (vw @ k)\n",
"\n",
" @torch.jit.script_method\n",
" def time_mixing(self, x, state, i:int, time_mix_k, time_mix_v, time_mix_r, time_first, time_decay, kw, vw, rw, ow):\n",
" xk = x * time_mix_k + state[5*i+1] * (1 - time_mix_k)\n",
" xv = x * time_mix_v + state[5*i+1] * (1 - time_mix_v)\n",
" xr = x * time_mix_r + state[5*i+1] * (1 - time_mix_r)\n",
" state[5*i+1] = x\n",
" r = torch.sigmoid(rw @ xr)\n",
" k = kw @ xk\n",
" v = vw @ xv\n",
" \n",
" aa = state[5*i+2]\n",
" bb = state[5*i+3]\n",
" pp = state[5*i+4]\n",
" ww = time_first + k\n",
" qq = torch.maximum(pp, ww)\n",
" e1 = torch.exp(pp - qq)\n",
" e2 = torch.exp(ww - qq)\n",
" a = e1 * aa + e2 * v\n",
" b = e1 * bb + e2\n",
" wkv = a / b\n",
" ww = pp + time_decay\n",
" qq = torch.maximum(ww, k)\n",
" e1 = torch.exp(ww - qq)\n",
" e2 = torch.exp(k - qq)\n",
" state[5*i+2] = e1 * aa + e2 * v\n",
" state[5*i+3] = e1 * bb + e2\n",
" state[5*i+4] = qq\n",
" return ow @ (r * wkv)\n",
"\n",
" def forward(self, token, state):\n",
" with torch.no_grad():\n",
" if state == None:\n",
" state = torch.zeros(self.args.n_layer * 5, self.args.n_embd)\n",
" for i in range(self.args.n_layer): state[5*i+4] = -1e30 # -infinity\n",
" \n",
" x = self.w.emb.weight[token]\n",
" x = self.layer_norm(x, self.w.blocks[0].ln0)\n",
" for i in range(self.args.n_layer):\n",
" att = self.w.blocks[i].att\n",
" x = x + self.time_mixing(self.layer_norm(x, self.w.blocks[i].ln1), state, i, \n",
" att.time_mix_k, att.time_mix_v, att.time_mix_r, att.time_first, att.time_decay, \n",
" att.key.weight, att.value.weight, att.receptance.weight, att.output.weight)\n",
" ffn = self.w.blocks[i].ffn\n",
" x = x + self.channel_mixing(self.layer_norm(x, self.w.blocks[i].ln2), state, i, \n",
" ffn.time_mix_k, ffn.time_mix_r, \n",
" ffn.key.weight, ffn.value.weight, ffn.receptance.weight)\n",
" \n",
" x = self.w.head.weight @ self.layer_norm(x, self.w.ln_out)\n",
" return x.float(), state\n",
"\n",
"##########################################################################################################"
]
},
{
"cell_type": "markdown",
"id": "f1b457af-77a3-4b5e-a6f3-034b0fc6708d",
"metadata": {},
"source": [
"采样方法和v2、v3版本相比没有发生变化代码做了一点优化调整而已。"
]
},
{
"cell_type": "code",
"execution_count": 6,
"id": "fdf027b6-7df9-4c0f-818e-013e7c49e3cd",
"metadata": {},
"outputs": [],
"source": [
"def sample_logits(out, temperature=1.0, top_p=0.8):\n",
" probs = F.softmax(out, dim=-1).numpy()\n",
" sorted_probs = np.sort(probs)[::-1]\n",
" cumulative_probs = np.cumsum(sorted_probs)\n",
" cutoff = float(sorted_probs[np.argmax(cumulative_probs > top_p)])\n",
" probs[probs < cutoff] = 0\n",
" if temperature != 1.0:\n",
" probs = probs.pow(1.0 / temperature)\n",
" probs = probs / np.sum(probs)\n",
" out = np.random.choice(a=len(probs), p=probs)\n",
" return out\n",
"\n",
"########################################################################################################"
]
},
{
"cell_type": "code",
"execution_count": 7,
"id": "298dbbde-6535-406b-bd43-f2d886799f8c",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n",
"Using CPU. Loading /data1/ckw/RWKV-4-Pile-430M-20220808-8066 ...\n"
]
}
],
"source": [
"print(f'\\nUsing CPU. Loading {args.MODEL_NAME} ...')\n",
"model = RWKV_RNN(args)"
]
},
{
"cell_type": "code",
"execution_count": 8,
"id": "7d366a89-02cb-4b5e-95ef-52f6376d3607",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n",
"Preprocessing context (slow version. see v2/rwkv/model.py for fast version)\n"
]
}
],
"source": [
"print(f'\\nPreprocessing context (slow version. see v2/rwkv/model.py for fast version)')\n",
"init_state = None\n",
"for token in tokenizer.encode(context).ids:\n",
" init_out, init_state = model.forward(token, init_state)"
]
},
{
"cell_type": "code",
"execution_count": 9,
"id": "5273e7a8-875e-4998-b98e-f81951a7af32",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n",
"\n",
"--[ Trial 0 ]----------------- \n",
"DataWhalechina is an organization founded at Shanghai Jiao Tong University that helps learners learn artificial intelligence. The machine learning solutions applied to the class are called Persona, which consist of several categories:\n",
"\n",
"\\begin{tabular}{|c|c|c|}\n",
"\\hline\n",
" Name & Description \\\\ \\hline\n",
"\\hline\n",
" \\end{tabular}\n",
"\n",
"DataWhalechina organizes the data in two ways:\n",
"\n",
"\\begin{tabular}{|c|c|c|}\n",
"\\hline\n",
" \\multicolumn{2}{|c}{\n",
"\n",
"--[ Trial 1 ]----------------- \n",
"DataWhalechina is an organization founded at Shanghai Jiao Tong University that helps learners learn artificial intelligence. The main goal is to allow learners to learn how to use artificial intelligence in an integrated fashion, by using both AI and deep learning techniques. Datawhalechina aims to teach AI algorithms from scratch and teach them from scratch to become competent with many algorithms that humans could not have.\n",
"\n",
"Applications\n",
"\n",
"Projects \n",
" DeeplearningAI : Encourage AI algorithms to become competent with many algorithms that humans could not have. Datawhalechina aims to be able to combine knowledge from multiple AI\n",
"\n",
"--[ Trial 2 ]----------------- \n",
"DataWhalechina is an organization founded at Shanghai Jiao Tong University that helps learners learn artificial intelligence. The company was founded in 2016. The company has graduated 1,000 engineers, who work from the companies headquarters in Shanghai.\n",
"\n",
"In September 2019, the team was reported to have learned over 400,000 artificial intelligence.\n",
"\n",
"In August 2019, the company was reported to have sold 600,000 artificial intelligence to clients in Singapore.\n",
"\n",
"References\n",
"\n",
"External links\n",
" \n",
"\n",
"Category:Human machine interaction\n",
"Category:Learning management systems\n",
"Category:Learning management systemsTechnologies, industry,\n",
"\n"
]
}
],
"source": [
"for TRIAL in range(NUM_TRIALS):\n",
" print(f'\\n\\n--[ Trial {TRIAL} ]-----------------', context, end=\"\")\n",
" all_tokens = []\n",
" out_last = 0\n",
" out, state = init_out.clone(), init_state.clone()\n",
" for i in range(LENGTH_PER_TRIAL):\n",
" token = sample_logits(out, TEMPERATURE, TOP_P)\n",
" all_tokens += [token]\n",
" tmp = tokenizer.decode(all_tokens[out_last:])\n",
" if '\\ufffd' not in tmp: # only print when we have a valid utf-8 string\n",
" print(tmp, end=\"\", flush=True)\n",
" out_last = i + 1\n",
" out, state = model.forward(token, state) \n",
"print('\\n')"
]
},
{
"cell_type": "markdown",
"id": "f1cdc809-c64d-4861-b540-460bc1097e38",
"metadata": {},
"source": [
"### 备注RWKV 的Scaling Law缩放定律\n",
"\n",
"RWKV 的缩放定律描述了模型性能随着各种因素变化的数学关系。这些因素包括模型大小($N$)、数据集大小($D$)或最优计算预算($C_{\\min}$)。缩放定律的重要性体现在以下两个方面:\n",
"1. **预测与规划**:它们允许我们在训练大型模型之前,通过插值和外推来预测和规划成本和性能。\n",
"2. **反馈与研究**:它们提供了关于模型失效情况下的重要反馈,指引未来研究方向。\n",
"\n",
"#### 关键内容总结:\n",
"- **与之前的RNN研究对比**之前的工作指出LSTM不完全遵循与Transformer相同的对数线性缩放定律。然而RWKV模型的训练结果表明RWKV遵循与Transformer相同的一般缩放定律形式。\n",
"- **实验验证**:在[v4的论文](https://arxiv.org/abs/2305.13048)通过训练45个RWKV模型验证了其损失与计算量之间的线性关系线性拟合的 $r^2$ 值为0.994,即使外推一个数量级,拟合度仍然很好($r^2$为0.875)。\n",
"\n",
"这些结果显示了RWKV模型在缩放时的优越性和与Transformer相似的性能缩放行为。"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "13f6025d-faea-4647-be05-8fb4cce05991",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.5"
}
},
"nbformat": 4,
"nbformat_minor": 5
}

View File

@ -1,297 +0,0 @@
{
"cells": [
{
"cell_type": "raw",
"id": "bcd88fb5-6a0f-4c4b-81fd-34be59ea7903",
"metadata": {},
"source": [
"模型下载链接https://hf-mirror.com/BlinkDL/rwkv-4-pile-430m/resolve/main/RWKV-4-Pile-430M-20220808-8066.pth?download=true"
]
},
{
"cell_type": "code",
"execution_count": 1,
"id": "5b78b7ef-acc6-46cf-88c2-f90a2835e4b3",
"metadata": {},
"outputs": [],
"source": [
"########################################################################################################\n",
"# The RWKV Language Model - https://github.com/BlinkDL/RWKV-LM\n",
"########################################################################################################\n",
"\n",
"import numpy as np\n",
"np.set_printoptions(precision=4, suppress=True, linewidth=200)\n",
"import types, torch\n",
"from torch.nn import functional as F\n",
"from tokenizers import Tokenizer"
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "deacc22b-2896-4b77-b595-3284b0c13544",
"metadata": {},
"outputs": [],
"source": [
"tokenizer = Tokenizer.from_file(\"20B_tokenizer.json\")\n",
"\n",
"args = types.SimpleNamespace()\n",
"args.MODEL_NAME = '/data1/ckw/RWKV-4-Pile-430M-20220808-8066'\n",
"args.n_layer = 24\n",
"args.n_embd = 1024\n",
"\n",
"context = \"\\nDataWhalechina is an organization founded at Shanghai Jiao Tong University that helps learners learn artificial intelligence.\"\n",
"NUM_TRIALS = 3\n",
"LENGTH_PER_TRIAL = 100\n",
"TEMPERATURE = 1.0\n",
"TOP_P = 0.85\n",
"########################################################################################################"
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "0f1b2e2b-9f0d-4db3-b9d9-d43e3e2537ee",
"metadata": {},
"outputs": [],
"source": [
"class RWKV_RNN(torch.jit.ScriptModule):\n",
" def __init__(self, args):\n",
" super().__init__()\n",
" self.args = args\n",
" self.eval() # set torch to inference mode\n",
" \n",
" w = torch.load(args.MODEL_NAME + '.pth', map_location='cpu')\n",
" for k in w.keys():\n",
" if '.time_' in k: w[k] = w[k].squeeze()\n",
" if '.time_decay' in k: w[k] = -torch.exp(w[k].float()) # the real time decay is like e^{-e^x}\n",
" else: w[k] = w[k].float() # convert to f32 type\n",
" \n",
" self.w = types.SimpleNamespace() # set self.w from w\n",
" self.w.blocks = {}\n",
" for k in w.keys(): # example: \"blocks.0.att.time_first\" => self.w.blocks[0].att.time_first\n",
" parts = k.split('.')\n",
" last = parts.pop()\n",
" here = self.w\n",
" for p in parts:\n",
" if p.isdigit():\n",
" p = int(p)\n",
" if p not in here: here[p] = types.SimpleNamespace()\n",
" here = here[p]\n",
" else:\n",
" if not hasattr(here, p): setattr(here, p, types.SimpleNamespace())\n",
" here = getattr(here, p)\n",
" setattr(here, last, w[k])\n",
"\n",
" def layer_norm(self, x, w):\n",
" return F.layer_norm(x, (self.args.n_embd,), weight=w.weight, bias=w.bias)\n",
"\n",
" @torch.jit.script_method\n",
" def channel_mixing(self, x, state, i:int, time_mix_k, time_mix_r, kw, vw, rw):\n",
" xk = x * time_mix_k + state[5*i+0] * (1 - time_mix_k)\n",
" xr = x * time_mix_r + state[5*i+0] * (1 - time_mix_r)\n",
" state[5*i+0] = x\n",
" r = torch.sigmoid(rw @ xr)\n",
" k = torch.square(torch.relu(kw @ xk)) # square relu, primer paper\n",
" return r * (vw @ k)\n",
"\n",
" @torch.jit.script_method\n",
" def time_mixing(self, x, state, i:int, time_mix_k, time_mix_v, time_mix_r, time_first, time_decay, kw, vw, rw, ow):\n",
" xk = x * time_mix_k + state[5*i+1] * (1 - time_mix_k)\n",
" xv = x * time_mix_v + state[5*i+1] * (1 - time_mix_v)\n",
" xr = x * time_mix_r + state[5*i+1] * (1 - time_mix_r)\n",
" state[5*i+1] = x\n",
" r = torch.sigmoid(rw @ xr)\n",
" k = kw @ xk\n",
" v = vw @ xv\n",
" \n",
" aa = state[5*i+2]\n",
" bb = state[5*i+3]\n",
" pp = state[5*i+4]\n",
" ww = time_first + k\n",
" qq = torch.maximum(pp, ww)\n",
" e1 = torch.exp(pp - qq)\n",
" e2 = torch.exp(ww - qq)\n",
" a = e1 * aa + e2 * v\n",
" b = e1 * bb + e2\n",
" wkv = a / b\n",
" ww = pp + time_decay\n",
" qq = torch.maximum(ww, k)\n",
" e1 = torch.exp(ww - qq)\n",
" e2 = torch.exp(k - qq)\n",
" state[5*i+2] = e1 * aa + e2 * v\n",
" state[5*i+3] = e1 * bb + e2\n",
" state[5*i+4] = qq\n",
" return ow @ (r * wkv)\n",
"\n",
" def forward(self, token, state):\n",
" with torch.no_grad():\n",
" if state == None:\n",
" state = torch.zeros(self.args.n_layer * 5, self.args.n_embd)\n",
" for i in range(self.args.n_layer): state[5*i+4] = -1e30 # -infinity\n",
" \n",
" x = self.w.emb.weight[token]\n",
" x = self.layer_norm(x, self.w.blocks[0].ln0)\n",
" for i in range(self.args.n_layer):\n",
" att = self.w.blocks[i].att\n",
" x = x + self.time_mixing(self.layer_norm(x, self.w.blocks[i].ln1), state, i, \n",
" att.time_mix_k, att.time_mix_v, att.time_mix_r, att.time_first, att.time_decay, \n",
" att.key.weight, att.value.weight, att.receptance.weight, att.output.weight)\n",
" ffn = self.w.blocks[i].ffn\n",
" x = x + self.channel_mixing(self.layer_norm(x, self.w.blocks[i].ln2), state, i, \n",
" ffn.time_mix_k, ffn.time_mix_r, \n",
" ffn.key.weight, ffn.value.weight, ffn.receptance.weight)\n",
" \n",
" x = self.w.head.weight @ self.layer_norm(x, self.w.ln_out)\n",
" return x.float(), state\n",
"\n",
"##########################################################################################################"
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "fdf027b6-7df9-4c0f-818e-013e7c49e3cd",
"metadata": {},
"outputs": [],
"source": [
"def sample_logits(out, temperature=1.0, top_p=0.8):\n",
" probs = F.softmax(out, dim=-1).numpy()\n",
" sorted_probs = np.sort(probs)[::-1]\n",
" cumulative_probs = np.cumsum(sorted_probs)\n",
" cutoff = float(sorted_probs[np.argmax(cumulative_probs > top_p)])\n",
" probs[probs < cutoff] = 0\n",
" if temperature != 1.0:\n",
" probs = probs.pow(1.0 / temperature)\n",
" probs = probs / np.sum(probs)\n",
" out = np.random.choice(a=len(probs), p=probs)\n",
" return out\n",
"\n",
"########################################################################################################"
]
},
{
"cell_type": "code",
"execution_count": 6,
"id": "298dbbde-6535-406b-bd43-f2d886799f8c",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n",
"Using CPU. Loading /data1/ckw/RWKV-4-Pile-430M-20220808-8066 ...\n"
]
}
],
"source": [
"print(f'\\nUsing CPU. Loading {args.MODEL_NAME} ...')\n",
"model = RWKV_RNN(args)"
]
},
{
"cell_type": "code",
"execution_count": 7,
"id": "7d366a89-02cb-4b5e-95ef-52f6376d3607",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n",
"Preprocessing context (slow version. see v2/rwkv/model.py for fast version)\n"
]
}
],
"source": [
"print(f'\\nPreprocessing context (slow version. see v2/rwkv/model.py for fast version)')\n",
"init_state = None\n",
"for token in tokenizer.encode(context).ids:\n",
" init_out, init_state = model.forward(token, init_state)"
]
},
{
"cell_type": "code",
"execution_count": 8,
"id": "5273e7a8-875e-4998-b98e-f81951a7af32",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n",
"\n",
"--[ Trial 0 ]----------------- \n",
"DataWhalechina is an organization founded at Shanghai Jiao Tong University that helps learners learn artificial intelligence. Founded in 2015 by AI graduate student Yawei Li, DataWhalechina aims to help people learn to think more naturally about data. Learn more.\n",
"\n",
"50% of U.S. high school graduates who take data science courses go on to pursue masters degrees, which cost about $7,000, according to The American Council for an Energy Efficient Economy. Learn more.\n",
"\n",
"More than 600 startups compete for the same creative talent awards in 2016. If there is one award\n",
"\n",
"--[ Trial 1 ]----------------- \n",
"DataWhalechina is an organization founded at Shanghai Jiao Tong University that helps learners learn artificial intelligence. Datawhalechina was established in 2013. The company was created to serve the needs of companies that seek to increase the utilization of machine learning technology in their environments. This aims to create a platform that will help increase the adoption of machine learning technology in organizations by creating better decision support tools.\n",
"\n",
"As of 2017, Datawhalechina's team of specialists are spread over the United States, Europe, Asia, Africa and Canada. Their strategy includes providing low-cost software solutions to the\n",
"\n",
"--[ Trial 2 ]----------------- \n",
"DataWhalechina is an organization founded at Shanghai Jiao Tong University that helps learners learn artificial intelligence.\n",
"The main objective of the organization is to provide diverse students with the information, skills, knowledge and ideas needed to tackle big challenges in their future. The success of the program has prompted the city government to give them more resources to bring more students in the program.\n",
"\n",
"Ethereum\n",
"\n",
"The Ethereum (ETH) blockchain, designed by XRP, is a decentralised ledger technology that enables Bitcoin (BTC) and other cryptocurrencies to be used as payment. It aims to be the largest\n",
"\n"
]
}
],
"source": [
"for TRIAL in range(NUM_TRIALS):\n",
" print(f'\\n\\n--[ Trial {TRIAL} ]-----------------', context, end=\"\")\n",
" all_tokens = []\n",
" out_last = 0\n",
" out, state = init_out.clone(), init_state.clone()\n",
" for i in range(LENGTH_PER_TRIAL):\n",
" token = sample_logits(out, TEMPERATURE, TOP_P)\n",
" all_tokens += [token]\n",
" tmp = tokenizer.decode(all_tokens[out_last:])\n",
" if '\\ufffd' not in tmp: # only print when we have a valid utf-8 string\n",
" print(tmp, end=\"\", flush=True)\n",
" out_last = i + 1\n",
" out, state = model.forward(token, state) \n",
"print('\\n')"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "d546bfd9-cf80-49bf-8c76-f3918d7d67e4",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.5"
}
},
"nbformat": 4,
"nbformat_minor": 5
}

View File

@ -0,0 +1,869 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": 4,
"id": "1fb76974-93ea-4b9c-81b1-55f826e7a361",
"metadata": {},
"outputs": [],
"source": [
"########################################################################################################\n",
"# The RWKV Language Model - https://github.com/BlinkDL/RWKV-LM\n",
"########################################################################################################\n",
"\n",
"import numpy as np\n",
"np.set_printoptions(precision=4, suppress=True, linewidth=200)\n",
"import types, torch\n",
"import torch.nn as nn\n",
"from torch.nn import functional as F\n",
"\n",
"MyModule = torch.jit.ScriptModule\n",
"MyFunction = torch.jit.script_method"
]
},
{
"cell_type": "markdown",
"id": "9c97049c-d3ae-4c72-bff4-d99416f8d650",
"metadata": {},
"source": [
"rwkv5又叫eagal"
]
},
{
"cell_type": "code",
"execution_count": 5,
"id": "b1059eca-db4f-4c0b-ae3e-37af49ec7fa1",
"metadata": {},
"outputs": [],
"source": [
"import torch"
]
},
{
"cell_type": "code",
"execution_count": 6,
"id": "1c8d8009-7ee7-4419-aacb-cdc45f287010",
"metadata": {},
"outputs": [],
"source": [
"class RWKV_TOKENIZER():\n",
" table: list[list[list[bytes]]]\n",
" good: list[set[int]]\n",
" wlen: list[int]\n",
" def __init__(self, file_name):\n",
" self.idx2token = {}\n",
" sorted = [] # must be already sorted\n",
" lines = open(file_name, \"r\", encoding=\"utf-8\").readlines()\n",
" for l in lines:\n",
" idx = int(l[:l.index(' ')])\n",
" x = eval(l[l.index(' '):l.rindex(' ')])\n",
" x = x.encode(\"utf-8\") if isinstance(x, str) else x\n",
" assert isinstance(x, bytes)\n",
" assert len(x) == int(l[l.rindex(' '):])\n",
" sorted += [x]\n",
" self.idx2token[idx] = x\n",
"\n",
" self.token2idx = {}\n",
" for k, v in self.idx2token.items():\n",
" self.token2idx[v] = int(k)\n",
"\n",
" # precompute some tables for fast matching\n",
" self.table = [[[] for j in range(256)] for i in range(256)]\n",
" self.good = [set() for i in range(256)]\n",
" self.wlen = [0 for i in range(256)]\n",
"\n",
" for i in reversed(range(len(sorted))): # reverse order - match longer tokens first\n",
" s = sorted[i]\n",
" if len(s) >= 2:\n",
" s0 = int(s[0])\n",
" s1 = int(s[1])\n",
" self.table[s0][s1] += [s]\n",
" self.wlen[s0] = max(self.wlen[s0], len(s))\n",
" self.good[s0].add(s1)\n",
"\n",
" def encodeBytes(self, src: bytes) -> list[int]:\n",
" src_len: int = len(src)\n",
" tokens: list[int] = []\n",
" i: int = 0\n",
" while i < src_len:\n",
" s: bytes = src[i : i + 1]\n",
"\n",
" if i < src_len - 1:\n",
" s1: int = int(src[i + 1])\n",
" s0: int = int(src[i])\n",
" if s1 in self.good[s0]:\n",
" sss: bytes = src[i : i + self.wlen[s0]]\n",
" try:\n",
" s = next(filter(sss.startswith, self.table[s0][s1]))\n",
" except:\n",
" pass\n",
" tokens.append(self.token2idx[s])\n",
" i += len(s)\n",
"\n",
" return tokens\n",
"\n",
" def decodeBytes(self, tokens):\n",
" return b''.join(map(lambda i: self.idx2token[i], tokens))\n",
"\n",
" def encode(self, src: str):\n",
" return self.encodeBytes(src.encode(\"utf-8\"))\n",
"\n",
" def decode(self, tokens):\n",
" return self.decodeBytes(tokens).decode('utf-8')\n",
"\n",
" def printTokens(self, tokens):\n",
" for i in tokens:\n",
" s = self.idx2token[i]\n",
" try:\n",
" s = s.decode('utf-8')\n",
" except:\n",
" pass\n",
" print(f'{repr(s)}{i}', end=' ')\n",
" # print(repr(s), i)\n",
" print()\n",
"\n",
"########################################################################################################"
]
},
{
"cell_type": "code",
"execution_count": 7,
"id": "63a4e8ba-a291-4fdc-aef1-ebfca21840d4",
"metadata": {},
"outputs": [],
"source": [
"def sample_logits(out, temperature=1.0, top_p=0.8):\n",
" probs = F.softmax(out, dim=-1).numpy()\n",
" sorted_probs = np.sort(probs)[::-1]\n",
" cumulative_probs = np.cumsum(sorted_probs)\n",
" cutoff = float(sorted_probs[np.argmax(cumulative_probs > top_p)])\n",
" probs[probs < cutoff] = 0\n",
" if temperature != 1.0:\n",
" probs = probs.pow(1.0 / temperature)\n",
" probs = probs / np.sum(probs)\n",
" out = np.random.choice(a=len(probs), p=probs)\n",
" return out\n",
"\n",
"########################################################################################################"
]
},
{
"cell_type": "raw",
"id": "cb8c7d5e-08cb-4780-b6d9-ab8bad1417d4",
"metadata": {},
"source": [
"可以从这个链接下载模型:\n",
"https://www.modelscope.cn/models/AI-ModelScope/rwkv-5-world/files\n",
"https://www.modelscope.cn/api/v1/models/AI-ModelScope/rwkv-5-world/repo?Revision=master&FilePath=RWKV-5-World-0.1B-v1-20230803-ctx4096.pth"
]
},
{
"cell_type": "code",
"execution_count": 8,
"id": "94d7d6db-e89e-4209-ae72-6625ba85ef5b",
"metadata": {},
"outputs": [],
"source": [
"tokenizer = RWKV_TOKENIZER(\"./rwkv_vocab_v20230424.txt\")\n",
"\n",
"# THIS IS NOW UPDATED TO SUPPORT LATEST RWKV-5 WORLD v2 MODELS\n",
"\n",
"args = types.SimpleNamespace()\n",
"args.MODEL_NAME = '/data1/ckw/RWKV-5-World-0.4B-v2-20231113-ctx4096' #这里不用有后缀.pth\n",
"args.n_layer = 24\n",
"args.n_embd = 1024\n",
"args.vocab_size = 65536"
]
},
{
"cell_type": "code",
"execution_count": 9,
"id": "c8dcf39a-7838-454b-85fc-ec9bd75fa243",
"metadata": {},
"outputs": [],
"source": [
"# N_LAYER=\"12\"\n",
"# N_EMBD=\"768\"\n",
"N_LAYER=\"24\"\n",
"N_EMBD=\"1024\""
]
},
{
"cell_type": "code",
"execution_count": 10,
"id": "74d7c96a-6fbc-401c-8078-fefb1a6ec5c3",
"metadata": {},
"outputs": [],
"source": [
"# context = \"\\nElon Musk has\"\n",
"# context = \"\\n我们发现\"\n",
"context = \"Q:Do you know datawhalechina?\\nA:\"\n",
"NUM_TRIALS = 3\n",
"LENGTH_PER_TRIAL = 100\n",
"LENGTH_PER_TRIAL = 4096\n",
"TEMPERATURE = 1.0\n",
"TOP_P = 0.7"
]
},
{
"cell_type": "markdown",
"id": "5ac1c244-71e3-4263-8ad8-4d0cf681ebfd",
"metadata": {},
"source": [
"Eagle (RWKV-5) 和 Finch (RWKV-6) 相较于基础的RWKV-4架构在建模上的改进\n",
"\n",
"1. **改进步骤**\n",
" - **Eagle的改进**Eagle模型在RWKV-4的基础上进行了多项改进包括引入矩阵值的注意力状态matrix-valued attention states、在注意力头上应用LayerNorm层归一化、使用SiLUSigmoid-Weighted Linear Unit进行注意力门控、并改进了初始化方法。此外Eagle移除了接受度receptance函数中的Sigmoid激活函数。\n",
" - **Finch的改进**Finch模型进一步引入了对衰减计划decay schedule和令牌移位token-shift的数据依赖性data-dependence使模型在处理时间和令牌数据时更加灵活和精确。\n",
"\n",
"2. **核心架构**\n",
" - 这些模型的核心架构依然类似于RWKV-4由一系列堆叠的残差块组成形状类似于传统的Transformer架构。\n",
" - 每个块包含一个预LayerNorm时间混合子层Pre-LayerNorm Time-Mixing sub-layer和一个预LayerNorm通道混合子层Pre-LayerNorm Channel-Mixing sub-layer对应于Transformer中的注意力子层和前馈网络子层。\n"
]
},
{
"cell_type": "markdown",
"id": "bd3d56d6-59af-41d0-9ac3-2cc5b4fb54ed",
"metadata": {},
"source": [
"这个是RWKV 5的Channel Mixing的代码实现可以对比一下RWKV 4的实现。\n",
"\n",
"\n",
"```python\n",
"@MyFunction\n",
" def channel_mixing(self, x, state, i:int, time_mix_k, time_mix_r, kw, vw, rw):\n",
" i0 = (2+self.head_size)*i+0\n",
" xk = x * time_mix_k + state[i0] * (1 - time_mix_k)\n",
" xr = x * time_mix_r + state[i0] * (1 - time_mix_r)\n",
" state[i0] = x\n",
" r = torch.sigmoid(rw @ xr)\n",
" k = torch.square(torch.relu(kw @ xk)) # square relu, primer paper\n",
" return r * (vw @ k)\n",
"```\n",
"\n",
"RWKV 4的Channel Mixing的代码实现为\n",
"\n",
"\n",
"```python\n",
"@torch.jit.script_method\n",
" def channel_mixing(self, x, state, i:int, time_mix_k, time_mix_r, kw, vw, rw):\n",
" xk = x * time_mix_k + state[5*i+0] * (1 - time_mix_k)\n",
" xr = x * time_mix_r + state[5*i+0] * (1 - time_mix_r)\n",
" state[5*i+0] = x\n",
" r = torch.sigmoid(rw @ xr)\n",
" k = torch.square(torch.relu(kw @ xk)) # square relu, primer paper\n",
" return r * (vw @ k)\n",
"```\n",
"\n",
"这里的`i`表示的是RWKV有多少层在RWKV4的每一层中Channel Mixing记录一个状态而每一个Time Mixing则记录4个状态所以一共是5个状态。而RWKV 5中每一层现在记录了`2+self.head_size`个状态Channel Mixing记录的状态以及计算过程和RWKV 4是完全一样的。"
]
},
{
"cell_type": "markdown",
"id": "976f399a-78ba-4fb2-9b52-d19afda8c5d0",
"metadata": {},
"source": [
"![](./img/01.png)\n",
"\n",
"图1RWKV架构概述。左侧时间混合和通道混合块右上角作为RNN单元的RWKV时间混合块中下部前馈模块中的令牌移位模块和Eagle时间混合右下角Finch时间混合中的令牌移位模块。所有形状注释为简洁起见假设为单头。虚线箭头左侧右上角表示在Finch中有连接但在Eagle中没有。"
]
},
{
"cell_type": "markdown",
"id": "ea37d6cd-1348-450b-9f6e-198d7c1d8368",
"metadata": {},
"source": [
"Eagle模型中采用的Token Shift技术\n",
"\n",
"1. **Token Shift**\n",
" - Eagle模型从之前的RWKV模型中采用了Token Shift技术这类似于大小为2的一维因果卷积1D causal convolution。\n",
" - 在图1的中心底部可以看到该技术的示意图。\n",
"\n",
"2. **线性插值定义**\n",
" - 为了更好地介绍Token Shift技术定义了一些符号。\n",
" - 线性插值lerp在时间步$t$和$t-1$之间用于RWKV-4和Eagle Token Shift定义如下\n",
" \\begin{align*}\n",
" \\text{lerp}_{\\Box}(a, b) = a + (b - a) \\odot \\mu_{\\Box}\n",
" \\end{align*}\n",
" - 其中,每个$\\mu_{\\Box} \\in \\mathbb{R}^D$是一个可学习的向量。\n",
"\n",
"3. **Token Shift的功能**\n",
" - Token Shift允许模型学习在每个时间步中分配新信息和旧信息的比例适用于接受度receptance、键key、值value和门控向量gate vectors中的每个通道$r, k, v, g$且每个头部head独立且唯一地应用这些向量。\n",
" - 这使得即使在单层内一个单独的头部也可以直接将过去和当前的令牌数据累积到这些向量的不同子空间中从而形成感应头induction heads。\n"
]
},
{
"cell_type": "markdown",
"id": "40ca7c1b-8faa-42d5-8c7d-c4f233063856",
"metadata": {},
"source": [
"在Eagle和Finch模型中通道混合模块Channel Mixing module的设置及其与RWKV-4架构的异同如下\n",
"\n",
"1. **模块一致性**\n",
" - 在Eagle和Finch模型中通道混合模块与之前的RWKV-4架构基本相同。\n",
" - 唯一的区别在于Eagle模型中通道混合模块的隐藏维度hidden dimension从原来的4D减少到了3.5D。\n",
"\n",
"2. **减少维度的原因**\n",
" - 这个隐藏维度的减少是为了在Eagle时间混合Eagle Time Mixing中引入新的门控权重gating weights并确保与之前模型在相同层数和嵌入维度下的参数数量相等。\n",
"\n",
"3. **Finch模型中的处理**\n",
" - 尽管Finch模型中增加了一些新的LoRA权重参数但并没有进一步减少隐藏维度。\n",
"\n",
"4. **公式一致性**\n",
" - 通道混合的公式与RWKV-4模型相同为了符号一致性notational consistency再次列出这些公式\n",
"\n",
"\\begin{align*}\n",
"r'_t &= \\text{lerp}_{r'}(x'_t, x'_{t-1}) W_{r'} \\in \\mathbb{R}^D \\quad \\text{(公式10)} \\\\\n",
"k'_t &= \\text{lerp}_{k'}(x'_t, x'_{t-1}) W_{k'} \\in \\mathbb{R}^{3.5D} \\quad \\text{(公式11)} \\\\\n",
"v'_t &= \\text{ReLU}(k'_t)^2 W_{v'} \\in \\mathbb{R}^D \\quad \\text{(公式12)} \\\\\n",
"o'_t &= \\sigma(r'_t) \\odot v'_t \\in \\mathbb{R}^D \\quad \\text{(公式13)}\n",
"\\end{align*}\n",
"\n",
"这些公式描述了在时间步 \\( t \\) 的通道混合操作:\n",
"- 使用线性插值lerp计算 \\( r'_t \\) 和 \\( k'_t \\)。\n",
"- \\( v'_t \\) 通过 \\( k'_t \\) 的ReLU平方值乘以权重矩阵 \\( W_{v'} \\) 得到。\n",
"- \\( o'_t \\) 是 \\( r'_t \\) 的激活函数 \\( \\sigma \\) 的输出与 \\( v'_t \\) 的逐元素乘积。\n",
"\n",
"其中3.5D 指的是一种表示维度的方式。在深度学习模型中D 通常代表模型的隐藏层维度即嵌入维度或特征空间的维度。例如如果模型的隐藏维度是256那么4D表示这个维度被扩展为4倍也就是1024。\n",
"\n",
"然而3.5D 是一个不常见的表示方法通常情况下我们会看到整数倍的表示如2D, 4D等。在这里3.5D代表的是隐藏维度的3.5倍。\n",
"\n",
"具体来说如果模型的基础维度是D那么3.5D就表示:\n",
"\\begin{align*} 3.5D = 3.5 \\times D \\end{align*}\n",
"\n",
"假设D是256那么3.5D就是:\n",
"\\begin{align*} 3.5 \\times 256 = 896 \\end{align*}\n",
"\n",
"所以3.5D就是指模型在特定层中使用的特征维度是基础维度的3.5倍。在这个文档中作者提到从4D减少到3.5D,意味着他们减少了某个层或模块的特征维度,以便引入新的门控权重并保持参数数量的一致性。"
]
},
{
"cell_type": "markdown",
"id": "67a5c0d4-57d5-4f9b-a2b7-bddd2250f08a",
"metadata": {},
"source": [
"Eagle时间混合Eagle Time Mixing的公式及其操作方法如下\n",
"\n",
"### 公式部分\n",
"\n",
"Eagle时间混合的公式如下\n",
"\n",
"\\begin{align*}\n",
"\\Box_t &= \\text{lerp}_{\\Box}(x_t, x_{t-1}) W_{\\Box}, \\quad \\Box \\in \\{r, k, v, g\\} \\tag{4} \\\\\n",
"w &= \\exp(-\\exp(\\omega)) \\tag{5} \\\\\n",
"\\text{wk} \\mathbf{v}_t &= \\text{diag}(u) \\cdot k_t^\\top \\cdot v_t + \\sum_{i=1}^{t-1} \\text{diag}(w)^{t-1-i} \\cdot k_i^\\top \\cdot v_i \\in \\mathbb{R}^{(D/h) \\times (D/h)} \\tag{6} \\\\\n",
"o_t &= \\text{concat} \\left( \\text{SiLU}(g_t) \\odot \\text{LayerNorm}(r_t \\cdot \\text{wk} \\mathbf{v}_t) \\right) W_o \\in \\mathbb{R}^D \\tag{7}\n",
"\\end{align*}\n",
"\n",
"### 解释部分\n",
"\n",
"- **LayerNorm的操作**LayerNorm在每个头部head上独立操作这相当于在h个组上执行GroupNormWu & He2018。值得注意的是$w$ 是由 $\\omega \\in \\mathbb{R}^{D/h}$ 通过公式 $w = \\exp(-\\exp(\\omega))$ 计算得到的,$\\omega$ 是实际的头部可训练参数。这确保了 $w$ 在区间 (0,1) 内,从而保证 $\\text{diag}(w)$ 是一个收缩矩阵。\n",
"\n",
"- **wkv_t 计算**wkv_t 的注意力计算可以用递归形式写为:\n",
" \\begin{align*}\n",
" \\text{wk} \\mathbf{v}' &= s + \\text{diag}(u) \\cdot k^\\top \\cdot v \\tag{8} \\\\\n",
" s' &= \\text{diag}(w) \\cdot s + k^\\top \\cdot v \\tag{9}\n",
" \\end{align*}\n",
"\n",
"- **解释RWKV的 wkv_t 项**RWKV的 wk\\mathbf{v}_t 项可以被认为是归一化 $k^\\top v$ 项的基于衰减的等价物。值得注意的是,对于给定的头部 $j$,递归状态 $s$ 是 $k^\\top v$ 的和,其中 $s$ 的每个通道在每个时间步通过相应的 $w$ 通道单独衰减。在应用接受度向量、门控和输出权重之前,当前令牌的 $k^\\top v$ 被乘以一个每通道的学习提升 $u$ 并与状态相加见图1右上角。这给当前令牌相对于包含在衰减状态历史中的过去令牌的和一个特殊的处理。接受度乘以这个和类似于线性注意力中的查询项。\n"
]
},
{
"cell_type": "markdown",
"id": "78a86c2f-962d-4826-baf4-bc19bc40b6e3",
"metadata": {},
"source": [
"这里的最大的改进应该是现在的计算是分成了`H = self.n_head`个头然后每个头的计算结果都被存到了state里。相比于RWKV-4这种改进可以类比于Transformer的单头自注意力机制改到多头注意力机制。\n",
"```python\n",
" @MyFunction\n",
" def time_mixing(self, x, state, i:int, time_mix_k, time_mix_v, time_mix_r, time_mix_g, time_first, time_decay, kw, vw, rw, gw, ow, ln_w, ln_b):\n",
" H = self.n_head\n",
" S = self.head_size\n",
"\n",
" i1 = (2+S)*i+1\n",
" xk = x * time_mix_k + state[i1] * (1 - time_mix_k)\n",
" xv = x * time_mix_v + state[i1] * (1 - time_mix_v)\n",
" xr = x * time_mix_r + state[i1] * (1 - time_mix_r)\n",
" xg = x * time_mix_g + state[i1] * (1 - time_mix_g)\n",
" state[i1] = x\n",
"\n",
" r = (rw @ xr).view(H, 1, S)\n",
" k = (kw @ xk).view(H, S, 1)\n",
" v = (vw @ xv).view(H, 1, S)\n",
" g = F.silu(gw @ xg)\n",
"\n",
" s = state[(2+S)*i+2:(2+S)*(i+1), :].reshape(H, S, S)\n",
"\n",
" x = torch.zeros(H, S)\n",
" a = k @ v\n",
" x = r @ (time_first * a + s)\n",
" s = a + time_decay * s\n",
" \n",
" state[(2+S)*i+2:(2+S)*(i+1), :] = s.reshape(S, -1)\n",
" x = x.flatten()\n",
"\n",
" x = F.group_norm(x.unsqueeze(0), num_groups=H, weight=ln_w, bias=ln_b, eps = 64e-5).squeeze(0) * g # same as gn(x/8, eps=1e-5)\n",
" return ow @ x\n",
"```"
]
},
{
"cell_type": "code",
"execution_count": 11,
"id": "bd093a96-fdc5-460d-b39f-fe3735795b42",
"metadata": {},
"outputs": [],
"source": [
"class RWKV_RNN(MyModule):\n",
" def __init__(self, args):\n",
" super().__init__()\n",
" self.args = args\n",
" self.eval() # set torch to inference mode\n",
" \n",
" w = torch.load(args.MODEL_NAME + '.pth', map_location='cpu')\n",
" for k in w.keys():\n",
" w[k] = w[k].float() # convert to f32 type\n",
" if '.time_' in k: w[k] = w[k].squeeze()\n",
" if '.time_decay' in k: w[k] = torch.exp(-torch.exp(w[k])).unsqueeze(-1)\n",
" if '.time_faaaa' in k: w[k] = w[k].unsqueeze(-1)\n",
"\n",
" self.n_head = w['blocks.0.att.time_decay'].shape[0]\n",
" self.head_size = w['blocks.0.ln1.weight'].shape[0] // self.n_head\n",
" \n",
" self.w = types.SimpleNamespace() # set self.w from w\n",
" self.w.blocks = {}\n",
" for k in w.keys(): # example: \"blocks.0.att.time_first\" => self.w.blocks[0].att.time_first\n",
" parts = k.split('.')\n",
" last = parts.pop()\n",
" here = self.w\n",
" for p in parts:\n",
" if p.isdigit():\n",
" p = int(p)\n",
" if p not in here: here[p] = types.SimpleNamespace()\n",
" here = here[p]\n",
" else:\n",
" if not hasattr(here, p): setattr(here, p, types.SimpleNamespace())\n",
" here = getattr(here, p)\n",
" setattr(here, last, w[k])\n",
"\n",
" def layer_norm(self, x, w):\n",
" return F.layer_norm(x, (self.args.n_embd,), weight=w.weight, bias=w.bias)\n",
"\n",
" @MyFunction\n",
" def channel_mixing(self, x, state, i:int, time_mix_k, time_mix_r, kw, vw, rw):\n",
" i0 = (2+self.head_size)*i+0\n",
" xk = x * time_mix_k + state[i0] * (1 - time_mix_k)\n",
" xr = x * time_mix_r + state[i0] * (1 - time_mix_r)\n",
" state[i0] = x\n",
" r = torch.sigmoid(rw @ xr)\n",
" k = torch.square(torch.relu(kw @ xk)) # square relu, primer paper\n",
" return r * (vw @ k)\n",
"\n",
" @MyFunction\n",
" def time_mixing(self, x, state, i:int, time_mix_k, time_mix_v, time_mix_r, time_mix_g, time_first, time_decay, kw, vw, rw, gw, ow, ln_w, ln_b):\n",
" H = self.n_head\n",
" S = self.head_size\n",
"\n",
" i1 = (2+S)*i+1\n",
" xk = x * time_mix_k + state[i1] * (1 - time_mix_k)\n",
" xv = x * time_mix_v + state[i1] * (1 - time_mix_v)\n",
" xr = x * time_mix_r + state[i1] * (1 - time_mix_r)\n",
" xg = x * time_mix_g + state[i1] * (1 - time_mix_g)\n",
" state[i1] = x\n",
"\n",
" r = (rw @ xr).view(H, 1, S)\n",
" k = (kw @ xk).view(H, S, 1)\n",
" v = (vw @ xv).view(H, 1, S)\n",
" g = F.silu(gw @ xg)\n",
"\n",
" s = state[(2+S)*i+2:(2+S)*(i+1), :].reshape(H, S, S)\n",
"\n",
" x = torch.zeros(H, S)\n",
" a = k @ v\n",
" x = r @ (time_first * a + s)\n",
" s = a + time_decay * s\n",
" \n",
" state[(2+S)*i+2:(2+S)*(i+1), :] = s.reshape(S, -1)\n",
" x = x.flatten()\n",
"\n",
" x = F.group_norm(x.unsqueeze(0), num_groups=H, weight=ln_w, bias=ln_b, eps = 64e-5).squeeze(0) * g # same as gn(x/8, eps=1e-5)\n",
" return ow @ x\n",
"\n",
" def forward(self, token, state):\n",
" with torch.no_grad():\n",
" if state == None:\n",
" state = torch.zeros(self.args.n_layer * (2+self.head_size), self.args.n_embd)\n",
" \n",
" x = self.w.emb.weight[token]\n",
" x = self.layer_norm(x, self.w.blocks[0].ln0)\n",
" for i in range(self.args.n_layer):\n",
" # print(i)\n",
" att = self.w.blocks[i].att\n",
" x = x + self.time_mixing(self.layer_norm(x, self.w.blocks[i].ln1), state, i, \n",
" att.time_mix_k, att.time_mix_v, att.time_mix_r, att.time_mix_g, att.time_faaaa, att.time_decay, \n",
" att.key.weight, att.value.weight, att.receptance.weight, att.gate.weight, att.output.weight,\n",
" att.ln_x.weight, att.ln_x.bias)\n",
" ffn = self.w.blocks[i].ffn\n",
" x = x + self.channel_mixing(self.layer_norm(x, self.w.blocks[i].ln2), state, i, \n",
" ffn.time_mix_k, ffn.time_mix_r, \n",
" ffn.key.weight, ffn.value.weight, ffn.receptance.weight)\n",
" \n",
" x = self.w.head.weight @ self.layer_norm(x, self.w.ln_out)\n",
" return x.float(), state"
]
},
{
"cell_type": "code",
"execution_count": 32,
"id": "a330cd34-7ed0-4a6c-92a3-19797d34ee77",
"metadata": {},
"outputs": [],
"source": [
"# context = \"Q:Do you know datawhalechina?\\nA:\"\n",
"context = '\\nQ:DataWhalechina is an organization founded at Shanghai Jiao Tong University that helps learners learn artificial intelligence. How do you think of it?'"
]
},
{
"cell_type": "code",
"execution_count": 33,
"id": "ad824161-413d-460c-9ffe-9dbfb739f86b",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"'/data1/ckw/RWKV-5-World-0.4B-v2-20231113-ctx4096'"
]
},
"execution_count": 33,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"args.MODEL_NAME"
]
},
{
"cell_type": "code",
"execution_count": 34,
"id": "f0e2f841-4cda-48d4-b055-7adf00f2fe73",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"(24, 1024)"
]
},
"execution_count": 34,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"args.n_layer,args.n_embd"
]
},
{
"cell_type": "code",
"execution_count": 35,
"id": "aba8a4d4-9a77-4191-a7ef-d5e6100ca3c1",
"metadata": {},
"outputs": [],
"source": [
"# args.n_layer = 24\n",
"# args.n_embd = 1024"
]
},
{
"cell_type": "code",
"execution_count": 36,
"id": "dd44f7bc-e8d6-4242-beb5-89a866990751",
"metadata": {},
"outputs": [],
"source": [
"# args.n_layer = 12\n",
"# args.n_embd = 768"
]
},
{
"cell_type": "code",
"execution_count": 37,
"id": "2a96d9dc-8b5e-40cc-bb36-24c9bdeac29e",
"metadata": {},
"outputs": [],
"source": [
"# args.MODEL_NAME='../models/rwkv-5-world-1b5'"
]
},
{
"cell_type": "code",
"execution_count": 38,
"id": "b7d07606-31b4-4c21-9f89-554d89c2c866",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n",
"Using CPU. Loading /data1/ckw/RWKV-5-World-0.4B-v2-20231113-ctx4096 ...\n",
"\n",
"Preprocessing context (slow version. see v2/rwkv/model.py for fast version)\n"
]
}
],
"source": [
"print(f'\\nUsing CPU. Loading {args.MODEL_NAME} ...')\n",
"model = RWKV_RNN(args)\n",
"\n",
"print(f'\\nPreprocessing context (slow version. see v2/rwkv/model.py for fast version)')\n",
"init_state = None"
]
},
{
"cell_type": "code",
"execution_count": 39,
"id": "ce42cfad-0274-4d5d-950d-fb89ff11ed2c",
"metadata": {},
"outputs": [],
"source": [
"init_state = None"
]
},
{
"cell_type": "code",
"execution_count": 40,
"id": "3e02a81c-1447-4936-a241-4d00ecf8e862",
"metadata": {},
"outputs": [],
"source": [
"LENGTH_PER_TRIAL=1024"
]
},
{
"cell_type": "code",
"execution_count": 41,
"id": "4a00ea05-d6fd-4052-b13a-8107fb268420",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n",
"\n",
"--[ Trial 0 ]----------------- \n",
"Q:DataWhalechina is an organization founded at Shanghai Jiao Tong University that helps learners learn artificial intelligence. How do you think of it?\n",
"QI: I think that the group of students is actually the whole AI community.\n",
"Q: In the first episode, how do you think you, a student, can use AI to solve a problem?\n",
"QI: It's a great opportunity to help develop and build knowledge, so that if we see AI problems, we can help solve them.\n",
"Q: How do you think that students can also participate in the teaching of AI?\n",
"QI: It is very important to let the students to think that there is an AI problem, and we can solve it by teaching AI.\n",
"Q: How do you think the research that we did on AI can be used to develop AI technologies?\n",
"QI: The research is interesting and it can be used to develop AI technologies.\n",
"Q: Do you think that students can learn from your research?\n",
"QI: I think so.\n",
"Q: You also talk about the use of AI in real-life applications. What do you think of that?\n",
"QI: I think it's a good thing to see.\n",
"Q: What are the major challenges that you see as being faced by the AI community?\n",
"QI: One is how to find data that can help us solve problems. The other is how to find a good dataset.\n",
"Q: You also talk about how we should deal with the big data problem. How do you think about that?\n",
"QI: We should not think that it is impossible to handle big data. There are a lot of big data, but there is a problem of how to handle them.\n",
"Q: What is the role of AI in industry?\n",
"QI: AI plays an important role in industry. AI has helped us improve the quality of services. We have a lot of new applications that we are using AI to solve.\n",
"Q: How do you think about AI and humans in the future?\n",
"QI: AI is not just for humans. It is also used for us to learn, for example.\n",
"Q: What do you think about the use of AI in the field of tourism?\n",
"QI: It's not that easy to use AI in tourism. There are so many problems.\n",
"Q: Do you think that AI will be a part of tourism in the future?\n",
"QI: I think so. It is very important for us to see.\n",
"Q: What do you think about AI and data sharing?\n",
"QI: It is not that easy to use AI in data sharing.\n",
"Q: What are the ways that you see AI in tourism?\n",
"QI: AI can be used to solve problems.\n",
"Q: How do you think about the relationship between AI and data?\n",
"QI: We need to use AI in the future to help us solve problems.\n",
"Q: How do you think about the relationship between AI and data sharing?\n",
"Q: What do you think about the future of AI?\n",
"Q: What are the main issues that AI is facing?\n",
"Q: What are the biggest challenges that you see in the field of AI?\n",
"Q: What do you think about the future of AI?\n",
"Q: What are the biggest challenges that you see in the field of AI?\n",
"Q: What are the main trends that you see in the field of AI?\n",
"Q: What are the main challenges that you see in the field of AI?\n",
"Q: What are the main trends that you see in the field of AI?\n",
"Q: What are the main trends that you see in the field of AI?\n",
"Q: What are the main trends that you see in the field of AI?\n",
"Q: What are the main trends that you see in the field of AI?\n",
"Q: What are the main trends that you see in the field of AI?\n",
"Q: What are the main trends that you see in the field of AI?\n",
"Q: What are the main trends that you see in the field of AI?\n",
"Q: What are the main trends that you see in the field of AI?\n",
"Q: What are the main trends that you see in the field of AI?\n",
"Q: What are the main trends that you see in the field of AI?\n",
"Q: What are the main trends that you see in the field of AI?\n",
"Q: What are the main trends that you see in the field of AI?\n",
"Q: What are the main trends that you see in the field of AI?\n",
"Q: What are the main trends that you see in the field of AI?\n",
"Q: What are the main trends that you see in the field of AI?\n",
"Q: What are the main trends that you see in the field of AI?\n",
"Q: What are the main trends that you see in the field of AI?\n",
"Q: What are the main trends that you see in the field of AI?\n",
"Q: What are the main trends that you see in the field of AI?\n",
"Q: What are the main trends that you see in the field of AI?\n",
"\n",
"--[ Trial 1 ]----------------- \n",
"Q:DataWhalechina is an organization founded at Shanghai Jiao Tong University that helps learners learn artificial intelligence. How do you think of it?\n",
"M: We are always looking for data to make sure that we are doing the right thing. We are currently looking at how to do this through the webinars and learning events. We have had speakers from different areas, such as from Silicon Valley, who have participated in the series. The current speaker, Marco Aurelio, was from Hong Kong. He was doing a presentation on Artificial Intelligence.\n",
"Q:How do you think of the audience that you are aiming to reach?\n",
"M: We are aiming at the general audience. We are also targeting people in the financial industry, who are also interested in artificial intelligence.\n",
"Q:What are your biggest challenges?\n",
"M: One of the biggest challenges is that the audience is very educated. They know about artificial intelligence and data. But the difficulty is that we have to explain the whole technology to them.\n",
"Q:How do you see the future of artificial intelligence?\n",
"M: It is an interesting future. It is really interesting. We are starting to see many different developments. The technology is really getting better and better. There are different ways of data that are being created. We have the development of machines to pick words and sentences and machines to make the machines think.\n",
"Q:How do you see the future of Artificial Intelligence?\n",
"M: We are constantly working on how to make the future of artificial intelligence more human-like.\n",
"Tags: dataWhalechina\n",
"Previous PostFuture is one of the hottest topics in Artificial Intelligence\n",
"Next PostOpinions about the future of Artificial Intelligence are changing\n",
"Cotton Developer News: Hands-On With Artificial Intelligence\n",
"Headlines from the data Whalechina Network: October 6, 2019\n",
"DataWhalechina Network: July 30, 2019\n",
"Cotton Developer News: July 22, 2019\n",
"Headlines from the data Whalechina Network: June 22, 2019\n",
"DataWhalechina Network: May 18, 2019\n",
"Archives Select Month July 2019 June 2019 May 2019 April 2019 March 2019 February 2019 January 2019 December 2018 November 2018 October 2018 September 2018 August 2018 July 2018 June 2018 May 2018 April 2018 March 2018 February 2018 January 2018 December 2017 November 2017 October 2017 September 2017 August 2017 July 2017 June 2017 May 2017 April 2017 March 2017 February 2017 January 2017 December 2016 November 2016 October 2016 September 2016 August 2016 July 2016 June 2016 May 2016 April 2016 March 2016 February 2016 January 2016 December 2015 November 2015 October 2015 September 2015 August 2015 July 2015 June 2015 May 2015 April 2015 March 2015 February 2015 January 2015 December 2014 November 2014 October 2014 September 2014 August 2014 July 2014 June 2014 May 2014 April 2014 March 2014 February 2014 January 2014 December 2013 November 2013 October 2013 September 2013 August 2013 July 2013 June 2013 May 2013 April 2013 March 2013 February 2013 January 2013 December 2012 November 2012 October 2012 September 2012 August 2012 July 2012 June 2012 May 2012 April 2012 March 2012 February 2012 January 2012 December 2011 November 2011 October 2011 September 2011 August 2011 July 2011 June 2011 May 2011 April 2011 March 2011 February 2011 January 2011 December 2010 November 2010 October 2010 September 2010 August 2010 July 2010 June 2010 May 2010 April 2010 March 2010 February 2010 January 2010 December 2009 November 2009 October 2009 September 2009 August 2009 July 2009 June 2009 May 2009 April 2009 March 2009 February 2009 January 2009 December 2008 November 2008 October 2008 September 2008 August 2008 July 2008 June 2008 May 2008 April 2008 March 2008 February 2008 January 2008 December 2007 November 2007 October 2007 September 2007 August 2007 July 2007 June 2007 May 2007 April 2007 March 2007 February 2007 January 2007 December 2006 November 2006 October 2006 September 2006 August 2006 July 2006 June 2006 May 2006 April 2006 March 2006 February 2006 January 2006 December 2005 November 2005 October 2005 September 2005 August 2005 July 2005 June 2005 May 2005 April 2005 March 2005 February 2005 January 2005 December 2004 November 2004 October 2004 September 2004 August 2004 July 2004 June 2004 May 2004 April 2004 March 2004 February 2004 January 2004 December 2003 November 2003 October 2003 September 2003 August 2003 July 2003 June 2003 May 2003 April 2003 March 2003 February 2003 January 2003 December 2002 November 2002 October 2002 September 2002 August 2002 July 2002\n",
"\n",
"--[ Trial 2 ]----------------- \n",
"Q:DataWhalechina is an organization founded at Shanghai Jiao Tong University that helps learners learn artificial intelligence. How do you think of it?\n",
"Q:As AI continues to grow, what are some of the most promising applications of artificial intelligence?\n",
"Q:How do you think artificial intelligence will affect the future of AI?\n",
"Q:How does AI's role in education differ from the way it was used in the past?\n",
"Q:What are some of the challenges AI will face in the future?\n",
"Q:What is your vision for AI?\n",
"Q:What are your key takeaways from this conference?\n",
"Q:What do you hope to accomplish with AI?\n",
"Q:What are your goals for the future of AI?\n",
"Q:What are your current trends and plans for AI?\n",
"Q:How can AI be applied in education?\n",
"Q:What do you think will be the biggest impact of AI on education?\n",
"Q:What is your vision for AI in the future?\n",
"Q:How does AI change the way we teach and learn?\n",
"Q:What are your hopes for the future of AI?\n",
"Q:How does AI's role in education differ from the way it was used in the past?\n",
"Q:What are some of the challenges AI will face in the future?\n",
"Q:What is your vision for the future of AI?\n",
"Q:What are your key takeaways from this conference?\n",
"Q:What is your vision for the future of AI?\n",
"Q:What are your goals for the future of AI?\n",
"Q:What are some of the biggest challenges AI will face in the future?\n",
"Q:What is your vision for AI's role in education?\n",
"Q:What are your goals for the future of AI?\n",
"Q:What are your hopes for the future of AI?\n",
"Q:What are your key takeaways from this conference?\n",
"Q:What is your vision for AI's role in education?\n",
"Q:What are your goals for the future of AI?\n",
"Q:What are some of the biggest challenges AI will face in the future?\n",
"Q:What is your vision for AI's role in education?\n",
"Q:What are your goals for the future of AI?\n",
"Q:What are some of the biggest challenges AI will face in the future?\n",
"Q:What is your vision for AI's role in education?\n",
"Q:What are your goals for the future of AI?\n",
"Q:What are some of the biggest challenges AI will face in the future?\n",
"Q:What is your vision for AI's role in education?\n",
"Q:What are your goals for the future of AI?\n",
"Q:What are some of the biggest challenges AI will face in the future?\n",
"Q:What is your vision for AI's role in education?\n",
"Q:What are your goals for the future of AI?\n",
"Q:What are some of the biggest challenges AI will face in the future?\n",
"Q:What is your vision for AI's role in education?\n",
"Q:What are your goals for the future of AI?\n",
"Q:What are some of the biggest challenges AI will face in the future?\n",
"Q:What is your vision for AI's role in education?\n",
"Q:What are your goals for the future of AI?\n",
"Q:What are some of the biggest challenges AI will face in the future?\n",
"Q:What is your vision for AI's role in education?\n",
"Q:What are your goals for the future of AI?\n",
"Q:What are some of the biggest challenges AI will face in the future?\n",
"Q:What is your vision for AI's role in education?\n",
"Q:What are your goals for the future of AI?\n",
"Q:What are some of the biggest challenges AI will face in the future?\n",
"Q:What is your vision for AI's role in education?\n",
"Q:What are your goals for the future of AI?\n",
"Q:What are some of the biggest challenges AI will face in the future?\n",
"Q:What is your vision for AI's role in education?\n",
"Q:What are your goals for the future of AI?\n",
"Q:What are some of the biggest challenges AI will face in the future?\n",
"Q:What is your vision for AI's role in education?\n",
"Q:What are your goals for the future of AI?\n",
"Q:What are some of the biggest challenges AI will face in the future?\n",
"Q:What is your vision for AI's role in education?\n",
"Q:What are your goals for the future of AI?\n",
"Q:What are some of the biggest challenges AI will face in the future?\n",
"Q:What is your vision for AI's role in education?\n",
"Q:What are your goals for the future of AI?\n",
"Q:What are some of the biggest challenges AI will face in the future?\n",
"Q:What is your vision for AI's role in education?\n",
"Q:What are your goals for the future of AI?\n",
"Q:What are some of the biggest challenges AI will face in the future?\n",
"Q\n",
"\n"
]
}
],
"source": [
"for token in tokenizer.encode(context):\n",
" init_out, init_state = model.forward(token, init_state)\n",
"\n",
"for TRIAL in range(NUM_TRIALS):\n",
" print(f'\\n\\n--[ Trial {TRIAL} ]-----------------', context, end=\"\")\n",
" all_tokens = []\n",
" out_last = 0\n",
" out, state = init_out.clone(), init_state.clone()\n",
" for i in range(LENGTH_PER_TRIAL):\n",
" token = sample_logits(out, TEMPERATURE, TOP_P)\n",
" all_tokens += [token]\n",
" try:\n",
" tmp = tokenizer.decode(all_tokens[out_last:])\n",
" if '\\ufffd' not in tmp: # only print when we have a valid utf-8 string\n",
" print(tmp, end=\"\", flush=True)\n",
" out_last = i + 1\n",
" except:\n",
" pass\n",
" out, state = model.forward(token, state) \n",
"print('\\n')"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "a3d3eaf3-252a-43da-9414-e1c6f6c681fc",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "kewei-ai",
"language": "python",
"name": "kewei-ai"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.5"
}
},
"nbformat": 4,
"nbformat_minor": 5
}

View File

@ -21,6 +21,16 @@
"MyFunction = torch.jit.script_method"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "0e6c9297-472f-4fd8-ad19-d8072b5040f8",
"metadata": {},
"outputs": [],
"source": [
"rwkv5又叫eagal"
]
},
{
"cell_type": "code",
"execution_count": 2,

Binary file not shown.

After

Width:  |  Height:  |  Size: 100 KiB

View File

@ -0,0 +1,514 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": 5,
"id": "f64be1c0-02a8-4ea9-ae05-85b66e803cac",
"metadata": {},
"outputs": [],
"source": [
"########################################################################################################\n",
"# The RWKV Language Model - https://github.com/BlinkDL/RWKV-LM\n",
"########################################################################################################\n",
"\n",
"import numpy as np\n",
"np.set_printoptions(precision=4, suppress=True, linewidth=200)\n",
"import types, torch\n",
"import torch.nn as nn\n",
"from torch.nn import functional as F\n",
"\n",
"MyModule = torch.jit.ScriptModule\n",
"MyFunction = torch.jit.script_method"
]
},
{
"cell_type": "markdown",
"id": "12f7fb18-23e7-44a1-883e-0bd673b8fb0b",
"metadata": {},
"source": [
"![](./img/01.png)\n",
"\n",
"图1RWKV架构概述。左侧时间混合和通道混合块右上角作为RNN单元的RWKV时间混合块中下部前馈模块中的令牌移位模块和Eagle时间混合右下角Finch时间混合中的令牌移位模块。所有形状注释为简洁起见假设为单头。虚线箭头左侧右上角表示在Finch中有连接但在Eagle中没有。"
]
},
{
"cell_type": "markdown",
"id": "b3d66b1f-5c04-44d7-9bbc-76f844707327",
"metadata": {},
"source": [
"首先RWKV 6相比于RWKV 5在Token Shift上进行了改进具体看下面的中间底部和右下角的图分别是RWKV 4/5的Token Shift方式和RWKV 6的Token Shift方式。"
]
},
{
"cell_type": "markdown",
"id": "20cd8fa2-c1a4-4fdc-ba5d-858b25df1bcf",
"metadata": {},
"source": [
"具体内容如下:\n",
"\n",
"### 公式部分\n",
"\n",
"Finch Token Shift中使用的数据依赖线性插值ddlerp定义如下\n",
"\n",
"\\begin{align*}\n",
"\\text{lora}_{\\Box}(x) &= \\lambda_{\\Box} + \\tanh(x A_{\\Box}) B_{\\Box} \\tag{14} \\\\\n",
"\\text{ddlerp}_{\\Box}(a, b) &= a + (b - a) \\odot \\text{lora}_{\\Box}(a + (b - a) \\odot \\mu_{x}) \\tag{15}\n",
"\\end{align*}\n",
"\n",
"### 解释部分\n",
"\n",
"- **可学习向量和矩阵**\n",
" - $\\mu_{x}$ 和每个 $\\lambda_{\\Box}$ 引入了维度为 $D$ 的可训练向量。\n",
" - $A_{\\Box} \\in \\mathbb{R}^{D \\times 32}$ 和 $B_{\\Box} \\in \\mathbb{R}^{32 \\times D}$ 引入了新的可训练权重矩阵。\n",
" - 对于公式中提到的LoRA$_{\\omega}$的特殊情况,引入了双倍大小的可训练权重矩阵:$A_{\\omega} \\in \\mathbb{R}^{D \\times 64}$ 和 $B_{\\omega} \\in \\mathbb{R}^{64 \\times D}$。\n",
"\n",
"- **未来模型扩展**\n",
" - 图1中右下角显示了一个示意图。\n",
" - 未来7B及更大规模的Finch模型预计将进一步增加这些权重矩阵的大小可能翻倍或更多。\n",
"\n",
"### 功能与作用\n",
"\n",
"这种带有数据依赖性的Token Shift新形式旨在扩展模型超越RWKV-4/Eagle风格的Token Shift的能力使得每个通道分配的新旧数据量现在依赖于当前和前一个时间步的输入。\n",
"\n",
"### 详细解释\n",
"\n",
"- **数据依赖线性插值ddlerp**\n",
" - ddlerp通过公式14和公式15实现它结合了当前时间步和前一个时间步的信息来计算插值。\n",
" - $\\text{lora}_{\\Box}(x)$利用了一个$\\lambda_{\\Box}$向量和通过$\\tanh$函数处理的$x A_{\\Box}$与$B_{\\Box}$的乘积来生成。\n",
"\n",
"- **模型能力扩展**\n",
" - 通过这种数据依赖的Token ShiftFinch模型能够更灵活地处理时间步之间的信息传递使得模型在处理复杂序列数据时更加精确和高效。\n",
"\n",
"总结来说Finch在Token Shift上引入了数据依赖的线性插值利用可训练的向量和矩阵来增强模型的灵活性和能力使其能够更好地处理时间步之间的信息从而提高了模型的整体性能。"
]
},
{
"cell_type": "code",
"execution_count": 6,
"id": "1261d4e1-df4e-410b-a4fa-45452c4b6fb1",
"metadata": {},
"outputs": [],
"source": [
"class RWKV_TOKENIZER():\n",
" table: list[list[list[bytes]]]\n",
" good: list[set[int]]\n",
" wlen: list[int]\n",
" def __init__(self, file_name):\n",
" self.idx2token = {}\n",
" sorted = [] # must be already sorted\n",
" lines = open(file_name, \"r\", encoding=\"utf-8\").readlines()\n",
" for l in lines:\n",
" idx = int(l[:l.index(' ')])\n",
" x = eval(l[l.index(' '):l.rindex(' ')])\n",
" x = x.encode(\"utf-8\") if isinstance(x, str) else x\n",
" assert isinstance(x, bytes)\n",
" assert len(x) == int(l[l.rindex(' '):])\n",
" sorted += [x]\n",
" self.idx2token[idx] = x\n",
"\n",
" self.token2idx = {}\n",
" for k, v in self.idx2token.items():\n",
" self.token2idx[v] = int(k)\n",
"\n",
" # precompute some tables for fast matching\n",
" self.table = [[[] for j in range(256)] for i in range(256)]\n",
" self.good = [set() for i in range(256)]\n",
" self.wlen = [0 for i in range(256)]\n",
"\n",
" for i in reversed(range(len(sorted))): # reverse order - match longer tokens first\n",
" s = sorted[i]\n",
" if len(s) >= 2:\n",
" s0 = int(s[0])\n",
" s1 = int(s[1])\n",
" self.table[s0][s1] += [s]\n",
" self.wlen[s0] = max(self.wlen[s0], len(s))\n",
" self.good[s0].add(s1)\n",
"\n",
" def encodeBytes(self, src: bytes) -> list[int]:\n",
" src_len: int = len(src)\n",
" tokens: list[int] = []\n",
" i: int = 0\n",
" while i < src_len:\n",
" s: bytes = src[i : i + 1]\n",
"\n",
" if i < src_len - 1:\n",
" s1: int = int(src[i + 1])\n",
" s0: int = int(src[i])\n",
" if s1 in self.good[s0]:\n",
" sss: bytes = src[i : i + self.wlen[s0]]\n",
" try:\n",
" s = next(filter(sss.startswith, self.table[s0][s1]))\n",
" except:\n",
" pass\n",
" tokens.append(self.token2idx[s])\n",
" i += len(s)\n",
"\n",
" return tokens\n",
"\n",
" def decodeBytes(self, tokens):\n",
" return b''.join(map(lambda i: self.idx2token[i], tokens))\n",
"\n",
" def encode(self, src: str):\n",
" return self.encodeBytes(src.encode(\"utf-8\"))\n",
"\n",
" def decode(self, tokens):\n",
" return self.decodeBytes(tokens).decode('utf-8')\n",
"\n",
" def printTokens(self, tokens):\n",
" for i in tokens:\n",
" s = self.idx2token[i]\n",
" try:\n",
" s = s.decode('utf-8')\n",
" except:\n",
" pass\n",
" print(f'{repr(s)}{i}', end=' ')\n",
" # print(repr(s), i)\n",
" print()\n",
"\n",
"########################################################################################################"
]
},
{
"cell_type": "code",
"execution_count": 7,
"id": "725bc55e-7f3f-4c1c-9664-ad84bf68e943",
"metadata": {},
"outputs": [],
"source": [
"#采样方式没有变化\n",
"def sample_logits(out, temperature=1.0, top_p=0.8):\n",
" probs = F.softmax(out, dim=-1).numpy()\n",
" sorted_probs = np.sort(probs)[::-1]\n",
" cumulative_probs = np.cumsum(sorted_probs)\n",
" cutoff = float(sorted_probs[np.argmax(cumulative_probs > top_p)])\n",
" probs[probs < cutoff] = 0\n",
" if temperature != 1.0:\n",
" probs = probs.pow(1.0 / temperature)\n",
" probs = probs / np.sum(probs)\n",
" out = np.random.choice(a=len(probs), p=probs)\n",
" return out\n",
"\n",
"########################################################################################################"
]
},
{
"cell_type": "raw",
"id": "812fac97-a6b8-423c-831d-fe7397883437",
"metadata": {},
"source": [
"模型下载地址https://hf-mirror.com/BlinkDL/rwkv-6-world/resolve/main/RWKV-x060-World-1B6-v2.1-20240328-ctx4096.pth"
]
},
{
"cell_type": "code",
"execution_count": 8,
"id": "434ff88e-b94e-4f8b-86a3-7fefca32cddb",
"metadata": {},
"outputs": [],
"source": [
"tokenizer = RWKV_TOKENIZER(\"./rwkv_vocab_v20230424.txt\")\n",
"\n",
"args = types.SimpleNamespace()\n",
"args.MODEL_NAME = '/data1/ckw/RWKV-x060-World-1B6-v2.1-20240328-ctx4096'\n",
"args.n_layer = 24\n",
"args.n_embd = 2048\n",
"args.vocab_size = 65536\n",
"\n",
"context = \"\\nDatawhale is \"\n",
"# context = \"\\n我们发现\"\n",
"NUM_TRIALS = 3\n",
"LENGTH_PER_TRIAL = 100\n",
"TEMPERATURE = 1.0\n",
"TOP_P = 0.7"
]
},
{
"cell_type": "markdown",
"id": "ee395717-e599-40fd-a4b9-ddfc800babea",
"metadata": {},
"source": [
"相比于RWKV 5的Channel Mixing见下面代码来说RWKV6的Channel Mixing没有变化这里的`time_maa_k`和RWKV 5中的`time_mix_k`是相同形状的可学习参数都是一个维度为D模型的隐藏层维度的张量。"
]
},
{
"cell_type": "markdown",
"id": "54571fa3-52c6-4b05-ab97-9f06d76a522e",
"metadata": {},
"source": [
"Finch在时间混合Time Mixing上做了以下改进具体内容如下\n",
"\n",
"### 公式部分\n",
"\n",
"Finch时间混合的公式如下\n",
"\n",
"\\begin{align*}\n",
"\\Box_t &= \\text{lerp}_{\\Box}(x_t, x_{t-1}) W_{\\Box}, \\quad \\Box \\in \\{r, k, v, g\\} \\tag{16} \\\\\n",
"d_t &= \\text{lora}_d(\\text{ddlerp}_d(x_t, x_{t-1})) \\tag{17} \\\\\n",
"w_t &= \\exp(-\\exp(d_t)) \\tag{18} \\\\\n",
"\\text{wk} \\mathbf{v}_t &= \\text{diag}(u) \\cdot k_t^\\top \\cdot v_t + \\sum_{i=1}^{t-1} \\left( \\prod_{j=1}^{i-1} w_j \\right) \\cdot k_i^\\top \\cdot v_i \\in \\mathbb{R}^{(D/h) \\times (D/h)} \\tag{19} \\\\\n",
"o_t &= \\text{concat} \\left( \\text{SiLU}(g_t) \\odot \\text{LayerNorm}(r_t \\cdot \\text{wk} \\mathbf{v}_t) \\right) W_o \\in \\mathbb{R}^D \\tag{20}\n",
"\\end{align*}\n",
"\n",
"### 解释部分\n",
"\n",
"- **可学习向量和矩阵**\n",
" - $\\Box_t$ 是通过线性插值lerp计算得到的适用于接受度receptance、键key、值value和门控向量gate vectors。\n",
" - $d_t$ 是通过 $\\text{lora}_d$ 函数对 $\\text{ddlerp}_d(x_t, x_{t-1})$ 进行处理得到的。\n",
" - $w_t$ 是由 $d_t$ 计算得到的,用于控制衰减的动态变化。\n",
"\n",
"- **时间混合计算**\n",
" - $\\text{wk} \\mathbf{v}_t$ 是通过当前键值对 $k_t^\\top \\cdot v_t$ 和所有之前时间步的键值对 $k_i^\\top \\cdot v_i$ 的加权和计算得到的,权重由 $w_t$ 控制。\n",
" - 输出 $o_t$ 是通过连接concat $\\text{SiLU}(g_t)$ 和 $\\text{LayerNorm}(r_t \\cdot \\text{wk} \\mathbf{v}_t)$ 的结果得到的。\n",
"\n",
"- **递归形式**\n",
" \\begin{align*}\n",
" \\text{wk} \\mathbf{v}' &= s + \\text{diag}(u) \\cdot k^\\top \\cdot v \\tag{21} \\\\\n",
" s' &= \\text{diag}(w) \\cdot s + k^\\top \\cdot v \\tag{22}\n",
" \\end{align*}\n",
"\n",
"### 功能与作用\n",
"\n",
"与Eagle不同Finch中的 $w_t$ 不是在整个序列中固定的。每个通道的 $w_t$ 可以随时间动态变化具体取决于数据输入这也是Finch中衰减的核心变化。\n",
"\n",
"### 详细解释\n",
"\n",
"- **动态衰减**\n",
" - Finch引入的数据依赖衰减使得每个通道的 $w_t$ 可以根据当前和之前的输入动态变化,而不是固定的学习向量。\n",
" - 这种动态衰减机制通过新的LoRA机制应用到学习向量上增加了模型的灵活性。\n",
"\n",
"- **高级Token-Shift**\n",
" - 新的时间衰减 $w_t$ 进一步应用了LoRA机制允许每个通道的 $w_t$ 基于当前和之前的令牌混合来变化。\n",
"\n",
"总结来说Finch在时间混合上通过引入数据依赖的动态衰减机制和高级Token-Shift实现了更高的灵活性和精确度使模型能够更好地处理和融合时间步之间的信息从而提高了整体性能。"
]
},
{
"cell_type": "code",
"execution_count": 9,
"id": "3869854a-a4e3-4652-9698-b0d81bbbd645",
"metadata": {},
"outputs": [],
"source": [
"class RWKV_RNN(MyModule):\n",
" def __init__(self, args):\n",
" super().__init__()\n",
" self.args = args\n",
" self.eval() # set torch to inference mode\n",
" \n",
" w = torch.load(args.MODEL_NAME + '.pth', map_location='cpu')\n",
"\n",
" for k in w.keys():\n",
" w[k] = w[k].float() # convert to f32 type\n",
" if '.time_' in k: w[k] = w[k].squeeze()\n",
" if '.time_faaaa' in k: w[k] = w[k].unsqueeze(-1)\n",
"\n",
" self.n_head = w['blocks.0.att.time_faaaa'].shape[0]\n",
" self.head_size = w['blocks.0.ln1.weight'].shape[0] // self.n_head\n",
" \n",
" self.w = types.SimpleNamespace() # set self.w from w\n",
" self.w.blocks = {}\n",
" for k in w.keys(): # example: \"blocks.0.att.time_first\" => self.w.blocks[0].att.time_first\n",
" parts = k.split('.')\n",
" last = parts.pop()\n",
" here = self.w\n",
" for p in parts:\n",
" if p.isdigit():\n",
" p = int(p)\n",
" if p not in here: here[p] = types.SimpleNamespace()\n",
" here = here[p]\n",
" else:\n",
" if not hasattr(here, p): setattr(here, p, types.SimpleNamespace())\n",
" here = getattr(here, p)\n",
" setattr(here, last, w[k])\n",
"\n",
" def layer_norm(self, x, w):\n",
" return F.layer_norm(x, (self.args.n_embd,), weight=w.weight, bias=w.bias)\n",
"\n",
" @MyFunction\n",
" def channel_mixing(self, x, state, i:int, time_maa_k, time_maa_r, kw, vw, rw):\n",
" i0 = (2+self.head_size)*i+0\n",
" sx = state[i0] - x\n",
" xk = x + sx * time_maa_k\n",
" xr = x + sx * time_maa_r\n",
" state[i0] = x\n",
" r = torch.sigmoid(rw @ xr)\n",
" k = torch.square(torch.relu(kw @ xk)) # square relu, primer paper\n",
" return r * (vw @ k)\n",
"\n",
" @MyFunction\n",
" def time_mixing(self, x, state, i:int, x_maa, w_maa, k_maa, v_maa, r_maa, g_maa, tm_w1, tm_w2, td_w1, td_w2, time_first, time_decay, kw, vw, rw, gw, ow, ln_w, ln_b):\n",
" H = self.n_head\n",
" S = self.head_size\n",
"\n",
" i1 = (2+S)*i+1\n",
" sx = state[i1] - x\n",
" state[i1] = x\n",
" xxx = x + sx * x_maa\n",
" xxx = torch.tanh(xxx @ tm_w1).view(5, 1, -1)\n",
" xxx = torch.bmm(xxx, tm_w2).view(5, -1)\n",
" mw, mk, mv, mr, mg = xxx.unbind(dim=0)\n",
"\n",
" xw = x + sx * (w_maa + mw)\n",
" xk = x + sx * (k_maa + mk)\n",
" xv = x + sx * (v_maa + mv)\n",
" xr = x + sx * (r_maa + mr)\n",
" xg = x + sx * (g_maa + mg)\n",
"\n",
" w = (time_decay + (torch.tanh(xw @ td_w1) @ td_w2).float()).view(H, S, 1)\n",
" w = torch.exp(-torch.exp(w.float()))\n",
"\n",
" r = (rw @ xr).view(H, 1, S)\n",
" k = (kw @ xk).view(H, S, 1)\n",
" v = (vw @ xv).view(H, 1, S)\n",
" g = F.silu(gw @ xg)\n",
"\n",
" s = state[(2+S)*i+2:(2+S)*(i+1), :].reshape(H, S, S)\n",
"\n",
" x = torch.zeros(H, S)\n",
" a = k @ v\n",
" x = r @ (time_first * a + s)\n",
" s = a + w * s\n",
" \n",
" state[(2+S)*i+2:(2+S)*(i+1), :] = s.reshape(S, -1)\n",
" x = x.flatten()\n",
"\n",
" x = F.group_norm(x.unsqueeze(0), num_groups=H, weight=ln_w, bias=ln_b, eps = 64e-5).squeeze(0) * g # same as gn(x/8, eps=1e-5)\n",
" return ow @ x\n",
"\n",
" def forward(self, token, state):\n",
" with torch.no_grad():\n",
" if state == None:\n",
" state = torch.zeros(self.args.n_layer * (2+self.head_size), self.args.n_embd)\n",
" \n",
" x = self.w.emb.weight[token]\n",
" x = self.layer_norm(x, self.w.blocks[0].ln0)\n",
" for i in range(self.args.n_layer):\n",
" att = self.w.blocks[i].att\n",
" x = x + self.time_mixing(self.layer_norm(x, self.w.blocks[i].ln1), state, i,\n",
" att.time_maa_x, att.time_maa_w, att.time_maa_k, att.time_maa_v, att.time_maa_r, att.time_maa_g, att.time_maa_w1, att.time_maa_w2,\n",
" att.time_decay_w1, att.time_decay_w2, att.time_faaaa, att.time_decay,\n",
" att.key.weight, att.value.weight, att.receptance.weight, att.gate.weight, att.output.weight,\n",
" att.ln_x.weight, att.ln_x.bias)\n",
" ffn = self.w.blocks[i].ffn\n",
" x = x + self.channel_mixing(self.layer_norm(x, self.w.blocks[i].ln2), state, i, \n",
" ffn.time_maa_k, ffn.time_maa_r, \n",
" ffn.key.weight, ffn.value.weight, ffn.receptance.weight)\n",
" \n",
" x = self.w.head.weight @ self.layer_norm(x, self.w.ln_out)\n",
" return x.float(), state"
]
},
{
"cell_type": "code",
"execution_count": 10,
"id": "5235be83-a574-41f6-8546-bc415e2aeacf",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n",
"Using CPU. Loading /data1/ckw/RWKV-x060-World-1B6-v2.1-20240328-ctx4096 ...\n",
"\n",
"Preprocessing context (slow version. see v2/rwkv/model.py for fast version)\n"
]
}
],
"source": [
"print(f'\\nUsing CPU. Loading {args.MODEL_NAME} ...')\n",
"model = RWKV_RNN(args)\n",
"\n",
"print(f'\\nPreprocessing context (slow version. see v2/rwkv/model.py for fast version)')\n",
"init_state = None"
]
},
{
"cell_type": "code",
"execution_count": 11,
"id": "035ce374-c1c0-43b8-9ba9-37df297baae6",
"metadata": {},
"outputs": [],
"source": [
"for token in tokenizer.encode(context):\n",
" init_out, init_state = model.forward(token, init_state)"
]
},
{
"cell_type": "code",
"execution_count": 12,
"id": "1e17d6ef-c02c-4d27-8cf6-9a262f75f77f",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n",
"\n",
"--[ Trial 0 ]----------------- \n",
"Datawhale is ➡️‼️\n",
"https://twitter.com/datawhale_cn/status/1463997087819689985\n",
"#Data #AI #DataAnalytics #AIOps #DataOps #MachineLearning #DataScience #DataLakeAnalytics #Hadoop #Amazon #Google #AWS #Azure #Dataprep #DevOps #OSS #Linux #Unix #BigData #BigDataOps #DataArchitecture #DataScienceOps #MachineLearningOps\n",
"\n",
"--[ Trial 1 ]----------------- \n",
"Datawhale is 🤓\n",
"\n",
"--[ Trial 2 ]----------------- \n",
"Datawhale is 🤯. They have a solid team, a really good SaaS product and the tools to support their users. That said, I have to take a serious look at the privacy and security of their platform before I buy into their story. I think this is a case of big companies buying into the hype, and they're not taking into account all the realities that go into building a privacy-focused product.\n",
"P.S. You can still apply to Datawhale's Program.\n",
"\n"
]
}
],
"source": [
"for TRIAL in range(NUM_TRIALS):\n",
" print(f'\\n\\n--[ Trial {TRIAL} ]-----------------', context, end=\"\")\n",
" all_tokens = []\n",
" out_last = 0\n",
" out, state = init_out.clone(), init_state.clone()\n",
" for i in range(LENGTH_PER_TRIAL):\n",
" token = sample_logits(out, TEMPERATURE, TOP_P)\n",
" all_tokens += [token]\n",
" try:\n",
" tmp = tokenizer.decode(all_tokens[out_last:])\n",
" if '\\ufffd' not in tmp: # only print when we have a valid utf-8 string\n",
" print(tmp, end=\"\", flush=True)\n",
" out_last = i + 1\n",
" except:\n",
" pass\n",
" out, state = model.forward(token, state) \n",
"print('\\n')"
]
},
{
"cell_type": "markdown",
"id": "172c33e0-6d5b-4143-b85d-86777a2f5739",
"metadata": {},
"source": [
"v6和v5相比感觉更喜欢使用emoj了哈哈"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "kewei-ai",
"language": "python",
"name": "kewei-ai"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.5"
}
},
"nbformat": 4,
"nbformat_minor": 5
}

View File

@ -2,7 +2,7 @@
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"execution_count": 5,
"id": "f64be1c0-02a8-4ea9-ae05-85b66e803cac",
"metadata": {},
"outputs": [],
@ -23,7 +23,7 @@
},
{
"cell_type": "code",
"execution_count": 2,
"execution_count": 6,
"id": "1261d4e1-df4e-410b-a4fa-45452c4b6fb1",
"metadata": {},
"outputs": [],
@ -109,7 +109,7 @@
},
{
"cell_type": "code",
"execution_count": 3,
"execution_count": 7,
"id": "725bc55e-7f3f-4c1c-9664-ad84bf68e943",
"metadata": {},
"outputs": [],
@ -139,7 +139,7 @@
},
{
"cell_type": "code",
"execution_count": 4,
"execution_count": 13,
"id": "434ff88e-b94e-4f8b-86a3-7fefca32cddb",
"metadata": {},
"outputs": [],
@ -160,18 +160,6 @@
"TOP_P = 0.7"
]
},
{
"cell_type": "markdown",
"id": "b275ed7e-6708-4e5e-b76e-60c3a2a4a6b6",
"metadata": {},
"source": [
"首先RWKV 6相比于RWKV 5在Token Shift上进行了改进具体看下面的中间底部和右下角的图分别是RWKV 4/5的Token Shift方式和RWKV 6的Token Shift方式。\n",
"\n",
"![](./img/01.png)\n",
"\n",
"相比于RWKV 5的Channel Mixing见下面来说RWKV6的Channel Mixing没有变化这里的`time_maa_k`和RWKV 5中的`time_mix_k`是相同形状的可学习参数都是一个维度为D模型的隐藏层维度的张量。"
]
},
{
"cell_type": "code",
"execution_count": 14,

Binary file not shown.

Before

Width:  |  Height:  |  Size: 232 KiB

After

Width:  |  Height:  |  Size: 100 KiB

View File

@ -82,8 +82,12 @@
| --- | --- | --- |
| ChatGLM3 | [chatglm3.ipynb](./Model_Architecture_Discussions/ChatGLM3/加载模型权重.ipynb) | [@Tangent-90C](https://github.com/Tangent-90C) |
| Llama3 | [llama3.ipynb](./Model_Architecture_Discussions/llama3/llama3-from-scratch.ipynb) | [@A10-research](https://www.aaaaaaaaaa.org/) |
| RWKV V2 | [rwkv-v2.ipynb](./Model_Architecture_Discussions/rwkv-v2/rwkv-v2.ipynb) | [@Ethan-Chen-plus](https://github.com/Ethan-Chen-plus) |
| RWKV V3 | [rwkv-v3.ipynb](./Model_Architecture_Discussions/rwkv-v3/rwkv-v3.ipynb) | [@Ethan-Chen-plus](https://github.com/Ethan-Chen-plus) |
| RWKV V2 | [rwkv-v2.ipynb](./Model_Architecture_Discussions/rwkv-v2/rwkv-v2-guide.ipynb) | [@Ethan-Chen-plus](https://github.com/Ethan-Chen-plus) |
| RWKV V3 | [rwkv-v3.ipynb](./Model_Architecture_Discussions/rwkv-v3/rwkv-v3-guide.ipynb) | [@Ethan-Chen-plus](https://github.com/Ethan-Chen-plus) |
| RWKV V4 | [rwkv-v4.ipynb](./Model_Architecture_Discussions/rwkv-v4/rwkv-v4-guide.ipynb) | [@Ethan-Chen-plus](https://github.com/Ethan-Chen-plus) |
| RWKV V5 | [rwkv-v5.ipynb](./Model_Architecture_Discussions/rwkv-v5/rwkv-v5-guide.ipynb) | [@Ethan-Chen-plus](https://github.com/Ethan-Chen-plus) |
| RWKV V6 | [rwkv-v6.ipynb](./Model_Architecture_Discussions/rwkv-v6/rwkv-v6-guide.ipynb) | [@Ethan-Chen-plus](https://github.com/Ethan-Chen-plus) |
| ChatGLM4 | [rwkv-v3.ipynb](./Model_Architecture_Discussions/ChatGLM4/chatglm4-guide.ipynb) | [@Ethan-Chen-plus](https://github.com/Ethan-Chen-plus) |
---