llms-from-scratch-cn/Model_Architecture_Discussions/mamba/model.py
2024-05-31 17:04:42 +08:00

341 lines
12 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
Mamba在一个PyTorch文件中的简单、极简实现。
建议在阅读代码之前或期间阅读以下内容:
[1] Mamba: Linear-Time Sequence Modeling with Selective State Spaces (Albert Gu 和 Tri Dao)
https://arxiv.org/abs/2312.00752
[2] The Annotated S4 (Sasha Rush 和 Sidd Karamcheti)
https://srush.github.io/annotated-s4
术语表:
b: 批次大小 (`B` 在 Mamba 论文 [1] 算法2中)
l: 序列长度 (`L` 在 [1] 算法2中)
d 或 d_model: 隐藏维度
n 或 d_state: 潜在状态维度 (`N` 在 [1] 算法2中)
expand: 扩展因子 (`E` 在 [1] 第3.4节中)
d_in 或 d_inner: d * expand (`D` 在 [1] 算法2中)
A, B, C, D: 状态空间参数 (参见任何状态空间表示公式)
(B, C是输入相关的(即选择性的这是Mamba的一个关键创新); A, D不是)
Δ 或 delta: 输入相关的步长
dt_rank: Δ的秩 (参见[1] 第3.6节 "Parameterization of ∆")
"""
from __future__ import annotations
import math
import json
import torch
import torch.nn as nn
import torch.nn.functional as F
from dataclasses import dataclass
from einops import rearrange, repeat, einsum
@dataclass
class ModelArgs:
d_model: int
n_layer: int
vocab_size: int
d_state: int = 16
expand: int = 2
dt_rank: Union[int, str] = 'auto'
d_conv: int = 4
pad_vocab_size_multiple: int = 8
conv_bias: bool = True
bias: bool = False
def __post_init__(self):
self.d_inner = int(self.expand * self.d_model)
if self.dt_rank == 'auto':
self.dt_rank = math.ceil(self.d_model / 16)
if self.vocab_size % self.pad_vocab_size_multiple != 0:
self.vocab_size += (self.pad_vocab_size_multiple
- self.vocab_size % self.pad_vocab_size_multiple)
class Mamba(nn.Module):
def __init__(self, args: ModelArgs):
"""完整的Mamba模型。"""
super().__init__()
self.args = args
self.embedding = nn.Embedding(args.vocab_size, args.d_model)
self.layers = nn.ModuleList([ResidualBlock(args) for _ in range(args.n_layer)])
self.norm_f = RMSNorm(args.d_model)
self.lm_head = nn.Linear(args.d_model, args.vocab_size, bias=False)
self.lm_head.weight = self.embedding.weight # 将输出投影与嵌入权重绑定。
# 参见 "Weight Tying" 论文
def forward(self, input_ids):
"""
参数:
input_ids (long tensor): 形状 (b, l) (参见顶部术语表中b, l, d_in, n的定义)
返回:
logits: 形状 (b, l, vocab_size)
官方实现:
class MambaLMHeadModel, https://github.com/state-spaces/mamba/blob/main/mamba_ssm/models/mixer_seq_simple.py#L173
"""
x = self.embedding(input_ids)
for layer in self.layers:
x = layer(x)
x = self.norm_f(x)
logits = self.lm_head(x)
return logits
@staticmethod
def from_pretrained(pretrained_model_name: str):
"""从HuggingFace加载预训练权重到模型中。
参数:
pretrained_model_name: 以下之一
* 'state-spaces/mamba-2.8b-slimpj'
* 'state-spaces/mamba-2.8b'
* 'state-spaces/mamba-1.4b'
* 'state-spaces/mamba-790m'
* 'state-spaces/mamba-370m'
* 'state-spaces/mamba-130m'
返回:
model: 加载了权重的Mamba模型
"""
from transformers.utils import WEIGHTS_NAME, CONFIG_NAME
from transformers.utils.hub import cached_file
def load_config_hf(model_name):
resolved_archive_file = cached_file(model_name, CONFIG_NAME,
_raise_exceptions_for_missing_entries=False)
return json.load(open(resolved_archive_file))
def load_state_dict_hf(model_name, device=None, dtype=None):
resolved_archive_file = cached_file(model_name, WEIGHTS_NAME,
_raise_exceptions_for_missing_entries=False)
return torch.load(resolved_archive_file, weights_only=True, map_location='cpu', mmap=True)
config_data = load_config_hf(pretrained_model_name)
args = ModelArgs(
d_model=config_data['d_model'],
n_layer=config_data['n_layer'],
vocab_size=config_data['vocab_size']
)
model = Mamba(args)
state_dict = load_state_dict_hf(pretrained_model_name)
new_state_dict = {}
for key in state_dict:
new_key = key.replace('backbone.', '')
new_state_dict[new_key] = state_dict[key]
model.load_state_dict(new_state_dict)
return model
class ResidualBlock(nn.Module):
def __init__(self, args: ModelArgs):
"""包装Mamba块的简单块具有归一化和残差连接。"""
super().__init__()
self.args = args
self.mixer = MambaBlock(args)
self.norm = RMSNorm(args.d_model)
def forward(self, x):
"""
参数:
x: 形状 (b, l, d) (参见顶部术语表中b, l, d_in, n的定义)
返回:
output: 形状 (b, l, d)
官方实现:
Block.forward(), https://github.com/state-spaces/mamba/blob/main/mamba_ssm/modules/mamba_simple.py#L297
注意: 官方库链式残差块看起来像
[Add -> Norm -> Mamba] -> [Add -> Norm -> Mamba] -> [Add -> Norm -> Mamba] -> ...
其中第一个Add是无操作。这纯粹是为了性能原因因为这
允许他们融合Add->Norm。
我们相反实现的块更为熟悉,更简单,数值上等效
[Norm -> Mamba -> Add] -> [Norm -> Mamba -> Add] -> [Norm -> Mamba -> Add] -> ....
"""
output = self.mixer(self.norm(x)) + x
return output
class MambaBlock(nn.Module):
def __init__(self, args: ModelArgs):
"""一个单一的Mamba块如Mamba论文[1]第3.4节中的图3所述。"""
super().__init__()
self.args = args
self.in_proj = nn.Linear(args.d_model, args.d_inner * 2, bias=args.bias)
self.conv1d = nn.Conv1d(
in_channels=args.d_inner,
out_channels=args.d_inner,
bias=args.conv_bias,
kernel_size=args.d_conv,
groups=args.d_inner,
padding=args.d_conv - 1,
)
# x_proj接收`x`并输出输入特定的Δ, B, C
self.x_proj = nn.Linear(args.d_inner, args.dt_rank + args.d_state * 2, bias=False)
# dt_proj将Δ从dt_rank投影到d_in
self.dt_proj = nn.Linear(args.dt_rank, args.d_inner, bias=True)
A = repeat(torch.arange(1, args.d_state + 1), 'n -> d n', d=args.d_inner)
self.A_log = nn.Parameter(torch.log(A))
self.D = nn.Parameter(torch.ones(args.d_inner))
self.out_proj = nn.Linear(args.d_inner, args.d_model, bias=args.bias)
def forward(self, x):
"""Mamba块前向。这与Mamba论文[1]第3.4节中的图3相同。
参数:
x: 形状 (b, l, d) (参见顶部术语表中b, l, d_in, n的定义)
返回:
output: 形状 (b, l, d)
官方实现:
class Mamba, https://github.com/state-spaces/mamba/blob/main/mamba_ssm/modules/mamba_simple.py#L119
mamba_inner_ref(), https://github.com/state-spaces/mamba/blob/main/mamba_ssm/ops/selective_scan_interface.py#L311
"""
(b, l, d) = x.shape
x_and_res = self.in_proj(x) # 形状 (b, l, 2 * d_in)
(x, res) = x_and_res.split(split_size=[self.args.d_inner, self.args.d_inner], dim=-1)
x = rearrange(x, 'b l d_in -> b d_in l')
x = self.conv1d(x)[:, :, :l]
x = rearrange(x, 'b d_in l -> b l d_in')
x = F.silu(x)
y = self.ssm(x)
y = y * F.silu(res)
output = self.out_proj(y)
return output
def ssm(self, x):
"""运行SSM。参见
- Mamba论文[1]第3.2节中的算法2
- The Annotated S4 [2] 中的run_SSM(A, B, C, u)
参数:
x: 形状 (b, l, d_in) (参见顶部术语表中b, l, d_in, n的定义)
返回:
output: 形状 (b, l, d_in)
官方实现:
mamba_inner_ref(), https://github.com/state-spaces/mamba/blob/main/mamba_ssm/ops/selective_scan_interface.py#L311
"""
(d_in, n) = self.A_log.shape
# 计算 ∆ A B C D状态空间参数。
# A, D 与输入无关 (参见Mamba论文[1]第3.5.2节"Interpretation of A"中为何A不是选择性的)
# ∆, B, C 与输入相关 (这是Mamba与线性时不变S4的一个关键区别
# 也是Mamba称为**选择性**状态空间的原因)
A = -torch.exp(self.A_log.float()) # 形状 (d_in, n)
D = self.D.float()
x_dbl = self.x_proj(x) # (b, l, dt_rank + 2*n)
(delta, B, C) = x_dbl.split(split_size=[self.args.dt_rank, n, n], dim=-1) # delta: (b, l, dt_rank). B, C: (b, l, n)
delta = F.softplus(self.dt_proj(delta)) # (b, l, d_in)
y = self.selective_scan(x, delta, A, B, C, D) # 这类似于The Annotated S4 [2]中的run_SSM(A, B, C, u)
return y
def selective_scan(self, u, delta, A, B, C, D):
"""执行选择性扫描算法。参见:
- Mamba论文[1]中的第2节状态空间模型
- Mamba论文[1]第3.2节中的算法2
- The Annotated S4 [2]中的run_SSM(A, B, C, u)
这是经典的离散状态空间公式:
x(t + 1) = Ax(t) + Bu(t)
y(t) = Cx(t) + Du(t)
除了B和C(以及用于离散化的步长delta)依赖于输入x(t)。
参数:
u: 形状 (b, l, d_in) (参见顶部术语表中b, l, d_in, n的定义)
delta: 形状 (b, l, d_in)
A: 形状 (d_in, n)
B: 形状 (b, l, n)
C: 形状 (b, l, n)
D: 形状 (d_in,)
返回:
output: 形状 (b, l, d_in)
官方实现:
selective_scan_ref(), https://github.com/state-spaces/mamba/blob/main/mamba_ssm/ops/selective_scan_interface.py#L86
注意: 我将`selective_scan_ref`中的一些部分进行了重构,所以功能不完全匹配。
"""
(b, l, d_in) = u.shape
n = A.shape[1]
# 离散化连续参数 (A, B)
# - A 使用零阶保持(ZOH)离散化 (参见Mamba论文[1]第2节方程4)
# - B 使用简化的欧拉离散化而不是ZOH。从与作者的讨论中
# "A是更重要的项简化B对性能影响不大"
deltaA = torch.exp(einsum(delta, A, 'b l d_in, d_in n -> b l d_in n'))
deltaB_u = einsum(delta, B, u, 'b l d_in, b l n, b l d_in -> b l d_in n')
# 执行选择性扫描 (参见The Annotated S4 [2]中的scan_SSM())
# 注意,以下是顺序的,而官方实现是一个更快的并行扫描,此外还考虑了硬件 (如FlashAttention)。
x = torch.zeros((b, d_in, n), device=deltaA.device)
ys = []
for i in range(l):
x = deltaA[:, i] * x + deltaB_u[:, i]
y = einsum(x, C[:, i, :], 'b d_in n, b n -> b d_in')
ys.append(y)
y = torch.stack(ys, dim=1) # 形状 (b, l, d_in)
y = y + u * D
return y
class RMSNorm(nn.Module):
def __init__(self,
d_model: int,
eps: float = 1e-5):
super().__init__()
self.eps = eps
self.weight = nn.Parameter(torch.ones(d_model))
def forward(self, x):
output = x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) * self.weight
return output