TensorRT-LLMs/examples/visual_gen/output_handler.py
Chang Liu 26901e4aa0
[TRTLLM-10612][feat] Initial support of AIGV models in TRTLLM (#11462)
Signed-off-by: Chang Liu (Enterprise Products) <liuc@nvidia.com>
Signed-off-by: Chang Liu <9713593+chang-l@users.noreply.github.com>
Signed-off-by: Zhenhua Wang <zhenhuaw@nvidia.com>
Co-authored-by: Freddy Qi <junq@nvidia.com>
Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com>
Co-authored-by: Zhenhua Wang <zhenhuaw@nvidia.com>
2026-02-14 06:11:11 +08:00

238 lines
8.4 KiB
Python

"""Unified output handler for diffusion model outputs."""
import os
from typing import Optional
import torch
from PIL import Image
from tensorrt_llm import logger
from tensorrt_llm.llmapi.visual_gen import MediaOutput
def postprocess_hf_video_tensor(video: torch.Tensor, remove_batch_dim: bool = True) -> torch.Tensor:
"""Post-process video tensor from HuggingFace pipeline output to final format.
HuggingFace pipelines with output_type="pt" return videos in (B, T, C, H, W) format,
which is different from VAE decoder output format.
Args:
video: Video tensor in (B, T, C, H, W) format from HuggingFace pipeline
remove_batch_dim: Whether to remove batch dimension. Default True for typical
single-batch video generation.
Returns:
Post-processed video tensor:
- If remove_batch_dim=True: (T, H, W, C) uint8 tensor
- If remove_batch_dim=False: (B, T, H, W, C) uint8 tensor
Note:
Assumes video values are in [-1, 1] range (standard pipeline output).
"""
# Remove batch dimension first if requested
if remove_batch_dim:
video = video[0] # (B, T, C, H, W) -> (T, C, H, W)
video = video.permute(0, 2, 3, 1) # (T, C, H, W) -> (T, H, W, C)
else:
video = video.permute(0, 1, 3, 4, 2) # (B, T, C, H, W) -> (B, T, H, W, C)
# Normalize to [0, 1] range
video = (video / 2 + 0.5).clamp(0, 1)
# Convert to uint8
video = (video * 255).round().to(torch.uint8)
return video
def postprocess_hf_image_tensor(image: torch.Tensor) -> torch.Tensor:
"""Post-process image tensor from HuggingFace pipeline output to final format.
HuggingFace pipelines with output_type="pt" return images in (B, C, H, W) format.
Args:
image: Image tensor in (B, C, H, W) or (C, H, W) format from HuggingFace pipeline
Returns:
Post-processed image tensor in (H, W, C) uint8 format
Note:
Assumes image values are in [-1, 1] range (standard pipeline output).
"""
# Remove batch dimension if present
if image.ndim == 4:
image = image[0] # (B, C, H, W) -> (C, H, W)
# Convert to (H, W, C) format
image = image.permute(1, 2, 0) # (C, H, W) -> (H, W, C)
# Normalize to [0, 1] range
image = (image / 2 + 0.5).clamp(0, 1)
# Convert to uint8
image = (image * 255).round().to(torch.uint8)
return image
class OutputHandler:
"""Handle saving of generated outputs in various formats.
Supports MediaOutput from all models:
- Video models (WAN): MediaOutput(video=torch.Tensor)
- Image models: MediaOutput(image=torch.Tensor)
- Video+Audio models: MediaOutput(video=torch.Tensor, audio=torch.Tensor)
Supported output formats:
- .png: Save single image or middle frame
- .gif: Save video as animated GIF (no audio)
- .mp4: Save video with audio (requires diffusers export_utils)
"""
@staticmethod
def save(output: MediaOutput, output_path: str, frame_rate: float = 24.0):
"""Save output based on content type and file extension.
Args:
output: MediaOutput containing model outputs (image/video/audio)
output_path: Path to save the output file
frame_rate: Frames per second for video output (default: 24.0)
"""
if not isinstance(output, MediaOutput):
raise ValueError(f"Expected output to be MediaOutput, got {type(output)}")
file_ext = os.path.splitext(output_path)[1].lower()
# Determine content type
if output.image is not None:
OutputHandler._save_image(output.image, output_path, file_ext)
elif output.video is not None:
OutputHandler._save_video(output.video, output.audio, output_path, file_ext, frame_rate)
else:
raise ValueError("Unknown output format. MediaOutput has no image or video data.")
@staticmethod
def _save_image(image: torch.Tensor, output_path: str, file_ext: str):
"""Save single image output.
Args:
image: Image as torch tensor (H, W, C) uint8
output_path: Path to save the image
file_ext: File extension (.png, .jpg, etc.)
"""
if file_ext not in [".png", ".jpg", ".jpeg"]:
logger.warning(f"Image output requested with {file_ext}, defaulting to .png")
output_path = output_path.replace(file_ext, ".png")
# Convert torch.Tensor to PIL Image and save
image_np = image.cpu().numpy()
Image.fromarray(image_np).save(output_path)
logger.info(f"Saved image to {output_path}")
@staticmethod
def _save_video(
video: torch.Tensor,
audio: Optional[torch.Tensor],
output_path: str,
file_ext: str,
frame_rate: float,
):
"""Save video output with optional audio.
Args:
video: Video frames as torch tensor (T, H, W, C) with dtype uint8
audio: Optional audio as torch tensor
output_path: Path to save the video
file_ext: File extension (.mp4, .gif, .png)
frame_rate: Frames per second
"""
if file_ext == ".mp4":
OutputHandler._save_mp4(video, audio, output_path, frame_rate)
elif file_ext == ".gif":
OutputHandler._save_gif(video, output_path, frame_rate)
elif file_ext == ".png":
OutputHandler._save_middle_frame(video, output_path)
else:
logger.warning(f"Unsupported video output format: {file_ext}, defaulting to .png")
output_path = output_path.replace(file_ext, ".png")
OutputHandler._save_middle_frame(video, output_path)
@staticmethod
def _save_mp4(
video: torch.Tensor, audio: Optional[torch.Tensor], output_path: str, frame_rate: float
):
"""Save video with optional audio as MP4.
Args:
video: Video frames as torch tensor (T, H, W, C) uint8
audio: Optional audio as torch tensor (float32)
output_path: Output path for MP4
frame_rate: Frames per second
"""
try:
from diffusers.pipelines.ltx2.export_utils import encode_video
# Prepare audio if present
audio_prepared = audio.float() if audio is not None else None
# encode_video expects (T, H, W, C) uint8 video and float32 audio
encode_video(
video,
fps=frame_rate,
audio=audio_prepared,
audio_sample_rate=24000 if audio_prepared is not None else None,
output_path=output_path,
)
logger.info(f"Saved video{' with audio' if audio is not None else ''} to {output_path}")
except ImportError:
logger.warning(
"diffusers export_utils (encode_video) not available. "
"Falling back to saving middle frame as PNG."
)
png_path = output_path.replace(".mp4", ".png")
OutputHandler._save_middle_frame(video, png_path)
@staticmethod
def _save_gif(video: torch.Tensor, output_path: str, frame_rate: float):
"""Save video as animated GIF.
Args:
video: Video frames as torch tensor (T, H, W, C) uint8
output_path: Output path for GIF
frame_rate: Frames per second
"""
# Convert torch.Tensor to numpy for PIL
video_np = video.cpu().numpy()
# Convert to list of PIL Images
frames = [Image.fromarray(video_np[i]) for i in range(video_np.shape[0])]
# Save as animated GIF
duration_ms = int(1000 / frame_rate)
frames[0].save(
output_path,
save_all=True,
append_images=frames[1:],
optimize=False,
duration=duration_ms,
loop=0,
)
logger.info(f"Saved video as GIF to {output_path} ({len(frames)} frames)")
@staticmethod
def _save_middle_frame(video: torch.Tensor, output_path: str):
"""Save middle frame of video as PNG.
Args:
video: Video frames as torch tensor (T, H, W, C) uint8
output_path: Output path for PNG
"""
# Convert torch.Tensor to numpy for PIL
video_np = video.cpu().numpy()
# Extract middle frame
frame_idx = video_np.shape[0] // 2
Image.fromarray(video_np[frame_idx]).save(output_path)
logger.info(f"Saved frame {frame_idx} to {output_path}")