llms-from-scratch-cn/Model_Architecture_Discussions/rwkv-v6/RWKV-v6-guide.ipynb
2024-06-10 22:49:34 +08:00

515 lines
22 KiB
Plaintext
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

{
"cells": [
{
"cell_type": "code",
"execution_count": 5,
"id": "f64be1c0-02a8-4ea9-ae05-85b66e803cac",
"metadata": {},
"outputs": [],
"source": [
"########################################################################################################\n",
"# The RWKV Language Model - https://github.com/BlinkDL/RWKV-LM\n",
"########################################################################################################\n",
"\n",
"import numpy as np\n",
"np.set_printoptions(precision=4, suppress=True, linewidth=200)\n",
"import types, torch\n",
"import torch.nn as nn\n",
"from torch.nn import functional as F\n",
"\n",
"MyModule = torch.jit.ScriptModule\n",
"MyFunction = torch.jit.script_method"
]
},
{
"cell_type": "markdown",
"id": "12f7fb18-23e7-44a1-883e-0bd673b8fb0b",
"metadata": {},
"source": [
"![](./img/01.png)\n",
"\n",
"图1RWKV架构概述。左侧时间混合和通道混合块右上角作为RNN单元的RWKV时间混合块中下部前馈模块中的令牌移位模块和Eagle时间混合右下角Finch时间混合中的令牌移位模块。所有形状注释为简洁起见假设为单头。虚线箭头左侧右上角表示在Finch中有连接但在Eagle中没有。"
]
},
{
"cell_type": "markdown",
"id": "b3d66b1f-5c04-44d7-9bbc-76f844707327",
"metadata": {},
"source": [
"首先RWKV 6相比于RWKV 5在Token Shift上进行了改进具体看下面的中间底部和右下角的图分别是RWKV 4/5的Token Shift方式和RWKV 6的Token Shift方式。"
]
},
{
"cell_type": "markdown",
"id": "20cd8fa2-c1a4-4fdc-ba5d-858b25df1bcf",
"metadata": {},
"source": [
"具体内容如下:\n",
"\n",
"### 公式部分\n",
"\n",
"Finch Token Shift中使用的数据依赖线性插值ddlerp定义如下\n",
"\n",
"\\begin{align*}\n",
"\\text{lora}_{\\Box}(x) &= \\lambda_{\\Box} + \\tanh(x A_{\\Box}) B_{\\Box} \\tag{14} \\\\\n",
"\\text{ddlerp}_{\\Box}(a, b) &= a + (b - a) \\odot \\text{lora}_{\\Box}(a + (b - a) \\odot \\mu_{x}) \\tag{15}\n",
"\\end{align*}\n",
"\n",
"### 解释部分\n",
"\n",
"- **可学习向量和矩阵**\n",
" - $\\mu_{x}$ 和每个 $\\lambda_{\\Box}$ 引入了维度为 $D$ 的可训练向量。\n",
" - $A_{\\Box} \\in \\mathbb{R}^{D \\times 32}$ 和 $B_{\\Box} \\in \\mathbb{R}^{32 \\times D}$ 引入了新的可训练权重矩阵。\n",
" - 对于公式中提到的LoRA$_{\\omega}$的特殊情况,引入了双倍大小的可训练权重矩阵:$A_{\\omega} \\in \\mathbb{R}^{D \\times 64}$ 和 $B_{\\omega} \\in \\mathbb{R}^{64 \\times D}$。\n",
"\n",
"- **未来模型扩展**\n",
" - 图1中右下角显示了一个示意图。\n",
" - 未来7B及更大规模的Finch模型预计将进一步增加这些权重矩阵的大小可能翻倍或更多。\n",
"\n",
"### 功能与作用\n",
"\n",
"这种带有数据依赖性的Token Shift新形式旨在扩展模型超越RWKV-4/Eagle风格的Token Shift的能力使得每个通道分配的新旧数据量现在依赖于当前和前一个时间步的输入。\n",
"\n",
"### 详细解释\n",
"\n",
"- **数据依赖线性插值ddlerp**\n",
" - ddlerp通过公式14和公式15实现它结合了当前时间步和前一个时间步的信息来计算插值。\n",
" - $\\text{lora}_{\\Box}(x)$利用了一个$\\lambda_{\\Box}$向量和通过$\\tanh$函数处理的$x A_{\\Box}$与$B_{\\Box}$的乘积来生成。\n",
"\n",
"- **模型能力扩展**\n",
" - 通过这种数据依赖的Token ShiftFinch模型能够更灵活地处理时间步之间的信息传递使得模型在处理复杂序列数据时更加精确和高效。\n",
"\n",
"总结来说Finch在Token Shift上引入了数据依赖的线性插值利用可训练的向量和矩阵来增强模型的灵活性和能力使其能够更好地处理时间步之间的信息从而提高了模型的整体性能。"
]
},
{
"cell_type": "code",
"execution_count": 6,
"id": "1261d4e1-df4e-410b-a4fa-45452c4b6fb1",
"metadata": {},
"outputs": [],
"source": [
"class RWKV_TOKENIZER():\n",
" table: list[list[list[bytes]]]\n",
" good: list[set[int]]\n",
" wlen: list[int]\n",
" def __init__(self, file_name):\n",
" self.idx2token = {}\n",
" sorted = [] # must be already sorted\n",
" lines = open(file_name, \"r\", encoding=\"utf-8\").readlines()\n",
" for l in lines:\n",
" idx = int(l[:l.index(' ')])\n",
" x = eval(l[l.index(' '):l.rindex(' ')])\n",
" x = x.encode(\"utf-8\") if isinstance(x, str) else x\n",
" assert isinstance(x, bytes)\n",
" assert len(x) == int(l[l.rindex(' '):])\n",
" sorted += [x]\n",
" self.idx2token[idx] = x\n",
"\n",
" self.token2idx = {}\n",
" for k, v in self.idx2token.items():\n",
" self.token2idx[v] = int(k)\n",
"\n",
" # precompute some tables for fast matching\n",
" self.table = [[[] for j in range(256)] for i in range(256)]\n",
" self.good = [set() for i in range(256)]\n",
" self.wlen = [0 for i in range(256)]\n",
"\n",
" for i in reversed(range(len(sorted))): # reverse order - match longer tokens first\n",
" s = sorted[i]\n",
" if len(s) >= 2:\n",
" s0 = int(s[0])\n",
" s1 = int(s[1])\n",
" self.table[s0][s1] += [s]\n",
" self.wlen[s0] = max(self.wlen[s0], len(s))\n",
" self.good[s0].add(s1)\n",
"\n",
" def encodeBytes(self, src: bytes) -> list[int]:\n",
" src_len: int = len(src)\n",
" tokens: list[int] = []\n",
" i: int = 0\n",
" while i < src_len:\n",
" s: bytes = src[i : i + 1]\n",
"\n",
" if i < src_len - 1:\n",
" s1: int = int(src[i + 1])\n",
" s0: int = int(src[i])\n",
" if s1 in self.good[s0]:\n",
" sss: bytes = src[i : i + self.wlen[s0]]\n",
" try:\n",
" s = next(filter(sss.startswith, self.table[s0][s1]))\n",
" except:\n",
" pass\n",
" tokens.append(self.token2idx[s])\n",
" i += len(s)\n",
"\n",
" return tokens\n",
"\n",
" def decodeBytes(self, tokens):\n",
" return b''.join(map(lambda i: self.idx2token[i], tokens))\n",
"\n",
" def encode(self, src: str):\n",
" return self.encodeBytes(src.encode(\"utf-8\"))\n",
"\n",
" def decode(self, tokens):\n",
" return self.decodeBytes(tokens).decode('utf-8')\n",
"\n",
" def printTokens(self, tokens):\n",
" for i in tokens:\n",
" s = self.idx2token[i]\n",
" try:\n",
" s = s.decode('utf-8')\n",
" except:\n",
" pass\n",
" print(f'{repr(s)}{i}', end=' ')\n",
" # print(repr(s), i)\n",
" print()\n",
"\n",
"########################################################################################################"
]
},
{
"cell_type": "code",
"execution_count": 7,
"id": "725bc55e-7f3f-4c1c-9664-ad84bf68e943",
"metadata": {},
"outputs": [],
"source": [
"#采样方式没有变化\n",
"def sample_logits(out, temperature=1.0, top_p=0.8):\n",
" probs = F.softmax(out, dim=-1).numpy()\n",
" sorted_probs = np.sort(probs)[::-1]\n",
" cumulative_probs = np.cumsum(sorted_probs)\n",
" cutoff = float(sorted_probs[np.argmax(cumulative_probs > top_p)])\n",
" probs[probs < cutoff] = 0\n",
" if temperature != 1.0:\n",
" probs = probs.pow(1.0 / temperature)\n",
" probs = probs / np.sum(probs)\n",
" out = np.random.choice(a=len(probs), p=probs)\n",
" return out\n",
"\n",
"########################################################################################################"
]
},
{
"cell_type": "raw",
"id": "812fac97-a6b8-423c-831d-fe7397883437",
"metadata": {},
"source": [
"模型下载地址https://hf-mirror.com/BlinkDL/rwkv-6-world/resolve/main/RWKV-x060-World-1B6-v2.1-20240328-ctx4096.pth"
]
},
{
"cell_type": "code",
"execution_count": 8,
"id": "434ff88e-b94e-4f8b-86a3-7fefca32cddb",
"metadata": {},
"outputs": [],
"source": [
"tokenizer = RWKV_TOKENIZER(\"./rwkv_vocab_v20230424.txt\")\n",
"\n",
"args = types.SimpleNamespace()\n",
"args.MODEL_NAME = '/data1/ckw/RWKV-x060-World-1B6-v2.1-20240328-ctx4096'\n",
"args.n_layer = 24\n",
"args.n_embd = 2048\n",
"args.vocab_size = 65536\n",
"\n",
"context = \"\\nDatawhale is \"\n",
"# context = \"\\n我们发现\"\n",
"NUM_TRIALS = 3\n",
"LENGTH_PER_TRIAL = 100\n",
"TEMPERATURE = 1.0\n",
"TOP_P = 0.7"
]
},
{
"cell_type": "markdown",
"id": "ee395717-e599-40fd-a4b9-ddfc800babea",
"metadata": {},
"source": [
"相比于RWKV 5的Channel Mixing见下面代码来说RWKV6的Channel Mixing没有变化这里的`time_maa_k`和RWKV 5中的`time_mix_k`是相同形状的可学习参数都是一个维度为D模型的隐藏层维度的张量。"
]
},
{
"cell_type": "markdown",
"id": "54571fa3-52c6-4b05-ab97-9f06d76a522e",
"metadata": {},
"source": [
"Finch在时间混合Time Mixing上做了以下改进具体内容如下\n",
"\n",
"### 公式部分\n",
"\n",
"Finch时间混合的公式如下\n",
"\n",
"\\begin{align*}\n",
"\\Box_t &= \\text{lerp}_{\\Box}(x_t, x_{t-1}) W_{\\Box}, \\quad \\Box \\in \\{r, k, v, g\\} \\tag{16} \\\\\n",
"d_t &= \\text{lora}_d(\\text{ddlerp}_d(x_t, x_{t-1})) \\tag{17} \\\\\n",
"w_t &= \\exp(-\\exp(d_t)) \\tag{18} \\\\\n",
"\\text{wk} \\mathbf{v}_t &= \\text{diag}(u) \\cdot k_t^\\top \\cdot v_t + \\sum_{i=1}^{t-1} \\left( \\prod_{j=1}^{i-1} w_j \\right) \\cdot k_i^\\top \\cdot v_i \\in \\mathbb{R}^{(D/h) \\times (D/h)} \\tag{19} \\\\\n",
"o_t &= \\text{concat} \\left( \\text{SiLU}(g_t) \\odot \\text{LayerNorm}(r_t \\cdot \\text{wk} \\mathbf{v}_t) \\right) W_o \\in \\mathbb{R}^D \\tag{20}\n",
"\\end{align*}\n",
"\n",
"### 解释部分\n",
"\n",
"- **可学习向量和矩阵**\n",
" - $\\Box_t$ 是通过线性插值lerp计算得到的适用于接受度receptance、键key、值value和门控向量gate vectors。\n",
" - $d_t$ 是通过 $\\text{lora}_d$ 函数对 $\\text{ddlerp}_d(x_t, x_{t-1})$ 进行处理得到的。\n",
" - $w_t$ 是由 $d_t$ 计算得到的,用于控制衰减的动态变化。\n",
"\n",
"- **时间混合计算**\n",
" - $\\text{wk} \\mathbf{v}_t$ 是通过当前键值对 $k_t^\\top \\cdot v_t$ 和所有之前时间步的键值对 $k_i^\\top \\cdot v_i$ 的加权和计算得到的,权重由 $w_t$ 控制。\n",
" - 输出 $o_t$ 是通过连接concat $\\text{SiLU}(g_t)$ 和 $\\text{LayerNorm}(r_t \\cdot \\text{wk} \\mathbf{v}_t)$ 的结果得到的。\n",
"\n",
"- **递归形式**\n",
" \\begin{align*}\n",
" \\text{wk} \\mathbf{v}' &= s + \\text{diag}(u) \\cdot k^\\top \\cdot v \\tag{21} \\\\\n",
" s' &= \\text{diag}(w) \\cdot s + k^\\top \\cdot v \\tag{22}\n",
" \\end{align*}\n",
"\n",
"### 功能与作用\n",
"\n",
"与Eagle不同Finch中的 $w_t$ 不是在整个序列中固定的。每个通道的 $w_t$ 可以随时间动态变化具体取决于数据输入这也是Finch中衰减的核心变化。\n",
"\n",
"### 详细解释\n",
"\n",
"- **动态衰减**\n",
" - Finch引入的数据依赖衰减使得每个通道的 $w_t$ 可以根据当前和之前的输入动态变化,而不是固定的学习向量。\n",
" - 这种动态衰减机制通过新的LoRA机制应用到学习向量上增加了模型的灵活性。\n",
"\n",
"- **高级Token-Shift**\n",
" - 新的时间衰减 $w_t$ 进一步应用了LoRA机制允许每个通道的 $w_t$ 基于当前和之前的令牌混合来变化。\n",
"\n",
"总结来说Finch在时间混合上通过引入数据依赖的动态衰减机制和高级Token-Shift实现了更高的灵活性和精确度使模型能够更好地处理和融合时间步之间的信息从而提高了整体性能。"
]
},
{
"cell_type": "code",
"execution_count": 9,
"id": "3869854a-a4e3-4652-9698-b0d81bbbd645",
"metadata": {},
"outputs": [],
"source": [
"class RWKV_RNN(MyModule):\n",
" def __init__(self, args):\n",
" super().__init__()\n",
" self.args = args\n",
" self.eval() # set torch to inference mode\n",
" \n",
" w = torch.load(args.MODEL_NAME + '.pth', map_location='cpu')\n",
"\n",
" for k in w.keys():\n",
" w[k] = w[k].float() # convert to f32 type\n",
" if '.time_' in k: w[k] = w[k].squeeze()\n",
" if '.time_faaaa' in k: w[k] = w[k].unsqueeze(-1)\n",
"\n",
" self.n_head = w['blocks.0.att.time_faaaa'].shape[0]\n",
" self.head_size = w['blocks.0.ln1.weight'].shape[0] // self.n_head\n",
" \n",
" self.w = types.SimpleNamespace() # set self.w from w\n",
" self.w.blocks = {}\n",
" for k in w.keys(): # example: \"blocks.0.att.time_first\" => self.w.blocks[0].att.time_first\n",
" parts = k.split('.')\n",
" last = parts.pop()\n",
" here = self.w\n",
" for p in parts:\n",
" if p.isdigit():\n",
" p = int(p)\n",
" if p not in here: here[p] = types.SimpleNamespace()\n",
" here = here[p]\n",
" else:\n",
" if not hasattr(here, p): setattr(here, p, types.SimpleNamespace())\n",
" here = getattr(here, p)\n",
" setattr(here, last, w[k])\n",
"\n",
" def layer_norm(self, x, w):\n",
" return F.layer_norm(x, (self.args.n_embd,), weight=w.weight, bias=w.bias)\n",
"\n",
" @MyFunction\n",
" def channel_mixing(self, x, state, i:int, time_maa_k, time_maa_r, kw, vw, rw):\n",
" i0 = (2+self.head_size)*i+0\n",
" sx = state[i0] - x\n",
" xk = x + sx * time_maa_k\n",
" xr = x + sx * time_maa_r\n",
" state[i0] = x\n",
" r = torch.sigmoid(rw @ xr)\n",
" k = torch.square(torch.relu(kw @ xk)) # square relu, primer paper\n",
" return r * (vw @ k)\n",
"\n",
" @MyFunction\n",
" def time_mixing(self, x, state, i:int, x_maa, w_maa, k_maa, v_maa, r_maa, g_maa, tm_w1, tm_w2, td_w1, td_w2, time_first, time_decay, kw, vw, rw, gw, ow, ln_w, ln_b):\n",
" H = self.n_head\n",
" S = self.head_size\n",
"\n",
" i1 = (2+S)*i+1\n",
" sx = state[i1] - x\n",
" state[i1] = x\n",
" xxx = x + sx * x_maa\n",
" xxx = torch.tanh(xxx @ tm_w1).view(5, 1, -1)\n",
" xxx = torch.bmm(xxx, tm_w2).view(5, -1)\n",
" mw, mk, mv, mr, mg = xxx.unbind(dim=0)\n",
"\n",
" xw = x + sx * (w_maa + mw)\n",
" xk = x + sx * (k_maa + mk)\n",
" xv = x + sx * (v_maa + mv)\n",
" xr = x + sx * (r_maa + mr)\n",
" xg = x + sx * (g_maa + mg)\n",
"\n",
" w = (time_decay + (torch.tanh(xw @ td_w1) @ td_w2).float()).view(H, S, 1)\n",
" w = torch.exp(-torch.exp(w.float()))\n",
"\n",
" r = (rw @ xr).view(H, 1, S)\n",
" k = (kw @ xk).view(H, S, 1)\n",
" v = (vw @ xv).view(H, 1, S)\n",
" g = F.silu(gw @ xg)\n",
"\n",
" s = state[(2+S)*i+2:(2+S)*(i+1), :].reshape(H, S, S)\n",
"\n",
" x = torch.zeros(H, S)\n",
" a = k @ v\n",
" x = r @ (time_first * a + s)\n",
" s = a + w * s\n",
" \n",
" state[(2+S)*i+2:(2+S)*(i+1), :] = s.reshape(S, -1)\n",
" x = x.flatten()\n",
"\n",
" x = F.group_norm(x.unsqueeze(0), num_groups=H, weight=ln_w, bias=ln_b, eps = 64e-5).squeeze(0) * g # same as gn(x/8, eps=1e-5)\n",
" return ow @ x\n",
"\n",
" def forward(self, token, state):\n",
" with torch.no_grad():\n",
" if state == None:\n",
" state = torch.zeros(self.args.n_layer * (2+self.head_size), self.args.n_embd)\n",
" \n",
" x = self.w.emb.weight[token]\n",
" x = self.layer_norm(x, self.w.blocks[0].ln0)\n",
" for i in range(self.args.n_layer):\n",
" att = self.w.blocks[i].att\n",
" x = x + self.time_mixing(self.layer_norm(x, self.w.blocks[i].ln1), state, i,\n",
" att.time_maa_x, att.time_maa_w, att.time_maa_k, att.time_maa_v, att.time_maa_r, att.time_maa_g, att.time_maa_w1, att.time_maa_w2,\n",
" att.time_decay_w1, att.time_decay_w2, att.time_faaaa, att.time_decay,\n",
" att.key.weight, att.value.weight, att.receptance.weight, att.gate.weight, att.output.weight,\n",
" att.ln_x.weight, att.ln_x.bias)\n",
" ffn = self.w.blocks[i].ffn\n",
" x = x + self.channel_mixing(self.layer_norm(x, self.w.blocks[i].ln2), state, i, \n",
" ffn.time_maa_k, ffn.time_maa_r, \n",
" ffn.key.weight, ffn.value.weight, ffn.receptance.weight)\n",
" \n",
" x = self.w.head.weight @ self.layer_norm(x, self.w.ln_out)\n",
" return x.float(), state"
]
},
{
"cell_type": "code",
"execution_count": 10,
"id": "5235be83-a574-41f6-8546-bc415e2aeacf",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n",
"Using CPU. Loading /data1/ckw/RWKV-x060-World-1B6-v2.1-20240328-ctx4096 ...\n",
"\n",
"Preprocessing context (slow version. see v2/rwkv/model.py for fast version)\n"
]
}
],
"source": [
"print(f'\\nUsing CPU. Loading {args.MODEL_NAME} ...')\n",
"model = RWKV_RNN(args)\n",
"\n",
"print(f'\\nPreprocessing context (slow version. see v2/rwkv/model.py for fast version)')\n",
"init_state = None"
]
},
{
"cell_type": "code",
"execution_count": 11,
"id": "035ce374-c1c0-43b8-9ba9-37df297baae6",
"metadata": {},
"outputs": [],
"source": [
"for token in tokenizer.encode(context):\n",
" init_out, init_state = model.forward(token, init_state)"
]
},
{
"cell_type": "code",
"execution_count": 12,
"id": "1e17d6ef-c02c-4d27-8cf6-9a262f75f77f",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n",
"\n",
"--[ Trial 0 ]----------------- \n",
"Datawhale is ➡️‼️\n",
"https://twitter.com/datawhale_cn/status/1463997087819689985\n",
"#Data #AI #DataAnalytics #AIOps #DataOps #MachineLearning #DataScience #DataLakeAnalytics #Hadoop #Amazon #Google #AWS #Azure #Dataprep #DevOps #OSS #Linux #Unix #BigData #BigDataOps #DataArchitecture #DataScienceOps #MachineLearningOps\n",
"\n",
"--[ Trial 1 ]----------------- \n",
"Datawhale is 🤓\n",
"\n",
"--[ Trial 2 ]----------------- \n",
"Datawhale is 🤯. They have a solid team, a really good SaaS product and the tools to support their users. That said, I have to take a serious look at the privacy and security of their platform before I buy into their story. I think this is a case of big companies buying into the hype, and they're not taking into account all the realities that go into building a privacy-focused product.\n",
"P.S. You can still apply to Datawhale's Program.\n",
"\n"
]
}
],
"source": [
"for TRIAL in range(NUM_TRIALS):\n",
" print(f'\\n\\n--[ Trial {TRIAL} ]-----------------', context, end=\"\")\n",
" all_tokens = []\n",
" out_last = 0\n",
" out, state = init_out.clone(), init_state.clone()\n",
" for i in range(LENGTH_PER_TRIAL):\n",
" token = sample_logits(out, TEMPERATURE, TOP_P)\n",
" all_tokens += [token]\n",
" try:\n",
" tmp = tokenizer.decode(all_tokens[out_last:])\n",
" if '\\ufffd' not in tmp: # only print when we have a valid utf-8 string\n",
" print(tmp, end=\"\", flush=True)\n",
" out_last = i + 1\n",
" except:\n",
" pass\n",
" out, state = model.forward(token, state) \n",
"print('\\n')"
]
},
{
"cell_type": "markdown",
"id": "172c33e0-6d5b-4143-b85d-86777a2f5739",
"metadata": {},
"source": [
"v6和v5相比感觉更喜欢使用emoj了哈哈"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "kewei-ai",
"language": "python",
"name": "kewei-ai"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.5"
}
},
"nbformat": 4,
"nbformat_minor": 5
}