mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-02-17 00:04:57 +08:00
Signed-off-by: Chang Liu (Enterprise Products) <liuc@nvidia.com> Signed-off-by: Chang Liu <9713593+chang-l@users.noreply.github.com> Signed-off-by: Zhenhua Wang <zhenhuaw@nvidia.com> Co-authored-by: Freddy Qi <junq@nvidia.com> Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com> Co-authored-by: Zhenhua Wang <zhenhuaw@nvidia.com>
227 lines
6.9 KiB
Python
227 lines
6.9 KiB
Python
#!/usr/bin/env python3
|
|
"""WAN Image-to-Video generation using TensorRT-LLM Visual Generation."""
|
|
|
|
import argparse
|
|
import time
|
|
|
|
from output_handler import OutputHandler
|
|
|
|
from tensorrt_llm import logger
|
|
from tensorrt_llm.llmapi.visual_gen import VisualGen, VisualGenParams
|
|
|
|
# Set logger level to ensure timing logs are printed
|
|
logger.set_level("info")
|
|
|
|
|
|
def parse_args():
|
|
parser = argparse.ArgumentParser(
|
|
description="TRTLLM VisualGen - Wan Image-to-Video Inference Example (supports Wan 2.1 and Wan 2.2)"
|
|
)
|
|
|
|
# Model & Input
|
|
parser.add_argument(
|
|
"--model_path",
|
|
type=str,
|
|
required=True,
|
|
help="Path to Wan I2V Diffusers model directory (2.1 or 2.2)",
|
|
)
|
|
parser.add_argument(
|
|
"--image_path",
|
|
type=str,
|
|
required=True,
|
|
help="Path to input image for I2V conditioning",
|
|
)
|
|
parser.add_argument(
|
|
"--last_image_path",
|
|
type=str,
|
|
default=None,
|
|
help="Optional path to last frame image for interpolation (Wan 2.1 only)",
|
|
)
|
|
parser.add_argument("--prompt", type=str, required=True, help="Text prompt for generation")
|
|
parser.add_argument(
|
|
"--negative_prompt",
|
|
type=str,
|
|
default=None,
|
|
help="Negative prompt. Default is model-specific.",
|
|
)
|
|
parser.add_argument(
|
|
"--output_path",
|
|
type=str,
|
|
default="output.png",
|
|
help="Path to save the output image/video frame",
|
|
)
|
|
|
|
# Generation Params
|
|
parser.add_argument("--height", type=int, default=720, help="Video height")
|
|
parser.add_argument("--width", type=int, default=1280, help="Video width")
|
|
parser.add_argument("--num_frames", type=int, default=81, help="Number of frames to generate")
|
|
parser.add_argument(
|
|
"--steps",
|
|
type=int,
|
|
default=None,
|
|
help="Number of denoising steps (default: auto-detect, 50 for Wan2.1, 40 for Wan2.2)",
|
|
)
|
|
parser.add_argument(
|
|
"--guidance_scale",
|
|
type=float,
|
|
default=None,
|
|
help="Guidance scale (default: auto-detect, 5.0 for Wan2.1, 4.0 for Wan2.2)",
|
|
)
|
|
parser.add_argument(
|
|
"--guidance_scale_2",
|
|
type=float,
|
|
default=None,
|
|
help="Second-stage guidance scale for Wan2.2 two-stage denoising (default: 3.0)",
|
|
)
|
|
parser.add_argument(
|
|
"--boundary_ratio",
|
|
type=float,
|
|
default=None,
|
|
help="Custom boundary ratio for two-stage denoising (default: auto-detect)",
|
|
)
|
|
parser.add_argument("--seed", type=int, default=42, help="Random seed")
|
|
|
|
# TeaCache Arguments
|
|
parser.add_argument(
|
|
"--enable_teacache", action="store_true", help="Enable TeaCache acceleration"
|
|
)
|
|
parser.add_argument(
|
|
"--teacache_thresh",
|
|
type=float,
|
|
default=0.2,
|
|
help="TeaCache similarity threshold (rel_l1_thresh)",
|
|
)
|
|
|
|
# Quantization
|
|
parser.add_argument(
|
|
"--linear_type",
|
|
type=str,
|
|
default="default",
|
|
choices=["default", "trtllm-fp8-per-tensor", "trtllm-fp8-blockwise", "svd-nvfp4"],
|
|
help="Linear layer quantization type",
|
|
)
|
|
|
|
# Attention Backend
|
|
parser.add_argument(
|
|
"--attention_backend",
|
|
type=str,
|
|
default="VANILLA",
|
|
choices=["VANILLA", "TRTLLM"],
|
|
help="Attention backend (VANILLA: PyTorch SDPA, TRTLLM: optimized kernels). "
|
|
"Note: TRTLLM automatically falls back to VANILLA for cross-attention.",
|
|
)
|
|
|
|
# Parallelism
|
|
parser.add_argument(
|
|
"--cfg_size",
|
|
type=int,
|
|
default=1,
|
|
choices=[1, 2],
|
|
help="CFG parallel size (1 or 2). Set to 2 for CFG Parallelism.",
|
|
)
|
|
parser.add_argument(
|
|
"--ulysses_size",
|
|
type=int,
|
|
default=1,
|
|
help="Ulysses (sequence) parallel size within each CFG group.",
|
|
)
|
|
|
|
return parser.parse_args()
|
|
|
|
|
|
def main():
|
|
args = parse_args()
|
|
|
|
# world_size = cfg_size * ulysses_size
|
|
# Example: cfg_size=2, ulysses_size=4 -> 8 GPUs
|
|
# GPU 0-3: CFG group 0 (positive prompt), internal Ulysses parallel
|
|
# GPU 4-7: CFG group 1 (negative prompt), internal Ulysses parallel
|
|
n_workers = args.cfg_size * args.ulysses_size
|
|
|
|
# Convert linear_type to quant_config
|
|
quant_config = None
|
|
if args.linear_type == "trtllm-fp8-per-tensor":
|
|
quant_config = {"quant_algo": "FP8", "dynamic": True}
|
|
elif args.linear_type == "trtllm-fp8-blockwise":
|
|
quant_config = {"quant_algo": "FP8_BLOCK_SCALES", "dynamic": True}
|
|
elif args.linear_type == "svd-nvfp4":
|
|
quant_config = {"quant_algo": "NVFP4", "dynamic": True}
|
|
|
|
# 1. Setup Configuration
|
|
diffusion_config = {
|
|
"model_type": "wan2",
|
|
"attention": {
|
|
"backend": args.attention_backend,
|
|
},
|
|
"teacache": {
|
|
"enable_teacache": args.enable_teacache,
|
|
"teacache_thresh": args.teacache_thresh,
|
|
},
|
|
"parallel": {
|
|
"dit_cfg_size": args.cfg_size,
|
|
"dit_ulysses_size": args.ulysses_size,
|
|
},
|
|
}
|
|
|
|
# Add quant_config if specified
|
|
if quant_config is not None:
|
|
diffusion_config["quant_config"] = quant_config
|
|
|
|
# 2. Initialize VisualGen
|
|
logger.info(
|
|
f"Initializing VisualGen: world_size={n_workers} "
|
|
f"(cfg_size={args.cfg_size}, ulysses_size={args.ulysses_size})"
|
|
)
|
|
visual_gen = VisualGen(
|
|
model_path=args.model_path,
|
|
n_workers=n_workers,
|
|
diffusion_config=diffusion_config,
|
|
)
|
|
|
|
try:
|
|
# 2. Run Inference
|
|
logger.info(f"Generating video for prompt: '{args.prompt}'")
|
|
logger.info(f"Negative prompt: '{args.negative_prompt}'")
|
|
logger.info(f"Input image: {args.image_path}")
|
|
if args.last_image_path:
|
|
logger.info(f"Last frame image: {args.last_image_path}")
|
|
logger.info(
|
|
f"Resolution: {args.height}x{args.width}, Frames: {args.num_frames}, Steps: {args.steps}"
|
|
)
|
|
|
|
start_time = time.time()
|
|
|
|
# Build parameters with explicit I2V and Wan 2.2 support
|
|
output = visual_gen.generate(
|
|
inputs={
|
|
"prompt": args.prompt,
|
|
"negative_prompt": args.negative_prompt,
|
|
},
|
|
params=VisualGenParams(
|
|
height=args.height,
|
|
width=args.width,
|
|
num_inference_steps=args.steps,
|
|
guidance_scale=args.guidance_scale,
|
|
seed=args.seed,
|
|
num_frames=args.num_frames,
|
|
input_reference=args.image_path,
|
|
last_image=args.last_image_path if args.last_image_path else None,
|
|
guidance_scale_2=args.guidance_scale_2,
|
|
boundary_ratio=args.boundary_ratio,
|
|
),
|
|
)
|
|
|
|
end_time = time.time()
|
|
logger.info(f"Generation completed in {end_time - start_time:.2f}s")
|
|
|
|
# 3. Save Output
|
|
OutputHandler.save(output, args.output_path, frame_rate=16.0)
|
|
|
|
finally:
|
|
# 4. Shutdown
|
|
visual_gen.shutdown()
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|