mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-02-17 08:15:10 +08:00
Signed-off-by: Chang Liu (Enterprise Products) <liuc@nvidia.com> Signed-off-by: Chang Liu <9713593+chang-l@users.noreply.github.com> Signed-off-by: Zhenhua Wang <zhenhuaw@nvidia.com> Co-authored-by: Freddy Qi <junq@nvidia.com> Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com> Co-authored-by: Zhenhua Wang <zhenhuaw@nvidia.com>
427 lines
15 KiB
Python
427 lines
15 KiB
Python
#!/usr/bin/env python
|
|
"""Media Storage for generated images and videos.
|
|
|
|
This module provides storage handlers for persisting generated media assets
|
|
(videos, images) and their associated metadata.
|
|
"""
|
|
|
|
import os
|
|
from io import BytesIO
|
|
from pathlib import Path
|
|
from typing import Any, Optional
|
|
|
|
import torch
|
|
from PIL import Image
|
|
|
|
from tensorrt_llm.logger import logger
|
|
|
|
|
|
class MediaStorage:
|
|
"""Handler for storing images and videos in various formats."""
|
|
|
|
@staticmethod
|
|
def save_image(
|
|
image: Any, output_path: str, format: Optional[str] = None, quality: int = 95
|
|
) -> str:
|
|
"""Save image to file.
|
|
|
|
Args:
|
|
image: torch.Tensor (H, W, C) uint8
|
|
output_path: Path to save the image
|
|
format: Image format (png, jpg, webp). If None, infer from extension
|
|
quality: Quality for lossy formats (1-100, higher is better)
|
|
|
|
Returns:
|
|
Path where the image was saved
|
|
"""
|
|
# Ensure output directory exists
|
|
output_dir = os.path.dirname(output_path)
|
|
if output_dir:
|
|
os.makedirs(output_dir, exist_ok=True)
|
|
|
|
# Convert to PIL Image if needed
|
|
pil_image = MediaStorage._to_pil_image(image)
|
|
|
|
# Determine format
|
|
if format is None:
|
|
ext = os.path.splitext(output_path)[1].lower()
|
|
if ext in [".png"]:
|
|
format = "PNG"
|
|
elif ext in [".jpg", ".jpeg"]:
|
|
format = "JPEG"
|
|
elif ext in [".webp"]:
|
|
format = "WEBP"
|
|
else:
|
|
logger.warning(f"Unknown image extension {ext}, defaulting to PNG")
|
|
format = "PNG"
|
|
output_path = output_path.rsplit(".", 1)[0] + ".png"
|
|
|
|
# Save image with format-specific handling
|
|
MediaStorage._save_pil_image(pil_image, output_path, format, quality)
|
|
|
|
logger.info(f"Saved image to {output_path} (format={format})")
|
|
return output_path
|
|
|
|
@staticmethod
|
|
def convert_image_to_bytes(image: Any, format: str = "PNG", quality: int = 95) -> bytes:
|
|
"""Convert image to bytes buffer.
|
|
|
|
Args:
|
|
image: torch.Tensor (H, W, C) uint8
|
|
format: Image format (PNG, JPEG, WEBP)
|
|
quality: Quality for lossy formats (1-100)
|
|
|
|
Returns:
|
|
Image bytes
|
|
"""
|
|
pil_image = MediaStorage._to_pil_image(image)
|
|
|
|
# Save to bytes buffer
|
|
buffer = BytesIO()
|
|
MediaStorage._save_pil_image(pil_image, buffer, format, quality)
|
|
|
|
return buffer.getvalue()
|
|
|
|
@staticmethod
|
|
def _to_pil_image(image: torch.Tensor) -> Image.Image:
|
|
"""Convert torch.Tensor to PIL Image.
|
|
|
|
Args:
|
|
image: torch.Tensor (H, W, C) uint8
|
|
|
|
Returns:
|
|
PIL Image
|
|
"""
|
|
if not isinstance(image, torch.Tensor):
|
|
raise ValueError(f"Expected torch.Tensor, got {type(image)}")
|
|
|
|
# Convert to numpy for PIL
|
|
image_np = image.cpu().numpy()
|
|
return Image.fromarray(image_np)
|
|
|
|
@staticmethod
|
|
def _save_pil_image(
|
|
pil_image: Image.Image,
|
|
output: Any, # Can be path string or BytesIO
|
|
format: str,
|
|
quality: int,
|
|
):
|
|
"""Save PIL Image to file or buffer.
|
|
|
|
Args:
|
|
pil_image: PIL Image to save
|
|
output: Output path (str) or BytesIO buffer
|
|
format: Image format (PNG, JPEG, WEBP)
|
|
quality: Quality for lossy formats (1-100)
|
|
"""
|
|
format_upper = format.upper()
|
|
|
|
if format_upper in ["JPEG", "JPG"]:
|
|
# Convert RGBA to RGB for JPEG
|
|
if pil_image.mode in ("RGBA", "LA", "P"):
|
|
background = Image.new("RGB", pil_image.size, (255, 255, 255))
|
|
if pil_image.mode == "P":
|
|
pil_image = pil_image.convert("RGBA")
|
|
background.paste(
|
|
pil_image, mask=pil_image.split()[-1] if pil_image.mode == "RGBA" else None
|
|
)
|
|
pil_image = background
|
|
pil_image.save(output, format="JPEG", quality=quality, optimize=True)
|
|
elif format_upper == "WEBP":
|
|
pil_image.save(output, format="WEBP", quality=quality)
|
|
else: # PNG or default
|
|
pil_image.save(output, format="PNG", optimize=True)
|
|
|
|
@staticmethod
|
|
def save_video(
|
|
video: Any,
|
|
output_path: str,
|
|
audio: Optional[Any] = None,
|
|
frame_rate: float = 24.0,
|
|
format: Optional[str] = None,
|
|
) -> str:
|
|
"""Save video to file with optional audio.
|
|
|
|
Args:
|
|
video: Video frames as torch.Tensor (T, H, W, C) uint8
|
|
output_path: Path to save the video
|
|
audio: Optional audio as torch.Tensor
|
|
frame_rate: Frames per second (default: 24.0)
|
|
format: Video format (mp4, gif, png). If None, infer from extension
|
|
|
|
Returns:
|
|
Path where the video was saved
|
|
"""
|
|
# Ensure output directory exists
|
|
if isinstance(output_path, Path):
|
|
output_path = str(output_path)
|
|
|
|
output_dir = os.path.dirname(output_path)
|
|
if output_dir:
|
|
os.makedirs(output_dir, exist_ok=True)
|
|
|
|
# Determine format
|
|
if format is None:
|
|
ext = os.path.splitext(output_path)[1].lower()
|
|
format = ext[1:] if ext else "mp4"
|
|
|
|
format = format.lower()
|
|
|
|
# Save based on format
|
|
if format == "mp4":
|
|
MediaStorage._save_mp4(video, audio, output_path, frame_rate)
|
|
elif format == "gif":
|
|
MediaStorage._save_gif(video, output_path, frame_rate)
|
|
elif format == "png":
|
|
MediaStorage._save_middle_frame(video, output_path)
|
|
else:
|
|
logger.warning(f"Unsupported video format: {format}, defaulting to mp4")
|
|
output_path = output_path.rsplit(".", 1)[0] + ".mp4"
|
|
MediaStorage._save_mp4(video, audio, output_path, frame_rate)
|
|
|
|
return output_path
|
|
|
|
@staticmethod
|
|
def convert_video_to_bytes(
|
|
video: Any, audio: Optional[Any] = None, frame_rate: float = 24.0, format: str = "mp4"
|
|
) -> bytes:
|
|
"""Convert video to bytes buffer.
|
|
|
|
Args:
|
|
video: Video frames as torch.Tensor (T, H, W, C) uint8
|
|
audio: Optional audio as torch.Tensor
|
|
frame_rate: Frames per second
|
|
format: Video format (mp4, gif)
|
|
|
|
Returns:
|
|
Video bytes
|
|
"""
|
|
import tempfile
|
|
|
|
# Create temporary file
|
|
with tempfile.NamedTemporaryFile(suffix=f".{format}", delete=False) as tmp_file:
|
|
tmp_path = tmp_file.name
|
|
|
|
try:
|
|
# Save to temporary file
|
|
MediaStorage.save_video(video, tmp_path, audio, frame_rate, format)
|
|
|
|
# Read bytes
|
|
with open(tmp_path, "rb") as f:
|
|
video_bytes = f.read()
|
|
|
|
return video_bytes
|
|
finally:
|
|
# Clean up temporary file
|
|
if os.path.exists(tmp_path):
|
|
os.unlink(tmp_path)
|
|
|
|
@staticmethod
|
|
def _save_mp4(
|
|
video: torch.Tensor, audio: Optional[torch.Tensor], output_path: str, frame_rate: float
|
|
) -> str:
|
|
"""Save video with optional audio as MP4.
|
|
|
|
Args:
|
|
video: Video frames as torch.Tensor (T, H, W, C) uint8
|
|
audio: Optional audio as torch.Tensor
|
|
output_path: Output path for MP4
|
|
frame_rate: Frames per second
|
|
|
|
Returns:
|
|
Path where the video was saved
|
|
"""
|
|
try:
|
|
from fractions import Fraction
|
|
|
|
import av
|
|
|
|
if not isinstance(video, torch.Tensor):
|
|
raise ValueError(f"Expected torch.Tensor for video, got {type(video)}")
|
|
|
|
# Convert video tensor to numpy: (T, H, W, C) uint8
|
|
video_np = video.cpu().numpy()
|
|
num_frames, height, width, channels = video_np.shape
|
|
|
|
# Ensure RGB format (3 channels)
|
|
if channels != 3:
|
|
raise ValueError(f"Expected 3-channel RGB video, got {channels} channels")
|
|
|
|
# Open output container
|
|
container = av.open(output_path, mode="w")
|
|
|
|
# Add video stream (H.264 codec)
|
|
video_stream = container.add_stream("libx264", rate=int(frame_rate))
|
|
video_stream.width = width
|
|
video_stream.height = height
|
|
video_stream.pix_fmt = "yuv420p"
|
|
video_stream.options = {"preset": "medium", "crf": "23"}
|
|
|
|
# Pre-process audio and add audio stream BEFORE any muxing.
|
|
# All streams must be registered before the first mux() call
|
|
# (which triggers container header writing).
|
|
audio_stream = None
|
|
audio_tensor = None
|
|
audio_sample_rate = 24000 # Default sample rate
|
|
if audio is not None:
|
|
if not isinstance(audio, torch.Tensor):
|
|
raise ValueError(f"Expected torch.Tensor for audio, got {type(audio)}")
|
|
|
|
# Prepare audio tensor: convert to (samples, channels) format
|
|
audio_tensor = audio
|
|
|
|
# Handle different audio tensor dimensions
|
|
if audio_tensor.ndim == 1:
|
|
# Mono audio: (samples,) -> (samples, 1)
|
|
audio_tensor = audio_tensor[:, None]
|
|
elif audio_tensor.ndim == 2:
|
|
# If shape[1] != 2 and shape[0] == 2, transpose to (samples, channels)
|
|
if audio_tensor.shape[1] != 2 and audio_tensor.shape[0] == 2:
|
|
audio_tensor = audio_tensor.T
|
|
if audio_tensor.shape[1] > 2:
|
|
audio_tensor = audio_tensor[:, :2]
|
|
elif audio_tensor.ndim == 3:
|
|
if audio_tensor.shape[0] == 1:
|
|
audio_tensor = audio_tensor.squeeze(0)
|
|
else:
|
|
audio_tensor = audio_tensor[0]
|
|
if audio_tensor.shape[1] != 2 and audio_tensor.shape[0] == 2:
|
|
audio_tensor = audio_tensor.T
|
|
if audio_tensor.shape[1] > 2:
|
|
audio_tensor = audio_tensor[:, :2]
|
|
else:
|
|
raise ValueError(
|
|
f"Unsupported audio tensor shape: {audio_tensor.shape}. "
|
|
f"Expected 1D, 2D, or 3D tensor."
|
|
)
|
|
|
|
if audio_tensor.shape[1] > 2:
|
|
audio_tensor = audio_tensor[:, :2]
|
|
|
|
# Convert to int16 if needed
|
|
if audio_tensor.dtype != torch.int16:
|
|
audio_tensor = torch.clip(audio_tensor, -1.0, 1.0)
|
|
audio_tensor = (audio_tensor * 32767.0).to(torch.int16)
|
|
|
|
# Add audio stream now (before any muxing)
|
|
audio_stream = container.add_stream("aac", rate=audio_sample_rate)
|
|
audio_stream.codec_context.sample_rate = audio_sample_rate
|
|
audio_stream.codec_context.layout = "stereo"
|
|
audio_stream.codec_context.time_base = Fraction(1, audio_sample_rate)
|
|
|
|
# --- Encode video frames ---
|
|
for frame_array in video_np:
|
|
frame = av.VideoFrame.from_ndarray(frame_array, format="rgb24")
|
|
for packet in video_stream.encode(frame):
|
|
container.mux(packet)
|
|
|
|
# Flush video encoder
|
|
for packet in video_stream.encode():
|
|
container.mux(packet)
|
|
|
|
# --- Encode audio (after video is done) ---
|
|
if audio_stream is not None and audio_tensor is not None:
|
|
# Build packed int16 frame: (1, samples*channels)
|
|
audio_np = audio_tensor.contiguous().reshape(1, -1).cpu().numpy()
|
|
|
|
frame_in = av.AudioFrame.from_ndarray(audio_np, format="s16", layout="stereo")
|
|
frame_in.sample_rate = audio_sample_rate
|
|
|
|
# Use AudioResampler to convert s16→fltp (AAC's native format)
|
|
cc = audio_stream.codec_context
|
|
audio_resampler = av.audio.resampler.AudioResampler(
|
|
format=cc.format or "fltp",
|
|
layout=cc.layout or "stereo",
|
|
rate=cc.sample_rate or audio_sample_rate,
|
|
)
|
|
|
|
audio_next_pts = 0
|
|
for rframe in audio_resampler.resample(frame_in):
|
|
if rframe.pts is None:
|
|
rframe.pts = audio_next_pts
|
|
audio_next_pts += rframe.samples
|
|
rframe.sample_rate = audio_sample_rate
|
|
container.mux(audio_stream.encode(rframe))
|
|
|
|
# Flush audio encoder
|
|
for packet in audio_stream.encode():
|
|
container.mux(packet)
|
|
|
|
# Close container
|
|
container.close()
|
|
|
|
logger.info(f"Saved video{' with audio' if audio is not None else ''} to {output_path}")
|
|
return output_path
|
|
|
|
except ImportError:
|
|
logger.warning(
|
|
"PyAV (av) library not available. "
|
|
"Falling back to saving middle frame as PNG. "
|
|
"Install with: pip install av"
|
|
)
|
|
png_path = output_path.replace(".mp4", ".png")
|
|
return MediaStorage._save_middle_frame(video, png_path)
|
|
except Exception as e:
|
|
logger.error(f"Error encoding video with PyAV: {e}")
|
|
import traceback
|
|
|
|
logger.error(traceback.format_exc())
|
|
logger.warning("Falling back to saving middle frame as PNG.")
|
|
png_path = output_path.replace(".mp4", ".png")
|
|
return MediaStorage._save_middle_frame(video, png_path)
|
|
|
|
@staticmethod
|
|
def _save_gif(video: torch.Tensor, output_path: str, frame_rate: float) -> str:
|
|
"""Save video as animated GIF.
|
|
|
|
Args:
|
|
video: Video frames as torch.Tensor (T, H, W, C) uint8
|
|
output_path: Output path for GIF
|
|
frame_rate: Frames per second
|
|
|
|
Returns:
|
|
Path where the GIF was saved
|
|
"""
|
|
if not isinstance(video, torch.Tensor):
|
|
raise ValueError(f"Expected torch.Tensor for video, got {type(video)}")
|
|
|
|
# Convert to numpy and then to list of PIL Images
|
|
video_np = video.cpu().numpy()
|
|
frames = [Image.fromarray(video_np[i]) for i in range(video_np.shape[0])]
|
|
|
|
# Save as GIF
|
|
duration_ms = int(1000 / frame_rate)
|
|
frames[0].save(
|
|
output_path,
|
|
save_all=True,
|
|
append_images=frames[1:],
|
|
optimize=False,
|
|
duration=duration_ms,
|
|
loop=0,
|
|
)
|
|
logger.info(f"Saved video as GIF to {output_path} ({len(frames)} frames)")
|
|
return output_path
|
|
|
|
@staticmethod
|
|
def _save_middle_frame(video: torch.Tensor, output_path: str) -> str:
|
|
"""Save middle frame of video as PNG.
|
|
|
|
Args:
|
|
video: Video frames as torch.Tensor (T, H, W, C) uint8
|
|
output_path: Output path for PNG
|
|
|
|
Returns:
|
|
Path where the frame was saved
|
|
"""
|
|
if not isinstance(video, torch.Tensor):
|
|
raise ValueError(f"Expected torch.Tensor for video, got {type(video)}")
|
|
|
|
# Extract middle frame
|
|
video_np = video.cpu().numpy()
|
|
frame_idx = video_np.shape[0] // 2
|
|
image = Image.fromarray(video_np[frame_idx])
|
|
|
|
image.save(output_path)
|
|
logger.info(f"Saved frame {frame_idx} to {output_path}")
|
|
return output_path
|