mirror of
https://github.com/datawhalechina/llms-from-scratch-cn.git
synced 2026-05-01 11:58:17 +08:00
397 lines
15 KiB
Plaintext
397 lines
15 KiB
Plaintext
{
|
||
"cells": [
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 5,
|
||
"id": "f64be1c0-02a8-4ea9-ae05-85b66e803cac",
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"########################################################################################################\n",
|
||
"# The RWKV Language Model - https://github.com/BlinkDL/RWKV-LM\n",
|
||
"########################################################################################################\n",
|
||
"\n",
|
||
"import numpy as np\n",
|
||
"np.set_printoptions(precision=4, suppress=True, linewidth=200)\n",
|
||
"import types, torch\n",
|
||
"import torch.nn as nn\n",
|
||
"from torch.nn import functional as F\n",
|
||
"\n",
|
||
"MyModule = torch.jit.ScriptModule\n",
|
||
"MyFunction = torch.jit.script_method"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 6,
|
||
"id": "1261d4e1-df4e-410b-a4fa-45452c4b6fb1",
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"class RWKV_TOKENIZER():\n",
|
||
" table: list[list[list[bytes]]]\n",
|
||
" good: list[set[int]]\n",
|
||
" wlen: list[int]\n",
|
||
" def __init__(self, file_name):\n",
|
||
" self.idx2token = {}\n",
|
||
" sorted = [] # must be already sorted\n",
|
||
" lines = open(file_name, \"r\", encoding=\"utf-8\").readlines()\n",
|
||
" for l in lines:\n",
|
||
" idx = int(l[:l.index(' ')])\n",
|
||
" x = eval(l[l.index(' '):l.rindex(' ')])\n",
|
||
" x = x.encode(\"utf-8\") if isinstance(x, str) else x\n",
|
||
" assert isinstance(x, bytes)\n",
|
||
" assert len(x) == int(l[l.rindex(' '):])\n",
|
||
" sorted += [x]\n",
|
||
" self.idx2token[idx] = x\n",
|
||
"\n",
|
||
" self.token2idx = {}\n",
|
||
" for k, v in self.idx2token.items():\n",
|
||
" self.token2idx[v] = int(k)\n",
|
||
"\n",
|
||
" # precompute some tables for fast matching\n",
|
||
" self.table = [[[] for j in range(256)] for i in range(256)]\n",
|
||
" self.good = [set() for i in range(256)]\n",
|
||
" self.wlen = [0 for i in range(256)]\n",
|
||
"\n",
|
||
" for i in reversed(range(len(sorted))): # reverse order - match longer tokens first\n",
|
||
" s = sorted[i]\n",
|
||
" if len(s) >= 2:\n",
|
||
" s0 = int(s[0])\n",
|
||
" s1 = int(s[1])\n",
|
||
" self.table[s0][s1] += [s]\n",
|
||
" self.wlen[s0] = max(self.wlen[s0], len(s))\n",
|
||
" self.good[s0].add(s1)\n",
|
||
"\n",
|
||
" def encodeBytes(self, src: bytes) -> list[int]:\n",
|
||
" src_len: int = len(src)\n",
|
||
" tokens: list[int] = []\n",
|
||
" i: int = 0\n",
|
||
" while i < src_len:\n",
|
||
" s: bytes = src[i : i + 1]\n",
|
||
"\n",
|
||
" if i < src_len - 1:\n",
|
||
" s1: int = int(src[i + 1])\n",
|
||
" s0: int = int(src[i])\n",
|
||
" if s1 in self.good[s0]:\n",
|
||
" sss: bytes = src[i : i + self.wlen[s0]]\n",
|
||
" try:\n",
|
||
" s = next(filter(sss.startswith, self.table[s0][s1]))\n",
|
||
" except:\n",
|
||
" pass\n",
|
||
" tokens.append(self.token2idx[s])\n",
|
||
" i += len(s)\n",
|
||
"\n",
|
||
" return tokens\n",
|
||
"\n",
|
||
" def decodeBytes(self, tokens):\n",
|
||
" return b''.join(map(lambda i: self.idx2token[i], tokens))\n",
|
||
"\n",
|
||
" def encode(self, src: str):\n",
|
||
" return self.encodeBytes(src.encode(\"utf-8\"))\n",
|
||
"\n",
|
||
" def decode(self, tokens):\n",
|
||
" return self.decodeBytes(tokens).decode('utf-8')\n",
|
||
"\n",
|
||
" def printTokens(self, tokens):\n",
|
||
" for i in tokens:\n",
|
||
" s = self.idx2token[i]\n",
|
||
" try:\n",
|
||
" s = s.decode('utf-8')\n",
|
||
" except:\n",
|
||
" pass\n",
|
||
" print(f'{repr(s)}{i}', end=' ')\n",
|
||
" # print(repr(s), i)\n",
|
||
" print()\n",
|
||
"\n",
|
||
"########################################################################################################"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 7,
|
||
"id": "725bc55e-7f3f-4c1c-9664-ad84bf68e943",
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"def sample_logits(out, temperature=1.0, top_p=0.8):\n",
|
||
" probs = F.softmax(out, dim=-1).numpy()\n",
|
||
" sorted_probs = np.sort(probs)[::-1]\n",
|
||
" cumulative_probs = np.cumsum(sorted_probs)\n",
|
||
" cutoff = float(sorted_probs[np.argmax(cumulative_probs > top_p)])\n",
|
||
" probs[probs < cutoff] = 0\n",
|
||
" if temperature != 1.0:\n",
|
||
" probs = probs.pow(1.0 / temperature)\n",
|
||
" probs = probs / np.sum(probs)\n",
|
||
" out = np.random.choice(a=len(probs), p=probs)\n",
|
||
" return out\n",
|
||
"\n",
|
||
"########################################################################################################"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "raw",
|
||
"id": "812fac97-a6b8-423c-831d-fe7397883437",
|
||
"metadata": {},
|
||
"source": [
|
||
"模型下载地址:https://hf-mirror.com/BlinkDL/rwkv-6-world/resolve/main/RWKV-x060-World-1B6-v2.1-20240328-ctx4096.pth"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 13,
|
||
"id": "434ff88e-b94e-4f8b-86a3-7fefca32cddb",
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"tokenizer = RWKV_TOKENIZER(\"./rwkv_vocab_v20230424.txt\")\n",
|
||
"\n",
|
||
"args = types.SimpleNamespace()\n",
|
||
"args.MODEL_NAME = '/data1/ckw/RWKV-x060-World-1B6-v2.1-20240328-ctx4096'\n",
|
||
"args.n_layer = 24\n",
|
||
"args.n_embd = 2048\n",
|
||
"args.vocab_size = 65536\n",
|
||
"\n",
|
||
"context = \"\\nDatawhale is \"\n",
|
||
"# context = \"\\n我们发现\"\n",
|
||
"NUM_TRIALS = 3\n",
|
||
"LENGTH_PER_TRIAL = 100\n",
|
||
"TEMPERATURE = 1.0\n",
|
||
"TOP_P = 0.7"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 14,
|
||
"id": "3869854a-a4e3-4652-9698-b0d81bbbd645",
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"class RWKV_RNN(MyModule):\n",
|
||
" def __init__(self, args):\n",
|
||
" super().__init__()\n",
|
||
" self.args = args\n",
|
||
" self.eval() # set torch to inference mode\n",
|
||
" \n",
|
||
" w = torch.load(args.MODEL_NAME + '.pth', map_location='cpu')\n",
|
||
"\n",
|
||
" for k in w.keys():\n",
|
||
" w[k] = w[k].float() # convert to f32 type\n",
|
||
" if '.time_' in k: w[k] = w[k].squeeze()\n",
|
||
" if '.time_faaaa' in k: w[k] = w[k].unsqueeze(-1)\n",
|
||
"\n",
|
||
" self.n_head = w['blocks.0.att.time_faaaa'].shape[0]\n",
|
||
" self.head_size = w['blocks.0.ln1.weight'].shape[0] // self.n_head\n",
|
||
" \n",
|
||
" self.w = types.SimpleNamespace() # set self.w from w\n",
|
||
" self.w.blocks = {}\n",
|
||
" for k in w.keys(): # example: \"blocks.0.att.time_first\" => self.w.blocks[0].att.time_first\n",
|
||
" parts = k.split('.')\n",
|
||
" last = parts.pop()\n",
|
||
" here = self.w\n",
|
||
" for p in parts:\n",
|
||
" if p.isdigit():\n",
|
||
" p = int(p)\n",
|
||
" if p not in here: here[p] = types.SimpleNamespace()\n",
|
||
" here = here[p]\n",
|
||
" else:\n",
|
||
" if not hasattr(here, p): setattr(here, p, types.SimpleNamespace())\n",
|
||
" here = getattr(here, p)\n",
|
||
" setattr(here, last, w[k])\n",
|
||
"\n",
|
||
" def layer_norm(self, x, w):\n",
|
||
" return F.layer_norm(x, (self.args.n_embd,), weight=w.weight, bias=w.bias)\n",
|
||
"\n",
|
||
" @MyFunction\n",
|
||
" def channel_mixing(self, x, state, i:int, time_maa_k, time_maa_r, kw, vw, rw):\n",
|
||
" i0 = (2+self.head_size)*i+0\n",
|
||
" sx = state[i0] - x\n",
|
||
" xk = x + sx * time_maa_k\n",
|
||
" xr = x + sx * time_maa_r\n",
|
||
" state[i0] = x\n",
|
||
" r = torch.sigmoid(rw @ xr)\n",
|
||
" k = torch.square(torch.relu(kw @ xk)) # square relu, primer paper\n",
|
||
" return r * (vw @ k)\n",
|
||
"\n",
|
||
" @MyFunction\n",
|
||
" def time_mixing(self, x, state, i:int, x_maa, w_maa, k_maa, v_maa, r_maa, g_maa, tm_w1, tm_w2, td_w1, td_w2, time_first, time_decay, kw, vw, rw, gw, ow, ln_w, ln_b):\n",
|
||
" H = self.n_head\n",
|
||
" S = self.head_size\n",
|
||
"\n",
|
||
" i1 = (2+S)*i+1\n",
|
||
" sx = state[i1] - x\n",
|
||
" state[i1] = x\n",
|
||
" xxx = x + sx * x_maa\n",
|
||
" xxx = torch.tanh(xxx @ tm_w1).view(5, 1, -1)\n",
|
||
" xxx = torch.bmm(xxx, tm_w2).view(5, -1)\n",
|
||
" mw, mk, mv, mr, mg = xxx.unbind(dim=0)\n",
|
||
"\n",
|
||
" xw = x + sx * (w_maa + mw)\n",
|
||
" xk = x + sx * (k_maa + mk)\n",
|
||
" xv = x + sx * (v_maa + mv)\n",
|
||
" xr = x + sx * (r_maa + mr)\n",
|
||
" xg = x + sx * (g_maa + mg)\n",
|
||
"\n",
|
||
" w = (time_decay + (torch.tanh(xw @ td_w1) @ td_w2).float()).view(H, S, 1)\n",
|
||
" w = torch.exp(-torch.exp(w.float()))\n",
|
||
"\n",
|
||
" r = (rw @ xr).view(H, 1, S)\n",
|
||
" k = (kw @ xk).view(H, S, 1)\n",
|
||
" v = (vw @ xv).view(H, 1, S)\n",
|
||
" g = F.silu(gw @ xg)\n",
|
||
"\n",
|
||
" s = state[(2+S)*i+2:(2+S)*(i+1), :].reshape(H, S, S)\n",
|
||
"\n",
|
||
" x = torch.zeros(H, S)\n",
|
||
" a = k @ v\n",
|
||
" x = r @ (time_first * a + s)\n",
|
||
" s = a + w * s\n",
|
||
" \n",
|
||
" state[(2+S)*i+2:(2+S)*(i+1), :] = s.reshape(S, -1)\n",
|
||
" x = x.flatten()\n",
|
||
"\n",
|
||
" x = F.group_norm(x.unsqueeze(0), num_groups=H, weight=ln_w, bias=ln_b, eps = 64e-5).squeeze(0) * g # same as gn(x/8, eps=1e-5)\n",
|
||
" return ow @ x\n",
|
||
"\n",
|
||
" def forward(self, token, state):\n",
|
||
" with torch.no_grad():\n",
|
||
" if state == None:\n",
|
||
" state = torch.zeros(self.args.n_layer * (2+self.head_size), self.args.n_embd)\n",
|
||
" \n",
|
||
" x = self.w.emb.weight[token]\n",
|
||
" x = self.layer_norm(x, self.w.blocks[0].ln0)\n",
|
||
" for i in range(self.args.n_layer):\n",
|
||
" att = self.w.blocks[i].att\n",
|
||
" x = x + self.time_mixing(self.layer_norm(x, self.w.blocks[i].ln1), state, i,\n",
|
||
" att.time_maa_x, att.time_maa_w, att.time_maa_k, att.time_maa_v, att.time_maa_r, att.time_maa_g, att.time_maa_w1, att.time_maa_w2,\n",
|
||
" att.time_decay_w1, att.time_decay_w2, att.time_faaaa, att.time_decay,\n",
|
||
" att.key.weight, att.value.weight, att.receptance.weight, att.gate.weight, att.output.weight,\n",
|
||
" att.ln_x.weight, att.ln_x.bias)\n",
|
||
" ffn = self.w.blocks[i].ffn\n",
|
||
" x = x + self.channel_mixing(self.layer_norm(x, self.w.blocks[i].ln2), state, i, \n",
|
||
" ffn.time_maa_k, ffn.time_maa_r, \n",
|
||
" ffn.key.weight, ffn.value.weight, ffn.receptance.weight)\n",
|
||
" \n",
|
||
" x = self.w.head.weight @ self.layer_norm(x, self.w.ln_out)\n",
|
||
" return x.float(), state"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 15,
|
||
"id": "5235be83-a574-41f6-8546-bc415e2aeacf",
|
||
"metadata": {},
|
||
"outputs": [
|
||
{
|
||
"name": "stdout",
|
||
"output_type": "stream",
|
||
"text": [
|
||
"\n",
|
||
"Using CPU. Loading /data1/ckw/RWKV-x060-World-1B6-v2.1-20240328-ctx4096 ...\n",
|
||
"\n",
|
||
"Preprocessing context (slow version. see v2/rwkv/model.py for fast version)\n"
|
||
]
|
||
}
|
||
],
|
||
"source": [
|
||
"print(f'\\nUsing CPU. Loading {args.MODEL_NAME} ...')\n",
|
||
"model = RWKV_RNN(args)\n",
|
||
"\n",
|
||
"print(f'\\nPreprocessing context (slow version. see v2/rwkv/model.py for fast version)')\n",
|
||
"init_state = None"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 16,
|
||
"id": "035ce374-c1c0-43b8-9ba9-37df297baae6",
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"for token in tokenizer.encode(context):\n",
|
||
" init_out, init_state = model.forward(token, init_state)"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 17,
|
||
"id": "1e17d6ef-c02c-4d27-8cf6-9a262f75f77f",
|
||
"metadata": {},
|
||
"outputs": [
|
||
{
|
||
"name": "stdout",
|
||
"output_type": "stream",
|
||
"text": [
|
||
"\n",
|
||
"\n",
|
||
"--[ Trial 0 ]----------------- \n",
|
||
"Datawhale is 🤑💰. #datawhale #technology #programming #python #data #java #SQL #object #pythonprogramming #django #open source #software #learning #future #machinelearning #dataanalysis #datavisualization #AI #ML #DL #CS #business #career #learning #learningcurve #technology #productivity #development #science #sciencecommunication #sciencecommunicationtools #science #Science #tech #engineering #engineeringtools #engineeringtechniques #engineering\n",
|
||
"\n",
|
||
"--[ Trial 1 ]----------------- \n",
|
||
"Datawhale is 👏 also 👏 an 👏 online 👏 trading 👏 platform 👏 and 👏 as 👏 mentioned 👏 above 👏, 👏 the 👏 reason 👏 why 👏 the 👏 team 👏 doesn't 👏 provide 👏 the 👏 open 👏 source\n",
|
||
"\n",
|
||
"--[ Trial 2 ]----------------- \n",
|
||
"Datawhale is 👌. I have the following questions about its operations and use:\n",
|
||
"1. Why does Datawhale charge for its services? What is its value proposition?\n",
|
||
"2. What is the main use case of Datawhale?\n",
|
||
"3. How does Datawhale use customer data?\n",
|
||
"4. What are Datawhale’s data collection practices?\n",
|
||
"5. What are Datawhale’s data storage practices?\n",
|
||
"6. What is Datawhale’s data retention policy?\n",
|
||
"\n",
|
||
"\n"
|
||
]
|
||
}
|
||
],
|
||
"source": [
|
||
"for TRIAL in range(NUM_TRIALS):\n",
|
||
" print(f'\\n\\n--[ Trial {TRIAL} ]-----------------', context, end=\"\")\n",
|
||
" all_tokens = []\n",
|
||
" out_last = 0\n",
|
||
" out, state = init_out.clone(), init_state.clone()\n",
|
||
" for i in range(LENGTH_PER_TRIAL):\n",
|
||
" token = sample_logits(out, TEMPERATURE, TOP_P)\n",
|
||
" all_tokens += [token]\n",
|
||
" try:\n",
|
||
" tmp = tokenizer.decode(all_tokens[out_last:])\n",
|
||
" if '\\ufffd' not in tmp: # only print when we have a valid utf-8 string\n",
|
||
" print(tmp, end=\"\", flush=True)\n",
|
||
" out_last = i + 1\n",
|
||
" except:\n",
|
||
" pass\n",
|
||
" out, state = model.forward(token, state) \n",
|
||
"print('\\n')"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"id": "172c33e0-6d5b-4143-b85d-86777a2f5739",
|
||
"metadata": {},
|
||
"source": [
|
||
"v6和v5相比,感觉更喜欢使用emoj了哈哈"
|
||
]
|
||
}
|
||
],
|
||
"metadata": {
|
||
"kernelspec": {
|
||
"display_name": "kewei-ai",
|
||
"language": "python",
|
||
"name": "kewei-ai"
|
||
},
|
||
"language_info": {
|
||
"codemirror_mode": {
|
||
"name": "ipython",
|
||
"version": 3
|
||
},
|
||
"file_extension": ".py",
|
||
"mimetype": "text/x-python",
|
||
"name": "python",
|
||
"nbconvert_exporter": "python",
|
||
"pygments_lexer": "ipython3",
|
||
"version": "3.11.5"
|
||
}
|
||
},
|
||
"nbformat": 4,
|
||
"nbformat_minor": 5
|
||
}
|