* feat: adding multimodal (only image for now) support in trtllm-bench Signed-off-by: Rakib Hasan <rhasan@nvidia.com> * fix: add in load_dataset() calls to maintain the v2.19.2 behavior Signed-off-by: Rakib Hasan <rhasan@nvidia.com> * re-adding prompt_token_ids and using that for prompt_len Signed-off-by: Rakib Hasan <rhasan@nvidia.com> * updating the datasets version in examples as well Signed-off-by: Rakib Hasan <rhasan@nvidia.com> * api changes are not needed Signed-off-by: Rakib Hasan <rhasan@nvidia.com> * moving datasets requirement and removing a missed api change Signed-off-by: Rakib Hasan <rhasan@nvidia.com> * addressing review comments Signed-off-by: Rakib Hasan <rhasan@nvidia.com> * refactoring the quickstart example Signed-off-by: Rakib Hasan <rhasan@nvidia.com> --------- Signed-off-by: Rakib Hasan <rhasan@nvidia.com> |
||
|---|---|---|
| .. | ||
| distil_whisper | ||
| convert_checkpoint.py | ||
| README.md | ||
| requirements.txt | ||
| run.py | ||
| tokenizer.py | ||
| whisper_utils.py | ||
Whisper
This document shows how to build and run a whisper model in TensorRT-LLM on a single GPU.
Overview
The TensorRT-LLM Whisper example code is located in examples/whisper.
convert_checkpoint.pyto convert weights from OpenAI Whisper format to TRT-LLM format.trtllm-buildto build the TensorRT engine(s) needed to run the Whisper model.run.pyto run the inference on a single wav file, or a HuggingFace dataset (Librispeech test clean).
Support Matrix
- FP16
- INT8 (Weight Only Quant)
- INT4 (Weight Only Quant)
Usage
The TensorRT-LLM Whisper example code locates at examples/whisper. It takes whisper pytorch weights as input, and builds the corresponding TensorRT engines.
Build TensorRT engine(s)
Need to prepare the whisper checkpoint first by downloading models from here.
wget --directory-prefix=assets https://raw.githubusercontent.com/openai/whisper/main/whisper/assets/multilingual.tiktoken
wget --directory-prefix=assets assets/mel_filters.npz https://raw.githubusercontent.com/openai/whisper/main/whisper/assets/mel_filters.npz
wget --directory-prefix=assets https://raw.githubusercontent.com/yuekaizhang/Triton-ASR-Client/main/datasets/mini_en/wav/1221-135766-0002.wav
# take large-v3 model as an example
wget --directory-prefix=assets https://openaipublic.azureedge.net/main/whisper/models/e5b1a55b89c1367dacf97e3e19bfd829a01529dbfdeefa8caeb59b3f1b81dadb/large-v3.pt
TensorRT-LLM Whisper builds TensorRT engine(s) from the pytorch checkpoint.
# install requirements first
pip install -r requirements.txt
INFERENCE_PRECISION=float16
WEIGHT_ONLY_PRECISION=int8
MAX_BEAM_WIDTH=4
MAX_BATCH_SIZE=8
checkpoint_dir=whisper_large_v3_weights_${WEIGHT_ONLY_PRECISION}
output_dir=whisper_large_v3_${WEIGHT_ONLY_PRECISION}
# Convert the large-v3 model weights into TensorRT-LLM format.
python3 convert_checkpoint.py \
--use_weight_only \
--weight_only_precision $WEIGHT_ONLY_PRECISION \
--output_dir $checkpoint_dir
# Build the large-v3 model using trtllm-build
trtllm-build --checkpoint_dir ${checkpoint_dir}/encoder \
--output_dir ${output_dir}/encoder \
--moe_plugin disable \
--max_batch_size ${MAX_BATCH_SIZE} \
--gemm_plugin disable \
--bert_attention_plugin ${INFERENCE_PRECISION} \
--max_input_len 3000 --max_seq_len=3000
trtllm-build --checkpoint_dir ${checkpoint_dir}/decoder \
--output_dir ${output_dir}/decoder \
--moe_plugin disable \
--max_beam_width ${MAX_BEAM_WIDTH} \
--max_batch_size ${MAX_BATCH_SIZE} \
--max_seq_len 114 \
--max_input_len 14 \
--max_encoder_input_len 3000 \
--gemm_plugin ${INFERENCE_PRECISION} \
--bert_attention_plugin ${INFERENCE_PRECISION} \
--gpt_attention_plugin ${INFERENCE_PRECISION}
Run
Different types of runtime are provided for whisper models. Following an order of serving performance and good usability, we recommend:
- (NEW) Python binding of C++ runtime w/ Paged KV Cache and Inflight Batching (IFB)
- Python runtime w/ Static Batching
Please refer to the documentation for the details of paged kv cache and inflight batching.
Run C++ runtime
Note: to use inflight batching and paged kv cache features in C++ runtime, please make sure you have set --paged_kv_cache enable and --remove_input_padding enable (which is by default enabled) in the trtllm-build command. Meanwhile, if using Python runtime, it is recommended to disable these flag by --paged_kv_cache disable and --remove_input_padding disable to avoid any unnecessary overhead.
# choose the engine you build [./whisper_large_v3, ./whisper_large_v3_int8]
output_dir=./whisper_large_v3
# decode a single audio file
# If the input file does not have a .wav extension, ffmpeg needs to be installed with the following command:
# apt-get update && apt-get install -y ffmpeg
# Inferencing via python binding of C++ runtime with inflight batching (IFB)
python3 run.py --name single_wav_test --engine_dir $output_dir --input_file assets/1221-135766-0002.wav
# decode a whole dataset
python3 run.py --engine_dir $output_dir --dataset hf-internal-testing/librispeech_asr_dummy --enable_warmup --name librispeech_dummy_large_v3
For pure C++ runtime, there is no example given yet. Please check the Executor API to implement your own end-to-end workflow. It is highly recommended to leverage more encapsulated solutions such as the above C++ Python binding or Triton backend.
Run Python runtime
For pure Python runtime, you can simply add the --use_py_session option.
Advanced Usage
--padding_strategy
OpenAI's official Whisper models accept WAV files of up to 30 seconds in length. For files shorter than 30 seconds, padding is required to reach the 30-second mark, which may not be efficient. Currently, three padding strategies are supported:
- max (default): Pads to 30 seconds.
- longest: Pads according to the longest duration in the current batch.
- nopad: No padding is applied. You will need to fine-tune the Whisper model to maintain accuracy. See examples.
--text_prefix
You can modify the input prompt for the Whisper decoder. For example, use <|startoftranscript|><|en|><|zh|><|transcribe|><|notimestamps|> to perform code-switching ASR between Chinese and English.
--compute_cer
Calculates the character error rate (CER) instead of the word error rate (WER) for languages such as Chinese and Japanese.
--dataset, --dataset_name, and --dataset_split
These options allow you to select different decoding audio datasets from Hugging Face.
Distil-Whisper
TensorRT-LLM also supports using distil-whisper's different models by first converting their params and weights from huggingface's naming format to openai whisper naming format. You can do so by running the script distil_whisper/convert_from_distil_whisper.py as follows:
# take distil-medium.en as an example
# download the gpt2.tiktoken
wget --directory-prefix=assets https://raw.githubusercontent.com/openai/whisper/main/whisper/assets/gpt2.tiktoken
# will download the model weights from huggingface and convert them to openai-whisper's pytorch format
# model is saved to ./assets/ by default
python3 distil_whisper/convert_from_distil_whisper.py --model_name distil-whisper/distil-medium.en --output_name distil-medium.en
# convert whisper model weights into TensorRT-LLM
INFERENCE_PRECISION=float16
WEIGHT_ONLY_PRECISION=int8
MAX_BEAM_WIDTH=4
MAX_BATCH_SIZE=8
checkpoint_dir=distil_whisper_medium_en_weights_${WEIGHT_ONLY_PRECISION}
output_dir=distil_whisper_medium_en${WEIGHT_ONLY_PRECISION}
python3 convert_checkpoint.py \
--use_weight_only \
--weight_only_precision $WEIGHT_ONLY_PRECISION \
--output_dir $checkpoint_dir \
--model_name distil-medium.en
Now, we can build and run the model like before:
trtllm-build --checkpoint_dir ${checkpoint_dir}/encoder \
--output_dir ${output_dir}/encoder \
--moe_plugin disable \
--max_batch_size ${MAX_BATCH_SIZE} \
--gemm_plugin disable \
--bert_attention_plugin ${INFERENCE_PRECISION} \
--max_input_len 3000 --max_seq_len=3000
trtllm-build --checkpoint_dir ${checkpoint_dir}/decoder \
--output_dir ${output_dir}/decoder \
--moe_plugin disable \
--max_beam_width ${MAX_BEAM_WIDTH} \
--max_batch_size ${MAX_BATCH_SIZE} \
--max_seq_len 114 \
--max_input_len 14 \
--max_encoder_input_len 3000 \
--gemm_plugin ${INFERENCE_PRECISION} \
--bert_attention_plugin ${INFERENCE_PRECISION} \
--gpt_attention_plugin ${INFERENCE_PRECISION}
# use cpp runtime python bindings
python3 run.py --engine_dir $output_dir --dataset hf-internal-testing/librispeech_asr_dummy --name librispeech_dummy_${output_dir}
# use python runtime
python3 run.py --engine_dir $output_dir --dataset hf-internal-testing/librispeech_asr_dummy --name librispeech_dummy_${output_dir} --use_py_session
Acknowledgment
This implementation of TensorRT-LLM for Whisper has been adapted from the NVIDIA TensorRT-LLM Hackathon 2023 submission of Jinheng Wang, which can be found in the repository Eddie-Wang-Hackathon2023 on GitHub. We extend our gratitude to Jinheng for providing a foundation for the implementation.