mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-14 06:27:45 +08:00
109 lines
4.2 KiB
Python
109 lines
4.2 KiB
Python
import argparse
|
|
import json
|
|
import os
|
|
|
|
from quickstart_advanced import add_llm_args, setup_llm
|
|
|
|
from tensorrt_llm.inputs import (ALL_SUPPORTED_MULTIMODAL_MODELS,
|
|
default_multimodal_input_loader)
|
|
|
|
example_images = [
|
|
"https://huggingface.co/datasets/YiYiXu/testing-images/resolve/main/seashore.png",
|
|
"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/inpaint.png",
|
|
"https://huggingface.co/datasets/Sayali9141/traffic_signal_images/resolve/main/61.jpg",
|
|
]
|
|
example_image_prompts = [
|
|
"Describe the natural environment in the image.",
|
|
"Describe the object and the weather condition in the image.",
|
|
"Describe the traffic condition on the road in the image.",
|
|
]
|
|
example_videos = [
|
|
"https://huggingface.co/datasets/Efficient-Large-Model/VILA-inference-demos/resolve/main/OAI-sora-tokyo-walk.mp4",
|
|
"https://huggingface.co/datasets/Efficient-Large-Model/VILA-inference-demos/resolve/main/world.mp4",
|
|
]
|
|
example_video_prompts = [
|
|
"Tell me what you see in the video briefly.",
|
|
"Describe the scene in the video briefly.",
|
|
]
|
|
|
|
|
|
def add_multimodal_args(parser):
|
|
parser.add_argument("--model_type",
|
|
type=str,
|
|
choices=ALL_SUPPORTED_MULTIMODAL_MODELS,
|
|
help="Model type.")
|
|
parser.add_argument("--modality",
|
|
type=str,
|
|
choices=["image", "video"],
|
|
default="image",
|
|
help="Media type.")
|
|
parser.add_argument("--media",
|
|
type=str,
|
|
nargs="+",
|
|
help="A single or a list of media filepaths / urls.")
|
|
parser.add_argument("--num_frames",
|
|
type=int,
|
|
default=8,
|
|
help="The number of video frames to be sampled.")
|
|
parser.add_argument("--image_format",
|
|
type=str,
|
|
choices=["pt", "pil"],
|
|
default="pt",
|
|
help="The format of the image.")
|
|
return parser
|
|
|
|
|
|
def parse_arguments():
|
|
parser = argparse.ArgumentParser(
|
|
description="Multimodal models with the PyTorch workflow.")
|
|
parser = add_llm_args(parser)
|
|
parser = add_multimodal_args(parser)
|
|
args = parser.parse_args()
|
|
|
|
args.disable_kv_cache_reuse = True # kv cache reuse does not work for multimodal, force overwrite
|
|
if args.kv_cache_fraction is None:
|
|
args.kv_cache_fraction = 0.6 # lower the default kv cache fraction for multimodal
|
|
|
|
return args
|
|
|
|
|
|
def main():
|
|
args = parse_arguments()
|
|
# set prompts and media to example prompts and images if they are not provided
|
|
if args.prompt is None:
|
|
args.prompt = example_image_prompts if args.modality == "image" else example_video_prompts
|
|
if args.media is None:
|
|
args.media = example_images if args.modality == "image" else example_videos
|
|
|
|
llm, sampling_params = setup_llm(args)
|
|
|
|
image_format = args.image_format
|
|
if args.model_type is not None:
|
|
model_type = args.model_type
|
|
else:
|
|
model_type = json.load(
|
|
open(os.path.join(llm._hf_model_dir, 'config.json')))['model_type']
|
|
assert model_type in ALL_SUPPORTED_MULTIMODAL_MODELS, f"Unsupported model_type: {model_type}"
|
|
|
|
device = "cuda"
|
|
inputs = default_multimodal_input_loader(tokenizer=llm.tokenizer,
|
|
model_dir=llm._hf_model_dir,
|
|
model_type=model_type,
|
|
modality=args.modality,
|
|
prompts=args.prompt,
|
|
media=args.media,
|
|
image_data_format=image_format,
|
|
num_frames=args.num_frames,
|
|
device=device)
|
|
|
|
outputs = llm.generate(inputs, sampling_params)
|
|
|
|
for i, output in enumerate(outputs):
|
|
prompt = args.prompt[i]
|
|
generated_text = output.outputs[0].text
|
|
print(f"[{i}] Prompt: {prompt!r}, Generated text: {generated_text!r}")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|