### RWKV 模型版本比较报告 本文档旨在比较 RWKV 模型的六个不同版本(v1 至 v6),并详细介绍每个版本的特性、改进和性能。以下是对这六个模型版本的详细分析和比较。 --- #### 版本概述 **RWKV v1** - 初始版本,基础实现 RWKV 时间混合和通道混合模块。 - 主要特性: - 使用时间混合(Time-mix)和通道混合(Channel-mix)模块。 - 采用标准的线性层和嵌入层初始化。 - 使用掩码来处理因果关系。 **RWKV v2** - 增强版本,改进了时间混合和通道混合的实现。 - 主要改进: - 优化了模型加载和状态管理。 - 增加了新的归一化方法。 - 提升了训练和推理效率。 **RWKV v3** - 进一步优化的版本,主要集中在性能提升。 - 主要改进: - 调整了层数和嵌入维度,提供更灵活的配置选项。 - 增加了预处理步骤,提高了推理效率。 **RWKV v4** - 增加了对更大规模模型的支持,提升了模型复杂度。 - 主要改进: - 支持24层和1024维嵌入。 - 增加了更多的参数调优选项。 **RWKV v5** - 继续提升模型规模和复杂度,并优化了模型架构。 - 主要改进: - 支持更高的嵌入维度(2048)。 - 引入了新的时间混合和通道混合方法,提升了模型性能。 **RWKV v6** - 最新版本,综合了前几个版本的改进,并引入了一些新特性。 - 主要改进: - 增加了对更大词汇表(65536)的支持。 - 采用了改进的混合方法,提升了推理速度和准确性。 --- #### 详细比较 **1. 架构与实现** - **时间混合(Time-Mix)和通道混合(Channel-Mix)**: - **v1**:基本实现,功能完备。 ```python class RWKV_TimeMix(nn.Module): def __init__(self, config, layer_id): super().__init__() assert config.n_attn % config.n_head == 0 self.layer_id = layer_id self.ctx_len = config.ctx_len self.n_head = config.n_head self.head_size = config.n_attn // config.n_head with torch.no_grad(): # initial time_w curves for better convergence ww = torch.ones(config.n_head, config.ctx_len) curve = torch.tensor([-(config.ctx_len - 1 - i) for i in range(config.ctx_len)]) # the distance for h in range(config.n_head): if h < config.n_head - 1: decay_speed = math.pow(config.ctx_len, -(h+1)/(config.n_head-1)) else: decay_speed = 0.0 ww[h] = torch.exp(curve * decay_speed) self.time_w = nn.Parameter(ww) self.time_alpha = nn.Parameter(torch.ones(self.n_head, 1, config.ctx_len)) self.time_beta = nn.Parameter(torch.ones(self.n_head, config.ctx_len, 1)) self.time_gamma = nn.Parameter(torch.ones(config.ctx_len, 1)) self.time_shift = nn.ZeroPad2d((0,0,1,-1)) self.key = nn.Linear(config.n_embd, config.n_attn) self.value = nn.Linear(config.n_embd, config.n_attn) self.receptance = nn.Linear(config.n_embd, config.n_attn) self.output = nn.Linear(config.n_attn, config.n_embd) ``` - **v2**:优化了时间混合和通道混合,提升了计算效率。 ```python class RWKV_ChannelMix(nn.Module): def __init__(self, layer_id): super().__init__() self.layer_id = layer_id self.time_shift = nn.ZeroPad2d((0,0,1,-1)) self.time_mix = nn.Parameter(torch.ones(1, 1, n_embd)) hidden_sz = 4 * n_embd self.key = nn.Linear(n_embd, hidden_sz, bias=False) self.receptance = nn.Linear(n_embd, n_embd, bias=False) self.value = nn.Linear(hidden_sz, n_embd, bias=False) def forward(self, x): x = x * self.time_mix + self.time_shift(x) * (1 - self.time_mix) k = self.key(x) k = torch.square(torch.relu(k)) kv = self.value(k) rkv = torch.sigmoid(self.receptance(x)) * kv return rkv ``` - **v3**:进一步优化,并增加了灵活的配置选项。 ```python class RWKV_ChannelMix(nn.Module): def __init__(self, layer_id): super().__init__() self.layer_id = layer_id self.time_shift = nn.ZeroPad2d((0,0,1,-1)) self.time_mix_k = nn.Parameter(torch.ones(1, 1, n_embd)) self.time_mix_r = nn.Parameter(torch.ones(1, 1, n_embd)) hidden_sz = 4 * n_embd self.key = nn.Linear(n_embd, hidden_sz, bias=False) self.receptance = nn.Linear(n_embd, n_embd, bias=False) self.value = nn.Linear(hidden_sz, n_embd, bias=False) def forward(self, x): xx = self.time_shift(x) xk = x * self.time_mix_k + xx * (1 - self.time_mix_k) xr = x * self.time_mix_r + xx * (1 - self.time_mix_r) k = self.key(xk) k = torch.square(torch.relu(k)) kv = self.value(k) rkv = torch.sigmoid(self.receptance(xr)) * kv return rkv ``` - **v4**:支持更大规模模型,提升了时间混合和通道混合的处理能力。 ```python class RWKV_RNN(torch.jit.ScriptModule): def __init__(self, args): super().__init__() self.args = args self.eval() # set torch to inference mode w = torch.load(args.MODEL_NAME + '.pth', map_location='cpu') for k in w.keys(): if '.time_' in k: w[k] = w[k].squeeze() if '.time_decay' in k: w[k] = -torch.exp(w[k].float()) # the real time decay is like e^{-e^x} else: w[k] = w[k].float() # convert to f32 type self.w = types.SimpleNamespace() # set self.w from w self.w.blocks = {} for k in w.keys(): parts = k.split('.') last = parts.pop() here = self.w for p in parts: if p.isdigit(): p = int(p) if p not in here: here[p] = types.SimpleNamespace() here = here[p] else: if not hasattr(here, p): setattr(here, p, types.SimpleNamespace()) here = getattr(here, p) setattr(here, last, w[k]) ``` - **v5**:引入了新的混合方法,进一步提升了性能。 ```python class RWKV_RNN(MyModule): def __init__(self, args): super().__init__() self.args = args self.eval() # set torch to inference mode w = torch.load(args.MODEL_NAME + '.pth', map_location='cpu') for k in w.keys(): w[k] = w[k].float() # convert to f32 type if '.time_' in k: w[k] = w[k].squeeze() if '.time_decay' in k: w[k] = torch.exp(-torch.exp(w[k])).unsqueeze(-1) if '.time_faaaa' in k: w[k] = w[k].unsqueeze(-1) self.n_head = w['blocks.0.att.time_decay'].shape[0] self.head_size = w['blocks.0.ln1.weight'].shape[0] // self.n_head self.w = types.SimpleNamespace() # set self.w from w self.w.blocks = {} for k in w.keys(): parts = k.split('.') last = parts.pop() here = self.w for p in parts: if p.isdigit(): p = int(p) if p not in here: here[p] = types.SimpleNamespace() here = here[p] else: if not hasattr(here, p): setattr(here, p, types.SimpleNamespace()) here = getattr(here, p) setattr(here, last, w[k]) ``` - **v6**:改进了混合方法,提升了整体性能和效率。 ```python class RWKV_RNN(MyModule): def __init__(self, args): super().__init__() self.args = args self.eval() # set torch to inference mode w = torch.load(args.MODEL_NAME + '.pth', map_location='cpu') for k in w.keys(): w[k] = w[k].float() # convert to f32 type if '.time_' in k: w[k] = w[k].squeeze() if '.time_faaaa' in k: w[k] = w[k].unsqueeze(-1) self.n_head = w['blocks.0.att.time_faaaa'].shape[0] self.head_size = w['blocks.0.ln1.weight'].shape[0] // self.n_head self.w = types.SimpleNamespace() # set self.w from w self.w.blocks = {} for k in w.keys(): parts = k.split('.') last = parts.pop() here = self.w for p in parts: if p.isdigit(): p = int(p) if p not in here: here[p] = types.SimpleNamespace() here = here[p] else: if not hasattr(here, p): setattr(here, p, types.SimpleNamespace()) here = getattr(here, p) setattr(here, last, w[k]) ``` **2. 模型规模** - **层数和嵌入维度**: - **v1**:标准配置,适用于基础任务。 - **v2**:支持12层和768维嵌入。 - **v3**:提供12层和24层选项,嵌入维度为768和1024。 - **v4**:支持24层和1024维嵌入。 - **v5**:嵌入维度增加至2048。 - **v6**:进一步增加模型复杂度,支持更大词汇表。 **3. 性能与效率** - **推理速度和资源消耗**: - **v1**:基础实现,资源消耗适中。 - **v2**:优化后,推理速度提升。 - **v3**:预处理步骤的增加,提高了推理效率。 - **v4**:更大规模模型下的性能优化。 - **v5**:新的混合方法提升了推理速度和准确性。 - **v6**:综合改进,推理速度和资源利用进一步优化。 **4. 词汇表和上下文长度** - **词汇表大小和上下文长度支持**: - **v1-v4**:词汇表大小和上下文长度逐步增加。 - **v5**:支持更大上下文长度,适应复杂任务。 - **v6**:支持最大65536的词汇表和更长的上下文长度。 --- ### 总结 RWKV 模型在每个版本中不断优化和提升,从基础的 v1 到复杂且高效的 v6,模型的性能和功能都有了显著的进步。以下是每个版本的推荐使用场景: - **v1**:适用于基础任务和初步研究。 - **v2**:适用于需要更高效率和优化的任务。 - **v3**:适用于需要灵活配置和更高性能的应用。 - **v4**:适用于大规模模型的训练和推理任务。 - **v5**:适用于需要高精度和高效推理的复杂任务。 - **v6**:适用于最前沿的研究和应用,提供最高的性能和效率。 每个版本在其特定的改进点上都为用户提供了更好的选择,根据具体需求选择合适的版本将能充分发挥 RWKV 模型的优势。