mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-30 15:43:19 +08:00
247 lines
9.5 KiB
Markdown
247 lines
9.5 KiB
Markdown
# RecurrentGemma
|
|
|
|
This document shows how to build and run a [RecurrentGemma](https://github.com/google-deepmind/recurrentgemma) model in TensorRT-LLM.
|
|
|
|
## Overview
|
|
|
|
The TensorRT-LLM RecurrentGemma implementation can be found in [`tensorrt_llm/models/recurrentgemma/model.py`](../../tensorrt_llm/models/recurrentgemma/model.py). The TensorRT-LLM RecurrentGemma example code is located in [`examples/recurrentgemma`](./). There is one main file:
|
|
|
|
* [`convert_checkpoint.py`](./convert_checkpoint.py) to convert a checkpoint from the JAX format to the TensorRT-LLM format.
|
|
|
|
In addition, there are two shared files in the parent folder [`examples`](../) for inference and evaluation:
|
|
|
|
* [`../run.py`](../run.py) to run the inference on an input text;
|
|
* [`../summarize.py`](../summarize.py) to summarize the articles in the [cnn_dailymail](https://huggingface.co/datasets/cnn_dailymail) dataset.
|
|
|
|
## Support Matrix
|
|
* FP16/BF16
|
|
* Checkpoint type: Jax, Torch, Keras, Huggingface (HF)
|
|
* python runtime
|
|
|
|
## Usage
|
|
|
|
### 1. Prepare requirements and download weights
|
|
|
|
Please install required packages first and setup `git-lfs`:
|
|
|
|
```bash
|
|
pip install -r requirements.txt
|
|
git clone https://github.com/google-deepmind/recurrentgemma.git
|
|
pip install ./recurrentgemma/[full]
|
|
pip install "transformers>=4.40.0"
|
|
|
|
# Setup git-lfs
|
|
git lfs install
|
|
```
|
|
|
|
Then use one of the following commands to fetch the checkpoint you are interested in. These models are public but users need to login and then they are able to clone the models.
|
|
|
|
```bash
|
|
# recurrentgemma-2b
|
|
git clone https://huggingface.co/google/recurrentgemma-2b ./recurrentgemma_model/recurrentgemma-2b
|
|
|
|
# recurrentgemma-2b-it
|
|
git clone https://huggingface.co/google/recurrentgemma-2b-it ./recurrentgemma_model/recurrentgemma-2b-it
|
|
|
|
# recurrentgemma-2b-flax
|
|
git clone https://huggingface.co/google/recurrentgemma-2b-flax ./recurrentgemma_model/recurrentgemma-2b-flax
|
|
|
|
# recurrentgemma-2b-it-flax
|
|
git clone https://huggingface.co/google/recurrentgemma-2b-it-flax ./recurrentgemma_model/recurrentgemma-2b-it-flax
|
|
```
|
|
|
|
### 2. Convert weights from JAX to TensorRT-LLM format
|
|
The [`convert_checkpoint.py`](./convert_checkpoint.py) script converts JAX weights to TensorRT-LLM checkpoints.
|
|
|
|
```bash
|
|
# recurrentgemma-2b
|
|
CKPT_2B_PATH=./recurrentgemma/recurrentgemma-2b
|
|
UNIFIED_CKPT_2B_PATH=./recurrentgemma_model/recurrentgemma-2b/trt_ckpt/fp16/1-gpu/
|
|
python convert_checkpoint.py --model_dir ${CKPT_2B_PATH} \
|
|
--ckpt_type hf \
|
|
--dtype float16 \
|
|
--output_dir ${UNIFIED_CKPT_2B_PATH}
|
|
|
|
# recurrentgemma-2b-it
|
|
CKPT_2B_IT_PATH=./recurrentgemma/recurrentgemma-2b-it
|
|
UNIFIED_CKPT_2B_IT_PATH=./recurrentgemma_model/recurrentgemma-2b-it/trt_ckpt/fp16/1-gpu/
|
|
python convert_checkpoint.py --model_dir ${CKPT_2B_IT_PATH} \
|
|
--ckpt_type hf \
|
|
--dtype bfloat16 \
|
|
--output_dir ${UNIFIED_CKPT_2B_IT_PATH}
|
|
|
|
# recurrentgemma-2b-flax
|
|
CKPT_2B_FLAX_PATH=./recurrentgemma_model/recurrentgemma-2b-flax/2b
|
|
UNIFIED_CKPT_2B_FLAX_PATH=./recurrentgemma_model/recurrentgemma-2b-flax/trt_ckpt/fp16/1-gpu/
|
|
python convert_checkpoint.py --model_dir ${CKPT_2B_FLAX_PATH} \
|
|
--ckpt_type jax \
|
|
--dtype float16 \
|
|
--output_dir ${UNIFIED_CKPT_2B_FLAX_PATH}
|
|
|
|
# recurrentgemma-2b-it-flax
|
|
CKPT_2B_IT_FLAX_PATH=./recurrentgemma_model/recurrentgemma-2b-it-flax/2b-it
|
|
UNIFIED_CKPT_2B_IT_FLAX_PATH=./recurrentgemma_model/recurrentgemma-2b-it-flax/trt_ckpt/bf16/1-gpu/
|
|
python convert_checkpoint.py --model_dir ${CKPT_2B_IT_FLAX_PATH} \
|
|
--ckpt_type jax \
|
|
--dtype bfloat16 \
|
|
--output_dir ${UNIFIED_CKPT_2B_IT_FLAX_PATH}
|
|
```
|
|
|
|
### 3. Build TensorRT engine(s)
|
|
After getting checkpoint, we can use `trtllm-build` command to build TensorRT-LLM engines from TensorRT-LLM checkpoints.
|
|
|
|
```bash
|
|
# recurrentgemma-2b
|
|
ENGINE_2B_PATH=./recurrentgemma_model/recurrentgemma-2b/trt_engines/fp16/1-gpu/
|
|
trtllm-build --checkpoint_dir ${UNIFIED_CKPT_2B_PATH} \
|
|
--gpt_attention_plugin float16 \
|
|
--gemm_plugin float16 \
|
|
--mamba_conv1d_plugin float16 \
|
|
--max_batch_size 8 \
|
|
--max_input_len 3000 \
|
|
--max_output_len 100 \
|
|
--output_dir ${ENGINE_2B_PATH}
|
|
|
|
# recurrentgemma-2b-it
|
|
ENGINE_2B_IT_PATH=./recurrentgemma_model/recurrentgemma-2b-it/trt_engines/bf16/1-gpu/
|
|
trtllm-build --checkpoint_dir ${UNIFIED_CKPT_2B_IT_PATH} \
|
|
--gpt_attention_plugin bfloat16 \
|
|
--gemm_plugin bfloat16 \
|
|
--mamba_conv1d_plugin bfloat16 \
|
|
--max_batch_size 8 \
|
|
--max_input_len 3000 \
|
|
--max_output_len 100 \
|
|
--output_dir ${ENGINE_2B_IT_PATH}
|
|
|
|
# recurrentgemma-2b-flax
|
|
ENGINE_2B_FLAX_PATH=./recurrentgemma_model/recurrentgemma-2b-flax/trt_engines/fp16/1-gpu/
|
|
trtllm-build --checkpoint_dir ${UNIFIED_CKPT_2B_FLAX_PATH} \
|
|
--gpt_attention_plugin float16 \
|
|
--gemm_plugin float16 \
|
|
--mamba_conv1d_plugin float16 \
|
|
--max_batch_size 8 \
|
|
--max_input_len 3000 \
|
|
--max_output_len 100 \
|
|
--output_dir ${ENGINE_2B_FLAX_PATH}
|
|
|
|
# recurrentgemma-2b-it-flax
|
|
ENGINE_2B_IT_FLAX_PATH=./recurrentgemma_model/recurrentgemma-2b-it-flax/trt_engines/bf16/1-gpu/
|
|
trtllm-build --checkpoint_dir ${UNIFIED_CKPT_2B_IT_FLAX_PATH} \
|
|
--gpt_attention_plugin bfloat16 \
|
|
--gemm_plugin bfloat16 \
|
|
--mamba_conv1d_plugin bfloat16 \
|
|
--max_batch_size 8 \
|
|
--max_input_len 3000 \
|
|
--max_output_len 100 \
|
|
--output_dir ${ENGINE_2B_IT_FLAX_PATH}
|
|
```
|
|
|
|
### 4. Run inference with the TensorRT engine(s)
|
|
|
|
We provide three examples to run inference `run.py`, `summarize.py` and `mmlu.py`. `run.py` only run inference with `input_text` and show the output.
|
|
|
|
`summarize.py` runs summarization on [cnn_dailymail](https://huggingface.co/datasets/cnn_dailymail) dataset and evaluate the model by [ROUGE](https://en.wikipedia.org/wiki/ROUGE_(metric)) scores and use the `ROUGE-1` score to validate the implementation.
|
|
|
|
`mmlu.py` runs MMLU to evaluate the model by accuracy.
|
|
|
|
Note that we need to download the dataset of MMLU first and the evaluation of MMLU requires more time.
|
|
|
|
* run.py
|
|
|
|
```bash
|
|
# recurrentgemma-2b
|
|
TOKENIZER_DIR_2B_PATH=./recurrentgemma/recurrentgemma-2b
|
|
python3 ../run.py --max_output_len=100 \
|
|
--use_py_session \
|
|
--tokenizer_dir ${TOKENIZER_DIR_2B_PATH} \
|
|
--engine_dir ${ENGINE_2B_PATH}
|
|
|
|
# recurrentgemma-2b-it
|
|
TOKENIZER_DIR_2B_IT_PATH=./recurrentgemma/recurrentgemma-2b-it
|
|
python3 ../run.py --max_output_len=100 \
|
|
--use_py_session \
|
|
--tokenizer_dir ${TOKENIZER_DIR_2B_IT_PATH} \
|
|
--engine_dir ${ENGINE_2B_IT_PATH}
|
|
|
|
# recurrentgemma-2b-flax
|
|
VOCAB_FILE_2B_FLAX_PATH=./recurrentgemma_model/recurrentgemma-2b-flax/tokenizer.model
|
|
python3 ../run.py --max_output_len=100 \
|
|
--use_py_session \
|
|
--vocab_file ${VOCAB_FILE_2B_FLAX_PATH} \
|
|
--engine_dir ${ENGINE_2B_FLAX_PATH}
|
|
|
|
# recurrentgemma-2b-it-flax
|
|
VOCAB_FILE_2B_IT_FLAX_PATH=./recurrentgemma_model/recurrentgemma-2b-it-flax/tokenizer.model
|
|
python3 ../run.py --max_output_len=100 \
|
|
--use_py_session \
|
|
--vocab_file ${VOCAB_FILE_2B_IT_FLAX_PATH} \
|
|
--engine_dir ${ENGINE_2B_IT_FLAX_PATH}
|
|
```
|
|
|
|
* summarize.py
|
|
|
|
```bash
|
|
# recurrentgemma-2b
|
|
python3 ../summarize.py --test_trt_llm \
|
|
--use_py_session \
|
|
--engine_dir ${ENGINE_2B_PATH} \
|
|
--batch_size 8 \
|
|
--tokenizer_dir ${TOKENIZER_DIR_2B_PATH}
|
|
|
|
# recurrentgemma-2b-it
|
|
python3 ../summarize.py --test_trt_llm \
|
|
--use_py_session \
|
|
--engine_dir ${ENGINE_2B_IT_PATH} \
|
|
--batch_size 8 \
|
|
--tokenizer_dir ${TOKENIZER_DIR_2B_IT_PATH}
|
|
|
|
# recurrentgemma-2b-flax
|
|
python3 ../summarize.py --test_trt_llm \
|
|
--use_py_session \
|
|
--engine_dir ${ENGINE_2B_FLAX_PATH} \
|
|
--batch_size 8 \
|
|
--vocab_file ${VOCAB_FILE_2B_FLAX_PATH}
|
|
|
|
# recurrentgemma-2b-it-flax
|
|
python3 ../summarize.py --test_trt_llm \
|
|
--use_py_session \
|
|
--engine_dir ${ENGINE_2B_IT_FLAX_PATH} \
|
|
--batch_size 8 \
|
|
--vocab_file ${VOCAB_FILE_2B_IT_FLAX_PATH}
|
|
```
|
|
|
|
* mmlu.py
|
|
|
|
Download the dataset first
|
|
|
|
```bash
|
|
mkdir data
|
|
wget https://people.eecs.berkeley.edu/~hendrycks/data.tar -O data/mmlu.tar
|
|
tar -xf data/mmlu.tar -C data
|
|
mv data/data data/mmlu
|
|
```
|
|
|
|
Evaluate on MMLU dataset.
|
|
|
|
```bash
|
|
# recurrentgemma-2b
|
|
python3 ../mmlu.py --test_trt_llm \
|
|
--tokenizer_dir ${TOKENIZER_DIR_2B_PATH} \
|
|
--engine_dir ${ENGINE_2B_PATH}
|
|
|
|
# recurrentgemma-2b-it
|
|
python3 ../mmlu.py --test_trt_llm \
|
|
--tokenizer_dir ${TOKENIZER_DIR_2B_IT_PATH} \
|
|
--engine_dir ${ENGINE_2B_IT_PATH}
|
|
|
|
# recurrentgemma-2b-flax
|
|
python3 ../mmlu.py --test_trt_llm \
|
|
--vocab_file ${VOCAB_FILE_2B_FLAX_PATH} \
|
|
--engine_dir ${ENGINE_2B_FLAX_PATH}
|
|
|
|
# recurrentgemma-2b-it-flax
|
|
python3 ../mmlu.py --test_trt_llm \
|
|
--vocab_file ${VOCAB_FILE_2B_IT_FLAX_PATH} \
|
|
--engine_dir ${ENGINE_2B_IT_FLAX_PATH}
|
|
```
|