This commit is contained in:
kewei 2024-05-31 16:44:23 +08:00
parent b377e1a60d
commit 2bd03665dd
23 changed files with 438582 additions and 0 deletions

View File

@ -0,0 +1,517 @@
########################################################################################################
# The RWKV Language Model - https://github.com/BlinkDL/RWKV-LM
########################################################################################################
import math
import logging
import torch
import torch.nn as nn
from torch.nn import functional as F
logger = logging.getLogger(__name__)
########################################################################################################
# RWKV: RWKV Time-mix + RWKV Channel-mix
########################################################################################################
def RWKV_Init(module, config): # fancy initialization of all lin & emb layer in the module
for m in module.modules():
if not isinstance(m, (nn.Linear, nn.Embedding)):
continue
with torch.no_grad():
name = '[unknown weight]'
for name, parameter in module.named_parameters(): # find the name of the weight
if id(m.weight) == id(parameter):
break
shape = m.weight.data.shape
gain = 1.0 # positive: gain for orthogonal, negative: std for normal
scale = 1.0 # extra scale for gain
if isinstance(m, nn.Linear):
if m.bias is not None:
m.bias.data.zero_()
if shape[0] > shape[1]:
gain = math.sqrt(shape[0] / shape[1])
if shape[0] == config.vocab_size and shape[1] == config.n_embd: # final projection?
scale = config.rwkv_emb_scale
if isinstance(m, nn.Embedding):
gain = math.sqrt(max(shape[0], shape[1]))
if shape[0] == config.vocab_size and shape[1] == config.n_embd: # token emb?
scale = config.rwkv_emb_scale
if hasattr(m, 'scale_init'):
scale = m.scale_init
print(str(shape[0]).ljust(5), str(shape[1]).ljust(5), f'{round(scale,2):g}'.ljust(4), name)
gain *= scale
if gain == 0:
nn.init.zeros_(m.weight) # zero init is great for some RWKV matrices
elif gain > 0:
nn.init.orthogonal_(m.weight, gain=gain)
else:
nn.init.normal_(m.weight, mean=0, std=-gain)
class RWKV_TimeMix(nn.Module):
def __init__(self, config, layer_id):
super().__init__()
assert config.n_attn % config.n_head == 0
self.layer_id = layer_id
self.ctx_len = config.ctx_len
self.n_head = config.n_head
self.head_size = config.n_attn // config.n_head
with torch.no_grad(): # initial time_w curves for better convergence
ww = torch.ones(config.n_head, config.ctx_len)
curve = torch.tensor([-(config.ctx_len - 1 - i) for i in range(config.ctx_len)]) # the distance
for h in range(config.n_head):
if h < config.n_head - 1:
decay_speed = math.pow(config.ctx_len, -(h+1)/(config.n_head-1))
else:
decay_speed = 0.0
ww[h] = torch.exp(curve * decay_speed)
# print('layer', layer_id, 'head', h, 'decay_speed', round(decay_speed, 4), ww[h][:5].numpy(), '...', ww[h][-5:].numpy())
self.time_w = nn.Parameter(ww)
self.time_alpha = nn.Parameter(torch.ones(self.n_head, 1, config.ctx_len))
self.time_beta = nn.Parameter(torch.ones(self.n_head, config.ctx_len, 1))
self.time_gamma = nn.Parameter(torch.ones(config.ctx_len, 1))
self.time_shift = nn.ZeroPad2d((0,0,1,-1))
self.key = nn.Linear(config.n_embd, config.n_attn)
self.value = nn.Linear(config.n_embd, config.n_attn)
self.receptance = nn.Linear(config.n_embd, config.n_attn)
# if config.rwkv_tiny_attn > 0:
# self.tiny_att = RWKV_TinyAttn(config)
self.output = nn.Linear(config.n_attn, config.n_embd)
self.key.scale_init = 0
self.receptance.scale_init = 0
self.output.scale_init = 0
def forward(self, x):
B, T, C = x.size()
TT = self.ctx_len
w = F.pad(self.time_w, (0, TT))
w = torch.tile(w, [TT])
w = w[:, :-TT].reshape(-1, TT, 2 * TT - 1)
w = w[:, :, TT-1:] # w is now a circulant matrix
w = w[:, :T, :T] * self.time_alpha[:, :, :T] * self.time_beta[:, :T, :]
x = torch.cat([self.time_shift(x[:, :, :C//2]), x[:, :, C//2:]], dim = -1)
# if hasattr(self, 'tiny_att'):
# tiny_att = self.tiny_att(x, self.mask)
k = self.key(x)
v = self.value(x)
r = self.receptance(x)
k = torch.clamp(k, max=30, min=-60) # clamp extreme values. e^30 = 10^13
k = torch.exp(k)
sum_k = torch.cumsum(k, dim=1)
kv = (k * v).view(B, T, self.n_head, self.head_size)
wkv = (torch.einsum('htu,buhc->bthc', w, kv)).contiguous().view(B, T, -1)
rwkv = torch.sigmoid(r) * wkv / sum_k
rwkv = self.output(rwkv)
# if hasattr(self, 'tiny_att'):
# rwkv += tiny_att
return rwkv * self.time_gamma[:T, :]
class RWKV_ChannelMix(nn.Module):
def __init__(self, config, layer_id):
super().__init__()
self.layer_id = layer_id
self.time_shift = nn.ZeroPad2d((0,0,1,-1))
hidden_sz = 5 * config.n_ffn // 2 # can use smaller hidden_sz because of receptance gating
self.key = nn.Linear(config.n_embd, hidden_sz)
self.value = nn.Linear(config.n_embd, hidden_sz)
self.weight = nn.Linear(hidden_sz, config.n_embd)
self.receptance = nn.Linear(config.n_embd, config.n_embd)
self.receptance.scale_init = 0
self.weight.scale_init = 0
def forward(self, x):
B, T, C = x.size()
x = torch.cat([self.time_shift(x[:, :, :C//2]), x[:, :, C//2:]], dim = -1)
k = self.key(x)
v = self.value(x)
r = self.receptance(x)
wkv = self.weight(F.mish(k) * v) # i find mish is a bit better than gelu
rwkv = torch.sigmoid(r) * wkv
return rwkv
class RWKV_TinyAttn(nn.Module): # extra tiny attention
def __init__(self, config):
super().__init__()
self.d_attn = config.rwkv_tiny_attn
self.n_head = config.rwkv_tiny_head
self.head_size = self.d_attn // self.n_head
self.qkv = nn.Linear(config.n_embd, self.d_attn * 3)
self.out = nn.Linear(self.d_attn, config.n_embd)
def forward(self, x, mask):
B, T, C = x.size()
qkv = self.qkv(x)
q, k, v = qkv.chunk(3, dim = -1)
if self.n_head > 1:
q = q.view(B, T, self.n_head, self.head_size).transpose(1, 2) # (B, T, C) -> (B, nh, T, hs)
k = k.view(B, T, self.n_head, self.head_size).transpose(1, 2) # (B, T, C) -> (B, nh, T, hs)
v = v.view(B, T, self.n_head, self.head_size).transpose(1, 2) # (B, T, C) -> (B, nh, T, hs)
qk = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(self.head_size)) # (B, nh, T, hs) * (B, nh, hs, T) -> (B, nh, T, T)
qk = qk.masked_fill(mask == 0, float('-inf'))
qk = F.softmax(qk, dim = -1)
qkv = qk @ v # (B, nh, T, T) * (B, nh, T, hs) -> (B, nh, T, hs)
if self.n_head > 1:
qkv = qkv.transpose(1, 2).contiguous().view(B, T, -1) # (B, nh, T, hs) -> (B, T, nh, hs) -> (B, T, C)
return self.out(qkv)
########################################################################################################
# MHA_rotary: Multi-head Attention + Rotary Encoding + GeGLU FFN
########################################################################################################
class RotaryEmbedding(torch.nn.Module):
def __init__(self, dim, base=10000):
super().__init__()
inv_freq = 1. / (base ** (torch.arange(0, dim, 2).float() / dim))
self.register_buffer('inv_freq', inv_freq)
self.seq_len_cached = None
self.cos_cached = None
self.sin_cached = None
def forward(self, x, seq_len=None):
if seq_len != self.seq_len_cached:
self.seq_len_cached = seq_len
t = torch.arange(seq_len, device=x.device)
freqs = torch.einsum('i,j->ij', t, self.inv_freq)
emb = torch.cat((freqs, freqs), dim=-1).to(x.device)
self.cos_cached = emb.cos()
self.sin_cached = emb.sin()
return self.cos_cached, self.sin_cached
def rotate_half(x):
x1, x2 = x[..., :x.shape[-1] // 2], x[..., x.shape[-1] // 2:]
return torch.cat((-x2, x1), -1)
@torch.jit.script
def apply_rotary_pos_emb(q, k, cos, sin):
cos, sin = cos[...,:q.shape[-2],:], sin[...,:q.shape[-2],:]
return (q * cos) + (rotate_half(q) * sin), (k * cos) + (rotate_half(k) * sin)
class MHA_rotary(nn.Module):
def __init__(self, config, layer_id, time_shift = False):
super().__init__()
self.layer_id = layer_id
assert config.n_attn % config.n_head == 0
self.n_head = config.n_head
self.ctx_len = config.ctx_len
self.head_size = config.n_attn // config.n_head
if time_shift:
self.time_shift = nn.ZeroPad2d((0,0,1,-1))
self.query = nn.Linear(config.n_embd, config.n_attn)
self.key = nn.Linear(config.n_embd, config.n_attn)
self.value = nn.Linear(config.n_embd, config.n_attn)
self.register_buffer("mask", torch.tril(torch.ones(config.ctx_len, config.ctx_len)))
self.rotary_ndims = int(self.head_size * 0.5)
self.rotary_emb = RotaryEmbedding(self.rotary_ndims)
self.output = nn.Linear(config.n_attn, config.n_embd)
def forward(self, x):
B, T, C = x.size()
if hasattr(self, 'time_shift'):
x = torch.cat([self.time_shift(x[:, :, :C//2]), x[:, :, C//2:]], dim = -1)
q = self.query(x).view(B, T, self.n_head, self.head_size).transpose(1, 2) # (B, T, C) -> (B, nh, T, hs)
k = self.key(x).view(B, T, self.n_head, self.head_size).transpose(1, 2) # (B, T, C) -> (B, nh, T, hs)
v = self.value(x).view(B, T, self.n_head, self.head_size).transpose(1, 2) # (B, T, C) -> (B, nh, T, hs)
q, query_pass = q[..., :self.rotary_ndims], q[..., self.rotary_ndims:]
k, key_pass = k[..., :self.rotary_ndims], k[..., self.rotary_ndims:]
cos, sin = self.rotary_emb(q, seq_len=T)
q, k = apply_rotary_pos_emb(q, k, cos, sin) # rotary encoding
q = torch.cat((q, query_pass), dim=-1)
k = torch.cat((k, key_pass), dim=-1)
att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1))) # self-attention: (B, nh, T, hs) * (B, nh, hs, T) -> (B, nh, T, T)
att = att.masked_fill(self.mask[:T,:T] == 0, float('-inf')) # causal mask
att = F.softmax(att, dim = -1) # softmax
x = att @ v # (B, nh, T, T) * (B, nh, T, hs) -> (B, nh, T, hs)
x = x.transpose(1, 2).contiguous().view(B, T, -1) # (B, nh, T, hs) -> (B, T, nh, hs) -> (B, T, C)
x = self.output(x)
return x
class GeGLU(torch.nn.Module):
def __init__(self, config, layer_id, time_shift = False):
super().__init__()
self.layer_id = layer_id
if time_shift:
self.time_shift = nn.ZeroPad2d((0,0,1,-1))
hidden_sz = 3 * config.n_ffn
self.key = nn.Linear(config.n_embd, hidden_sz)
self.value = nn.Linear(config.n_embd, hidden_sz)
self.weight = nn.Linear(hidden_sz, config.n_embd)
def forward(self, x):
B, T, C = x.size()
if hasattr(self, 'time_shift'):
x = torch.cat([self.time_shift(x[:, :, :C//2]), x[:, :, C//2:]], dim = -1)
k = self.key(x)
v = self.value(x)
y = self.weight(F.gelu(k) * v)
return y
########################################################################################################
# MHA_pro: with more tricks
########################################################################################################
class MHA_pro(nn.Module):
def __init__(self, config, layer_id):
super().__init__()
self.layer_id = layer_id
assert config.n_attn % config.n_head == 0
self.n_head = config.n_head
self.ctx_len = config.ctx_len
self.head_size = config.n_attn // config.n_head
self.time_w = nn.Parameter(torch.ones(self.n_head, config.ctx_len))
self.time_alpha = nn.Parameter(torch.ones(self.n_head, 1, config.ctx_len))
self.time_beta = nn.Parameter(torch.ones(self.n_head, config.ctx_len, 1))
self.time_gamma = nn.Parameter(torch.ones(config.ctx_len, 1))
self.register_buffer("mask", torch.tril(torch.ones(config.ctx_len, config.ctx_len)))
self.time_shift = nn.ZeroPad2d((0,0,1,-1))
self.query = nn.Linear(config.n_embd, config.n_attn)
self.key = nn.Linear(config.n_embd, config.n_attn)
self.value = nn.Linear(config.n_embd, config.n_attn)
self.rotary_ndims = int(self.head_size * 0.5)
self.rotary_emb = RotaryEmbedding(self.rotary_ndims)
self.head_mix = nn.Conv2d(self.n_head, self.n_head, kernel_size=1, bias=False) # talking heads
self.output = nn.Linear(config.n_attn, config.n_embd)
def forward(self, x):
B, T, C = x.size()
TT = self.ctx_len
w = F.pad(self.time_w, (0, TT))
w = torch.tile(w, [TT])
w = w[:, :-TT].reshape(-1, TT, 2 * TT - 1)
w = w[:, :, TT-1:] # w is now a circulant matrix
w = w[:, :T, :T] * self.time_alpha[:, :, :T] * self.time_beta[:, :T, :]
x = torch.cat([self.time_shift(x[:, :, :C//2]), x[:, :, C//2:]], dim = -1) # time-shift mixing
q = self.query(x).view(B, T, self.n_head, self.head_size).transpose(1, 2) # (B, T, C) -> (B, nh, T, hs)
k = self.key(x).view(B, T, self.n_head, self.head_size).transpose(1, 2) # (B, T, C) -> (B, nh, T, hs)
v = self.value(x).view(B, T, self.n_head, self.head_size).transpose(1, 2) # (B, T, C) -> (B, nh, T, hs)
q, query_pass = q[..., :self.rotary_ndims], q[..., self.rotary_ndims:]
k, key_pass = k[..., :self.rotary_ndims], k[..., self.rotary_ndims:]
cos, sin = self.rotary_emb(q, seq_len=T)
q, k = apply_rotary_pos_emb(q, k, cos, sin) # rotary encoding
q = torch.cat((q, query_pass), dim=-1)
k = torch.cat((k, key_pass), dim=-1)
att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1))) # self-attention: (B, nh, T, hs) * (B, nh, hs, T) -> (B, nh, T, T)
att = att.masked_fill(self.mask[:T,:T] == 0, float('-inf')) # causal mask
att = F.softmax(att, dim = -1) # softmax
att = att * w # time-weighting
att = self.head_mix(att) # talking heads
x = att @ v # (B, nh, T, T) * (B, nh, T, hs) -> (B, nh, T, hs)
x = x.transpose(1, 2).contiguous().view(B, T, -1) # (B, nh, T, hs) -> (B, T, nh, hs) -> (B, T, C)
x = self.output(x) * self.time_gamma[:T, :]
return x
########################################################################################################
# The GPT Model with our blocks
########################################################################################################
class RMSNorm(nn.Module):
def __init__(self, d):
super().__init__()
self.dd = d ** (-1. / 2)
self.weight = nn.Parameter(torch.ones(d))
def forward(self, x):
norm_x = x.norm(2, dim=-1, keepdim=True)
x_normed = x / (norm_x * self.dd + 1e-12)
return self.weight * x_normed
class FixedNorm(nn.Module):
def __init__(self, d):
super().__init__()
self.dd = d ** (-1. / 2)
def forward(self, x):
norm_x = x.norm(2, dim=-1, keepdim=True)
x_normed = x / (norm_x * self.dd + 1e-12)
return x_normed
########################################################################################################
class GPTConfig:
def __init__(self, vocab_size, ctx_len, **kwargs):
self.vocab_size = vocab_size
self.ctx_len = ctx_len
for k,v in kwargs.items():
setattr(self, k, v)
class Block(nn.Module):
def __init__(self, config, layer_id):
super().__init__()
self.config = config
self.ln1 = nn.LayerNorm(config.n_embd)
self.ln2 = nn.LayerNorm(config.n_embd)
if config.model_type == 'RWKV':
# self.ln1 = FixedNorm(config.n_embd)
# self.ln2 = FixedNorm(config.n_embd)
self.attn = RWKV_TimeMix(config, layer_id)
self.mlp = RWKV_ChannelMix(config, layer_id)
elif config.model_type == 'MHA_rotary':
self.attn = MHA_rotary(config, layer_id)
self.mlp = GeGLU(config, layer_id)
elif config.model_type == 'MHA_shift':
self.attn = MHA_rotary(config, layer_id, time_shift=True)
self.mlp = GeGLU(config, layer_id, time_shift=True)
elif config.model_type == 'MHA_pro':
self.attn = MHA_pro(config, layer_id)
self.mlp = RWKV_ChannelMix(config, layer_id)
def forward(self, x):
x = x + self.attn(self.ln1(x))
x = x + self.mlp(self.ln2(x))
return x
class GPT(nn.Module):
def __init__(self, config):
super().__init__()
self.config = config
self.tok_emb = nn.Embedding(config.vocab_size, config.n_embd)
self.blocks = nn.Sequential(*[Block(config, i) for i in range(config.n_layer)])
self.ln_f = nn.LayerNorm(config.n_embd)
self.time_out = nn.Parameter(torch.ones(1,config.ctx_len,1)) # reduce confidence of early tokens
self.head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
self.head_q = nn.Linear(config.n_embd, 256)
self.head_q.scale_init = 0.01
self.head_k = nn.Linear(config.n_embd, 256)
self.head_k.scale_init = 0.01
self.register_buffer("copy_mask", torch.tril(torch.ones(config.ctx_len, config.ctx_len)))
self.ctx_len = config.ctx_len
if self.config.model_type == 'RWKV':
RWKV_Init(self, config)
else:
self.apply(self._init_weights)
logger.info("number of parameters: %e", sum(p.numel() for p in self.parameters()))
def get_ctx_len(self):
return self.ctx_len
def _init_weights(self, module):
if isinstance(module, (nn.Linear, nn.Embedding)):
module.weight.data.normal_(mean=0.0, std=0.01)
if isinstance(module, nn.Linear) and module.bias is not None:
module.bias.data.zero_()
def configure_optimizers(self, train_config):
# separate out all parameters to those that will and won't experience regularizing weight decay
decay = set()
no_decay = set()
whitelist_weight_modules = (nn.Linear, )
blacklist_weight_modules = (RMSNorm, nn.LayerNorm, nn.Embedding)
for mn, m in self.named_modules():
for pn, p in m.named_parameters():
fpn = '%s.%s' % (mn, pn) if mn else pn # full param name
if pn.endswith('bias') or ('time' in fpn) or ('head' in fpn):
no_decay.add(fpn)
elif pn.endswith('weight') and isinstance(m, whitelist_weight_modules):
decay.add(fpn)
elif pn.endswith('weight') and isinstance(m, blacklist_weight_modules):
no_decay.add(fpn)
# validate that we considered every parameter
param_dict = {pn: p for pn, p in self.named_parameters()}
inter_params = decay & no_decay
union_params = decay | no_decay
assert len(inter_params) == 0, "parameters %s made it into both decay/no_decay sets!" % (str(inter_params), )
assert len(param_dict.keys() - union_params) == 0, "parameters %s were not separated into either decay/no_decay set!" \
% (str(param_dict.keys() - union_params), )
optim_groups = [
{"params": [param_dict[pn] for pn in sorted(list(decay))], "weight_decay": train_config.weight_decay},
{"params": [param_dict[pn] for pn in sorted(list(no_decay))], "weight_decay": 0.0},
]
optimizer = torch.optim.AdamW(optim_groups, lr=train_config.learning_rate, betas=train_config.betas, eps=train_config.eps)
return optimizer
def forward(self, idx, targets=None):
B, T = idx.size()
assert T <= self.ctx_len, "Cannot forward, because len(input) > model ctx_len."
x = self.tok_emb(idx)
x = self.blocks(x)
x = self.ln_f(x)
q = self.head_q(x)[:,:T,:]
k = self.head_k(x)[:,:T,:]
c = (q @ k.transpose(-2, -1)) * (1.0 / 256)
c = c.masked_fill(self.copy_mask[:T,:T] == 0, 0)
c = c @ F.one_hot(idx, num_classes = self.config.vocab_size).float()
x = x * self.time_out[:, :T, :] # reduce confidence of early tokens
x = self.head(x) + c
loss = None
if targets is not None:
loss = F.cross_entropy(x.view(-1, x.size(-1)), targets.view(-1))
return x, loss

View File

@ -0,0 +1,258 @@
########################################################################################################
# The RWKV Language Model - https://github.com/BlinkDL/RWKV-LM
########################################################################################################
import numpy as np
import math, json, time, types, copy, sys, os
import torch
from torch.nn import functional as F
import torch.nn as nn
from transformers import PreTrainedTokenizerFast
# RUN_DEVICE = 'cpu' # cpu cuda
# ctx_len = 768
# n_layer = 12
# n_embd = 768
RUN_DEVICE = 'cpu'
ctx_len = 768
n_layer = 24
n_embd = 1024
MODEL_NAME = '/data1/ckw/20220615-10803'
vocab_size = 50277
VOCAB_NAME = '20B_tokenizer.json'
print(f'\n* running on {RUN_DEVICE}')
################################################################################################################
class RWKV_ChannelMix(nn.Module):
def __init__(self, layer_id):
super().__init__()
self.layer_id = layer_id
self.time_shift = nn.ZeroPad2d((0,0,1,-1))
self.time_mix = nn.Parameter(torch.ones(1, 1, n_embd))
hidden_sz = 4 * n_embd
self.key = nn.Linear(n_embd, hidden_sz, bias=False)
self.receptance = nn.Linear(n_embd, n_embd, bias=False)
self.value = nn.Linear(hidden_sz, n_embd, bias=False)
def forward(self, x):
x = x * self.time_mix + self.time_shift(x) * (1 - self.time_mix)
k = self.key(x)
k = torch.square(torch.relu(k))
kv = self.value(k)
rkv = torch.sigmoid(self.receptance(x)) * kv
return rkv
class RWKV_TimeMix(nn.Module):
def __init__(self, layer_id):
super().__init__()
self.layer_id = layer_id
self.time_decay = nn.Parameter(torch.ones(n_embd, 1))
self.time_curve = torch.tensor([-(ctx_len - 2 - i) for i in range(ctx_len-1)]).unsqueeze(0)
self.time_first = nn.Parameter(torch.ones(n_embd, 1) * math.log(0.3))
self.time_shift = nn.ZeroPad2d((0,0,1,-1))
self.time_mix = nn.Parameter(torch.ones(1,1,n_embd))
self.key = nn.Linear(n_embd, n_embd, bias=False)
self.value = nn.Linear(n_embd, n_embd, bias=False)
self.receptance = nn.Linear(n_embd, n_embd, bias=False)
self.output = nn.Linear(n_embd, n_embd, bias=False)
def forward(self, x):
B, T, C = x.size()
x = x * self.time_mix + self.time_shift(x) * (1 - self.time_mix)
k = self.key(x).transpose(-1, -2)
v = self.value(x).transpose(-1, -2)
r = self.receptance(x)
k = torch.clamp(k, max=60)
k = torch.exp(k)
kv = k * v
self.time_w = torch.cat([torch.exp(self.time_decay) * self.time_curve.to(self.time_decay.device), self.time_first], dim=-1)
w = torch.exp(self.time_w)
w = w[:,-T:].unsqueeze(1)
wkv = F.conv1d(nn.ZeroPad2d((T-1, 0, 0, 0))(kv), w, groups=C)
wk = F.conv1d(nn.ZeroPad2d((T-1, 0, 0, 0))(k), w, groups=C) + 1e-9
rwkv = torch.sigmoid(r) * (wkv / wk).transpose(-1, -2)
rwkv = self.output(rwkv)
return rwkv
class Block(nn.Module):
def __init__(self, layer_id):
super().__init__()
self.layer_id = layer_id
self.ln1 = nn.LayerNorm(n_embd)
self.ln2 = nn.LayerNorm(n_embd)
self.att = RWKV_TimeMix(layer_id)
self.ffn = RWKV_ChannelMix(layer_id)
def forward(self, x):
x = self.ln1(x)
x = x + self.att(x)
x = self.ln2(x)
x = x + self.ffn(x)
return x
class RWKV_GPT(nn.Module):
def __init__(self, MODEL_NAME=MODEL_NAME):
super().__init__()
print('\nloading RWKV-GPT', MODEL_NAME)
self.tokenizer = PreTrainedTokenizerFast(tokenizer_file=VOCAB_NAME)
self.emb = nn.Embedding(vocab_size, n_embd)
self.blocks = nn.Sequential(*[Block(i) for i in range(n_layer)])
self.ln_out = nn.LayerNorm(n_embd)
self.head = nn.Linear(n_embd, vocab_size, bias=False)
self.ctx_len = ctx_len
self.eval()
self.load_state_dict(torch.load(MODEL_NAME + '.pth'))
self.eval()
def forward(self, idx):
B, T = idx.size()
assert T <= self.ctx_len, "Cannot forward, because len(input) > model ctx_len."
x = self.emb(idx)
x = self.blocks(x)
x = self.ln_out(x)
x = self.head(x)
return x
################################################################################################################
time_buf = {}
class RWKV_RNN():
def __init__(self, MODEL_NAME=MODEL_NAME):
print('\nloading RWKV-RNN', MODEL_NAME)
self.ctx_len = ctx_len
self.n_layer = n_layer
self.n_embd = n_embd
self.tokenizer = PreTrainedTokenizerFast(tokenizer_file=VOCAB_NAME)
self.w = types.SimpleNamespace()
w = torch.load(MODEL_NAME + '.pth', map_location=torch.device(RUN_DEVICE))
for x in w.keys():
if '.time_' in x:
w[x] = w[x].squeeze()
if '.time_decay' in x:
w[x] = torch.exp(-torch.exp(w[x]))
if '.time_first' in x:
w[x] = torch.exp(w[x])
xx = x.split('.')
here = self.w
for i in range(len(xx)):
if xx[i].isdigit():
ii = int(xx[i])
if ii not in here:
here[ii] = types.SimpleNamespace()
here = here[ii]
else:
if i == len(xx) - 1:
setattr(here, xx[i], w[x])
elif not hasattr(here, xx[i]):
if xx[i+1].isdigit():
setattr(here, xx[i], {})
else:
setattr(here, xx[i], types.SimpleNamespace())
here = getattr(here, xx[i])
self.clear()
def clear(self):
self.xx = {}
self.aa = {}
self.bb = {}
def save(self, target):
target.xx = copy.deepcopy(self.xx)
target.aa = copy.deepcopy(self.aa)
target.bb = copy.deepcopy(self.bb)
def load(self, target):
self.xx = copy.deepcopy(target.xx)
self.aa = copy.deepcopy(target.aa)
self.bb = copy.deepcopy(target.bb)
def LN(self, xx, w):
return F.layer_norm(xx, (n_embd,), weight=w.weight, bias=w.bias)
def FF(self, xx, w, name):
if name not in self.xx:
self.xx[name] = torch.zeros(n_embd, device=RUN_DEVICE)
x = xx * w.time_mix + self.xx[name] * (1 - w.time_mix)
self.xx[name] = xx
r = torch.sigmoid(w.receptance.weight @ x)
k = torch.square(torch.relu(w.key.weight @ x))
kv = w.value.weight @ k
return r * kv
def SA(self, xx, w, name):
if name not in self.xx:
self.xx[name] = torch.zeros(n_embd, device=RUN_DEVICE)
self.aa[name] = torch.zeros(n_embd, device=RUN_DEVICE)
self.bb[name] = torch.zeros(n_embd, device=RUN_DEVICE)
x = xx * w.time_mix + self.xx[name] * (1 - w.time_mix)
self.xx[name] = xx
r = torch.sigmoid(w.receptance.weight @ x)
k = torch.exp(torch.clamp(w.key.weight @ x, max=60))
v = w.value.weight @ x
kv = k * v
a = self.aa[name] + w.time_first * kv
b = self.bb[name] + w.time_first * k
self.aa[name] = w.time_decay * self.aa[name] + kv
self.bb[name] = w.time_decay * self.bb[name] + k
rwkv = r * a / (b + 1e-9)
return w.output.weight @ rwkv
def run(self, ctx):
w = self.w
x = w.emb.weight[ctx[-1]]
for i in range(n_layer):
x = self.LN(x, w.blocks[i].ln1)
x = x + self.SA(x, w.blocks[i].att, f'att.{i}')
x = self.LN(x, w.blocks[i].ln2)
x = x + self.FF(x, w.blocks[i].ffn, f'ffn.{i}')
x = self.LN(x, w.ln_out)
x = w.head.weight @ x
x = x.tolist()
return x
################################################################################################################

View File

@ -0,0 +1,273 @@
########################################################################################################
# The RWKV Language Model - https://github.com/BlinkDL/RWKV-LM
########################################################################################################
import numpy as np
import math, json, time, types, copy, sys, os
import torch
from torch.nn import functional as F
import torch.nn as nn
from transformers import PreTrainedTokenizerFast
RUN_DEVICE = 'cpu' # cpu cuda
ctx_len = 768
n_layer = 12
n_embd = 768
# n_layer = 24
# n_embd = 1024
# ---> download RWKV-3 169M model from https://huggingface.co/BlinkDL/rwkv-3-pile-169m/tree/main
# MODEL_NAME = '/data1/ckw/RWKV-3-Pile-430M-20220817-10602'
MODEL_NAME = '/data1/ckw/RWKV-3-Pile-20220720-10704'
K_EPS = 1e-8
vocab_size = 50277
VOCAB_NAME = '20B_tokenizer.json'
print(f'\n* running on {RUN_DEVICE}')
################################################################################################################
class RWKV_ChannelMix(nn.Module):
def __init__(self, layer_id):
super().__init__()
self.layer_id = layer_id
self.time_shift = nn.ZeroPad2d((0,0,1,-1))
self.time_mix_k = nn.Parameter(torch.ones(1, 1, n_embd))
self.time_mix_r = nn.Parameter(torch.ones(1, 1, n_embd))
hidden_sz = 4 * n_embd
self.key = nn.Linear(n_embd, hidden_sz, bias=False)
self.receptance = nn.Linear(n_embd, n_embd, bias=False)
self.value = nn.Linear(hidden_sz, n_embd, bias=False)
def forward(self, x):
xx = self.time_shift(x)
xk = x * self.time_mix_k + xx * (1 - self.time_mix_k)
xr = x * self.time_mix_r + xx * (1 - self.time_mix_r)
k = self.key(xk)
k = torch.square(torch.relu(k))
kv = self.value(k)
rkv = torch.sigmoid(self.receptance(xr)) * kv
return rkv
class RWKV_TimeMix(nn.Module):
def __init__(self, layer_id):
super().__init__()
self.layer_id = layer_id
self.time_decay = nn.Parameter(torch.ones(n_embd, 1))
self.time_curve = torch.tensor([-(ctx_len - 2 - i) for i in range(ctx_len-1)]).unsqueeze(0)
self.time_first = nn.Parameter(torch.ones(n_embd, 1) * math.log(0.3))
self.time_shift = nn.ZeroPad2d((0,0,1,-1))
self.time_mix_k = nn.Parameter(torch.ones(1,1,n_embd))
self.time_mix_v = nn.Parameter(torch.ones(1,1,n_embd))
self.time_mix_r = nn.Parameter(torch.ones(1,1,n_embd))
self.key = nn.Linear(n_embd, n_embd, bias=False)
self.value = nn.Linear(n_embd, n_embd, bias=False)
self.receptance = nn.Linear(n_embd, n_embd, bias=False)
self.output = nn.Linear(n_embd, n_embd, bias=False)
def forward(self, x):
B, T, C = x.size()
xx = self.time_shift(x)
xk = x * self.time_mix_k + xx * (1 - self.time_mix_k)
xv = x * self.time_mix_v + xx * (1 - self.time_mix_v)
xr = x * self.time_mix_r + xx * (1 - self.time_mix_r)
k = self.key(xk).transpose(-1, -2)
v = self.value(xv).transpose(-1, -2)
r = self.receptance(xr)
k = torch.clamp(k, max=60)
k = torch.exp(k)
kv = k * v
self.time_w = torch.cat([torch.exp(self.time_decay) * self.time_curve.to(self.time_decay.device), self.time_first], dim=-1)
w = torch.exp(self.time_w)
w = w[:,-T:].unsqueeze(1)
wkv = F.conv1d(nn.ZeroPad2d((T-1, 0, 0, 0))(kv), w, groups=C)
wk = F.conv1d(nn.ZeroPad2d((T-1, 0, 0, 0))(k), w, groups=C) + K_EPS
rwkv = torch.sigmoid(r) * (wkv / wk).transpose(-1, -2)
rwkv = self.output(rwkv)
return rwkv
class Block(nn.Module):
def __init__(self, layer_id):
super().__init__()
self.layer_id = layer_id
self.ln1 = nn.LayerNorm(n_embd)
self.ln2 = nn.LayerNorm(n_embd)
if self.layer_id == 0:
self.ln0 = nn.LayerNorm(n_embd)
self.att = RWKV_TimeMix(layer_id)
self.ffn = RWKV_ChannelMix(layer_id)
def forward(self, x):
if self.layer_id == 0:
x = self.ln0(x)
x = x + self.att(self.ln1(x))
x = x + self.ffn(self.ln2(x))
return x
class RWKV_GPT(nn.Module):
def __init__(self, MODEL_NAME=MODEL_NAME):
super().__init__()
print('\nloading RWKV-GPT', MODEL_NAME)
self.tokenizer = PreTrainedTokenizerFast(tokenizer_file=VOCAB_NAME)
self.emb = nn.Embedding(vocab_size, n_embd)
self.blocks = nn.Sequential(*[Block(i) for i in range(n_layer)])
self.ln_out = nn.LayerNorm(n_embd)
self.head = nn.Linear(n_embd, vocab_size, bias=False)
self.ctx_len = ctx_len
self.eval()
self.load_state_dict(torch.load(MODEL_NAME + '.pth'))
self.eval()
def forward(self, idx):
B, T = idx.size()
assert T <= self.ctx_len, "Cannot forward, because len(input) > model ctx_len."
x = self.emb(idx)
x = self.blocks(x)
x = self.ln_out(x)
x = self.head(x)
return x
################################################################################################################
time_buf = {}
class RWKV_RNN():
def __init__(self, MODEL_NAME=MODEL_NAME):
print('\nloading RWKV-RNN', MODEL_NAME)
self.ctx_len = ctx_len
self.n_layer = n_layer
self.n_embd = n_embd
self.tokenizer = PreTrainedTokenizerFast(tokenizer_file=VOCAB_NAME)
self.w = types.SimpleNamespace()
w = torch.load(MODEL_NAME + '.pth', map_location=torch.device(RUN_DEVICE))
for x in w.keys():
if '.time_' in x:
w[x] = w[x].squeeze()
if '.time_decay' in x:
w[x] = torch.exp(-torch.exp(w[x]))
if '.time_first' in x:
w[x] = torch.exp(w[x])
xx = x.split('.')
here = self.w
for i in range(len(xx)):
if xx[i].isdigit():
ii = int(xx[i])
if ii not in here:
here[ii] = types.SimpleNamespace()
here = here[ii]
else:
if i == len(xx) - 1:
setattr(here, xx[i], w[x])
elif not hasattr(here, xx[i]):
if xx[i+1].isdigit():
setattr(here, xx[i], {})
else:
setattr(here, xx[i], types.SimpleNamespace())
here = getattr(here, xx[i])
self.clear()
def clear(self):
self.xx = {}
self.aa = {}
self.bb = {}
def save(self, target):
target.xx = copy.deepcopy(self.xx)
target.aa = copy.deepcopy(self.aa)
target.bb = copy.deepcopy(self.bb)
def load(self, target):
self.xx = copy.deepcopy(target.xx)
self.aa = copy.deepcopy(target.aa)
self.bb = copy.deepcopy(target.bb)
def LN(self, xx, w):
return F.layer_norm(xx, (n_embd,), weight=w.weight, bias=w.bias)
def FF(self, xx, w, name):
if name not in self.xx:
self.xx[name] = torch.zeros(n_embd, device=RUN_DEVICE)
xk = xx * w.time_mix_k + self.xx[name] * (1 - w.time_mix_k)
xr = xx * w.time_mix_r + self.xx[name] * (1 - w.time_mix_r)
self.xx[name] = xx
r = torch.sigmoid(w.receptance.weight @ xr)
k = torch.square(torch.relu(w.key.weight @ xk))
kv = w.value.weight @ k
return r * kv
def SA(self, xx, w, name):
if name not in self.xx:
self.xx[name] = torch.zeros(n_embd, device=RUN_DEVICE)
self.aa[name] = torch.zeros(n_embd, device=RUN_DEVICE)
self.bb[name] = torch.zeros(n_embd, device=RUN_DEVICE)
xk = xx * w.time_mix_k + self.xx[name] * (1 - w.time_mix_k)
xv = xx * w.time_mix_v + self.xx[name] * (1 - w.time_mix_v)
xr = xx * w.time_mix_r + self.xx[name] * (1 - w.time_mix_r)
self.xx[name] = xx
r = torch.sigmoid(w.receptance.weight @ xr)
k = torch.exp(torch.clamp(w.key.weight @ xk, max=60))
v = w.value.weight @ xv
kv = k * v
a = self.aa[name] + w.time_first * kv
b = self.bb[name] + w.time_first * k
self.aa[name] = w.time_decay * self.aa[name] + kv
self.bb[name] = w.time_decay * self.bb[name] + k
rwkv = r * a / (b + K_EPS)
return w.output.weight @ rwkv
def run(self, ctx):
w = self.w
x = w.emb.weight[ctx[-1]]
x = self.LN(x, w.blocks[0].ln0)
for i in range(n_layer):
x = x + self.SA(self.LN(x, w.blocks[i].ln1), w.blocks[i].att, f'att.{i}')
x = x + self.FF(self.LN(x, w.blocks[i].ln2), w.blocks[i].ffn, f'ffn.{i}')
x = self.LN(x, w.ln_out)
x = w.head.weight @ x
x = x.tolist()
return x
################################################################################################################

View File

@ -0,0 +1,189 @@
#!/usr/bin/env python
# coding: utf-8
模型下载链接https://hf-mirror.com/BlinkDL/rwkv-4-pile-430m/resolve/main/RWKV-4-Pile-430M-20220808-8066.pth?download=true
# In[1]:
########################################################################################################
# The RWKV Language Model - https://github.com/BlinkDL/RWKV-LM
########################################################################################################
import numpy as np
np.set_printoptions(precision=4, suppress=True, linewidth=200)
import types, torch
from torch.nn import functional as F
from tokenizers import Tokenizer
# In[2]:
tokenizer = Tokenizer.from_file("20B_tokenizer.json")
args = types.SimpleNamespace()
args.MODEL_NAME = '/data1/ckw/RWKV-4-Pile-430M-20220808-8066'
args.n_layer = 24
args.n_embd = 1024
context = "\nDataWhalechina is an organization founded at Shanghai Jiao Tong University that helps learners learn artificial intelligence."
NUM_TRIALS = 3
LENGTH_PER_TRIAL = 100
TEMPERATURE = 1.0
TOP_P = 0.85
########################################################################################################
# In[3]:
class RWKV_RNN(torch.jit.ScriptModule):
def __init__(self, args):
super().__init__()
self.args = args
self.eval() # set torch to inference mode
w = torch.load(args.MODEL_NAME + '.pth', map_location='cpu')
for k in w.keys():
if '.time_' in k: w[k] = w[k].squeeze()
if '.time_decay' in k: w[k] = -torch.exp(w[k].float()) # the real time decay is like e^{-e^x}
else: w[k] = w[k].float() # convert to f32 type
self.w = types.SimpleNamespace() # set self.w from w
self.w.blocks = {}
for k in w.keys(): # example: "blocks.0.att.time_first" => self.w.blocks[0].att.time_first
parts = k.split('.')
last = parts.pop()
here = self.w
for p in parts:
if p.isdigit():
p = int(p)
if p not in here: here[p] = types.SimpleNamespace()
here = here[p]
else:
if not hasattr(here, p): setattr(here, p, types.SimpleNamespace())
here = getattr(here, p)
setattr(here, last, w[k])
def layer_norm(self, x, w):
return F.layer_norm(x, (self.args.n_embd,), weight=w.weight, bias=w.bias)
@torch.jit.script_method
def channel_mixing(self, x, state, i:int, time_mix_k, time_mix_r, kw, vw, rw):
xk = x * time_mix_k + state[5*i+0] * (1 - time_mix_k)
xr = x * time_mix_r + state[5*i+0] * (1 - time_mix_r)
state[5*i+0] = x
r = torch.sigmoid(rw @ xr)
k = torch.square(torch.relu(kw @ xk)) # square relu, primer paper
return r * (vw @ k)
@torch.jit.script_method
def time_mixing(self, x, state, i:int, time_mix_k, time_mix_v, time_mix_r, time_first, time_decay, kw, vw, rw, ow):
xk = x * time_mix_k + state[5*i+1] * (1 - time_mix_k)
xv = x * time_mix_v + state[5*i+1] * (1 - time_mix_v)
xr = x * time_mix_r + state[5*i+1] * (1 - time_mix_r)
state[5*i+1] = x
r = torch.sigmoid(rw @ xr)
k = kw @ xk
v = vw @ xv
aa = state[5*i+2]
bb = state[5*i+3]
pp = state[5*i+4]
ww = time_first + k
qq = torch.maximum(pp, ww)
e1 = torch.exp(pp - qq)
e2 = torch.exp(ww - qq)
a = e1 * aa + e2 * v
b = e1 * bb + e2
wkv = a / b
ww = pp + time_decay
qq = torch.maximum(ww, k)
e1 = torch.exp(ww - qq)
e2 = torch.exp(k - qq)
state[5*i+2] = e1 * aa + e2 * v
state[5*i+3] = e1 * bb + e2
state[5*i+4] = qq
return ow @ (r * wkv)
def forward(self, token, state):
with torch.no_grad():
if state == None:
state = torch.zeros(self.args.n_layer * 5, self.args.n_embd)
for i in range(self.args.n_layer): state[5*i+4] = -1e30 # -infinity
x = self.w.emb.weight[token]
x = self.layer_norm(x, self.w.blocks[0].ln0)
for i in range(self.args.n_layer):
att = self.w.blocks[i].att
x = x + self.time_mixing(self.layer_norm(x, self.w.blocks[i].ln1), state, i,
att.time_mix_k, att.time_mix_v, att.time_mix_r, att.time_first, att.time_decay,
att.key.weight, att.value.weight, att.receptance.weight, att.output.weight)
ffn = self.w.blocks[i].ffn
x = x + self.channel_mixing(self.layer_norm(x, self.w.blocks[i].ln2), state, i,
ffn.time_mix_k, ffn.time_mix_r,
ffn.key.weight, ffn.value.weight, ffn.receptance.weight)
x = self.w.head.weight @ self.layer_norm(x, self.w.ln_out)
return x.float(), state
##########################################################################################################
# In[4]:
def sample_logits(out, temperature=1.0, top_p=0.8):
probs = F.softmax(out, dim=-1).numpy()
sorted_probs = np.sort(probs)[::-1]
cumulative_probs = np.cumsum(sorted_probs)
cutoff = float(sorted_probs[np.argmax(cumulative_probs > top_p)])
probs[probs < cutoff] = 0
if temperature != 1.0:
probs = probs.pow(1.0 / temperature)
probs = probs / np.sum(probs)
out = np.random.choice(a=len(probs), p=probs)
return out
########################################################################################################
# In[6]:
print(f'\nUsing CPU. Loading {args.MODEL_NAME} ...')
model = RWKV_RNN(args)
# In[7]:
print(f'\nPreprocessing context (slow version. see v2/rwkv/model.py for fast version)')
init_state = None
for token in tokenizer.encode(context).ids:
init_out, init_state = model.forward(token, init_state)
# In[8]:
for TRIAL in range(NUM_TRIALS):
print(f'\n\n--[ Trial {TRIAL} ]-----------------', context, end="")
all_tokens = []
out_last = 0
out, state = init_out.clone(), init_state.clone()
for i in range(LENGTH_PER_TRIAL):
token = sample_logits(out, TEMPERATURE, TOP_P)
all_tokens += [token]
tmp = tokenizer.decode(all_tokens[out_last:])
if '\ufffd' not in tmp: # only print when we have a valid utf-8 string
print(tmp, end="", flush=True)
out_last = i + 1
out, state = model.forward(token, state)
print('\n')
# In[ ]:

View File

@ -0,0 +1,351 @@
#!/usr/bin/env python
# coding: utf-8
# In[1]:
########################################################################################################
# The RWKV Language Model - https://github.com/BlinkDL/RWKV-LM
########################################################################################################
import numpy as np
np.set_printoptions(precision=4, suppress=True, linewidth=200)
import types, torch
import torch.nn as nn
from torch.nn import functional as F
MyModule = torch.jit.ScriptModule
MyFunction = torch.jit.script_method
# In[2]:
import torch
# In[3]:
class RWKV_TOKENIZER():
table: list[list[list[bytes]]]
good: list[set[int]]
wlen: list[int]
def __init__(self, file_name):
self.idx2token = {}
sorted = [] # must be already sorted
lines = open(file_name, "r", encoding="utf-8").readlines()
for l in lines:
idx = int(l[:l.index(' ')])
x = eval(l[l.index(' '):l.rindex(' ')])
x = x.encode("utf-8") if isinstance(x, str) else x
assert isinstance(x, bytes)
assert len(x) == int(l[l.rindex(' '):])
sorted += [x]
self.idx2token[idx] = x
self.token2idx = {}
for k, v in self.idx2token.items():
self.token2idx[v] = int(k)
# precompute some tables for fast matching
self.table = [[[] for j in range(256)] for i in range(256)]
self.good = [set() for i in range(256)]
self.wlen = [0 for i in range(256)]
for i in reversed(range(len(sorted))): # reverse order - match longer tokens first
s = sorted[i]
if len(s) >= 2:
s0 = int(s[0])
s1 = int(s[1])
self.table[s0][s1] += [s]
self.wlen[s0] = max(self.wlen[s0], len(s))
self.good[s0].add(s1)
def encodeBytes(self, src: bytes) -> list[int]:
src_len: int = len(src)
tokens: list[int] = []
i: int = 0
while i < src_len:
s: bytes = src[i : i + 1]
if i < src_len - 1:
s1: int = int(src[i + 1])
s0: int = int(src[i])
if s1 in self.good[s0]:
sss: bytes = src[i : i + self.wlen[s0]]
try:
s = next(filter(sss.startswith, self.table[s0][s1]))
except:
pass
tokens.append(self.token2idx[s])
i += len(s)
return tokens
def decodeBytes(self, tokens):
return b''.join(map(lambda i: self.idx2token[i], tokens))
def encode(self, src: str):
return self.encodeBytes(src.encode("utf-8"))
def decode(self, tokens):
return self.decodeBytes(tokens).decode('utf-8')
def printTokens(self, tokens):
for i in tokens:
s = self.idx2token[i]
try:
s = s.decode('utf-8')
except:
pass
print(f'{repr(s)}{i}', end=' ')
# print(repr(s), i)
print()
########################################################################################################
# In[4]:
def sample_logits(out, temperature=1.0, top_p=0.8):
probs = F.softmax(out, dim=-1).numpy()
sorted_probs = np.sort(probs)[::-1]
cumulative_probs = np.cumsum(sorted_probs)
cutoff = float(sorted_probs[np.argmax(cumulative_probs > top_p)])
probs[probs < cutoff] = 0
if temperature != 1.0:
probs = probs.pow(1.0 / temperature)
probs = probs / np.sum(probs)
out = np.random.choice(a=len(probs), p=probs)
return out
########################################################################################################
可以从这个链接下载模型
https://www.modelscope.cn/models/AI-ModelScope/rwkv-5-world/files
https://www.modelscope.cn/api/v1/models/AI-ModelScope/rwkv-5-world/repo?Revision=master&FilePath=RWKV-5-World-0.1B-v1-20230803-ctx4096.pth
# In[68]:
tokenizer = RWKV_TOKENIZER("./rwkv_vocab_v20230424.txt")
# THIS IS NOW UPDATED TO SUPPORT LATEST RWKV-5 WORLD v2 MODELS
args = types.SimpleNamespace()
args.MODEL_NAME = '/data1/ckw/RWKV-5-World-0.4B-v2-20231113-ctx4096' #这里不用有后缀.pth
args.n_layer = 24
args.n_embd = 1024
args.vocab_size = 65536
# In[69]:
# N_LAYER="12"
# N_EMBD="768"
N_LAYER="24"
N_EMBD="1024"
# In[70]:
# context = "\nElon Musk has"
# context = "\n我们发现"
context = "Q:Do you know datawhalechina?\nA:"
NUM_TRIALS = 3
LENGTH_PER_TRIAL = 100
LENGTH_PER_TRIAL = 4096
TEMPERATURE = 1.0
TOP_P = 0.7
# In[80]:
class RWKV_RNN(MyModule):
def __init__(self, args):
super().__init__()
self.args = args
self.eval() # set torch to inference mode
w = torch.load(args.MODEL_NAME + '.pth', map_location='cpu')
for k in w.keys():
w[k] = w[k].float() # convert to f32 type
if '.time_' in k: w[k] = w[k].squeeze()
if '.time_decay' in k: w[k] = torch.exp(-torch.exp(w[k])).unsqueeze(-1)
if '.time_faaaa' in k: w[k] = w[k].unsqueeze(-1)
self.n_head = w['blocks.0.att.time_decay'].shape[0]
self.head_size = w['blocks.0.ln1.weight'].shape[0] // self.n_head
self.w = types.SimpleNamespace() # set self.w from w
self.w.blocks = {}
for k in w.keys(): # example: "blocks.0.att.time_first" => self.w.blocks[0].att.time_first
parts = k.split('.')
last = parts.pop()
here = self.w
for p in parts:
if p.isdigit():
p = int(p)
if p not in here: here[p] = types.SimpleNamespace()
here = here[p]
else:
if not hasattr(here, p): setattr(here, p, types.SimpleNamespace())
here = getattr(here, p)
setattr(here, last, w[k])
def layer_norm(self, x, w):
return F.layer_norm(x, (self.args.n_embd,), weight=w.weight, bias=w.bias)
@MyFunction
def channel_mixing(self, x, state, i:int, time_mix_k, time_mix_r, kw, vw, rw):
i0 = (2+self.head_size)*i+0
xk = x * time_mix_k + state[i0] * (1 - time_mix_k)
xr = x * time_mix_r + state[i0] * (1 - time_mix_r)
state[i0] = x
r = torch.sigmoid(rw @ xr)
k = torch.square(torch.relu(kw @ xk)) # square relu, primer paper
return r * (vw @ k)
@MyFunction
def time_mixing(self, x, state, i:int, time_mix_k, time_mix_v, time_mix_r, time_mix_g, time_first, time_decay, kw, vw, rw, gw, ow, ln_w, ln_b):
H = self.n_head
S = self.head_size
i1 = (2+S)*i+1
xk = x * time_mix_k + state[i1] * (1 - time_mix_k)
xv = x * time_mix_v + state[i1] * (1 - time_mix_v)
xr = x * time_mix_r + state[i1] * (1 - time_mix_r)
xg = x * time_mix_g + state[i1] * (1 - time_mix_g)
state[i1] = x
r = (rw @ xr).view(H, 1, S)
k = (kw @ xk).view(H, S, 1)
v = (vw @ xv).view(H, 1, S)
g = F.silu(gw @ xg)
s = state[(2+S)*i+2:(2+S)*(i+1), :].reshape(H, S, S)
x = torch.zeros(H, S)
a = k @ v
x = r @ (time_first * a + s)
s = a + time_decay * s
state[(2+S)*i+2:(2+S)*(i+1), :] = s.reshape(S, -1)
x = x.flatten()
x = F.group_norm(x.unsqueeze(0), num_groups=H, weight=ln_w, bias=ln_b, eps = 64e-5).squeeze(0) * g # same as gn(x/8, eps=1e-5)
return ow @ x
def forward(self, token, state):
with torch.no_grad():
if state == None:
state = torch.zeros(self.args.n_layer * (2+self.head_size), self.args.n_embd)
x = self.w.emb.weight[token]
x = self.layer_norm(x, self.w.blocks[0].ln0)
for i in range(self.args.n_layer):
# print(i)
att = self.w.blocks[i].att
x = x + self.time_mixing(self.layer_norm(x, self.w.blocks[i].ln1), state, i,
att.time_mix_k, att.time_mix_v, att.time_mix_r, att.time_mix_g, att.time_faaaa, att.time_decay,
att.key.weight, att.value.weight, att.receptance.weight, att.gate.weight, att.output.weight,
att.ln_x.weight, att.ln_x.bias)
ffn = self.w.blocks[i].ffn
x = x + self.channel_mixing(self.layer_norm(x, self.w.blocks[i].ln2), state, i,
ffn.time_mix_k, ffn.time_mix_r,
ffn.key.weight, ffn.value.weight, ffn.receptance.weight)
x = self.w.head.weight @ self.layer_norm(x, self.w.ln_out)
return x.float(), state
# In[81]:
context = "Q:Do you know datawhalechina?\nA:"
# In[82]:
args.MODEL_NAME
# In[83]:
args.n_layer,args.n_embd
# In[84]:
# args.n_layer = 24
# args.n_embd = 1024
# In[85]:
# args.n_layer = 12
# args.n_embd = 768
# In[86]:
# args.MODEL_NAME='../models/rwkv-5-world-1b5'
# In[87]:
print(f'\nUsing CPU. Loading {args.MODEL_NAME} ...')
model = RWKV_RNN(args)
print(f'\nPreprocessing context (slow version. see v2/rwkv/model.py for fast version)')
init_state = None
# In[88]:
init_state = None
# In[89]:
LENGTH_PER_TRIAL=1024
# In[90]:
for token in tokenizer.encode(context):
init_out, init_state = model.forward(token, init_state)
for TRIAL in range(NUM_TRIALS):
print(f'\n\n--[ Trial {TRIAL} ]-----------------', context, end="")
all_tokens = []
out_last = 0
out, state = init_out.clone(), init_state.clone()
for i in range(LENGTH_PER_TRIAL):
token = sample_logits(out, TEMPERATURE, TOP_P)
all_tokens += [token]
try:
tmp = tokenizer.decode(all_tokens[out_last:])
if '\ufffd' not in tmp: # only print when we have a valid utf-8 string
print(tmp, end="", flush=True)
out_last = i + 1
except:
pass
out, state = model.forward(token, state)
print('\n')
# 显然datawhale这个数据没有训练过哈哈。不过速度是蛮快的这个没得说在cpu上跑资源消耗也很小。

View File

@ -0,0 +1,291 @@
#!/usr/bin/env python
# coding: utf-8
# In[5]:
########################################################################################################
# The RWKV Language Model - https://github.com/BlinkDL/RWKV-LM
########################################################################################################
import numpy as np
np.set_printoptions(precision=4, suppress=True, linewidth=200)
import types, torch
import torch.nn as nn
from torch.nn import functional as F
MyModule = torch.jit.ScriptModule
MyFunction = torch.jit.script_method
# In[6]:
class RWKV_TOKENIZER():
table: list[list[list[bytes]]]
good: list[set[int]]
wlen: list[int]
def __init__(self, file_name):
self.idx2token = {}
sorted = [] # must be already sorted
lines = open(file_name, "r", encoding="utf-8").readlines()
for l in lines:
idx = int(l[:l.index(' ')])
x = eval(l[l.index(' '):l.rindex(' ')])
x = x.encode("utf-8") if isinstance(x, str) else x
assert isinstance(x, bytes)
assert len(x) == int(l[l.rindex(' '):])
sorted += [x]
self.idx2token[idx] = x
self.token2idx = {}
for k, v in self.idx2token.items():
self.token2idx[v] = int(k)
# precompute some tables for fast matching
self.table = [[[] for j in range(256)] for i in range(256)]
self.good = [set() for i in range(256)]
self.wlen = [0 for i in range(256)]
for i in reversed(range(len(sorted))): # reverse order - match longer tokens first
s = sorted[i]
if len(s) >= 2:
s0 = int(s[0])
s1 = int(s[1])
self.table[s0][s1] += [s]
self.wlen[s0] = max(self.wlen[s0], len(s))
self.good[s0].add(s1)
def encodeBytes(self, src: bytes) -> list[int]:
src_len: int = len(src)
tokens: list[int] = []
i: int = 0
while i < src_len:
s: bytes = src[i : i + 1]
if i < src_len - 1:
s1: int = int(src[i + 1])
s0: int = int(src[i])
if s1 in self.good[s0]:
sss: bytes = src[i : i + self.wlen[s0]]
try:
s = next(filter(sss.startswith, self.table[s0][s1]))
except:
pass
tokens.append(self.token2idx[s])
i += len(s)
return tokens
def decodeBytes(self, tokens):
return b''.join(map(lambda i: self.idx2token[i], tokens))
def encode(self, src: str):
return self.encodeBytes(src.encode("utf-8"))
def decode(self, tokens):
return self.decodeBytes(tokens).decode('utf-8')
def printTokens(self, tokens):
for i in tokens:
s = self.idx2token[i]
try:
s = s.decode('utf-8')
except:
pass
print(f'{repr(s)}{i}', end=' ')
# print(repr(s), i)
print()
########################################################################################################
# In[7]:
def sample_logits(out, temperature=1.0, top_p=0.8):
probs = F.softmax(out, dim=-1).numpy()
sorted_probs = np.sort(probs)[::-1]
cumulative_probs = np.cumsum(sorted_probs)
cutoff = float(sorted_probs[np.argmax(cumulative_probs > top_p)])
probs[probs < cutoff] = 0
if temperature != 1.0:
probs = probs.pow(1.0 / temperature)
probs = probs / np.sum(probs)
out = np.random.choice(a=len(probs), p=probs)
return out
########################################################################################################
模型下载地址https://hf-mirror.com/BlinkDL/rwkv-6-world/resolve/main/RWKV-x060-World-1B6-v2.1-20240328-ctx4096.pth
# In[13]:
tokenizer = RWKV_TOKENIZER("./rwkv_vocab_v20230424.txt")
args = types.SimpleNamespace()
args.MODEL_NAME = '/data1/ckw/RWKV-x060-World-1B6-v2.1-20240328-ctx4096'
args.n_layer = 24
args.n_embd = 2048
args.vocab_size = 65536
context = "\nDatawhale is "
# context = "\n我们发现"
NUM_TRIALS = 3
LENGTH_PER_TRIAL = 100
TEMPERATURE = 1.0
TOP_P = 0.7
# In[14]:
class RWKV_RNN(MyModule):
def __init__(self, args):
super().__init__()
self.args = args
self.eval() # set torch to inference mode
w = torch.load(args.MODEL_NAME + '.pth', map_location='cpu')
for k in w.keys():
w[k] = w[k].float() # convert to f32 type
if '.time_' in k: w[k] = w[k].squeeze()
if '.time_faaaa' in k: w[k] = w[k].unsqueeze(-1)
self.n_head = w['blocks.0.att.time_faaaa'].shape[0]
self.head_size = w['blocks.0.ln1.weight'].shape[0] // self.n_head
self.w = types.SimpleNamespace() # set self.w from w
self.w.blocks = {}
for k in w.keys(): # example: "blocks.0.att.time_first" => self.w.blocks[0].att.time_first
parts = k.split('.')
last = parts.pop()
here = self.w
for p in parts:
if p.isdigit():
p = int(p)
if p not in here: here[p] = types.SimpleNamespace()
here = here[p]
else:
if not hasattr(here, p): setattr(here, p, types.SimpleNamespace())
here = getattr(here, p)
setattr(here, last, w[k])
def layer_norm(self, x, w):
return F.layer_norm(x, (self.args.n_embd,), weight=w.weight, bias=w.bias)
@MyFunction
def channel_mixing(self, x, state, i:int, time_maa_k, time_maa_r, kw, vw, rw):
i0 = (2+self.head_size)*i+0
sx = state[i0] - x
xk = x + sx * time_maa_k
xr = x + sx * time_maa_r
state[i0] = x
r = torch.sigmoid(rw @ xr)
k = torch.square(torch.relu(kw @ xk)) # square relu, primer paper
return r * (vw @ k)
@MyFunction
def time_mixing(self, x, state, i:int, x_maa, w_maa, k_maa, v_maa, r_maa, g_maa, tm_w1, tm_w2, td_w1, td_w2, time_first, time_decay, kw, vw, rw, gw, ow, ln_w, ln_b):
H = self.n_head
S = self.head_size
i1 = (2+S)*i+1
sx = state[i1] - x
state[i1] = x
xxx = x + sx * x_maa
xxx = torch.tanh(xxx @ tm_w1).view(5, 1, -1)
xxx = torch.bmm(xxx, tm_w2).view(5, -1)
mw, mk, mv, mr, mg = xxx.unbind(dim=0)
xw = x + sx * (w_maa + mw)
xk = x + sx * (k_maa + mk)
xv = x + sx * (v_maa + mv)
xr = x + sx * (r_maa + mr)
xg = x + sx * (g_maa + mg)
w = (time_decay + (torch.tanh(xw @ td_w1) @ td_w2).float()).view(H, S, 1)
w = torch.exp(-torch.exp(w.float()))
r = (rw @ xr).view(H, 1, S)
k = (kw @ xk).view(H, S, 1)
v = (vw @ xv).view(H, 1, S)
g = F.silu(gw @ xg)
s = state[(2+S)*i+2:(2+S)*(i+1), :].reshape(H, S, S)
x = torch.zeros(H, S)
a = k @ v
x = r @ (time_first * a + s)
s = a + w * s
state[(2+S)*i+2:(2+S)*(i+1), :] = s.reshape(S, -1)
x = x.flatten()
x = F.group_norm(x.unsqueeze(0), num_groups=H, weight=ln_w, bias=ln_b, eps = 64e-5).squeeze(0) * g # same as gn(x/8, eps=1e-5)
return ow @ x
def forward(self, token, state):
with torch.no_grad():
if state == None:
state = torch.zeros(self.args.n_layer * (2+self.head_size), self.args.n_embd)
x = self.w.emb.weight[token]
x = self.layer_norm(x, self.w.blocks[0].ln0)
for i in range(self.args.n_layer):
att = self.w.blocks[i].att
x = x + self.time_mixing(self.layer_norm(x, self.w.blocks[i].ln1), state, i,
att.time_maa_x, att.time_maa_w, att.time_maa_k, att.time_maa_v, att.time_maa_r, att.time_maa_g, att.time_maa_w1, att.time_maa_w2,
att.time_decay_w1, att.time_decay_w2, att.time_faaaa, att.time_decay,
att.key.weight, att.value.weight, att.receptance.weight, att.gate.weight, att.output.weight,
att.ln_x.weight, att.ln_x.bias)
ffn = self.w.blocks[i].ffn
x = x + self.channel_mixing(self.layer_norm(x, self.w.blocks[i].ln2), state, i,
ffn.time_maa_k, ffn.time_maa_r,
ffn.key.weight, ffn.value.weight, ffn.receptance.weight)
x = self.w.head.weight @ self.layer_norm(x, self.w.ln_out)
return x.float(), state
# In[15]:
print(f'\nUsing CPU. Loading {args.MODEL_NAME} ...')
model = RWKV_RNN(args)
print(f'\nPreprocessing context (slow version. see v2/rwkv/model.py for fast version)')
init_state = None
# In[16]:
for token in tokenizer.encode(context):
init_out, init_state = model.forward(token, init_state)
# In[17]:
for TRIAL in range(NUM_TRIALS):
print(f'\n\n--[ Trial {TRIAL} ]-----------------', context, end="")
all_tokens = []
out_last = 0
out, state = init_out.clone(), init_state.clone()
for i in range(LENGTH_PER_TRIAL):
token = sample_logits(out, TEMPERATURE, TOP_P)
all_tokens += [token]
try:
tmp = tokenizer.decode(all_tokens[out_last:])
if '\ufffd' not in tmp: # only print when we have a valid utf-8 string
print(tmp, end="", flush=True)
out_last = i + 1
except:
pass
out, state = model.forward(token, state)
print('\n')
# v6和v5相比感觉更喜欢使用emoj了哈哈

View File

@ -0,0 +1,273 @@
### RWKV 模型版本比较报告
本文档旨在比较 RWKV 模型的六个不同版本v1 至 v6并详细介绍每个版本的特性、改进和性能。以下是对这六个模型版本的详细分析和比较。
---
#### 版本概述
**RWKV v1**
- 初始版本,基础实现 RWKV 时间混合和通道混合模块。
- 主要特性:
- 使用时间混合Time-mix和通道混合Channel-mix模块。
- 采用标准的线性层和嵌入层初始化。
- 使用掩码来处理因果关系。
**RWKV v2**
- 增强版本,改进了时间混合和通道混合的实现。
- 主要改进:
- 优化了模型加载和状态管理。
- 增加了新的归一化方法。
- 提升了训练和推理效率。
**RWKV v3**
- 进一步优化的版本,主要集中在性能提升。
- 主要改进:
- 调整了层数和嵌入维度,提供更灵活的配置选项。
- 增加了预处理步骤,提高了推理效率。
**RWKV v4**
- 增加了对更大规模模型的支持,提升了模型复杂度。
- 主要改进:
- 支持24层和1024维嵌入。
- 增加了更多的参数调优选项。
**RWKV v5**
- 继续提升模型规模和复杂度,并优化了模型架构。
- 主要改进:
- 支持更高的嵌入维度2048
- 引入了新的时间混合和通道混合方法,提升了模型性能。
**RWKV v6**
- 最新版本,综合了前几个版本的改进,并引入了一些新特性。
- 主要改进:
- 增加了对更大词汇表65536的支持。
- 采用了改进的混合方法,提升了推理速度和准确性。
---
#### 详细比较
**1. 架构与实现**
- **时间混合Time-Mix和通道混合Channel-Mix**
- **v1**:基本实现,功能完备。
```python
class RWKV_TimeMix(nn.Module):
def __init__(self, config, layer_id):
super().__init__()
assert config.n_attn % config.n_head == 0
self.layer_id = layer_id
self.ctx_len = config.ctx_len
self.n_head = config.n_head
self.head_size = config.n_attn // config.n_head
with torch.no_grad(): # initial time_w curves for better convergence
ww = torch.ones(config.n_head, config.ctx_len)
curve = torch.tensor([-(config.ctx_len - 1 - i) for i in range(config.ctx_len)]) # the distance
for h in range(config.n_head):
if h < config.n_head - 1:
decay_speed = math.pow(config.ctx_len, -(h+1)/(config.n_head-1))
else:
decay_speed = 0.0
ww[h] = torch.exp(curve * decay_speed)
self.time_w = nn.Parameter(ww)
self.time_alpha = nn.Parameter(torch.ones(self.n_head, 1, config.ctx_len))
self.time_beta = nn.Parameter(torch.ones(self.n_head, config.ctx_len, 1))
self.time_gamma = nn.Parameter(torch.ones(config.ctx_len, 1))
self.time_shift = nn.ZeroPad2d((0,0,1,-1))
self.key = nn.Linear(config.n_embd, config.n_attn)
self.value = nn.Linear(config.n_embd, config.n_attn)
self.receptance = nn.Linear(config.n_embd, config.n_attn)
self.output = nn.Linear(config.n_attn, config.n_embd)
```
- **v2**:优化了时间混合和通道混合,提升了计算效率。
```python
class RWKV_ChannelMix(nn.Module):
def __init__(self, layer_id):
super().__init__()
self.layer_id = layer_id
self.time_shift = nn.ZeroPad2d((0,0,1,-1))
self.time_mix = nn.Parameter(torch.ones(1, 1, n_embd))
hidden_sz = 4 * n_embd
self.key = nn.Linear(n_embd, hidden_sz, bias=False)
self.receptance = nn.Linear(n_embd, n_embd, bias=False)
self.value = nn.Linear(hidden_sz, n_embd, bias=False)
def forward(self, x):
x = x * self.time_mix + self.time_shift(x) * (1 - self.time_mix)
k = self.key(x)
k = torch.square(torch.relu(k))
kv = self.value(k)
rkv = torch.sigmoid(self.receptance(x)) * kv
return rkv
```
- **v3**:进一步优化,并增加了灵活的配置选项。
```python
class RWKV_ChannelMix(nn.Module):
def __init__(self, layer_id):
super().__init__()
self.layer_id = layer_id
self.time_shift = nn.ZeroPad2d((0,0,1,-1))
self.time_mix_k = nn.Parameter(torch.ones(1, 1, n_embd))
self.time_mix_r = nn.Parameter(torch.ones(1, 1, n_embd))
hidden_sz = 4 * n_embd
self.key = nn.Linear(n_embd, hidden_sz, bias=False)
self.receptance = nn.Linear(n_embd, n_embd, bias=False)
self.value = nn.Linear(hidden_sz, n_embd, bias=False)
def forward(self, x):
xx = self.time_shift(x)
xk = x * self.time_mix_k + xx * (1 - self.time_mix_k)
xr = x * self.time_mix_r + xx * (1 - self.time_mix_r)
k = self.key(xk)
k = torch.square(torch.relu(k))
kv = self.value(k)
rkv = torch.sigmoid(self.receptance(xr)) * kv
return rkv
```
- **v4**:支持更大规模模型,提升了时间混合和通道混合的处理能力。
```python
class RWKV_RNN(torch.jit.ScriptModule):
def __init__(self, args):
super().__init__()
self.args = args
self.eval() # set torch to inference mode
w = torch.load(args.MODEL_NAME + '.pth', map_location='cpu')
for k in w.keys():
if '.time_' in k: w[k] = w[k].squeeze()
if '.time_decay' in k: w[k] = -torch.exp(w[k].float()) # the real time decay is like e^{-e^x}
else: w[k] = w[k].float() # convert to f32 type
self.w = types.SimpleNamespace() # set self.w from w
self.w.blocks = {}
for k in w.keys():
parts = k.split('.')
last = parts.pop()
here = self.w
for p in parts:
if p.isdigit():
p = int(p)
if p not in here: here[p] = types.SimpleNamespace()
here = here[p]
else:
if not hasattr(here, p): setattr(here, p, types.SimpleNamespace())
here = getattr(here, p)
setattr(here, last, w[k])
```
- **v5**:引入了新的混合方法,进一步提升了性能。
```python
class RWKV_RNN(MyModule):
def __init__(self, args):
super().__init__()
self.args = args
self.eval() # set torch to inference mode
w = torch.load(args.MODEL_NAME + '.pth', map_location='cpu')
for k in w.keys():
w[k] = w[k].float() # convert to f32 type
if '.time_' in k: w[k] = w[k].squeeze()
if '.time_decay' in k: w[k] = torch.exp(-torch.exp(w[k])).unsqueeze(-1)
if '.time_faaaa' in k: w[k] = w[k].unsqueeze(-1)
self.n_head = w['blocks.0.att.time_decay'].shape[0]
self.head_size = w['blocks.0.ln1.weight'].shape[0] // self.n_head
self.w = types.SimpleNamespace() # set self.w from w
self.w.blocks = {}
for k in w.keys():
parts = k.split('.')
last = parts.pop()
here = self.w
for p in parts:
if p.isdigit():
p = int(p)
if p not in here: here[p] = types.SimpleNamespace()
here = here[p]
else:
if not hasattr(here, p): setattr(here, p, types.SimpleNamespace())
here = getattr(here, p)
setattr(here, last, w[k])
```
- **v6**:改进了混合方法,提升了整体性能和效率。
```python
class RWKV_RNN(MyModule):
def __init__(self, args):
super().__init__()
self.args = args
self.eval() # set torch to inference mode
w = torch.load(args.MODEL_NAME + '.pth', map_location='cpu')
for k in w.keys():
w[k] = w[k].float() # convert to f32 type
if '.time_' in k: w[k] = w[k].squeeze()
if '.time_faaaa' in k: w[k] = w[k].unsqueeze(-1)
self.n_head = w['blocks.0.att.time_faaaa'].shape[0]
self.head_size = w['blocks.0.ln1.weight'].shape[0] // self.n_head
self.w = types.SimpleNamespace() # set self.w from w
self.w.blocks = {}
for k in w.keys():
parts = k.split('.')
last = parts.pop()
here = self.w
for p in parts:
if p.isdigit():
p = int(p)
if p not in here: here[p] = types.SimpleNamespace()
here = here[p]
else:
if not hasattr(here, p): setattr(here, p, types.SimpleNamespace())
here = getattr(here, p)
setattr(here, last, w[k])
```
**2. 模型规模**
- **层数和嵌入维度**
- **v1**:标准配置,适用于基础任务。
- **v2**支持12层和768维嵌入。
- **v3**提供12层和24层选项嵌入维度为768和1024。
- **v4**支持24层和1024维嵌入。
- **v5**嵌入维度增加至2048。
- **v6**:进一步增加模型复杂度,支持更大词汇表。
**3. 性能与效率**
- **推理速度和资源消耗**
- **v1**:基础实现,资源消耗适中。
- **v2**:优化后,推理速度提升。
- **v3**:预处理步骤的增加,提高了推理效率。
- **v4**:更大规模模型下的性能优化。
- **v5**:新的混合方法提升了推理速度和准确性。
- **v6**:综合改进,推理速度和资源利用进一步优化。
**4. 词汇表和上下文长度**
- **词汇表大小和上下文长度支持**
- **v1-v4**:词汇表大小和上下文长度逐步增加。
- **v5**:支持更大上下文长度,适应复杂任务。
- **v6**支持最大65536的词汇表和更长的上下文长度。
---
### 总结
RWKV 模型在每个版本中不断优化和提升,从基础的 v1 到复杂且高效的 v6模型的性能和功能都有了显著的进步。以下是每个版本的推荐使用场景
- **v1**:适用于基础任务和初步研究。
- **v2**:适用于需要更高效率和优化的任务。
- **v3**:适用于需要灵活配置和更高性能的应用。
- **v4**:适用于大规模模型的训练和推理任务。
- **v5**:适用于需要高精度和高效推理的复杂任务。
- **v6**:适用于最前沿的研究和应用,提供最高的性能和效率。
每个版本在其特定的改进点上都为用户提供了更好的选择,根据具体需求选择合适的版本将能充分发挥 RWKV 模型的优势。

View File

@ -0,0 +1,517 @@
########################################################################################################
# The RWKV Language Model - https://github.com/BlinkDL/RWKV-LM
########################################################################################################
import math
import logging
import torch
import torch.nn as nn
from torch.nn import functional as F
logger = logging.getLogger(__name__)
########################################################################################################
# RWKV: RWKV Time-mix + RWKV Channel-mix
########################################################################################################
def RWKV_Init(module, config): # fancy initialization of all lin & emb layer in the module
for m in module.modules():
if not isinstance(m, (nn.Linear, nn.Embedding)):
continue
with torch.no_grad():
name = '[unknown weight]'
for name, parameter in module.named_parameters(): # find the name of the weight
if id(m.weight) == id(parameter):
break
shape = m.weight.data.shape
gain = 1.0 # positive: gain for orthogonal, negative: std for normal
scale = 1.0 # extra scale for gain
if isinstance(m, nn.Linear):
if m.bias is not None:
m.bias.data.zero_()
if shape[0] > shape[1]:
gain = math.sqrt(shape[0] / shape[1])
if shape[0] == config.vocab_size and shape[1] == config.n_embd: # final projection?
scale = config.rwkv_emb_scale
if isinstance(m, nn.Embedding):
gain = math.sqrt(max(shape[0], shape[1]))
if shape[0] == config.vocab_size and shape[1] == config.n_embd: # token emb?
scale = config.rwkv_emb_scale
if hasattr(m, 'scale_init'):
scale = m.scale_init
print(str(shape[0]).ljust(5), str(shape[1]).ljust(5), f'{round(scale,2):g}'.ljust(4), name)
gain *= scale
if gain == 0:
nn.init.zeros_(m.weight) # zero init is great for some RWKV matrices
elif gain > 0:
nn.init.orthogonal_(m.weight, gain=gain)
else:
nn.init.normal_(m.weight, mean=0, std=-gain)
class RWKV_TimeMix(nn.Module):
def __init__(self, config, layer_id):
super().__init__()
assert config.n_attn % config.n_head == 0
self.layer_id = layer_id
self.ctx_len = config.ctx_len
self.n_head = config.n_head
self.head_size = config.n_attn // config.n_head
with torch.no_grad(): # initial time_w curves for better convergence
ww = torch.ones(config.n_head, config.ctx_len)
curve = torch.tensor([-(config.ctx_len - 1 - i) for i in range(config.ctx_len)]) # the distance
for h in range(config.n_head):
if h < config.n_head - 1:
decay_speed = math.pow(config.ctx_len, -(h+1)/(config.n_head-1))
else:
decay_speed = 0.0
ww[h] = torch.exp(curve * decay_speed)
# print('layer', layer_id, 'head', h, 'decay_speed', round(decay_speed, 4), ww[h][:5].numpy(), '...', ww[h][-5:].numpy())
self.time_w = nn.Parameter(ww)
self.time_alpha = nn.Parameter(torch.ones(self.n_head, 1, config.ctx_len))
self.time_beta = nn.Parameter(torch.ones(self.n_head, config.ctx_len, 1))
self.time_gamma = nn.Parameter(torch.ones(config.ctx_len, 1))
self.time_shift = nn.ZeroPad2d((0,0,1,-1))
self.key = nn.Linear(config.n_embd, config.n_attn)
self.value = nn.Linear(config.n_embd, config.n_attn)
self.receptance = nn.Linear(config.n_embd, config.n_attn)
# if config.rwkv_tiny_attn > 0:
# self.tiny_att = RWKV_TinyAttn(config)
self.output = nn.Linear(config.n_attn, config.n_embd)
self.key.scale_init = 0
self.receptance.scale_init = 0
self.output.scale_init = 0
def forward(self, x):
B, T, C = x.size()
TT = self.ctx_len
w = F.pad(self.time_w, (0, TT))
w = torch.tile(w, [TT])
w = w[:, :-TT].reshape(-1, TT, 2 * TT - 1)
w = w[:, :, TT-1:] # w is now a circulant matrix
w = w[:, :T, :T] * self.time_alpha[:, :, :T] * self.time_beta[:, :T, :]
x = torch.cat([self.time_shift(x[:, :, :C//2]), x[:, :, C//2:]], dim = -1)
# if hasattr(self, 'tiny_att'):
# tiny_att = self.tiny_att(x, self.mask)
k = self.key(x)
v = self.value(x)
r = self.receptance(x)
k = torch.clamp(k, max=30, min=-60) # clamp extreme values. e^30 = 10^13
k = torch.exp(k)
sum_k = torch.cumsum(k, dim=1)
kv = (k * v).view(B, T, self.n_head, self.head_size)
wkv = (torch.einsum('htu,buhc->bthc', w, kv)).contiguous().view(B, T, -1)
rwkv = torch.sigmoid(r) * wkv / sum_k
rwkv = self.output(rwkv)
# if hasattr(self, 'tiny_att'):
# rwkv += tiny_att
return rwkv * self.time_gamma[:T, :]
class RWKV_ChannelMix(nn.Module):
def __init__(self, config, layer_id):
super().__init__()
self.layer_id = layer_id
self.time_shift = nn.ZeroPad2d((0,0,1,-1))
hidden_sz = 5 * config.n_ffn // 2 # can use smaller hidden_sz because of receptance gating
self.key = nn.Linear(config.n_embd, hidden_sz)
self.value = nn.Linear(config.n_embd, hidden_sz)
self.weight = nn.Linear(hidden_sz, config.n_embd)
self.receptance = nn.Linear(config.n_embd, config.n_embd)
self.receptance.scale_init = 0
self.weight.scale_init = 0
def forward(self, x):
B, T, C = x.size()
x = torch.cat([self.time_shift(x[:, :, :C//2]), x[:, :, C//2:]], dim = -1)
k = self.key(x)
v = self.value(x)
r = self.receptance(x)
wkv = self.weight(F.mish(k) * v) # i find mish is a bit better than gelu
rwkv = torch.sigmoid(r) * wkv
return rwkv
class RWKV_TinyAttn(nn.Module): # extra tiny attention
def __init__(self, config):
super().__init__()
self.d_attn = config.rwkv_tiny_attn
self.n_head = config.rwkv_tiny_head
self.head_size = self.d_attn // self.n_head
self.qkv = nn.Linear(config.n_embd, self.d_attn * 3)
self.out = nn.Linear(self.d_attn, config.n_embd)
def forward(self, x, mask):
B, T, C = x.size()
qkv = self.qkv(x)
q, k, v = qkv.chunk(3, dim = -1)
if self.n_head > 1:
q = q.view(B, T, self.n_head, self.head_size).transpose(1, 2) # (B, T, C) -> (B, nh, T, hs)
k = k.view(B, T, self.n_head, self.head_size).transpose(1, 2) # (B, T, C) -> (B, nh, T, hs)
v = v.view(B, T, self.n_head, self.head_size).transpose(1, 2) # (B, T, C) -> (B, nh, T, hs)
qk = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(self.head_size)) # (B, nh, T, hs) * (B, nh, hs, T) -> (B, nh, T, T)
qk = qk.masked_fill(mask == 0, float('-inf'))
qk = F.softmax(qk, dim = -1)
qkv = qk @ v # (B, nh, T, T) * (B, nh, T, hs) -> (B, nh, T, hs)
if self.n_head > 1:
qkv = qkv.transpose(1, 2).contiguous().view(B, T, -1) # (B, nh, T, hs) -> (B, T, nh, hs) -> (B, T, C)
return self.out(qkv)
########################################################################################################
# MHA_rotary: Multi-head Attention + Rotary Encoding + GeGLU FFN
########################################################################################################
class RotaryEmbedding(torch.nn.Module):
def __init__(self, dim, base=10000):
super().__init__()
inv_freq = 1. / (base ** (torch.arange(0, dim, 2).float() / dim))
self.register_buffer('inv_freq', inv_freq)
self.seq_len_cached = None
self.cos_cached = None
self.sin_cached = None
def forward(self, x, seq_len=None):
if seq_len != self.seq_len_cached:
self.seq_len_cached = seq_len
t = torch.arange(seq_len, device=x.device)
freqs = torch.einsum('i,j->ij', t, self.inv_freq)
emb = torch.cat((freqs, freqs), dim=-1).to(x.device)
self.cos_cached = emb.cos()
self.sin_cached = emb.sin()
return self.cos_cached, self.sin_cached
def rotate_half(x):
x1, x2 = x[..., :x.shape[-1] // 2], x[..., x.shape[-1] // 2:]
return torch.cat((-x2, x1), -1)
@torch.jit.script
def apply_rotary_pos_emb(q, k, cos, sin):
cos, sin = cos[...,:q.shape[-2],:], sin[...,:q.shape[-2],:]
return (q * cos) + (rotate_half(q) * sin), (k * cos) + (rotate_half(k) * sin)
class MHA_rotary(nn.Module):
def __init__(self, config, layer_id, time_shift = False):
super().__init__()
self.layer_id = layer_id
assert config.n_attn % config.n_head == 0
self.n_head = config.n_head
self.ctx_len = config.ctx_len
self.head_size = config.n_attn // config.n_head
if time_shift:
self.time_shift = nn.ZeroPad2d((0,0,1,-1))
self.query = nn.Linear(config.n_embd, config.n_attn)
self.key = nn.Linear(config.n_embd, config.n_attn)
self.value = nn.Linear(config.n_embd, config.n_attn)
self.register_buffer("mask", torch.tril(torch.ones(config.ctx_len, config.ctx_len)))
self.rotary_ndims = int(self.head_size * 0.5)
self.rotary_emb = RotaryEmbedding(self.rotary_ndims)
self.output = nn.Linear(config.n_attn, config.n_embd)
def forward(self, x):
B, T, C = x.size()
if hasattr(self, 'time_shift'):
x = torch.cat([self.time_shift(x[:, :, :C//2]), x[:, :, C//2:]], dim = -1)
q = self.query(x).view(B, T, self.n_head, self.head_size).transpose(1, 2) # (B, T, C) -> (B, nh, T, hs)
k = self.key(x).view(B, T, self.n_head, self.head_size).transpose(1, 2) # (B, T, C) -> (B, nh, T, hs)
v = self.value(x).view(B, T, self.n_head, self.head_size).transpose(1, 2) # (B, T, C) -> (B, nh, T, hs)
q, query_pass = q[..., :self.rotary_ndims], q[..., self.rotary_ndims:]
k, key_pass = k[..., :self.rotary_ndims], k[..., self.rotary_ndims:]
cos, sin = self.rotary_emb(q, seq_len=T)
q, k = apply_rotary_pos_emb(q, k, cos, sin) # rotary encoding
q = torch.cat((q, query_pass), dim=-1)
k = torch.cat((k, key_pass), dim=-1)
att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1))) # self-attention: (B, nh, T, hs) * (B, nh, hs, T) -> (B, nh, T, T)
att = att.masked_fill(self.mask[:T,:T] == 0, float('-inf')) # causal mask
att = F.softmax(att, dim = -1) # softmax
x = att @ v # (B, nh, T, T) * (B, nh, T, hs) -> (B, nh, T, hs)
x = x.transpose(1, 2).contiguous().view(B, T, -1) # (B, nh, T, hs) -> (B, T, nh, hs) -> (B, T, C)
x = self.output(x)
return x
class GeGLU(torch.nn.Module):
def __init__(self, config, layer_id, time_shift = False):
super().__init__()
self.layer_id = layer_id
if time_shift:
self.time_shift = nn.ZeroPad2d((0,0,1,-1))
hidden_sz = 3 * config.n_ffn
self.key = nn.Linear(config.n_embd, hidden_sz)
self.value = nn.Linear(config.n_embd, hidden_sz)
self.weight = nn.Linear(hidden_sz, config.n_embd)
def forward(self, x):
B, T, C = x.size()
if hasattr(self, 'time_shift'):
x = torch.cat([self.time_shift(x[:, :, :C//2]), x[:, :, C//2:]], dim = -1)
k = self.key(x)
v = self.value(x)
y = self.weight(F.gelu(k) * v)
return y
########################################################################################################
# MHA_pro: with more tricks
########################################################################################################
class MHA_pro(nn.Module):
def __init__(self, config, layer_id):
super().__init__()
self.layer_id = layer_id
assert config.n_attn % config.n_head == 0
self.n_head = config.n_head
self.ctx_len = config.ctx_len
self.head_size = config.n_attn // config.n_head
self.time_w = nn.Parameter(torch.ones(self.n_head, config.ctx_len))
self.time_alpha = nn.Parameter(torch.ones(self.n_head, 1, config.ctx_len))
self.time_beta = nn.Parameter(torch.ones(self.n_head, config.ctx_len, 1))
self.time_gamma = nn.Parameter(torch.ones(config.ctx_len, 1))
self.register_buffer("mask", torch.tril(torch.ones(config.ctx_len, config.ctx_len)))
self.time_shift = nn.ZeroPad2d((0,0,1,-1))
self.query = nn.Linear(config.n_embd, config.n_attn)
self.key = nn.Linear(config.n_embd, config.n_attn)
self.value = nn.Linear(config.n_embd, config.n_attn)
self.rotary_ndims = int(self.head_size * 0.5)
self.rotary_emb = RotaryEmbedding(self.rotary_ndims)
self.head_mix = nn.Conv2d(self.n_head, self.n_head, kernel_size=1, bias=False) # talking heads
self.output = nn.Linear(config.n_attn, config.n_embd)
def forward(self, x):
B, T, C = x.size()
TT = self.ctx_len
w = F.pad(self.time_w, (0, TT))
w = torch.tile(w, [TT])
w = w[:, :-TT].reshape(-1, TT, 2 * TT - 1)
w = w[:, :, TT-1:] # w is now a circulant matrix
w = w[:, :T, :T] * self.time_alpha[:, :, :T] * self.time_beta[:, :T, :]
x = torch.cat([self.time_shift(x[:, :, :C//2]), x[:, :, C//2:]], dim = -1) # time-shift mixing
q = self.query(x).view(B, T, self.n_head, self.head_size).transpose(1, 2) # (B, T, C) -> (B, nh, T, hs)
k = self.key(x).view(B, T, self.n_head, self.head_size).transpose(1, 2) # (B, T, C) -> (B, nh, T, hs)
v = self.value(x).view(B, T, self.n_head, self.head_size).transpose(1, 2) # (B, T, C) -> (B, nh, T, hs)
q, query_pass = q[..., :self.rotary_ndims], q[..., self.rotary_ndims:]
k, key_pass = k[..., :self.rotary_ndims], k[..., self.rotary_ndims:]
cos, sin = self.rotary_emb(q, seq_len=T)
q, k = apply_rotary_pos_emb(q, k, cos, sin) # rotary encoding
q = torch.cat((q, query_pass), dim=-1)
k = torch.cat((k, key_pass), dim=-1)
att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1))) # self-attention: (B, nh, T, hs) * (B, nh, hs, T) -> (B, nh, T, T)
att = att.masked_fill(self.mask[:T,:T] == 0, float('-inf')) # causal mask
att = F.softmax(att, dim = -1) # softmax
att = att * w # time-weighting
att = self.head_mix(att) # talking heads
x = att @ v # (B, nh, T, T) * (B, nh, T, hs) -> (B, nh, T, hs)
x = x.transpose(1, 2).contiguous().view(B, T, -1) # (B, nh, T, hs) -> (B, T, nh, hs) -> (B, T, C)
x = self.output(x) * self.time_gamma[:T, :]
return x
########################################################################################################
# The GPT Model with our blocks
########################################################################################################
class RMSNorm(nn.Module):
def __init__(self, d):
super().__init__()
self.dd = d ** (-1. / 2)
self.weight = nn.Parameter(torch.ones(d))
def forward(self, x):
norm_x = x.norm(2, dim=-1, keepdim=True)
x_normed = x / (norm_x * self.dd + 1e-12)
return self.weight * x_normed
class FixedNorm(nn.Module):
def __init__(self, d):
super().__init__()
self.dd = d ** (-1. / 2)
def forward(self, x):
norm_x = x.norm(2, dim=-1, keepdim=True)
x_normed = x / (norm_x * self.dd + 1e-12)
return x_normed
########################################################################################################
class GPTConfig:
def __init__(self, vocab_size, ctx_len, **kwargs):
self.vocab_size = vocab_size
self.ctx_len = ctx_len
for k,v in kwargs.items():
setattr(self, k, v)
class Block(nn.Module):
def __init__(self, config, layer_id):
super().__init__()
self.config = config
self.ln1 = nn.LayerNorm(config.n_embd)
self.ln2 = nn.LayerNorm(config.n_embd)
if config.model_type == 'RWKV':
# self.ln1 = FixedNorm(config.n_embd)
# self.ln2 = FixedNorm(config.n_embd)
self.attn = RWKV_TimeMix(config, layer_id)
self.mlp = RWKV_ChannelMix(config, layer_id)
elif config.model_type == 'MHA_rotary':
self.attn = MHA_rotary(config, layer_id)
self.mlp = GeGLU(config, layer_id)
elif config.model_type == 'MHA_shift':
self.attn = MHA_rotary(config, layer_id, time_shift=True)
self.mlp = GeGLU(config, layer_id, time_shift=True)
elif config.model_type == 'MHA_pro':
self.attn = MHA_pro(config, layer_id)
self.mlp = RWKV_ChannelMix(config, layer_id)
def forward(self, x):
x = x + self.attn(self.ln1(x))
x = x + self.mlp(self.ln2(x))
return x
class GPT(nn.Module):
def __init__(self, config):
super().__init__()
self.config = config
self.tok_emb = nn.Embedding(config.vocab_size, config.n_embd)
self.blocks = nn.Sequential(*[Block(config, i) for i in range(config.n_layer)])
self.ln_f = nn.LayerNorm(config.n_embd)
self.time_out = nn.Parameter(torch.ones(1,config.ctx_len,1)) # reduce confidence of early tokens
self.head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
self.head_q = nn.Linear(config.n_embd, 256)
self.head_q.scale_init = 0.01
self.head_k = nn.Linear(config.n_embd, 256)
self.head_k.scale_init = 0.01
self.register_buffer("copy_mask", torch.tril(torch.ones(config.ctx_len, config.ctx_len)))
self.ctx_len = config.ctx_len
if self.config.model_type == 'RWKV':
RWKV_Init(self, config)
else:
self.apply(self._init_weights)
logger.info("number of parameters: %e", sum(p.numel() for p in self.parameters()))
def get_ctx_len(self):
return self.ctx_len
def _init_weights(self, module):
if isinstance(module, (nn.Linear, nn.Embedding)):
module.weight.data.normal_(mean=0.0, std=0.01)
if isinstance(module, nn.Linear) and module.bias is not None:
module.bias.data.zero_()
def configure_optimizers(self, train_config):
# separate out all parameters to those that will and won't experience regularizing weight decay
decay = set()
no_decay = set()
whitelist_weight_modules = (nn.Linear, )
blacklist_weight_modules = (RMSNorm, nn.LayerNorm, nn.Embedding)
for mn, m in self.named_modules():
for pn, p in m.named_parameters():
fpn = '%s.%s' % (mn, pn) if mn else pn # full param name
if pn.endswith('bias') or ('time' in fpn) or ('head' in fpn):
no_decay.add(fpn)
elif pn.endswith('weight') and isinstance(m, whitelist_weight_modules):
decay.add(fpn)
elif pn.endswith('weight') and isinstance(m, blacklist_weight_modules):
no_decay.add(fpn)
# validate that we considered every parameter
param_dict = {pn: p for pn, p in self.named_parameters()}
inter_params = decay & no_decay
union_params = decay | no_decay
assert len(inter_params) == 0, "parameters %s made it into both decay/no_decay sets!" % (str(inter_params), )
assert len(param_dict.keys() - union_params) == 0, "parameters %s were not separated into either decay/no_decay set!" \
% (str(param_dict.keys() - union_params), )
optim_groups = [
{"params": [param_dict[pn] for pn in sorted(list(decay))], "weight_decay": train_config.weight_decay},
{"params": [param_dict[pn] for pn in sorted(list(no_decay))], "weight_decay": 0.0},
]
optimizer = torch.optim.AdamW(optim_groups, lr=train_config.learning_rate, betas=train_config.betas, eps=train_config.eps)
return optimizer
def forward(self, idx, targets=None):
B, T = idx.size()
assert T <= self.ctx_len, "Cannot forward, because len(input) > model ctx_len."
x = self.tok_emb(idx)
x = self.blocks(x)
x = self.ln_f(x)
q = self.head_q(x)[:,:T,:]
k = self.head_k(x)[:,:T,:]
c = (q @ k.transpose(-2, -1)) * (1.0 / 256)
c = c.masked_fill(self.copy_mask[:T,:T] == 0, 0)
c = c @ F.one_hot(idx, num_classes = self.config.vocab_size).float()
x = x * self.time_out[:, :T, :] # reduce confidence of early tokens
x = self.head(x) + c
loss = None
if targets is not None:
loss = F.cross_entropy(x.view(-1, x.size(-1)), targets.view(-1))
return x, loss

View File

@ -0,0 +1,235 @@
### RWKV v1 源码详细解读
由于时间久远没有找到对应的预训练脚本对于v1主要做源码分析v2-v6均可以找到载入模型实现的脚本。本文档详细解析 RWKV v1 版本的核心代码,具体包括初始化、时间混合模块、通道混合模块和多头注意力机制。代码使用 PyTorch 实现,具有良好的可读性和扩展性。
---
#### 模型初始化
首先,我们定义了一个用于初始化线性层和嵌入层的函数 `RWKV_Init`
```python
def RWKV_Init(module, config):
for m in module.modules():
if not isinstance(m, (nn.Linear, nn.Embedding)):
continue
with torch.no_grad():
name = '[unknown weight]'
for name, parameter in module.named_parameters():
if id(m.weight) == id(parameter):
break
shape = m.weight.data.shape
gain = 1.0
scale = 1.0
if isinstance(m, nn.Linear):
if m.bias is not None:
m.bias.data.zero_()
if shape[0] > shape[1]:
gain = math.sqrt(shape[0] / shape[1])
if shape[0] == config.vocab_size and shape[1] == config.n_embd:
scale = config.rwkv_emb_scale
if isinstance(m, nn.Embedding):
gain = math.sqrt(max(shape[0], shape[1]))
if shape[0] == config.vocab_size and shape[1] == config.n_embd:
scale = config.rwkv_emb_scale
if hasattr(m, 'scale_init'):
scale = m.scale_init
print(str(shape[0]).ljust(5), str(shape[1]).ljust(5), f'{round(scale,2):g}'.ljust(4), name)
gain *= scale
if gain == 0:
nn.init.zeros_(m.weight)
elif gain > 0:
nn.init.orthogonal_(m.weight, gain=gain)
else:
nn.init.normal_(m.weight, mean=0, std=-gain)
```
该函数遍历模块中的所有线性层和嵌入层,并根据特定条件初始化其权重。对于线性层,若偏置存在,则将其初始化为零;对于嵌入层,计算权重的增益和缩放因子。根据不同的条件使用不同的初始化方法,例如正交初始化或正态初始化。
---
#### 时间混合模块
`RWKV_TimeMix` 类实现了时间混合机制。
```python
class RWKV_TimeMix(nn.Module):
def __init__(self, config, layer_id):
super().__init__()
assert config.n_attn % config.n_head == 0
self.layer_id = layer_id
self.ctx_len = config.ctx_len
self.n_head = config.n_head
self.head_size = config.n_attn // config.n_head
with torch.no_grad():
ww = torch.ones(config.n_head, config.ctx_len)
curve = torch.tensor([-(config.ctx_len - 1 - i) for i in range(config.ctx_len)])
for h in range(config.n_head):
if h < config.n_head - 1:
decay_speed = math.pow(config.ctx_len, -(h+1)/(config.n_head-1))
else:
decay_speed = 0.0
ww[h] = torch.exp(curve * decay_speed)
self.time_w = nn.Parameter(ww)
self.time_alpha = nn.Parameter(torch.ones(self.n_head, 1, config.ctx_len))
self.time_beta = nn.Parameter(torch.ones(self.n_head, config.ctx_len, 1))
self.time_gamma = nn.Parameter(torch.ones(config.ctx_len, 1))
self.time_shift = nn.ZeroPad2d((0,0,1,-1))
self.key = nn.Linear(config.n_embd, config.n_attn)
self.value = nn.Linear(config.n_embd, config.n_attn)
self.receptance = nn.Linear(config.n_embd, config.n_attn)
self.output = nn.Linear(config.n_attn, config.n_embd)
self.key.scale_init = 0
self.receptance.scale_init = 0
self.output.scale_init = 0
def forward(self, x):
B, T, C = x.size()
TT = self.ctx_len
w = F.pad(self.time_w, (0, TT))
w = torch.tile(w, [TT])
w = w[:, :-TT].reshape(-1, TT, 2 * TT - 1)
w = w[:, :, TT-1:]
w = w[:, :T, :T] * self.time_alpha[:, :, :T] * self.time_beta[:, :T, :]
x = torch.cat([self.time_shift(x[:, :, :C//2]), x[:, :, C//2:]], dim = -1)
k = self.key(x)
v = self.value(x)
r = self.receptance(x)
k = torch.clamp(k, max=30, min=-60)
k = torch.exp(k)
sum_k = torch.cumsum(k, dim=1)
kv = (k * v).view(B, T, self.n_head, self.head_size)
wkv = (torch.einsum('htu,buhc->bthc', w, kv)).contiguous().view(B, T, -1)
rwkv = torch.sigmoid(r) * wkv / sum_k
rwkv = self.output(rwkv)
return rwkv * self.time_gamma[:T, :]
```
该类实现了时间混合机制,通过时间权重矩阵 `time_w` 对输入进行变换。`time_w` 矩阵根据头数和上下文长度初始化,随后对输入进行时间维度上的变换和混合。`key`、`value` 和 `receptance` 三个线性层分别生成键、值和接收信号,并通过 sigmoid 函数计算输出。
---
#### 通道混合模块
`RWKV_ChannelMix` 类实现了通道混合机制。
```python
class RWKV_ChannelMix(nn.Module):
def __init__(self, config, layer_id):
super().__init__()
self.layer_id = layer_id
self.time_shift = nn.ZeroPad2d((0,0,1,-1))
hidden_sz = 5 * config.n_ffn // 2
self.key = nn.Linear(config.n_embd, hidden_sz)
self.value = nn.Linear(config.n_embd, hidden_sz)
self.weight = nn.Linear(hidden_sz, config.n_embd)
self.receptance = nn.Linear(config.n_embd, config.n_embd)
self.receptance.scale_init = 0
self.weight.scale_init = 0
def forward(self, x):
B, T, C = x.size()
x = torch.cat([self.time_shift(x[:, :, :C//2]), x[:, :, C//2:]], dim = -1)
k = self.key(x)
v = self.value(x)
r = self.receptance(x)
wkv = self.weight(F.mish(k) * v)
rwkv = torch.sigmoid(r) * wkv
return rwkv
```
该类实现了通道混合机制,通过 `key`、`value` 和 `receptance` 三个线性层对输入进行变换。`key` 和 `value` 层生成的张量经过 `mish` 激活函数变换后再通过 `weight` 层进行加权变换,最后与 `receptance` 层生成的接收信号相乘得到最终输出。
---
#### 多头注意力机制
`MHA_rotary` 类实现了多头注意力机制,并引入了旋转位置编码。
```python
class MHA_rotary(nn.Module):
def __init__(self, config, layer_id, time_shift = False):
super().__init__()
self.layer_id = layer_id
assert config.n_attn % config.n_head == 0
self.n_head = config.n_head
self.ctx_len = config.ctx_len
self.head_size = config.n_attn // config.n_head
if time_shift:
self.time_shift = nn.ZeroPad2d((0,0,1,-1))
self.query = nn.Linear(config.n_embd, config.n_attn)
self.key = nn.Linear(config.n_embd, config.n_attn)
self.value = nn.Linear(config.n_embd, config.n_attn)
self.register_buffer("mask", torch.tril(torch.ones(config.ctx_len, config.ctx_len)))
self.rotary_ndims = int(self.head_size * 0.5)
self.rotary_emb = RotaryEmbedding(self.rotary_ndims)
self.output = nn.Linear(config.n_attn, config.n_embd)
def forward(self, x):
B, T, C = x.size()
if hasattr(self, 'time_shift'):
x = torch.cat([self.time_shift(x[:, :, :C//2]), x[:, :, C//2:]], dim = -1)
q = self.query(x).view(B, T, self.n_head, self.head_size).transpose(1, 2)
k = self.key(x).view(B, T, self.n_head, self.head_size).transpose(1, 2
)
v = self.value(x).view(B, T, self.n_head, self.head_size).transpose(1, 2)
q, query_pass = q[..., :self.rotary_ndims], q[..., self.rotary_ndims:]
k, key_pass = k[..., :self.rotary_ndims], k[..., self.rotary_ndims:]
cos, sin = self.rotary_emb(q, seq_len=T)
q, k = apply_rotary_pos_emb(q, k, cos, sin)
q = torch.cat((q, query_pass), dim=-1)
k = torch.cat((k, key_pass), dim=-1)
att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1)))
att = att.masked_fill(self.mask[:T,:T] == 0, float('-inf'))
att = F.softmax(att, dim = -1)
x = att @ v
x = x.transpose(1, 2).contiguous().view(B, T, -1)
x = self.output(x)
return x
```
该类实现了多头注意力机制,并引入了旋转位置编码以增强模型的位置信息表示。通过 `query`、`key` 和 `value` 三个线性层生成查询、键和值,再通过旋转编码和自注意力机制计算最终输出。
---
#### 总结
RWKV v1 模型实现了基础的时间混合、通道混合和多头注意力机制。通过对输入进行多维度的变换和混合,实现了对复杂特征的提取和表示。上述代码段展示了模型的核心组件和主要计算过程,为后续版本的优化和改进提供了坚实的基础。

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,258 @@
########################################################################################################
# The RWKV Language Model - https://github.com/BlinkDL/RWKV-LM
########################################################################################################
import numpy as np
import math, json, time, types, copy, sys, os
import torch
from torch.nn import functional as F
import torch.nn as nn
from transformers import PreTrainedTokenizerFast
# RUN_DEVICE = 'cpu' # cpu cuda
# ctx_len = 768
# n_layer = 12
# n_embd = 768
RUN_DEVICE = 'cpu'
ctx_len = 768
n_layer = 24
n_embd = 1024
MODEL_NAME = '/data1/ckw/20220615-10803'
vocab_size = 50277
VOCAB_NAME = '20B_tokenizer.json'
print(f'\n* running on {RUN_DEVICE}')
################################################################################################################
class RWKV_ChannelMix(nn.Module):
def __init__(self, layer_id):
super().__init__()
self.layer_id = layer_id
self.time_shift = nn.ZeroPad2d((0,0,1,-1))
self.time_mix = nn.Parameter(torch.ones(1, 1, n_embd))
hidden_sz = 4 * n_embd
self.key = nn.Linear(n_embd, hidden_sz, bias=False)
self.receptance = nn.Linear(n_embd, n_embd, bias=False)
self.value = nn.Linear(hidden_sz, n_embd, bias=False)
def forward(self, x):
x = x * self.time_mix + self.time_shift(x) * (1 - self.time_mix)
k = self.key(x)
k = torch.square(torch.relu(k))
kv = self.value(k)
rkv = torch.sigmoid(self.receptance(x)) * kv
return rkv
class RWKV_TimeMix(nn.Module):
def __init__(self, layer_id):
super().__init__()
self.layer_id = layer_id
self.time_decay = nn.Parameter(torch.ones(n_embd, 1))
self.time_curve = torch.tensor([-(ctx_len - 2 - i) for i in range(ctx_len-1)]).unsqueeze(0)
self.time_first = nn.Parameter(torch.ones(n_embd, 1) * math.log(0.3))
self.time_shift = nn.ZeroPad2d((0,0,1,-1))
self.time_mix = nn.Parameter(torch.ones(1,1,n_embd))
self.key = nn.Linear(n_embd, n_embd, bias=False)
self.value = nn.Linear(n_embd, n_embd, bias=False)
self.receptance = nn.Linear(n_embd, n_embd, bias=False)
self.output = nn.Linear(n_embd, n_embd, bias=False)
def forward(self, x):
B, T, C = x.size()
x = x * self.time_mix + self.time_shift(x) * (1 - self.time_mix)
k = self.key(x).transpose(-1, -2)
v = self.value(x).transpose(-1, -2)
r = self.receptance(x)
k = torch.clamp(k, max=60)
k = torch.exp(k)
kv = k * v
self.time_w = torch.cat([torch.exp(self.time_decay) * self.time_curve.to(self.time_decay.device), self.time_first], dim=-1)
w = torch.exp(self.time_w)
w = w[:,-T:].unsqueeze(1)
wkv = F.conv1d(nn.ZeroPad2d((T-1, 0, 0, 0))(kv), w, groups=C)
wk = F.conv1d(nn.ZeroPad2d((T-1, 0, 0, 0))(k), w, groups=C) + 1e-9
rwkv = torch.sigmoid(r) * (wkv / wk).transpose(-1, -2)
rwkv = self.output(rwkv)
return rwkv
class Block(nn.Module):
def __init__(self, layer_id):
super().__init__()
self.layer_id = layer_id
self.ln1 = nn.LayerNorm(n_embd)
self.ln2 = nn.LayerNorm(n_embd)
self.att = RWKV_TimeMix(layer_id)
self.ffn = RWKV_ChannelMix(layer_id)
def forward(self, x):
x = self.ln1(x)
x = x + self.att(x)
x = self.ln2(x)
x = x + self.ffn(x)
return x
class RWKV_GPT(nn.Module):
def __init__(self, MODEL_NAME=MODEL_NAME):
super().__init__()
print('\nloading RWKV-GPT', MODEL_NAME)
self.tokenizer = PreTrainedTokenizerFast(tokenizer_file=VOCAB_NAME)
self.emb = nn.Embedding(vocab_size, n_embd)
self.blocks = nn.Sequential(*[Block(i) for i in range(n_layer)])
self.ln_out = nn.LayerNorm(n_embd)
self.head = nn.Linear(n_embd, vocab_size, bias=False)
self.ctx_len = ctx_len
self.eval()
self.load_state_dict(torch.load(MODEL_NAME + '.pth'))
self.eval()
def forward(self, idx):
B, T = idx.size()
assert T <= self.ctx_len, "Cannot forward, because len(input) > model ctx_len."
x = self.emb(idx)
x = self.blocks(x)
x = self.ln_out(x)
x = self.head(x)
return x
################################################################################################################
time_buf = {}
class RWKV_RNN():
def __init__(self, MODEL_NAME=MODEL_NAME):
print('\nloading RWKV-RNN', MODEL_NAME)
self.ctx_len = ctx_len
self.n_layer = n_layer
self.n_embd = n_embd
self.tokenizer = PreTrainedTokenizerFast(tokenizer_file=VOCAB_NAME)
self.w = types.SimpleNamespace()
w = torch.load(MODEL_NAME + '.pth', map_location=torch.device(RUN_DEVICE))
for x in w.keys():
if '.time_' in x:
w[x] = w[x].squeeze()
if '.time_decay' in x:
w[x] = torch.exp(-torch.exp(w[x]))
if '.time_first' in x:
w[x] = torch.exp(w[x])
xx = x.split('.')
here = self.w
for i in range(len(xx)):
if xx[i].isdigit():
ii = int(xx[i])
if ii not in here:
here[ii] = types.SimpleNamespace()
here = here[ii]
else:
if i == len(xx) - 1:
setattr(here, xx[i], w[x])
elif not hasattr(here, xx[i]):
if xx[i+1].isdigit():
setattr(here, xx[i], {})
else:
setattr(here, xx[i], types.SimpleNamespace())
here = getattr(here, xx[i])
self.clear()
def clear(self):
self.xx = {}
self.aa = {}
self.bb = {}
def save(self, target):
target.xx = copy.deepcopy(self.xx)
target.aa = copy.deepcopy(self.aa)
target.bb = copy.deepcopy(self.bb)
def load(self, target):
self.xx = copy.deepcopy(target.xx)
self.aa = copy.deepcopy(target.aa)
self.bb = copy.deepcopy(target.bb)
def LN(self, xx, w):
return F.layer_norm(xx, (n_embd,), weight=w.weight, bias=w.bias)
def FF(self, xx, w, name):
if name not in self.xx:
self.xx[name] = torch.zeros(n_embd, device=RUN_DEVICE)
x = xx * w.time_mix + self.xx[name] * (1 - w.time_mix)
self.xx[name] = xx
r = torch.sigmoid(w.receptance.weight @ x)
k = torch.square(torch.relu(w.key.weight @ x))
kv = w.value.weight @ k
return r * kv
def SA(self, xx, w, name):
if name not in self.xx:
self.xx[name] = torch.zeros(n_embd, device=RUN_DEVICE)
self.aa[name] = torch.zeros(n_embd, device=RUN_DEVICE)
self.bb[name] = torch.zeros(n_embd, device=RUN_DEVICE)
x = xx * w.time_mix + self.xx[name] * (1 - w.time_mix)
self.xx[name] = xx
r = torch.sigmoid(w.receptance.weight @ x)
k = torch.exp(torch.clamp(w.key.weight @ x, max=60))
v = w.value.weight @ x
kv = k * v
a = self.aa[name] + w.time_first * kv
b = self.bb[name] + w.time_first * k
self.aa[name] = w.time_decay * self.aa[name] + kv
self.bb[name] = w.time_decay * self.bb[name] + k
rwkv = r * a / (b + 1e-9)
return w.output.weight @ rwkv
def run(self, ctx):
w = self.w
x = w.emb.weight[ctx[-1]]
for i in range(n_layer):
x = self.LN(x, w.blocks[i].ln1)
x = x + self.SA(x, w.blocks[i].att, f'att.{i}')
x = self.LN(x, w.blocks[i].ln2)
x = x + self.FF(x, w.blocks[i].ffn, f'ffn.{i}')
x = self.LN(x, w.ln_out)
x = w.head.weight @ x
x = x.tolist()
return x
################################################################################################################

View File

@ -0,0 +1,575 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"id": "c6cb345b-d2a1-408f-8e91-e562f5cd032a",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n",
"* running on cpu\n"
]
}
],
"source": [
"########################################################################################################\n",
"# The RWKV Language Model - https://github.com/BlinkDL/RWKV-LM\n",
"########################################################################################################\n",
"\n",
"import numpy as np\n",
"import types\n",
"import copy\n",
"import torch\n",
"from torch.nn import functional as F\n",
"\n",
"from model import RWKV_RNN"
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "70ff5b46-9362-47e9-a1b4-26cac9698c69",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n",
"******************************************************************************\n",
"* This is a preview of RWKV-v2-RNN trained on the Pile for only 50B tokens.\n",
"* It is NOT indicative of the final performance (which requires 300B tokens).\n",
"******************************************************************************\n"
]
}
],
"source": [
"np.set_printoptions(precision=4, suppress=True, linewidth=200)\n",
"\n",
"print('''\n",
"******************************************************************************\n",
"* This is a preview of RWKV-v2-RNN trained on the Pile for only 50B tokens.\n",
"* It is NOT indicative of the final performance (which requires 300B tokens).\n",
"******************************************************************************''')"
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "795e2228-cc19-4e57-bcb7-5027365519c5",
"metadata": {},
"outputs": [],
"source": [
"# Edit model.py to set CPU / CUDA mode. Runs on CPU by default.\n",
"\n",
"TEMPERATURE = 1.0\n",
"TOP_P = 0.7\n",
"\n",
"DEBUG_DEBUG = False\n",
"LENGTH_OF_EACH = 333\n",
"NUM_TRIALS = 100\n",
"\n",
"context = '\\nDataWhalechina is an organization founded at Shanghai Jiao Tong University that helps learners learn artificial intelligence.'\n",
"\n",
"##############################################################################################################"
]
},
{
"cell_type": "raw",
"id": "08295fac-a5ce-478f-8bd9-fc191aae1af0",
"metadata": {},
"source": [
"模型下载地址https://hf-mirror.com/BlinkDL/rwkv-2-pile-430m/resolve/main/20220615-10803.pth?download=true"
]
},
{
"cell_type": "raw",
"id": "f41497aa-4bff-4a91-8e1e-d811b3077170",
"metadata": {},
"source": [
"请在model.py中修改模型路径\n",
"例如我的路径是/data1/ckw/20220615-10803\n",
"如果想使用cuda加速请参考https://github.com/BlinkDL/RWKV-v2-RNN-Pile"
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "329c7001-102a-4d84-88ba-2b97a68c9557",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n",
"loading RWKV-RNN /data1/ckw/20220615-10803\n"
]
}
],
"source": [
"model = RWKV_RNN()"
]
},
{
"cell_type": "code",
"execution_count": 5,
"id": "326e1bc0-eaa2-4f4e-94d9-60fe99d8768f",
"metadata": {},
"outputs": [],
"source": [
"def sample_logits(out, temperature=1.0, top_p=None):\n",
" probs = F.softmax(torch.tensor(out), dim=-1)\n",
" sorted_probs, _ = torch.sort(probs, descending=True)\n",
"\n",
" cumulative_probs = torch.cumsum(sorted_probs, dim=-1).numpy()\n",
" cutoff = float(sorted_probs[np.argmax(cumulative_probs > top_p)])\n",
" probs[probs < cutoff] = 0\n",
"\n",
" if temperature != 1.0:\n",
" probs = probs.pow(1.0 / temperature)\n",
"\n",
" return torch.multinomial(probs, num_samples=1)[0]"
]
},
{
"cell_type": "code",
"execution_count": 6,
"id": "392a71df-ac9e-4349-a60f-8988466e5812",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n",
"DataWhalechina is an organization founded at Shanghai Jiao Tong University that helps learners learn artificial intelligence. We train artificial intelligence and robotics experts from around the world to build innovative artificial intelligence systems for education. We also develop software and systems to improve the quality of data for AI and robotics. DataWhalechina provides learning solutions for AI and robotics, software development and data analysis. We help AI developers to create new applications and models, using machine learning to understand the human brain and to improve its performance.\n",
"\n",
"Our goal is to build AI systems for educational purposes. We are an international AI organization, with over 20 years of experience in training AI specialists. Our expertise is based on the following:\n",
"\n",
"Advantages\n",
"\n",
"Maintain high standards of quality\n",
"\n",
"Create effective learning opportunities\n",
"\n",
"Ensure the reliability of data\n",
"\n",
"Help people in training\n",
"\n",
"Reduce costs\n",
"\n",
"Expand the reach of the organization\n",
"\n",
"Were proud to be the first organization to create an AI that focuses on data quality.\n",
"\n",
"What is a data quality assessment?\n",
"\n",
"Data quality assessment is a process of conducting quality assessment of data. Data quality assessment uses techniques that examine the process of data collection and analysis, to assess the quality of the data collected.\n",
"\n",
"If the process of data quality assessment is not complete, the data will be deleted from the system.\n",
"\n",
"If the process of data quality assessment is not complete, the system will be subject to further analysis and processing.\n",
"\n",
"The data quality assessment will also include an analysis of the system to determine whether the data quality assessment has been completed.\n",
"\n",
"How does data quality assessment work?\n",
"\n",
"When an organization works on a data quality assessment, it\n",
"----------------------------------------------------------------------\n",
"DataWhalechina is an organization founded at Shanghai Jiao Tong University that helps learners learn artificial intelligence. The organization started as an experimental project. They have been working on different projects and one of them is datawhale.io.\n",
"\n",
"At the time of writing, DataWhale is being used by schools and institutions to help students understand data science and to train their algorithms.\n",
"\n",
"DataWhale has many projects and they have been active since 2018. There are over 10 projects and thousands of datawhale.io users on their network.\n",
"\n",
"DataWhale.io has been working on different projects. They started with the open source project ZopSketch, that helps create a smart camera and its still in development. They have also been working on AutoLink, which allows users to automatically upload a link to their own photo on social media.\n",
"\n",
"They have been working on Tekken, which is a competitive action game and they have developed the program MonsterJigsaw. They have also been working on the UBIO game.\n",
"\n",
"DataWhale is helping people to improve their data science skills, especially by providing tools that can help them with their own data.\n",
"\n",
"DataWhale has also been working on data science projects that help data scientists and data scientists. DataWhale has been working on various data science projects. They have created tools to help companies, organizations, and businesses use data.\n",
"\n",
"DataWhale.io is a data science company that has created data science projects.\n",
"\n",
"DataWhale.io has worked on various projects, including data science and data science. They have created data science projects that help people improve their data science skills.\n",
"\n",
"DataWhale has been working\n",
"----------------------------------------------------------------------\n",
"DataWhalechina is an organization founded at Shanghai Jiao Tong University that helps learners learn artificial intelligence. The group trains students in the language, skill and design of artificial intelligence and research in AI. It was created in 2018 and is currently in operation.\n",
"\n",
"International Students\n",
"\n",
"Learn about our international students and international students that are currently enrolled at DataWhalechina.\n",
"\n",
"Tangxia Online\n",
"\n",
"Tangxia Online is a student portal for all Chinese students in Shanghai. It enables them to see information about Shanghai, access information about current courses and news about Shanghai.\n",
"\n",
"China Business China (CBBC)\n",
"\n",
"CBBC is a portal that hosts global business courses, online courses and information about China in Shanghai.\n",
"\n",
"As the global host of the international online program CBBC, DataWhalechina has established a network of Chinese students, alumni, educators and business students. The portal offers a number of valuable courses and information.\n",
"\n",
"About DataWhalechina\n",
"\n",
"DataWhalechina is a Chinese portal that hosts data and information about international students. The portal provides information about current courses and information about China in Shanghai.\n",
"\n",
"DataWhalechina.com is a portal that allows users to view courses, course options, materials and information about Shanghai.\n",
"\n",
"DataWhalechina.com has a strong focus on providing students with online resources and information about Shanghai, China.\n",
"\n",
"For more information, visit datawhalechina.com.\n",
"\n",
"China Data China\n",
"\n",
"China Data China is a portal that provides users with online information and resources about Shanghai, China. The portal provides users with information about Shanghai, China.\n",
"\n",
"DataWhalechina is a portal that allows\n",
"----------------------------------------------------------------------\n",
"DataWhalechina is an organization founded at Shanghai Jiao Tong University that helps learners learn artificial intelligence. They have trained over 2,000 people to date, and are now one of the top 10 data mining companies in the world.\n",
"\n",
"This means they can offer a variety of training programs to their users. These include data mining and artificial intelligence course. They have trained people in artificial intelligence and machine learning. They have also trained an AI-powered library of services. This means they can help any company that wants to take advantage of AI-powered data mining services.\n",
"\n",
"Founded in 2017, DataWhalechina is a machine learning and artificial intelligence company based in Shanghai, China. They have trained some of the worlds top 5 AI-powered data mining companies. Their services include machine learning, artificial intelligence, and data mining.\n",
"\n",
"DataWhalechina has been working in the AI industry for many years. They have worked on many projects and technologies, including AI-powered platforms.\n",
"\n",
"DataWhalechina is a member of the Information and Technology Sector (IT) Council.\n",
"\n",
"The company also works with startups to create AI-powered services that will help their clients.\n",
"\n",
"These services include machine learning and artificial intelligence training. They also have many consultants that are trained on the companys platform.\n",
"\n",
"The company also has a data scientist, an artificial intelligence researcher, and a data science engineer.\n",
"\n",
"DataWhalechina has been in the AI industry for more than 20 years. They have an office in Shanghai, China.\n",
"\n",
"The company offers many different services to help their clients learn AI. They have courses and workshops.\n",
"\n",
"The company has worked with startups that have worked\n",
"----------------------------------------------------------------------\n",
"DataWhalechina is an organization founded at Shanghai Jiao Tong University that helps learners learn artificial intelligence. We have a strong commitment to developing learning tools for everyone. DataWhalechinas AI and ML platform offers users a free course on ML algorithms and data mining for anyone with basic programming skills.\n",
"\n",
"We will start with a small introduction to Machine Learning, we will start with building the first model and the first algorithm, followed by a few more.\n",
"\n",
"Data mining\n",
"\n",
"Machine learning is a process of classifying data into classes based on the principles of statistical inference and machine learning. The purpose of machine learning is to classify data into different categories.\n",
"\n",
"Data mining is the science of extracting useful information from large amounts of data.\n",
"\n",
"You can think of data mining as a way to find patterns in data. For example, if youre looking for the average temperature in a certain area, you can try to classify data based on temperature.\n",
"\n",
"Data mining is also a technique that uses machine learning to predict the future behaviour of an object.\n",
"\n",
"Machine learning can also be used to predict the future behaviour of a system. This is called a model-based learning algorithm.\n",
"\n",
"Data mining is a relatively new discipline, and many people havent even heard of it. However, we do have some basic concepts of data mining.\n",
"\n",
"Data mining is a technique for creating models to predict future events based on data. Data mining is also known as “machine learning” or “artificial intelligence”.\n",
"\n",
"Machine learning can be used to create models for some tasks. For example, an algorithm can be trained to predict a salesmans behaviour based on past sales.\n",
"\n",
"If you want to make a decision\n",
"----------------------------------------------------------------------\n",
"DataWhalechina is an organization founded at Shanghai Jiao Tong University that helps learners learn artificial intelligence. DataWhalechina works with universities and the government to provide AI training, including computer-assisted learning, for students in areas like health care, sports, and science.\n",
"\n",
"DataWhalechina was founded in 2015.\n",
"\n",
"DataWhalechina provides AI research, teaching, and software development for the United States, Europe, and China. It was founded by China University of Political Science and Law, University of Washington, and University of San Diego.\n",
"\n",
"Since 2017, DataWhalechina has operated the Web Science Institute. DataWhalechina is also involved in the field of artificial intelligence.\n",
"\n",
"List of data scientists\n",
"DataWhalechina has made numerous contributions to the field of AI, including its founding of the DataScience Initiative, a non-profit organization that conducts research and analysis of artificial intelligence. The company is currently focused on improving the quality of AI research.\n",
"\n",
"DataWhalechina has conducted various studies and experiments in the field of artificial intelligence, including one in which data scientists learned how to use machine learning algorithms to predict human emotions.\n",
"\n",
"Datawhalechina has been involved in various projects related to artificial intelligence, such as its partnership with Leverage Ventures, an AI startup that focuses on artificial intelligence research.\n",
"\n",
"DataWhalechina and Leverage Ventures have collaborated on numerous data science projects, including a project that builds a machine learning model that detects patterns in YouTube videos.\n",
"\n",
"DataWhalechina was also involved in helping Leverage Ventures to launch their research into machine learning.\n",
"\n",
"DataWhalechina was\n",
"----------------------------------------------------------------------\n",
"DataWhalechina is an organization founded at Shanghai Jiao Tong University that helps learners learn artificial intelligence. The organization also is involved in many other non-profit organizations such as the Alumni Association of the University of California, Los Angeles, and the United States Army.\n",
"\n",
"History \n",
"DataWhalechina was founded in May 2018.\n",
"\n",
"Awards \n",
"The organization has been awarded the following awards:\n",
" 2018: Business Week\n",
" 2018: Google News\n",
" 2018: MediaWorld\n",
"\n",
"References\n",
"\n",
"External links \n",
" Official website\n",
"\n",
"Category:2018 establishments in China\n",
"Category:Companies based in Shanghai\n",
"Category:Companies established in 2018\n",
"Category:Companies based in Guangdong\n",
"Category:Internet search engines\n",
"Category:Information and communication technology companies\n",
"Category:International organizations based in China\n",
"Category:Internet search engines\n",
"Category:Organizations established in 2018\n",
"Category:Information technology organizations based in Europe\n",
"Category:Non-profit organizations based in Shanghai\n",
"Category:Organizations that support open access\n",
"Category:Artificial intelligence organizations\n",
"Category:Companies based in Shanghai\n",
"Category:2019 mergers and acquisitions\n",
"Category:Chinese brands\n",
"Category:Technology companies of China\n",
"Category:2019 mergers and acquisitions\n",
"Category:Technology companies established in 2018\n",
"Category:2019 mergers and acquisitions\n",
"Category:Technology companies established in 2018\n",
"Category:2018 mergers and acquisitions\n",
"Category:2019 mergers and acquisitions\n",
"Category:Istanbul Technical University\n",
"Category:Online learning companies\n",
"Category:Tech companies of China\n",
"Category:Technology companies established in 2018\n",
"Category:Technology companies established in 2019\n",
"Category:Technology companies of China\n",
"Category:2018 establishments in China\n",
"Category:Technology companies of China\n",
"Category:2018 mergers and\n",
"----------------------------------------------------------------------\n",
"DataWhalechina is an organization founded at Shanghai Jiao Tong University that helps learners learn artificial intelligence. Their mission is to help students become more effective learner. They focus on educating learners to master artificial intelligence (AI) and to improve their learning experience. They conduct research, help students learn and improve their skills, and bring their knowledge and understanding to the public.\n",
"\n",
"We are all in the same boat. We all know the situation, but no one is doing anything about it. There is no one that can make things better, so how do we fix this?\n",
"\n",
"In this article, we are going to share our perspective and ask you to help. We are all in the same boat, so how can we fix this?\n",
"\n",
"In this article, we will share our perspective and ask you to help us. We are all in the same boat. We all know the situation, but no one knows the answer.\n",
"\n",
"What is AI?\n",
"\n",
"An AI is an intelligent program that is capable of understanding the human condition, finding the truth, and making decisions.\n",
"\n",
"This term has been in use for the past 20 years and is considered a self-driving car.\n",
"\n",
"It is a machine that is capable of learning from experiences and taking actions, and making decisions, based on these actions.\n",
"\n",
"It is a self-driving car that is self-learning and learns from experiences. It is an intelligent machine that is able to learn from experiences, and make decisions, based on experiences.\n",
"\n",
"Why AI?\n",
"\n",
"AI is a method of learning that is being developed in a new way that is unprecedented. This is what makes AI such a promising field to work on.\n",
"\n",
"What is an AI?\n",
"----------------------------------------------------------------------\n",
"DataWhalechina is an organization founded at Shanghai Jiao Tong University that helps learners learn artificial intelligence. We have a passion for AI and are always working on it. If you want to be an AI expert, join our membership and get access to a huge collection of free courses, resources, videos, tutorials, and more. Join today!\n",
"\n",
"https://twitter.com/DataWhalechina\n",
"\n",
"https://www.facebook.com/DataWhalechina\n",
"\n",
"https://www.instagram.com/DataWhalechina\n",
"\n",
"https://www.linkedin.com/company/data-whale\n",
"\n",
"Join our platform and learn the benefits of data.\n",
"\n",
"Get Data Whale\n",
"\n",
"Delve into the data world with the help of this professional engineer. This man can design the biggest robot, program the smartest AI in the world, and design the best AI in the world. Its easy to learn and fun to watch. You can also use this platform to create your own AI projects, and join our community to be an expert in AI.\n",
"\n",
"https://diveintodatawhale.com/\n",
"\n",
"https://www.linkedin.com/in/diveintodatawhale\n",
"\n",
"https://linkedin.com/in/diveintodatawhale\n",
"\n",
"www.linkedin.com/in/diveintodatawhale\n",
"\n",
"www.linkedin.com/in/davewhale\n",
"\n",
"Dive Into Data Whale\n",
"\n",
"This is the latest free data whitepaper. We use data from datawhalechina.com to build our AI and data-driven products.\n",
"\n",
"Data Whale\n",
"\n",
"\n",
"----------------------------------------------------------------------\n",
"DataWhalechina is an organization founded at Shanghai Jiao Tong University that helps learners learn artificial intelligence. Its mission is to help AI learners in the UK by providing them with free training.\n",
"\n",
"DataWhalechina was created by the Open Data Institute in the UK. It was launched on March 19, 2017, and it provides online courses for AI professionals to improve their skills in data science and data analytics. The program runs in the UK on an annual basis.\n",
"\n",
"The OReilly books \n",
"\n",
"In 2017, DataWhalechina launched a new book titled Data Whale: How the Big Data Industry is Changing the Way We Think and Act. It contains four parts, each part dealing with a different aspect of data science. The first part is a \"standard-answer\" chapter that explains how to make sense of the massive data available to you. The second part is a \"how-to\" guide that discusses the four parts, and the third part contains a chapter titled \"Big Data Analytics\". The fourth part of the book deals with analytics tools, like Amazon Web Services, the AWS cloud, and Apache Kafka.\n",
"\n",
"Other chapters include:\n",
"\n",
"A guide to artificial intelligence\n",
"A guide to data science\n",
"An introduction to machine learning\n",
"An introduction to data science\n",
"An introduction to data visualization\n",
"\n",
"The book has sections on machine learning, artificial intelligence, data science, artificial intelligence, and related fields. The book also includes a section on artificial intelligence basics, including data mining, machine learning, and machine learning in medicine.\n",
"\n",
"References\n",
"\n",
"External links \n",
" DataWhalechina\n",
"\n",
"Category:Software companies of the United Kingdom\n",
"Category:Software companies of the United States\n",
"Category:Data science\n",
"\n",
"----------------------------------------------------------------------\n",
"DataWhalechina is an organization founded at Shanghai Jiao Tong University that helps learners learn artificial intelligence.\n",
"\n",
"“The learning process will be more structured and easier,” said Swinoujci. “It will be a more enjoyable experience for students, which will in turn lead to better learning experiences.”\n",
"\n",
"In recent years, the industry has seen a significant uptick in the number of artificial intelligence (AI) training courses offered by companies. This includes instructional courses such as Artificial Intelligence in Banking and Finance, AI in Construction and Finance, AI in Energy, and AI in Hospitality and Leisure.\n",
"\n",
"AI is a topic that is well suited to the artificial intelligence field, as it offers an immersive learning experience that requires sophisticated knowledge. This is why artificial intelligence is increasingly being used in training courses.\n",
"\n",
"“AI in training courses is becoming increasingly popular in the industry,” said Swinoujci. “AI is transforming the way we teach artificial intelligence.”\n",
"\n",
"Swinoujci said that artificial intelligence in training courses will be “extremely challenging,” but he added that he is optimistic that the AI course will soon become a more accessible form of learning.\n",
"\n",
"“AI in training courses is a well-organized, structured, and challenging course that can be integrated into the education of future engineers,” he said. “I think it is crucial for the field of artificial intelligence to become more widely accessible, and that it can be used as a source of training.”\n",
"\n",
"Chinese investors are highly enthusiastic about artificial intelligence. This comes as a huge boon for the sector. The market is estimated to grow at an annual rate of over $11 billion in the next three years, and this will drive the industry\n",
"----------------------------------------------------------------------\n",
"DataWhalechina is an organization founded at Shanghai Jiao Tong University that helps learners learn artificial intelligence. They specialize in developing AI and artificial intelligence in China.\n",
"\n",
"In this video, I was asked about the future of AI in China. In this video, I share my story about how I got my job and how I feel about AI today.\n",
"\n",
"The case for AI\n",
"\n",
"When you think about AI, you think about machines, but how does AI actually help humans? I have two main ideas about AI.\n",
"\n",
"1. AI allows humans to do amazing things\n",
"\n",
"AI can do incredible things. For example, artificial intelligence can do amazing things for machine learning. In the case of this video, I used a machine learning algorithm to detect suspicious behavior from millions of people. The algorithm was so good, that the machine was able to catch some very interesting information. This is what we have to take away from this story.\n",
"\n",
"2. AI allows us to see the world\n",
"\n",
"I was taught by a Chinese man who told me how AI could help us see the world in a new way. I learned a lot about the human mind from AI. I saw how machines can make decisions based on what people see. I also saw how machines can understand the natural world better than humans.\n",
"\n",
"I believe AI can make the world a better place. In the future, AI will become a fundamental part of how we think about the world.\n",
"\n",
"How AI can help you\n",
"\n",
"Here are some things you can do to help us make the world a better place:\n",
"\n",
"Make your computer smarter\n",
"\n",
"In the next lesson, I will show you how AI can make your computer smarter. I want you to use this lesson\n",
"----------------------------------------------------------------------\n",
"DataWhalechina is an organization founded at Shanghai Jiao Tong University that helps learners learn artificial intelligence. Were also building a team of Chinese artificial intelligence experts to help train AI researchers. We believe AI is the next frontier for education.\n",
"\n",
"Can you tell us about your organization?\n",
"\n",
"DataWh"
]
},
{
"ename": "KeyboardInterrupt",
"evalue": "",
"output_type": "error",
"traceback": [
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
"\u001b[0;31mKeyboardInterrupt\u001b[0m Traceback (most recent call last)",
"Cell \u001b[0;32mIn[6], line 31\u001b[0m\n\u001b[1;32m 29\u001b[0m out \u001b[38;5;241m=\u001b[39m copy\u001b[38;5;241m.\u001b[39mdeepcopy(init_state\u001b[38;5;241m.\u001b[39mout)\n\u001b[1;32m 30\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[0;32m---> 31\u001b[0m out \u001b[38;5;241m=\u001b[39m \u001b[43mmodel\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mrun\u001b[49m\u001b[43m(\u001b[49m\u001b[43mx\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 33\u001b[0m out[\u001b[38;5;241m0\u001b[39m] \u001b[38;5;241m=\u001b[39m \u001b[38;5;241m-\u001b[39m\u001b[38;5;241m999999999\u001b[39m \u001b[38;5;66;03m# disable <|endoftext|>\u001b[39;00m\n\u001b[1;32m 35\u001b[0m char \u001b[38;5;241m=\u001b[39m sample_logits(out, temperature\u001b[38;5;241m=\u001b[39mTEMPERATURE, top_p\u001b[38;5;241m=\u001b[39mTOP_P)\n",
"File \u001b[0;32m/data1/ckw/01大语言模型/rwkv-v2/model.py:249\u001b[0m, in \u001b[0;36mRWKV_RNN.run\u001b[0;34m(self, ctx)\u001b[0m\n\u001b[1;32m 247\u001b[0m x \u001b[38;5;241m=\u001b[39m x \u001b[38;5;241m+\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mSA(x, w\u001b[38;5;241m.\u001b[39mblocks[i]\u001b[38;5;241m.\u001b[39matt, \u001b[38;5;124mf\u001b[39m\u001b[38;5;124m'\u001b[39m\u001b[38;5;124matt.\u001b[39m\u001b[38;5;132;01m{\u001b[39;00mi\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m'\u001b[39m)\n\u001b[1;32m 248\u001b[0m x \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mLN(x, w\u001b[38;5;241m.\u001b[39mblocks[i]\u001b[38;5;241m.\u001b[39mln2)\n\u001b[0;32m--> 249\u001b[0m x \u001b[38;5;241m=\u001b[39m x \u001b[38;5;241m+\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mFF\u001b[49m\u001b[43m(\u001b[49m\u001b[43mx\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mw\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mblocks\u001b[49m\u001b[43m[\u001b[49m\u001b[43mi\u001b[49m\u001b[43m]\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mffn\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;124;43mf\u001b[39;49m\u001b[38;5;124;43m'\u001b[39;49m\u001b[38;5;124;43mffn.\u001b[39;49m\u001b[38;5;132;43;01m{\u001b[39;49;00m\u001b[43mi\u001b[49m\u001b[38;5;132;43;01m}\u001b[39;49;00m\u001b[38;5;124;43m'\u001b[39;49m\u001b[43m)\u001b[49m\n\u001b[1;32m 251\u001b[0m x \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mLN(x, w\u001b[38;5;241m.\u001b[39mln_out)\n\u001b[1;32m 253\u001b[0m x \u001b[38;5;241m=\u001b[39m w\u001b[38;5;241m.\u001b[39mhead\u001b[38;5;241m.\u001b[39mweight \u001b[38;5;241m@\u001b[39m x\n",
"File \u001b[0;32m/data1/ckw/01大语言模型/rwkv-v2/model.py:213\u001b[0m, in \u001b[0;36mRWKV_RNN.FF\u001b[0;34m(self, xx, w, name)\u001b[0m\n\u001b[1;32m 210\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mxx[name] \u001b[38;5;241m=\u001b[39m xx\n\u001b[1;32m 212\u001b[0m r \u001b[38;5;241m=\u001b[39m torch\u001b[38;5;241m.\u001b[39msigmoid(w\u001b[38;5;241m.\u001b[39mreceptance\u001b[38;5;241m.\u001b[39mweight \u001b[38;5;241m@\u001b[39m x)\n\u001b[0;32m--> 213\u001b[0m k \u001b[38;5;241m=\u001b[39m torch\u001b[38;5;241m.\u001b[39msquare(\u001b[43mtorch\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mrelu\u001b[49m\u001b[43m(\u001b[49m\u001b[43mw\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mkey\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mweight\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m@\u001b[39;49m\u001b[43m \u001b[49m\u001b[43mx\u001b[49m\u001b[43m)\u001b[49m)\n\u001b[1;32m 214\u001b[0m kv \u001b[38;5;241m=\u001b[39m w\u001b[38;5;241m.\u001b[39mvalue\u001b[38;5;241m.\u001b[39mweight \u001b[38;5;241m@\u001b[39m k\n\u001b[1;32m 216\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m r \u001b[38;5;241m*\u001b[39m kv\n",
"\u001b[0;31mKeyboardInterrupt\u001b[0m: "
]
}
],
"source": [
"for TRIAL in range(1 if DEBUG_DEBUG else NUM_TRIALS):\n",
" ctx = [model.tokenizer.encode(context)][0]\n",
" src_len = len(ctx)\n",
" print(context, end='')\n",
"\n",
" model.clear()\n",
" if TRIAL == 0:\n",
" init_state = types.SimpleNamespace()\n",
" for i in range(src_len if DEBUG_DEBUG else src_len):\n",
" x = ctx[:i+1]\n",
" if i == src_len - 1:\n",
" init_state.out = model.run(x)\n",
" else:\n",
" model.run(x)\n",
" model.save(init_state)\n",
" else:\n",
" model.load(init_state)\n",
"\n",
" if DEBUG_DEBUG:\n",
" out = init_state.out\n",
" print('\\n', np.array(x), '==>', np.array(\n",
" out), np.max(out), np.min(out))\n",
"\n",
" for i in range(src_len, src_len + (0 if DEBUG_DEBUG else LENGTH_OF_EACH)):\n",
" x = ctx[:i+1]\n",
" x = x[-model.ctx_len:]\n",
"\n",
" if i == src_len:\n",
" out = copy.deepcopy(init_state.out)\n",
" else:\n",
" out = model.run(x)\n",
"\n",
" out[0] = -999999999 # disable <|endoftext|>\n",
"\n",
" char = sample_logits(out, temperature=TEMPERATURE, top_p=TOP_P)\n",
" char = char.item()\n",
" print(model.tokenizer.decode(char), end='', flush=True)\n",
"\n",
" ctx += [char]\n",
" print('\\n' + '-' * 70, end='')"
]
},
{
"cell_type": "raw",
"id": "2c54ec05-21ee-4f71-80d8-b028e1c60913",
"metadata": {},
"source": [
"这里我们以DataWhalechina is an organization founded at Shanghai Jiao Tong University that helps learners learn artificial intelligence. \n",
"为开头写了100组因为太多我们就看几组效果就行。"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.5"
}
},
"nbformat": 4,
"nbformat_minor": 5
}

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,273 @@
########################################################################################################
# The RWKV Language Model - https://github.com/BlinkDL/RWKV-LM
########################################################################################################
import numpy as np
import math, json, time, types, copy, sys, os
import torch
from torch.nn import functional as F
import torch.nn as nn
from transformers import PreTrainedTokenizerFast
RUN_DEVICE = 'cpu' # cpu cuda
ctx_len = 768
n_layer = 12
n_embd = 768
# n_layer = 24
# n_embd = 1024
# ---> download RWKV-3 169M model from https://huggingface.co/BlinkDL/rwkv-3-pile-169m/tree/main
# MODEL_NAME = '/data1/ckw/RWKV-3-Pile-430M-20220817-10602'
MODEL_NAME = '/data1/ckw/RWKV-3-Pile-20220720-10704'
K_EPS = 1e-8
vocab_size = 50277
VOCAB_NAME = '20B_tokenizer.json'
print(f'\n* running on {RUN_DEVICE}')
################################################################################################################
class RWKV_ChannelMix(nn.Module):
def __init__(self, layer_id):
super().__init__()
self.layer_id = layer_id
self.time_shift = nn.ZeroPad2d((0,0,1,-1))
self.time_mix_k = nn.Parameter(torch.ones(1, 1, n_embd))
self.time_mix_r = nn.Parameter(torch.ones(1, 1, n_embd))
hidden_sz = 4 * n_embd
self.key = nn.Linear(n_embd, hidden_sz, bias=False)
self.receptance = nn.Linear(n_embd, n_embd, bias=False)
self.value = nn.Linear(hidden_sz, n_embd, bias=False)
def forward(self, x):
xx = self.time_shift(x)
xk = x * self.time_mix_k + xx * (1 - self.time_mix_k)
xr = x * self.time_mix_r + xx * (1 - self.time_mix_r)
k = self.key(xk)
k = torch.square(torch.relu(k))
kv = self.value(k)
rkv = torch.sigmoid(self.receptance(xr)) * kv
return rkv
class RWKV_TimeMix(nn.Module):
def __init__(self, layer_id):
super().__init__()
self.layer_id = layer_id
self.time_decay = nn.Parameter(torch.ones(n_embd, 1))
self.time_curve = torch.tensor([-(ctx_len - 2 - i) for i in range(ctx_len-1)]).unsqueeze(0)
self.time_first = nn.Parameter(torch.ones(n_embd, 1) * math.log(0.3))
self.time_shift = nn.ZeroPad2d((0,0,1,-1))
self.time_mix_k = nn.Parameter(torch.ones(1,1,n_embd))
self.time_mix_v = nn.Parameter(torch.ones(1,1,n_embd))
self.time_mix_r = nn.Parameter(torch.ones(1,1,n_embd))
self.key = nn.Linear(n_embd, n_embd, bias=False)
self.value = nn.Linear(n_embd, n_embd, bias=False)
self.receptance = nn.Linear(n_embd, n_embd, bias=False)
self.output = nn.Linear(n_embd, n_embd, bias=False)
def forward(self, x):
B, T, C = x.size()
xx = self.time_shift(x)
xk = x * self.time_mix_k + xx * (1 - self.time_mix_k)
xv = x * self.time_mix_v + xx * (1 - self.time_mix_v)
xr = x * self.time_mix_r + xx * (1 - self.time_mix_r)
k = self.key(xk).transpose(-1, -2)
v = self.value(xv).transpose(-1, -2)
r = self.receptance(xr)
k = torch.clamp(k, max=60)
k = torch.exp(k)
kv = k * v
self.time_w = torch.cat([torch.exp(self.time_decay) * self.time_curve.to(self.time_decay.device), self.time_first], dim=-1)
w = torch.exp(self.time_w)
w = w[:,-T:].unsqueeze(1)
wkv = F.conv1d(nn.ZeroPad2d((T-1, 0, 0, 0))(kv), w, groups=C)
wk = F.conv1d(nn.ZeroPad2d((T-1, 0, 0, 0))(k), w, groups=C) + K_EPS
rwkv = torch.sigmoid(r) * (wkv / wk).transpose(-1, -2)
rwkv = self.output(rwkv)
return rwkv
class Block(nn.Module):
def __init__(self, layer_id):
super().__init__()
self.layer_id = layer_id
self.ln1 = nn.LayerNorm(n_embd)
self.ln2 = nn.LayerNorm(n_embd)
if self.layer_id == 0:
self.ln0 = nn.LayerNorm(n_embd)
self.att = RWKV_TimeMix(layer_id)
self.ffn = RWKV_ChannelMix(layer_id)
def forward(self, x):
if self.layer_id == 0:
x = self.ln0(x)
x = x + self.att(self.ln1(x))
x = x + self.ffn(self.ln2(x))
return x
class RWKV_GPT(nn.Module):
def __init__(self, MODEL_NAME=MODEL_NAME):
super().__init__()
print('\nloading RWKV-GPT', MODEL_NAME)
self.tokenizer = PreTrainedTokenizerFast(tokenizer_file=VOCAB_NAME)
self.emb = nn.Embedding(vocab_size, n_embd)
self.blocks = nn.Sequential(*[Block(i) for i in range(n_layer)])
self.ln_out = nn.LayerNorm(n_embd)
self.head = nn.Linear(n_embd, vocab_size, bias=False)
self.ctx_len = ctx_len
self.eval()
self.load_state_dict(torch.load(MODEL_NAME + '.pth'))
self.eval()
def forward(self, idx):
B, T = idx.size()
assert T <= self.ctx_len, "Cannot forward, because len(input) > model ctx_len."
x = self.emb(idx)
x = self.blocks(x)
x = self.ln_out(x)
x = self.head(x)
return x
################################################################################################################
time_buf = {}
class RWKV_RNN():
def __init__(self, MODEL_NAME=MODEL_NAME):
print('\nloading RWKV-RNN', MODEL_NAME)
self.ctx_len = ctx_len
self.n_layer = n_layer
self.n_embd = n_embd
self.tokenizer = PreTrainedTokenizerFast(tokenizer_file=VOCAB_NAME)
self.w = types.SimpleNamespace()
w = torch.load(MODEL_NAME + '.pth', map_location=torch.device(RUN_DEVICE))
for x in w.keys():
if '.time_' in x:
w[x] = w[x].squeeze()
if '.time_decay' in x:
w[x] = torch.exp(-torch.exp(w[x]))
if '.time_first' in x:
w[x] = torch.exp(w[x])
xx = x.split('.')
here = self.w
for i in range(len(xx)):
if xx[i].isdigit():
ii = int(xx[i])
if ii not in here:
here[ii] = types.SimpleNamespace()
here = here[ii]
else:
if i == len(xx) - 1:
setattr(here, xx[i], w[x])
elif not hasattr(here, xx[i]):
if xx[i+1].isdigit():
setattr(here, xx[i], {})
else:
setattr(here, xx[i], types.SimpleNamespace())
here = getattr(here, xx[i])
self.clear()
def clear(self):
self.xx = {}
self.aa = {}
self.bb = {}
def save(self, target):
target.xx = copy.deepcopy(self.xx)
target.aa = copy.deepcopy(self.aa)
target.bb = copy.deepcopy(self.bb)
def load(self, target):
self.xx = copy.deepcopy(target.xx)
self.aa = copy.deepcopy(target.aa)
self.bb = copy.deepcopy(target.bb)
def LN(self, xx, w):
return F.layer_norm(xx, (n_embd,), weight=w.weight, bias=w.bias)
def FF(self, xx, w, name):
if name not in self.xx:
self.xx[name] = torch.zeros(n_embd, device=RUN_DEVICE)
xk = xx * w.time_mix_k + self.xx[name] * (1 - w.time_mix_k)
xr = xx * w.time_mix_r + self.xx[name] * (1 - w.time_mix_r)
self.xx[name] = xx
r = torch.sigmoid(w.receptance.weight @ xr)
k = torch.square(torch.relu(w.key.weight @ xk))
kv = w.value.weight @ k
return r * kv
def SA(self, xx, w, name):
if name not in self.xx:
self.xx[name] = torch.zeros(n_embd, device=RUN_DEVICE)
self.aa[name] = torch.zeros(n_embd, device=RUN_DEVICE)
self.bb[name] = torch.zeros(n_embd, device=RUN_DEVICE)
xk = xx * w.time_mix_k + self.xx[name] * (1 - w.time_mix_k)
xv = xx * w.time_mix_v + self.xx[name] * (1 - w.time_mix_v)
xr = xx * w.time_mix_r + self.xx[name] * (1 - w.time_mix_r)
self.xx[name] = xx
r = torch.sigmoid(w.receptance.weight @ xr)
k = torch.exp(torch.clamp(w.key.weight @ xk, max=60))
v = w.value.weight @ xv
kv = k * v
a = self.aa[name] + w.time_first * kv
b = self.bb[name] + w.time_first * k
self.aa[name] = w.time_decay * self.aa[name] + kv
self.bb[name] = w.time_decay * self.bb[name] + k
rwkv = r * a / (b + K_EPS)
return w.output.weight @ rwkv
def run(self, ctx):
w = self.w
x = w.emb.weight[ctx[-1]]
x = self.LN(x, w.blocks[0].ln0)
for i in range(n_layer):
x = x + self.SA(self.LN(x, w.blocks[i].ln1), w.blocks[i].att, f'att.{i}')
x = x + self.FF(self.LN(x, w.blocks[i].ln2), w.blocks[i].ffn, f'ffn.{i}')
x = self.LN(x, w.ln_out)
x = w.head.weight @ x
x = x.tolist()
return x
################################################################################################################

View File

@ -0,0 +1,319 @@
########################################################################################################
# The RWKV Language Model - https://github.com/BlinkDL/RWKV-LM
########################################################################################################
import types
import copy
import torch
import math
from torch.nn import functional as F
import torch.nn as nn
RWKV_K_CLAMP = 60
RWKV_K_EPS = 1e-8
RWKV_HEAD_QK_DIM = 256
print(f'\nRWKV_K_CLAMP {RWKV_K_CLAMP} RWKV_K_EPS {RWKV_K_EPS} RWKV_HEAD_QK_DIM {RWKV_HEAD_QK_DIM}\n')
DEBUG_TIME = False # True False - show trained time-coeffs
############################################################################################################
RWKV_CFG = types.SimpleNamespace()
class RWKV_ChannelMix(nn.Module):
def __init__(self, layer_id):
super().__init__()
self.layer_id = layer_id
self.time_shift = nn.ZeroPad2d((0,0,1,-1))
self.time_mix_k = nn.Parameter(torch.ones(1, 1, RWKV_CFG.n_embd))
self.time_mix_r = nn.Parameter(torch.ones(1, 1, RWKV_CFG.n_embd))
hidden_sz = 4 * RWKV_CFG.n_embd
self.key = nn.Linear(RWKV_CFG.n_embd, hidden_sz, bias=False)
self.receptance = nn.Linear(RWKV_CFG.n_embd, RWKV_CFG.n_embd, bias=False)
self.value = nn.Linear(hidden_sz, RWKV_CFG.n_embd, bias=False)
def forward(self, x):
xx = self.time_shift(x)
xk = x * self.time_mix_k + xx * (1 - self.time_mix_k)
xr = x * self.time_mix_r + xx * (1 - self.time_mix_r)
k = self.key(xk)
k = torch.square(torch.relu(k))
kv = self.value(k)
rkv = torch.sigmoid(self.receptance(xr)) * kv
return rkv
class RWKV_TimeMix(nn.Module):
def __init__(self, layer_id):
super().__init__()
self.layer_id = layer_id
self.time_decay = nn.Parameter(torch.ones(RWKV_CFG.n_embd, 1))
self.time_curve = torch.tensor([-(RWKV_CFG.ctx_len - 2 - i) for i in range(RWKV_CFG.ctx_len-1)]).unsqueeze(0)
self.time_first = nn.Parameter(torch.ones(RWKV_CFG.n_embd, 1) * math.log(0.3))
self.time_shift = nn.ZeroPad2d((0,0,1,-1))
self.time_mix_k = nn.Parameter(torch.ones(1,1,RWKV_CFG.n_embd))
self.time_mix_v = nn.Parameter(torch.ones(1,1,RWKV_CFG.n_embd))
self.time_mix_r = nn.Parameter(torch.ones(1,1,RWKV_CFG.n_embd))
self.key = nn.Linear(RWKV_CFG.n_embd, RWKV_CFG.n_embd, bias=False)
self.value = nn.Linear(RWKV_CFG.n_embd, RWKV_CFG.n_embd, bias=False)
self.receptance = nn.Linear(RWKV_CFG.n_embd, RWKV_CFG.n_embd, bias=False)
self.output = nn.Linear(RWKV_CFG.n_embd, RWKV_CFG.n_embd, bias=False)
def forward(self, x):
B, T, C = x.size()
xx = self.time_shift(x)
xk = x * self.time_mix_k + xx * (1 - self.time_mix_k)
xv = x * self.time_mix_v + xx * (1 - self.time_mix_v)
xr = x * self.time_mix_r + xx * (1 - self.time_mix_r)
k = self.key(xk).transpose(-1, -2)
v = self.value(xv).transpose(-1, -2)
r = self.receptance(xr)
k = torch.clamp(k, max=RWKV_K_CLAMP)
k = torch.exp(k)
kv = k * v
self.time_w = torch.cat([torch.exp(self.time_decay) * self.time_curve.to(self.time_decay.device), self.time_first], dim=-1)
w = torch.exp(self.time_w)
w = w[:,-T:].unsqueeze(1)
wkv = F.conv1d(nn.ZeroPad2d((T-1, 0, 0, 0))(kv), w, groups=C)
wk = F.conv1d(nn.ZeroPad2d((T-1, 0, 0, 0))(k), w, groups=C) + RWKV_K_EPS
rwkv = torch.sigmoid(r) * (wkv / wk).transpose(-1, -2)
rwkv = self.output(rwkv)
return rwkv
class Block(nn.Module):
def __init__(self, layer_id):
super().__init__()
self.layer_id = layer_id
self.ln1 = nn.LayerNorm(RWKV_CFG.n_embd)
self.ln2 = nn.LayerNorm(RWKV_CFG.n_embd)
if self.layer_id == 0:
self.ln0 = nn.LayerNorm(RWKV_CFG.n_embd)
if self.layer_id == 0 and RWKV_CFG.model_type == 'RWKV-ffnPre':
self.ffnPre = RWKV_ChannelMix(layer_id+1000)
else:
self.att = RWKV_TimeMix(layer_id)
self.ffn = RWKV_ChannelMix(layer_id)
def forward(self, x):
if self.layer_id == 0:
x = self.ln0(x)
if self.layer_id == 0 and RWKV_CFG.model_type == 'RWKV-ffnPre':
x = x + self.ffnPre(self.ln1(x))
else:
x = x + self.att(self.ln1(x))
x = x + self.ffn(self.ln2(x))
return x
class RWKV_GPT(nn.Module):
def __init__(self, MODEL_NAME, RUN_DEVICE, model_type, vocab_size, n_layer, n_embd, ctx_len):
global RWKV_CFG
super().__init__()
RWKV_CFG.RUN_DEVICE = RUN_DEVICE
RWKV_CFG.model_type = model_type
RWKV_CFG.vocab_size = vocab_size
RWKV_CFG.n_layer = n_layer
RWKV_CFG.n_embd = n_embd
RWKV_CFG.ctx_len = ctx_len
print('\nloading RWKV-GPT', MODEL_NAME)
self.emb = nn.Embedding(vocab_size, n_embd)
self.blocks = nn.Sequential(*[Block(i) for i in range(n_layer)])
self.ln_out = nn.LayerNorm(n_embd)
self.head = nn.Linear(n_embd, vocab_size, bias=False)
if RWKV_HEAD_QK_DIM > 0:
self.head_q = nn.Linear(n_embd, RWKV_HEAD_QK_DIM, bias=False)
self.head_q.scale_init = 0
self.head_k = nn.Linear(n_embd, RWKV_HEAD_QK_DIM, bias=False)
self.head_k.scale_init = 0.1
self.register_buffer("copy_mask", torch.tril(
torch.ones(ctx_len, ctx_len)))
self.ctx_len = ctx_len
self.eval()
self.load_state_dict(torch.load(MODEL_NAME + '.pth'))
self.eval()
def forward(self, idx):
B, T = idx.size()
assert T <= self.ctx_len, "Cannot forward, because len(input) > model ctx_len."
x = self.emb(idx)
x = self.blocks(x)
x = self.ln_out(x)
if RWKV_HEAD_QK_DIM > 0:
q = self.head_q(x)[:, :T, :]
k = self.head_k(x)[:, :T, :]
c = (q @ k.transpose(-2, -1)) * (1.0 / RWKV_HEAD_QK_DIM)
c = c.masked_fill(self.copy_mask[:T, :T] == 0, 0)
c = c @ F.one_hot(idx, num_classes=RWKV_CFG.vocab_size).float()
x = self.head(x) + c
else:
x = self.head(x)
return x
############################################################################################################
class RWKV_RNN():
def __init__(self, MODEL_NAME, RUN_DEVICE, model_type, n_layer, n_embd, ctx_len):
self.RUN_DEVICE = RUN_DEVICE
self.model_type = model_type
self.n_layer = n_layer
self.n_embd = n_embd
self.ctx_len = ctx_len
self.w = types.SimpleNamespace()
w = torch.load(MODEL_NAME + '.pth',
map_location=torch.device(RUN_DEVICE))
for x in w.keys():
if '.time_' in x:
w[x] = w[x].squeeze()
if '.time_decay' in x:
w[x] = torch.exp(-torch.exp(w[x]))
if '.time_first' in x:
w[x] = torch.exp(w[x])
if DEBUG_TIME and '.time_' in x:
print(x, w[x].squeeze().cpu().numpy())
xx = x.split('.')
here = self.w
for i in range(len(xx)):
if xx[i].isdigit():
ii = int(xx[i])
if ii not in here:
here[ii] = types.SimpleNamespace()
here = here[ii]
else:
if i == len(xx) - 1:
setattr(here, xx[i], w[x])
elif not hasattr(here, xx[i]):
if xx[i+1].isdigit():
setattr(here, xx[i], {})
else:
setattr(here, xx[i], types.SimpleNamespace())
here = getattr(here, xx[i])
self.clear()
def clear(self):
self.xx = {}
self.aa = {}
self.bb = {}
self.hk = None
def save(self, target):
target.xx = copy.deepcopy(self.xx)
target.aa = copy.deepcopy(self.aa)
target.bb = copy.deepcopy(self.bb)
target.hk = copy.deepcopy(self.hk)
def load(self, target):
self.xx = copy.deepcopy(target.xx)
self.aa = copy.deepcopy(target.aa)
self.bb = copy.deepcopy(target.bb)
self.hk = copy.deepcopy(target.hk)
def LN(self, xx, w):
return F.layer_norm(xx, (self.n_embd,), weight=w.weight, bias=w.bias)
def FF(self, xx, w, name):
if name not in self.xx:
self.xx[name] = torch.zeros(self.n_embd, device=self.RUN_DEVICE)
xk = xx * w.time_mix_k + self.xx[name] * (1 - w.time_mix_k)
xr = xx * w.time_mix_r + self.xx[name] * (1 - w.time_mix_r)
self.xx[name] = xx
r = torch.sigmoid(w.receptance.weight @ xr)
k = torch.square(torch.relu(w.key.weight @ xk))
kv = w.value.weight @ k
return r * kv
def SA(self, xx, w, name):
if name not in self.xx:
self.xx[name] = torch.zeros(self.n_embd, device=self.RUN_DEVICE)
self.aa[name] = torch.zeros(self.n_embd, device=self.RUN_DEVICE)
self.bb[name] = torch.zeros(self.n_embd, device=self.RUN_DEVICE)
xk = xx * w.time_mix_k + self.xx[name] * (1 - w.time_mix_k)
xv = xx * w.time_mix_v + self.xx[name] * (1 - w.time_mix_v)
xr = xx * w.time_mix_r + self.xx[name] * (1 - w.time_mix_r)
self.xx[name] = xx
r = torch.sigmoid(w.receptance.weight @ xr)
k = torch.exp(torch.clamp(w.key.weight @ xk, max=RWKV_K_CLAMP))
v = w.value.weight @ xv
kv = k * v
a = self.aa[name] + w.time_first * kv
b = self.bb[name] + w.time_first * k
self.aa[name] = w.time_decay * self.aa[name] + kv
self.bb[name] = w.time_decay * self.bb[name] + k
rwkv = r * a / (b + RWKV_K_EPS)
return w.output.weight @ rwkv
def run(self, ctx):
w = self.w
x = w.emb.weight[ctx[-1]]
for i in range(self.n_layer):
if i == 0:
x = self.LN(x, w.blocks[i].ln0)
if i == 0 and self.model_type == 'RWKV-ffnPre':
x = x + self.FF(self.LN(x, w.blocks[i].ln1), w.blocks[i].ffnPre, f'ffnPre.{i}')
else:
x = x + self.SA(self.LN(x, w.blocks[i].ln1), w.blocks[i].att, f'att.{i}')
x = x + self.FF(self.LN(x, w.blocks[i].ln2), w.blocks[i].ffn, f'ffn.{i}')
x = self.LN(x, w.ln_out)
if RWKV_HEAD_QK_DIM > 0:
if self.hk == None:
self.hk = (w.head_k.weight @ x).unsqueeze(0)
else:
self.hk = torch.cat(
[self.hk, (w.head_k.weight @ x).unsqueeze(0)], dim=0)
if self.hk.shape[0] > self.ctx_len:
self.hk = self.hk[-self.ctx_len:, :]
q = w.head_q.weight @ x
x = w.head.weight @ x
x = x.cpu().numpy().tolist()
c = (self.hk @ q) / RWKV_HEAD_QK_DIM
for i in range(len(c)):
x[ctx[i]] += c[i]
else:
x = w.head.weight @ x
x = x.cpu().numpy().tolist()
return x

View File

@ -0,0 +1,259 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"id": "937c4760-7ea9-43ce-8e53-77ab5c92f3b7",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n",
"* running on cpu\n"
]
}
],
"source": [
"########################################################################################################\n",
"# The RWKV Language Model - https://github.com/BlinkDL/RWKV-LM\n",
"########################################################################################################\n",
"\n",
"import numpy as np\n",
"import types\n",
"import copy\n",
"import torch\n",
"from torch.nn import functional as F\n",
"\n",
"from model import RWKV_RNN\n",
"\n",
"np.set_printoptions(precision=4, suppress=True, linewidth=200)"
]
},
{
"cell_type": "raw",
"id": "e4e4f51e-b01b-45ac-8418-4ba32460e305",
"metadata": {},
"source": [
"模型下载地址本脚本请用如下169m参数的不要用这个链接https://hf-mirror.com/BlinkDL/rwkv-3-pile-430m/resolve/main/RWKV-3-Pile-430M-20220817-10602.pth?download=true\n",
"模型下载地址用这个https://hf-mirror.com/BlinkDL/rwkv-3-pile-169m/resolve/main/RWKV-3-Pile-20220720-10704.pth?download=true"
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "cd2655e8-ed8f-4d2d-9852-5101cb0bf70e",
"metadata": {},
"outputs": [],
"source": [
"# ---> Edit src/model.py to set MODEL_NAME and CPU / CUDA mode <---\n",
"\n",
"TEMPERATURE = 1.0\n",
"TOP_P = 0.7\n",
"\n",
"DEBUG_DEBUG = False\n",
"LENGTH_OF_EACH = 222\n",
"NUM_TRIALS = 100\n",
"\n",
"context = '\\nDataWhalechina is an organization founded at Shanghai Jiao Tong University that helps learners learn artificial intelligence.'\n",
"\n",
"##############################################################################################################"
]
},
{
"cell_type": "raw",
"id": "97d34a8c-b8c6-4444-9c7c-61458b1663f9",
"metadata": {},
"source": [
"请在model.py中修改模型路径\n",
"例如我的路径是/data1/ckw/RWKV-3-Pile-430M-20220817-10602\n",
"如果想使用cuda加速请参考https://github.com/BlinkDL/RWKV-v2-RNN-Pile"
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "ebd826cd-633e-4c3d-aa15-daee44c438ae",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n",
"loading RWKV-RNN /data1/ckw/RWKV-3-Pile-20220720-10704\n",
"\n",
"--> Currently the first run takes a while if your prompt is long, as we are using RNN to process the prompt. Use GPT to build the hidden state for better speed. <--\n",
"\n"
]
}
],
"source": [
"##############################################################################################################\n",
"\n",
"model = RWKV_RNN()\n",
"print('\\n--> Currently the first run takes a while if your prompt is long, as we are using RNN to process the prompt. Use GPT to build the hidden state for better speed. <--\\n')"
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "60170c2d-be0b-43ef-9ed2-7cc5fbd1bab1",
"metadata": {},
"outputs": [],
"source": [
"def sample_logits(out, temperature=1.0, top_p=None):\n",
" probs = F.softmax(torch.tensor(out), dim=-1)\n",
" sorted_probs, _ = torch.sort(probs, descending=True)\n",
"\n",
" cumulative_probs = torch.cumsum(sorted_probs, dim=-1).numpy()\n",
" cutoff = float(sorted_probs[np.argmax(cumulative_probs > top_p)])\n",
" probs[probs < cutoff] = 0\n",
"\n",
" if temperature != 1.0:\n",
" probs = probs.pow(1.0 / temperature)\n",
"\n",
" return torch.multinomial(probs, num_samples=1)[0]"
]
},
{
"cell_type": "code",
"execution_count": 5,
"id": "a7465358-f05a-420a-b3af-69304d8de9f4",
"metadata": {},
"outputs": [],
"source": [
"NUM_TRIALS=3"
]
},
{
"cell_type": "code",
"execution_count": 6,
"id": "fdeb0be0-da7d-4d2b-abf1-54b5909bbb99",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n",
"DataWhalechina is an organization founded at Shanghai Jiao Tong University that helps learners learn artificial intelligence. The group has its own website which is the second largest data whalechina.com and it has been downloaded over 11 million times. It has a huge reach on social media and has over 2.5 million members. It was founded in September 2010 by co-founder and CEO Wang Tianjian.\n",
"\n",
"Anytime, when the topic is changing, people must change their minds and make a decision. Its no wonder that the internet has started to get heated in the past years. There is no other way to know that this information will not get leaked. It is an open secret that every member of the AI and Deep Learning community will benefit from this new social media platform.\n",
"\n",
"Although the data whalechina.com is not meant to be a home for AI researchers, it is still an important tool in the machine learning community and an important part of the Internets knowledge and knowledge. It is also a platform that has been used by a wide variety of people to learn about the topic.\n",
"\n",
"One of the most significant pieces of AI research was done by\n",
"----------------------------------------------------------------------\n",
"DataWhalechina is an organization founded at Shanghai Jiao Tong University that helps learners learn artificial intelligence. Datawhalechina helps you learn the concepts of artificial intelligence and analyze it. It offers a computer program to analyze a series of lectures and can analyze any situation to make a decision.\n",
"\n",
"Hongwei Online School offers computer science and engineering programs in a complex and energetic environment. It provides online courses, design software, online tutorials, and software development.\n",
"\n",
"Goethe University offers online classes in artificial intelligence, machine learning, and the latest trends in technology.\n",
"\n",
"About the project: Goethe University is the German think tank for technology, technology, technology, and society. We publish several books, several books, a web page, and a talk show. We make educational and scientific innovations in a collaborative fashion, but there is no way to control the movement of knowledge.\n",
"\n",
"We will use your feedback to improve the program, and improve our service.\n",
"\n",
"Goethe University is a pioneer in artificial intelligence in Germany. We help organizations achieve their goals through innovation, innovative methods, and positive collaboration.\n",
"\n",
"Here at Goethe University, we welcome new students from around the world\n",
"----------------------------------------------------------------------\n",
"DataWhalechina is an organization founded at Shanghai Jiao Tong University that helps learners learn artificial intelligence.\n",
"\n",
"And here is how they did it:\n",
"\n",
"> This application is free and open-source and allows you to build and distribute the software, the system software and the platform. The software is free and open-source.\n",
"\n",
"> It has been developed by the authors and has been published in the USENIX Internet Journal.\n",
"\n",
"> The main purpose of this platform is to build software that will enable human learning and data sharing in many different fields, including AI, robotics, energy, the Internet of Things and more.\n",
"\n",
"There is also a section about open source software and the project that is open-source. This is the teams final article.\n",
"\n",
"## Future Work\n",
"\n",
"This project is for all of you who are learning and teaching, in particular, the data-scientists and researchers that are working in the field. The author, Joao, has written a lot of open source software. In this project, he plans to do the following:\n",
"\n",
"- Create a site to collect data on data science and robotics.\n",
"\n",
"- Publ\n",
"----------------------------------------------------------------------"
]
}
],
"source": [
"for TRIAL in range(1 if DEBUG_DEBUG else NUM_TRIALS):\n",
" ctx = [model.tokenizer.encode(context)][0]\n",
" src_len = len(ctx)\n",
" print(context, end='')\n",
"\n",
" model.clear()\n",
" if TRIAL == 0: # build the RNN hidden state?\n",
" init_state = types.SimpleNamespace()\n",
" for i in range(src_len if DEBUG_DEBUG else src_len):\n",
" x = ctx[:i+1]\n",
" if i == src_len - 1:\n",
" init_state.out = model.run(x)\n",
" else:\n",
" model.run(x)\n",
" model.save(init_state)\n",
" else:\n",
" model.load(init_state)\n",
"\n",
" if DEBUG_DEBUG:\n",
" out = init_state.out\n",
" print('\\n', np.array(x), '==>', np.array(\n",
" out), np.max(out), np.min(out))\n",
"\n",
" for i in range(src_len, src_len + (0 if DEBUG_DEBUG else LENGTH_OF_EACH)):\n",
" x = ctx[:i+1]\n",
" x = x[-model.ctx_len:]\n",
"\n",
" if i == src_len:\n",
" out = copy.deepcopy(init_state.out) # load the RNN hidden state\n",
" else:\n",
" out = model.run(x) # run the RNN\n",
"\n",
" out[0] = -999999999 # disable <|endoftext|>\n",
"\n",
" char = sample_logits(out, temperature=TEMPERATURE, top_p=TOP_P)\n",
" char = char.item()\n",
" print(model.tokenizer.decode(char), end='', flush=True)\n",
"\n",
" ctx += [char]\n",
" print('\\n' + '-' * 70, end='')"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "c581c4a7-acfc-41b2-ace4-f6375e7e7ea6",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.5"
}
},
"nbformat": 4,
"nbformat_minor": 5
}

View File

@ -0,0 +1,122 @@
########################################################################################################
# The RWKV Language Model - https://github.com/BlinkDL/RWKV-LM
########################################################################################################
import json
import random
import time
import math
import numpy as np
import torch
import torch.nn as nn
from torch.nn import functional as F
from torch.utils.data import Dataset
class Dataset(Dataset):
def __init__(self, data, ctx_len, epoch_length_fixed):
print('building token list...', end=' ')
unique = sorted(list(set(data)))
# print()
# for u in unique:
# print(u, end=' ')
# print('\n\n')
xx = 0
xxObj = {}
for u in unique:
xxObj[xx] = u
xx += 1
with open('vocab.json', "w", encoding="utf-16") as vocab_file:
vocab_file.write(json.dumps(xxObj, ensure_ascii=False))
data_size, vocab_size = len(data), len(unique)
print('data has %d tokens, %d unique.' % (data_size, vocab_size))
self.stoi = {ch: i for i, ch in enumerate(unique)}
self.itos = {i: ch for i, ch in enumerate(unique)}
self.ctx_len = ctx_len
self.epoch_length_fixed = epoch_length_fixed
self.vocab_size = vocab_size
self.data = data
def __len__(self):
return self.epoch_length_fixed
def __getitem__(self, idx):
# cheat: pick a random spot in dataset
i = np.random.randint(0, len(self.data) - (self.ctx_len + 1))
chunk = self.data[i:i+self.ctx_len+1]
dix = [self.stoi[s] for s in chunk]
x = torch.tensor(dix[:-1], dtype=torch.long,
device=torch.device('cuda'))
y = torch.tensor(dix[1:], dtype=torch.long,
device=torch.device('cuda'))
return x, y
class TOKENIZER():
def __init__(self, WORD_NAME, UNKNOWN_CHAR='\ue083'):
with open(WORD_NAME + '.json', "r", encoding="utf-16") as result_file:
self.word_table = json.load(result_file)
self.vocab_size = len(self.word_table)
self.stoi = {v: int(k) for k, v in self.word_table.items()}
self.itos = {int(k): v for k, v in self.word_table.items()}
self.UNKNOWN_CHAR = self.stoi[UNKNOWN_CHAR]
def refine_context(self, context):
context = context.strip().split('\n')
for c in range(len(context)):
context[c] = context[c].strip().strip('\u3000').strip('\r')
context = list(filter(lambda c: c != '', context))
context = '\n' + ('\n'.join(context)).strip()
if context == '':
context = '\n'
return context
def sample_logits(self, out, x, ctx_len, temperature=1.0, top_p_usual=None, top_p_newline=None):
# out[self.UNKNOWN_CHAR] = -float('Inf')
lastChar = int(x[-1])
probs = F.softmax(torch.tensor(out), dim=-1)
if self.itos[lastChar] == '\n':
top_p = top_p_newline
else:
top_p = top_p_usual
sorted_probs, s_index = torch.sort(probs, descending=True)
# for j in range(30):
# pp = sorted_probs[j].item()
# if pp < 0.005:
# break
# ss = self.itos[int(s_index[j])].replace('\n','_')
# print(f'{math.floor(pp*100):>3.0f}{ss}', end='')
# print('')
cumulative_probs = torch.cumsum(sorted_probs, dim=-1).numpy()
cutoff = float(sorted_probs[np.argmax(cumulative_probs > top_p)])
probs[probs < cutoff] = 0
# print("[" + str(round(cutoff,4)) + ' ' + str(round(to_float(sum(probs)),3)) + "]", end = "")
if temperature != 1.0:
probs = probs.pow(1.0 / temperature)
return torch.multinomial(probs, num_samples=1)[0]
def to_float(x):
return x.cpu().detach().numpy().flatten()[0].astype(float)
def set_seed(seed):
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,297 @@
{
"cells": [
{
"cell_type": "raw",
"id": "bcd88fb5-6a0f-4c4b-81fd-34be59ea7903",
"metadata": {},
"source": [
"模型下载链接https://hf-mirror.com/BlinkDL/rwkv-4-pile-430m/resolve/main/RWKV-4-Pile-430M-20220808-8066.pth?download=true"
]
},
{
"cell_type": "code",
"execution_count": 1,
"id": "5b78b7ef-acc6-46cf-88c2-f90a2835e4b3",
"metadata": {},
"outputs": [],
"source": [
"########################################################################################################\n",
"# The RWKV Language Model - https://github.com/BlinkDL/RWKV-LM\n",
"########################################################################################################\n",
"\n",
"import numpy as np\n",
"np.set_printoptions(precision=4, suppress=True, linewidth=200)\n",
"import types, torch\n",
"from torch.nn import functional as F\n",
"from tokenizers import Tokenizer"
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "deacc22b-2896-4b77-b595-3284b0c13544",
"metadata": {},
"outputs": [],
"source": [
"tokenizer = Tokenizer.from_file(\"20B_tokenizer.json\")\n",
"\n",
"args = types.SimpleNamespace()\n",
"args.MODEL_NAME = '/data1/ckw/RWKV-4-Pile-430M-20220808-8066'\n",
"args.n_layer = 24\n",
"args.n_embd = 1024\n",
"\n",
"context = \"\\nDataWhalechina is an organization founded at Shanghai Jiao Tong University that helps learners learn artificial intelligence.\"\n",
"NUM_TRIALS = 3\n",
"LENGTH_PER_TRIAL = 100\n",
"TEMPERATURE = 1.0\n",
"TOP_P = 0.85\n",
"########################################################################################################"
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "0f1b2e2b-9f0d-4db3-b9d9-d43e3e2537ee",
"metadata": {},
"outputs": [],
"source": [
"class RWKV_RNN(torch.jit.ScriptModule):\n",
" def __init__(self, args):\n",
" super().__init__()\n",
" self.args = args\n",
" self.eval() # set torch to inference mode\n",
" \n",
" w = torch.load(args.MODEL_NAME + '.pth', map_location='cpu')\n",
" for k in w.keys():\n",
" if '.time_' in k: w[k] = w[k].squeeze()\n",
" if '.time_decay' in k: w[k] = -torch.exp(w[k].float()) # the real time decay is like e^{-e^x}\n",
" else: w[k] = w[k].float() # convert to f32 type\n",
" \n",
" self.w = types.SimpleNamespace() # set self.w from w\n",
" self.w.blocks = {}\n",
" for k in w.keys(): # example: \"blocks.0.att.time_first\" => self.w.blocks[0].att.time_first\n",
" parts = k.split('.')\n",
" last = parts.pop()\n",
" here = self.w\n",
" for p in parts:\n",
" if p.isdigit():\n",
" p = int(p)\n",
" if p not in here: here[p] = types.SimpleNamespace()\n",
" here = here[p]\n",
" else:\n",
" if not hasattr(here, p): setattr(here, p, types.SimpleNamespace())\n",
" here = getattr(here, p)\n",
" setattr(here, last, w[k])\n",
"\n",
" def layer_norm(self, x, w):\n",
" return F.layer_norm(x, (self.args.n_embd,), weight=w.weight, bias=w.bias)\n",
"\n",
" @torch.jit.script_method\n",
" def channel_mixing(self, x, state, i:int, time_mix_k, time_mix_r, kw, vw, rw):\n",
" xk = x * time_mix_k + state[5*i+0] * (1 - time_mix_k)\n",
" xr = x * time_mix_r + state[5*i+0] * (1 - time_mix_r)\n",
" state[5*i+0] = x\n",
" r = torch.sigmoid(rw @ xr)\n",
" k = torch.square(torch.relu(kw @ xk)) # square relu, primer paper\n",
" return r * (vw @ k)\n",
"\n",
" @torch.jit.script_method\n",
" def time_mixing(self, x, state, i:int, time_mix_k, time_mix_v, time_mix_r, time_first, time_decay, kw, vw, rw, ow):\n",
" xk = x * time_mix_k + state[5*i+1] * (1 - time_mix_k)\n",
" xv = x * time_mix_v + state[5*i+1] * (1 - time_mix_v)\n",
" xr = x * time_mix_r + state[5*i+1] * (1 - time_mix_r)\n",
" state[5*i+1] = x\n",
" r = torch.sigmoid(rw @ xr)\n",
" k = kw @ xk\n",
" v = vw @ xv\n",
" \n",
" aa = state[5*i+2]\n",
" bb = state[5*i+3]\n",
" pp = state[5*i+4]\n",
" ww = time_first + k\n",
" qq = torch.maximum(pp, ww)\n",
" e1 = torch.exp(pp - qq)\n",
" e2 = torch.exp(ww - qq)\n",
" a = e1 * aa + e2 * v\n",
" b = e1 * bb + e2\n",
" wkv = a / b\n",
" ww = pp + time_decay\n",
" qq = torch.maximum(ww, k)\n",
" e1 = torch.exp(ww - qq)\n",
" e2 = torch.exp(k - qq)\n",
" state[5*i+2] = e1 * aa + e2 * v\n",
" state[5*i+3] = e1 * bb + e2\n",
" state[5*i+4] = qq\n",
" return ow @ (r * wkv)\n",
"\n",
" def forward(self, token, state):\n",
" with torch.no_grad():\n",
" if state == None:\n",
" state = torch.zeros(self.args.n_layer * 5, self.args.n_embd)\n",
" for i in range(self.args.n_layer): state[5*i+4] = -1e30 # -infinity\n",
" \n",
" x = self.w.emb.weight[token]\n",
" x = self.layer_norm(x, self.w.blocks[0].ln0)\n",
" for i in range(self.args.n_layer):\n",
" att = self.w.blocks[i].att\n",
" x = x + self.time_mixing(self.layer_norm(x, self.w.blocks[i].ln1), state, i, \n",
" att.time_mix_k, att.time_mix_v, att.time_mix_r, att.time_first, att.time_decay, \n",
" att.key.weight, att.value.weight, att.receptance.weight, att.output.weight)\n",
" ffn = self.w.blocks[i].ffn\n",
" x = x + self.channel_mixing(self.layer_norm(x, self.w.blocks[i].ln2), state, i, \n",
" ffn.time_mix_k, ffn.time_mix_r, \n",
" ffn.key.weight, ffn.value.weight, ffn.receptance.weight)\n",
" \n",
" x = self.w.head.weight @ self.layer_norm(x, self.w.ln_out)\n",
" return x.float(), state\n",
"\n",
"##########################################################################################################"
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "fdf027b6-7df9-4c0f-818e-013e7c49e3cd",
"metadata": {},
"outputs": [],
"source": [
"def sample_logits(out, temperature=1.0, top_p=0.8):\n",
" probs = F.softmax(out, dim=-1).numpy()\n",
" sorted_probs = np.sort(probs)[::-1]\n",
" cumulative_probs = np.cumsum(sorted_probs)\n",
" cutoff = float(sorted_probs[np.argmax(cumulative_probs > top_p)])\n",
" probs[probs < cutoff] = 0\n",
" if temperature != 1.0:\n",
" probs = probs.pow(1.0 / temperature)\n",
" probs = probs / np.sum(probs)\n",
" out = np.random.choice(a=len(probs), p=probs)\n",
" return out\n",
"\n",
"########################################################################################################"
]
},
{
"cell_type": "code",
"execution_count": 6,
"id": "298dbbde-6535-406b-bd43-f2d886799f8c",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n",
"Using CPU. Loading /data1/ckw/RWKV-4-Pile-430M-20220808-8066 ...\n"
]
}
],
"source": [
"print(f'\\nUsing CPU. Loading {args.MODEL_NAME} ...')\n",
"model = RWKV_RNN(args)"
]
},
{
"cell_type": "code",
"execution_count": 7,
"id": "7d366a89-02cb-4b5e-95ef-52f6376d3607",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n",
"Preprocessing context (slow version. see v2/rwkv/model.py for fast version)\n"
]
}
],
"source": [
"print(f'\\nPreprocessing context (slow version. see v2/rwkv/model.py for fast version)')\n",
"init_state = None\n",
"for token in tokenizer.encode(context).ids:\n",
" init_out, init_state = model.forward(token, init_state)"
]
},
{
"cell_type": "code",
"execution_count": 8,
"id": "5273e7a8-875e-4998-b98e-f81951a7af32",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n",
"\n",
"--[ Trial 0 ]----------------- \n",
"DataWhalechina is an organization founded at Shanghai Jiao Tong University that helps learners learn artificial intelligence. Founded in 2015 by AI graduate student Yawei Li, DataWhalechina aims to help people learn to think more naturally about data. Learn more.\n",
"\n",
"50% of U.S. high school graduates who take data science courses go on to pursue masters degrees, which cost about $7,000, according to The American Council for an Energy Efficient Economy. Learn more.\n",
"\n",
"More than 600 startups compete for the same creative talent awards in 2016. If there is one award\n",
"\n",
"--[ Trial 1 ]----------------- \n",
"DataWhalechina is an organization founded at Shanghai Jiao Tong University that helps learners learn artificial intelligence. Datawhalechina was established in 2013. The company was created to serve the needs of companies that seek to increase the utilization of machine learning technology in their environments. This aims to create a platform that will help increase the adoption of machine learning technology in organizations by creating better decision support tools.\n",
"\n",
"As of 2017, Datawhalechina's team of specialists are spread over the United States, Europe, Asia, Africa and Canada. Their strategy includes providing low-cost software solutions to the\n",
"\n",
"--[ Trial 2 ]----------------- \n",
"DataWhalechina is an organization founded at Shanghai Jiao Tong University that helps learners learn artificial intelligence.\n",
"The main objective of the organization is to provide diverse students with the information, skills, knowledge and ideas needed to tackle big challenges in their future. The success of the program has prompted the city government to give them more resources to bring more students in the program.\n",
"\n",
"Ethereum\n",
"\n",
"The Ethereum (ETH) blockchain, designed by XRP, is a decentralised ledger technology that enables Bitcoin (BTC) and other cryptocurrencies to be used as payment. It aims to be the largest\n",
"\n"
]
}
],
"source": [
"for TRIAL in range(NUM_TRIALS):\n",
" print(f'\\n\\n--[ Trial {TRIAL} ]-----------------', context, end=\"\")\n",
" all_tokens = []\n",
" out_last = 0\n",
" out, state = init_out.clone(), init_state.clone()\n",
" for i in range(LENGTH_PER_TRIAL):\n",
" token = sample_logits(out, TEMPERATURE, TOP_P)\n",
" all_tokens += [token]\n",
" tmp = tokenizer.decode(all_tokens[out_last:])\n",
" if '\\ufffd' not in tmp: # only print when we have a valid utf-8 string\n",
" print(tmp, end=\"\", flush=True)\n",
" out_last = i + 1\n",
" out, state = model.forward(token, state) \n",
"print('\\n')"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "d546bfd9-cf80-49bf-8c76-f3918d7d67e4",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.5"
}
},
"nbformat": 4,
"nbformat_minor": 5
}

View File

@ -0,0 +1,546 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"id": "1fb76974-93ea-4b9c-81b1-55f826e7a361",
"metadata": {},
"outputs": [],
"source": [
"########################################################################################################\n",
"# The RWKV Language Model - https://github.com/BlinkDL/RWKV-LM\n",
"########################################################################################################\n",
"\n",
"import numpy as np\n",
"np.set_printoptions(precision=4, suppress=True, linewidth=200)\n",
"import types, torch\n",
"import torch.nn as nn\n",
"from torch.nn import functional as F\n",
"\n",
"MyModule = torch.jit.ScriptModule\n",
"MyFunction = torch.jit.script_method"
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "b1059eca-db4f-4c0b-ae3e-37af49ec7fa1",
"metadata": {},
"outputs": [],
"source": [
"import torch"
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "1c8d8009-7ee7-4419-aacb-cdc45f287010",
"metadata": {},
"outputs": [],
"source": [
"class RWKV_TOKENIZER():\n",
" table: list[list[list[bytes]]]\n",
" good: list[set[int]]\n",
" wlen: list[int]\n",
" def __init__(self, file_name):\n",
" self.idx2token = {}\n",
" sorted = [] # must be already sorted\n",
" lines = open(file_name, \"r\", encoding=\"utf-8\").readlines()\n",
" for l in lines:\n",
" idx = int(l[:l.index(' ')])\n",
" x = eval(l[l.index(' '):l.rindex(' ')])\n",
" x = x.encode(\"utf-8\") if isinstance(x, str) else x\n",
" assert isinstance(x, bytes)\n",
" assert len(x) == int(l[l.rindex(' '):])\n",
" sorted += [x]\n",
" self.idx2token[idx] = x\n",
"\n",
" self.token2idx = {}\n",
" for k, v in self.idx2token.items():\n",
" self.token2idx[v] = int(k)\n",
"\n",
" # precompute some tables for fast matching\n",
" self.table = [[[] for j in range(256)] for i in range(256)]\n",
" self.good = [set() for i in range(256)]\n",
" self.wlen = [0 for i in range(256)]\n",
"\n",
" for i in reversed(range(len(sorted))): # reverse order - match longer tokens first\n",
" s = sorted[i]\n",
" if len(s) >= 2:\n",
" s0 = int(s[0])\n",
" s1 = int(s[1])\n",
" self.table[s0][s1] += [s]\n",
" self.wlen[s0] = max(self.wlen[s0], len(s))\n",
" self.good[s0].add(s1)\n",
"\n",
" def encodeBytes(self, src: bytes) -> list[int]:\n",
" src_len: int = len(src)\n",
" tokens: list[int] = []\n",
" i: int = 0\n",
" while i < src_len:\n",
" s: bytes = src[i : i + 1]\n",
"\n",
" if i < src_len - 1:\n",
" s1: int = int(src[i + 1])\n",
" s0: int = int(src[i])\n",
" if s1 in self.good[s0]:\n",
" sss: bytes = src[i : i + self.wlen[s0]]\n",
" try:\n",
" s = next(filter(sss.startswith, self.table[s0][s1]))\n",
" except:\n",
" pass\n",
" tokens.append(self.token2idx[s])\n",
" i += len(s)\n",
"\n",
" return tokens\n",
"\n",
" def decodeBytes(self, tokens):\n",
" return b''.join(map(lambda i: self.idx2token[i], tokens))\n",
"\n",
" def encode(self, src: str):\n",
" return self.encodeBytes(src.encode(\"utf-8\"))\n",
"\n",
" def decode(self, tokens):\n",
" return self.decodeBytes(tokens).decode('utf-8')\n",
"\n",
" def printTokens(self, tokens):\n",
" for i in tokens:\n",
" s = self.idx2token[i]\n",
" try:\n",
" s = s.decode('utf-8')\n",
" except:\n",
" pass\n",
" print(f'{repr(s)}{i}', end=' ')\n",
" # print(repr(s), i)\n",
" print()\n",
"\n",
"########################################################################################################"
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "63a4e8ba-a291-4fdc-aef1-ebfca21840d4",
"metadata": {},
"outputs": [],
"source": [
"def sample_logits(out, temperature=1.0, top_p=0.8):\n",
" probs = F.softmax(out, dim=-1).numpy()\n",
" sorted_probs = np.sort(probs)[::-1]\n",
" cumulative_probs = np.cumsum(sorted_probs)\n",
" cutoff = float(sorted_probs[np.argmax(cumulative_probs > top_p)])\n",
" probs[probs < cutoff] = 0\n",
" if temperature != 1.0:\n",
" probs = probs.pow(1.0 / temperature)\n",
" probs = probs / np.sum(probs)\n",
" out = np.random.choice(a=len(probs), p=probs)\n",
" return out\n",
"\n",
"########################################################################################################"
]
},
{
"cell_type": "raw",
"id": "cb8c7d5e-08cb-4780-b6d9-ab8bad1417d4",
"metadata": {},
"source": [
"可以从这个链接下载模型:\n",
"https://www.modelscope.cn/models/AI-ModelScope/rwkv-5-world/files\n",
"https://www.modelscope.cn/api/v1/models/AI-ModelScope/rwkv-5-world/repo?Revision=master&FilePath=RWKV-5-World-0.1B-v1-20230803-ctx4096.pth"
]
},
{
"cell_type": "code",
"execution_count": 68,
"id": "94d7d6db-e89e-4209-ae72-6625ba85ef5b",
"metadata": {},
"outputs": [],
"source": [
"tokenizer = RWKV_TOKENIZER(\"./rwkv_vocab_v20230424.txt\")\n",
"\n",
"# THIS IS NOW UPDATED TO SUPPORT LATEST RWKV-5 WORLD v2 MODELS\n",
"\n",
"args = types.SimpleNamespace()\n",
"args.MODEL_NAME = '/data1/ckw/RWKV-5-World-0.4B-v2-20231113-ctx4096' #这里不用有后缀.pth\n",
"args.n_layer = 24\n",
"args.n_embd = 1024\n",
"args.vocab_size = 65536"
]
},
{
"cell_type": "code",
"execution_count": 69,
"id": "c8dcf39a-7838-454b-85fc-ec9bd75fa243",
"metadata": {},
"outputs": [],
"source": [
"# N_LAYER=\"12\"\n",
"# N_EMBD=\"768\"\n",
"N_LAYER=\"24\"\n",
"N_EMBD=\"1024\""
]
},
{
"cell_type": "code",
"execution_count": 70,
"id": "74d7c96a-6fbc-401c-8078-fefb1a6ec5c3",
"metadata": {},
"outputs": [],
"source": [
"# context = \"\\nElon Musk has\"\n",
"# context = \"\\n我们发现\"\n",
"context = \"Q:Do you know datawhalechina?\\nA:\"\n",
"NUM_TRIALS = 3\n",
"LENGTH_PER_TRIAL = 100\n",
"LENGTH_PER_TRIAL = 4096\n",
"TEMPERATURE = 1.0\n",
"TOP_P = 0.7"
]
},
{
"cell_type": "code",
"execution_count": 80,
"id": "bd093a96-fdc5-460d-b39f-fe3735795b42",
"metadata": {},
"outputs": [],
"source": [
"class RWKV_RNN(MyModule):\n",
" def __init__(self, args):\n",
" super().__init__()\n",
" self.args = args\n",
" self.eval() # set torch to inference mode\n",
" \n",
" w = torch.load(args.MODEL_NAME + '.pth', map_location='cpu')\n",
" for k in w.keys():\n",
" w[k] = w[k].float() # convert to f32 type\n",
" if '.time_' in k: w[k] = w[k].squeeze()\n",
" if '.time_decay' in k: w[k] = torch.exp(-torch.exp(w[k])).unsqueeze(-1)\n",
" if '.time_faaaa' in k: w[k] = w[k].unsqueeze(-1)\n",
"\n",
" self.n_head = w['blocks.0.att.time_decay'].shape[0]\n",
" self.head_size = w['blocks.0.ln1.weight'].shape[0] // self.n_head\n",
" \n",
" self.w = types.SimpleNamespace() # set self.w from w\n",
" self.w.blocks = {}\n",
" for k in w.keys(): # example: \"blocks.0.att.time_first\" => self.w.blocks[0].att.time_first\n",
" parts = k.split('.')\n",
" last = parts.pop()\n",
" here = self.w\n",
" for p in parts:\n",
" if p.isdigit():\n",
" p = int(p)\n",
" if p not in here: here[p] = types.SimpleNamespace()\n",
" here = here[p]\n",
" else:\n",
" if not hasattr(here, p): setattr(here, p, types.SimpleNamespace())\n",
" here = getattr(here, p)\n",
" setattr(here, last, w[k])\n",
"\n",
" def layer_norm(self, x, w):\n",
" return F.layer_norm(x, (self.args.n_embd,), weight=w.weight, bias=w.bias)\n",
"\n",
" @MyFunction\n",
" def channel_mixing(self, x, state, i:int, time_mix_k, time_mix_r, kw, vw, rw):\n",
" i0 = (2+self.head_size)*i+0\n",
" xk = x * time_mix_k + state[i0] * (1 - time_mix_k)\n",
" xr = x * time_mix_r + state[i0] * (1 - time_mix_r)\n",
" state[i0] = x\n",
" r = torch.sigmoid(rw @ xr)\n",
" k = torch.square(torch.relu(kw @ xk)) # square relu, primer paper\n",
" return r * (vw @ k)\n",
"\n",
" @MyFunction\n",
" def time_mixing(self, x, state, i:int, time_mix_k, time_mix_v, time_mix_r, time_mix_g, time_first, time_decay, kw, vw, rw, gw, ow, ln_w, ln_b):\n",
" H = self.n_head\n",
" S = self.head_size\n",
"\n",
" i1 = (2+S)*i+1\n",
" xk = x * time_mix_k + state[i1] * (1 - time_mix_k)\n",
" xv = x * time_mix_v + state[i1] * (1 - time_mix_v)\n",
" xr = x * time_mix_r + state[i1] * (1 - time_mix_r)\n",
" xg = x * time_mix_g + state[i1] * (1 - time_mix_g)\n",
" state[i1] = x\n",
"\n",
" r = (rw @ xr).view(H, 1, S)\n",
" k = (kw @ xk).view(H, S, 1)\n",
" v = (vw @ xv).view(H, 1, S)\n",
" g = F.silu(gw @ xg)\n",
"\n",
" s = state[(2+S)*i+2:(2+S)*(i+1), :].reshape(H, S, S)\n",
"\n",
" x = torch.zeros(H, S)\n",
" a = k @ v\n",
" x = r @ (time_first * a + s)\n",
" s = a + time_decay * s\n",
" \n",
" state[(2+S)*i+2:(2+S)*(i+1), :] = s.reshape(S, -1)\n",
" x = x.flatten()\n",
"\n",
" x = F.group_norm(x.unsqueeze(0), num_groups=H, weight=ln_w, bias=ln_b, eps = 64e-5).squeeze(0) * g # same as gn(x/8, eps=1e-5)\n",
" return ow @ x\n",
"\n",
" def forward(self, token, state):\n",
" with torch.no_grad():\n",
" if state == None:\n",
" state = torch.zeros(self.args.n_layer * (2+self.head_size), self.args.n_embd)\n",
" \n",
" x = self.w.emb.weight[token]\n",
" x = self.layer_norm(x, self.w.blocks[0].ln0)\n",
" for i in range(self.args.n_layer):\n",
" # print(i)\n",
" att = self.w.blocks[i].att\n",
" x = x + self.time_mixing(self.layer_norm(x, self.w.blocks[i].ln1), state, i, \n",
" att.time_mix_k, att.time_mix_v, att.time_mix_r, att.time_mix_g, att.time_faaaa, att.time_decay, \n",
" att.key.weight, att.value.weight, att.receptance.weight, att.gate.weight, att.output.weight,\n",
" att.ln_x.weight, att.ln_x.bias)\n",
" ffn = self.w.blocks[i].ffn\n",
" x = x + self.channel_mixing(self.layer_norm(x, self.w.blocks[i].ln2), state, i, \n",
" ffn.time_mix_k, ffn.time_mix_r, \n",
" ffn.key.weight, ffn.value.weight, ffn.receptance.weight)\n",
" \n",
" x = self.w.head.weight @ self.layer_norm(x, self.w.ln_out)\n",
" return x.float(), state"
]
},
{
"cell_type": "code",
"execution_count": 81,
"id": "a330cd34-7ed0-4a6c-92a3-19797d34ee77",
"metadata": {},
"outputs": [],
"source": [
"context = \"Q:Do you know datawhalechina?\\nA:\""
]
},
{
"cell_type": "code",
"execution_count": 82,
"id": "ad824161-413d-460c-9ffe-9dbfb739f86b",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"'/data1/ckw/RWKV-5-World-0.4B-v2-20231113-ctx4096'"
]
},
"execution_count": 82,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"args.MODEL_NAME"
]
},
{
"cell_type": "code",
"execution_count": 83,
"id": "f0e2f841-4cda-48d4-b055-7adf00f2fe73",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"(24, 1024)"
]
},
"execution_count": 83,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"args.n_layer,args.n_embd"
]
},
{
"cell_type": "code",
"execution_count": 84,
"id": "aba8a4d4-9a77-4191-a7ef-d5e6100ca3c1",
"metadata": {},
"outputs": [],
"source": [
"# args.n_layer = 24\n",
"# args.n_embd = 1024"
]
},
{
"cell_type": "code",
"execution_count": 85,
"id": "dd44f7bc-e8d6-4242-beb5-89a866990751",
"metadata": {},
"outputs": [],
"source": [
"# args.n_layer = 12\n",
"# args.n_embd = 768"
]
},
{
"cell_type": "code",
"execution_count": 86,
"id": "2a96d9dc-8b5e-40cc-bb36-24c9bdeac29e",
"metadata": {},
"outputs": [],
"source": [
"# args.MODEL_NAME='../models/rwkv-5-world-1b5'"
]
},
{
"cell_type": "code",
"execution_count": 87,
"id": "b7d07606-31b4-4c21-9f89-554d89c2c866",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n",
"Using CPU. Loading /data1/ckw/RWKV-5-World-0.4B-v2-20231113-ctx4096 ...\n",
"\n",
"Preprocessing context (slow version. see v2/rwkv/model.py for fast version)\n"
]
}
],
"source": [
"print(f'\\nUsing CPU. Loading {args.MODEL_NAME} ...')\n",
"model = RWKV_RNN(args)\n",
"\n",
"print(f'\\nPreprocessing context (slow version. see v2/rwkv/model.py for fast version)')\n",
"init_state = None"
]
},
{
"cell_type": "code",
"execution_count": 88,
"id": "ce42cfad-0274-4d5d-950d-fb89ff11ed2c",
"metadata": {},
"outputs": [],
"source": [
"init_state = None"
]
},
{
"cell_type": "code",
"execution_count": 89,
"id": "3e02a81c-1447-4936-a241-4d00ecf8e862",
"metadata": {},
"outputs": [],
"source": [
"LENGTH_PER_TRIAL=1024"
]
},
{
"cell_type": "code",
"execution_count": 90,
"id": "4a00ea05-d6fd-4052-b13a-8107fb268420",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n",
"\n",
"--[ Trial 0 ]----------------- Q:Do you know datawhalechina?\n",
"A:If you have your money, go to the blue whale website. It has information on whales and its feeding habits. The info is up-to-date and well-researched.\n",
"https://www.who.int/news-room/fact-sheets/detail/blue-whale\n",
"It's a nice fish. It's great for eating.\n",
"\n",
"--[ Trial 1 ]----------------- Q:Do you know datawhalechina?\n",
"A:A very old and very great one.\n",
"http://www.cnn.com/2008/WORLD/asia/10/17/china.beach.tourist.casino.cnn/index.html\n",
"\n",
"--[ Trial 2 ]----------------- Q:Do you know datawhalechina?\n",
"A:The datawhale china is a chinese based data analytics and decision making company, they work with many big companies in China. They work with large numbers of large companies in china. They can provide companies with data, analytics and knowledge about companies.\n",
"Q:How do you find datawhale china?\n",
"A:We use a variety of sources to find companies in china. We search for companies based on a variety of criteria. We look for companies with a specific industry or product. We also use a variety of data sources to find companies in china.\n",
"Q:What kind of data do you use to find companies in china?\n",
"A:We use a variety of data sources to find companies in china. We look for companies with a specific industry or product. We also use a variety of sources to find companies in china.\n",
"Q:How do you know if a company is in china?\n",
"A:We use a variety of sources to find companies in china. We look for companies based on a variety of criteria. We also use a variety of sources to find companies in china.\n",
"Q:What are some of the advantages of using datawhale china?\n",
"A:We use datawhale china to find companies in china. We use a variety of sources to find companies in china. We also use a variety of sources to find companies in china.\n",
"Q:What is the purpose of using datawhale china?\n",
"A:We use datawhale china to find companies in china. We use a variety of sources to find companies in china. We also use a variety of sources to find companies in china.\n",
"Q:What are some of the advantages of using datawhale china?\n",
"A:We use datawhale china to find companies in china. We use a variety of sources to find companies in china. We also use a variety of sources to find companies in china.\n",
"Q:What are some of the disadvantages of using datawhale china?\n",
"A:We use datawhale china to find companies in china. We use a variety of sources to find companies in china. We also use a variety of sources to find companies in china.\n",
"Q:What are some of the advantages of using datawhale china?\n",
"A:We use datawhale china to find companies in china. We use a variety of sources to find companies in china. We also use a variety of sources to find companies in china.\n",
"Q:What are some of the disadvantages of using datawhale china?\n",
"A:We use datawhale china to find companies in china. We use a variety of sources to find companies in china. We also use a variety of sources to find companies in china.\n",
"Q:What are some of the disadvantages of using datawhale china?\n",
"A:We use datawhale china to find companies in china. We use a variety of sources to find companies in china. We also use a variety of sources to find companies in china.\n",
"Q:What are some of the disadvantages of using datawhale china?\n",
"A:We use datawhale china to find companies in china. We use a variety of sources to find companies in china. We also use a variety of sources to find companies in china.\n",
"Q:What are some of the advantages of using datawhale china?\n",
"A:We use datawhale china to find companies in china. We use a variety of sources to find companies in china. We also use a variety of sources to find companies in china.\n",
"Q:What are some of the disadvantages of using datawhale china?\n",
"A:We use datawhale china to find companies in china. We use a variety of sources to find companies in china. We also use a variety of sources to find companies in china.\n",
"Q:What are some of the disadvantages of using datawhale china?\n",
"A:We use datawhale china to find companies in china. We use a variety of sources to find companies in china. We also use a variety of sources to find companies in china.\n",
"Q:What are some of the disadvantages of using datawhale china?\n",
"A:We use datawhale china to find companies in china. We use a variety of sources to find companies in china. We also use a variety of sources to find companies in china.\n",
"Q:What are some of the disadvantages of using datawhale china?\n",
"A:We use datawhale china to find companies in china. We use a variety of sources to find companies in china. We also use a variety of sources to find companies in china.\n",
"\n"
]
}
],
"source": [
"for token in tokenizer.encode(context):\n",
" init_out, init_state = model.forward(token, init_state)\n",
"\n",
"for TRIAL in range(NUM_TRIALS):\n",
" print(f'\\n\\n--[ Trial {TRIAL} ]-----------------', context, end=\"\")\n",
" all_tokens = []\n",
" out_last = 0\n",
" out, state = init_out.clone(), init_state.clone()\n",
" for i in range(LENGTH_PER_TRIAL):\n",
" token = sample_logits(out, TEMPERATURE, TOP_P)\n",
" all_tokens += [token]\n",
" try:\n",
" tmp = tokenizer.decode(all_tokens[out_last:])\n",
" if '\\ufffd' not in tmp: # only print when we have a valid utf-8 string\n",
" print(tmp, end=\"\", flush=True)\n",
" out_last = i + 1\n",
" except:\n",
" pass\n",
" out, state = model.forward(token, state) \n",
"print('\\n')"
]
},
{
"cell_type": "markdown",
"id": "bfb6dc7a-786f-427b-9aa9-b51f2771b0cc",
"metadata": {},
"source": [
"显然datawhale这个数据没有训练过哈哈。不过速度是蛮快的这个没得说在cpu上跑资源消耗也很小。"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "kewei-ai",
"language": "python",
"name": "kewei-ai"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.5"
}
},
"nbformat": 4,
"nbformat_minor": 5
}

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,396 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": 5,
"id": "f64be1c0-02a8-4ea9-ae05-85b66e803cac",
"metadata": {},
"outputs": [],
"source": [
"########################################################################################################\n",
"# The RWKV Language Model - https://github.com/BlinkDL/RWKV-LM\n",
"########################################################################################################\n",
"\n",
"import numpy as np\n",
"np.set_printoptions(precision=4, suppress=True, linewidth=200)\n",
"import types, torch\n",
"import torch.nn as nn\n",
"from torch.nn import functional as F\n",
"\n",
"MyModule = torch.jit.ScriptModule\n",
"MyFunction = torch.jit.script_method"
]
},
{
"cell_type": "code",
"execution_count": 6,
"id": "1261d4e1-df4e-410b-a4fa-45452c4b6fb1",
"metadata": {},
"outputs": [],
"source": [
"class RWKV_TOKENIZER():\n",
" table: list[list[list[bytes]]]\n",
" good: list[set[int]]\n",
" wlen: list[int]\n",
" def __init__(self, file_name):\n",
" self.idx2token = {}\n",
" sorted = [] # must be already sorted\n",
" lines = open(file_name, \"r\", encoding=\"utf-8\").readlines()\n",
" for l in lines:\n",
" idx = int(l[:l.index(' ')])\n",
" x = eval(l[l.index(' '):l.rindex(' ')])\n",
" x = x.encode(\"utf-8\") if isinstance(x, str) else x\n",
" assert isinstance(x, bytes)\n",
" assert len(x) == int(l[l.rindex(' '):])\n",
" sorted += [x]\n",
" self.idx2token[idx] = x\n",
"\n",
" self.token2idx = {}\n",
" for k, v in self.idx2token.items():\n",
" self.token2idx[v] = int(k)\n",
"\n",
" # precompute some tables for fast matching\n",
" self.table = [[[] for j in range(256)] for i in range(256)]\n",
" self.good = [set() for i in range(256)]\n",
" self.wlen = [0 for i in range(256)]\n",
"\n",
" for i in reversed(range(len(sorted))): # reverse order - match longer tokens first\n",
" s = sorted[i]\n",
" if len(s) >= 2:\n",
" s0 = int(s[0])\n",
" s1 = int(s[1])\n",
" self.table[s0][s1] += [s]\n",
" self.wlen[s0] = max(self.wlen[s0], len(s))\n",
" self.good[s0].add(s1)\n",
"\n",
" def encodeBytes(self, src: bytes) -> list[int]:\n",
" src_len: int = len(src)\n",
" tokens: list[int] = []\n",
" i: int = 0\n",
" while i < src_len:\n",
" s: bytes = src[i : i + 1]\n",
"\n",
" if i < src_len - 1:\n",
" s1: int = int(src[i + 1])\n",
" s0: int = int(src[i])\n",
" if s1 in self.good[s0]:\n",
" sss: bytes = src[i : i + self.wlen[s0]]\n",
" try:\n",
" s = next(filter(sss.startswith, self.table[s0][s1]))\n",
" except:\n",
" pass\n",
" tokens.append(self.token2idx[s])\n",
" i += len(s)\n",
"\n",
" return tokens\n",
"\n",
" def decodeBytes(self, tokens):\n",
" return b''.join(map(lambda i: self.idx2token[i], tokens))\n",
"\n",
" def encode(self, src: str):\n",
" return self.encodeBytes(src.encode(\"utf-8\"))\n",
"\n",
" def decode(self, tokens):\n",
" return self.decodeBytes(tokens).decode('utf-8')\n",
"\n",
" def printTokens(self, tokens):\n",
" for i in tokens:\n",
" s = self.idx2token[i]\n",
" try:\n",
" s = s.decode('utf-8')\n",
" except:\n",
" pass\n",
" print(f'{repr(s)}{i}', end=' ')\n",
" # print(repr(s), i)\n",
" print()\n",
"\n",
"########################################################################################################"
]
},
{
"cell_type": "code",
"execution_count": 7,
"id": "725bc55e-7f3f-4c1c-9664-ad84bf68e943",
"metadata": {},
"outputs": [],
"source": [
"def sample_logits(out, temperature=1.0, top_p=0.8):\n",
" probs = F.softmax(out, dim=-1).numpy()\n",
" sorted_probs = np.sort(probs)[::-1]\n",
" cumulative_probs = np.cumsum(sorted_probs)\n",
" cutoff = float(sorted_probs[np.argmax(cumulative_probs > top_p)])\n",
" probs[probs < cutoff] = 0\n",
" if temperature != 1.0:\n",
" probs = probs.pow(1.0 / temperature)\n",
" probs = probs / np.sum(probs)\n",
" out = np.random.choice(a=len(probs), p=probs)\n",
" return out\n",
"\n",
"########################################################################################################"
]
},
{
"cell_type": "raw",
"id": "812fac97-a6b8-423c-831d-fe7397883437",
"metadata": {},
"source": [
"模型下载地址https://hf-mirror.com/BlinkDL/rwkv-6-world/resolve/main/RWKV-x060-World-1B6-v2.1-20240328-ctx4096.pth"
]
},
{
"cell_type": "code",
"execution_count": 13,
"id": "434ff88e-b94e-4f8b-86a3-7fefca32cddb",
"metadata": {},
"outputs": [],
"source": [
"tokenizer = RWKV_TOKENIZER(\"./rwkv_vocab_v20230424.txt\")\n",
"\n",
"args = types.SimpleNamespace()\n",
"args.MODEL_NAME = '/data1/ckw/RWKV-x060-World-1B6-v2.1-20240328-ctx4096'\n",
"args.n_layer = 24\n",
"args.n_embd = 2048\n",
"args.vocab_size = 65536\n",
"\n",
"context = \"\\nDatawhale is \"\n",
"# context = \"\\n我们发现\"\n",
"NUM_TRIALS = 3\n",
"LENGTH_PER_TRIAL = 100\n",
"TEMPERATURE = 1.0\n",
"TOP_P = 0.7"
]
},
{
"cell_type": "code",
"execution_count": 14,
"id": "3869854a-a4e3-4652-9698-b0d81bbbd645",
"metadata": {},
"outputs": [],
"source": [
"class RWKV_RNN(MyModule):\n",
" def __init__(self, args):\n",
" super().__init__()\n",
" self.args = args\n",
" self.eval() # set torch to inference mode\n",
" \n",
" w = torch.load(args.MODEL_NAME + '.pth', map_location='cpu')\n",
"\n",
" for k in w.keys():\n",
" w[k] = w[k].float() # convert to f32 type\n",
" if '.time_' in k: w[k] = w[k].squeeze()\n",
" if '.time_faaaa' in k: w[k] = w[k].unsqueeze(-1)\n",
"\n",
" self.n_head = w['blocks.0.att.time_faaaa'].shape[0]\n",
" self.head_size = w['blocks.0.ln1.weight'].shape[0] // self.n_head\n",
" \n",
" self.w = types.SimpleNamespace() # set self.w from w\n",
" self.w.blocks = {}\n",
" for k in w.keys(): # example: \"blocks.0.att.time_first\" => self.w.blocks[0].att.time_first\n",
" parts = k.split('.')\n",
" last = parts.pop()\n",
" here = self.w\n",
" for p in parts:\n",
" if p.isdigit():\n",
" p = int(p)\n",
" if p not in here: here[p] = types.SimpleNamespace()\n",
" here = here[p]\n",
" else:\n",
" if not hasattr(here, p): setattr(here, p, types.SimpleNamespace())\n",
" here = getattr(here, p)\n",
" setattr(here, last, w[k])\n",
"\n",
" def layer_norm(self, x, w):\n",
" return F.layer_norm(x, (self.args.n_embd,), weight=w.weight, bias=w.bias)\n",
"\n",
" @MyFunction\n",
" def channel_mixing(self, x, state, i:int, time_maa_k, time_maa_r, kw, vw, rw):\n",
" i0 = (2+self.head_size)*i+0\n",
" sx = state[i0] - x\n",
" xk = x + sx * time_maa_k\n",
" xr = x + sx * time_maa_r\n",
" state[i0] = x\n",
" r = torch.sigmoid(rw @ xr)\n",
" k = torch.square(torch.relu(kw @ xk)) # square relu, primer paper\n",
" return r * (vw @ k)\n",
"\n",
" @MyFunction\n",
" def time_mixing(self, x, state, i:int, x_maa, w_maa, k_maa, v_maa, r_maa, g_maa, tm_w1, tm_w2, td_w1, td_w2, time_first, time_decay, kw, vw, rw, gw, ow, ln_w, ln_b):\n",
" H = self.n_head\n",
" S = self.head_size\n",
"\n",
" i1 = (2+S)*i+1\n",
" sx = state[i1] - x\n",
" state[i1] = x\n",
" xxx = x + sx * x_maa\n",
" xxx = torch.tanh(xxx @ tm_w1).view(5, 1, -1)\n",
" xxx = torch.bmm(xxx, tm_w2).view(5, -1)\n",
" mw, mk, mv, mr, mg = xxx.unbind(dim=0)\n",
"\n",
" xw = x + sx * (w_maa + mw)\n",
" xk = x + sx * (k_maa + mk)\n",
" xv = x + sx * (v_maa + mv)\n",
" xr = x + sx * (r_maa + mr)\n",
" xg = x + sx * (g_maa + mg)\n",
"\n",
" w = (time_decay + (torch.tanh(xw @ td_w1) @ td_w2).float()).view(H, S, 1)\n",
" w = torch.exp(-torch.exp(w.float()))\n",
"\n",
" r = (rw @ xr).view(H, 1, S)\n",
" k = (kw @ xk).view(H, S, 1)\n",
" v = (vw @ xv).view(H, 1, S)\n",
" g = F.silu(gw @ xg)\n",
"\n",
" s = state[(2+S)*i+2:(2+S)*(i+1), :].reshape(H, S, S)\n",
"\n",
" x = torch.zeros(H, S)\n",
" a = k @ v\n",
" x = r @ (time_first * a + s)\n",
" s = a + w * s\n",
" \n",
" state[(2+S)*i+2:(2+S)*(i+1), :] = s.reshape(S, -1)\n",
" x = x.flatten()\n",
"\n",
" x = F.group_norm(x.unsqueeze(0), num_groups=H, weight=ln_w, bias=ln_b, eps = 64e-5).squeeze(0) * g # same as gn(x/8, eps=1e-5)\n",
" return ow @ x\n",
"\n",
" def forward(self, token, state):\n",
" with torch.no_grad():\n",
" if state == None:\n",
" state = torch.zeros(self.args.n_layer * (2+self.head_size), self.args.n_embd)\n",
" \n",
" x = self.w.emb.weight[token]\n",
" x = self.layer_norm(x, self.w.blocks[0].ln0)\n",
" for i in range(self.args.n_layer):\n",
" att = self.w.blocks[i].att\n",
" x = x + self.time_mixing(self.layer_norm(x, self.w.blocks[i].ln1), state, i,\n",
" att.time_maa_x, att.time_maa_w, att.time_maa_k, att.time_maa_v, att.time_maa_r, att.time_maa_g, att.time_maa_w1, att.time_maa_w2,\n",
" att.time_decay_w1, att.time_decay_w2, att.time_faaaa, att.time_decay,\n",
" att.key.weight, att.value.weight, att.receptance.weight, att.gate.weight, att.output.weight,\n",
" att.ln_x.weight, att.ln_x.bias)\n",
" ffn = self.w.blocks[i].ffn\n",
" x = x + self.channel_mixing(self.layer_norm(x, self.w.blocks[i].ln2), state, i, \n",
" ffn.time_maa_k, ffn.time_maa_r, \n",
" ffn.key.weight, ffn.value.weight, ffn.receptance.weight)\n",
" \n",
" x = self.w.head.weight @ self.layer_norm(x, self.w.ln_out)\n",
" return x.float(), state"
]
},
{
"cell_type": "code",
"execution_count": 15,
"id": "5235be83-a574-41f6-8546-bc415e2aeacf",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n",
"Using CPU. Loading /data1/ckw/RWKV-x060-World-1B6-v2.1-20240328-ctx4096 ...\n",
"\n",
"Preprocessing context (slow version. see v2/rwkv/model.py for fast version)\n"
]
}
],
"source": [
"print(f'\\nUsing CPU. Loading {args.MODEL_NAME} ...')\n",
"model = RWKV_RNN(args)\n",
"\n",
"print(f'\\nPreprocessing context (slow version. see v2/rwkv/model.py for fast version)')\n",
"init_state = None"
]
},
{
"cell_type": "code",
"execution_count": 16,
"id": "035ce374-c1c0-43b8-9ba9-37df297baae6",
"metadata": {},
"outputs": [],
"source": [
"for token in tokenizer.encode(context):\n",
" init_out, init_state = model.forward(token, init_state)"
]
},
{
"cell_type": "code",
"execution_count": 17,
"id": "1e17d6ef-c02c-4d27-8cf6-9a262f75f77f",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n",
"\n",
"--[ Trial 0 ]----------------- \n",
"Datawhale is 🤑💰. #datawhale #technology #programming #python #data #java #SQL #object #pythonprogramming #django #open source #software #learning #future #machinelearning #dataanalysis #datavisualization #AI #ML #DL #CS #business #career #learning #learningcurve #technology #productivity #development #science #sciencecommunication #sciencecommunicationtools #science #Science #tech #engineering #engineeringtools #engineeringtechniques #engineering\n",
"\n",
"--[ Trial 1 ]----------------- \n",
"Datawhale is 👏 also 👏 an 👏 online 👏 trading 👏 platform 👏 and 👏 as 👏 mentioned 👏 above 👏, 👏 the 👏 reason 👏 why 👏 the 👏 team 👏 doesn't 👏 provide 👏 the 👏 open 👏 source\n",
"\n",
"--[ Trial 2 ]----------------- \n",
"Datawhale is 👌. I have the following questions about its operations and use:\n",
"1. Why does Datawhale charge for its services? What is its value proposition?\n",
"2. What is the main use case of Datawhale?\n",
"3. How does Datawhale use customer data?\n",
"4. What are Datawhales data collection practices?\n",
"5. What are Datawhales data storage practices?\n",
"6. What is Datawhales data retention policy?\n",
"\n",
"\n"
]
}
],
"source": [
"for TRIAL in range(NUM_TRIALS):\n",
" print(f'\\n\\n--[ Trial {TRIAL} ]-----------------', context, end=\"\")\n",
" all_tokens = []\n",
" out_last = 0\n",
" out, state = init_out.clone(), init_state.clone()\n",
" for i in range(LENGTH_PER_TRIAL):\n",
" token = sample_logits(out, TEMPERATURE, TOP_P)\n",
" all_tokens += [token]\n",
" try:\n",
" tmp = tokenizer.decode(all_tokens[out_last:])\n",
" if '\\ufffd' not in tmp: # only print when we have a valid utf-8 string\n",
" print(tmp, end=\"\", flush=True)\n",
" out_last = i + 1\n",
" except:\n",
" pass\n",
" out, state = model.forward(token, state) \n",
"print('\\n')"
]
},
{
"cell_type": "markdown",
"id": "172c33e0-6d5b-4143-b85d-86777a2f5739",
"metadata": {},
"source": [
"v6和v5相比感觉更喜欢使用emoj了哈哈"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "kewei-ai",
"language": "python",
"name": "kewei-ai"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.5"
}
},
"nbformat": 4,
"nbformat_minor": 5
}

File diff suppressed because it is too large Load Diff