* Update TensorRT-LLM --------- Co-authored-by: Timur Abishev <abishev.timur@gmail.com> Co-authored-by: MahmoudAshraf97 <hassouna97.ma@gmail.com> Co-authored-by: Saeyoon Oh <saeyoon.oh@furiosa.ai> Co-authored-by: hattizai <hattizai@gmail.com> |
||
|---|---|---|
| .. | ||
| convert_checkpoint.py | ||
| README.md | ||
| requirements.txt | ||
Mamba
This document shows how to build and run a Mamba model in TensorRT-LLM on a single GPU.
Overview
The TensorRT-LLM Mamba implementation can be found in tensorrt_llm/models/mamba/model.py. The TensorRT-LLM Mamba example code is located in examples/mamba. There is one main file:
convert_checkpoint.pyto convert a checkpoint from the HuggingFace (HF) Transformers format to the TensorRT-LLM format.
In addition, there are two shared files in the parent folder examples for inference and evaluation:
../run.pyto run the inference on an input text;../summarize.pyto summarize the articles in the cnn_dailymail dataset.
Support Matrix
| Model Name | FP16 | BF16 |
|---|---|---|
| Mamba1 | Y | Y |
| Mamba2 | Y | Y |
- Mamba2: TensorRT-LLM can only support the pure Mamba model for now, will support the hybrid models later.
Usage
The next two sections describe how to convert the weights from the HuggingFace (HF) Transformers format to the TensorRT-LLM format.
1. Download weights from HuggingFace Transformers
Please install required packages first and setup git-lfs:
pip install -r requirements.txt
git lfs install
There are different HF checkpoints available. For Mamba1, TensorRT-LLM can support those Transformers compatible models. Here're some examples to fetch the checkpoint.
# mamba-2.8b
git clone https://huggingface.co/state-spaces/mamba-2.8b-hf ./mamba_model/mamba-2.8b
# mamba-130m
git clone https://huggingface.co/state-spaces/mamba-130m-hf ./mamba_model/mamba-130m
# mamba2-2.7b
git clone https://huggingface.co/state-spaces/mamba2-2.7b ./mamba_model/mamba2-2.7b
# mamba2-130m
git clone https://huggingface.co/state-spaces/mamba2-130m ./mamba_model/mamba2-130m
# mamba-codestral-7B-v0.1
git clone https://huggingface.co/mistralai/mamba-codestral-7B-v0.1 ./mamba_model/mamba-codestral-7B-v0.1
Since mamba models use tokenizer from gpt-neox-20b model and mamba-codestral-7B-v0.1 uses the same tokenizer with mathstral-7B-v0.1 model, use the following command to fetch the checkpoint of gpt-neox-20b and mathstral-7B-v0.1.
# gpt-neox-20b
git clone https://huggingface.co/EleutherAI/gpt-neox-20b ./mamba_model/gpt-neox-20b
# mathstral-7B-v0.1
git clone https://huggingface.co/mistralai/mathstral-7B-v0.1 ./mamba_model/mathstral-7B-v0.1
2. Convert weights from HF Transformers to TensorRT-LLM format
The convert_checkpoint.py script converts HF weights to TensorRT-LLM checkpoints.
# mamba-2.8b
python convert_checkpoint.py --model_dir ./mamba_model/mamba-2.8b/ \
--dtype bfloat16 \
--output_dir ./mamba_model/mamba-2.8b/trt_ckpt/bf16/1-gpu/
# mamba-130m
python convert_checkpoint.py --model_dir ./mamba_model/mamba-130m/ \
--dtype float16 \
--output_dir ./mamba_model/mamba-130m/trt_ckpt/fp16/1-gpu/
# mamba2-2.7b
python convert_checkpoint.py --model_dir ./mamba_model/mamba2-2.7b/ \
--dtype float16 \
--output_dir ./mamba_model/mamba2-2.7b/trt_ckpt/fp16/1-gpu/
# mamba2-130m
python convert_checkpoint.py --model_dir ./mamba_model/mamba2-130m/ \
--dtype float16 \
--output_dir ./mamba_model/mamba2-130m/trt_ckpt/fp16/1-gpu/
# mamba-codestral-7B-v0.1
python convert_checkpoint.py --model_dir ./mamba_model/mamba-codestral-7B-v0.1/ \
--dtype float16 \
--output_dir ./mamba_model/mamba-codestral-7B-v0.1/trt_ckpt/fp16/1-gpu/
3. Build TensorRT engine(s)
The trtllm-build command builds TensorRT-LLM engines from TensorRT-LLM checkpoints.
# mamba-2.8b
trtllm-build --checkpoint_dir ./mamba_model/mamba-2.8b/trt_ckpt/bf16/1-gpu/ \
--paged_kv_cache disable \
--gemm_plugin auto \
--max_batch_size 8 \
--max_input_len 924 \
--max_seq_len 1024 \
--output_dir ./mamba_model/mamba-2.8b/trt_engines/bf16/1-gpu/
# mamba-130m
trtllm-build --checkpoint_dir ./mamba_model/mamba-130m/trt_ckpt/fp16/1-gpu/ \
--paged_kv_cache disable \
--gemm_plugin auto \
--max_batch_size 8 \
--max_input_len 924 \
--max_seq_len 1024 \
--output_dir ./mamba_model/mamba-130m/trt_engines/fp16/1-gpu/
# mamba2-2.7b
trtllm-build --checkpoint_dir ./mamba_model/mamba2-2.7b/trt_ckpt/fp16/1-gpu/ \
--paged_kv_cache disable \
--gemm_plugin auto \
--max_batch_size 8 \
--max_input_len 924 \
--max_seq_len 1024 \
--output_dir ./mamba_model/mamba2-2.7b/trt_engines/fp16/1-gpu/
# mamba2-130m
trtllm-build --checkpoint_dir ./mamba_model/mamba2-130m/trt_ckpt/fp16/1-gpu/ \
--paged_kv_cache disable \
--gemm_plugin auto \
--max_batch_size 8 \
--max_input_len 924 \
--max_seq_len 1024 \
--output_dir ./mamba_model/mamba2-130m/trt_engines/fp16/1-gpu/
# mamba-codestral-7B-v0.1
trtllm-build --checkpoint_dir ./mamba_model/mamba-codestral-7B-v0.1/trt_ckpt/fp16/1-gpu/ \
--paged_kv_cache disable \
--gemm_plugin auto \
--max_batch_size 8 \
--max_input_len 924 \
--max_seq_len 1024 \
--output_dir ./mamba_model/mamba-codestral-7B-v0.1/trt_engines/fp16/1-gpu/
Note that when building Mamba models, you need to disable the paged_kv_cache as it is used for
transformer-based models. Mamba models use paged_state instead and it is enabed by default.
If paged_state is disabled, engine will be built with the contiguous stage cache.
4. Run summarization task with the TensorRT engine(s)
The following section describes how to run a TensorRT-LLM Mamba model to summarize the articles from the
cnn_dailymail dataset. For each summary, the script can compute the
ROUGE scores and use the ROUGE-1 score to validate the implementation.
# mamba-2.8b
python ../summarize.py --test_trt_llm \
--hf_model_dir ./mamba_model/mamba-2.8b/ \
--tokenizer_dir ./mamba_model/gpt-neox-20b/ \
--data_type bf16 \
--engine_dir ./mamba_model/mamba-2.8b/trt_engines/bf16/1-gpu/
# mamba-130m
python ../summarize.py --test_trt_llm \
--hf_model_dir ./mamba_model/mamba-130m/ \
--tokenizer_dir ./mamba_model/gpt-neox-20b/ \
--data_type fp16 \
--engine_dir ./mamba_model/mamba-130m/trt_engines/fp16/1-gpu/
# mamba2-2.7b
python ../summarize.py --test_trt_llm \
--hf_model_dir ./mamba_model/mamba2-2.7b/ \
--tokenizer_dir ./mamba_model/gpt-neox-20b/ \
--data_type fp16 \
--engine_dir ./mamba_model/mamba2-2.7b/trt_engines/fp16/1-gpu/
# mamba2-130m
python ../summarize.py --test_trt_llm \
--hf_model_dir ./mamba_model/mamba2-130m/ \
--tokenizer_dir ./mamba_model/gpt-neox-20b/ \
--data_type fp16 \
--engine_dir ./mamba_model/mamba2-130m/trt_engines/fp16/1-gpu/
# mamba-codestral-7B-v0.1
python ../summarize.py --test_trt_llm \
--hf_model_dir ./mamba_model/mamba-codestral-7B-v0.1/ \
--tokenizer_dir ./mamba_model/mathstral-7B-v0.1/ \
--data_type fp16 \
--engine_dir ./mamba_model/mamba-codestral-7B-v0.1/trt_engines/fp16/1-gpu/