TensorRT-LLMs/tensorrt_llm/serve/media_storage.py
Chang Liu 26901e4aa0
[TRTLLM-10612][feat] Initial support of AIGV models in TRTLLM (#11462)
Signed-off-by: Chang Liu (Enterprise Products) <liuc@nvidia.com>
Signed-off-by: Chang Liu <9713593+chang-l@users.noreply.github.com>
Signed-off-by: Zhenhua Wang <zhenhuaw@nvidia.com>
Co-authored-by: Freddy Qi <junq@nvidia.com>
Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com>
Co-authored-by: Zhenhua Wang <zhenhuaw@nvidia.com>
2026-02-14 06:11:11 +08:00

427 lines
15 KiB
Python

#!/usr/bin/env python
"""Media Storage for generated images and videos.
This module provides storage handlers for persisting generated media assets
(videos, images) and their associated metadata.
"""
import os
from io import BytesIO
from pathlib import Path
from typing import Any, Optional
import torch
from PIL import Image
from tensorrt_llm.logger import logger
class MediaStorage:
"""Handler for storing images and videos in various formats."""
@staticmethod
def save_image(
image: Any, output_path: str, format: Optional[str] = None, quality: int = 95
) -> str:
"""Save image to file.
Args:
image: torch.Tensor (H, W, C) uint8
output_path: Path to save the image
format: Image format (png, jpg, webp). If None, infer from extension
quality: Quality for lossy formats (1-100, higher is better)
Returns:
Path where the image was saved
"""
# Ensure output directory exists
output_dir = os.path.dirname(output_path)
if output_dir:
os.makedirs(output_dir, exist_ok=True)
# Convert to PIL Image if needed
pil_image = MediaStorage._to_pil_image(image)
# Determine format
if format is None:
ext = os.path.splitext(output_path)[1].lower()
if ext in [".png"]:
format = "PNG"
elif ext in [".jpg", ".jpeg"]:
format = "JPEG"
elif ext in [".webp"]:
format = "WEBP"
else:
logger.warning(f"Unknown image extension {ext}, defaulting to PNG")
format = "PNG"
output_path = output_path.rsplit(".", 1)[0] + ".png"
# Save image with format-specific handling
MediaStorage._save_pil_image(pil_image, output_path, format, quality)
logger.info(f"Saved image to {output_path} (format={format})")
return output_path
@staticmethod
def convert_image_to_bytes(image: Any, format: str = "PNG", quality: int = 95) -> bytes:
"""Convert image to bytes buffer.
Args:
image: torch.Tensor (H, W, C) uint8
format: Image format (PNG, JPEG, WEBP)
quality: Quality for lossy formats (1-100)
Returns:
Image bytes
"""
pil_image = MediaStorage._to_pil_image(image)
# Save to bytes buffer
buffer = BytesIO()
MediaStorage._save_pil_image(pil_image, buffer, format, quality)
return buffer.getvalue()
@staticmethod
def _to_pil_image(image: torch.Tensor) -> Image.Image:
"""Convert torch.Tensor to PIL Image.
Args:
image: torch.Tensor (H, W, C) uint8
Returns:
PIL Image
"""
if not isinstance(image, torch.Tensor):
raise ValueError(f"Expected torch.Tensor, got {type(image)}")
# Convert to numpy for PIL
image_np = image.cpu().numpy()
return Image.fromarray(image_np)
@staticmethod
def _save_pil_image(
pil_image: Image.Image,
output: Any, # Can be path string or BytesIO
format: str,
quality: int,
):
"""Save PIL Image to file or buffer.
Args:
pil_image: PIL Image to save
output: Output path (str) or BytesIO buffer
format: Image format (PNG, JPEG, WEBP)
quality: Quality for lossy formats (1-100)
"""
format_upper = format.upper()
if format_upper in ["JPEG", "JPG"]:
# Convert RGBA to RGB for JPEG
if pil_image.mode in ("RGBA", "LA", "P"):
background = Image.new("RGB", pil_image.size, (255, 255, 255))
if pil_image.mode == "P":
pil_image = pil_image.convert("RGBA")
background.paste(
pil_image, mask=pil_image.split()[-1] if pil_image.mode == "RGBA" else None
)
pil_image = background
pil_image.save(output, format="JPEG", quality=quality, optimize=True)
elif format_upper == "WEBP":
pil_image.save(output, format="WEBP", quality=quality)
else: # PNG or default
pil_image.save(output, format="PNG", optimize=True)
@staticmethod
def save_video(
video: Any,
output_path: str,
audio: Optional[Any] = None,
frame_rate: float = 24.0,
format: Optional[str] = None,
) -> str:
"""Save video to file with optional audio.
Args:
video: Video frames as torch.Tensor (T, H, W, C) uint8
output_path: Path to save the video
audio: Optional audio as torch.Tensor
frame_rate: Frames per second (default: 24.0)
format: Video format (mp4, gif, png). If None, infer from extension
Returns:
Path where the video was saved
"""
# Ensure output directory exists
if isinstance(output_path, Path):
output_path = str(output_path)
output_dir = os.path.dirname(output_path)
if output_dir:
os.makedirs(output_dir, exist_ok=True)
# Determine format
if format is None:
ext = os.path.splitext(output_path)[1].lower()
format = ext[1:] if ext else "mp4"
format = format.lower()
# Save based on format
if format == "mp4":
MediaStorage._save_mp4(video, audio, output_path, frame_rate)
elif format == "gif":
MediaStorage._save_gif(video, output_path, frame_rate)
elif format == "png":
MediaStorage._save_middle_frame(video, output_path)
else:
logger.warning(f"Unsupported video format: {format}, defaulting to mp4")
output_path = output_path.rsplit(".", 1)[0] + ".mp4"
MediaStorage._save_mp4(video, audio, output_path, frame_rate)
return output_path
@staticmethod
def convert_video_to_bytes(
video: Any, audio: Optional[Any] = None, frame_rate: float = 24.0, format: str = "mp4"
) -> bytes:
"""Convert video to bytes buffer.
Args:
video: Video frames as torch.Tensor (T, H, W, C) uint8
audio: Optional audio as torch.Tensor
frame_rate: Frames per second
format: Video format (mp4, gif)
Returns:
Video bytes
"""
import tempfile
# Create temporary file
with tempfile.NamedTemporaryFile(suffix=f".{format}", delete=False) as tmp_file:
tmp_path = tmp_file.name
try:
# Save to temporary file
MediaStorage.save_video(video, tmp_path, audio, frame_rate, format)
# Read bytes
with open(tmp_path, "rb") as f:
video_bytes = f.read()
return video_bytes
finally:
# Clean up temporary file
if os.path.exists(tmp_path):
os.unlink(tmp_path)
@staticmethod
def _save_mp4(
video: torch.Tensor, audio: Optional[torch.Tensor], output_path: str, frame_rate: float
) -> str:
"""Save video with optional audio as MP4.
Args:
video: Video frames as torch.Tensor (T, H, W, C) uint8
audio: Optional audio as torch.Tensor
output_path: Output path for MP4
frame_rate: Frames per second
Returns:
Path where the video was saved
"""
try:
from fractions import Fraction
import av
if not isinstance(video, torch.Tensor):
raise ValueError(f"Expected torch.Tensor for video, got {type(video)}")
# Convert video tensor to numpy: (T, H, W, C) uint8
video_np = video.cpu().numpy()
num_frames, height, width, channels = video_np.shape
# Ensure RGB format (3 channels)
if channels != 3:
raise ValueError(f"Expected 3-channel RGB video, got {channels} channels")
# Open output container
container = av.open(output_path, mode="w")
# Add video stream (H.264 codec)
video_stream = container.add_stream("libx264", rate=int(frame_rate))
video_stream.width = width
video_stream.height = height
video_stream.pix_fmt = "yuv420p"
video_stream.options = {"preset": "medium", "crf": "23"}
# Pre-process audio and add audio stream BEFORE any muxing.
# All streams must be registered before the first mux() call
# (which triggers container header writing).
audio_stream = None
audio_tensor = None
audio_sample_rate = 24000 # Default sample rate
if audio is not None:
if not isinstance(audio, torch.Tensor):
raise ValueError(f"Expected torch.Tensor for audio, got {type(audio)}")
# Prepare audio tensor: convert to (samples, channels) format
audio_tensor = audio
# Handle different audio tensor dimensions
if audio_tensor.ndim == 1:
# Mono audio: (samples,) -> (samples, 1)
audio_tensor = audio_tensor[:, None]
elif audio_tensor.ndim == 2:
# If shape[1] != 2 and shape[0] == 2, transpose to (samples, channels)
if audio_tensor.shape[1] != 2 and audio_tensor.shape[0] == 2:
audio_tensor = audio_tensor.T
if audio_tensor.shape[1] > 2:
audio_tensor = audio_tensor[:, :2]
elif audio_tensor.ndim == 3:
if audio_tensor.shape[0] == 1:
audio_tensor = audio_tensor.squeeze(0)
else:
audio_tensor = audio_tensor[0]
if audio_tensor.shape[1] != 2 and audio_tensor.shape[0] == 2:
audio_tensor = audio_tensor.T
if audio_tensor.shape[1] > 2:
audio_tensor = audio_tensor[:, :2]
else:
raise ValueError(
f"Unsupported audio tensor shape: {audio_tensor.shape}. "
f"Expected 1D, 2D, or 3D tensor."
)
if audio_tensor.shape[1] > 2:
audio_tensor = audio_tensor[:, :2]
# Convert to int16 if needed
if audio_tensor.dtype != torch.int16:
audio_tensor = torch.clip(audio_tensor, -1.0, 1.0)
audio_tensor = (audio_tensor * 32767.0).to(torch.int16)
# Add audio stream now (before any muxing)
audio_stream = container.add_stream("aac", rate=audio_sample_rate)
audio_stream.codec_context.sample_rate = audio_sample_rate
audio_stream.codec_context.layout = "stereo"
audio_stream.codec_context.time_base = Fraction(1, audio_sample_rate)
# --- Encode video frames ---
for frame_array in video_np:
frame = av.VideoFrame.from_ndarray(frame_array, format="rgb24")
for packet in video_stream.encode(frame):
container.mux(packet)
# Flush video encoder
for packet in video_stream.encode():
container.mux(packet)
# --- Encode audio (after video is done) ---
if audio_stream is not None and audio_tensor is not None:
# Build packed int16 frame: (1, samples*channels)
audio_np = audio_tensor.contiguous().reshape(1, -1).cpu().numpy()
frame_in = av.AudioFrame.from_ndarray(audio_np, format="s16", layout="stereo")
frame_in.sample_rate = audio_sample_rate
# Use AudioResampler to convert s16→fltp (AAC's native format)
cc = audio_stream.codec_context
audio_resampler = av.audio.resampler.AudioResampler(
format=cc.format or "fltp",
layout=cc.layout or "stereo",
rate=cc.sample_rate or audio_sample_rate,
)
audio_next_pts = 0
for rframe in audio_resampler.resample(frame_in):
if rframe.pts is None:
rframe.pts = audio_next_pts
audio_next_pts += rframe.samples
rframe.sample_rate = audio_sample_rate
container.mux(audio_stream.encode(rframe))
# Flush audio encoder
for packet in audio_stream.encode():
container.mux(packet)
# Close container
container.close()
logger.info(f"Saved video{' with audio' if audio is not None else ''} to {output_path}")
return output_path
except ImportError:
logger.warning(
"PyAV (av) library not available. "
"Falling back to saving middle frame as PNG. "
"Install with: pip install av"
)
png_path = output_path.replace(".mp4", ".png")
return MediaStorage._save_middle_frame(video, png_path)
except Exception as e:
logger.error(f"Error encoding video with PyAV: {e}")
import traceback
logger.error(traceback.format_exc())
logger.warning("Falling back to saving middle frame as PNG.")
png_path = output_path.replace(".mp4", ".png")
return MediaStorage._save_middle_frame(video, png_path)
@staticmethod
def _save_gif(video: torch.Tensor, output_path: str, frame_rate: float) -> str:
"""Save video as animated GIF.
Args:
video: Video frames as torch.Tensor (T, H, W, C) uint8
output_path: Output path for GIF
frame_rate: Frames per second
Returns:
Path where the GIF was saved
"""
if not isinstance(video, torch.Tensor):
raise ValueError(f"Expected torch.Tensor for video, got {type(video)}")
# Convert to numpy and then to list of PIL Images
video_np = video.cpu().numpy()
frames = [Image.fromarray(video_np[i]) for i in range(video_np.shape[0])]
# Save as GIF
duration_ms = int(1000 / frame_rate)
frames[0].save(
output_path,
save_all=True,
append_images=frames[1:],
optimize=False,
duration=duration_ms,
loop=0,
)
logger.info(f"Saved video as GIF to {output_path} ({len(frames)} frames)")
return output_path
@staticmethod
def _save_middle_frame(video: torch.Tensor, output_path: str) -> str:
"""Save middle frame of video as PNG.
Args:
video: Video frames as torch.Tensor (T, H, W, C) uint8
output_path: Output path for PNG
Returns:
Path where the frame was saved
"""
if not isinstance(video, torch.Tensor):
raise ValueError(f"Expected torch.Tensor for video, got {type(video)}")
# Extract middle frame
video_np = video.cpu().numpy()
frame_idx = video_np.shape[0] // 2
image = Image.fromarray(video_np[frame_idx])
image.save(output_path)
logger.info(f"Saved frame {frame_idx} to {output_path}")
return output_path