mirror of
https://github.com/datawhalechina/llms-from-scratch-cn.git
synced 2026-05-01 11:58:17 +08:00
273 lines
11 KiB
Markdown
273 lines
11 KiB
Markdown
### RWKV 模型版本比较报告
|
||
|
||
本文档旨在比较 RWKV 模型的六个不同版本(v1 至 v6),并详细介绍每个版本的特性、改进和性能。以下是对这六个模型版本的详细分析和比较。
|
||
|
||
---
|
||
|
||
#### 版本概述
|
||
|
||
**RWKV v1**
|
||
- 初始版本,基础实现 RWKV 时间混合和通道混合模块。
|
||
- 主要特性:
|
||
- 使用时间混合(Time-mix)和通道混合(Channel-mix)模块。
|
||
- 采用标准的线性层和嵌入层初始化。
|
||
- 使用掩码来处理因果关系。
|
||
|
||
**RWKV v2**
|
||
- 增强版本,改进了时间混合和通道混合的实现。
|
||
- 主要改进:
|
||
- 优化了模型加载和状态管理。
|
||
- 增加了新的归一化方法。
|
||
- 提升了训练和推理效率。
|
||
|
||
**RWKV v3**
|
||
- 进一步优化的版本,主要集中在性能提升。
|
||
- 主要改进:
|
||
- 调整了层数和嵌入维度,提供更灵活的配置选项。
|
||
- 增加了预处理步骤,提高了推理效率。
|
||
|
||
**RWKV v4**
|
||
- 增加了对更大规模模型的支持,提升了模型复杂度。
|
||
- 主要改进:
|
||
- 支持24层和1024维嵌入。
|
||
- 增加了更多的参数调优选项。
|
||
|
||
**RWKV v5**
|
||
- 继续提升模型规模和复杂度,并优化了模型架构。
|
||
- 主要改进:
|
||
- 支持更高的嵌入维度(2048)。
|
||
- 引入了新的时间混合和通道混合方法,提升了模型性能。
|
||
|
||
**RWKV v6**
|
||
- 最新版本,综合了前几个版本的改进,并引入了一些新特性。
|
||
- 主要改进:
|
||
- 增加了对更大词汇表(65536)的支持。
|
||
- 采用了改进的混合方法,提升了推理速度和准确性。
|
||
|
||
---
|
||
|
||
#### 详细比较
|
||
|
||
**1. 架构与实现**
|
||
|
||
- **时间混合(Time-Mix)和通道混合(Channel-Mix)**:
|
||
- **v1**:基本实现,功能完备。
|
||
```python
|
||
class RWKV_TimeMix(nn.Module):
|
||
def __init__(self, config, layer_id):
|
||
super().__init__()
|
||
assert config.n_attn % config.n_head == 0
|
||
self.layer_id = layer_id
|
||
self.ctx_len = config.ctx_len
|
||
self.n_head = config.n_head
|
||
self.head_size = config.n_attn // config.n_head
|
||
|
||
with torch.no_grad(): # initial time_w curves for better convergence
|
||
ww = torch.ones(config.n_head, config.ctx_len)
|
||
curve = torch.tensor([-(config.ctx_len - 1 - i) for i in range(config.ctx_len)]) # the distance
|
||
for h in range(config.n_head):
|
||
if h < config.n_head - 1:
|
||
decay_speed = math.pow(config.ctx_len, -(h+1)/(config.n_head-1))
|
||
else:
|
||
decay_speed = 0.0
|
||
ww[h] = torch.exp(curve * decay_speed)
|
||
self.time_w = nn.Parameter(ww)
|
||
self.time_alpha = nn.Parameter(torch.ones(self.n_head, 1, config.ctx_len))
|
||
self.time_beta = nn.Parameter(torch.ones(self.n_head, config.ctx_len, 1))
|
||
self.time_gamma = nn.Parameter(torch.ones(config.ctx_len, 1))
|
||
|
||
self.time_shift = nn.ZeroPad2d((0,0,1,-1))
|
||
|
||
self.key = nn.Linear(config.n_embd, config.n_attn)
|
||
self.value = nn.Linear(config.n_embd, config.n_attn)
|
||
self.receptance = nn.Linear(config.n_embd, config.n_attn)
|
||
|
||
self.output = nn.Linear(config.n_attn, config.n_embd)
|
||
```
|
||
- **v2**:优化了时间混合和通道混合,提升了计算效率。
|
||
```python
|
||
class RWKV_ChannelMix(nn.Module):
|
||
def __init__(self, layer_id):
|
||
super().__init__()
|
||
self.layer_id = layer_id
|
||
|
||
self.time_shift = nn.ZeroPad2d((0,0,1,-1))
|
||
self.time_mix = nn.Parameter(torch.ones(1, 1, n_embd))
|
||
|
||
hidden_sz = 4 * n_embd
|
||
self.key = nn.Linear(n_embd, hidden_sz, bias=False)
|
||
self.receptance = nn.Linear(n_embd, n_embd, bias=False)
|
||
self.value = nn.Linear(hidden_sz, n_embd, bias=False)
|
||
|
||
def forward(self, x):
|
||
x = x * self.time_mix + self.time_shift(x) * (1 - self.time_mix)
|
||
k = self.key(x)
|
||
k = torch.square(torch.relu(k))
|
||
kv = self.value(k)
|
||
rkv = torch.sigmoid(self.receptance(x)) * kv
|
||
return rkv
|
||
```
|
||
- **v3**:进一步优化,并增加了灵活的配置选项。
|
||
```python
|
||
class RWKV_ChannelMix(nn.Module):
|
||
def __init__(self, layer_id):
|
||
super().__init__()
|
||
self.layer_id = layer_id
|
||
|
||
self.time_shift = nn.ZeroPad2d((0,0,1,-1))
|
||
self.time_mix_k = nn.Parameter(torch.ones(1, 1, n_embd))
|
||
self.time_mix_r = nn.Parameter(torch.ones(1, 1, n_embd))
|
||
|
||
hidden_sz = 4 * n_embd
|
||
self.key = nn.Linear(n_embd, hidden_sz, bias=False)
|
||
self.receptance = nn.Linear(n_embd, n_embd, bias=False)
|
||
self.value = nn.Linear(hidden_sz, n_embd, bias=False)
|
||
|
||
def forward(self, x):
|
||
xx = self.time_shift(x)
|
||
xk = x * self.time_mix_k + xx * (1 - self.time_mix_k)
|
||
xr = x * self.time_mix_r + xx * (1 - self.time_mix_r)
|
||
k = self.key(xk)
|
||
k = torch.square(torch.relu(k))
|
||
kv = self.value(k)
|
||
rkv = torch.sigmoid(self.receptance(xr)) * kv
|
||
return rkv
|
||
```
|
||
- **v4**:支持更大规模模型,提升了时间混合和通道混合的处理能力。
|
||
```python
|
||
class RWKV_RNN(torch.jit.ScriptModule):
|
||
def __init__(self, args):
|
||
super().__init__()
|
||
self.args = args
|
||
self.eval() # set torch to inference mode
|
||
|
||
w = torch.load(args.MODEL_NAME + '.pth', map_location='cpu')
|
||
for k in w.keys():
|
||
if '.time_' in k: w[k] = w[k].squeeze()
|
||
if '.time_decay' in k: w[k] = -torch.exp(w[k].float()) # the real time decay is like e^{-e^x}
|
||
else: w[k] = w[k].float() # convert to f32 type
|
||
self.w = types.SimpleNamespace() # set self.w from w
|
||
self.w.blocks = {}
|
||
for k in w.keys():
|
||
parts = k.split('.')
|
||
last = parts.pop()
|
||
here = self.w
|
||
for p in parts:
|
||
if p.isdigit():
|
||
p = int(p)
|
||
if p not in here: here[p] = types.SimpleNamespace()
|
||
here = here[p]
|
||
else:
|
||
if not hasattr(here, p): setattr(here, p, types.SimpleNamespace())
|
||
here = getattr(here, p)
|
||
setattr(here, last, w[k])
|
||
```
|
||
- **v5**:引入了新的混合方法,进一步提升了性能。
|
||
```python
|
||
class RWKV_RNN(MyModule):
|
||
def __init__(self, args):
|
||
super().__init__()
|
||
self.args = args
|
||
self.eval() # set torch to inference mode
|
||
|
||
w = torch.load(args.MODEL_NAME + '.pth', map_location='cpu')
|
||
for k in w.keys():
|
||
w[k] = w[k].float() # convert to f32 type
|
||
if '.time_' in k: w[k] = w[k].squeeze()
|
||
if '.time_decay' in k: w[k] = torch.exp(-torch.exp(w[k])).unsqueeze(-1)
|
||
if '.time_faaaa' in k: w[k] = w[k].unsqueeze(-1)
|
||
|
||
self.n_head = w['blocks.0.att.time_decay'].shape[0]
|
||
self.head_size = w['blocks.0.ln1.weight'].shape[0] // self.n_head
|
||
|
||
self.w = types.SimpleNamespace() # set self.w from w
|
||
self.w.blocks = {}
|
||
for k in w.keys():
|
||
parts = k.split('.')
|
||
last = parts.pop()
|
||
here = self.w
|
||
for p in parts:
|
||
if p.isdigit():
|
||
p = int(p)
|
||
if p not in here: here[p] = types.SimpleNamespace()
|
||
here = here[p]
|
||
else:
|
||
if not hasattr(here, p): setattr(here, p, types.SimpleNamespace())
|
||
here = getattr(here, p)
|
||
setattr(here, last, w[k])
|
||
```
|
||
|
||
|
||
- **v6**:改进了混合方法,提升了整体性能和效率。
|
||
```python
|
||
class RWKV_RNN(MyModule):
|
||
def __init__(self, args):
|
||
super().__init__()
|
||
self.args = args
|
||
self.eval() # set torch to inference mode
|
||
|
||
w = torch.load(args.MODEL_NAME + '.pth', map_location='cpu')
|
||
for k in w.keys():
|
||
w[k] = w[k].float() # convert to f32 type
|
||
if '.time_' in k: w[k] = w[k].squeeze()
|
||
if '.time_faaaa' in k: w[k] = w[k].unsqueeze(-1)
|
||
self.n_head = w['blocks.0.att.time_faaaa'].shape[0]
|
||
self.head_size = w['blocks.0.ln1.weight'].shape[0] // self.n_head
|
||
self.w = types.SimpleNamespace() # set self.w from w
|
||
self.w.blocks = {}
|
||
for k in w.keys():
|
||
parts = k.split('.')
|
||
last = parts.pop()
|
||
here = self.w
|
||
for p in parts:
|
||
if p.isdigit():
|
||
p = int(p)
|
||
if p not in here: here[p] = types.SimpleNamespace()
|
||
here = here[p]
|
||
else:
|
||
if not hasattr(here, p): setattr(here, p, types.SimpleNamespace())
|
||
here = getattr(here, p)
|
||
setattr(here, last, w[k])
|
||
```
|
||
|
||
**2. 模型规模**
|
||
|
||
- **层数和嵌入维度**:
|
||
- **v1**:标准配置,适用于基础任务。
|
||
- **v2**:支持12层和768维嵌入。
|
||
- **v3**:提供12层和24层选项,嵌入维度为768和1024。
|
||
- **v4**:支持24层和1024维嵌入。
|
||
- **v5**:嵌入维度增加至2048。
|
||
- **v6**:进一步增加模型复杂度,支持更大词汇表。
|
||
|
||
**3. 性能与效率**
|
||
|
||
- **推理速度和资源消耗**:
|
||
- **v1**:基础实现,资源消耗适中。
|
||
- **v2**:优化后,推理速度提升。
|
||
- **v3**:预处理步骤的增加,提高了推理效率。
|
||
- **v4**:更大规模模型下的性能优化。
|
||
- **v5**:新的混合方法提升了推理速度和准确性。
|
||
- **v6**:综合改进,推理速度和资源利用进一步优化。
|
||
|
||
**4. 词汇表和上下文长度**
|
||
|
||
- **词汇表大小和上下文长度支持**:
|
||
- **v1-v4**:词汇表大小和上下文长度逐步增加。
|
||
- **v5**:支持更大上下文长度,适应复杂任务。
|
||
- **v6**:支持最大65536的词汇表和更长的上下文长度。
|
||
|
||
---
|
||
|
||
### 总结
|
||
|
||
RWKV 模型在每个版本中不断优化和提升,从基础的 v1 到复杂且高效的 v6,模型的性能和功能都有了显著的进步。以下是每个版本的推荐使用场景:
|
||
|
||
- **v1**:适用于基础任务和初步研究。
|
||
- **v2**:适用于需要更高效率和优化的任务。
|
||
- **v3**:适用于需要灵活配置和更高性能的应用。
|
||
- **v4**:适用于大规模模型的训练和推理任务。
|
||
- **v5**:适用于需要高精度和高效推理的复杂任务。
|
||
- **v6**:适用于最前沿的研究和应用,提供最高的性能和效率。
|
||
|
||
每个版本在其特定的改进点上都为用户提供了更好的选择,根据具体需求选择合适的版本将能充分发挥 RWKV 模型的优势。 |