llms-from-scratch-cn/Model_Architecture_Discussions/rwkv-v6/RWKV_v6_demo.ipynb
2024-05-31 16:44:23 +08:00

397 lines
15 KiB
Plaintext
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

{
"cells": [
{
"cell_type": "code",
"execution_count": 5,
"id": "f64be1c0-02a8-4ea9-ae05-85b66e803cac",
"metadata": {},
"outputs": [],
"source": [
"########################################################################################################\n",
"# The RWKV Language Model - https://github.com/BlinkDL/RWKV-LM\n",
"########################################################################################################\n",
"\n",
"import numpy as np\n",
"np.set_printoptions(precision=4, suppress=True, linewidth=200)\n",
"import types, torch\n",
"import torch.nn as nn\n",
"from torch.nn import functional as F\n",
"\n",
"MyModule = torch.jit.ScriptModule\n",
"MyFunction = torch.jit.script_method"
]
},
{
"cell_type": "code",
"execution_count": 6,
"id": "1261d4e1-df4e-410b-a4fa-45452c4b6fb1",
"metadata": {},
"outputs": [],
"source": [
"class RWKV_TOKENIZER():\n",
" table: list[list[list[bytes]]]\n",
" good: list[set[int]]\n",
" wlen: list[int]\n",
" def __init__(self, file_name):\n",
" self.idx2token = {}\n",
" sorted = [] # must be already sorted\n",
" lines = open(file_name, \"r\", encoding=\"utf-8\").readlines()\n",
" for l in lines:\n",
" idx = int(l[:l.index(' ')])\n",
" x = eval(l[l.index(' '):l.rindex(' ')])\n",
" x = x.encode(\"utf-8\") if isinstance(x, str) else x\n",
" assert isinstance(x, bytes)\n",
" assert len(x) == int(l[l.rindex(' '):])\n",
" sorted += [x]\n",
" self.idx2token[idx] = x\n",
"\n",
" self.token2idx = {}\n",
" for k, v in self.idx2token.items():\n",
" self.token2idx[v] = int(k)\n",
"\n",
" # precompute some tables for fast matching\n",
" self.table = [[[] for j in range(256)] for i in range(256)]\n",
" self.good = [set() for i in range(256)]\n",
" self.wlen = [0 for i in range(256)]\n",
"\n",
" for i in reversed(range(len(sorted))): # reverse order - match longer tokens first\n",
" s = sorted[i]\n",
" if len(s) >= 2:\n",
" s0 = int(s[0])\n",
" s1 = int(s[1])\n",
" self.table[s0][s1] += [s]\n",
" self.wlen[s0] = max(self.wlen[s0], len(s))\n",
" self.good[s0].add(s1)\n",
"\n",
" def encodeBytes(self, src: bytes) -> list[int]:\n",
" src_len: int = len(src)\n",
" tokens: list[int] = []\n",
" i: int = 0\n",
" while i < src_len:\n",
" s: bytes = src[i : i + 1]\n",
"\n",
" if i < src_len - 1:\n",
" s1: int = int(src[i + 1])\n",
" s0: int = int(src[i])\n",
" if s1 in self.good[s0]:\n",
" sss: bytes = src[i : i + self.wlen[s0]]\n",
" try:\n",
" s = next(filter(sss.startswith, self.table[s0][s1]))\n",
" except:\n",
" pass\n",
" tokens.append(self.token2idx[s])\n",
" i += len(s)\n",
"\n",
" return tokens\n",
"\n",
" def decodeBytes(self, tokens):\n",
" return b''.join(map(lambda i: self.idx2token[i], tokens))\n",
"\n",
" def encode(self, src: str):\n",
" return self.encodeBytes(src.encode(\"utf-8\"))\n",
"\n",
" def decode(self, tokens):\n",
" return self.decodeBytes(tokens).decode('utf-8')\n",
"\n",
" def printTokens(self, tokens):\n",
" for i in tokens:\n",
" s = self.idx2token[i]\n",
" try:\n",
" s = s.decode('utf-8')\n",
" except:\n",
" pass\n",
" print(f'{repr(s)}{i}', end=' ')\n",
" # print(repr(s), i)\n",
" print()\n",
"\n",
"########################################################################################################"
]
},
{
"cell_type": "code",
"execution_count": 7,
"id": "725bc55e-7f3f-4c1c-9664-ad84bf68e943",
"metadata": {},
"outputs": [],
"source": [
"def sample_logits(out, temperature=1.0, top_p=0.8):\n",
" probs = F.softmax(out, dim=-1).numpy()\n",
" sorted_probs = np.sort(probs)[::-1]\n",
" cumulative_probs = np.cumsum(sorted_probs)\n",
" cutoff = float(sorted_probs[np.argmax(cumulative_probs > top_p)])\n",
" probs[probs < cutoff] = 0\n",
" if temperature != 1.0:\n",
" probs = probs.pow(1.0 / temperature)\n",
" probs = probs / np.sum(probs)\n",
" out = np.random.choice(a=len(probs), p=probs)\n",
" return out\n",
"\n",
"########################################################################################################"
]
},
{
"cell_type": "raw",
"id": "812fac97-a6b8-423c-831d-fe7397883437",
"metadata": {},
"source": [
"模型下载地址https://hf-mirror.com/BlinkDL/rwkv-6-world/resolve/main/RWKV-x060-World-1B6-v2.1-20240328-ctx4096.pth"
]
},
{
"cell_type": "code",
"execution_count": 13,
"id": "434ff88e-b94e-4f8b-86a3-7fefca32cddb",
"metadata": {},
"outputs": [],
"source": [
"tokenizer = RWKV_TOKENIZER(\"./rwkv_vocab_v20230424.txt\")\n",
"\n",
"args = types.SimpleNamespace()\n",
"args.MODEL_NAME = '/data1/ckw/RWKV-x060-World-1B6-v2.1-20240328-ctx4096'\n",
"args.n_layer = 24\n",
"args.n_embd = 2048\n",
"args.vocab_size = 65536\n",
"\n",
"context = \"\\nDatawhale is \"\n",
"# context = \"\\n我们发现\"\n",
"NUM_TRIALS = 3\n",
"LENGTH_PER_TRIAL = 100\n",
"TEMPERATURE = 1.0\n",
"TOP_P = 0.7"
]
},
{
"cell_type": "code",
"execution_count": 14,
"id": "3869854a-a4e3-4652-9698-b0d81bbbd645",
"metadata": {},
"outputs": [],
"source": [
"class RWKV_RNN(MyModule):\n",
" def __init__(self, args):\n",
" super().__init__()\n",
" self.args = args\n",
" self.eval() # set torch to inference mode\n",
" \n",
" w = torch.load(args.MODEL_NAME + '.pth', map_location='cpu')\n",
"\n",
" for k in w.keys():\n",
" w[k] = w[k].float() # convert to f32 type\n",
" if '.time_' in k: w[k] = w[k].squeeze()\n",
" if '.time_faaaa' in k: w[k] = w[k].unsqueeze(-1)\n",
"\n",
" self.n_head = w['blocks.0.att.time_faaaa'].shape[0]\n",
" self.head_size = w['blocks.0.ln1.weight'].shape[0] // self.n_head\n",
" \n",
" self.w = types.SimpleNamespace() # set self.w from w\n",
" self.w.blocks = {}\n",
" for k in w.keys(): # example: \"blocks.0.att.time_first\" => self.w.blocks[0].att.time_first\n",
" parts = k.split('.')\n",
" last = parts.pop()\n",
" here = self.w\n",
" for p in parts:\n",
" if p.isdigit():\n",
" p = int(p)\n",
" if p not in here: here[p] = types.SimpleNamespace()\n",
" here = here[p]\n",
" else:\n",
" if not hasattr(here, p): setattr(here, p, types.SimpleNamespace())\n",
" here = getattr(here, p)\n",
" setattr(here, last, w[k])\n",
"\n",
" def layer_norm(self, x, w):\n",
" return F.layer_norm(x, (self.args.n_embd,), weight=w.weight, bias=w.bias)\n",
"\n",
" @MyFunction\n",
" def channel_mixing(self, x, state, i:int, time_maa_k, time_maa_r, kw, vw, rw):\n",
" i0 = (2+self.head_size)*i+0\n",
" sx = state[i0] - x\n",
" xk = x + sx * time_maa_k\n",
" xr = x + sx * time_maa_r\n",
" state[i0] = x\n",
" r = torch.sigmoid(rw @ xr)\n",
" k = torch.square(torch.relu(kw @ xk)) # square relu, primer paper\n",
" return r * (vw @ k)\n",
"\n",
" @MyFunction\n",
" def time_mixing(self, x, state, i:int, x_maa, w_maa, k_maa, v_maa, r_maa, g_maa, tm_w1, tm_w2, td_w1, td_w2, time_first, time_decay, kw, vw, rw, gw, ow, ln_w, ln_b):\n",
" H = self.n_head\n",
" S = self.head_size\n",
"\n",
" i1 = (2+S)*i+1\n",
" sx = state[i1] - x\n",
" state[i1] = x\n",
" xxx = x + sx * x_maa\n",
" xxx = torch.tanh(xxx @ tm_w1).view(5, 1, -1)\n",
" xxx = torch.bmm(xxx, tm_w2).view(5, -1)\n",
" mw, mk, mv, mr, mg = xxx.unbind(dim=0)\n",
"\n",
" xw = x + sx * (w_maa + mw)\n",
" xk = x + sx * (k_maa + mk)\n",
" xv = x + sx * (v_maa + mv)\n",
" xr = x + sx * (r_maa + mr)\n",
" xg = x + sx * (g_maa + mg)\n",
"\n",
" w = (time_decay + (torch.tanh(xw @ td_w1) @ td_w2).float()).view(H, S, 1)\n",
" w = torch.exp(-torch.exp(w.float()))\n",
"\n",
" r = (rw @ xr).view(H, 1, S)\n",
" k = (kw @ xk).view(H, S, 1)\n",
" v = (vw @ xv).view(H, 1, S)\n",
" g = F.silu(gw @ xg)\n",
"\n",
" s = state[(2+S)*i+2:(2+S)*(i+1), :].reshape(H, S, S)\n",
"\n",
" x = torch.zeros(H, S)\n",
" a = k @ v\n",
" x = r @ (time_first * a + s)\n",
" s = a + w * s\n",
" \n",
" state[(2+S)*i+2:(2+S)*(i+1), :] = s.reshape(S, -1)\n",
" x = x.flatten()\n",
"\n",
" x = F.group_norm(x.unsqueeze(0), num_groups=H, weight=ln_w, bias=ln_b, eps = 64e-5).squeeze(0) * g # same as gn(x/8, eps=1e-5)\n",
" return ow @ x\n",
"\n",
" def forward(self, token, state):\n",
" with torch.no_grad():\n",
" if state == None:\n",
" state = torch.zeros(self.args.n_layer * (2+self.head_size), self.args.n_embd)\n",
" \n",
" x = self.w.emb.weight[token]\n",
" x = self.layer_norm(x, self.w.blocks[0].ln0)\n",
" for i in range(self.args.n_layer):\n",
" att = self.w.blocks[i].att\n",
" x = x + self.time_mixing(self.layer_norm(x, self.w.blocks[i].ln1), state, i,\n",
" att.time_maa_x, att.time_maa_w, att.time_maa_k, att.time_maa_v, att.time_maa_r, att.time_maa_g, att.time_maa_w1, att.time_maa_w2,\n",
" att.time_decay_w1, att.time_decay_w2, att.time_faaaa, att.time_decay,\n",
" att.key.weight, att.value.weight, att.receptance.weight, att.gate.weight, att.output.weight,\n",
" att.ln_x.weight, att.ln_x.bias)\n",
" ffn = self.w.blocks[i].ffn\n",
" x = x + self.channel_mixing(self.layer_norm(x, self.w.blocks[i].ln2), state, i, \n",
" ffn.time_maa_k, ffn.time_maa_r, \n",
" ffn.key.weight, ffn.value.weight, ffn.receptance.weight)\n",
" \n",
" x = self.w.head.weight @ self.layer_norm(x, self.w.ln_out)\n",
" return x.float(), state"
]
},
{
"cell_type": "code",
"execution_count": 15,
"id": "5235be83-a574-41f6-8546-bc415e2aeacf",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n",
"Using CPU. Loading /data1/ckw/RWKV-x060-World-1B6-v2.1-20240328-ctx4096 ...\n",
"\n",
"Preprocessing context (slow version. see v2/rwkv/model.py for fast version)\n"
]
}
],
"source": [
"print(f'\\nUsing CPU. Loading {args.MODEL_NAME} ...')\n",
"model = RWKV_RNN(args)\n",
"\n",
"print(f'\\nPreprocessing context (slow version. see v2/rwkv/model.py for fast version)')\n",
"init_state = None"
]
},
{
"cell_type": "code",
"execution_count": 16,
"id": "035ce374-c1c0-43b8-9ba9-37df297baae6",
"metadata": {},
"outputs": [],
"source": [
"for token in tokenizer.encode(context):\n",
" init_out, init_state = model.forward(token, init_state)"
]
},
{
"cell_type": "code",
"execution_count": 17,
"id": "1e17d6ef-c02c-4d27-8cf6-9a262f75f77f",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n",
"\n",
"--[ Trial 0 ]----------------- \n",
"Datawhale is 🤑💰. #datawhale #technology #programming #python #data #java #SQL #object #pythonprogramming #django #open source #software #learning #future #machinelearning #dataanalysis #datavisualization #AI #ML #DL #CS #business #career #learning #learningcurve #technology #productivity #development #science #sciencecommunication #sciencecommunicationtools #science #Science #tech #engineering #engineeringtools #engineeringtechniques #engineering\n",
"\n",
"--[ Trial 1 ]----------------- \n",
"Datawhale is 👏 also 👏 an 👏 online 👏 trading 👏 platform 👏 and 👏 as 👏 mentioned 👏 above 👏, 👏 the 👏 reason 👏 why 👏 the 👏 team 👏 doesn't 👏 provide 👏 the 👏 open 👏 source\n",
"\n",
"--[ Trial 2 ]----------------- \n",
"Datawhale is 👌. I have the following questions about its operations and use:\n",
"1. Why does Datawhale charge for its services? What is its value proposition?\n",
"2. What is the main use case of Datawhale?\n",
"3. How does Datawhale use customer data?\n",
"4. What are Datawhales data collection practices?\n",
"5. What are Datawhales data storage practices?\n",
"6. What is Datawhales data retention policy?\n",
"\n",
"\n"
]
}
],
"source": [
"for TRIAL in range(NUM_TRIALS):\n",
" print(f'\\n\\n--[ Trial {TRIAL} ]-----------------', context, end=\"\")\n",
" all_tokens = []\n",
" out_last = 0\n",
" out, state = init_out.clone(), init_state.clone()\n",
" for i in range(LENGTH_PER_TRIAL):\n",
" token = sample_logits(out, TEMPERATURE, TOP_P)\n",
" all_tokens += [token]\n",
" try:\n",
" tmp = tokenizer.decode(all_tokens[out_last:])\n",
" if '\\ufffd' not in tmp: # only print when we have a valid utf-8 string\n",
" print(tmp, end=\"\", flush=True)\n",
" out_last = i + 1\n",
" except:\n",
" pass\n",
" out, state = model.forward(token, state) \n",
"print('\\n')"
]
},
{
"cell_type": "markdown",
"id": "172c33e0-6d5b-4143-b85d-86777a2f5739",
"metadata": {},
"source": [
"v6和v5相比感觉更喜欢使用emoj了哈哈"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "kewei-ai",
"language": "python",
"name": "kewei-ai"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.5"
}
},
"nbformat": 4,
"nbformat_minor": 5
}