#!/usr/bin/env python # coding: utf-8 # In[1]: ######################################################################################################## # The RWKV Language Model - https://github.com/BlinkDL/RWKV-LM ######################################################################################################## import numpy as np np.set_printoptions(precision=4, suppress=True, linewidth=200) import types, torch import torch.nn as nn from torch.nn import functional as F MyModule = torch.jit.ScriptModule MyFunction = torch.jit.script_method # In[2]: import torch # In[3]: class RWKV_TOKENIZER(): table: list[list[list[bytes]]] good: list[set[int]] wlen: list[int] def __init__(self, file_name): self.idx2token = {} sorted = [] # must be already sorted lines = open(file_name, "r", encoding="utf-8").readlines() for l in lines: idx = int(l[:l.index(' ')]) x = eval(l[l.index(' '):l.rindex(' ')]) x = x.encode("utf-8") if isinstance(x, str) else x assert isinstance(x, bytes) assert len(x) == int(l[l.rindex(' '):]) sorted += [x] self.idx2token[idx] = x self.token2idx = {} for k, v in self.idx2token.items(): self.token2idx[v] = int(k) # precompute some tables for fast matching self.table = [[[] for j in range(256)] for i in range(256)] self.good = [set() for i in range(256)] self.wlen = [0 for i in range(256)] for i in reversed(range(len(sorted))): # reverse order - match longer tokens first s = sorted[i] if len(s) >= 2: s0 = int(s[0]) s1 = int(s[1]) self.table[s0][s1] += [s] self.wlen[s0] = max(self.wlen[s0], len(s)) self.good[s0].add(s1) def encodeBytes(self, src: bytes) -> list[int]: src_len: int = len(src) tokens: list[int] = [] i: int = 0 while i < src_len: s: bytes = src[i : i + 1] if i < src_len - 1: s1: int = int(src[i + 1]) s0: int = int(src[i]) if s1 in self.good[s0]: sss: bytes = src[i : i + self.wlen[s0]] try: s = next(filter(sss.startswith, self.table[s0][s1])) except: pass tokens.append(self.token2idx[s]) i += len(s) return tokens def decodeBytes(self, tokens): return b''.join(map(lambda i: self.idx2token[i], tokens)) def encode(self, src: str): return self.encodeBytes(src.encode("utf-8")) def decode(self, tokens): return self.decodeBytes(tokens).decode('utf-8') def printTokens(self, tokens): for i in tokens: s = self.idx2token[i] try: s = s.decode('utf-8') except: pass print(f'{repr(s)}{i}', end=' ') # print(repr(s), i) print() ######################################################################################################## # In[4]: def sample_logits(out, temperature=1.0, top_p=0.8): probs = F.softmax(out, dim=-1).numpy() sorted_probs = np.sort(probs)[::-1] cumulative_probs = np.cumsum(sorted_probs) cutoff = float(sorted_probs[np.argmax(cumulative_probs > top_p)]) probs[probs < cutoff] = 0 if temperature != 1.0: probs = probs.pow(1.0 / temperature) probs = probs / np.sum(probs) out = np.random.choice(a=len(probs), p=probs) return out ######################################################################################################## 可以从这个链接下载模型: https://www.modelscope.cn/models/AI-ModelScope/rwkv-5-world/files https://www.modelscope.cn/api/v1/models/AI-ModelScope/rwkv-5-world/repo?Revision=master&FilePath=RWKV-5-World-0.1B-v1-20230803-ctx4096.pth # In[68]: tokenizer = RWKV_TOKENIZER("./rwkv_vocab_v20230424.txt") # THIS IS NOW UPDATED TO SUPPORT LATEST RWKV-5 WORLD v2 MODELS args = types.SimpleNamespace() args.MODEL_NAME = '/data1/ckw/RWKV-5-World-0.4B-v2-20231113-ctx4096' #这里不用有后缀.pth args.n_layer = 24 args.n_embd = 1024 args.vocab_size = 65536 # In[69]: # N_LAYER="12" # N_EMBD="768" N_LAYER="24" N_EMBD="1024" # In[70]: # context = "\nElon Musk has" # context = "\n我们发现" context = "Q:Do you know datawhalechina?\nA:" NUM_TRIALS = 3 LENGTH_PER_TRIAL = 100 LENGTH_PER_TRIAL = 4096 TEMPERATURE = 1.0 TOP_P = 0.7 # In[80]: class RWKV_RNN(MyModule): def __init__(self, args): super().__init__() self.args = args self.eval() # set torch to inference mode w = torch.load(args.MODEL_NAME + '.pth', map_location='cpu') for k in w.keys(): w[k] = w[k].float() # convert to f32 type if '.time_' in k: w[k] = w[k].squeeze() if '.time_decay' in k: w[k] = torch.exp(-torch.exp(w[k])).unsqueeze(-1) if '.time_faaaa' in k: w[k] = w[k].unsqueeze(-1) self.n_head = w['blocks.0.att.time_decay'].shape[0] self.head_size = w['blocks.0.ln1.weight'].shape[0] // self.n_head self.w = types.SimpleNamespace() # set self.w from w self.w.blocks = {} for k in w.keys(): # example: "blocks.0.att.time_first" => self.w.blocks[0].att.time_first parts = k.split('.') last = parts.pop() here = self.w for p in parts: if p.isdigit(): p = int(p) if p not in here: here[p] = types.SimpleNamespace() here = here[p] else: if not hasattr(here, p): setattr(here, p, types.SimpleNamespace()) here = getattr(here, p) setattr(here, last, w[k]) def layer_norm(self, x, w): return F.layer_norm(x, (self.args.n_embd,), weight=w.weight, bias=w.bias) @MyFunction def channel_mixing(self, x, state, i:int, time_mix_k, time_mix_r, kw, vw, rw): i0 = (2+self.head_size)*i+0 xk = x * time_mix_k + state[i0] * (1 - time_mix_k) xr = x * time_mix_r + state[i0] * (1 - time_mix_r) state[i0] = x r = torch.sigmoid(rw @ xr) k = torch.square(torch.relu(kw @ xk)) # square relu, primer paper return r * (vw @ k) @MyFunction def time_mixing(self, x, state, i:int, time_mix_k, time_mix_v, time_mix_r, time_mix_g, time_first, time_decay, kw, vw, rw, gw, ow, ln_w, ln_b): H = self.n_head S = self.head_size i1 = (2+S)*i+1 xk = x * time_mix_k + state[i1] * (1 - time_mix_k) xv = x * time_mix_v + state[i1] * (1 - time_mix_v) xr = x * time_mix_r + state[i1] * (1 - time_mix_r) xg = x * time_mix_g + state[i1] * (1 - time_mix_g) state[i1] = x r = (rw @ xr).view(H, 1, S) k = (kw @ xk).view(H, S, 1) v = (vw @ xv).view(H, 1, S) g = F.silu(gw @ xg) s = state[(2+S)*i+2:(2+S)*(i+1), :].reshape(H, S, S) x = torch.zeros(H, S) a = k @ v x = r @ (time_first * a + s) s = a + time_decay * s state[(2+S)*i+2:(2+S)*(i+1), :] = s.reshape(S, -1) x = x.flatten() x = F.group_norm(x.unsqueeze(0), num_groups=H, weight=ln_w, bias=ln_b, eps = 64e-5).squeeze(0) * g # same as gn(x/8, eps=1e-5) return ow @ x def forward(self, token, state): with torch.no_grad(): if state == None: state = torch.zeros(self.args.n_layer * (2+self.head_size), self.args.n_embd) x = self.w.emb.weight[token] x = self.layer_norm(x, self.w.blocks[0].ln0) for i in range(self.args.n_layer): # print(i) att = self.w.blocks[i].att x = x + self.time_mixing(self.layer_norm(x, self.w.blocks[i].ln1), state, i, att.time_mix_k, att.time_mix_v, att.time_mix_r, att.time_mix_g, att.time_faaaa, att.time_decay, att.key.weight, att.value.weight, att.receptance.weight, att.gate.weight, att.output.weight, att.ln_x.weight, att.ln_x.bias) ffn = self.w.blocks[i].ffn x = x + self.channel_mixing(self.layer_norm(x, self.w.blocks[i].ln2), state, i, ffn.time_mix_k, ffn.time_mix_r, ffn.key.weight, ffn.value.weight, ffn.receptance.weight) x = self.w.head.weight @ self.layer_norm(x, self.w.ln_out) return x.float(), state # In[81]: context = "Q:Do you know datawhalechina?\nA:" # In[82]: args.MODEL_NAME # In[83]: args.n_layer,args.n_embd # In[84]: # args.n_layer = 24 # args.n_embd = 1024 # In[85]: # args.n_layer = 12 # args.n_embd = 768 # In[86]: # args.MODEL_NAME='../models/rwkv-5-world-1b5' # In[87]: print(f'\nUsing CPU. Loading {args.MODEL_NAME} ...') model = RWKV_RNN(args) print(f'\nPreprocessing context (slow version. see v2/rwkv/model.py for fast version)') init_state = None # In[88]: init_state = None # In[89]: LENGTH_PER_TRIAL=1024 # In[90]: for token in tokenizer.encode(context): init_out, init_state = model.forward(token, init_state) for TRIAL in range(NUM_TRIALS): print(f'\n\n--[ Trial {TRIAL} ]-----------------', context, end="") all_tokens = [] out_last = 0 out, state = init_out.clone(), init_state.clone() for i in range(LENGTH_PER_TRIAL): token = sample_logits(out, TEMPERATURE, TOP_P) all_tokens += [token] try: tmp = tokenizer.decode(all_tokens[out_last:]) if '\ufffd' not in tmp: # only print when we have a valid utf-8 string print(tmp, end="", flush=True) out_last = i + 1 except: pass out, state = model.forward(token, state) print('\n') # 显然datawhale这个数据没有训练过哈哈。不过速度是蛮快的,这个没得说,在cpu上跑,资源消耗也很小。