TensorRT-LLMs/examples/pytorch/quickstart_multimodal.py
rakib-hasan d0eb47d33a
[TRTLLM-5053] Refactoring and Unifying the Multimodal input preparation (#4506)
* refactoring the multimodal input prep

Signed-off-by: Rakib Hasan <rhasan@nvidia.com>

* adding out-of-tree override option

Signed-off-by: Rakib Hasan <rhasan@nvidia.com>

* adding exceptional case for llava-next

Signed-off-by: Rakib Hasan <rhasan@nvidia.com>

* fixing typo

Signed-off-by: Rakib Hasan <rhasan@nvidia.com>

* addressing review comments, adding placement option, handling tokenizer variations

Signed-off-by: Rakib Hasan <rhasan@nvidia.com>

* addressing pytest-asyncio behavior change

Signed-off-by: Rakib Hasan <rhasan@nvidia.com>

---------

Signed-off-by: Rakib Hasan <rhasan@nvidia.com>
2025-06-03 12:02:07 -07:00

102 lines
3.9 KiB
Python

import argparse
import json
import os
from quickstart_advanced import add_llm_args, setup_llm
from tensorrt_llm.inputs import (ALL_SUPPORTED_MULTIMODAL_MODELS,
default_multimodal_input_loader)
example_images = [
"https://huggingface.co/datasets/YiYiXu/testing-images/resolve/main/seashore.png",
"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/inpaint.png",
"https://huggingface.co/datasets/Sayali9141/traffic_signal_images/resolve/main/61.jpg",
]
example_image_prompts = [
"Describe the natural environment in the image.",
"Describe the object and the weather condition in the image.",
"Describe the traffic condition on the road in the image.",
]
example_videos = [
"https://huggingface.co/datasets/Efficient-Large-Model/VILA-inference-demos/resolve/main/OAI-sora-tokyo-walk.mp4",
"https://huggingface.co/datasets/Efficient-Large-Model/VILA-inference-demos/resolve/main/world.mp4",
]
example_video_prompts = [
"Tell me what you see in the video briefly.",
"Describe the scene in the video briefly.",
]
def add_multimodal_args(parser):
parser.add_argument("--model_type",
type=str,
choices=ALL_SUPPORTED_MULTIMODAL_MODELS,
help="Model type.")
parser.add_argument("--modality",
type=str,
choices=["image", "video"],
default="image",
help="Media type.")
parser.add_argument("--media",
type=str,
nargs="+",
help="A single or a list of media filepaths / urls.")
parser.add_argument("--num_frames",
type=int,
default=8,
help="The number of video frames to be sampled.")
return parser
def parse_arguments():
parser = argparse.ArgumentParser(
description="Multimodal models with the PyTorch workflow.")
parser = add_llm_args(parser)
parser = add_multimodal_args(parser)
args = parser.parse_args()
args.disable_kv_cache_reuse = True # kv cache reuse does not work for multimodal, force overwrite
if args.kv_cache_fraction is None:
args.kv_cache_fraction = 0.6 # lower the default kv cache fraction for multimodal
return args
def main():
args = parse_arguments()
# set prompts and media to example prompts and images if they are not provided
if args.prompt is None:
args.prompt = example_image_prompts if args.modality == "image" else example_video_prompts
if args.media is None:
args.media = example_images if args.modality == "image" else example_videos
llm, sampling_params = setup_llm(args)
image_format = "pt" # ["pt", "pil"]
if args.model_type is not None:
model_type = args.model_type
else:
model_type = json.load(
open(os.path.join(llm._hf_model_dir, 'config.json')))['model_type']
assert model_type in ALL_SUPPORTED_MULTIMODAL_MODELS, f"Unsupported model_type: {model_type}"
inputs = default_multimodal_input_loader(tokenizer=llm.tokenizer,
model_dir=llm._hf_model_dir,
model_type=model_type,
modality=args.modality,
prompts=args.prompt,
media=args.media,
image_data_format=image_format,
num_frames=args.num_frames)
outputs = llm.generate(inputs, sampling_params)
for i, output in enumerate(outputs):
prompt = inputs[i]['prompt']
generated_text = output.outputs[0].text
print(f"[{i}] Prompt: {prompt!r}, Generated text: {generated_text!r}")
if __name__ == "__main__":
main()