add mamba

This commit is contained in:
kewei 2024-05-31 17:04:42 +08:00
parent 2bd03665dd
commit 1df8bae480
3 changed files with 739 additions and 0 deletions

View File

@ -0,0 +1,34 @@
## mamba-minimal
Simple, minimal implementation of Mamba in one file of PyTorch.
Featuring:
* Equivalent numerical output as official implementation for both forward and backward pass
* Simplified, readable, annotated code
Does NOT include:
* Speed. The official implementation is heavily optimized, and these optimizations are core contributions of the Mamba paper. I kept most implementations simple for readability.
* Proper parameter initialization (though this could be added without sacrificing readability)
## Demo
See [demo.ipynb](demo.ipynb) for examples of prompt completions.
```python
from model import Mamba
from transformers import AutoTokenizer
model = Mamba.from_pretrained('state-spaces/mamba-370m')
tokenizer = AutoTokenizer.from_pretrained('EleutherAI/gpt-neox-20b')
generate(model, tokenizer, 'Mamba is the')
```
> Mamba is the world's longest venomous snake with an estimated length of over 150 m. With such a large size and a venomous bite, Mamba kills by stabbing the victim (which is more painful and less effective than a single stab of the bite)
150 meters... 🫢 scary!
## References
The Mamba architecture was introduced in [Mamba: Linear-Time Sequence Modeling with Selective State Spaces](https://arxiv.org/abs/2312.00752) by [Albert Gu](https://twitter.com/_albertgu?lang=en) and [Tri Dao](https://twitter.com/tri_dao?ref_src=twsrc%5Egoogle%7Ctwcamp%5Eserp%7Ctwgr%5Eauthor).
The official implementation is here: https://github.com/state-spaces/mamba/tree/main

View File

@ -0,0 +1,365 @@
{
"cells": [
{
"cell_type": "markdown",
"id": "531467a2-5160-4073-a990-0d81d574b014",
"metadata": {},
"source": [
"## (1) Load model"
]
},
{
"cell_type": "code",
"execution_count": 1,
"id": "621ebeea-475a-4917-af01-22b1ec948075",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"env: HF_ENDPOINT=https://hf-mirror.com\n"
]
}
],
"source": [
"%env HF_ENDPOINT=https://hf-mirror.com"
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "cac008ac-d314-4ebe-a3d4-f9454d7e3614",
"metadata": {
"scrolled": true
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Looking in indexes: https://pypi.tuna.tsinghua.edu.cn/simple\n",
"Collecting einops\n",
" Downloading https://pypi.tuna.tsinghua.edu.cn/packages/44/5a/f0b9ad6c0a9017e62d4735daaeb11ba3b6c009d69a26141b258cd37b5588/einops-0.8.0-py3-none-any.whl (43 kB)\n",
"\u001b[2K \u001b[38;2;114;156;31m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m43.2/43.2 kB\u001b[0m \u001b[31m395.8 kB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0mm eta \u001b[36m-:--:--\u001b[0m\n",
"\u001b[?25hInstalling collected packages: einops\n",
"Successfully installed einops-0.8.0\n",
"Note: you may need to restart the kernel to use updated packages.\n"
]
}
],
"source": [
"%pip install einops"
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "d9337043-4e7a-4b20-9d89-6c6257245334",
"metadata": {},
"outputs": [
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "570f0993fa46465d8866e3190fde4a0e",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"config.json: 0%| | 0.00/200 [00:00<?, ?B/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "76b85f2e1c84414290c77815e848a03c",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"pytorch_model.bin: 0%| | 0.00/1.49G [00:00<?, ?B/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"/data1/ckw/micromamba/envs/kewei-ai/lib/python3.11/site-packages/torch/_utils.py:831: UserWarning: TypedStorage is deprecated. It will be removed in the future and UntypedStorage will be the only storage class. This should only matter to you if you are using storages directly. To access UntypedStorage directly, use tensor.untyped_storage() instead of tensor.storage()\n",
" return self.fget.__get__(instance, owner)()\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "4851f6fb0db0428c8e1ff0dc9d147155",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"tokenizer_config.json: 0%| | 0.00/156 [00:00<?, ?B/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "27ca8edc7934436a9350a3ee95d5590a",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"vocab.json: 0.00B [00:00, ?B/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "1c27d4a2251841cfbdce227ef4069ad2",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"merges.txt: 0.00B [00:00, ?B/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "a3ceb33a66c64dc9a7bb3a8ebc20463e",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"tokenizer.json: 0.00B [00:00, ?B/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "407e741d1bb544bdb9540ac77c71dcfe",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"special_tokens_map.json: 0%| | 0.00/90.0 [00:00<?, ?B/s]"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"from model import Mamba, ModelArgs\n",
"from transformers import AutoTokenizer\n",
"\n",
"# One of:\n",
"# 'state-spaces/mamba-2.8b-slimpj'\n",
"# 'state-spaces/mamba-2.8b'\n",
"# 'state-spaces/mamba-1.4b'\n",
"# 'state-spaces/mamba-790m'\n",
"# 'state-spaces/mamba-370m'\n",
"# 'state-spaces/mamba-130m'\n",
"pretrained_model_name = 'state-spaces/mamba-370m'\n",
"\n",
"model = Mamba.from_pretrained(pretrained_model_name)\n",
"tokenizer = AutoTokenizer.from_pretrained('EleutherAI/gpt-neox-20b')"
]
},
{
"cell_type": "markdown",
"id": "0b2efb17-37ad-472b-b029-9567acf17629",
"metadata": {},
"source": [
"## (2) Generate Text"
]
},
{
"cell_type": "code",
"execution_count": 5,
"id": "c4b2d62d-0d95-4a3f-bd98-aa37e3f26b39",
"metadata": {},
"outputs": [],
"source": [
"import torch\n",
"import torch.nn.functional as F\n",
"\n",
"\n",
"def generate(model,\n",
" tokenizer,\n",
" prompt: str,\n",
" n_tokens_to_gen: int = 50,\n",
" sample: bool = True,\n",
" top_k: int = 40):\n",
" model.eval()\n",
" \n",
" input_ids = tokenizer(prompt, return_tensors='pt').input_ids\n",
" \n",
" for token_n in range(n_tokens_to_gen):\n",
" with torch.no_grad():\n",
" indices_to_input = input_ids\n",
" next_token_logits = model(indices_to_input)[:, -1]\n",
" \n",
" probs = F.softmax(next_token_logits, dim=-1)\n",
" (batch, vocab_size) = probs.shape\n",
" \n",
" if top_k is not None:\n",
" (values, indices) = torch.topk(probs, k=top_k)\n",
" probs[probs < values[:, -1, None]] = 0\n",
" probs = probs / probs.sum(axis=1, keepdims=True)\n",
" \n",
" if sample:\n",
" next_indices = torch.multinomial(probs, num_samples=1)\n",
" else:\n",
" next_indices = torch.argmax(probs, dim=-1)[:, None]\n",
" \n",
" input_ids = torch.cat([input_ids, next_indices], dim=1)\n",
"\n",
" output_completions = [tokenizer.decode(output.tolist()) for output in input_ids][0]\n",
" \n",
" return output_completions"
]
},
{
"cell_type": "code",
"execution_count": 6,
"id": "ee877143-2042-4579-8042-a96db6200517",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Mamba is the new bestselling game from the popular mobile developers. This addictive and immersive platformer is a fun and addictive mobile game created by Nintendo that requires players to take control of a large robot monster in an environment where you control the movement of its legs and\n"
]
}
],
"source": [
"print(generate(model, tokenizer, 'Mamba is the'))"
]
},
{
"cell_type": "raw",
"id": "2c7642e7-6702-4687-8c4c-95f858625c8d",
"metadata": {},
"source": [
"上述这一段在cpu上用了114s"
]
},
{
"cell_type": "code",
"execution_count": 7,
"id": "65d70549-597f-49ca-9185-2184d2576f7d",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n",
"DataWhalechina is an organization founded at Shanghai Jiao Tong University that helps learners learn artificial intelligence. The organization helps students gain expertise to be able to lead data mining projects in companies. The organization has developed various applications to help a learning process of data mining and artificial intelligence. Moreover, data management and the use of Big Data can also be used as\n"
]
}
],
"source": [
"print(generate(model, tokenizer, '\\nDataWhalechina is an organization founded at Shanghai Jiao Tong University that helps learners learn artificial intelligence.'))"
]
},
{
"cell_type": "code",
"execution_count": 8,
"id": "6d419fc9-066b-4818-812c-2f1952528bc6",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"The meaning of life is \n",
"just this: It is the best you can do.\n",
"\n",
"--K.J.\n",
"\n",
"And finally: How to handle your emotions. \n",
"\n",
"<|endoftext|>Q:\n",
"\n",
"Error creating an EntityManager instance in JavaEE 7\n",
"\n",
"This is\n"
]
}
],
"source": [
"print(generate(model, tokenizer, 'The meaning of life is '))"
]
},
{
"cell_type": "code",
"execution_count": 11,
"id": "2b189e6e-6a96-4770-88cf-7c5de22cb321",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"def reverse_string(text, result):\n",
" # find the position of the start of the string.\n",
" start = text.index(text[0:-1])\n",
" # find the position where the string begins changing.\n",
" end = text.index\n"
]
}
],
"source": [
"print(generate(model, tokenizer, 'def reverse_string('))"
]
},
{
"cell_type": "raw",
"id": "be40cccb-e0e3-40ca-ac87-90249f1bbe3f",
"metadata": {},
"source": [
"我感觉mamba速度还是赶rwkv差了不少340m参数的模型不及rwkv-1b6参数量的模型快"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.5"
}
},
"nbformat": 4,
"nbformat_minor": 5
}

View File

@ -0,0 +1,340 @@
"""
Mamba在一个PyTorch文件中的简单极简实现
建议在阅读代码之前或期间阅读以下内容
[1] Mamba: Linear-Time Sequence Modeling with Selective State Spaces (Albert Gu Tri Dao)
https://arxiv.org/abs/2312.00752
[2] The Annotated S4 (Sasha Rush Sidd Karamcheti)
https://srush.github.io/annotated-s4
术语表
b: 批次大小 (`B` Mamba 论文 [1] 算法2中)
l: 序列长度 (`L` [1] 算法2中)
d d_model: 隐藏维度
n d_state: 潜在状态维度 (`N` [1] 算法2中)
expand: 扩展因子 (`E` [1] 第3.4节中)
d_in d_inner: d * expand (`D` [1] 算法2中)
A, B, C, D: 状态空间参数 (参见任何状态空间表示公式)
(B, C是输入相关的(即选择性的这是Mamba的一个关键创新); A, D不是)
Δ delta: 输入相关的步长
dt_rank: Δ的秩 (参见[1] 第3.6 "Parameterization of ∆")
"""
from __future__ import annotations
import math
import json
import torch
import torch.nn as nn
import torch.nn.functional as F
from dataclasses import dataclass
from einops import rearrange, repeat, einsum
@dataclass
class ModelArgs:
d_model: int
n_layer: int
vocab_size: int
d_state: int = 16
expand: int = 2
dt_rank: Union[int, str] = 'auto'
d_conv: int = 4
pad_vocab_size_multiple: int = 8
conv_bias: bool = True
bias: bool = False
def __post_init__(self):
self.d_inner = int(self.expand * self.d_model)
if self.dt_rank == 'auto':
self.dt_rank = math.ceil(self.d_model / 16)
if self.vocab_size % self.pad_vocab_size_multiple != 0:
self.vocab_size += (self.pad_vocab_size_multiple
- self.vocab_size % self.pad_vocab_size_multiple)
class Mamba(nn.Module):
def __init__(self, args: ModelArgs):
"""完整的Mamba模型。"""
super().__init__()
self.args = args
self.embedding = nn.Embedding(args.vocab_size, args.d_model)
self.layers = nn.ModuleList([ResidualBlock(args) for _ in range(args.n_layer)])
self.norm_f = RMSNorm(args.d_model)
self.lm_head = nn.Linear(args.d_model, args.vocab_size, bias=False)
self.lm_head.weight = self.embedding.weight # 将输出投影与嵌入权重绑定。
# 参见 "Weight Tying" 论文
def forward(self, input_ids):
"""
参数:
input_ids (long tensor): 形状 (b, l) (参见顶部术语表中b, l, d_in, n的定义)
返回:
logits: 形状 (b, l, vocab_size)
官方实现:
class MambaLMHeadModel, https://github.com/state-spaces/mamba/blob/main/mamba_ssm/models/mixer_seq_simple.py#L173
"""
x = self.embedding(input_ids)
for layer in self.layers:
x = layer(x)
x = self.norm_f(x)
logits = self.lm_head(x)
return logits
@staticmethod
def from_pretrained(pretrained_model_name: str):
"""从HuggingFace加载预训练权重到模型中。
参数:
pretrained_model_name: 以下之一
* 'state-spaces/mamba-2.8b-slimpj'
* 'state-spaces/mamba-2.8b'
* 'state-spaces/mamba-1.4b'
* 'state-spaces/mamba-790m'
* 'state-spaces/mamba-370m'
* 'state-spaces/mamba-130m'
返回:
model: 加载了权重的Mamba模型
"""
from transformers.utils import WEIGHTS_NAME, CONFIG_NAME
from transformers.utils.hub import cached_file
def load_config_hf(model_name):
resolved_archive_file = cached_file(model_name, CONFIG_NAME,
_raise_exceptions_for_missing_entries=False)
return json.load(open(resolved_archive_file))
def load_state_dict_hf(model_name, device=None, dtype=None):
resolved_archive_file = cached_file(model_name, WEIGHTS_NAME,
_raise_exceptions_for_missing_entries=False)
return torch.load(resolved_archive_file, weights_only=True, map_location='cpu', mmap=True)
config_data = load_config_hf(pretrained_model_name)
args = ModelArgs(
d_model=config_data['d_model'],
n_layer=config_data['n_layer'],
vocab_size=config_data['vocab_size']
)
model = Mamba(args)
state_dict = load_state_dict_hf(pretrained_model_name)
new_state_dict = {}
for key in state_dict:
new_key = key.replace('backbone.', '')
new_state_dict[new_key] = state_dict[key]
model.load_state_dict(new_state_dict)
return model
class ResidualBlock(nn.Module):
def __init__(self, args: ModelArgs):
"""包装Mamba块的简单块具有归一化和残差连接。"""
super().__init__()
self.args = args
self.mixer = MambaBlock(args)
self.norm = RMSNorm(args.d_model)
def forward(self, x):
"""
参数:
x: 形状 (b, l, d) (参见顶部术语表中b, l, d_in, n的定义)
返回:
output: 形状 (b, l, d)
官方实现:
Block.forward(), https://github.com/state-spaces/mamba/blob/main/mamba_ssm/modules/mamba_simple.py#L297
注意: 官方库链式残差块看起来像
[Add -> Norm -> Mamba] -> [Add -> Norm -> Mamba] -> [Add -> Norm -> Mamba] -> ...
其中第一个Add是无操作这纯粹是为了性能原因因为这
允许他们融合Add->Norm
我们相反实现的块更为熟悉更简单数值上等效
[Norm -> Mamba -> Add] -> [Norm -> Mamba -> Add] -> [Norm -> Mamba -> Add] -> ....
"""
output = self.mixer(self.norm(x)) + x
return output
class MambaBlock(nn.Module):
def __init__(self, args: ModelArgs):
"""一个单一的Mamba块如Mamba论文[1]第3.4节中的图3所述。"""
super().__init__()
self.args = args
self.in_proj = nn.Linear(args.d_model, args.d_inner * 2, bias=args.bias)
self.conv1d = nn.Conv1d(
in_channels=args.d_inner,
out_channels=args.d_inner,
bias=args.conv_bias,
kernel_size=args.d_conv,
groups=args.d_inner,
padding=args.d_conv - 1,
)
# x_proj接收`x`并输出输入特定的Δ, B, C
self.x_proj = nn.Linear(args.d_inner, args.dt_rank + args.d_state * 2, bias=False)
# dt_proj将Δ从dt_rank投影到d_in
self.dt_proj = nn.Linear(args.dt_rank, args.d_inner, bias=True)
A = repeat(torch.arange(1, args.d_state + 1), 'n -> d n', d=args.d_inner)
self.A_log = nn.Parameter(torch.log(A))
self.D = nn.Parameter(torch.ones(args.d_inner))
self.out_proj = nn.Linear(args.d_inner, args.d_model, bias=args.bias)
def forward(self, x):
"""Mamba块前向。这与Mamba论文[1]第3.4节中的图3相同。
参数:
x: 形状 (b, l, d) (参见顶部术语表中b, l, d_in, n的定义)
返回:
output: 形状 (b, l, d)
官方实现:
class Mamba, https://github.com/state-spaces/mamba/blob/main/mamba_ssm/modules/mamba_simple.py#L119
mamba_inner_ref(), https://github.com/state-spaces/mamba/blob/main/mamba_ssm/ops/selective_scan_interface.py#L311
"""
(b, l, d) = x.shape
x_and_res = self.in_proj(x) # 形状 (b, l, 2 * d_in)
(x, res) = x_and_res.split(split_size=[self.args.d_inner, self.args.d_inner], dim=-1)
x = rearrange(x, 'b l d_in -> b d_in l')
x = self.conv1d(x)[:, :, :l]
x = rearrange(x, 'b d_in l -> b l d_in')
x = F.silu(x)
y = self.ssm(x)
y = y * F.silu(res)
output = self.out_proj(y)
return output
def ssm(self, x):
"""运行SSM。参见
- Mamba论文[1]第3.2节中的算法2
- The Annotated S4 [2] 中的run_SSM(A, B, C, u)
参数:
x: 形状 (b, l, d_in) (参见顶部术语表中b, l, d_in, n的定义)
返回:
output: 形状 (b, l, d_in)
官方实现:
mamba_inner_ref(), https://github.com/state-spaces/mamba/blob/main/mamba_ssm/ops/selective_scan_interface.py#L311
"""
(d_in, n) = self.A_log.shape
# 计算 ∆ A B C D状态空间参数。
# A, D 与输入无关 (参见Mamba论文[1]第3.5.2节"Interpretation of A"中为何A不是选择性的)
# ∆, B, C 与输入相关 (这是Mamba与线性时不变S4的一个关键区别
# 也是Mamba称为**选择性**状态空间的原因)
A = -torch.exp(self.A_log.float()) # 形状 (d_in, n)
D = self.D.float()
x_dbl = self.x_proj(x) # (b, l, dt_rank + 2*n)
(delta, B, C) = x_dbl.split(split_size=[self.args.dt_rank, n, n], dim=-1) # delta: (b, l, dt_rank). B, C: (b, l, n)
delta = F.softplus(self.dt_proj(delta)) # (b, l, d_in)
y = self.selective_scan(x, delta, A, B, C, D) # 这类似于The Annotated S4 [2]中的run_SSM(A, B, C, u)
return y
def selective_scan(self, u, delta, A, B, C, D):
"""执行选择性扫描算法。参见:
- Mamba论文[1]中的第2节状态空间模型
- Mamba论文[1]第3.2节中的算法2
- The Annotated S4 [2]中的run_SSM(A, B, C, u)
这是经典的离散状态空间公式
x(t + 1) = Ax(t) + Bu(t)
y(t) = Cx(t) + Du(t)
除了B和C(以及用于离散化的步长delta)依赖于输入x(t)
参数:
u: 形状 (b, l, d_in) (参见顶部术语表中b, l, d_in, n的定义)
delta: 形状 (b, l, d_in)
A: 形状 (d_in, n)
B: 形状 (b, l, n)
C: 形状 (b, l, n)
D: 形状 (d_in,)
返回:
output: 形状 (b, l, d_in)
官方实现:
selective_scan_ref(), https://github.com/state-spaces/mamba/blob/main/mamba_ssm/ops/selective_scan_interface.py#L86
注意: 我将`selective_scan_ref`中的一些部分进行了重构所以功能不完全匹配
"""
(b, l, d_in) = u.shape
n = A.shape[1]
# 离散化连续参数 (A, B)
# - A 使用零阶保持(ZOH)离散化 (参见Mamba论文[1]第2节方程4)
# - B 使用简化的欧拉离散化而不是ZOH。从与作者的讨论中
# "A是更重要的项简化B对性能影响不大"
deltaA = torch.exp(einsum(delta, A, 'b l d_in, d_in n -> b l d_in n'))
deltaB_u = einsum(delta, B, u, 'b l d_in, b l n, b l d_in -> b l d_in n')
# 执行选择性扫描 (参见The Annotated S4 [2]中的scan_SSM())
# 注意,以下是顺序的,而官方实现是一个更快的并行扫描,此外还考虑了硬件 (如FlashAttention)。
x = torch.zeros((b, d_in, n), device=deltaA.device)
ys = []
for i in range(l):
x = deltaA[:, i] * x + deltaB_u[:, i]
y = einsum(x, C[:, i, :], 'b d_in n, b n -> b d_in')
ys.append(y)
y = torch.stack(ys, dim=1) # 形状 (b, l, d_in)
y = y + u * D
return y
class RMSNorm(nn.Module):
def __init__(self,
d_model: int,
eps: float = 1e-5):
super().__init__()
self.eps = eps
self.weight = nn.Parameter(torch.ones(d_model))
def forward(self, x):
output = x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) * self.weight
return output