mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-30 07:33:48 +08:00
* Update TensorRT-LLM --------- Co-authored-by: meghagarwal <16129366+megha95@users.noreply.github.com>
229 lines
9.2 KiB
Markdown
229 lines
9.2 KiB
Markdown
# Mamba
|
|
|
|
This document shows how to build and run a [Mamba](https://github.com/state-spaces/mamba) model in TensorRT-LLM on a single GPU.
|
|
|
|
## Overview
|
|
|
|
The TensorRT-LLM Mamba implementation can be found in [`tensorrt_llm/models/mamba/model.py`](../../tensorrt_llm/models/mamba/model.py). The TensorRT-LLM Mamba example code is located in [`examples/mamba`](./). There is one main file:
|
|
|
|
* [`convert_checkpoint.py`](./convert_checkpoint.py) to convert a checkpoint from the [HuggingFace (HF) Transformers](https://github.com/huggingface/transformers) format to the TensorRT-LLM format.
|
|
|
|
In addition, there are two shared files in the parent folder [`examples`](../) for inference and evaluation:
|
|
|
|
* [`../run.py`](../run.py) to run the inference on an input text;
|
|
* [`../summarize.py`](../summarize.py) to summarize the articles in the [cnn_dailymail](https://huggingface.co/datasets/cnn_dailymail) dataset.
|
|
|
|
|
|
## Support Matrix
|
|
* FP16
|
|
* BF16
|
|
|
|
## Usage
|
|
|
|
The next two sections describe how to convert the weights from the [HuggingFace (HF) Transformers](https://github.com/huggingface/transformers)
|
|
format to the TensorRT-LLM format.
|
|
|
|
### 1. Download weights from HuggingFace Transformers
|
|
|
|
Please install required packages first and setup `git-lfs`:
|
|
|
|
```bash
|
|
pip install -r requirements.txt
|
|
git clone --branch v1.1.1 https://github.com/Dao-AILab/causal-conv1d.git
|
|
git clone --branch v1.1.1 https://github.com/state-spaces/mamba.git
|
|
pip install ./causal-conv1d/ ./mamba/
|
|
|
|
# Setup git-lfs
|
|
git lfs install
|
|
```
|
|
|
|
There are six HF checkpoints available. Use one of the following commands to fetch the checkpoint you are interested in.
|
|
|
|
```bash
|
|
# mamba-2.8b-slimpj
|
|
git clone https://huggingface.co/state-spaces/mamba-2.8b-slimpj ./mamba_model/mamba-2.8b-slimpj
|
|
|
|
# mamba-2.8b
|
|
git clone https://huggingface.co/state-spaces/mamba-2.8b ./mamba_model/mamba-2.8b
|
|
|
|
# mamba-1.4b
|
|
git clone https://huggingface.co/state-spaces/mamba-1.4b ./mamba_model/mamba-1.4b
|
|
|
|
# mamba-790m
|
|
git clone https://huggingface.co/state-spaces/mamba-790m ./mamba_model/mamba-790m
|
|
|
|
# mamba-370m
|
|
git clone https://huggingface.co/state-spaces/mamba-370m ./mamba_model/mamba-370m
|
|
|
|
# mamba-130m
|
|
git clone https://huggingface.co/state-spaces/mamba-130m ./mamba_model/mamba-130m
|
|
```
|
|
|
|
Since mamba models use tokenizer from gpt-neox-20b model, use the following command to fetch the checkpoint of gpt-neox-20b.
|
|
|
|
```bash
|
|
# gpt-neox-20b
|
|
git clone https://huggingface.co/EleutherAI/gpt-neox-20b ./mamba_model/gpt-neox-20b
|
|
```
|
|
|
|
### 2. Convert weights from HF Transformers to TensorRT-LLM format
|
|
The [`convert_checkpoint.py`](./convert_checkpoint.py) script converts HF weights to TensorRT-LLM checkpoints.
|
|
|
|
```bash
|
|
# mamba-2.8b-slimpj
|
|
python convert_checkpoint.py --model_dir ./mamba_model/mamba-2.8b-slimpj/ \
|
|
--dtype bfloat16 \
|
|
--output_dir ./mamba_model/mamba-2.8b-slimpj/trt_ckpt/bf16/1-gpu/
|
|
|
|
# mamba-2.8b
|
|
python convert_checkpoint.py --model_dir ./mamba_model/mamba-2.8b/ \
|
|
--dtype bfloat16 \
|
|
--output_dir ./mamba_model/mamba-2.8b/trt_ckpt/bf16/1-gpu/
|
|
|
|
# mamba-1.4b
|
|
python convert_checkpoint.py --model_dir ./mamba_model/mamba-1.4b/ \
|
|
--dtype float16 \
|
|
--output_dir ./mamba_model/mamba-1.4b/trt_ckpt/fp16/1-gpu/
|
|
|
|
# mamba-790m
|
|
python convert_checkpoint.py --model_dir ./mamba_model/mamba-790m/ \
|
|
--dtype float16 \
|
|
--output_dir ./mamba_model/mamba-790m/trt_ckpt/fp16/1-gpu/
|
|
|
|
# mamba-370m
|
|
python convert_checkpoint.py --model_dir ./mamba_model/mamba-370m/ \
|
|
--dtype float16 \
|
|
--output_dir ./mamba_model/mamba-370m/trt_ckpt/fp16/1-gpu/
|
|
|
|
# mamba-130m
|
|
python convert_checkpoint.py --model_dir ./mamba_model/mamba-130m/ \
|
|
--dtype float16 \
|
|
--output_dir ./mamba_model/mamba-130m/trt_ckpt/fp16/1-gpu/
|
|
```
|
|
|
|
### 3. Build TensorRT engine(s)
|
|
The `trtllm-build` command builds TensorRT-LLM engines from TensorRT-LLM checkpoints.
|
|
|
|
```bash
|
|
# mamba-2.8b-slimpj
|
|
trtllm-build --checkpoint_dir ./mamba_model/mamba-2.8b-slimpj/trt_ckpt/bf16/1-gpu/ \
|
|
--gpt_attention_plugin disable \
|
|
--paged_kv_cache disable \
|
|
--gemm_plugin bfloat16 \
|
|
--mamba_conv1d_plugin bfloat16 \
|
|
--max_batch_size 8 \
|
|
--max_input_len 924 \
|
|
--max_output_len 100 \
|
|
--output_dir ./mamba_model/mamba-2.8b-slimpj/trt_engines/bf16/1-gpu/
|
|
|
|
# mamba-2.8b
|
|
trtllm-build --checkpoint_dir ./mamba_model/mamba-2.8b/trt_ckpt/bf16/1-gpu/ \
|
|
--gpt_attention_plugin disable \
|
|
--paged_kv_cache disable \
|
|
--gemm_plugin bfloat16 \
|
|
--mamba_conv1d_plugin bfloat16 \
|
|
--max_batch_size 8 \
|
|
--max_input_len 924 \
|
|
--max_output_len 100 \
|
|
--output_dir ./mamba_model/mamba-2.8b/trt_engines/bf16/1-gpu/
|
|
|
|
# mamba-1.4b
|
|
trtllm-build --checkpoint_dir ./mamba_model/mamba-1.4b/trt_ckpt/fp16/1-gpu/ \
|
|
--gpt_attention_plugin disable \
|
|
--paged_kv_cache disable \
|
|
--gemm_plugin float16 \
|
|
--mamba_conv1d_plugin float16 \
|
|
--max_batch_size 8 \
|
|
--max_input_len 924 \
|
|
--max_output_len 100 \
|
|
--output_dir ./mamba_model/mamba-1.4b/trt_engines/fp16/1-gpu/
|
|
|
|
# mamba-790m
|
|
trtllm-build --checkpoint_dir ./mamba_model/mamba-790m/trt_ckpt/fp16/1-gpu/ \
|
|
--gpt_attention_plugin disable \
|
|
--paged_kv_cache disable \
|
|
--gemm_plugin float16 \
|
|
--mamba_conv1d_plugin float16 \
|
|
--max_batch_size 8 \
|
|
--max_input_len 924 \
|
|
--max_output_len 100 \
|
|
--output_dir ./mamba_model/mamba-790m/trt_engines/fp16/1-gpu/
|
|
|
|
# mamba-370m
|
|
trtllm-build --checkpoint_dir ./mamba_model/mamba-370m/trt_ckpt/fp16/1-gpu/ \
|
|
--gpt_attention_plugin disable \
|
|
--paged_kv_cache disable \
|
|
--gemm_plugin float16 \
|
|
--mamba_conv1d_plugin float16 \
|
|
--max_batch_size 8 \
|
|
--max_input_len 924 \
|
|
--max_output_len 100 \
|
|
--output_dir ./mamba_model/mamba-370m/trt_engines/fp16/1-gpu/
|
|
|
|
# mamba-130m
|
|
trtllm-build --checkpoint_dir ./mamba_model/mamba-130m/trt_ckpt/fp16/1-gpu/ \
|
|
--gpt_attention_plugin disable \
|
|
--paged_kv_cache disable \
|
|
--gemm_plugin float16 \
|
|
--mamba_conv1d_plugin float16 \
|
|
--max_batch_size 8 \
|
|
--max_input_len 924 \
|
|
--max_output_len 100 \
|
|
--output_dir ./mamba_model/mamba-130m/trt_engines/fp16/1-gpu/
|
|
```
|
|
|
|
Note that when building Mamba models, you need to disable the `paged_kv_cache` as it is used for
|
|
transformer-based models. Mamba models use `paged_state` instead and it is enabed by default.
|
|
If `paged_state` is disabled, engine will be built with the contiguous stage cache.
|
|
|
|
### 4. Run summarization task with the TensorRT engine(s)
|
|
|
|
The following section describes how to run a TensorRT-LLM Mamba model to summarize the articles from the
|
|
[cnn_dailymail](https://huggingface.co/datasets/cnn_dailymail) dataset. For each summary, the script can compute the
|
|
[ROUGE](https://en.wikipedia.org/wiki/ROUGE_(metric)) scores and use the `ROUGE-1` score to validate the implementation.
|
|
|
|
### Run
|
|
```bash
|
|
# mamba-2.8b-slimpj
|
|
python ../summarize.py --test_trt_llm \
|
|
--hf_model_dir ./mamba_model/mamba-2.8b-slimpj/ \
|
|
--tokenizer_dir ./mamba_model/gpt-neox-20b/ \
|
|
--data_type bf16 \
|
|
--engine_dir ./mamba_model/mamba-2.8b-slimpj/trt_engines/bf16/1-gpu/
|
|
|
|
# mamba-2.8b
|
|
python ../summarize.py --test_trt_llm \
|
|
--hf_model_dir ./mamba_model/mamba-2.8b/ \
|
|
--tokenizer_dir ./mamba_model/gpt-neox-20b/ \
|
|
--data_type bf16 \
|
|
--engine_dir ./mamba_model/mamba-2.8b/trt_engines/bf16/1-gpu/
|
|
|
|
# mamba-1.4b
|
|
python ../summarize.py --test_trt_llm \
|
|
--hf_model_dir ./mamba_model/mamba-1.4b/ \
|
|
--tokenizer_dir ./mamba_model/gpt-neox-20b/ \
|
|
--data_type fp16 \
|
|
--engine_dir ./mamba_model/mamba-1.4b/trt_engines/fp16/1-gpu/
|
|
|
|
# mamba-790m
|
|
python ../summarize.py --test_trt_llm \
|
|
--hf_model_dir ./mamba_model/mamba-790m/ \
|
|
--tokenizer_dir ./mamba_model/gpt-neox-20b/ \
|
|
--data_type fp16 \
|
|
--engine_dir ./mamba_model/mamba-790m/trt_engines/fp16/1-gpu/
|
|
|
|
# mamba-370m
|
|
python ../summarize.py --test_trt_llm \
|
|
--hf_model_dir ./mamba_model/mamba-370m/ \
|
|
--tokenizer_dir ./mamba_model/gpt-neox-20b/ \
|
|
--data_type fp16 \
|
|
--engine_dir ./mamba_model/mamba-370m/trt_engines/fp16/1-gpu/
|
|
|
|
# mamba-130m
|
|
python ../summarize.py --test_trt_llm \
|
|
--hf_model_dir ./mamba_model/mamba-130m/ \
|
|
--tokenizer_dir ./mamba_model/gpt-neox-20b/ \
|
|
--data_type fp16 \
|
|
--engine_dir ./mamba_model/mamba-130m/trt_engines/fp16/1-gpu/
|
|
```
|