mirror of
https://github.com/datawhalechina/llms-from-scratch-cn.git
synced 2026-02-20 01:34:46 +08:00
341 lines
12 KiB
Python
341 lines
12 KiB
Python
"""
|
||
Mamba在一个PyTorch文件中的简单、极简实现。
|
||
|
||
建议在阅读代码之前或期间阅读以下内容:
|
||
[1] Mamba: Linear-Time Sequence Modeling with Selective State Spaces (Albert Gu 和 Tri Dao)
|
||
https://arxiv.org/abs/2312.00752
|
||
[2] The Annotated S4 (Sasha Rush 和 Sidd Karamcheti)
|
||
https://srush.github.io/annotated-s4
|
||
|
||
术语表:
|
||
b: 批次大小 (`B` 在 Mamba 论文 [1] 算法2中)
|
||
l: 序列长度 (`L` 在 [1] 算法2中)
|
||
d 或 d_model: 隐藏维度
|
||
n 或 d_state: 潜在状态维度 (`N` 在 [1] 算法2中)
|
||
expand: 扩展因子 (`E` 在 [1] 第3.4节中)
|
||
d_in 或 d_inner: d * expand (`D` 在 [1] 算法2中)
|
||
A, B, C, D: 状态空间参数 (参见任何状态空间表示公式)
|
||
(B, C是输入相关的(即选择性的,这是Mamba的一个关键创新); A, D不是)
|
||
Δ 或 delta: 输入相关的步长
|
||
dt_rank: Δ的秩 (参见[1] 第3.6节 "Parameterization of ∆")
|
||
|
||
"""
|
||
from __future__ import annotations
|
||
import math
|
||
import json
|
||
import torch
|
||
import torch.nn as nn
|
||
import torch.nn.functional as F
|
||
from dataclasses import dataclass
|
||
from einops import rearrange, repeat, einsum
|
||
|
||
|
||
@dataclass
|
||
class ModelArgs:
|
||
d_model: int
|
||
n_layer: int
|
||
vocab_size: int
|
||
d_state: int = 16
|
||
expand: int = 2
|
||
dt_rank: Union[int, str] = 'auto'
|
||
d_conv: int = 4
|
||
pad_vocab_size_multiple: int = 8
|
||
conv_bias: bool = True
|
||
bias: bool = False
|
||
|
||
def __post_init__(self):
|
||
self.d_inner = int(self.expand * self.d_model)
|
||
|
||
if self.dt_rank == 'auto':
|
||
self.dt_rank = math.ceil(self.d_model / 16)
|
||
|
||
if self.vocab_size % self.pad_vocab_size_multiple != 0:
|
||
self.vocab_size += (self.pad_vocab_size_multiple
|
||
- self.vocab_size % self.pad_vocab_size_multiple)
|
||
|
||
|
||
class Mamba(nn.Module):
|
||
def __init__(self, args: ModelArgs):
|
||
"""完整的Mamba模型。"""
|
||
super().__init__()
|
||
self.args = args
|
||
|
||
self.embedding = nn.Embedding(args.vocab_size, args.d_model)
|
||
self.layers = nn.ModuleList([ResidualBlock(args) for _ in range(args.n_layer)])
|
||
self.norm_f = RMSNorm(args.d_model)
|
||
|
||
self.lm_head = nn.Linear(args.d_model, args.vocab_size, bias=False)
|
||
self.lm_head.weight = self.embedding.weight # 将输出投影与嵌入权重绑定。
|
||
# 参见 "Weight Tying" 论文
|
||
|
||
|
||
def forward(self, input_ids):
|
||
"""
|
||
参数:
|
||
input_ids (long tensor): 形状 (b, l) (参见顶部术语表中b, l, d_in, n的定义)
|
||
|
||
返回:
|
||
logits: 形状 (b, l, vocab_size)
|
||
|
||
官方实现:
|
||
class MambaLMHeadModel, https://github.com/state-spaces/mamba/blob/main/mamba_ssm/models/mixer_seq_simple.py#L173
|
||
|
||
"""
|
||
x = self.embedding(input_ids)
|
||
|
||
for layer in self.layers:
|
||
x = layer(x)
|
||
|
||
x = self.norm_f(x)
|
||
logits = self.lm_head(x)
|
||
|
||
return logits
|
||
|
||
|
||
@staticmethod
|
||
def from_pretrained(pretrained_model_name: str):
|
||
"""从HuggingFace加载预训练权重到模型中。
|
||
|
||
参数:
|
||
pretrained_model_name: 以下之一
|
||
* 'state-spaces/mamba-2.8b-slimpj'
|
||
* 'state-spaces/mamba-2.8b'
|
||
* 'state-spaces/mamba-1.4b'
|
||
* 'state-spaces/mamba-790m'
|
||
* 'state-spaces/mamba-370m'
|
||
* 'state-spaces/mamba-130m'
|
||
|
||
返回:
|
||
model: 加载了权重的Mamba模型
|
||
|
||
"""
|
||
from transformers.utils import WEIGHTS_NAME, CONFIG_NAME
|
||
from transformers.utils.hub import cached_file
|
||
|
||
def load_config_hf(model_name):
|
||
resolved_archive_file = cached_file(model_name, CONFIG_NAME,
|
||
_raise_exceptions_for_missing_entries=False)
|
||
return json.load(open(resolved_archive_file))
|
||
|
||
|
||
def load_state_dict_hf(model_name, device=None, dtype=None):
|
||
resolved_archive_file = cached_file(model_name, WEIGHTS_NAME,
|
||
_raise_exceptions_for_missing_entries=False)
|
||
return torch.load(resolved_archive_file, weights_only=True, map_location='cpu', mmap=True)
|
||
|
||
config_data = load_config_hf(pretrained_model_name)
|
||
args = ModelArgs(
|
||
d_model=config_data['d_model'],
|
||
n_layer=config_data['n_layer'],
|
||
vocab_size=config_data['vocab_size']
|
||
)
|
||
model = Mamba(args)
|
||
|
||
state_dict = load_state_dict_hf(pretrained_model_name)
|
||
new_state_dict = {}
|
||
for key in state_dict:
|
||
new_key = key.replace('backbone.', '')
|
||
new_state_dict[new_key] = state_dict[key]
|
||
model.load_state_dict(new_state_dict)
|
||
|
||
return model
|
||
|
||
|
||
class ResidualBlock(nn.Module):
|
||
def __init__(self, args: ModelArgs):
|
||
"""包装Mamba块的简单块,具有归一化和残差连接。"""
|
||
super().__init__()
|
||
self.args = args
|
||
self.mixer = MambaBlock(args)
|
||
self.norm = RMSNorm(args.d_model)
|
||
|
||
|
||
def forward(self, x):
|
||
"""
|
||
参数:
|
||
x: 形状 (b, l, d) (参见顶部术语表中b, l, d_in, n的定义)
|
||
|
||
返回:
|
||
output: 形状 (b, l, d)
|
||
|
||
官方实现:
|
||
Block.forward(), https://github.com/state-spaces/mamba/blob/main/mamba_ssm/modules/mamba_simple.py#L297
|
||
|
||
注意: 官方库链式残差块看起来像
|
||
[Add -> Norm -> Mamba] -> [Add -> Norm -> Mamba] -> [Add -> Norm -> Mamba] -> ...
|
||
其中第一个Add是无操作。这纯粹是为了性能原因,因为这
|
||
允许他们融合Add->Norm。
|
||
|
||
我们相反实现的块更为熟悉,更简单,数值上等效
|
||
[Norm -> Mamba -> Add] -> [Norm -> Mamba -> Add] -> [Norm -> Mamba -> Add] -> ....
|
||
|
||
"""
|
||
output = self.mixer(self.norm(x)) + x
|
||
|
||
return output
|
||
|
||
|
||
class MambaBlock(nn.Module):
|
||
def __init__(self, args: ModelArgs):
|
||
"""一个单一的Mamba块,如Mamba论文[1]第3.4节中的图3所述。"""
|
||
super().__init__()
|
||
self.args = args
|
||
|
||
self.in_proj = nn.Linear(args.d_model, args.d_inner * 2, bias=args.bias)
|
||
|
||
self.conv1d = nn.Conv1d(
|
||
in_channels=args.d_inner,
|
||
out_channels=args.d_inner,
|
||
bias=args.conv_bias,
|
||
kernel_size=args.d_conv,
|
||
groups=args.d_inner,
|
||
padding=args.d_conv - 1,
|
||
)
|
||
|
||
# x_proj接收`x`并输出输入特定的Δ, B, C
|
||
self.x_proj = nn.Linear(args.d_inner, args.dt_rank + args.d_state * 2, bias=False)
|
||
|
||
# dt_proj将Δ从dt_rank投影到d_in
|
||
self.dt_proj = nn.Linear(args.dt_rank, args.d_inner, bias=True)
|
||
|
||
A = repeat(torch.arange(1, args.d_state + 1), 'n -> d n', d=args.d_inner)
|
||
self.A_log = nn.Parameter(torch.log(A))
|
||
self.D = nn.Parameter(torch.ones(args.d_inner))
|
||
self.out_proj = nn.Linear(args.d_inner, args.d_model, bias=args.bias)
|
||
|
||
|
||
def forward(self, x):
|
||
"""Mamba块前向。这与Mamba论文[1]第3.4节中的图3相同。
|
||
|
||
参数:
|
||
x: 形状 (b, l, d) (参见顶部术语表中b, l, d_in, n的定义)
|
||
|
||
返回:
|
||
output: 形状 (b, l, d)
|
||
|
||
官方实现:
|
||
class Mamba, https://github.com/state-spaces/mamba/blob/main/mamba_ssm/modules/mamba_simple.py#L119
|
||
mamba_inner_ref(), https://github.com/state-spaces/mamba/blob/main/mamba_ssm/ops/selective_scan_interface.py#L311
|
||
|
||
"""
|
||
(b, l, d) = x.shape
|
||
|
||
x_and_res = self.in_proj(x) # 形状 (b, l, 2 * d_in)
|
||
(x, res) = x_and_res.split(split_size=[self.args.d_inner, self.args.d_inner], dim=-1)
|
||
|
||
x = rearrange(x, 'b l d_in -> b d_in l')
|
||
x = self.conv1d(x)[:, :, :l]
|
||
x = rearrange(x, 'b d_in l -> b l d_in')
|
||
|
||
x = F.silu(x)
|
||
|
||
y = self.ssm(x)
|
||
|
||
y = y * F.silu(res)
|
||
|
||
output = self.out_proj(y)
|
||
|
||
return output
|
||
|
||
|
||
def ssm(self, x):
|
||
"""运行SSM。参见:
|
||
- Mamba论文[1]第3.2节中的算法2
|
||
- The Annotated S4 [2] 中的run_SSM(A, B, C, u)
|
||
|
||
参数:
|
||
x: 形状 (b, l, d_in) (参见顶部术语表中b, l, d_in, n的定义)
|
||
|
||
返回:
|
||
output: 形状 (b, l, d_in)
|
||
|
||
官方实现:
|
||
mamba_inner_ref(), https://github.com/state-spaces/mamba/blob/main/mamba_ssm/ops/selective_scan_interface.py#L311
|
||
|
||
"""
|
||
(d_in, n) = self.A_log.shape
|
||
|
||
# 计算 ∆ A B C D,状态空间参数。
|
||
# A, D 与输入无关 (参见Mamba论文[1]第3.5.2节"Interpretation of A"中为何A不是选择性的)
|
||
# ∆, B, C 与输入相关 (这是Mamba与线性时不变S4的一个关键区别,
|
||
# 也是Mamba称为**选择性**状态空间的原因)
|
||
|
||
A = -torch.exp(self.A_log.float()) # 形状 (d_in, n)
|
||
D = self.D.float()
|
||
|
||
x_dbl = self.x_proj(x) # (b, l, dt_rank + 2*n)
|
||
|
||
(delta, B, C) = x_dbl.split(split_size=[self.args.dt_rank, n, n], dim=-1) # delta: (b, l, dt_rank). B, C: (b, l, n)
|
||
delta = F.softplus(self.dt_proj(delta)) # (b, l, d_in)
|
||
|
||
y = self.selective_scan(x, delta, A, B, C, D) # 这类似于The Annotated S4 [2]中的run_SSM(A, B, C, u)
|
||
|
||
return y
|
||
|
||
|
||
def selective_scan(self, u, delta, A, B, C, D):
|
||
"""执行选择性扫描算法。参见:
|
||
- Mamba论文[1]中的第2节状态空间模型
|
||
- Mamba论文[1]第3.2节中的算法2
|
||
- The Annotated S4 [2]中的run_SSM(A, B, C, u)
|
||
|
||
这是经典的离散状态空间公式:
|
||
x(t + 1) = Ax(t) + Bu(t)
|
||
y(t) = Cx(t) + Du(t)
|
||
除了B和C(以及用于离散化的步长delta)依赖于输入x(t)。
|
||
|
||
参数:
|
||
u: 形状 (b, l, d_in) (参见顶部术语表中b, l, d_in, n的定义)
|
||
delta: 形状 (b, l, d_in)
|
||
A: 形状 (d_in, n)
|
||
B: 形状 (b, l, n)
|
||
C: 形状 (b, l, n)
|
||
D: 形状 (d_in,)
|
||
|
||
返回:
|
||
output: 形状 (b, l, d_in)
|
||
|
||
官方实现:
|
||
selective_scan_ref(), https://github.com/state-spaces/mamba/blob/main/mamba_ssm/ops/selective_scan_interface.py#L86
|
||
注意: 我将`selective_scan_ref`中的一些部分进行了重构,所以功能不完全匹配。
|
||
|
||
"""
|
||
(b, l, d_in) = u.shape
|
||
n = A.shape[1]
|
||
|
||
# 离散化连续参数 (A, B)
|
||
# - A 使用零阶保持(ZOH)离散化 (参见Mamba论文[1]第2节方程4)
|
||
# - B 使用简化的欧拉离散化而不是ZOH。从与作者的讨论中:
|
||
# "A是更重要的项,简化B对性能影响不大"
|
||
deltaA = torch.exp(einsum(delta, A, 'b l d_in, d_in n -> b l d_in n'))
|
||
deltaB_u = einsum(delta, B, u, 'b l d_in, b l n, b l d_in -> b l d_in n')
|
||
|
||
# 执行选择性扫描 (参见The Annotated S4 [2]中的scan_SSM())
|
||
# 注意,以下是顺序的,而官方实现是一个更快的并行扫描,此外还考虑了硬件 (如FlashAttention)。
|
||
x = torch.zeros((b, d_in, n), device=deltaA.device)
|
||
ys = []
|
||
for i in range(l):
|
||
x = deltaA[:, i] * x + deltaB_u[:, i]
|
||
y = einsum(x, C[:, i, :], 'b d_in n, b n -> b d_in')
|
||
ys.append(y)
|
||
y = torch.stack(ys, dim=1) # 形状 (b, l, d_in)
|
||
|
||
y = y + u * D
|
||
|
||
return y
|
||
|
||
|
||
class RMSNorm(nn.Module):
|
||
def __init__(self,
|
||
d_model: int,
|
||
eps: float = 1e-5):
|
||
super().__init__()
|
||
self.eps = eps
|
||
self.weight = nn.Parameter(torch.ones(d_model))
|
||
|
||
|
||
def forward(self, x):
|
||
output = x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) * self.weight
|
||
|
||
return output
|