from typing import List, Union import cv2 import numpy as np import requests import torch from PIL import Image from torchvision.transforms import ToTensor def load_image(image: str, format: str = "pt", device: str = "cuda") -> Union[Image.Image, torch.Tensor]: assert format in ["pt", "pil"], "format must be either Pytorch or PIL" if image.startswith("http://") or image.startswith("https://"): image = Image.open(requests.get(image, stream=True, timeout=10).raw) else: image = Image.open(image) image = image.convert("RGB") if format == "pt": return ToTensor()(image).to(device=device) else: return image def load_video( video: str, num_frames: int = 10, format: str = "pt", device: str = "cuda") -> Union[List[Image.Image], List[torch.Tensor]]: assert format in ["pt", "pil"], "format must be either Pytorch or PIL" # Load video frames from a video file vidcap = cv2.VideoCapture(video) if not vidcap.isOpened(): raise ValueError( f"Video '{video}' could not be opened. Make sure opencv is installed with video support." ) # Find the last frame as frame count might not be accurate frame_count = int(vidcap.get(cv2.CAP_PROP_FRAME_COUNT)) while frame_count > 0: vidcap.set(cv2.CAP_PROP_POS_FRAMES, frame_count - 1) if vidcap.grab(): break frame_count -= 1 else: raise ValueError(f"Video '{video}' has no frames.") # Extract frames uniformly indices = np.round(np.linspace(0, frame_count - 1, num_frames)).astype(int) frames = {} for index in indices: if index in frames: continue vidcap.set(cv2.CAP_PROP_POS_FRAMES, index) success, frame = vidcap.read() if not success: continue frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) frames[index] = Image.fromarray(frame) return [ ToTensor()(frames[index]).to( device=device) if format == "pt" else frames[index] for index in indices if index in frames ]