TensorRT-LLMs/examples/visual_gen/visual_gen_wan_t2v.py
Chang Liu 26901e4aa0
[TRTLLM-10612][feat] Initial support of AIGV models in TRTLLM (#11462)
Signed-off-by: Chang Liu (Enterprise Products) <liuc@nvidia.com>
Signed-off-by: Chang Liu <9713593+chang-l@users.noreply.github.com>
Signed-off-by: Zhenhua Wang <zhenhuaw@nvidia.com>
Co-authored-by: Freddy Qi <junq@nvidia.com>
Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com>
Co-authored-by: Zhenhua Wang <zhenhuaw@nvidia.com>
2026-02-14 06:11:11 +08:00

229 lines
7.1 KiB
Python
Executable File
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

#!/usr/bin/env python3
"""WAN Text-to-Video generation using TensorRT-LLM Visual Generation."""
import argparse
import time
from output_handler import OutputHandler
from tensorrt_llm import logger
from tensorrt_llm.llmapi.visual_gen import VisualGen, VisualGenParams
# Set logger level to ensure timing logs are printed
logger.set_level("info")
def parse_args():
parser = argparse.ArgumentParser(
description="TRTLLM VisualGen - Wan Text-to-Video Inference Example (supports Wan 2.1 and Wan 2.2)"
)
# Model & Input
parser.add_argument(
"--model_path",
type=str,
required=True,
help="Local path or HuggingFace Hub model ID (e.g., Wan-AI/Wan2.1-T2V-1.3B-Diffusers)",
)
parser.add_argument(
"--revision",
type=str,
default=None,
help="HuggingFace Hub revision (branch, tag, or commit SHA)",
)
parser.add_argument("--prompt", type=str, required=True, help="Text prompt for generation")
parser.add_argument(
"--negative_prompt",
type=str,
default=None,
help="Negative prompt. Default is model-specific.",
)
parser.add_argument(
"--output_path",
type=str,
default="output.png",
help="Path to save the output image/video frame",
)
# Generation Params
parser.add_argument("--height", type=int, default=720, help="Video height")
parser.add_argument("--width", type=int, default=1280, help="Video width")
parser.add_argument("--num_frames", type=int, default=81, help="Number of frames to generate")
parser.add_argument(
"--steps",
type=int,
default=None,
help="Number of denoising steps (default: auto-detect, 50 for Wan2.1, 40 for Wan2.2)",
)
parser.add_argument(
"--guidance_scale",
type=float,
default=None,
help="Guidance scale (default: auto-detect, 5.0 for Wan2.1, 4.0 for Wan2.2)",
)
parser.add_argument(
"--guidance_scale_2",
type=float,
default=None,
help="Second-stage guidance scale for Wan2.2 two-stage denoising (default: 3.0)",
)
parser.add_argument(
"--boundary_ratio",
type=float,
default=None,
help="Custom boundary ratio for two-stage denoising (default: auto-detect)",
)
parser.add_argument("--seed", type=int, default=42, help="Random seed")
# TeaCache Arguments
parser.add_argument(
"--enable_teacache", action="store_true", help="Enable TeaCache acceleration"
)
parser.add_argument(
"--teacache_thresh",
type=float,
default=0.2,
help="TeaCache similarity threshold (rel_l1_thresh)",
)
# Quantization
parser.add_argument(
"--linear_type",
type=str,
default="default",
choices=["default", "trtllm-fp8-per-tensor", "trtllm-fp8-blockwise", "svd-nvfp4"],
help="Linear layer quantization type",
)
# Attention Backend
parser.add_argument(
"--attention_backend",
type=str,
default="VANILLA",
choices=["VANILLA", "TRTLLM"],
help="Attention backend (VANILLA: PyTorch SDPA, TRTLLM: optimized kernels). "
"Note: TRTLLM automatically falls back to VANILLA for cross-attention.",
)
# Parallelism
parser.add_argument(
"--cfg_size",
type=int,
default=1,
choices=[1, 2],
help="CFG parallel size (1 or 2). "
"Distributes positive/negative prompts across GPUs. "
"Example: cfg_size=2 on 4 GPUs -> 2 GPUs per prompt.",
)
parser.add_argument(
"--ulysses_size",
type=int,
default=1,
help="Ulysses sequence parallel size within each CFG group. "
"Distributes sequence across GPUs for longer sequences. "
"Requirements: num_heads (12) and sequence length must both be divisible by ulysses_size. "
"Example: ulysses_size=2 on 4 GPUs with cfg_size=2 -> "
"2 CFG groups × 2 Ulysses ranks = 4 GPUs total.",
)
return parser.parse_args()
def main():
args = parse_args()
# Total workers: cfg_size × ulysses_size
# See ParallelConfig in config.py for detailed parallelism strategy and examples
n_workers = args.cfg_size * args.ulysses_size
# Log Ulysses configuration (validation happens in setup_sequence_parallelism)
if args.ulysses_size > 1:
num_heads = 12 # WAN has 12 attention heads
logger.info(
f"Using Ulysses sequence parallelism: "
f"{num_heads} heads / {args.ulysses_size} ranks = "
f"{num_heads // args.ulysses_size} heads per GPU"
)
# Convert linear_type to quant_config
quant_config = None
if args.linear_type == "trtllm-fp8-per-tensor":
quant_config = {"quant_algo": "FP8", "dynamic": True}
elif args.linear_type == "trtllm-fp8-blockwise":
quant_config = {"quant_algo": "FP8_BLOCK_SCALES", "dynamic": True}
elif args.linear_type == "svd-nvfp4":
quant_config = {"quant_algo": "NVFP4", "dynamic": True}
# 1. Setup Configuration
diffusion_config = {
"model_type": "wan2",
"revision": args.revision,
"attention": {
"backend": args.attention_backend,
},
"teacache": {
"enable_teacache": args.enable_teacache,
"teacache_thresh": args.teacache_thresh,
},
"parallel": {
"dit_cfg_size": args.cfg_size,
"dit_ulysses_size": args.ulysses_size,
},
}
# Add quant_config if specified
if quant_config is not None:
diffusion_config["quant_config"] = quant_config
# 2. Initialize VisualGen
logger.info(
f"Initializing VisualGen: world_size={n_workers} "
f"(cfg_size={args.cfg_size}, ulysses_size={args.ulysses_size})"
)
visual_gen = VisualGen(
model_path=args.model_path,
n_workers=n_workers,
diffusion_config=diffusion_config,
)
try:
# 2. Run Inference
logger.info(f"Generating video for prompt: '{args.prompt}'")
logger.info(f"Negative prompt: '{args.negative_prompt}'")
logger.info(
f"Resolution: {args.height}x{args.width}, Frames: {args.num_frames}, Steps: {args.steps}"
)
start_time = time.time()
output = visual_gen.generate(
inputs={
"prompt": args.prompt,
"negative_prompt": args.negative_prompt,
},
params=VisualGenParams(
height=args.height,
width=args.width,
num_inference_steps=args.steps,
guidance_scale=args.guidance_scale,
seed=args.seed,
num_frames=args.num_frames,
guidance_scale_2=args.guidance_scale_2,
boundary_ratio=args.boundary_ratio,
),
)
end_time = time.time()
logger.info(f"Generation completed in {end_time - start_time:.2f}s")
# 3. Save Output
OutputHandler.save(output, args.output_path, frame_rate=16.0)
finally:
# 4. Shutdown
visual_gen.shutdown()
if __name__ == "__main__":
main()