llms-from-scratch-cn/Model_Architecture_Discussions/rwkv-compare/readme.md
2024-05-31 16:44:23 +08:00

273 lines
11 KiB
Markdown
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

### RWKV 模型版本比较报告
本文档旨在比较 RWKV 模型的六个不同版本v1 至 v6并详细介绍每个版本的特性、改进和性能。以下是对这六个模型版本的详细分析和比较。
---
#### 版本概述
**RWKV v1**
- 初始版本,基础实现 RWKV 时间混合和通道混合模块。
- 主要特性:
- 使用时间混合Time-mix和通道混合Channel-mix模块。
- 采用标准的线性层和嵌入层初始化。
- 使用掩码来处理因果关系。
**RWKV v2**
- 增强版本,改进了时间混合和通道混合的实现。
- 主要改进:
- 优化了模型加载和状态管理。
- 增加了新的归一化方法。
- 提升了训练和推理效率。
**RWKV v3**
- 进一步优化的版本,主要集中在性能提升。
- 主要改进:
- 调整了层数和嵌入维度,提供更灵活的配置选项。
- 增加了预处理步骤,提高了推理效率。
**RWKV v4**
- 增加了对更大规模模型的支持,提升了模型复杂度。
- 主要改进:
- 支持24层和1024维嵌入。
- 增加了更多的参数调优选项。
**RWKV v5**
- 继续提升模型规模和复杂度,并优化了模型架构。
- 主要改进:
- 支持更高的嵌入维度2048
- 引入了新的时间混合和通道混合方法,提升了模型性能。
**RWKV v6**
- 最新版本,综合了前几个版本的改进,并引入了一些新特性。
- 主要改进:
- 增加了对更大词汇表65536的支持。
- 采用了改进的混合方法,提升了推理速度和准确性。
---
#### 详细比较
**1. 架构与实现**
- **时间混合Time-Mix和通道混合Channel-Mix**
- **v1**:基本实现,功能完备。
```python
class RWKV_TimeMix(nn.Module):
def __init__(self, config, layer_id):
super().__init__()
assert config.n_attn % config.n_head == 0
self.layer_id = layer_id
self.ctx_len = config.ctx_len
self.n_head = config.n_head
self.head_size = config.n_attn // config.n_head
with torch.no_grad(): # initial time_w curves for better convergence
ww = torch.ones(config.n_head, config.ctx_len)
curve = torch.tensor([-(config.ctx_len - 1 - i) for i in range(config.ctx_len)]) # the distance
for h in range(config.n_head):
if h < config.n_head - 1:
decay_speed = math.pow(config.ctx_len, -(h+1)/(config.n_head-1))
else:
decay_speed = 0.0
ww[h] = torch.exp(curve * decay_speed)
self.time_w = nn.Parameter(ww)
self.time_alpha = nn.Parameter(torch.ones(self.n_head, 1, config.ctx_len))
self.time_beta = nn.Parameter(torch.ones(self.n_head, config.ctx_len, 1))
self.time_gamma = nn.Parameter(torch.ones(config.ctx_len, 1))
self.time_shift = nn.ZeroPad2d((0,0,1,-1))
self.key = nn.Linear(config.n_embd, config.n_attn)
self.value = nn.Linear(config.n_embd, config.n_attn)
self.receptance = nn.Linear(config.n_embd, config.n_attn)
self.output = nn.Linear(config.n_attn, config.n_embd)
```
- **v2**优化了时间混合和通道混合提升了计算效率
```python
class RWKV_ChannelMix(nn.Module):
def __init__(self, layer_id):
super().__init__()
self.layer_id = layer_id
self.time_shift = nn.ZeroPad2d((0,0,1,-1))
self.time_mix = nn.Parameter(torch.ones(1, 1, n_embd))
hidden_sz = 4 * n_embd
self.key = nn.Linear(n_embd, hidden_sz, bias=False)
self.receptance = nn.Linear(n_embd, n_embd, bias=False)
self.value = nn.Linear(hidden_sz, n_embd, bias=False)
def forward(self, x):
x = x * self.time_mix + self.time_shift(x) * (1 - self.time_mix)
k = self.key(x)
k = torch.square(torch.relu(k))
kv = self.value(k)
rkv = torch.sigmoid(self.receptance(x)) * kv
return rkv
```
- **v3**进一步优化并增加了灵活的配置选项
```python
class RWKV_ChannelMix(nn.Module):
def __init__(self, layer_id):
super().__init__()
self.layer_id = layer_id
self.time_shift = nn.ZeroPad2d((0,0,1,-1))
self.time_mix_k = nn.Parameter(torch.ones(1, 1, n_embd))
self.time_mix_r = nn.Parameter(torch.ones(1, 1, n_embd))
hidden_sz = 4 * n_embd
self.key = nn.Linear(n_embd, hidden_sz, bias=False)
self.receptance = nn.Linear(n_embd, n_embd, bias=False)
self.value = nn.Linear(hidden_sz, n_embd, bias=False)
def forward(self, x):
xx = self.time_shift(x)
xk = x * self.time_mix_k + xx * (1 - self.time_mix_k)
xr = x * self.time_mix_r + xx * (1 - self.time_mix_r)
k = self.key(xk)
k = torch.square(torch.relu(k))
kv = self.value(k)
rkv = torch.sigmoid(self.receptance(xr)) * kv
return rkv
```
- **v4**支持更大规模模型提升了时间混合和通道混合的处理能力
```python
class RWKV_RNN(torch.jit.ScriptModule):
def __init__(self, args):
super().__init__()
self.args = args
self.eval() # set torch to inference mode
w = torch.load(args.MODEL_NAME + '.pth', map_location='cpu')
for k in w.keys():
if '.time_' in k: w[k] = w[k].squeeze()
if '.time_decay' in k: w[k] = -torch.exp(w[k].float()) # the real time decay is like e^{-e^x}
else: w[k] = w[k].float() # convert to f32 type
self.w = types.SimpleNamespace() # set self.w from w
self.w.blocks = {}
for k in w.keys():
parts = k.split('.')
last = parts.pop()
here = self.w
for p in parts:
if p.isdigit():
p = int(p)
if p not in here: here[p] = types.SimpleNamespace()
here = here[p]
else:
if not hasattr(here, p): setattr(here, p, types.SimpleNamespace())
here = getattr(here, p)
setattr(here, last, w[k])
```
- **v5**引入了新的混合方法进一步提升了性能
```python
class RWKV_RNN(MyModule):
def __init__(self, args):
super().__init__()
self.args = args
self.eval() # set torch to inference mode
w = torch.load(args.MODEL_NAME + '.pth', map_location='cpu')
for k in w.keys():
w[k] = w[k].float() # convert to f32 type
if '.time_' in k: w[k] = w[k].squeeze()
if '.time_decay' in k: w[k] = torch.exp(-torch.exp(w[k])).unsqueeze(-1)
if '.time_faaaa' in k: w[k] = w[k].unsqueeze(-1)
self.n_head = w['blocks.0.att.time_decay'].shape[0]
self.head_size = w['blocks.0.ln1.weight'].shape[0] // self.n_head
self.w = types.SimpleNamespace() # set self.w from w
self.w.blocks = {}
for k in w.keys():
parts = k.split('.')
last = parts.pop()
here = self.w
for p in parts:
if p.isdigit():
p = int(p)
if p not in here: here[p] = types.SimpleNamespace()
here = here[p]
else:
if not hasattr(here, p): setattr(here, p, types.SimpleNamespace())
here = getattr(here, p)
setattr(here, last, w[k])
```
- **v6**改进了混合方法提升了整体性能和效率
```python
class RWKV_RNN(MyModule):
def __init__(self, args):
super().__init__()
self.args = args
self.eval() # set torch to inference mode
w = torch.load(args.MODEL_NAME + '.pth', map_location='cpu')
for k in w.keys():
w[k] = w[k].float() # convert to f32 type
if '.time_' in k: w[k] = w[k].squeeze()
if '.time_faaaa' in k: w[k] = w[k].unsqueeze(-1)
self.n_head = w['blocks.0.att.time_faaaa'].shape[0]
self.head_size = w['blocks.0.ln1.weight'].shape[0] // self.n_head
self.w = types.SimpleNamespace() # set self.w from w
self.w.blocks = {}
for k in w.keys():
parts = k.split('.')
last = parts.pop()
here = self.w
for p in parts:
if p.isdigit():
p = int(p)
if p not in here: here[p] = types.SimpleNamespace()
here = here[p]
else:
if not hasattr(here, p): setattr(here, p, types.SimpleNamespace())
here = getattr(here, p)
setattr(here, last, w[k])
```
**2. 模型规模**
- **层数和嵌入维度**
- **v1**标准配置适用于基础任务
- **v2**支持12层和768维嵌入
- **v3**提供12层和24层选项嵌入维度为768和1024
- **v4**支持24层和1024维嵌入
- **v5**嵌入维度增加至2048
- **v6**进一步增加模型复杂度支持更大词汇表
**3. 性能与效率**
- **推理速度和资源消耗**
- **v1**基础实现资源消耗适中
- **v2**优化后推理速度提升
- **v3**预处理步骤的增加提高了推理效率
- **v4**更大规模模型下的性能优化
- **v5**新的混合方法提升了推理速度和准确性
- **v6**综合改进推理速度和资源利用进一步优化
**4. 词汇表和上下文长度**
- **词汇表大小和上下文长度支持**
- **v1-v4**词汇表大小和上下文长度逐步增加
- **v5**支持更大上下文长度适应复杂任务
- **v6**支持最大65536的词汇表和更长的上下文长度
---
### 总结
RWKV 模型在每个版本中不断优化和提升从基础的 v1 到复杂且高效的 v6模型的性能和功能都有了显著的进步以下是每个版本的推荐使用场景
- **v1**适用于基础任务和初步研究
- **v2**适用于需要更高效率和优化的任务
- **v3**适用于需要灵活配置和更高性能的应用
- **v4**适用于大规模模型的训练和推理任务
- **v5**适用于需要高精度和高效推理的复杂任务
- **v6**适用于最前沿的研究和应用提供最高的性能和效率
每个版本在其特定的改进点上都为用户提供了更好的选择根据具体需求选择合适的版本将能充分发挥 RWKV 模型的优势