mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-13 22:18:36 +08:00
* feat: adding multimodal (only image for now) support in trtllm-bench Signed-off-by: Rakib Hasan <rhasan@nvidia.com> * fix: add in load_dataset() calls to maintain the v2.19.2 behavior Signed-off-by: Rakib Hasan <rhasan@nvidia.com> * re-adding prompt_token_ids and using that for prompt_len Signed-off-by: Rakib Hasan <rhasan@nvidia.com> * updating the datasets version in examples as well Signed-off-by: Rakib Hasan <rhasan@nvidia.com> * api changes are not needed Signed-off-by: Rakib Hasan <rhasan@nvidia.com> * moving datasets requirement and removing a missed api change Signed-off-by: Rakib Hasan <rhasan@nvidia.com> * addressing review comments Signed-off-by: Rakib Hasan <rhasan@nvidia.com> * refactoring the quickstart example Signed-off-by: Rakib Hasan <rhasan@nvidia.com> --------- Signed-off-by: Rakib Hasan <rhasan@nvidia.com>
115 lines
4.1 KiB
Python
115 lines
4.1 KiB
Python
import argparse
|
|
import json
|
|
import os
|
|
from typing import Any, Dict, List
|
|
|
|
from quickstart_advanced import add_llm_args, setup_llm
|
|
|
|
from tensorrt_llm.inputs import (INPUT_FORMATTER_MAP, default_image_loader,
|
|
default_video_loader)
|
|
|
|
example_images = [
|
|
"https://huggingface.co/datasets/YiYiXu/testing-images/resolve/main/seashore.png",
|
|
"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/inpaint.png",
|
|
"https://huggingface.co/datasets/Sayali9141/traffic_signal_images/resolve/main/61.jpg",
|
|
]
|
|
example_image_prompts = [
|
|
"Describe the natural environment in the image.",
|
|
"Describe the object and the weather condition in the image.",
|
|
"Describe the traffic condition on the road in the image.",
|
|
]
|
|
example_videos = [
|
|
"https://huggingface.co/datasets/Efficient-Large-Model/VILA-inference-demos/resolve/main/OAI-sora-tokyo-walk.mp4",
|
|
"https://huggingface.co/datasets/Efficient-Large-Model/VILA-inference-demos/resolve/main/world.mp4",
|
|
]
|
|
example_video_prompts = [
|
|
"Tell me what you see in the video briefly.",
|
|
"Describe the scene in the video briefly.",
|
|
]
|
|
|
|
|
|
def prepare_multimodal_inputs(model_dir: str,
|
|
model_type: str,
|
|
modality: str,
|
|
prompts: List[str],
|
|
media: List[str],
|
|
image_data_format: str = "pt",
|
|
num_frames: int = 8) -> List[Dict[str, Any]]:
|
|
|
|
inputs = []
|
|
if modality == "image":
|
|
inputs = default_image_loader(prompts, media, image_data_format)
|
|
elif modality == "video":
|
|
inputs = default_video_loader(prompts, media, image_data_format,
|
|
num_frames)
|
|
else:
|
|
raise ValueError(f"Unsupported modality: {modality}")
|
|
|
|
inputs = INPUT_FORMATTER_MAP[model_type](model_dir, inputs)
|
|
|
|
return inputs
|
|
|
|
|
|
def add_multimodal_args(parser):
|
|
parser.add_argument("--model_type",
|
|
type=str,
|
|
choices=INPUT_FORMATTER_MAP.keys(),
|
|
help="Model type.")
|
|
parser.add_argument("--modality",
|
|
type=str,
|
|
choices=["image", "video"],
|
|
default="image",
|
|
help="Media type.")
|
|
parser.add_argument("--media",
|
|
type=str,
|
|
nargs="+",
|
|
help="A single or a list of media filepaths / urls.")
|
|
parser.add_argument("--num_frames",
|
|
type=int,
|
|
default=8,
|
|
help="The number of video frames to be sampled.")
|
|
return parser
|
|
|
|
|
|
def parse_arguments():
|
|
parser = argparse.ArgumentParser(
|
|
description="Multimodal models with the PyTorch workflow.")
|
|
parser = add_llm_args(parser)
|
|
parser = add_multimodal_args(parser)
|
|
args = parser.parse_args()
|
|
|
|
args.kv_cache_enable_block_reuse = False # kv cache reuse does not work for multimodal, force overwrite
|
|
if args.kv_cache_fraction is None:
|
|
args.kv_cache_fraction = 0.6 # lower the default kv cache fraction for multimodal
|
|
|
|
return args
|
|
|
|
|
|
def main():
|
|
args = parse_arguments()
|
|
|
|
llm, sampling_params = setup_llm(args)
|
|
|
|
image_format = "pt" # ["pt", "pil"]
|
|
if args.model_type is not None:
|
|
model_type = args.model_type
|
|
else:
|
|
model_type = json.load(
|
|
open(os.path.join(llm._hf_model_dir, 'config.json')))['model_type']
|
|
assert model_type in INPUT_FORMATTER_MAP, f"Unsupported model_type: {model_type}"
|
|
|
|
inputs = prepare_multimodal_inputs(args.model_dir, model_type,
|
|
args.modality, args.prompt, args.media,
|
|
image_format, args.num_frames)
|
|
|
|
outputs = llm.generate(inputs, sampling_params)
|
|
|
|
for i, output in enumerate(outputs):
|
|
prompt = inputs[i]['prompt']
|
|
generated_text = output.outputs[0].text
|
|
print(f"[{i}] Prompt: {prompt!r}, Generated text: {generated_text!r}")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|