mirror of
https://github.com/datawhalechina/llms-from-scratch-cn.git
synced 2026-04-25 08:58:17 +08:00
add mamba
This commit is contained in:
parent
2bd03665dd
commit
1df8bae480
34
Model_Architecture_Discussions/mamba/README.md
Normal file
34
Model_Architecture_Discussions/mamba/README.md
Normal file
@ -0,0 +1,34 @@
|
||||
## mamba-minimal
|
||||
|
||||
Simple, minimal implementation of Mamba in one file of PyTorch.
|
||||
|
||||
Featuring:
|
||||
* Equivalent numerical output as official implementation for both forward and backward pass
|
||||
* Simplified, readable, annotated code
|
||||
|
||||
Does NOT include:
|
||||
* Speed. The official implementation is heavily optimized, and these optimizations are core contributions of the Mamba paper. I kept most implementations simple for readability.
|
||||
* Proper parameter initialization (though this could be added without sacrificing readability)
|
||||
|
||||
## Demo
|
||||
|
||||
See [demo.ipynb](demo.ipynb) for examples of prompt completions.
|
||||
|
||||
```python
|
||||
from model import Mamba
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
model = Mamba.from_pretrained('state-spaces/mamba-370m')
|
||||
tokenizer = AutoTokenizer.from_pretrained('EleutherAI/gpt-neox-20b')
|
||||
|
||||
generate(model, tokenizer, 'Mamba is the')
|
||||
```
|
||||
> Mamba is the world's longest venomous snake with an estimated length of over 150 m. With such a large size and a venomous bite, Mamba kills by stabbing the victim (which is more painful and less effective than a single stab of the bite)
|
||||
|
||||
150 meters... 🫢 scary!
|
||||
|
||||
## References
|
||||
|
||||
The Mamba architecture was introduced in [Mamba: Linear-Time Sequence Modeling with Selective State Spaces](https://arxiv.org/abs/2312.00752) by [Albert Gu](https://twitter.com/_albertgu?lang=en) and [Tri Dao](https://twitter.com/tri_dao?ref_src=twsrc%5Egoogle%7Ctwcamp%5Eserp%7Ctwgr%5Eauthor).
|
||||
|
||||
The official implementation is here: https://github.com/state-spaces/mamba/tree/main
|
||||
365
Model_Architecture_Discussions/mamba/demo.ipynb
Normal file
365
Model_Architecture_Discussions/mamba/demo.ipynb
Normal file
@ -0,0 +1,365 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "531467a2-5160-4073-a990-0d81d574b014",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## (1) Load model"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 1,
|
||||
"id": "621ebeea-475a-4917-af01-22b1ec948075",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"env: HF_ENDPOINT=https://hf-mirror.com\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"%env HF_ENDPOINT=https://hf-mirror.com"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 3,
|
||||
"id": "cac008ac-d314-4ebe-a3d4-f9454d7e3614",
|
||||
"metadata": {
|
||||
"scrolled": true
|
||||
},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Looking in indexes: https://pypi.tuna.tsinghua.edu.cn/simple\n",
|
||||
"Collecting einops\n",
|
||||
" Downloading https://pypi.tuna.tsinghua.edu.cn/packages/44/5a/f0b9ad6c0a9017e62d4735daaeb11ba3b6c009d69a26141b258cd37b5588/einops-0.8.0-py3-none-any.whl (43 kB)\n",
|
||||
"\u001b[2K \u001b[38;2;114;156;31m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m43.2/43.2 kB\u001b[0m \u001b[31m395.8 kB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0mm eta \u001b[36m-:--:--\u001b[0m\n",
|
||||
"\u001b[?25hInstalling collected packages: einops\n",
|
||||
"Successfully installed einops-0.8.0\n",
|
||||
"Note: you may need to restart the kernel to use updated packages.\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"%pip install einops"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 4,
|
||||
"id": "d9337043-4e7a-4b20-9d89-6c6257245334",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"application/vnd.jupyter.widget-view+json": {
|
||||
"model_id": "570f0993fa46465d8866e3190fde4a0e",
|
||||
"version_major": 2,
|
||||
"version_minor": 0
|
||||
},
|
||||
"text/plain": [
|
||||
"config.json: 0%| | 0.00/200 [00:00<?, ?B/s]"
|
||||
]
|
||||
},
|
||||
"metadata": {},
|
||||
"output_type": "display_data"
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"application/vnd.jupyter.widget-view+json": {
|
||||
"model_id": "76b85f2e1c84414290c77815e848a03c",
|
||||
"version_major": 2,
|
||||
"version_minor": 0
|
||||
},
|
||||
"text/plain": [
|
||||
"pytorch_model.bin: 0%| | 0.00/1.49G [00:00<?, ?B/s]"
|
||||
]
|
||||
},
|
||||
"metadata": {},
|
||||
"output_type": "display_data"
|
||||
},
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"/data1/ckw/micromamba/envs/kewei-ai/lib/python3.11/site-packages/torch/_utils.py:831: UserWarning: TypedStorage is deprecated. It will be removed in the future and UntypedStorage will be the only storage class. This should only matter to you if you are using storages directly. To access UntypedStorage directly, use tensor.untyped_storage() instead of tensor.storage()\n",
|
||||
" return self.fget.__get__(instance, owner)()\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"application/vnd.jupyter.widget-view+json": {
|
||||
"model_id": "4851f6fb0db0428c8e1ff0dc9d147155",
|
||||
"version_major": 2,
|
||||
"version_minor": 0
|
||||
},
|
||||
"text/plain": [
|
||||
"tokenizer_config.json: 0%| | 0.00/156 [00:00<?, ?B/s]"
|
||||
]
|
||||
},
|
||||
"metadata": {},
|
||||
"output_type": "display_data"
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"application/vnd.jupyter.widget-view+json": {
|
||||
"model_id": "27ca8edc7934436a9350a3ee95d5590a",
|
||||
"version_major": 2,
|
||||
"version_minor": 0
|
||||
},
|
||||
"text/plain": [
|
||||
"vocab.json: 0.00B [00:00, ?B/s]"
|
||||
]
|
||||
},
|
||||
"metadata": {},
|
||||
"output_type": "display_data"
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"application/vnd.jupyter.widget-view+json": {
|
||||
"model_id": "1c27d4a2251841cfbdce227ef4069ad2",
|
||||
"version_major": 2,
|
||||
"version_minor": 0
|
||||
},
|
||||
"text/plain": [
|
||||
"merges.txt: 0.00B [00:00, ?B/s]"
|
||||
]
|
||||
},
|
||||
"metadata": {},
|
||||
"output_type": "display_data"
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"application/vnd.jupyter.widget-view+json": {
|
||||
"model_id": "a3ceb33a66c64dc9a7bb3a8ebc20463e",
|
||||
"version_major": 2,
|
||||
"version_minor": 0
|
||||
},
|
||||
"text/plain": [
|
||||
"tokenizer.json: 0.00B [00:00, ?B/s]"
|
||||
]
|
||||
},
|
||||
"metadata": {},
|
||||
"output_type": "display_data"
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"application/vnd.jupyter.widget-view+json": {
|
||||
"model_id": "407e741d1bb544bdb9540ac77c71dcfe",
|
||||
"version_major": 2,
|
||||
"version_minor": 0
|
||||
},
|
||||
"text/plain": [
|
||||
"special_tokens_map.json: 0%| | 0.00/90.0 [00:00<?, ?B/s]"
|
||||
]
|
||||
},
|
||||
"metadata": {},
|
||||
"output_type": "display_data"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"from model import Mamba, ModelArgs\n",
|
||||
"from transformers import AutoTokenizer\n",
|
||||
"\n",
|
||||
"# One of:\n",
|
||||
"# 'state-spaces/mamba-2.8b-slimpj'\n",
|
||||
"# 'state-spaces/mamba-2.8b'\n",
|
||||
"# 'state-spaces/mamba-1.4b'\n",
|
||||
"# 'state-spaces/mamba-790m'\n",
|
||||
"# 'state-spaces/mamba-370m'\n",
|
||||
"# 'state-spaces/mamba-130m'\n",
|
||||
"pretrained_model_name = 'state-spaces/mamba-370m'\n",
|
||||
"\n",
|
||||
"model = Mamba.from_pretrained(pretrained_model_name)\n",
|
||||
"tokenizer = AutoTokenizer.from_pretrained('EleutherAI/gpt-neox-20b')"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "0b2efb17-37ad-472b-b029-9567acf17629",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## (2) Generate Text"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 5,
|
||||
"id": "c4b2d62d-0d95-4a3f-bd98-aa37e3f26b39",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import torch\n",
|
||||
"import torch.nn.functional as F\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"def generate(model,\n",
|
||||
" tokenizer,\n",
|
||||
" prompt: str,\n",
|
||||
" n_tokens_to_gen: int = 50,\n",
|
||||
" sample: bool = True,\n",
|
||||
" top_k: int = 40):\n",
|
||||
" model.eval()\n",
|
||||
" \n",
|
||||
" input_ids = tokenizer(prompt, return_tensors='pt').input_ids\n",
|
||||
" \n",
|
||||
" for token_n in range(n_tokens_to_gen):\n",
|
||||
" with torch.no_grad():\n",
|
||||
" indices_to_input = input_ids\n",
|
||||
" next_token_logits = model(indices_to_input)[:, -1]\n",
|
||||
" \n",
|
||||
" probs = F.softmax(next_token_logits, dim=-1)\n",
|
||||
" (batch, vocab_size) = probs.shape\n",
|
||||
" \n",
|
||||
" if top_k is not None:\n",
|
||||
" (values, indices) = torch.topk(probs, k=top_k)\n",
|
||||
" probs[probs < values[:, -1, None]] = 0\n",
|
||||
" probs = probs / probs.sum(axis=1, keepdims=True)\n",
|
||||
" \n",
|
||||
" if sample:\n",
|
||||
" next_indices = torch.multinomial(probs, num_samples=1)\n",
|
||||
" else:\n",
|
||||
" next_indices = torch.argmax(probs, dim=-1)[:, None]\n",
|
||||
" \n",
|
||||
" input_ids = torch.cat([input_ids, next_indices], dim=1)\n",
|
||||
"\n",
|
||||
" output_completions = [tokenizer.decode(output.tolist()) for output in input_ids][0]\n",
|
||||
" \n",
|
||||
" return output_completions"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 6,
|
||||
"id": "ee877143-2042-4579-8042-a96db6200517",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Mamba is the new bestselling game from the popular mobile developers. This addictive and immersive platformer is a fun and addictive mobile game created by Nintendo that requires players to take control of a large robot monster in an environment where you control the movement of its legs and\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"print(generate(model, tokenizer, 'Mamba is the'))"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "raw",
|
||||
"id": "2c7642e7-6702-4687-8c4c-95f858625c8d",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"上述这一段在cpu上用了114s"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 7,
|
||||
"id": "65d70549-597f-49ca-9185-2184d2576f7d",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"\n",
|
||||
"DataWhalechina is an organization founded at Shanghai Jiao Tong University that helps learners learn artificial intelligence. The organization helps students gain expertise to be able to lead data mining projects in companies. The organization has developed various applications to help a learning process of data mining and artificial intelligence. Moreover, data management and the use of Big Data can also be used as\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"print(generate(model, tokenizer, '\\nDataWhalechina is an organization founded at Shanghai Jiao Tong University that helps learners learn artificial intelligence.'))"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 8,
|
||||
"id": "6d419fc9-066b-4818-812c-2f1952528bc6",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"The meaning of life is \n",
|
||||
"just this: It is the best you can do.\n",
|
||||
"\n",
|
||||
"--K.J.\n",
|
||||
"\n",
|
||||
"And finally: How to handle your emotions. \n",
|
||||
"\n",
|
||||
"<|endoftext|>Q:\n",
|
||||
"\n",
|
||||
"Error creating an EntityManager instance in JavaEE 7\n",
|
||||
"\n",
|
||||
"This is\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"print(generate(model, tokenizer, 'The meaning of life is '))"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 11,
|
||||
"id": "2b189e6e-6a96-4770-88cf-7c5de22cb321",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"def reverse_string(text, result):\n",
|
||||
" # find the position of the start of the string.\n",
|
||||
" start = text.index(text[0:-1])\n",
|
||||
" # find the position where the string begins changing.\n",
|
||||
" end = text.index\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"print(generate(model, tokenizer, 'def reverse_string('))"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "raw",
|
||||
"id": "be40cccb-e0e3-40ca-ac87-90249f1bbe3f",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"我感觉mamba速度还是赶rwkv差了不少,340m参数的模型,不及rwkv-1b6参数量的模型快"
|
||||
]
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "Python 3 (ipykernel)",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.11.5"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 5
|
||||
}
|
||||
340
Model_Architecture_Discussions/mamba/model.py
Normal file
340
Model_Architecture_Discussions/mamba/model.py
Normal file
@ -0,0 +1,340 @@
|
||||
"""
|
||||
Mamba在一个PyTorch文件中的简单、极简实现。
|
||||
|
||||
建议在阅读代码之前或期间阅读以下内容:
|
||||
[1] Mamba: Linear-Time Sequence Modeling with Selective State Spaces (Albert Gu 和 Tri Dao)
|
||||
https://arxiv.org/abs/2312.00752
|
||||
[2] The Annotated S4 (Sasha Rush 和 Sidd Karamcheti)
|
||||
https://srush.github.io/annotated-s4
|
||||
|
||||
术语表:
|
||||
b: 批次大小 (`B` 在 Mamba 论文 [1] 算法2中)
|
||||
l: 序列长度 (`L` 在 [1] 算法2中)
|
||||
d 或 d_model: 隐藏维度
|
||||
n 或 d_state: 潜在状态维度 (`N` 在 [1] 算法2中)
|
||||
expand: 扩展因子 (`E` 在 [1] 第3.4节中)
|
||||
d_in 或 d_inner: d * expand (`D` 在 [1] 算法2中)
|
||||
A, B, C, D: 状态空间参数 (参见任何状态空间表示公式)
|
||||
(B, C是输入相关的(即选择性的,这是Mamba的一个关键创新); A, D不是)
|
||||
Δ 或 delta: 输入相关的步长
|
||||
dt_rank: Δ的秩 (参见[1] 第3.6节 "Parameterization of ∆")
|
||||
|
||||
"""
|
||||
from __future__ import annotations
|
||||
import math
|
||||
import json
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from dataclasses import dataclass
|
||||
from einops import rearrange, repeat, einsum
|
||||
|
||||
|
||||
@dataclass
|
||||
class ModelArgs:
|
||||
d_model: int
|
||||
n_layer: int
|
||||
vocab_size: int
|
||||
d_state: int = 16
|
||||
expand: int = 2
|
||||
dt_rank: Union[int, str] = 'auto'
|
||||
d_conv: int = 4
|
||||
pad_vocab_size_multiple: int = 8
|
||||
conv_bias: bool = True
|
||||
bias: bool = False
|
||||
|
||||
def __post_init__(self):
|
||||
self.d_inner = int(self.expand * self.d_model)
|
||||
|
||||
if self.dt_rank == 'auto':
|
||||
self.dt_rank = math.ceil(self.d_model / 16)
|
||||
|
||||
if self.vocab_size % self.pad_vocab_size_multiple != 0:
|
||||
self.vocab_size += (self.pad_vocab_size_multiple
|
||||
- self.vocab_size % self.pad_vocab_size_multiple)
|
||||
|
||||
|
||||
class Mamba(nn.Module):
|
||||
def __init__(self, args: ModelArgs):
|
||||
"""完整的Mamba模型。"""
|
||||
super().__init__()
|
||||
self.args = args
|
||||
|
||||
self.embedding = nn.Embedding(args.vocab_size, args.d_model)
|
||||
self.layers = nn.ModuleList([ResidualBlock(args) for _ in range(args.n_layer)])
|
||||
self.norm_f = RMSNorm(args.d_model)
|
||||
|
||||
self.lm_head = nn.Linear(args.d_model, args.vocab_size, bias=False)
|
||||
self.lm_head.weight = self.embedding.weight # 将输出投影与嵌入权重绑定。
|
||||
# 参见 "Weight Tying" 论文
|
||||
|
||||
|
||||
def forward(self, input_ids):
|
||||
"""
|
||||
参数:
|
||||
input_ids (long tensor): 形状 (b, l) (参见顶部术语表中b, l, d_in, n的定义)
|
||||
|
||||
返回:
|
||||
logits: 形状 (b, l, vocab_size)
|
||||
|
||||
官方实现:
|
||||
class MambaLMHeadModel, https://github.com/state-spaces/mamba/blob/main/mamba_ssm/models/mixer_seq_simple.py#L173
|
||||
|
||||
"""
|
||||
x = self.embedding(input_ids)
|
||||
|
||||
for layer in self.layers:
|
||||
x = layer(x)
|
||||
|
||||
x = self.norm_f(x)
|
||||
logits = self.lm_head(x)
|
||||
|
||||
return logits
|
||||
|
||||
|
||||
@staticmethod
|
||||
def from_pretrained(pretrained_model_name: str):
|
||||
"""从HuggingFace加载预训练权重到模型中。
|
||||
|
||||
参数:
|
||||
pretrained_model_name: 以下之一
|
||||
* 'state-spaces/mamba-2.8b-slimpj'
|
||||
* 'state-spaces/mamba-2.8b'
|
||||
* 'state-spaces/mamba-1.4b'
|
||||
* 'state-spaces/mamba-790m'
|
||||
* 'state-spaces/mamba-370m'
|
||||
* 'state-spaces/mamba-130m'
|
||||
|
||||
返回:
|
||||
model: 加载了权重的Mamba模型
|
||||
|
||||
"""
|
||||
from transformers.utils import WEIGHTS_NAME, CONFIG_NAME
|
||||
from transformers.utils.hub import cached_file
|
||||
|
||||
def load_config_hf(model_name):
|
||||
resolved_archive_file = cached_file(model_name, CONFIG_NAME,
|
||||
_raise_exceptions_for_missing_entries=False)
|
||||
return json.load(open(resolved_archive_file))
|
||||
|
||||
|
||||
def load_state_dict_hf(model_name, device=None, dtype=None):
|
||||
resolved_archive_file = cached_file(model_name, WEIGHTS_NAME,
|
||||
_raise_exceptions_for_missing_entries=False)
|
||||
return torch.load(resolved_archive_file, weights_only=True, map_location='cpu', mmap=True)
|
||||
|
||||
config_data = load_config_hf(pretrained_model_name)
|
||||
args = ModelArgs(
|
||||
d_model=config_data['d_model'],
|
||||
n_layer=config_data['n_layer'],
|
||||
vocab_size=config_data['vocab_size']
|
||||
)
|
||||
model = Mamba(args)
|
||||
|
||||
state_dict = load_state_dict_hf(pretrained_model_name)
|
||||
new_state_dict = {}
|
||||
for key in state_dict:
|
||||
new_key = key.replace('backbone.', '')
|
||||
new_state_dict[new_key] = state_dict[key]
|
||||
model.load_state_dict(new_state_dict)
|
||||
|
||||
return model
|
||||
|
||||
|
||||
class ResidualBlock(nn.Module):
|
||||
def __init__(self, args: ModelArgs):
|
||||
"""包装Mamba块的简单块,具有归一化和残差连接。"""
|
||||
super().__init__()
|
||||
self.args = args
|
||||
self.mixer = MambaBlock(args)
|
||||
self.norm = RMSNorm(args.d_model)
|
||||
|
||||
|
||||
def forward(self, x):
|
||||
"""
|
||||
参数:
|
||||
x: 形状 (b, l, d) (参见顶部术语表中b, l, d_in, n的定义)
|
||||
|
||||
返回:
|
||||
output: 形状 (b, l, d)
|
||||
|
||||
官方实现:
|
||||
Block.forward(), https://github.com/state-spaces/mamba/blob/main/mamba_ssm/modules/mamba_simple.py#L297
|
||||
|
||||
注意: 官方库链式残差块看起来像
|
||||
[Add -> Norm -> Mamba] -> [Add -> Norm -> Mamba] -> [Add -> Norm -> Mamba] -> ...
|
||||
其中第一个Add是无操作。这纯粹是为了性能原因,因为这
|
||||
允许他们融合Add->Norm。
|
||||
|
||||
我们相反实现的块更为熟悉,更简单,数值上等效
|
||||
[Norm -> Mamba -> Add] -> [Norm -> Mamba -> Add] -> [Norm -> Mamba -> Add] -> ....
|
||||
|
||||
"""
|
||||
output = self.mixer(self.norm(x)) + x
|
||||
|
||||
return output
|
||||
|
||||
|
||||
class MambaBlock(nn.Module):
|
||||
def __init__(self, args: ModelArgs):
|
||||
"""一个单一的Mamba块,如Mamba论文[1]第3.4节中的图3所述。"""
|
||||
super().__init__()
|
||||
self.args = args
|
||||
|
||||
self.in_proj = nn.Linear(args.d_model, args.d_inner * 2, bias=args.bias)
|
||||
|
||||
self.conv1d = nn.Conv1d(
|
||||
in_channels=args.d_inner,
|
||||
out_channels=args.d_inner,
|
||||
bias=args.conv_bias,
|
||||
kernel_size=args.d_conv,
|
||||
groups=args.d_inner,
|
||||
padding=args.d_conv - 1,
|
||||
)
|
||||
|
||||
# x_proj接收`x`并输出输入特定的Δ, B, C
|
||||
self.x_proj = nn.Linear(args.d_inner, args.dt_rank + args.d_state * 2, bias=False)
|
||||
|
||||
# dt_proj将Δ从dt_rank投影到d_in
|
||||
self.dt_proj = nn.Linear(args.dt_rank, args.d_inner, bias=True)
|
||||
|
||||
A = repeat(torch.arange(1, args.d_state + 1), 'n -> d n', d=args.d_inner)
|
||||
self.A_log = nn.Parameter(torch.log(A))
|
||||
self.D = nn.Parameter(torch.ones(args.d_inner))
|
||||
self.out_proj = nn.Linear(args.d_inner, args.d_model, bias=args.bias)
|
||||
|
||||
|
||||
def forward(self, x):
|
||||
"""Mamba块前向。这与Mamba论文[1]第3.4节中的图3相同。
|
||||
|
||||
参数:
|
||||
x: 形状 (b, l, d) (参见顶部术语表中b, l, d_in, n的定义)
|
||||
|
||||
返回:
|
||||
output: 形状 (b, l, d)
|
||||
|
||||
官方实现:
|
||||
class Mamba, https://github.com/state-spaces/mamba/blob/main/mamba_ssm/modules/mamba_simple.py#L119
|
||||
mamba_inner_ref(), https://github.com/state-spaces/mamba/blob/main/mamba_ssm/ops/selective_scan_interface.py#L311
|
||||
|
||||
"""
|
||||
(b, l, d) = x.shape
|
||||
|
||||
x_and_res = self.in_proj(x) # 形状 (b, l, 2 * d_in)
|
||||
(x, res) = x_and_res.split(split_size=[self.args.d_inner, self.args.d_inner], dim=-1)
|
||||
|
||||
x = rearrange(x, 'b l d_in -> b d_in l')
|
||||
x = self.conv1d(x)[:, :, :l]
|
||||
x = rearrange(x, 'b d_in l -> b l d_in')
|
||||
|
||||
x = F.silu(x)
|
||||
|
||||
y = self.ssm(x)
|
||||
|
||||
y = y * F.silu(res)
|
||||
|
||||
output = self.out_proj(y)
|
||||
|
||||
return output
|
||||
|
||||
|
||||
def ssm(self, x):
|
||||
"""运行SSM。参见:
|
||||
- Mamba论文[1]第3.2节中的算法2
|
||||
- The Annotated S4 [2] 中的run_SSM(A, B, C, u)
|
||||
|
||||
参数:
|
||||
x: 形状 (b, l, d_in) (参见顶部术语表中b, l, d_in, n的定义)
|
||||
|
||||
返回:
|
||||
output: 形状 (b, l, d_in)
|
||||
|
||||
官方实现:
|
||||
mamba_inner_ref(), https://github.com/state-spaces/mamba/blob/main/mamba_ssm/ops/selective_scan_interface.py#L311
|
||||
|
||||
"""
|
||||
(d_in, n) = self.A_log.shape
|
||||
|
||||
# 计算 ∆ A B C D,状态空间参数。
|
||||
# A, D 与输入无关 (参见Mamba论文[1]第3.5.2节"Interpretation of A"中为何A不是选择性的)
|
||||
# ∆, B, C 与输入相关 (这是Mamba与线性时不变S4的一个关键区别,
|
||||
# 也是Mamba称为**选择性**状态空间的原因)
|
||||
|
||||
A = -torch.exp(self.A_log.float()) # 形状 (d_in, n)
|
||||
D = self.D.float()
|
||||
|
||||
x_dbl = self.x_proj(x) # (b, l, dt_rank + 2*n)
|
||||
|
||||
(delta, B, C) = x_dbl.split(split_size=[self.args.dt_rank, n, n], dim=-1) # delta: (b, l, dt_rank). B, C: (b, l, n)
|
||||
delta = F.softplus(self.dt_proj(delta)) # (b, l, d_in)
|
||||
|
||||
y = self.selective_scan(x, delta, A, B, C, D) # 这类似于The Annotated S4 [2]中的run_SSM(A, B, C, u)
|
||||
|
||||
return y
|
||||
|
||||
|
||||
def selective_scan(self, u, delta, A, B, C, D):
|
||||
"""执行选择性扫描算法。参见:
|
||||
- Mamba论文[1]中的第2节状态空间模型
|
||||
- Mamba论文[1]第3.2节中的算法2
|
||||
- The Annotated S4 [2]中的run_SSM(A, B, C, u)
|
||||
|
||||
这是经典的离散状态空间公式:
|
||||
x(t + 1) = Ax(t) + Bu(t)
|
||||
y(t) = Cx(t) + Du(t)
|
||||
除了B和C(以及用于离散化的步长delta)依赖于输入x(t)。
|
||||
|
||||
参数:
|
||||
u: 形状 (b, l, d_in) (参见顶部术语表中b, l, d_in, n的定义)
|
||||
delta: 形状 (b, l, d_in)
|
||||
A: 形状 (d_in, n)
|
||||
B: 形状 (b, l, n)
|
||||
C: 形状 (b, l, n)
|
||||
D: 形状 (d_in,)
|
||||
|
||||
返回:
|
||||
output: 形状 (b, l, d_in)
|
||||
|
||||
官方实现:
|
||||
selective_scan_ref(), https://github.com/state-spaces/mamba/blob/main/mamba_ssm/ops/selective_scan_interface.py#L86
|
||||
注意: 我将`selective_scan_ref`中的一些部分进行了重构,所以功能不完全匹配。
|
||||
|
||||
"""
|
||||
(b, l, d_in) = u.shape
|
||||
n = A.shape[1]
|
||||
|
||||
# 离散化连续参数 (A, B)
|
||||
# - A 使用零阶保持(ZOH)离散化 (参见Mamba论文[1]第2节方程4)
|
||||
# - B 使用简化的欧拉离散化而不是ZOH。从与作者的讨论中:
|
||||
# "A是更重要的项,简化B对性能影响不大"
|
||||
deltaA = torch.exp(einsum(delta, A, 'b l d_in, d_in n -> b l d_in n'))
|
||||
deltaB_u = einsum(delta, B, u, 'b l d_in, b l n, b l d_in -> b l d_in n')
|
||||
|
||||
# 执行选择性扫描 (参见The Annotated S4 [2]中的scan_SSM())
|
||||
# 注意,以下是顺序的,而官方实现是一个更快的并行扫描,此外还考虑了硬件 (如FlashAttention)。
|
||||
x = torch.zeros((b, d_in, n), device=deltaA.device)
|
||||
ys = []
|
||||
for i in range(l):
|
||||
x = deltaA[:, i] * x + deltaB_u[:, i]
|
||||
y = einsum(x, C[:, i, :], 'b d_in n, b n -> b d_in')
|
||||
ys.append(y)
|
||||
y = torch.stack(ys, dim=1) # 形状 (b, l, d_in)
|
||||
|
||||
y = y + u * D
|
||||
|
||||
return y
|
||||
|
||||
|
||||
class RMSNorm(nn.Module):
|
||||
def __init__(self,
|
||||
d_model: int,
|
||||
eps: float = 1e-5):
|
||||
super().__init__()
|
||||
self.eps = eps
|
||||
self.weight = nn.Parameter(torch.ones(d_model))
|
||||
|
||||
|
||||
def forward(self, x):
|
||||
output = x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) * self.weight
|
||||
|
||||
return output
|
||||
Loading…
Reference in New Issue
Block a user