mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-02-17 00:04:57 +08:00
Signed-off-by: Chang Liu (Enterprise Products) <liuc@nvidia.com> Signed-off-by: Chang Liu <9713593+chang-l@users.noreply.github.com> Signed-off-by: Zhenhua Wang <zhenhuaw@nvidia.com> Co-authored-by: Freddy Qi <junq@nvidia.com> Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com> Co-authored-by: Zhenhua Wang <zhenhuaw@nvidia.com>
238 lines
8.4 KiB
Python
238 lines
8.4 KiB
Python
"""Unified output handler for diffusion model outputs."""
|
|
|
|
import os
|
|
from typing import Optional
|
|
|
|
import torch
|
|
from PIL import Image
|
|
|
|
from tensorrt_llm import logger
|
|
from tensorrt_llm.llmapi.visual_gen import MediaOutput
|
|
|
|
|
|
def postprocess_hf_video_tensor(video: torch.Tensor, remove_batch_dim: bool = True) -> torch.Tensor:
|
|
"""Post-process video tensor from HuggingFace pipeline output to final format.
|
|
|
|
HuggingFace pipelines with output_type="pt" return videos in (B, T, C, H, W) format,
|
|
which is different from VAE decoder output format.
|
|
|
|
Args:
|
|
video: Video tensor in (B, T, C, H, W) format from HuggingFace pipeline
|
|
remove_batch_dim: Whether to remove batch dimension. Default True for typical
|
|
single-batch video generation.
|
|
|
|
Returns:
|
|
Post-processed video tensor:
|
|
- If remove_batch_dim=True: (T, H, W, C) uint8 tensor
|
|
- If remove_batch_dim=False: (B, T, H, W, C) uint8 tensor
|
|
|
|
Note:
|
|
Assumes video values are in [-1, 1] range (standard pipeline output).
|
|
"""
|
|
# Remove batch dimension first if requested
|
|
if remove_batch_dim:
|
|
video = video[0] # (B, T, C, H, W) -> (T, C, H, W)
|
|
video = video.permute(0, 2, 3, 1) # (T, C, H, W) -> (T, H, W, C)
|
|
else:
|
|
video = video.permute(0, 1, 3, 4, 2) # (B, T, C, H, W) -> (B, T, H, W, C)
|
|
|
|
# Normalize to [0, 1] range
|
|
video = (video / 2 + 0.5).clamp(0, 1)
|
|
|
|
# Convert to uint8
|
|
video = (video * 255).round().to(torch.uint8)
|
|
|
|
return video
|
|
|
|
|
|
def postprocess_hf_image_tensor(image: torch.Tensor) -> torch.Tensor:
|
|
"""Post-process image tensor from HuggingFace pipeline output to final format.
|
|
|
|
HuggingFace pipelines with output_type="pt" return images in (B, C, H, W) format.
|
|
|
|
Args:
|
|
image: Image tensor in (B, C, H, W) or (C, H, W) format from HuggingFace pipeline
|
|
|
|
Returns:
|
|
Post-processed image tensor in (H, W, C) uint8 format
|
|
|
|
Note:
|
|
Assumes image values are in [-1, 1] range (standard pipeline output).
|
|
"""
|
|
# Remove batch dimension if present
|
|
if image.ndim == 4:
|
|
image = image[0] # (B, C, H, W) -> (C, H, W)
|
|
|
|
# Convert to (H, W, C) format
|
|
image = image.permute(1, 2, 0) # (C, H, W) -> (H, W, C)
|
|
|
|
# Normalize to [0, 1] range
|
|
image = (image / 2 + 0.5).clamp(0, 1)
|
|
|
|
# Convert to uint8
|
|
image = (image * 255).round().to(torch.uint8)
|
|
|
|
return image
|
|
|
|
|
|
class OutputHandler:
|
|
"""Handle saving of generated outputs in various formats.
|
|
|
|
Supports MediaOutput from all models:
|
|
- Video models (WAN): MediaOutput(video=torch.Tensor)
|
|
- Image models: MediaOutput(image=torch.Tensor)
|
|
- Video+Audio models: MediaOutput(video=torch.Tensor, audio=torch.Tensor)
|
|
|
|
Supported output formats:
|
|
- .png: Save single image or middle frame
|
|
- .gif: Save video as animated GIF (no audio)
|
|
- .mp4: Save video with audio (requires diffusers export_utils)
|
|
"""
|
|
|
|
@staticmethod
|
|
def save(output: MediaOutput, output_path: str, frame_rate: float = 24.0):
|
|
"""Save output based on content type and file extension.
|
|
|
|
Args:
|
|
output: MediaOutput containing model outputs (image/video/audio)
|
|
output_path: Path to save the output file
|
|
frame_rate: Frames per second for video output (default: 24.0)
|
|
"""
|
|
if not isinstance(output, MediaOutput):
|
|
raise ValueError(f"Expected output to be MediaOutput, got {type(output)}")
|
|
|
|
file_ext = os.path.splitext(output_path)[1].lower()
|
|
|
|
# Determine content type
|
|
if output.image is not None:
|
|
OutputHandler._save_image(output.image, output_path, file_ext)
|
|
elif output.video is not None:
|
|
OutputHandler._save_video(output.video, output.audio, output_path, file_ext, frame_rate)
|
|
else:
|
|
raise ValueError("Unknown output format. MediaOutput has no image or video data.")
|
|
|
|
@staticmethod
|
|
def _save_image(image: torch.Tensor, output_path: str, file_ext: str):
|
|
"""Save single image output.
|
|
|
|
Args:
|
|
image: Image as torch tensor (H, W, C) uint8
|
|
output_path: Path to save the image
|
|
file_ext: File extension (.png, .jpg, etc.)
|
|
"""
|
|
if file_ext not in [".png", ".jpg", ".jpeg"]:
|
|
logger.warning(f"Image output requested with {file_ext}, defaulting to .png")
|
|
output_path = output_path.replace(file_ext, ".png")
|
|
|
|
# Convert torch.Tensor to PIL Image and save
|
|
image_np = image.cpu().numpy()
|
|
Image.fromarray(image_np).save(output_path)
|
|
logger.info(f"Saved image to {output_path}")
|
|
|
|
@staticmethod
|
|
def _save_video(
|
|
video: torch.Tensor,
|
|
audio: Optional[torch.Tensor],
|
|
output_path: str,
|
|
file_ext: str,
|
|
frame_rate: float,
|
|
):
|
|
"""Save video output with optional audio.
|
|
|
|
Args:
|
|
video: Video frames as torch tensor (T, H, W, C) with dtype uint8
|
|
audio: Optional audio as torch tensor
|
|
output_path: Path to save the video
|
|
file_ext: File extension (.mp4, .gif, .png)
|
|
frame_rate: Frames per second
|
|
"""
|
|
if file_ext == ".mp4":
|
|
OutputHandler._save_mp4(video, audio, output_path, frame_rate)
|
|
elif file_ext == ".gif":
|
|
OutputHandler._save_gif(video, output_path, frame_rate)
|
|
elif file_ext == ".png":
|
|
OutputHandler._save_middle_frame(video, output_path)
|
|
else:
|
|
logger.warning(f"Unsupported video output format: {file_ext}, defaulting to .png")
|
|
output_path = output_path.replace(file_ext, ".png")
|
|
OutputHandler._save_middle_frame(video, output_path)
|
|
|
|
@staticmethod
|
|
def _save_mp4(
|
|
video: torch.Tensor, audio: Optional[torch.Tensor], output_path: str, frame_rate: float
|
|
):
|
|
"""Save video with optional audio as MP4.
|
|
|
|
Args:
|
|
video: Video frames as torch tensor (T, H, W, C) uint8
|
|
audio: Optional audio as torch tensor (float32)
|
|
output_path: Output path for MP4
|
|
frame_rate: Frames per second
|
|
"""
|
|
try:
|
|
from diffusers.pipelines.ltx2.export_utils import encode_video
|
|
|
|
# Prepare audio if present
|
|
audio_prepared = audio.float() if audio is not None else None
|
|
|
|
# encode_video expects (T, H, W, C) uint8 video and float32 audio
|
|
encode_video(
|
|
video,
|
|
fps=frame_rate,
|
|
audio=audio_prepared,
|
|
audio_sample_rate=24000 if audio_prepared is not None else None,
|
|
output_path=output_path,
|
|
)
|
|
logger.info(f"Saved video{' with audio' if audio is not None else ''} to {output_path}")
|
|
|
|
except ImportError:
|
|
logger.warning(
|
|
"diffusers export_utils (encode_video) not available. "
|
|
"Falling back to saving middle frame as PNG."
|
|
)
|
|
png_path = output_path.replace(".mp4", ".png")
|
|
OutputHandler._save_middle_frame(video, png_path)
|
|
|
|
@staticmethod
|
|
def _save_gif(video: torch.Tensor, output_path: str, frame_rate: float):
|
|
"""Save video as animated GIF.
|
|
|
|
Args:
|
|
video: Video frames as torch tensor (T, H, W, C) uint8
|
|
output_path: Output path for GIF
|
|
frame_rate: Frames per second
|
|
"""
|
|
# Convert torch.Tensor to numpy for PIL
|
|
video_np = video.cpu().numpy()
|
|
|
|
# Convert to list of PIL Images
|
|
frames = [Image.fromarray(video_np[i]) for i in range(video_np.shape[0])]
|
|
|
|
# Save as animated GIF
|
|
duration_ms = int(1000 / frame_rate)
|
|
frames[0].save(
|
|
output_path,
|
|
save_all=True,
|
|
append_images=frames[1:],
|
|
optimize=False,
|
|
duration=duration_ms,
|
|
loop=0,
|
|
)
|
|
logger.info(f"Saved video as GIF to {output_path} ({len(frames)} frames)")
|
|
|
|
@staticmethod
|
|
def _save_middle_frame(video: torch.Tensor, output_path: str):
|
|
"""Save middle frame of video as PNG.
|
|
|
|
Args:
|
|
video: Video frames as torch tensor (T, H, W, C) uint8
|
|
output_path: Output path for PNG
|
|
"""
|
|
# Convert torch.Tensor to numpy for PIL
|
|
video_np = video.cpu().numpy()
|
|
|
|
# Extract middle frame
|
|
frame_idx = video_np.shape[0] // 2
|
|
Image.fromarray(video_np[frame_idx]).save(output_path)
|
|
logger.info(f"Saved frame {frame_idx} to {output_path}")
|