TensorRT-LLMs/examples/pytorch/quickstart_multimodal.py
rakib-hasan ff3b741045
feat: adding multimodal (only image for now) support in trtllm-bench (#3490)
* feat: adding multimodal (only image for now) support in trtllm-bench

Signed-off-by: Rakib Hasan <rhasan@nvidia.com>

* fix: add  in load_dataset() calls to maintain the v2.19.2 behavior

Signed-off-by: Rakib Hasan <rhasan@nvidia.com>

* re-adding prompt_token_ids and using that for prompt_len

Signed-off-by: Rakib Hasan <rhasan@nvidia.com>

* updating the datasets version in examples as well

Signed-off-by: Rakib Hasan <rhasan@nvidia.com>

* api changes are not needed

Signed-off-by: Rakib Hasan <rhasan@nvidia.com>

* moving datasets requirement and removing a missed api change

Signed-off-by: Rakib Hasan <rhasan@nvidia.com>

* addressing review comments

Signed-off-by: Rakib Hasan <rhasan@nvidia.com>

* refactoring the quickstart example

Signed-off-by: Rakib Hasan <rhasan@nvidia.com>

---------

Signed-off-by: Rakib Hasan <rhasan@nvidia.com>
2025-04-18 07:06:16 +08:00

115 lines
4.1 KiB
Python

import argparse
import json
import os
from typing import Any, Dict, List
from quickstart_advanced import add_llm_args, setup_llm
from tensorrt_llm.inputs import (INPUT_FORMATTER_MAP, default_image_loader,
default_video_loader)
example_images = [
"https://huggingface.co/datasets/YiYiXu/testing-images/resolve/main/seashore.png",
"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/inpaint.png",
"https://huggingface.co/datasets/Sayali9141/traffic_signal_images/resolve/main/61.jpg",
]
example_image_prompts = [
"Describe the natural environment in the image.",
"Describe the object and the weather condition in the image.",
"Describe the traffic condition on the road in the image.",
]
example_videos = [
"https://huggingface.co/datasets/Efficient-Large-Model/VILA-inference-demos/resolve/main/OAI-sora-tokyo-walk.mp4",
"https://huggingface.co/datasets/Efficient-Large-Model/VILA-inference-demos/resolve/main/world.mp4",
]
example_video_prompts = [
"Tell me what you see in the video briefly.",
"Describe the scene in the video briefly.",
]
def prepare_multimodal_inputs(model_dir: str,
model_type: str,
modality: str,
prompts: List[str],
media: List[str],
image_data_format: str = "pt",
num_frames: int = 8) -> List[Dict[str, Any]]:
inputs = []
if modality == "image":
inputs = default_image_loader(prompts, media, image_data_format)
elif modality == "video":
inputs = default_video_loader(prompts, media, image_data_format,
num_frames)
else:
raise ValueError(f"Unsupported modality: {modality}")
inputs = INPUT_FORMATTER_MAP[model_type](model_dir, inputs)
return inputs
def add_multimodal_args(parser):
parser.add_argument("--model_type",
type=str,
choices=INPUT_FORMATTER_MAP.keys(),
help="Model type.")
parser.add_argument("--modality",
type=str,
choices=["image", "video"],
default="image",
help="Media type.")
parser.add_argument("--media",
type=str,
nargs="+",
help="A single or a list of media filepaths / urls.")
parser.add_argument("--num_frames",
type=int,
default=8,
help="The number of video frames to be sampled.")
return parser
def parse_arguments():
parser = argparse.ArgumentParser(
description="Multimodal models with the PyTorch workflow.")
parser = add_llm_args(parser)
parser = add_multimodal_args(parser)
args = parser.parse_args()
args.kv_cache_enable_block_reuse = False # kv cache reuse does not work for multimodal, force overwrite
if args.kv_cache_fraction is None:
args.kv_cache_fraction = 0.6 # lower the default kv cache fraction for multimodal
return args
def main():
args = parse_arguments()
llm, sampling_params = setup_llm(args)
image_format = "pt" # ["pt", "pil"]
if args.model_type is not None:
model_type = args.model_type
else:
model_type = json.load(
open(os.path.join(llm._hf_model_dir, 'config.json')))['model_type']
assert model_type in INPUT_FORMATTER_MAP, f"Unsupported model_type: {model_type}"
inputs = prepare_multimodal_inputs(args.model_dir, model_type,
args.modality, args.prompt, args.media,
image_format, args.num_frames)
outputs = llm.generate(inputs, sampling_params)
for i, output in enumerate(outputs):
prompt = inputs[i]['prompt']
generated_text = output.outputs[0].text
print(f"[{i}] Prompt: {prompt!r}, Generated text: {generated_text!r}")
if __name__ == "__main__":
main()