diff --git a/Model_Architecture_Discussions/mamba/README.md b/Model_Architecture_Discussions/mamba/README.md new file mode 100644 index 0000000..4e5e616 --- /dev/null +++ b/Model_Architecture_Discussions/mamba/README.md @@ -0,0 +1,34 @@ +## mamba-minimal + +Simple, minimal implementation of Mamba in one file of PyTorch. + +Featuring: +* Equivalent numerical output as official implementation for both forward and backward pass +* Simplified, readable, annotated code + +Does NOT include: +* Speed. The official implementation is heavily optimized, and these optimizations are core contributions of the Mamba paper. I kept most implementations simple for readability. +* Proper parameter initialization (though this could be added without sacrificing readability) + +## Demo + +See [demo.ipynb](demo.ipynb) for examples of prompt completions. + +```python +from model import Mamba +from transformers import AutoTokenizer + +model = Mamba.from_pretrained('state-spaces/mamba-370m') +tokenizer = AutoTokenizer.from_pretrained('EleutherAI/gpt-neox-20b') + +generate(model, tokenizer, 'Mamba is the') +``` +> Mamba is the world's longest venomous snake with an estimated length of over 150 m. With such a large size and a venomous bite, Mamba kills by stabbing the victim (which is more painful and less effective than a single stab of the bite) + +150 meters... 🫢 scary! + +## References + +The Mamba architecture was introduced in [Mamba: Linear-Time Sequence Modeling with Selective State Spaces](https://arxiv.org/abs/2312.00752) by [Albert Gu](https://twitter.com/_albertgu?lang=en) and [Tri Dao](https://twitter.com/tri_dao?ref_src=twsrc%5Egoogle%7Ctwcamp%5Eserp%7Ctwgr%5Eauthor). + +The official implementation is here: https://github.com/state-spaces/mamba/tree/main diff --git a/Model_Architecture_Discussions/mamba/demo.ipynb b/Model_Architecture_Discussions/mamba/demo.ipynb new file mode 100644 index 0000000..4c46c1a --- /dev/null +++ b/Model_Architecture_Discussions/mamba/demo.ipynb @@ -0,0 +1,365 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "531467a2-5160-4073-a990-0d81d574b014", + "metadata": {}, + "source": [ + "## (1) Load model" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "id": "621ebeea-475a-4917-af01-22b1ec948075", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "env: HF_ENDPOINT=https://hf-mirror.com\n" + ] + } + ], + "source": [ + "%env HF_ENDPOINT=https://hf-mirror.com" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "id": "cac008ac-d314-4ebe-a3d4-f9454d7e3614", + "metadata": { + "scrolled": true + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Looking in indexes: https://pypi.tuna.tsinghua.edu.cn/simple\n", + "Collecting einops\n", + " Downloading https://pypi.tuna.tsinghua.edu.cn/packages/44/5a/f0b9ad6c0a9017e62d4735daaeb11ba3b6c009d69a26141b258cd37b5588/einops-0.8.0-py3-none-any.whl (43 kB)\n", + "\u001b[2K \u001b[38;2;114;156;31m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m43.2/43.2 kB\u001b[0m \u001b[31m395.8 kB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0mm eta \u001b[36m-:--:--\u001b[0m\n", + "\u001b[?25hInstalling collected packages: einops\n", + "Successfully installed einops-0.8.0\n", + "Note: you may need to restart the kernel to use updated packages.\n" + ] + } + ], + "source": [ + "%pip install einops" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "id": "d9337043-4e7a-4b20-9d89-6c6257245334", + "metadata": {}, + "outputs": [ + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "570f0993fa46465d8866e3190fde4a0e", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "config.json: 0%| | 0.00/200 [00:00Q:\n", + "\n", + "Error creating an EntityManager instance in JavaEE 7\n", + "\n", + "This is\n" + ] + } + ], + "source": [ + "print(generate(model, tokenizer, 'The meaning of life is '))" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "id": "2b189e6e-6a96-4770-88cf-7c5de22cb321", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "def reverse_string(text, result):\n", + " # find the position of the start of the string.\n", + " start = text.index(text[0:-1])\n", + " # find the position where the string begins changing.\n", + " end = text.index\n" + ] + } + ], + "source": [ + "print(generate(model, tokenizer, 'def reverse_string('))" + ] + }, + { + "cell_type": "raw", + "id": "be40cccb-e0e3-40ca-ac87-90249f1bbe3f", + "metadata": {}, + "source": [ + "我感觉mamba速度还是赶rwkv差了不少,340m参数的模型,不及rwkv-1b6参数量的模型快" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.11.5" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/Model_Architecture_Discussions/mamba/model.py b/Model_Architecture_Discussions/mamba/model.py new file mode 100644 index 0000000..0369e82 --- /dev/null +++ b/Model_Architecture_Discussions/mamba/model.py @@ -0,0 +1,340 @@ +""" +Mamba在一个PyTorch文件中的简单、极简实现。 + +建议在阅读代码之前或期间阅读以下内容: + [1] Mamba: Linear-Time Sequence Modeling with Selective State Spaces (Albert Gu 和 Tri Dao) + https://arxiv.org/abs/2312.00752 + [2] The Annotated S4 (Sasha Rush 和 Sidd Karamcheti) + https://srush.github.io/annotated-s4 + +术语表: + b: 批次大小 (`B` 在 Mamba 论文 [1] 算法2中) + l: 序列长度 (`L` 在 [1] 算法2中) + d 或 d_model: 隐藏维度 + n 或 d_state: 潜在状态维度 (`N` 在 [1] 算法2中) + expand: 扩展因子 (`E` 在 [1] 第3.4节中) + d_in 或 d_inner: d * expand (`D` 在 [1] 算法2中) + A, B, C, D: 状态空间参数 (参见任何状态空间表示公式) + (B, C是输入相关的(即选择性的,这是Mamba的一个关键创新); A, D不是) + Δ 或 delta: 输入相关的步长 + dt_rank: Δ的秩 (参见[1] 第3.6节 "Parameterization of ∆") + +""" +from __future__ import annotations +import math +import json +import torch +import torch.nn as nn +import torch.nn.functional as F +from dataclasses import dataclass +from einops import rearrange, repeat, einsum + + +@dataclass +class ModelArgs: + d_model: int + n_layer: int + vocab_size: int + d_state: int = 16 + expand: int = 2 + dt_rank: Union[int, str] = 'auto' + d_conv: int = 4 + pad_vocab_size_multiple: int = 8 + conv_bias: bool = True + bias: bool = False + + def __post_init__(self): + self.d_inner = int(self.expand * self.d_model) + + if self.dt_rank == 'auto': + self.dt_rank = math.ceil(self.d_model / 16) + + if self.vocab_size % self.pad_vocab_size_multiple != 0: + self.vocab_size += (self.pad_vocab_size_multiple + - self.vocab_size % self.pad_vocab_size_multiple) + + +class Mamba(nn.Module): + def __init__(self, args: ModelArgs): + """完整的Mamba模型。""" + super().__init__() + self.args = args + + self.embedding = nn.Embedding(args.vocab_size, args.d_model) + self.layers = nn.ModuleList([ResidualBlock(args) for _ in range(args.n_layer)]) + self.norm_f = RMSNorm(args.d_model) + + self.lm_head = nn.Linear(args.d_model, args.vocab_size, bias=False) + self.lm_head.weight = self.embedding.weight # 将输出投影与嵌入权重绑定。 + # 参见 "Weight Tying" 论文 + + + def forward(self, input_ids): + """ + 参数: + input_ids (long tensor): 形状 (b, l) (参见顶部术语表中b, l, d_in, n的定义) + + 返回: + logits: 形状 (b, l, vocab_size) + + 官方实现: + class MambaLMHeadModel, https://github.com/state-spaces/mamba/blob/main/mamba_ssm/models/mixer_seq_simple.py#L173 + + """ + x = self.embedding(input_ids) + + for layer in self.layers: + x = layer(x) + + x = self.norm_f(x) + logits = self.lm_head(x) + + return logits + + + @staticmethod + def from_pretrained(pretrained_model_name: str): + """从HuggingFace加载预训练权重到模型中。 + + 参数: + pretrained_model_name: 以下之一 + * 'state-spaces/mamba-2.8b-slimpj' + * 'state-spaces/mamba-2.8b' + * 'state-spaces/mamba-1.4b' + * 'state-spaces/mamba-790m' + * 'state-spaces/mamba-370m' + * 'state-spaces/mamba-130m' + + 返回: + model: 加载了权重的Mamba模型 + + """ + from transformers.utils import WEIGHTS_NAME, CONFIG_NAME + from transformers.utils.hub import cached_file + + def load_config_hf(model_name): + resolved_archive_file = cached_file(model_name, CONFIG_NAME, + _raise_exceptions_for_missing_entries=False) + return json.load(open(resolved_archive_file)) + + + def load_state_dict_hf(model_name, device=None, dtype=None): + resolved_archive_file = cached_file(model_name, WEIGHTS_NAME, + _raise_exceptions_for_missing_entries=False) + return torch.load(resolved_archive_file, weights_only=True, map_location='cpu', mmap=True) + + config_data = load_config_hf(pretrained_model_name) + args = ModelArgs( + d_model=config_data['d_model'], + n_layer=config_data['n_layer'], + vocab_size=config_data['vocab_size'] + ) + model = Mamba(args) + + state_dict = load_state_dict_hf(pretrained_model_name) + new_state_dict = {} + for key in state_dict: + new_key = key.replace('backbone.', '') + new_state_dict[new_key] = state_dict[key] + model.load_state_dict(new_state_dict) + + return model + + +class ResidualBlock(nn.Module): + def __init__(self, args: ModelArgs): + """包装Mamba块的简单块,具有归一化和残差连接。""" + super().__init__() + self.args = args + self.mixer = MambaBlock(args) + self.norm = RMSNorm(args.d_model) + + + def forward(self, x): + """ + 参数: + x: 形状 (b, l, d) (参见顶部术语表中b, l, d_in, n的定义) + + 返回: + output: 形状 (b, l, d) + + 官方实现: + Block.forward(), https://github.com/state-spaces/mamba/blob/main/mamba_ssm/modules/mamba_simple.py#L297 + + 注意: 官方库链式残差块看起来像 + [Add -> Norm -> Mamba] -> [Add -> Norm -> Mamba] -> [Add -> Norm -> Mamba] -> ... + 其中第一个Add是无操作。这纯粹是为了性能原因,因为这 + 允许他们融合Add->Norm。 + + 我们相反实现的块更为熟悉,更简单,数值上等效 + [Norm -> Mamba -> Add] -> [Norm -> Mamba -> Add] -> [Norm -> Mamba -> Add] -> .... + + """ + output = self.mixer(self.norm(x)) + x + + return output + + +class MambaBlock(nn.Module): + def __init__(self, args: ModelArgs): + """一个单一的Mamba块,如Mamba论文[1]第3.4节中的图3所述。""" + super().__init__() + self.args = args + + self.in_proj = nn.Linear(args.d_model, args.d_inner * 2, bias=args.bias) + + self.conv1d = nn.Conv1d( + in_channels=args.d_inner, + out_channels=args.d_inner, + bias=args.conv_bias, + kernel_size=args.d_conv, + groups=args.d_inner, + padding=args.d_conv - 1, + ) + + # x_proj接收`x`并输出输入特定的Δ, B, C + self.x_proj = nn.Linear(args.d_inner, args.dt_rank + args.d_state * 2, bias=False) + + # dt_proj将Δ从dt_rank投影到d_in + self.dt_proj = nn.Linear(args.dt_rank, args.d_inner, bias=True) + + A = repeat(torch.arange(1, args.d_state + 1), 'n -> d n', d=args.d_inner) + self.A_log = nn.Parameter(torch.log(A)) + self.D = nn.Parameter(torch.ones(args.d_inner)) + self.out_proj = nn.Linear(args.d_inner, args.d_model, bias=args.bias) + + + def forward(self, x): + """Mamba块前向。这与Mamba论文[1]第3.4节中的图3相同。 + + 参数: + x: 形状 (b, l, d) (参见顶部术语表中b, l, d_in, n的定义) + + 返回: + output: 形状 (b, l, d) + + 官方实现: + class Mamba, https://github.com/state-spaces/mamba/blob/main/mamba_ssm/modules/mamba_simple.py#L119 + mamba_inner_ref(), https://github.com/state-spaces/mamba/blob/main/mamba_ssm/ops/selective_scan_interface.py#L311 + + """ + (b, l, d) = x.shape + + x_and_res = self.in_proj(x) # 形状 (b, l, 2 * d_in) + (x, res) = x_and_res.split(split_size=[self.args.d_inner, self.args.d_inner], dim=-1) + + x = rearrange(x, 'b l d_in -> b d_in l') + x = self.conv1d(x)[:, :, :l] + x = rearrange(x, 'b d_in l -> b l d_in') + + x = F.silu(x) + + y = self.ssm(x) + + y = y * F.silu(res) + + output = self.out_proj(y) + + return output + + + def ssm(self, x): + """运行SSM。参见: + - Mamba论文[1]第3.2节中的算法2 + - The Annotated S4 [2] 中的run_SSM(A, B, C, u) + + 参数: + x: 形状 (b, l, d_in) (参见顶部术语表中b, l, d_in, n的定义) + + 返回: + output: 形状 (b, l, d_in) + + 官方实现: + mamba_inner_ref(), https://github.com/state-spaces/mamba/blob/main/mamba_ssm/ops/selective_scan_interface.py#L311 + + """ + (d_in, n) = self.A_log.shape + + # 计算 ∆ A B C D,状态空间参数。 + # A, D 与输入无关 (参见Mamba论文[1]第3.5.2节"Interpretation of A"中为何A不是选择性的) + # ∆, B, C 与输入相关 (这是Mamba与线性时不变S4的一个关键区别, + # 也是Mamba称为**选择性**状态空间的原因) + + A = -torch.exp(self.A_log.float()) # 形状 (d_in, n) + D = self.D.float() + + x_dbl = self.x_proj(x) # (b, l, dt_rank + 2*n) + + (delta, B, C) = x_dbl.split(split_size=[self.args.dt_rank, n, n], dim=-1) # delta: (b, l, dt_rank). B, C: (b, l, n) + delta = F.softplus(self.dt_proj(delta)) # (b, l, d_in) + + y = self.selective_scan(x, delta, A, B, C, D) # 这类似于The Annotated S4 [2]中的run_SSM(A, B, C, u) + + return y + + + def selective_scan(self, u, delta, A, B, C, D): + """执行选择性扫描算法。参见: + - Mamba论文[1]中的第2节状态空间模型 + - Mamba论文[1]第3.2节中的算法2 + - The Annotated S4 [2]中的run_SSM(A, B, C, u) + + 这是经典的离散状态空间公式: + x(t + 1) = Ax(t) + Bu(t) + y(t) = Cx(t) + Du(t) + 除了B和C(以及用于离散化的步长delta)依赖于输入x(t)。 + + 参数: + u: 形状 (b, l, d_in) (参见顶部术语表中b, l, d_in, n的定义) + delta: 形状 (b, l, d_in) + A: 形状 (d_in, n) + B: 形状 (b, l, n) + C: 形状 (b, l, n) + D: 形状 (d_in,) + + 返回: + output: 形状 (b, l, d_in) + + 官方实现: + selective_scan_ref(), https://github.com/state-spaces/mamba/blob/main/mamba_ssm/ops/selective_scan_interface.py#L86 + 注意: 我将`selective_scan_ref`中的一些部分进行了重构,所以功能不完全匹配。 + + """ + (b, l, d_in) = u.shape + n = A.shape[1] + + # 离散化连续参数 (A, B) + # - A 使用零阶保持(ZOH)离散化 (参见Mamba论文[1]第2节方程4) + # - B 使用简化的欧拉离散化而不是ZOH。从与作者的讨论中: + # "A是更重要的项,简化B对性能影响不大" + deltaA = torch.exp(einsum(delta, A, 'b l d_in, d_in n -> b l d_in n')) + deltaB_u = einsum(delta, B, u, 'b l d_in, b l n, b l d_in -> b l d_in n') + + # 执行选择性扫描 (参见The Annotated S4 [2]中的scan_SSM()) + # 注意,以下是顺序的,而官方实现是一个更快的并行扫描,此外还考虑了硬件 (如FlashAttention)。 + x = torch.zeros((b, d_in, n), device=deltaA.device) + ys = [] + for i in range(l): + x = deltaA[:, i] * x + deltaB_u[:, i] + y = einsum(x, C[:, i, :], 'b d_in n, b n -> b d_in') + ys.append(y) + y = torch.stack(ys, dim=1) # 形状 (b, l, d_in) + + y = y + u * D + + return y + + +class RMSNorm(nn.Module): + def __init__(self, + d_model: int, + eps: float = 1e-5): + super().__init__() + self.eps = eps + self.weight = nn.Parameter(torch.ones(d_model)) + + + def forward(self, x): + output = x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) * self.weight + + return output