TensorRT-LLMs/examples/visual_gen/visual_gen_wan_i2v.py
Chang Liu 26901e4aa0
[TRTLLM-10612][feat] Initial support of AIGV models in TRTLLM (#11462)
Signed-off-by: Chang Liu (Enterprise Products) <liuc@nvidia.com>
Signed-off-by: Chang Liu <9713593+chang-l@users.noreply.github.com>
Signed-off-by: Zhenhua Wang <zhenhuaw@nvidia.com>
Co-authored-by: Freddy Qi <junq@nvidia.com>
Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com>
Co-authored-by: Zhenhua Wang <zhenhuaw@nvidia.com>
2026-02-14 06:11:11 +08:00

227 lines
6.9 KiB
Python

#!/usr/bin/env python3
"""WAN Image-to-Video generation using TensorRT-LLM Visual Generation."""
import argparse
import time
from output_handler import OutputHandler
from tensorrt_llm import logger
from tensorrt_llm.llmapi.visual_gen import VisualGen, VisualGenParams
# Set logger level to ensure timing logs are printed
logger.set_level("info")
def parse_args():
parser = argparse.ArgumentParser(
description="TRTLLM VisualGen - Wan Image-to-Video Inference Example (supports Wan 2.1 and Wan 2.2)"
)
# Model & Input
parser.add_argument(
"--model_path",
type=str,
required=True,
help="Path to Wan I2V Diffusers model directory (2.1 or 2.2)",
)
parser.add_argument(
"--image_path",
type=str,
required=True,
help="Path to input image for I2V conditioning",
)
parser.add_argument(
"--last_image_path",
type=str,
default=None,
help="Optional path to last frame image for interpolation (Wan 2.1 only)",
)
parser.add_argument("--prompt", type=str, required=True, help="Text prompt for generation")
parser.add_argument(
"--negative_prompt",
type=str,
default=None,
help="Negative prompt. Default is model-specific.",
)
parser.add_argument(
"--output_path",
type=str,
default="output.png",
help="Path to save the output image/video frame",
)
# Generation Params
parser.add_argument("--height", type=int, default=720, help="Video height")
parser.add_argument("--width", type=int, default=1280, help="Video width")
parser.add_argument("--num_frames", type=int, default=81, help="Number of frames to generate")
parser.add_argument(
"--steps",
type=int,
default=None,
help="Number of denoising steps (default: auto-detect, 50 for Wan2.1, 40 for Wan2.2)",
)
parser.add_argument(
"--guidance_scale",
type=float,
default=None,
help="Guidance scale (default: auto-detect, 5.0 for Wan2.1, 4.0 for Wan2.2)",
)
parser.add_argument(
"--guidance_scale_2",
type=float,
default=None,
help="Second-stage guidance scale for Wan2.2 two-stage denoising (default: 3.0)",
)
parser.add_argument(
"--boundary_ratio",
type=float,
default=None,
help="Custom boundary ratio for two-stage denoising (default: auto-detect)",
)
parser.add_argument("--seed", type=int, default=42, help="Random seed")
# TeaCache Arguments
parser.add_argument(
"--enable_teacache", action="store_true", help="Enable TeaCache acceleration"
)
parser.add_argument(
"--teacache_thresh",
type=float,
default=0.2,
help="TeaCache similarity threshold (rel_l1_thresh)",
)
# Quantization
parser.add_argument(
"--linear_type",
type=str,
default="default",
choices=["default", "trtllm-fp8-per-tensor", "trtllm-fp8-blockwise", "svd-nvfp4"],
help="Linear layer quantization type",
)
# Attention Backend
parser.add_argument(
"--attention_backend",
type=str,
default="VANILLA",
choices=["VANILLA", "TRTLLM"],
help="Attention backend (VANILLA: PyTorch SDPA, TRTLLM: optimized kernels). "
"Note: TRTLLM automatically falls back to VANILLA for cross-attention.",
)
# Parallelism
parser.add_argument(
"--cfg_size",
type=int,
default=1,
choices=[1, 2],
help="CFG parallel size (1 or 2). Set to 2 for CFG Parallelism.",
)
parser.add_argument(
"--ulysses_size",
type=int,
default=1,
help="Ulysses (sequence) parallel size within each CFG group.",
)
return parser.parse_args()
def main():
args = parse_args()
# world_size = cfg_size * ulysses_size
# Example: cfg_size=2, ulysses_size=4 -> 8 GPUs
# GPU 0-3: CFG group 0 (positive prompt), internal Ulysses parallel
# GPU 4-7: CFG group 1 (negative prompt), internal Ulysses parallel
n_workers = args.cfg_size * args.ulysses_size
# Convert linear_type to quant_config
quant_config = None
if args.linear_type == "trtllm-fp8-per-tensor":
quant_config = {"quant_algo": "FP8", "dynamic": True}
elif args.linear_type == "trtllm-fp8-blockwise":
quant_config = {"quant_algo": "FP8_BLOCK_SCALES", "dynamic": True}
elif args.linear_type == "svd-nvfp4":
quant_config = {"quant_algo": "NVFP4", "dynamic": True}
# 1. Setup Configuration
diffusion_config = {
"model_type": "wan2",
"attention": {
"backend": args.attention_backend,
},
"teacache": {
"enable_teacache": args.enable_teacache,
"teacache_thresh": args.teacache_thresh,
},
"parallel": {
"dit_cfg_size": args.cfg_size,
"dit_ulysses_size": args.ulysses_size,
},
}
# Add quant_config if specified
if quant_config is not None:
diffusion_config["quant_config"] = quant_config
# 2. Initialize VisualGen
logger.info(
f"Initializing VisualGen: world_size={n_workers} "
f"(cfg_size={args.cfg_size}, ulysses_size={args.ulysses_size})"
)
visual_gen = VisualGen(
model_path=args.model_path,
n_workers=n_workers,
diffusion_config=diffusion_config,
)
try:
# 2. Run Inference
logger.info(f"Generating video for prompt: '{args.prompt}'")
logger.info(f"Negative prompt: '{args.negative_prompt}'")
logger.info(f"Input image: {args.image_path}")
if args.last_image_path:
logger.info(f"Last frame image: {args.last_image_path}")
logger.info(
f"Resolution: {args.height}x{args.width}, Frames: {args.num_frames}, Steps: {args.steps}"
)
start_time = time.time()
# Build parameters with explicit I2V and Wan 2.2 support
output = visual_gen.generate(
inputs={
"prompt": args.prompt,
"negative_prompt": args.negative_prompt,
},
params=VisualGenParams(
height=args.height,
width=args.width,
num_inference_steps=args.steps,
guidance_scale=args.guidance_scale,
seed=args.seed,
num_frames=args.num_frames,
input_reference=args.image_path,
last_image=args.last_image_path if args.last_image_path else None,
guidance_scale_2=args.guidance_scale_2,
boundary_ratio=args.boundary_ratio,
),
)
end_time = time.time()
logger.info(f"Generation completed in {end_time - start_time:.2f}s")
# 3. Save Output
OutputHandler.save(output, args.output_path, frame_rate=16.0)
finally:
# 4. Shutdown
visual_gen.shutdown()
if __name__ == "__main__":
main()