mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-02-16 15:55:08 +08:00
Signed-off-by: Chang Liu (Enterprise Products) <liuc@nvidia.com> Signed-off-by: Chang Liu <9713593+chang-l@users.noreply.github.com> Signed-off-by: Zhenhua Wang <zhenhuaw@nvidia.com> Co-authored-by: Freddy Qi <junq@nvidia.com> Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com> Co-authored-by: Zhenhua Wang <zhenhuaw@nvidia.com>
229 lines
7.1 KiB
Python
Executable File
229 lines
7.1 KiB
Python
Executable File
#!/usr/bin/env python3
|
||
"""WAN Text-to-Video generation using TensorRT-LLM Visual Generation."""
|
||
|
||
import argparse
|
||
import time
|
||
|
||
from output_handler import OutputHandler
|
||
|
||
from tensorrt_llm import logger
|
||
from tensorrt_llm.llmapi.visual_gen import VisualGen, VisualGenParams
|
||
|
||
# Set logger level to ensure timing logs are printed
|
||
logger.set_level("info")
|
||
|
||
|
||
def parse_args():
|
||
parser = argparse.ArgumentParser(
|
||
description="TRTLLM VisualGen - Wan Text-to-Video Inference Example (supports Wan 2.1 and Wan 2.2)"
|
||
)
|
||
|
||
# Model & Input
|
||
parser.add_argument(
|
||
"--model_path",
|
||
type=str,
|
||
required=True,
|
||
help="Local path or HuggingFace Hub model ID (e.g., Wan-AI/Wan2.1-T2V-1.3B-Diffusers)",
|
||
)
|
||
parser.add_argument(
|
||
"--revision",
|
||
type=str,
|
||
default=None,
|
||
help="HuggingFace Hub revision (branch, tag, or commit SHA)",
|
||
)
|
||
parser.add_argument("--prompt", type=str, required=True, help="Text prompt for generation")
|
||
parser.add_argument(
|
||
"--negative_prompt",
|
||
type=str,
|
||
default=None,
|
||
help="Negative prompt. Default is model-specific.",
|
||
)
|
||
parser.add_argument(
|
||
"--output_path",
|
||
type=str,
|
||
default="output.png",
|
||
help="Path to save the output image/video frame",
|
||
)
|
||
|
||
# Generation Params
|
||
parser.add_argument("--height", type=int, default=720, help="Video height")
|
||
parser.add_argument("--width", type=int, default=1280, help="Video width")
|
||
parser.add_argument("--num_frames", type=int, default=81, help="Number of frames to generate")
|
||
parser.add_argument(
|
||
"--steps",
|
||
type=int,
|
||
default=None,
|
||
help="Number of denoising steps (default: auto-detect, 50 for Wan2.1, 40 for Wan2.2)",
|
||
)
|
||
parser.add_argument(
|
||
"--guidance_scale",
|
||
type=float,
|
||
default=None,
|
||
help="Guidance scale (default: auto-detect, 5.0 for Wan2.1, 4.0 for Wan2.2)",
|
||
)
|
||
parser.add_argument(
|
||
"--guidance_scale_2",
|
||
type=float,
|
||
default=None,
|
||
help="Second-stage guidance scale for Wan2.2 two-stage denoising (default: 3.0)",
|
||
)
|
||
parser.add_argument(
|
||
"--boundary_ratio",
|
||
type=float,
|
||
default=None,
|
||
help="Custom boundary ratio for two-stage denoising (default: auto-detect)",
|
||
)
|
||
parser.add_argument("--seed", type=int, default=42, help="Random seed")
|
||
|
||
# TeaCache Arguments
|
||
parser.add_argument(
|
||
"--enable_teacache", action="store_true", help="Enable TeaCache acceleration"
|
||
)
|
||
parser.add_argument(
|
||
"--teacache_thresh",
|
||
type=float,
|
||
default=0.2,
|
||
help="TeaCache similarity threshold (rel_l1_thresh)",
|
||
)
|
||
|
||
# Quantization
|
||
parser.add_argument(
|
||
"--linear_type",
|
||
type=str,
|
||
default="default",
|
||
choices=["default", "trtllm-fp8-per-tensor", "trtllm-fp8-blockwise", "svd-nvfp4"],
|
||
help="Linear layer quantization type",
|
||
)
|
||
|
||
# Attention Backend
|
||
parser.add_argument(
|
||
"--attention_backend",
|
||
type=str,
|
||
default="VANILLA",
|
||
choices=["VANILLA", "TRTLLM"],
|
||
help="Attention backend (VANILLA: PyTorch SDPA, TRTLLM: optimized kernels). "
|
||
"Note: TRTLLM automatically falls back to VANILLA for cross-attention.",
|
||
)
|
||
|
||
# Parallelism
|
||
parser.add_argument(
|
||
"--cfg_size",
|
||
type=int,
|
||
default=1,
|
||
choices=[1, 2],
|
||
help="CFG parallel size (1 or 2). "
|
||
"Distributes positive/negative prompts across GPUs. "
|
||
"Example: cfg_size=2 on 4 GPUs -> 2 GPUs per prompt.",
|
||
)
|
||
parser.add_argument(
|
||
"--ulysses_size",
|
||
type=int,
|
||
default=1,
|
||
help="Ulysses sequence parallel size within each CFG group. "
|
||
"Distributes sequence across GPUs for longer sequences. "
|
||
"Requirements: num_heads (12) and sequence length must both be divisible by ulysses_size. "
|
||
"Example: ulysses_size=2 on 4 GPUs with cfg_size=2 -> "
|
||
"2 CFG groups × 2 Ulysses ranks = 4 GPUs total.",
|
||
)
|
||
|
||
return parser.parse_args()
|
||
|
||
|
||
def main():
|
||
args = parse_args()
|
||
|
||
# Total workers: cfg_size × ulysses_size
|
||
# See ParallelConfig in config.py for detailed parallelism strategy and examples
|
||
n_workers = args.cfg_size * args.ulysses_size
|
||
|
||
# Log Ulysses configuration (validation happens in setup_sequence_parallelism)
|
||
if args.ulysses_size > 1:
|
||
num_heads = 12 # WAN has 12 attention heads
|
||
logger.info(
|
||
f"Using Ulysses sequence parallelism: "
|
||
f"{num_heads} heads / {args.ulysses_size} ranks = "
|
||
f"{num_heads // args.ulysses_size} heads per GPU"
|
||
)
|
||
|
||
# Convert linear_type to quant_config
|
||
quant_config = None
|
||
if args.linear_type == "trtllm-fp8-per-tensor":
|
||
quant_config = {"quant_algo": "FP8", "dynamic": True}
|
||
elif args.linear_type == "trtllm-fp8-blockwise":
|
||
quant_config = {"quant_algo": "FP8_BLOCK_SCALES", "dynamic": True}
|
||
elif args.linear_type == "svd-nvfp4":
|
||
quant_config = {"quant_algo": "NVFP4", "dynamic": True}
|
||
|
||
# 1. Setup Configuration
|
||
diffusion_config = {
|
||
"model_type": "wan2",
|
||
"revision": args.revision,
|
||
"attention": {
|
||
"backend": args.attention_backend,
|
||
},
|
||
"teacache": {
|
||
"enable_teacache": args.enable_teacache,
|
||
"teacache_thresh": args.teacache_thresh,
|
||
},
|
||
"parallel": {
|
||
"dit_cfg_size": args.cfg_size,
|
||
"dit_ulysses_size": args.ulysses_size,
|
||
},
|
||
}
|
||
|
||
# Add quant_config if specified
|
||
if quant_config is not None:
|
||
diffusion_config["quant_config"] = quant_config
|
||
|
||
# 2. Initialize VisualGen
|
||
logger.info(
|
||
f"Initializing VisualGen: world_size={n_workers} "
|
||
f"(cfg_size={args.cfg_size}, ulysses_size={args.ulysses_size})"
|
||
)
|
||
visual_gen = VisualGen(
|
||
model_path=args.model_path,
|
||
n_workers=n_workers,
|
||
diffusion_config=diffusion_config,
|
||
)
|
||
|
||
try:
|
||
# 2. Run Inference
|
||
logger.info(f"Generating video for prompt: '{args.prompt}'")
|
||
logger.info(f"Negative prompt: '{args.negative_prompt}'")
|
||
logger.info(
|
||
f"Resolution: {args.height}x{args.width}, Frames: {args.num_frames}, Steps: {args.steps}"
|
||
)
|
||
|
||
start_time = time.time()
|
||
|
||
output = visual_gen.generate(
|
||
inputs={
|
||
"prompt": args.prompt,
|
||
"negative_prompt": args.negative_prompt,
|
||
},
|
||
params=VisualGenParams(
|
||
height=args.height,
|
||
width=args.width,
|
||
num_inference_steps=args.steps,
|
||
guidance_scale=args.guidance_scale,
|
||
seed=args.seed,
|
||
num_frames=args.num_frames,
|
||
guidance_scale_2=args.guidance_scale_2,
|
||
boundary_ratio=args.boundary_ratio,
|
||
),
|
||
)
|
||
|
||
end_time = time.time()
|
||
logger.info(f"Generation completed in {end_time - start_time:.2f}s")
|
||
|
||
# 3. Save Output
|
||
OutputHandler.save(output, args.output_path, frame_rate=16.0)
|
||
|
||
finally:
|
||
# 4. Shutdown
|
||
visual_gen.shutdown()
|
||
|
||
|
||
if __name__ == "__main__":
|
||
main()
|