#!/usr/bin/env python3 """WAN Text-to-Video generation using TensorRT-LLM Visual Generation.""" import argparse import time from output_handler import OutputHandler from tensorrt_llm import logger from tensorrt_llm.llmapi.visual_gen import VisualGen, VisualGenParams # Set logger level to ensure timing logs are printed logger.set_level("info") def parse_args(): parser = argparse.ArgumentParser( description="TRTLLM VisualGen - Wan Text-to-Video Inference Example (supports Wan 2.1 and Wan 2.2)" ) # Model & Input parser.add_argument( "--model_path", type=str, required=True, help="Local path or HuggingFace Hub model ID (e.g., Wan-AI/Wan2.1-T2V-1.3B-Diffusers)", ) parser.add_argument( "--revision", type=str, default=None, help="HuggingFace Hub revision (branch, tag, or commit SHA)", ) parser.add_argument("--prompt", type=str, required=True, help="Text prompt for generation") parser.add_argument( "--negative_prompt", type=str, default=None, help="Negative prompt. Default is model-specific.", ) parser.add_argument( "--output_path", type=str, default="output.png", help="Path to save the output image/video frame", ) # Generation Params parser.add_argument("--height", type=int, default=720, help="Video height") parser.add_argument("--width", type=int, default=1280, help="Video width") parser.add_argument("--num_frames", type=int, default=81, help="Number of frames to generate") parser.add_argument( "--steps", type=int, default=None, help="Number of denoising steps (default: auto-detect, 50 for Wan2.1, 40 for Wan2.2)", ) parser.add_argument( "--guidance_scale", type=float, default=None, help="Guidance scale (default: auto-detect, 5.0 for Wan2.1, 4.0 for Wan2.2)", ) parser.add_argument( "--guidance_scale_2", type=float, default=None, help="Second-stage guidance scale for Wan2.2 two-stage denoising (default: 3.0)", ) parser.add_argument( "--boundary_ratio", type=float, default=None, help="Custom boundary ratio for two-stage denoising (default: auto-detect)", ) parser.add_argument("--seed", type=int, default=42, help="Random seed") # TeaCache Arguments parser.add_argument( "--enable_teacache", action="store_true", help="Enable TeaCache acceleration" ) parser.add_argument( "--teacache_thresh", type=float, default=0.2, help="TeaCache similarity threshold (rel_l1_thresh)", ) # Quantization parser.add_argument( "--linear_type", type=str, default="default", choices=["default", "trtllm-fp8-per-tensor", "trtllm-fp8-blockwise", "svd-nvfp4"], help="Linear layer quantization type", ) # Attention Backend parser.add_argument( "--attention_backend", type=str, default="VANILLA", choices=["VANILLA", "TRTLLM"], help="Attention backend (VANILLA: PyTorch SDPA, TRTLLM: optimized kernels). " "Note: TRTLLM automatically falls back to VANILLA for cross-attention.", ) # Parallelism parser.add_argument( "--cfg_size", type=int, default=1, choices=[1, 2], help="CFG parallel size (1 or 2). " "Distributes positive/negative prompts across GPUs. " "Example: cfg_size=2 on 4 GPUs -> 2 GPUs per prompt.", ) parser.add_argument( "--ulysses_size", type=int, default=1, help="Ulysses sequence parallel size within each CFG group. " "Distributes sequence across GPUs for longer sequences. " "Requirements: num_heads (12) and sequence length must both be divisible by ulysses_size. " "Example: ulysses_size=2 on 4 GPUs with cfg_size=2 -> " "2 CFG groups × 2 Ulysses ranks = 4 GPUs total.", ) return parser.parse_args() def main(): args = parse_args() # Total workers: cfg_size × ulysses_size # See ParallelConfig in config.py for detailed parallelism strategy and examples n_workers = args.cfg_size * args.ulysses_size # Log Ulysses configuration (validation happens in setup_sequence_parallelism) if args.ulysses_size > 1: num_heads = 12 # WAN has 12 attention heads logger.info( f"Using Ulysses sequence parallelism: " f"{num_heads} heads / {args.ulysses_size} ranks = " f"{num_heads // args.ulysses_size} heads per GPU" ) # Convert linear_type to quant_config quant_config = None if args.linear_type == "trtllm-fp8-per-tensor": quant_config = {"quant_algo": "FP8", "dynamic": True} elif args.linear_type == "trtllm-fp8-blockwise": quant_config = {"quant_algo": "FP8_BLOCK_SCALES", "dynamic": True} elif args.linear_type == "svd-nvfp4": quant_config = {"quant_algo": "NVFP4", "dynamic": True} # 1. Setup Configuration diffusion_config = { "model_type": "wan2", "revision": args.revision, "attention": { "backend": args.attention_backend, }, "teacache": { "enable_teacache": args.enable_teacache, "teacache_thresh": args.teacache_thresh, }, "parallel": { "dit_cfg_size": args.cfg_size, "dit_ulysses_size": args.ulysses_size, }, } # Add quant_config if specified if quant_config is not None: diffusion_config["quant_config"] = quant_config # 2. Initialize VisualGen logger.info( f"Initializing VisualGen: world_size={n_workers} " f"(cfg_size={args.cfg_size}, ulysses_size={args.ulysses_size})" ) visual_gen = VisualGen( model_path=args.model_path, n_workers=n_workers, diffusion_config=diffusion_config, ) try: # 2. Run Inference logger.info(f"Generating video for prompt: '{args.prompt}'") logger.info(f"Negative prompt: '{args.negative_prompt}'") logger.info( f"Resolution: {args.height}x{args.width}, Frames: {args.num_frames}, Steps: {args.steps}" ) start_time = time.time() output = visual_gen.generate( inputs={ "prompt": args.prompt, "negative_prompt": args.negative_prompt, }, params=VisualGenParams( height=args.height, width=args.width, num_inference_steps=args.steps, guidance_scale=args.guidance_scale, seed=args.seed, num_frames=args.num_frames, guidance_scale_2=args.guidance_scale_2, boundary_ratio=args.boundary_ratio, ), ) end_time = time.time() logger.info(f"Generation completed in {end_time - start_time:.2f}s") # 3. Save Output OutputHandler.save(output, args.output_path, frame_rate=16.0) finally: # 4. Shutdown visual_gen.shutdown() if __name__ == "__main__": main()