mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-14 06:27:45 +08:00
70 lines
2.1 KiB
Python
70 lines
2.1 KiB
Python
from typing import List, Union
|
|
|
|
import cv2
|
|
import numpy as np
|
|
import requests
|
|
import torch
|
|
from PIL import Image
|
|
from torchvision.transforms import ToTensor
|
|
|
|
|
|
def load_image(image: str,
|
|
format: str = "pt",
|
|
device: str = "cuda") -> Union[Image.Image, torch.Tensor]:
|
|
assert format in ["pt", "pil"], "format must be either Pytorch or PIL"
|
|
|
|
if image.startswith("http://") or image.startswith("https://"):
|
|
image = Image.open(requests.get(image, stream=True, timeout=10).raw)
|
|
else:
|
|
image = Image.open(image)
|
|
image = image.convert("RGB")
|
|
if format == "pt":
|
|
return ToTensor()(image).to(device=device)
|
|
else:
|
|
return image
|
|
|
|
|
|
def load_video(
|
|
video: str,
|
|
num_frames: int = 10,
|
|
format: str = "pt",
|
|
device: str = "cuda") -> Union[List[Image.Image], List[torch.Tensor]]:
|
|
assert format in ["pt", "pil"], "format must be either Pytorch or PIL"
|
|
|
|
# Load video frames from a video file
|
|
vidcap = cv2.VideoCapture(video)
|
|
|
|
if not vidcap.isOpened():
|
|
raise ValueError(
|
|
f"Video '{video}' could not be opened. Make sure opencv is installed with video support."
|
|
)
|
|
|
|
# Find the last frame as frame count might not be accurate
|
|
frame_count = int(vidcap.get(cv2.CAP_PROP_FRAME_COUNT))
|
|
while frame_count > 0:
|
|
vidcap.set(cv2.CAP_PROP_POS_FRAMES, frame_count - 1)
|
|
if vidcap.grab():
|
|
break
|
|
frame_count -= 1
|
|
else:
|
|
raise ValueError(f"Video '{video}' has no frames.")
|
|
|
|
# Extract frames uniformly
|
|
indices = np.round(np.linspace(0, frame_count - 1, num_frames)).astype(int)
|
|
frames = {}
|
|
for index in indices:
|
|
if index in frames:
|
|
continue
|
|
vidcap.set(cv2.CAP_PROP_POS_FRAMES, index)
|
|
success, frame = vidcap.read()
|
|
if not success:
|
|
continue
|
|
frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
|
|
frames[index] = Image.fromarray(frame)
|
|
|
|
return [
|
|
ToTensor()(frames[index]).to(
|
|
device=device) if format == "pt" else frames[index]
|
|
for index in indices if index in frames
|
|
]
|