mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-13 22:18:36 +08:00
* feat: adding multimodal (only image for now) support in trtllm-bench Signed-off-by: Rakib Hasan <rhasan@nvidia.com> * fix: add in load_dataset() calls to maintain the v2.19.2 behavior Signed-off-by: Rakib Hasan <rhasan@nvidia.com> * re-adding prompt_token_ids and using that for prompt_len Signed-off-by: Rakib Hasan <rhasan@nvidia.com> * updating the datasets version in examples as well Signed-off-by: Rakib Hasan <rhasan@nvidia.com> * api changes are not needed Signed-off-by: Rakib Hasan <rhasan@nvidia.com> * moving datasets requirement and removing a missed api change Signed-off-by: Rakib Hasan <rhasan@nvidia.com> * addressing review comments Signed-off-by: Rakib Hasan <rhasan@nvidia.com> * refactoring the quickstart example Signed-off-by: Rakib Hasan <rhasan@nvidia.com> --------- Signed-off-by: Rakib Hasan <rhasan@nvidia.com>
226 lines
6.9 KiB
Python
226 lines
6.9 KiB
Python
from typing import List, Union
|
|
|
|
import cv2
|
|
import numpy as np
|
|
import requests
|
|
import torch
|
|
from PIL import Image
|
|
from torchvision.transforms import ToTensor
|
|
from transformers import AutoProcessor
|
|
|
|
|
|
def load_image(image: str,
|
|
format: str = "pt",
|
|
device: str = "cuda") -> Union[Image.Image, torch.Tensor]:
|
|
assert format in ["pt", "pil"], "format must be either Pytorch or PIL"
|
|
|
|
if image.startswith("http://") or image.startswith("https://"):
|
|
image = Image.open(requests.get(image, stream=True, timeout=10).raw)
|
|
else:
|
|
image = Image.open(image)
|
|
image = image.convert("RGB")
|
|
if format == "pt":
|
|
return ToTensor()(image).to(device=device)
|
|
else:
|
|
return image
|
|
|
|
|
|
def load_video(
|
|
video: str,
|
|
num_frames: int = 10,
|
|
format: str = "pt",
|
|
device: str = "cuda") -> Union[List[Image.Image], List[torch.Tensor]]:
|
|
assert format in ["pt", "pil"], "format must be either Pytorch or PIL"
|
|
|
|
# Load video frames from a video file
|
|
vidcap = cv2.VideoCapture(video)
|
|
|
|
if not vidcap.isOpened():
|
|
raise ValueError(
|
|
f"Video '{video}' could not be opened. Make sure opencv is installed with video support."
|
|
)
|
|
|
|
# Find the last frame as frame count might not be accurate
|
|
frame_count = int(vidcap.get(cv2.CAP_PROP_FRAME_COUNT))
|
|
while frame_count > 0:
|
|
vidcap.set(cv2.CAP_PROP_POS_FRAMES, frame_count - 1)
|
|
if vidcap.grab():
|
|
break
|
|
frame_count -= 1
|
|
else:
|
|
raise ValueError(f"Video '{video}' has no frames.")
|
|
|
|
# Extract frames uniformly
|
|
indices = np.round(np.linspace(0, frame_count - 1, num_frames)).astype(int)
|
|
frames = {}
|
|
for index in indices:
|
|
if index in frames:
|
|
continue
|
|
vidcap.set(cv2.CAP_PROP_POS_FRAMES, index)
|
|
success, frame = vidcap.read()
|
|
if not success:
|
|
continue
|
|
frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
|
|
frames[index] = Image.fromarray(frame)
|
|
|
|
return [
|
|
ToTensor()(frames[index]).to(
|
|
device=device) if format == "pt" else frames[index]
|
|
for index in indices if index in frames
|
|
]
|
|
|
|
|
|
"""
|
|
VLM input preparation.
|
|
"""
|
|
|
|
|
|
def format_vila_input(model_dir, inputs):
|
|
"""
|
|
This function formats the input for the VILA/NVILA VL model.
|
|
|
|
Arguments:
|
|
model_dir: The directory of the model to load any preprocessor.
|
|
inputs: The list of inputs to format.
|
|
|
|
Returns:
|
|
A list of dictionaries where "prompt" data is modified to a TextPrompt that combines text prompt and multimodal data.
|
|
"""
|
|
|
|
def add_media_token(prompt, multi_modal_data):
|
|
mm_tokens = ""
|
|
if "image" in multi_modal_data:
|
|
for _ in multi_modal_data["image"]:
|
|
mm_tokens += "<image>"
|
|
elif "video" in multi_modal_data:
|
|
for _ in multi_modal_data["video"]:
|
|
mm_tokens += "<vila/video>"
|
|
return mm_tokens + prompt
|
|
|
|
for input in inputs:
|
|
input["prompt"] = add_media_token(input["prompt"],
|
|
input["multi_modal_data"])
|
|
return inputs
|
|
|
|
|
|
def format_llava_next_input(model_dir, inputs):
|
|
"""
|
|
This function formats the input for the Llava Next VL model.
|
|
|
|
Arguments:
|
|
model_dir: The directory of the model to load any preprocessor.
|
|
inputs: The list of inputs to format.
|
|
|
|
Returns:
|
|
A list of dictionaries where "prompt" data is modified to a TextPrompt that combines text prompt and multimodal data.
|
|
"""
|
|
processor = AutoProcessor.from_pretrained(model_dir)
|
|
|
|
# Single-image inference chat template. For multi-image template,
|
|
# see https://huggingface.co/docs/transformers/en/model_doc/llava_next#multi-image-inference.
|
|
def apply_template(prompt, multimodal_data):
|
|
conversation = [
|
|
{
|
|
"role": "user",
|
|
"content": [
|
|
{
|
|
"type": "text",
|
|
"text": prompt
|
|
},
|
|
{
|
|
"type": "image"
|
|
},
|
|
],
|
|
},
|
|
]
|
|
return processor.apply_chat_template(
|
|
conversation,
|
|
add_generation_prompt=True,
|
|
)
|
|
|
|
for input in inputs:
|
|
input["prompt"] = apply_template(input["prompt"],
|
|
input["multi_modal_data"])
|
|
return inputs
|
|
|
|
|
|
def format_qwen2_vl_input(model_dir, inputs):
|
|
"""
|
|
This function formats the input for the Qwen2/Qwen2.5 VL model.
|
|
|
|
Arguments:
|
|
model_dir: The directory of the model to load any preprocessor.
|
|
inputs: The list of inputs to format.
|
|
|
|
Returns:
|
|
A list of dictionaries where "prompt" data is modified to a TextPrompt that combines text prompt and multimodal data.
|
|
"""
|
|
processor = AutoProcessor.from_pretrained(model_dir)
|
|
|
|
def apply_template(prompt, multimodal_data):
|
|
content = [{
|
|
"type": media_type
|
|
} for media_type, items in multimodal_data.items()
|
|
for _ in items] + [{
|
|
"type": "text",
|
|
"text": prompt
|
|
}]
|
|
|
|
conversation = [{"role": "user", "content": content}]
|
|
# print(conversation)
|
|
return processor.apply_chat_template(
|
|
conversation,
|
|
tokenize=False,
|
|
add_generation_prompt=True,
|
|
)
|
|
|
|
for input in inputs:
|
|
input["prompt"] = apply_template(input["prompt"],
|
|
input["multi_modal_data"])
|
|
return inputs
|
|
|
|
|
|
def default_image_loader(prompts, images, image_data_format="pt"):
|
|
if len(images) > len(prompts) and len(prompts) == 1:
|
|
# 1 prompt + N media
|
|
images = [images]
|
|
inputs = [{
|
|
"prompt": prompt,
|
|
"multi_modal_data": {
|
|
"image": [
|
|
load_image(i, format=image_data_format, device="cuda")
|
|
for i in image
|
|
] if isinstance(image, list) else
|
|
[load_image(image, format=image_data_format, device="cuda")]
|
|
}
|
|
} for prompt, image in zip(prompts, images)]
|
|
return inputs
|
|
|
|
|
|
def default_video_loader(prompts, videos, image_data_format="pt", num_frames=8):
|
|
if len(videos) > len(prompts) and len(prompts) == 1:
|
|
# 1 prompt + N media
|
|
videos = [videos]
|
|
inputs = [{
|
|
"prompt": prompt,
|
|
"multi_modal_data": {
|
|
"video": [
|
|
load_video(
|
|
i, num_frames, format=image_data_format, device="cuda")
|
|
for i in video
|
|
] if isinstance(video, list) else [
|
|
load_video(
|
|
video, num_frames, format=image_data_format, device="cuda")
|
|
]
|
|
}
|
|
} for prompt, video in zip(prompts, videos)]
|
|
return inputs
|
|
|
|
|
|
INPUT_FORMATTER_MAP = {
|
|
"llava_llama": format_vila_input,
|
|
"llava_next": format_llava_next_input,
|
|
"qwen2_vl": format_qwen2_vl_input,
|
|
"qwen2_5_vl": format_qwen2_vl_input,
|
|
}
|