TensorRT-LLMs/tensorrt_llm/inputs/utils.py
rakib-hasan 25407249a5
[TRTLLM-5054][fix] Removing repeated loading of input processor (#4161)
removing repeated loading of input processor

Signed-off-by: Rakib Hasan <rhasan@nvidia.com>
2025-05-16 08:04:58 +08:00

326 lines
10 KiB
Python

import base64
import tempfile
from io import BytesIO
from pathlib import Path
from typing import List, Union
from urllib.parse import urlparse
import aiohttp
import cv2
import numpy as np
import requests
import torch
from PIL import Image
from torchvision.transforms import ToTensor
from transformers import AutoProcessor
def _load_and_convert_image(image):
image = Image.open(image)
image.load()
return image.convert("RGB")
def load_image(image: str,
format: str = "pt",
device: str = "cuda") -> Union[Image.Image, torch.Tensor]:
assert format in ["pt", "pil"], "format must be either Pytorch or PIL"
parsed_url = urlparse(image)
if parsed_url.scheme in ["http", "https"]:
image = requests.get(image, stream=True, timeout=10).raw
image = _load_and_convert_image(image)
elif parsed_url.scheme == "data":
data_spec, data = parsed_url.path.split(",", 1)
media_type, data_type = data_spec.split(";", 1)
if data_type != "base64":
msg = "Only base64 data URLs are supported for now."
raise NotImplementedError(msg)
content = base64.b64decode(data)
image = _load_and_convert_image(BytesIO(content))
else:
image = _load_and_convert_image(image)
if format == "pt":
return ToTensor()(image).to(device=device)
else:
return image
async def async_load_image(
image: str,
format: str = "pt",
device: str = "cuda") -> Union[Image.Image, torch.Tensor]:
assert format in ["pt", "pil"], "format must be either Pytorch or PIL"
parsed_url = urlparse(image)
if parsed_url.scheme in ["http", "https"]:
async with aiohttp.ClientSession() as session:
async with session.get(image) as response:
content = await response.read()
image = _load_and_convert_image(BytesIO(content))
elif parsed_url.scheme == "data":
data_spec, data = parsed_url.path.split(",", 1)
media_type, data_type = data_spec.split(";", 1)
if data_type != "base64":
msg = "Only base64 data URLs are supported for now."
raise NotImplementedError(msg)
content = base64.b64decode(data)
image = _load_and_convert_image(BytesIO(content))
else:
image = _load_and_convert_image(Path(parsed_url.path))
if format == "pt":
return ToTensor()(image).to(device=device)
else:
return image
def load_video(
video: str,
num_frames: int = 10,
format: str = "pt",
device: str = "cuda") -> Union[List[Image.Image], List[torch.Tensor]]:
assert format in ["pt", "pil"], "format must be either Pytorch or PIL"
# Load video frames from a video file
vidcap = cv2.VideoCapture(video)
if not vidcap.isOpened():
raise ValueError(
f"Video '{video}' could not be opened. Make sure opencv is installed with video support."
)
# Find the last frame as frame count might not be accurate
frame_count = int(vidcap.get(cv2.CAP_PROP_FRAME_COUNT))
while frame_count > 0:
vidcap.set(cv2.CAP_PROP_POS_FRAMES, frame_count - 1)
if vidcap.grab():
break
frame_count -= 1
else:
raise ValueError(f"Video '{video}' has no frames.")
# Extract frames uniformly
indices = np.round(np.linspace(0, frame_count - 1, num_frames)).astype(int)
frames = {}
for index in indices:
if index in frames:
continue
vidcap.set(cv2.CAP_PROP_POS_FRAMES, index)
success, frame = vidcap.read()
if not success:
continue
frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
frames[index] = Image.fromarray(frame)
return [
ToTensor()(frames[index]).to(
device=device) if format == "pt" else frames[index]
for index in indices if index in frames
]
async def async_load_video(
video: str,
num_frames: int = 10,
format: str = "pt",
device: str = "cuda") -> Union[List[Image.Image], List[torch.Tensor]]:
assert format in ["pt", "pil"], "format must be either Pytorch or PIL"
parsed_url = urlparse(video)
if parsed_url.scheme in ["http", "https"]:
async with aiohttp.ClientSession() as session:
async with session.get(video) as response:
with tempfile.NamedTemporaryFile(delete=False,
suffix='.mp4') as tmp:
tmp.write(await response.content.read())
video_path = tmp.name
# TODO: add case for video encoded in base64
else:
video_path = video
return load_video(video_path, num_frames, format, device)
# Copied from https://github.com/vllm-project/vllm/blob/main/examples/online_serving/openai_chat_completion_client_for_multimodal.py#L38
def encode_base64_content_from_url(content_url: str) -> str:
"""Encode a content retrieved from a remote url to base64 format."""
with requests.get(content_url, timeout=10) as response:
response.raise_for_status()
result = base64.b64encode(response.content).decode('utf-8')
return result
"""
VLM input preparation.
"""
def format_vila_input(model_dir, inputs):
"""
This function formats the input for the VILA/NVILA VL model.
Arguments:
model_dir: The directory of the model to load any preprocessor.
inputs: The list of inputs to format.
Returns:
A list of dictionaries where "prompt" data is modified to a TextPrompt that combines text prompt and multimodal data.
"""
def add_media_token(prompt, multi_modal_data):
mm_tokens = ""
if "image" in multi_modal_data:
for _ in multi_modal_data["image"]:
mm_tokens += "<image>"
elif "video" in multi_modal_data:
for _ in multi_modal_data["video"]:
mm_tokens += "<vila/video>"
return mm_tokens + prompt
for input in inputs:
input["prompt"] = add_media_token(input["prompt"],
input["multi_modal_data"])
return inputs
def format_generic_input(model_dir, inputs):
"""
This function formats the input for the Llava Next VL model.
Arguments:
model_dir: The directory of the model to load any preprocessor.
inputs: The list of inputs to format.
Returns:
A list of dictionaries where "prompt" data is modified to a TextPrompt that combines text prompt and multimodal data.
"""
processor = AutoProcessor.from_pretrained(model_dir)
# Single-image inference chat template. For multi-image template,
# see https://huggingface.co/docs/transformers/en/model_doc/llava_next#multi-image-inference.
def apply_template(prompt, multimodal_data):
conversation = [
{
"role":
"user",
"content": [
{
"type": "text",
"text": prompt
},
*[{
"type": "image"
} for _ in multimodal_data["image"]],
],
},
]
return processor.apply_chat_template(
conversation,
add_generation_prompt=True,
)
for input in inputs:
input["prompt"] = apply_template(input["prompt"],
input["multi_modal_data"])
return inputs
def format_qwen2_vl_input(model_dir, inputs):
"""
This function formats the input for the Qwen2/Qwen2.5 VL model.
Arguments:
model_dir: The directory of the model to load any preprocessor.
inputs: The list of inputs to format.
Returns:
A list of dictionaries where "prompt" data is modified to a TextPrompt that combines text prompt and multimodal data.
"""
processor = AutoProcessor.from_pretrained(model_dir)
def apply_template(prompt, multimodal_data):
content = [{
"type": media_type
} for media_type, items in multimodal_data.items()
for _ in items] + [{
"type": "text",
"text": prompt
}]
conversation = [{"role": "user", "content": content}]
return processor.apply_chat_template(
conversation,
tokenize=False,
add_generation_prompt=True,
)
for input in inputs:
input["prompt"] = apply_template(input["prompt"],
input["multi_modal_data"])
return inputs
def default_image_loader(prompts: List[str],
images: Union[List[List[str]], List[str]],
image_data_format: str = "pt"):
if len(images) > len(prompts) and len(prompts) == 1:
# 1 prompt + N media
images = [images]
assert len(images) == len(prompts)
inputs = [{
"prompt": prompt,
"multi_modal_data": {
"image": [
load_image(i, format=image_data_format, device="cuda")
for i in image
] if isinstance(image, list) else
[load_image(image, format=image_data_format, device="cuda")]
}
} for prompt, image in zip(prompts, images)]
return inputs
def default_video_loader(prompts: List[str],
videos: Union[List[List[str]], List[str]],
image_data_format: str = "pt",
num_frames: int = 8):
if len(videos) > len(prompts) and len(prompts) == 1:
# 1 prompt + N media
videos = [videos]
assert len(videos) == len(prompts)
inputs = [{
"prompt": prompt,
"multi_modal_data": {
"video": [
load_video(
i, num_frames, format=image_data_format, device="cuda")
for i in video
] if isinstance(video, list) else [
load_video(
video, num_frames, format=image_data_format, device="cuda")
]
}
} for prompt, video in zip(prompts, videos)]
return inputs
INPUT_FORMATTER_MAP = {
"llava_llama": format_vila_input,
"llava_next": format_generic_input,
"qwen2_vl": format_qwen2_vl_input,
"qwen2_5_vl": format_qwen2_vl_input,
"llama4": format_generic_input,
}