llms-from-scratch-cn/Model_Architecture_Discussions/rwkv-compare/model_v5.py
2024-05-31 16:44:23 +08:00

352 lines
10 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

#!/usr/bin/env python
# coding: utf-8
# In[1]:
########################################################################################################
# The RWKV Language Model - https://github.com/BlinkDL/RWKV-LM
########################################################################################################
import numpy as np
np.set_printoptions(precision=4, suppress=True, linewidth=200)
import types, torch
import torch.nn as nn
from torch.nn import functional as F
MyModule = torch.jit.ScriptModule
MyFunction = torch.jit.script_method
# In[2]:
import torch
# In[3]:
class RWKV_TOKENIZER():
table: list[list[list[bytes]]]
good: list[set[int]]
wlen: list[int]
def __init__(self, file_name):
self.idx2token = {}
sorted = [] # must be already sorted
lines = open(file_name, "r", encoding="utf-8").readlines()
for l in lines:
idx = int(l[:l.index(' ')])
x = eval(l[l.index(' '):l.rindex(' ')])
x = x.encode("utf-8") if isinstance(x, str) else x
assert isinstance(x, bytes)
assert len(x) == int(l[l.rindex(' '):])
sorted += [x]
self.idx2token[idx] = x
self.token2idx = {}
for k, v in self.idx2token.items():
self.token2idx[v] = int(k)
# precompute some tables for fast matching
self.table = [[[] for j in range(256)] for i in range(256)]
self.good = [set() for i in range(256)]
self.wlen = [0 for i in range(256)]
for i in reversed(range(len(sorted))): # reverse order - match longer tokens first
s = sorted[i]
if len(s) >= 2:
s0 = int(s[0])
s1 = int(s[1])
self.table[s0][s1] += [s]
self.wlen[s0] = max(self.wlen[s0], len(s))
self.good[s0].add(s1)
def encodeBytes(self, src: bytes) -> list[int]:
src_len: int = len(src)
tokens: list[int] = []
i: int = 0
while i < src_len:
s: bytes = src[i : i + 1]
if i < src_len - 1:
s1: int = int(src[i + 1])
s0: int = int(src[i])
if s1 in self.good[s0]:
sss: bytes = src[i : i + self.wlen[s0]]
try:
s = next(filter(sss.startswith, self.table[s0][s1]))
except:
pass
tokens.append(self.token2idx[s])
i += len(s)
return tokens
def decodeBytes(self, tokens):
return b''.join(map(lambda i: self.idx2token[i], tokens))
def encode(self, src: str):
return self.encodeBytes(src.encode("utf-8"))
def decode(self, tokens):
return self.decodeBytes(tokens).decode('utf-8')
def printTokens(self, tokens):
for i in tokens:
s = self.idx2token[i]
try:
s = s.decode('utf-8')
except:
pass
print(f'{repr(s)}{i}', end=' ')
# print(repr(s), i)
print()
########################################################################################################
# In[4]:
def sample_logits(out, temperature=1.0, top_p=0.8):
probs = F.softmax(out, dim=-1).numpy()
sorted_probs = np.sort(probs)[::-1]
cumulative_probs = np.cumsum(sorted_probs)
cutoff = float(sorted_probs[np.argmax(cumulative_probs > top_p)])
probs[probs < cutoff] = 0
if temperature != 1.0:
probs = probs.pow(1.0 / temperature)
probs = probs / np.sum(probs)
out = np.random.choice(a=len(probs), p=probs)
return out
########################################################################################################
可以从这个链接下载模型
https://www.modelscope.cn/models/AI-ModelScope/rwkv-5-world/files
https://www.modelscope.cn/api/v1/models/AI-ModelScope/rwkv-5-world/repo?Revision=master&FilePath=RWKV-5-World-0.1B-v1-20230803-ctx4096.pth
# In[68]:
tokenizer = RWKV_TOKENIZER("./rwkv_vocab_v20230424.txt")
# THIS IS NOW UPDATED TO SUPPORT LATEST RWKV-5 WORLD v2 MODELS
args = types.SimpleNamespace()
args.MODEL_NAME = '/data1/ckw/RWKV-5-World-0.4B-v2-20231113-ctx4096' #这里不用有后缀.pth
args.n_layer = 24
args.n_embd = 1024
args.vocab_size = 65536
# In[69]:
# N_LAYER="12"
# N_EMBD="768"
N_LAYER="24"
N_EMBD="1024"
# In[70]:
# context = "\nElon Musk has"
# context = "\n我们发现"
context = "Q:Do you know datawhalechina?\nA:"
NUM_TRIALS = 3
LENGTH_PER_TRIAL = 100
LENGTH_PER_TRIAL = 4096
TEMPERATURE = 1.0
TOP_P = 0.7
# In[80]:
class RWKV_RNN(MyModule):
def __init__(self, args):
super().__init__()
self.args = args
self.eval() # set torch to inference mode
w = torch.load(args.MODEL_NAME + '.pth', map_location='cpu')
for k in w.keys():
w[k] = w[k].float() # convert to f32 type
if '.time_' in k: w[k] = w[k].squeeze()
if '.time_decay' in k: w[k] = torch.exp(-torch.exp(w[k])).unsqueeze(-1)
if '.time_faaaa' in k: w[k] = w[k].unsqueeze(-1)
self.n_head = w['blocks.0.att.time_decay'].shape[0]
self.head_size = w['blocks.0.ln1.weight'].shape[0] // self.n_head
self.w = types.SimpleNamespace() # set self.w from w
self.w.blocks = {}
for k in w.keys(): # example: "blocks.0.att.time_first" => self.w.blocks[0].att.time_first
parts = k.split('.')
last = parts.pop()
here = self.w
for p in parts:
if p.isdigit():
p = int(p)
if p not in here: here[p] = types.SimpleNamespace()
here = here[p]
else:
if not hasattr(here, p): setattr(here, p, types.SimpleNamespace())
here = getattr(here, p)
setattr(here, last, w[k])
def layer_norm(self, x, w):
return F.layer_norm(x, (self.args.n_embd,), weight=w.weight, bias=w.bias)
@MyFunction
def channel_mixing(self, x, state, i:int, time_mix_k, time_mix_r, kw, vw, rw):
i0 = (2+self.head_size)*i+0
xk = x * time_mix_k + state[i0] * (1 - time_mix_k)
xr = x * time_mix_r + state[i0] * (1 - time_mix_r)
state[i0] = x
r = torch.sigmoid(rw @ xr)
k = torch.square(torch.relu(kw @ xk)) # square relu, primer paper
return r * (vw @ k)
@MyFunction
def time_mixing(self, x, state, i:int, time_mix_k, time_mix_v, time_mix_r, time_mix_g, time_first, time_decay, kw, vw, rw, gw, ow, ln_w, ln_b):
H = self.n_head
S = self.head_size
i1 = (2+S)*i+1
xk = x * time_mix_k + state[i1] * (1 - time_mix_k)
xv = x * time_mix_v + state[i1] * (1 - time_mix_v)
xr = x * time_mix_r + state[i1] * (1 - time_mix_r)
xg = x * time_mix_g + state[i1] * (1 - time_mix_g)
state[i1] = x
r = (rw @ xr).view(H, 1, S)
k = (kw @ xk).view(H, S, 1)
v = (vw @ xv).view(H, 1, S)
g = F.silu(gw @ xg)
s = state[(2+S)*i+2:(2+S)*(i+1), :].reshape(H, S, S)
x = torch.zeros(H, S)
a = k @ v
x = r @ (time_first * a + s)
s = a + time_decay * s
state[(2+S)*i+2:(2+S)*(i+1), :] = s.reshape(S, -1)
x = x.flatten()
x = F.group_norm(x.unsqueeze(0), num_groups=H, weight=ln_w, bias=ln_b, eps = 64e-5).squeeze(0) * g # same as gn(x/8, eps=1e-5)
return ow @ x
def forward(self, token, state):
with torch.no_grad():
if state == None:
state = torch.zeros(self.args.n_layer * (2+self.head_size), self.args.n_embd)
x = self.w.emb.weight[token]
x = self.layer_norm(x, self.w.blocks[0].ln0)
for i in range(self.args.n_layer):
# print(i)
att = self.w.blocks[i].att
x = x + self.time_mixing(self.layer_norm(x, self.w.blocks[i].ln1), state, i,
att.time_mix_k, att.time_mix_v, att.time_mix_r, att.time_mix_g, att.time_faaaa, att.time_decay,
att.key.weight, att.value.weight, att.receptance.weight, att.gate.weight, att.output.weight,
att.ln_x.weight, att.ln_x.bias)
ffn = self.w.blocks[i].ffn
x = x + self.channel_mixing(self.layer_norm(x, self.w.blocks[i].ln2), state, i,
ffn.time_mix_k, ffn.time_mix_r,
ffn.key.weight, ffn.value.weight, ffn.receptance.weight)
x = self.w.head.weight @ self.layer_norm(x, self.w.ln_out)
return x.float(), state
# In[81]:
context = "Q:Do you know datawhalechina?\nA:"
# In[82]:
args.MODEL_NAME
# In[83]:
args.n_layer,args.n_embd
# In[84]:
# args.n_layer = 24
# args.n_embd = 1024
# In[85]:
# args.n_layer = 12
# args.n_embd = 768
# In[86]:
# args.MODEL_NAME='../models/rwkv-5-world-1b5'
# In[87]:
print(f'\nUsing CPU. Loading {args.MODEL_NAME} ...')
model = RWKV_RNN(args)
print(f'\nPreprocessing context (slow version. see v2/rwkv/model.py for fast version)')
init_state = None
# In[88]:
init_state = None
# In[89]:
LENGTH_PER_TRIAL=1024
# In[90]:
for token in tokenizer.encode(context):
init_out, init_state = model.forward(token, init_state)
for TRIAL in range(NUM_TRIALS):
print(f'\n\n--[ Trial {TRIAL} ]-----------------', context, end="")
all_tokens = []
out_last = 0
out, state = init_out.clone(), init_state.clone()
for i in range(LENGTH_PER_TRIAL):
token = sample_logits(out, TEMPERATURE, TOP_P)
all_tokens += [token]
try:
tmp = tokenizer.decode(all_tokens[out_last:])
if '\ufffd' not in tmp: # only print when we have a valid utf-8 string
print(tmp, end="", flush=True)
out_last = i + 1
except:
pass
out, state = model.forward(token, state)
print('\n')
# 显然datawhale这个数据没有训练过哈哈。不过速度是蛮快的这个没得说在cpu上跑资源消耗也很小。