llms-from-scratch-cn/Model_Architecture_Discussions/rwkv-compare/model_v4.py
2024-05-31 16:44:23 +08:00

190 lines
6.4 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

#!/usr/bin/env python
# coding: utf-8
模型下载链接https://hf-mirror.com/BlinkDL/rwkv-4-pile-430m/resolve/main/RWKV-4-Pile-430M-20220808-8066.pth?download=true
# In[1]:
########################################################################################################
# The RWKV Language Model - https://github.com/BlinkDL/RWKV-LM
########################################################################################################
import numpy as np
np.set_printoptions(precision=4, suppress=True, linewidth=200)
import types, torch
from torch.nn import functional as F
from tokenizers import Tokenizer
# In[2]:
tokenizer = Tokenizer.from_file("20B_tokenizer.json")
args = types.SimpleNamespace()
args.MODEL_NAME = '/data1/ckw/RWKV-4-Pile-430M-20220808-8066'
args.n_layer = 24
args.n_embd = 1024
context = "\nDataWhalechina is an organization founded at Shanghai Jiao Tong University that helps learners learn artificial intelligence."
NUM_TRIALS = 3
LENGTH_PER_TRIAL = 100
TEMPERATURE = 1.0
TOP_P = 0.85
########################################################################################################
# In[3]:
class RWKV_RNN(torch.jit.ScriptModule):
def __init__(self, args):
super().__init__()
self.args = args
self.eval() # set torch to inference mode
w = torch.load(args.MODEL_NAME + '.pth', map_location='cpu')
for k in w.keys():
if '.time_' in k: w[k] = w[k].squeeze()
if '.time_decay' in k: w[k] = -torch.exp(w[k].float()) # the real time decay is like e^{-e^x}
else: w[k] = w[k].float() # convert to f32 type
self.w = types.SimpleNamespace() # set self.w from w
self.w.blocks = {}
for k in w.keys(): # example: "blocks.0.att.time_first" => self.w.blocks[0].att.time_first
parts = k.split('.')
last = parts.pop()
here = self.w
for p in parts:
if p.isdigit():
p = int(p)
if p not in here: here[p] = types.SimpleNamespace()
here = here[p]
else:
if not hasattr(here, p): setattr(here, p, types.SimpleNamespace())
here = getattr(here, p)
setattr(here, last, w[k])
def layer_norm(self, x, w):
return F.layer_norm(x, (self.args.n_embd,), weight=w.weight, bias=w.bias)
@torch.jit.script_method
def channel_mixing(self, x, state, i:int, time_mix_k, time_mix_r, kw, vw, rw):
xk = x * time_mix_k + state[5*i+0] * (1 - time_mix_k)
xr = x * time_mix_r + state[5*i+0] * (1 - time_mix_r)
state[5*i+0] = x
r = torch.sigmoid(rw @ xr)
k = torch.square(torch.relu(kw @ xk)) # square relu, primer paper
return r * (vw @ k)
@torch.jit.script_method
def time_mixing(self, x, state, i:int, time_mix_k, time_mix_v, time_mix_r, time_first, time_decay, kw, vw, rw, ow):
xk = x * time_mix_k + state[5*i+1] * (1 - time_mix_k)
xv = x * time_mix_v + state[5*i+1] * (1 - time_mix_v)
xr = x * time_mix_r + state[5*i+1] * (1 - time_mix_r)
state[5*i+1] = x
r = torch.sigmoid(rw @ xr)
k = kw @ xk
v = vw @ xv
aa = state[5*i+2]
bb = state[5*i+3]
pp = state[5*i+4]
ww = time_first + k
qq = torch.maximum(pp, ww)
e1 = torch.exp(pp - qq)
e2 = torch.exp(ww - qq)
a = e1 * aa + e2 * v
b = e1 * bb + e2
wkv = a / b
ww = pp + time_decay
qq = torch.maximum(ww, k)
e1 = torch.exp(ww - qq)
e2 = torch.exp(k - qq)
state[5*i+2] = e1 * aa + e2 * v
state[5*i+3] = e1 * bb + e2
state[5*i+4] = qq
return ow @ (r * wkv)
def forward(self, token, state):
with torch.no_grad():
if state == None:
state = torch.zeros(self.args.n_layer * 5, self.args.n_embd)
for i in range(self.args.n_layer): state[5*i+4] = -1e30 # -infinity
x = self.w.emb.weight[token]
x = self.layer_norm(x, self.w.blocks[0].ln0)
for i in range(self.args.n_layer):
att = self.w.blocks[i].att
x = x + self.time_mixing(self.layer_norm(x, self.w.blocks[i].ln1), state, i,
att.time_mix_k, att.time_mix_v, att.time_mix_r, att.time_first, att.time_decay,
att.key.weight, att.value.weight, att.receptance.weight, att.output.weight)
ffn = self.w.blocks[i].ffn
x = x + self.channel_mixing(self.layer_norm(x, self.w.blocks[i].ln2), state, i,
ffn.time_mix_k, ffn.time_mix_r,
ffn.key.weight, ffn.value.weight, ffn.receptance.weight)
x = self.w.head.weight @ self.layer_norm(x, self.w.ln_out)
return x.float(), state
##########################################################################################################
# In[4]:
def sample_logits(out, temperature=1.0, top_p=0.8):
probs = F.softmax(out, dim=-1).numpy()
sorted_probs = np.sort(probs)[::-1]
cumulative_probs = np.cumsum(sorted_probs)
cutoff = float(sorted_probs[np.argmax(cumulative_probs > top_p)])
probs[probs < cutoff] = 0
if temperature != 1.0:
probs = probs.pow(1.0 / temperature)
probs = probs / np.sum(probs)
out = np.random.choice(a=len(probs), p=probs)
return out
########################################################################################################
# In[6]:
print(f'\nUsing CPU. Loading {args.MODEL_NAME} ...')
model = RWKV_RNN(args)
# In[7]:
print(f'\nPreprocessing context (slow version. see v2/rwkv/model.py for fast version)')
init_state = None
for token in tokenizer.encode(context).ids:
init_out, init_state = model.forward(token, init_state)
# In[8]:
for TRIAL in range(NUM_TRIALS):
print(f'\n\n--[ Trial {TRIAL} ]-----------------', context, end="")
all_tokens = []
out_last = 0
out, state = init_out.clone(), init_state.clone()
for i in range(LENGTH_PER_TRIAL):
token = sample_logits(out, TEMPERATURE, TOP_P)
all_tokens += [token]
tmp = tokenizer.decode(all_tokens[out_last:])
if '\ufffd' not in tmp: # only print when we have a valid utf-8 string
print(tmp, end="", flush=True)
out_last = i + 1
out, state = model.forward(token, state)
print('\n')
# In[ ]: