TensorRT-LLMs/tensorrt_llm/inputs/utils.py
2025-03-11 21:13:42 +08:00

70 lines
2.1 KiB
Python

from typing import List, Union
import cv2
import numpy as np
import requests
import torch
from PIL import Image
from torchvision.transforms import ToTensor
def load_image(image: str,
format: str = "pt",
device: str = "cuda") -> Union[Image.Image, torch.Tensor]:
assert format in ["pt", "pil"], "format must be either Pytorch or PIL"
if image.startswith("http://") or image.startswith("https://"):
image = Image.open(requests.get(image, stream=True, timeout=10).raw)
else:
image = Image.open(image)
image = image.convert("RGB")
if format == "pt":
return ToTensor()(image).to(device=device)
else:
return image
def load_video(
video: str,
num_frames: int = 10,
format: str = "pt",
device: str = "cuda") -> Union[List[Image.Image], List[torch.Tensor]]:
assert format in ["pt", "pil"], "format must be either Pytorch or PIL"
# Load video frames from a video file
vidcap = cv2.VideoCapture(video)
if not vidcap.isOpened():
raise ValueError(
f"Video '{video}' could not be opened. Make sure opencv is installed with video support."
)
# Find the last frame as frame count might not be accurate
frame_count = int(vidcap.get(cv2.CAP_PROP_FRAME_COUNT))
while frame_count > 0:
vidcap.set(cv2.CAP_PROP_POS_FRAMES, frame_count - 1)
if vidcap.grab():
break
frame_count -= 1
else:
raise ValueError(f"Video '{video}' has no frames.")
# Extract frames uniformly
indices = np.round(np.linspace(0, frame_count - 1, num_frames)).astype(int)
frames = {}
for index in indices:
if index in frames:
continue
vidcap.set(cv2.CAP_PROP_POS_FRAMES, index)
success, frame = vidcap.read()
if not success:
continue
frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
frames[index] = Image.fromarray(frame)
return [
ToTensor()(frames[index]).to(
device=device) if format == "pt" else frames[index]
for index in indices if index in frames
]