mirror of
https://github.com/datawhalechina/llms-from-scratch-cn.git
synced 2026-02-19 17:24:43 +08:00
190 lines
6.4 KiB
Python
190 lines
6.4 KiB
Python
#!/usr/bin/env python
|
||
# coding: utf-8
|
||
模型下载链接:https://hf-mirror.com/BlinkDL/rwkv-4-pile-430m/resolve/main/RWKV-4-Pile-430M-20220808-8066.pth?download=true
|
||
# In[1]:
|
||
|
||
|
||
########################################################################################################
|
||
# The RWKV Language Model - https://github.com/BlinkDL/RWKV-LM
|
||
########################################################################################################
|
||
|
||
import numpy as np
|
||
np.set_printoptions(precision=4, suppress=True, linewidth=200)
|
||
import types, torch
|
||
from torch.nn import functional as F
|
||
from tokenizers import Tokenizer
|
||
|
||
|
||
# In[2]:
|
||
|
||
|
||
tokenizer = Tokenizer.from_file("20B_tokenizer.json")
|
||
|
||
args = types.SimpleNamespace()
|
||
args.MODEL_NAME = '/data1/ckw/RWKV-4-Pile-430M-20220808-8066'
|
||
args.n_layer = 24
|
||
args.n_embd = 1024
|
||
|
||
context = "\nDataWhalechina is an organization founded at Shanghai Jiao Tong University that helps learners learn artificial intelligence."
|
||
NUM_TRIALS = 3
|
||
LENGTH_PER_TRIAL = 100
|
||
TEMPERATURE = 1.0
|
||
TOP_P = 0.85
|
||
########################################################################################################
|
||
|
||
|
||
# In[3]:
|
||
|
||
|
||
class RWKV_RNN(torch.jit.ScriptModule):
|
||
def __init__(self, args):
|
||
super().__init__()
|
||
self.args = args
|
||
self.eval() # set torch to inference mode
|
||
|
||
w = torch.load(args.MODEL_NAME + '.pth', map_location='cpu')
|
||
for k in w.keys():
|
||
if '.time_' in k: w[k] = w[k].squeeze()
|
||
if '.time_decay' in k: w[k] = -torch.exp(w[k].float()) # the real time decay is like e^{-e^x}
|
||
else: w[k] = w[k].float() # convert to f32 type
|
||
|
||
self.w = types.SimpleNamespace() # set self.w from w
|
||
self.w.blocks = {}
|
||
for k in w.keys(): # example: "blocks.0.att.time_first" => self.w.blocks[0].att.time_first
|
||
parts = k.split('.')
|
||
last = parts.pop()
|
||
here = self.w
|
||
for p in parts:
|
||
if p.isdigit():
|
||
p = int(p)
|
||
if p not in here: here[p] = types.SimpleNamespace()
|
||
here = here[p]
|
||
else:
|
||
if not hasattr(here, p): setattr(here, p, types.SimpleNamespace())
|
||
here = getattr(here, p)
|
||
setattr(here, last, w[k])
|
||
|
||
def layer_norm(self, x, w):
|
||
return F.layer_norm(x, (self.args.n_embd,), weight=w.weight, bias=w.bias)
|
||
|
||
@torch.jit.script_method
|
||
def channel_mixing(self, x, state, i:int, time_mix_k, time_mix_r, kw, vw, rw):
|
||
xk = x * time_mix_k + state[5*i+0] * (1 - time_mix_k)
|
||
xr = x * time_mix_r + state[5*i+0] * (1 - time_mix_r)
|
||
state[5*i+0] = x
|
||
r = torch.sigmoid(rw @ xr)
|
||
k = torch.square(torch.relu(kw @ xk)) # square relu, primer paper
|
||
return r * (vw @ k)
|
||
|
||
@torch.jit.script_method
|
||
def time_mixing(self, x, state, i:int, time_mix_k, time_mix_v, time_mix_r, time_first, time_decay, kw, vw, rw, ow):
|
||
xk = x * time_mix_k + state[5*i+1] * (1 - time_mix_k)
|
||
xv = x * time_mix_v + state[5*i+1] * (1 - time_mix_v)
|
||
xr = x * time_mix_r + state[5*i+1] * (1 - time_mix_r)
|
||
state[5*i+1] = x
|
||
r = torch.sigmoid(rw @ xr)
|
||
k = kw @ xk
|
||
v = vw @ xv
|
||
|
||
aa = state[5*i+2]
|
||
bb = state[5*i+3]
|
||
pp = state[5*i+4]
|
||
ww = time_first + k
|
||
qq = torch.maximum(pp, ww)
|
||
e1 = torch.exp(pp - qq)
|
||
e2 = torch.exp(ww - qq)
|
||
a = e1 * aa + e2 * v
|
||
b = e1 * bb + e2
|
||
wkv = a / b
|
||
ww = pp + time_decay
|
||
qq = torch.maximum(ww, k)
|
||
e1 = torch.exp(ww - qq)
|
||
e2 = torch.exp(k - qq)
|
||
state[5*i+2] = e1 * aa + e2 * v
|
||
state[5*i+3] = e1 * bb + e2
|
||
state[5*i+4] = qq
|
||
return ow @ (r * wkv)
|
||
|
||
def forward(self, token, state):
|
||
with torch.no_grad():
|
||
if state == None:
|
||
state = torch.zeros(self.args.n_layer * 5, self.args.n_embd)
|
||
for i in range(self.args.n_layer): state[5*i+4] = -1e30 # -infinity
|
||
|
||
x = self.w.emb.weight[token]
|
||
x = self.layer_norm(x, self.w.blocks[0].ln0)
|
||
for i in range(self.args.n_layer):
|
||
att = self.w.blocks[i].att
|
||
x = x + self.time_mixing(self.layer_norm(x, self.w.blocks[i].ln1), state, i,
|
||
att.time_mix_k, att.time_mix_v, att.time_mix_r, att.time_first, att.time_decay,
|
||
att.key.weight, att.value.weight, att.receptance.weight, att.output.weight)
|
||
ffn = self.w.blocks[i].ffn
|
||
x = x + self.channel_mixing(self.layer_norm(x, self.w.blocks[i].ln2), state, i,
|
||
ffn.time_mix_k, ffn.time_mix_r,
|
||
ffn.key.weight, ffn.value.weight, ffn.receptance.weight)
|
||
|
||
x = self.w.head.weight @ self.layer_norm(x, self.w.ln_out)
|
||
return x.float(), state
|
||
|
||
##########################################################################################################
|
||
|
||
|
||
# In[4]:
|
||
|
||
|
||
def sample_logits(out, temperature=1.0, top_p=0.8):
|
||
probs = F.softmax(out, dim=-1).numpy()
|
||
sorted_probs = np.sort(probs)[::-1]
|
||
cumulative_probs = np.cumsum(sorted_probs)
|
||
cutoff = float(sorted_probs[np.argmax(cumulative_probs > top_p)])
|
||
probs[probs < cutoff] = 0
|
||
if temperature != 1.0:
|
||
probs = probs.pow(1.0 / temperature)
|
||
probs = probs / np.sum(probs)
|
||
out = np.random.choice(a=len(probs), p=probs)
|
||
return out
|
||
|
||
########################################################################################################
|
||
|
||
|
||
# In[6]:
|
||
|
||
|
||
print(f'\nUsing CPU. Loading {args.MODEL_NAME} ...')
|
||
model = RWKV_RNN(args)
|
||
|
||
|
||
# In[7]:
|
||
|
||
|
||
print(f'\nPreprocessing context (slow version. see v2/rwkv/model.py for fast version)')
|
||
init_state = None
|
||
for token in tokenizer.encode(context).ids:
|
||
init_out, init_state = model.forward(token, init_state)
|
||
|
||
|
||
# In[8]:
|
||
|
||
|
||
for TRIAL in range(NUM_TRIALS):
|
||
print(f'\n\n--[ Trial {TRIAL} ]-----------------', context, end="")
|
||
all_tokens = []
|
||
out_last = 0
|
||
out, state = init_out.clone(), init_state.clone()
|
||
for i in range(LENGTH_PER_TRIAL):
|
||
token = sample_logits(out, TEMPERATURE, TOP_P)
|
||
all_tokens += [token]
|
||
tmp = tokenizer.decode(all_tokens[out_last:])
|
||
if '\ufffd' not in tmp: # only print when we have a valid utf-8 string
|
||
print(tmp, end="", flush=True)
|
||
out_last = i + 1
|
||
out, state = model.forward(token, state)
|
||
print('\n')
|
||
|
||
|
||
# In[ ]:
|
||
|
||
|
||
|
||
|