TensorRT-LLMs/examples/pytorch/quickstart_multimodal.py
qixiang-99 3f67a4c9d8
fix: Set default prompts and media for multimodal quickstart example (#3792)
Signed-off-by: qixiang-99 <203170375+qixiang-99@users.noreply.github.com>
2025-04-23 22:02:28 -07:00

120 lines
4.4 KiB
Python

import argparse
import json
import os
from typing import Any, Dict, List
from quickstart_advanced import add_llm_args, setup_llm
from tensorrt_llm.inputs import (INPUT_FORMATTER_MAP, default_image_loader,
default_video_loader)
example_images = [
"https://huggingface.co/datasets/YiYiXu/testing-images/resolve/main/seashore.png",
"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/inpaint.png",
"https://huggingface.co/datasets/Sayali9141/traffic_signal_images/resolve/main/61.jpg",
]
example_image_prompts = [
"Describe the natural environment in the image.",
"Describe the object and the weather condition in the image.",
"Describe the traffic condition on the road in the image.",
]
example_videos = [
"https://huggingface.co/datasets/Efficient-Large-Model/VILA-inference-demos/resolve/main/OAI-sora-tokyo-walk.mp4",
"https://huggingface.co/datasets/Efficient-Large-Model/VILA-inference-demos/resolve/main/world.mp4",
]
example_video_prompts = [
"Tell me what you see in the video briefly.",
"Describe the scene in the video briefly.",
]
def prepare_multimodal_inputs(model_dir: str,
model_type: str,
modality: str,
prompts: List[str],
media: List[str],
image_data_format: str = "pt",
num_frames: int = 8) -> List[Dict[str, Any]]:
inputs = []
if modality == "image":
inputs = default_image_loader(prompts, media, image_data_format)
elif modality == "video":
inputs = default_video_loader(prompts, media, image_data_format,
num_frames)
else:
raise ValueError(f"Unsupported modality: {modality}")
inputs = INPUT_FORMATTER_MAP[model_type](model_dir, inputs)
return inputs
def add_multimodal_args(parser):
parser.add_argument("--model_type",
type=str,
choices=INPUT_FORMATTER_MAP.keys(),
help="Model type.")
parser.add_argument("--modality",
type=str,
choices=["image", "video"],
default="image",
help="Media type.")
parser.add_argument("--media",
type=str,
nargs="+",
help="A single or a list of media filepaths / urls.")
parser.add_argument("--num_frames",
type=int,
default=8,
help="The number of video frames to be sampled.")
return parser
def parse_arguments():
parser = argparse.ArgumentParser(
description="Multimodal models with the PyTorch workflow.")
parser = add_llm_args(parser)
parser = add_multimodal_args(parser)
args = parser.parse_args()
args.kv_cache_enable_block_reuse = False # kv cache reuse does not work for multimodal, force overwrite
if args.kv_cache_fraction is None:
args.kv_cache_fraction = 0.6 # lower the default kv cache fraction for multimodal
return args
def main():
args = parse_arguments()
# set prompts and media to example prompts and images if they are not provided
if args.prompt is None:
args.prompt = example_image_prompts if args.modality == "image" else example_video_prompts
if args.media is None:
args.media = example_images if args.modality == "image" else example_videos
llm, sampling_params = setup_llm(args)
image_format = "pt" # ["pt", "pil"]
if args.model_type is not None:
model_type = args.model_type
else:
model_type = json.load(
open(os.path.join(llm._hf_model_dir, 'config.json')))['model_type']
assert model_type in INPUT_FORMATTER_MAP, f"Unsupported model_type: {model_type}"
inputs = prepare_multimodal_inputs(args.model_dir, model_type,
args.modality, args.prompt, args.media,
image_format, args.num_frames)
outputs = llm.generate(inputs, sampling_params)
for i, output in enumerate(outputs):
prompt = inputs[i]['prompt']
generated_text = output.outputs[0].text
print(f"[{i}] Prompt: {prompt!r}, Generated text: {generated_text!r}")
if __name__ == "__main__":
main()