TensorRT-LLMs/examples/multimodal/README.md
2024-04-16 19:40:08 +08:00

367 lines
15 KiB
Markdown

<!-- omit from toc -->
# Multi-Modal
This document shows how to run multimodal pipelines with TensorRT-LLM, e.g. from image+text input modalities to text output.
Multimodal models' LLM part has an additional parameter `--max_multimodal_len` compared to LLM-only build commands. Under the hood, `max_multimodal_len` and `max_prompt_embedding_table_size` are effectively the same concept, i.e., prepended/concatenated embeddings (either multimodal feature embeddings or prompt tuning embeddings) to the LLM input embeddings. The multimodal features from the visual encoder of shape `[batch_size, num_visual_features, visual_hidden_dim]` is flattened as `[batch_size * num_visual_features, visual_hidden_dim]` and passed like a prompt embedding table.
We first describe how to run each model on a single GPU. We then provide general guidelines on using tensor parallelism for LLM part of the pipeline.
- [BLIP2-T5](#blip2-t5)
- [BLIP2-OPT](#blip2-opt)
- [LLaVA and VILA](#llava-and-vila)
- [Nougat](#nougat)
- [Enabling tensor parallelism for multi-GPU](#enabling-tensor-parallelism-for-multi-gpu)
## BLIP2-T5
1. Download Huggingface weights and convert original checkpoint to TRT-LLM checkpoint format
following example in `examples/enc_dec/README.md`.
```bash
export MODEL_NAME="flan-t5-xl" # also flan-t5-xxl
git clone https://huggingface.co/google/${MODEL_NAME} tmp/hf_models/${MODEL_NAME}
python ../enc_dec/convert_checkpoint.py --model_type t5 \
--model_dir tmp/hf_models/${MODEL_NAME} \
--output_dir tmp/trt_models/${MODEL_NAME}/bfloat16 \
--tp_size 1 \
--pp_size 1 \
--weight_data_type float32 \
--dtype bfloat16 \
--max_multimodal_len 256 # 8 (max_batch_size) * 32 (num_visual_features)
```
2. Build TRT-LLM engine from TRT-LLM checkpoint
```bash
trtllm-build --checkpoint_dir tmp/trt_models/${MODEL_NAME}/bfloat16/tp1/pp1/encoder \
--output_dir tmp/trt_engines/${MODEL_NAME}/1-gpu/bfloat16/tp1/encoder \
--paged_kv_cache disable \
--moe_plugin disable \
--enable_xqa disable \
--use_custom_all_reduce disable \
--gemm_plugin bfloat16 \
--bert_attention_plugin bfloat16 \
--gpt_attention_plugin bfloat16 \
--remove_input_padding enable \
--context_fmha disable \
--max_beam_width 1 \
--max_batch_size 8 \
--max_output_len 100 \
--max_input_len 924 \
--max_multimodal_len 256 # 8 (max_batch_size) * 32 (num_visual_features)
# Same command for decoder but don't set --max_multimodal_len
trtllm-build --checkpoint_dir tmp/trt_models/${MODEL_NAME}/bfloat16/tp1/pp1/decoder \
--output_dir tmp/trt_engines/${MODEL_NAME}/1-gpu/bfloat16/tp1/decoder \
--paged_kv_cache disable \
--moe_plugin disable \
--enable_xqa disable \
--use_custom_all_reduce disable \
--gemm_plugin bfloat16 \
--bert_attention_plugin bfloat16 \
--gpt_attention_plugin bfloat16 \
--remove_input_padding enable \
--context_fmha disable \
--max_beam_width 1 \
--max_batch_size 8 \
--max_output_len 100 \
--max_encoder_input_len 924 \
--max_input_len 1
```
**NOTE**: `max_multimodal_len = max_batch_size * num_visual_features`, so if you change max_batch_size, max multimodal length **MUST** be changed accordingly.
The built T5 engines are located in `./tmp/trt_engines/${MODEL_NAME}/1-gpu/bfloat16/tp1`.
3. Build TensorRT engines for visual components
```bash
python build_visual_engine.py --model_type ${MODEL_NAME} --model_path tmp/hf_models/${MODEL_NAME} --max_batch_size 8
```
The built engines are located in `./visual_engines/${MODEL_NAME}`.
To run the BLIP2 pipeline with batch size > 1, change `--max_batch_size` argument to `build_visual_engine.py` accordingly.
4. Assemble everything into BLIP2 pipeline
```bash
python run.py \
--max_new_tokens 30 \
--input_text "Question: which city is this? Answer:" \
--hf_model_dir tmp/hf_models/${MODEL_NAME} \
--visual_engine_dir visual_engines/${MODEL_NAME} \
--llm_engine_dir tmp/trt_engines/${MODEL_NAME}/1-gpu/bfloat16/tp1
```
## BLIP2-OPT
OPT pipeline needs few minor changes from T5 pipeline
1. Convert Huggingface weights to TRT-LLM checkpoint format following `examples/opt/README.md`.
2. Use `trtllm-build` command to build TRT-LLM engine for OPT.
3. The full list of commands is as follows:
```bash
export MODEL_NAME="opt-2.7b" # also opt-6.7b
git clone https://huggingface.co/facebook/${MODEL_NAME} tmp/hf_models/${MODEL_NAME}
python ../opt/convert_checkpoint.py \
--model_dir tmp/hf_models/${MODEL_NAME} \
--dtype float16 \
--output_dir tmp/trt_models/${MODEL_NAME}/fp16/1-gpu
trtllm-build \
--checkpoint_dir tmp/trt_models/${MODEL_NAME}/fp16/1-gpu \
--output_dir trt_engines/${MODEL_NAME}/fp16/1-gpu \
--gemm_plugin float16 \
--max_beam_width 1 \
--max_batch_size 8 \
--max_multimodal_len 256 \
--max_input_len 924 \
--max_output_len 100
python build_visual_engine.py --model_type ${MODEL_NAME} --model_path tmp/hf_models/${MODEL_NAME}
python run.py \
--max_new_tokens 30 \
--input_text "Question: which city is this? Answer:" \
--hf_model_dir tmp/hf_models/${MODEL_NAME} \
--visual_engine_dir visual_engines/${MODEL_NAME} \
--llm_engine_dir trt_engines/${MODEL_NAME}/fp16/1-gpu \
```
4. INT8/INT4 weight-only quantization for OPT can be enabled using commands as follows (take `INT4` as an example, while `INT8` is the default precision for weight-only quantization):
```bash
python ../opt/convert_checkpoint.py \
--model_dir tmp/hf_models/${MODEL_NAME} \
--dtype float16 \
--output_dir tmp/trt_models/${MODEL_NAME}/int4_weightonly/1-gpu \
--use_weight_only \
--weight_only_precision int4
trtllm-build \
--checkpoint_dir tmp/trt_models/${MODEL_NAME}/int4_weightonly/1-gpu \
--output_dir trt_engines/${MODEL_NAME}/int4_weightonly/1-gpu \
--gemm_plugin float16 \
--max_beam_width 1 \
--max_batch_size 8 \
--max_multimodal_len 256 \
--max_input_len 924 \
--max_output_len 100
```
The built OPT engines lie in `trt_engines/${MODEL_NAME}/int4_weightonly/1-gpu`.
You should use this directory as `--llm_engine_dir` argument to `run.py`
**NOTE:** INT8/INT4 option is not supported for BLIP2-T5, because quantization support has not been
added for encoder-decoder models yet.
## LLaVA and VILA
[LLaVA](https://github.com/haotian-liu/LLaVA) and [VILA](https://github.com/Efficient-Large-Model/VILA) are both visual language models (VLM) that can be deployed in TensorRT-LLM with many quantization options.
1. Download Huggingface model weights. These models have both visual and LLM components
unlike BLIP2 example which downloads only LLM components from Huggingface.
For LLaVA,
```bash
export MODEL_NAME="llava-1.5-7b-hf" # also llava-1.5-13b-hf
git clone https://huggingface.co/llava-hf/${MODEL_NAME} tmp/hf_models/${MODEL_NAME}
```
For VILA, we need a few more steps until it is added to HF model zoo
```bash
# clone original VILA repo
export VILA_PATH="tmp/hf_models/VILA"
git clone https://github.com/Efficient-Large-Model/VILA.git ${VILA_PATH}
# download VILA checkpoints
export MODEL_NAME="vila-7B" # also vila-2.7B, vila-13B
git clone https://huggingface.co/Efficient-Large-Model/${MODEL_NAME} tmp/hf_models/${MODEL_NAME}
# turn off delay_load to allow model component access
sed -i 's/delay_load=True/delay_load=False/g' ${VILA_PATH}/llava/model/llava_arch.py
# line manipulation to enable AWQ. otherwise need to replace HF's llama implementation
sed -i '/vision_tower = self.get_vision_tower()/a \ attention_mask = torch.ones_like(input_ids, dtype=torch.bool)' ${VILA_PATH}/llava/model/llava_arch.py
sed -i 's/seqlens_in_batch=sorted_seqlens_in_batch/#seqlens_in_batch=sorted_seqlens_in_batch/g' ${VILA_PATH}/llava/model/language_model/llava_llama.py
```
2. Generate TRT-LLM engine for LLaMA following example in `examples/llama/README.md`
```bash
python ../llama/convert_checkpoint.py \
--model_dir tmp/hf_models/${MODEL_NAME} \
--output_dir tmp/trt_models/${MODEL_NAME}/fp16/1-gpu \
--dtype float16
trtllm-build \
--checkpoint_dir tmp/trt_models/${MODEL_NAME}/fp16/1-gpu \
--output_dir trt_engines/${MODEL_NAME}/fp16/1-gpu \
--gemm_plugin float16 \
--use_fused_mlp \
--max_batch_size 1 \
--max_input_len 2048 \
--max_output_len 512 \
--max_multimodal_len 576 # 1 (max_batch_size) * 576 (num_visual_features)
```
Note: do not use `--use_fused_mlp` flag in quantization mode.
3. Build TensorRT engines for visual components
```bash
python build_visual_engine.py --model_path tmp/hf_models/${MODEL_NAME} --model_type llava # for LLaVA
python build_visual_engine.py --model_path tmp/hf_models/${MODEL_NAME} --model_type vila --vila_path ${VILA_PATH} # for VILA
```
```bash
python run.py \
--max_new_tokens 30 \
--hf_model_dir tmp/hf_models/${MODEL_NAME} \
--visual_engine_dir visual_engines/${MODEL_NAME} \
--llm_engine_dir trt_engines/${MODEL_NAME}/fp16/1-gpu \
--input_text "Question: which city is this? Answer:" # or "Please describe the traffic condition." for VILA
```
Note: use `--run_profiling` for performance measurement, use `--check_accuracy` for accuracy check.
4. (Optional) INT8/INT4 weight-only quantization for LLaMA can be enabled as follows (take `INT4` as an example, while `INT8` is the default precision for weight-only quantization):
```bash
python ../llama/convert_checkpoint.py \
--model_dir tmp/hf_models/${MODEL_NAME} \
--dtype float16 \
--output_dir tmp/trt_models/${MODEL_NAME}/int4_weightonly/1-gpu \
--use_weight_only \
--weight_only_precision int4
trtllm-build \
--checkpoint_dir tmp/trt_models/${MODEL_NAME}/int4_weightonly/1-gpu \
--output_dir trt_engines/${MODEL_NAME}/int4_weightonly/1-gpu \
--gemm_plugin float16 \
--max_batch_size 1 \
--max_input_len 1024 \
--max_output_len 100 \
--max_multimodal_len 576
```
The built engines lie in `trt_engines/${MODEL_NAME}/int4_weightonly/1-gpu`.
You should use this directory as `--llm_engine_dir` argument to `run.py`
5. (Optional) One can also use LLaVA/VILA with other quantization options, like SmoothQuant and INT4 AWQ, that are supported by LLaMA.
Instructions in LLaMA [README](../llama/README.md) to enable SmoothQuant and INT4 AWQ can be re-used to generate
quantized TRT engines for LLM component of LLaVA/VILA.
For example,
```bash
python ../quantization/quantize.py \
--model_dir tmp/hf_models/${MODEL_NAME} \
--output_dir tmp/trt_models/${MODEL_NAME}/int4_awq/1-gpu \
--dtype float16 \
--qformat int4_awq \
--calib_size 32
trtllm-build \
--checkpoint_dir tmp/trt_models/${MODEL_NAME}/int4_awq/1-gpu \
--output_dir trt_engines/${MODEL_NAME}/int4_awq/1-gpu \
--gemm_plugin float16 \
--max_batch_size 1 \
--max_input_len 1024 \
--max_output_len 100 \
--max_multimodal_len 576
```
## Nougat
1. Download Huggingface weights
```bash
export MODEL_NAME="nougat-base" # also nougat-small
git clone https://huggingface.co/facebook/${MODEL_NAME} tmp/hf_models/${MODEL_NAME}
```
2. Convert Huggingface weights into TRT-LLM checkpoints and build TRT engines using scripts in `examples/enc_dec`
Nougat uses mBART architecture but replaces the LLM encoder with a Swin Transformer encoder.
To achieve this, we add an extra `--nougat` flag (over mBART example) to
`convert_checkpoint.py` in `examples/enc_dec` and `trtllm-build`.
```bash
python ../enc_dec/convert_checkpoint.py --model_type bart \
--model_dir tmp/hf_models/${MODEL_NAME} \
--output_dir tmp/trt_models/${MODEL_NAME}/bfloat16 \
--tp_size 1 \
--pp_size 1 \
--weight_data_type float32 \
--dtype bfloat16 \
--nougat
trtllm-build --checkpoint_dir tmp/trt_models/${MODEL_NAME}/bfloat16/tp1/pp1/decoder \
--output_dir tmp/trt_engines/${MODEL_NAME}/1-gpu/bfloat16/tp1/decoder \
--paged_kv_cache disable \
--moe_plugin disable \
--enable_xqa disable \
--use_custom_all_reduce disable \
--gemm_plugin bfloat16 \
--bert_attention_plugin bfloat16 \
--gpt_attention_plugin bfloat16 \
--remove_input_padding enable \
--max_beam_width 1 \
--max_batch_size 1 \
--max_output_len 100 \
--max_input_len 1 \
--max_encoder_input_len 588 # 1 (max_batch_size) * 588 (num_visual_features)
```
3. Generate TensorRT engines for visual components and combine everything into final pipeline.
```bash
python build_visual_engine.py --model_type nougat --model_path tmp/hf_models/${MODEL_NAME}
python run.py \
--hf_model_dir tmp/hf_models/${MODEL_NAME} \
--visual_engine_dir visual_engines/${MODEL_NAME} \
--llm_engine_dir tmp/trt_engines/${MODEL_NAME}/1-gpu/bfloat16/tp1 \
```
Note: Nougat models usually do not need a text prompt.
## Enabling tensor parallelism for multi-GPU
The LLM part of the pipeline can be run on multiple GPUs using tensor parallelism.
The visual encoder will be replicated on each GPU and operate in a data parallel fashion.
To enable tensor parallelism, both weight conversion step (from Huggingface to FT format)
and engine building step should use additional arguments. Finally `run.py` should be prefixed
with `mpirun -n NUM_GPUS --allow-run-as-root`.
The full set of commands to enable 2-way tensor parallelism for LLaVA is:
```bash
export MODEL_NAME="llava-1.5-7b-hf"
python ../llama/convert_checkpoint.py \
--model_dir tmp/hf_models/${MODEL_NAME} \
--output_dir tmp/trt_models/${MODEL_NAME}/fp16/2-gpu \
--dtype float16 --tp_size 2
trtllm-build \
--checkpoint_dir tmp/trt_models/${MODEL_NAME}/fp16/2-gpu \
--output_dir trt_engines/${MODEL_NAME}/fp16/2-gpu \
--gemm_plugin float16 \
--max_batch_size 1 \
--max_input_len 2048 \
--max_output_len 512 \
--max_multimodal_len 576
python build_visual_engine.py --model_type llava --model_path tmp/hf_models/${MODEL_NAME}
mpirun -n 2 --allow-run-as-root \
python run.py \
--max_new_tokens 30 \
--hf_model_dir tmp/hf_models/${MODEL_NAME} \
--visual_engine_dir visual_engines/${MODEL_NAME} \
--llm_engine_dir trt_engines/${MODEL_NAME}/fp16/2-gpu \
```