TensorRT-LLMs/examples/visual_gen/hf_wan.py
Chang Liu 26901e4aa0
[TRTLLM-10612][feat] Initial support of AIGV models in TRTLLM (#11462)
Signed-off-by: Chang Liu (Enterprise Products) <liuc@nvidia.com>
Signed-off-by: Chang Liu <9713593+chang-l@users.noreply.github.com>
Signed-off-by: Zhenhua Wang <zhenhuaw@nvidia.com>
Co-authored-by: Freddy Qi <junq@nvidia.com>
Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com>
Co-authored-by: Zhenhua Wang <zhenhuaw@nvidia.com>
2026-02-14 06:11:11 +08:00

142 lines
4.1 KiB
Python
Executable File

#!/usr/bin/env python3
"""Baseline test for WAN using official diffusers library."""
import sys
import torch
from output_handler import OutputHandler, postprocess_hf_video_tensor
from tensorrt_llm._torch.visual_gen import MediaOutput
def test_wan_baseline(
model_path: str,
output_path: str,
prompt: str = "A cute cat playing piano",
height: int = 480,
width: int = 832,
num_frames: int = 33,
num_inference_steps: int = 50,
guidance_scale: float = 7.0,
seed: int = 42,
):
"""Test WAN video generation with official diffusers."""
from diffusers import WanPipeline
print("=" * 80)
print("WAN Baseline Test (Official Diffusers)")
print("=" * 80)
print()
# Load pipeline
print(f"Loading WAN pipeline from {model_path}...")
pipe = WanPipeline.from_pretrained(model_path, torch_dtype=torch.bfloat16)
pipe.to("cuda")
print("✅ Pipeline loaded")
print()
# Check model states
print("Model Training States:")
print(f" text_encoder.training: {pipe.text_encoder.training}")
print(f" transformer.training: {pipe.transformer.training}")
print(f" vae.training: {pipe.vae.training}")
print()
# Generate video
print(f"Generating video: '{prompt}'")
print(
f"Parameters: {height}x{width}, {num_frames} frames, {num_inference_steps} steps, guidance={guidance_scale}"
)
print()
# Set random seed
generator = torch.Generator(device="cuda").manual_seed(seed)
result = pipe(
prompt=prompt,
height=height,
width=width,
num_frames=num_frames,
num_inference_steps=num_inference_steps,
guidance_scale=guidance_scale,
generator=generator,
output_type="pt",
return_dict=False,
)
video = result[0]
# Post-process video tensor: (B, T, C, H, W) -> (T, H, W, C) uint8
video = postprocess_hf_video_tensor(video, remove_batch_dim=True)
print("=" * 80)
print("Generation Complete!")
print("=" * 80)
print(f"Video shape: {video.shape}")
print(f"Video dtype: {video.dtype}")
print()
# Save output
print(f"Saving output to {output_path}...")
OutputHandler.save(output=MediaOutput(video=video), output_path=output_path, frame_rate=24.0)
print(f"✅ Saved to {output_path}")
print()
print("=" * 80)
print("WAN BASELINE TEST PASSED ✅")
print("=" * 80)
return video
if __name__ == "__main__":
import argparse
parser = argparse.ArgumentParser(
description="HuggingFace Baseline - WAN Text-to-Video Generation"
)
# Model & Input
parser.add_argument(
"--model_path",
type=str,
default="/llm-models/Wan2.1-T2V-1.3B-Diffusers/",
help="Path to WAN model",
)
parser.add_argument(
"--prompt", type=str, default="A cute cat playing piano", help="Text prompt for generation"
)
parser.add_argument(
"--output_path", type=str, default="wan_baseline.gif", help="Output file path"
)
# Generation parameters
parser.add_argument("--height", type=int, default=480, help="Video height")
parser.add_argument("--width", type=int, default=832, help="Video width")
parser.add_argument("--num_frames", type=int, default=33, help="Number of frames to generate")
parser.add_argument("--steps", type=int, default=50, help="Number of denoising steps")
parser.add_argument(
"--guidance_scale", type=float, default=7.0, help="Classifier-free guidance scale"
)
parser.add_argument("--seed", type=int, default=42, help="Random seed")
args = parser.parse_args()
try:
test_wan_baseline(
args.model_path,
args.output_path,
prompt=args.prompt,
height=args.height,
width=args.width,
num_frames=args.num_frames,
num_inference_steps=args.steps,
guidance_scale=args.guidance_scale,
seed=args.seed,
)
except Exception as e:
print(f"\n❌ ERROR: {e}")
import traceback
traceback.print_exc()
sys.exit(1)