diff --git a/README.md b/README.md
index 31ecc45440..25533f6552 100644
--- a/README.md
+++ b/README.md
@@ -5,9 +5,6 @@ TensorRT LLM
TensorRT LLM provides users with an easy-to-use Python API to define Large Language Models (LLMs) and supports
state-of-the-art optimizations to perform inference efficiently on NVIDIA GPUs.
-š TensorRT LLM is experimenting with Image&Video Generation models in [TensorRT-LLM/feat/visual_gen](https://github.com/NVIDIA/TensorRT-LLM/tree/feat/visual_gen/tensorrt_llm/visual_gen) branch.
-This branch is a prototype and not stable for production use. PRs are not accepted.
-
[](https://nvidia.github.io/TensorRT-LLM/)
[](https://deepwiki.com/NVIDIA/TensorRT-LLM)
[](https://www.python.org/downloads/release/python-3123/)
diff --git a/examples/visual_gen/README.md b/examples/visual_gen/README.md
new file mode 100644
index 0000000000..4dfd5e07e0
--- /dev/null
+++ b/examples/visual_gen/README.md
@@ -0,0 +1,172 @@
+# Visual Generation Examples
+
+Quick reference for running visual generation models (WAN).
+
+## Prerequisites
+
+```bash
+# Install dependencies (from repository root)
+pip install -r requirements-dev.txt
+pip install git+https://github.com/huggingface/diffusers.git
+pip install av
+```
+
+## Quick Start
+
+```bash
+# Set MODEL_ROOT to your model directory (required for examples)
+export MODEL_ROOT=/llm-models
+# Optional: PROJECT_ROOT defaults to repo root when run from examples/visual_gen
+
+# Run all examples (auto-detects GPUs)
+cd examples/visual_gen
+./visual_gen_examples.sh
+```
+
+
+## Environment Variables
+
+| Variable | Default | Description |
+|----------|---------|-------------|
+| `PROJECT_ROOT` | Auto-detected | Path to repository root (set when running from `examples/visual_gen`) |
+| `MODEL_ROOT` | `/llm-models` | Path to model directory |
+| `TLLM_LOG_LEVEL` | `INFO` | Logging level |
+
+---
+
+## WAN (Text-to-Video)
+
+### Basic Usage
+
+**Single GPU:**
+```bash
+python visual_gen_wan_t2v.py \
+ --model_path ${MODEL_ROOT}/Wan2.1-T2V-1.3B-Diffusers \
+ --prompt "A cute cat playing piano" \
+ --height 480 --width 832 --num_frames 33 \
+ --output_path output.mp4
+```
+
+**With TeaCache:**
+```bash
+python visual_gen_wan_t2v.py \
+ --model_path ${MODEL_ROOT}/Wan2.1-T2V-1.3B-Diffusers \
+ --prompt "A cute cat playing piano" \
+ --height 480 --width 832 --num_frames 33 \
+ --enable_teacache \
+ --output_path output.mp4
+```
+
+### Multi-GPU Parallelism
+
+WAN supports two parallelism modes that can be combined:
+- **CFG Parallelism**: Split positive/negative prompts across GPUs
+- **Ulysses Parallelism**: Split sequence across GPUs for longer sequences
+
+
+**Ulysses Only (2 GPUs):**
+```bash
+python visual_gen_wan_t2v.py \
+ --model_path ${MODEL_ROOT}/Wan2.1-T2V-1.3B-Diffusers \
+ --prompt "A cute cat playing piano" \
+ --height 480 --width 832 --num_frames 33 \
+ --attention_backend TRTLLM \
+ --cfg_size 1 --ulysses_size 2 \
+ --output_path output.mp4
+```
+GPU Layout: GPU 0-1 share sequence (6 heads each)
+
+**CFG Only (2 GPUs):**
+```bash
+python visual_gen_wan_t2v.py \
+ --model_path ${MODEL_ROOT}/Wan2.1-T2V-1.3B-Diffusers \
+ --prompt "A cute cat playing piano" \
+ --height 480 --width 832 --num_frames 33 \
+ --attention_backend TRTLLM \
+ --cfg_size 2 --ulysses_size 1 \
+ --output_path output.mp4
+```
+GPU Layout: GPU 0 (positive) | GPU 1 (negative)
+
+**CFG + Ulysses (4 GPUs):**
+```bash
+python visual_gen_wan_t2v.py \
+ --model_path ${MODEL_ROOT}/Wan2.1-T2V-1.3B-Diffusers \
+ --prompt "A cute cat playing piano" \
+ --height 480 --width 832 --num_frames 33 \
+ --attention_backend TRTLLM \
+ --cfg_size 2 --ulysses_size 2 \
+ --output_path output.mp4
+```
+GPU Layout: GPU 0-1 (positive, Ulysses) | GPU 2-3 (negative, Ulysses)
+
+**Large-Scale (8 GPUs):**
+```bash
+python visual_gen_wan_t2v.py \
+ --model_path ${MODEL_ROOT}/Wan2.1-T2V-1.3B-Diffusers \
+ --prompt "A cute cat playing piano" \
+ --height 480 --width 832 --num_frames 33 \
+ --attention_backend TRTLLM \
+ --cfg_size 2 --ulysses_size 4 \
+ --output_path output.mp4
+```
+GPU Layout: GPU 0-3 (positive) | GPU 4-7 (negative)
+
+---
+
+## Common Arguments
+
+| Argument | WAN | Default | Description |
+|----------|-----|---------|-------------|
+| `--height` | ā | 720 | Output height |
+| `--width` | ā | 1280 | Output width |
+| `--num_frames` | ā | 81 | Number of frames |
+| `--steps` | ā | 50 | Denoising steps |
+| `--guidance_scale` | ā | 5.0 | CFG guidance strength |
+| `--seed` | ā | 42 | Random seed |
+| `--enable_teacache` | ā | False | Cache optimization |
+| `--teacache_thresh` | ā | 0.2 | TeaCache similarity threshold |
+| `--attention_backend` | ā | VANILLA | VANILLA or TRTLLM |
+| `--cfg_size` | ā | 1 | CFG parallelism |
+| `--ulysses_size` | ā | 1 | Sequence parallelism |
+| `--linear_type` | ā | default | Quantization type |
+
+## Troubleshooting
+
+**Out of Memory:**
+- Use quantization: `--linear_type trtllm-fp8-blockwise`
+- Reduce resolution or frames
+- Enable TeaCache: `--enable_teacache`
+- Use Ulysses parallelism with more GPUs
+
+**Slow Inference:**
+- Enable TeaCache: `--enable_teacache`
+- Use TRTLLM backend: `--attention_backend TRTLLM`
+- Use multi-GPU: `--cfg_size 2` or `--ulysses_size 2`
+
+**Import Errors:**
+- Run from repository root
+- Install necessary dependencies, e.g., `pip install -r requirements-dev.txt`
+
+**Ulysses Errors:**
+- `ulysses_size` must divide 12 (WAN heads)
+- Total GPUs = `cfg_size Ć ulysses_size`
+- Sequence length must be divisible by `ulysses_size`
+
+## Output Formats
+
+- **WAN**: `.mp4` (video), `.gif` (animated), `.png` (single frame)
+
+## Baseline Validation
+
+Compare with official HuggingFace Diffusers implementation:
+
+```bash
+# Run HuggingFace baselines
+./hf_examples.sh
+
+# Or run individual models
+python hf_wan.py --model_path ${MODEL_ROOT}/Wan2.1-T2V-1.3B-Diffusers
+```
+
+Compare outputs with same seed for correctness verification.
diff --git a/examples/visual_gen/cat_piano.png b/examples/visual_gen/cat_piano.png
new file mode 100644
index 0000000000..3b60bf25be
Binary files /dev/null and b/examples/visual_gen/cat_piano.png differ
diff --git a/examples/visual_gen/hf_examples.sh b/examples/visual_gen/hf_examples.sh
new file mode 100755
index 0000000000..192983015d
--- /dev/null
+++ b/examples/visual_gen/hf_examples.sh
@@ -0,0 +1,128 @@
+#!/bin/bash
+# HuggingFace Baseline Tests - Official Diffusers Implementation
+#
+# Usage:
+# export PROJECT_ROOT=/path/to/tekit
+# export MODEL_ROOT=/path/to/models
+# ./hf_examples.sh
+#
+# Or inline:
+# PROJECT_ROOT=/workspace/gitlab/tekit-b200 MODEL_ROOT=/llm-models ./hf_examples.sh
+
+set -e # Exit on error
+
+# Environment variables with defaults
+PROJECT_ROOT=${PROJECT_ROOT:-"/workspace/gitlab/tekit-b200"}
+MODEL_ROOT=${MODEL_ROOT:-"/llm-models"}
+
+# Log configuration
+export TLLM_LOG_LEVEL=${TLLM_LOG_LEVEL:-"INFO"}
+
+echo "============================================"
+echo "HuggingFace Diffusers Baseline Tests"
+echo "============================================"
+echo "PROJECT_ROOT: $PROJECT_ROOT"
+echo "MODEL_ROOT: $MODEL_ROOT"
+echo "LOG_LEVEL: $TLLM_LOG_LEVEL"
+echo ""
+echo "Purpose: Establish baseline results using"
+echo " official diffusers implementations"
+echo "============================================"
+echo ""
+
+# Check Python dependencies
+echo "Checking dependencies..."
+MISSING_DEPS=""
+
+if ! python -c "import diffusers" 2>/dev/null; then
+ echo "ā ERROR: diffusers not found"
+ MISSING_DEPS="$MISSING_DEPS diffusers"
+fi
+
+if ! python -c "import torch" 2>/dev/null; then
+ echo "ā ERROR: torch not found"
+ MISSING_DEPS="$MISSING_DEPS torch"
+fi
+
+if [ -n "$MISSING_DEPS" ]; then
+ echo ""
+ echo "ā Missing required dependencies:$MISSING_DEPS"
+ echo "Install with: pip install$MISSING_DEPS"
+ exit 1
+fi
+
+echo "ā
All required dependencies found"
+echo ""
+
+# Detect GPU
+if command -v nvidia-smi &> /dev/null; then
+ GPU_COUNT=$(nvidia-smi --query-gpu=name --format=csv,noheader | wc -l)
+ echo "Detected $GPU_COUNT GPU(s)"
+ GPU_NAME=$(nvidia-smi --query-gpu=name --format=csv,noheader | head -1)
+ echo "GPU: $GPU_NAME"
+else
+ echo "ā ļø WARNING: nvidia-smi not found"
+ echo " Continuing with CPU (very slow!)"
+ GPU_COUNT=0
+fi
+echo ""
+
+# Create output directory (in current directory)
+OUTPUT_DIR="./baseline_outputs"
+mkdir -p "$OUTPUT_DIR"
+echo "Output directory: $OUTPUT_DIR ($(pwd)/baseline_outputs)"
+echo ""
+
+#############################################
+# WAN (Wan2.1) Baseline Test
+#############################################
+
+echo "============================================"
+echo "1/1: WAN Baseline Test"
+echo "============================================"
+echo ""
+
+WAN_MODEL="${MODEL_ROOT}/Wan2.1-T2V-1.3B-Diffusers/"
+WAN_OUTPUT="${OUTPUT_DIR}/wan_baseline.gif"
+
+if [ -d "$WAN_MODEL" ]; then
+ echo "Testing WAN with official diffusers..."
+ python ${PROJECT_ROOT}/examples/visual_gen/hf_wan.py \
+ --model_path "$WAN_MODEL" \
+ --output_path "$WAN_OUTPUT" \
+ --prompt "A cute cat playing piano" \
+ --height 480 \
+ --width 832 \
+ --num_frames 33 \
+ --steps 50 \
+ --guidance_scale 7.0 \
+ --seed 42
+ echo ""
+ echo "ā
WAN baseline test completed"
+ echo " Output: $WAN_OUTPUT"
+else
+ echo "ā ļø SKIPPED: WAN model not found at $WAN_MODEL"
+fi
+
+echo ""
+
+#############################################
+# Summary
+#############################################
+
+echo "============================================"
+echo "Baseline Tests Complete!"
+echo "============================================"
+echo ""
+echo "Output files saved to: $OUTPUT_DIR"
+echo ""
+ls -lh "$OUTPUT_DIR" 2>/dev/null || echo "No outputs generated"
+echo ""
+echo "Next Steps:"
+echo " 1. Verify outputs are correct (images/videos generated)"
+echo " 2. Compare with custom implementation outputs"
+echo " 3. Use these as reference/baseline for debugging"
+echo ""
+echo "Comparison command:"
+echo " diff -r $OUTPUT_DIR "
+echo "============================================"
diff --git a/examples/visual_gen/hf_wan.py b/examples/visual_gen/hf_wan.py
new file mode 100755
index 0000000000..3919794052
--- /dev/null
+++ b/examples/visual_gen/hf_wan.py
@@ -0,0 +1,141 @@
+#!/usr/bin/env python3
+"""Baseline test for WAN using official diffusers library."""
+
+import sys
+
+import torch
+from output_handler import OutputHandler, postprocess_hf_video_tensor
+
+from tensorrt_llm._torch.visual_gen import MediaOutput
+
+
+def test_wan_baseline(
+ model_path: str,
+ output_path: str,
+ prompt: str = "A cute cat playing piano",
+ height: int = 480,
+ width: int = 832,
+ num_frames: int = 33,
+ num_inference_steps: int = 50,
+ guidance_scale: float = 7.0,
+ seed: int = 42,
+):
+ """Test WAN video generation with official diffusers."""
+ from diffusers import WanPipeline
+
+ print("=" * 80)
+ print("WAN Baseline Test (Official Diffusers)")
+ print("=" * 80)
+ print()
+
+ # Load pipeline
+ print(f"Loading WAN pipeline from {model_path}...")
+ pipe = WanPipeline.from_pretrained(model_path, torch_dtype=torch.bfloat16)
+ pipe.to("cuda")
+ print("ā
Pipeline loaded")
+ print()
+
+ # Check model states
+ print("Model Training States:")
+ print(f" text_encoder.training: {pipe.text_encoder.training}")
+ print(f" transformer.training: {pipe.transformer.training}")
+ print(f" vae.training: {pipe.vae.training}")
+ print()
+
+ # Generate video
+ print(f"Generating video: '{prompt}'")
+ print(
+ f"Parameters: {height}x{width}, {num_frames} frames, {num_inference_steps} steps, guidance={guidance_scale}"
+ )
+ print()
+
+ # Set random seed
+ generator = torch.Generator(device="cuda").manual_seed(seed)
+
+ result = pipe(
+ prompt=prompt,
+ height=height,
+ width=width,
+ num_frames=num_frames,
+ num_inference_steps=num_inference_steps,
+ guidance_scale=guidance_scale,
+ generator=generator,
+ output_type="pt",
+ return_dict=False,
+ )
+
+ video = result[0]
+
+ # Post-process video tensor: (B, T, C, H, W) -> (T, H, W, C) uint8
+ video = postprocess_hf_video_tensor(video, remove_batch_dim=True)
+
+ print("=" * 80)
+ print("Generation Complete!")
+ print("=" * 80)
+ print(f"Video shape: {video.shape}")
+ print(f"Video dtype: {video.dtype}")
+ print()
+
+ # Save output
+ print(f"Saving output to {output_path}...")
+ OutputHandler.save(output=MediaOutput(video=video), output_path=output_path, frame_rate=24.0)
+ print(f"ā
Saved to {output_path}")
+ print()
+
+ print("=" * 80)
+ print("WAN BASELINE TEST PASSED ā
")
+ print("=" * 80)
+ return video
+
+
+if __name__ == "__main__":
+ import argparse
+
+ parser = argparse.ArgumentParser(
+ description="HuggingFace Baseline - WAN Text-to-Video Generation"
+ )
+
+ # Model & Input
+ parser.add_argument(
+ "--model_path",
+ type=str,
+ default="/llm-models/Wan2.1-T2V-1.3B-Diffusers/",
+ help="Path to WAN model",
+ )
+ parser.add_argument(
+ "--prompt", type=str, default="A cute cat playing piano", help="Text prompt for generation"
+ )
+ parser.add_argument(
+ "--output_path", type=str, default="wan_baseline.gif", help="Output file path"
+ )
+
+ # Generation parameters
+ parser.add_argument("--height", type=int, default=480, help="Video height")
+ parser.add_argument("--width", type=int, default=832, help="Video width")
+ parser.add_argument("--num_frames", type=int, default=33, help="Number of frames to generate")
+ parser.add_argument("--steps", type=int, default=50, help="Number of denoising steps")
+ parser.add_argument(
+ "--guidance_scale", type=float, default=7.0, help="Classifier-free guidance scale"
+ )
+ parser.add_argument("--seed", type=int, default=42, help="Random seed")
+
+ args = parser.parse_args()
+
+ try:
+ test_wan_baseline(
+ args.model_path,
+ args.output_path,
+ prompt=args.prompt,
+ height=args.height,
+ width=args.width,
+ num_frames=args.num_frames,
+ num_inference_steps=args.steps,
+ guidance_scale=args.guidance_scale,
+ seed=args.seed,
+ )
+ except Exception as e:
+ print(f"\nā ERROR: {e}")
+ import traceback
+
+ traceback.print_exc()
+ sys.exit(1)
diff --git a/examples/visual_gen/output_handler.py b/examples/visual_gen/output_handler.py
new file mode 100644
index 0000000000..a360d681f9
--- /dev/null
+++ b/examples/visual_gen/output_handler.py
@@ -0,0 +1,237 @@
+"""Unified output handler for diffusion model outputs."""
+
+import os
+from typing import Optional
+
+import torch
+from PIL import Image
+
+from tensorrt_llm import logger
+from tensorrt_llm.llmapi.visual_gen import MediaOutput
+
+
+def postprocess_hf_video_tensor(video: torch.Tensor, remove_batch_dim: bool = True) -> torch.Tensor:
+ """Post-process video tensor from HuggingFace pipeline output to final format.
+
+ HuggingFace pipelines with output_type="pt" return videos in (B, T, C, H, W) format,
+ which is different from VAE decoder output format.
+
+ Args:
+ video: Video tensor in (B, T, C, H, W) format from HuggingFace pipeline
+ remove_batch_dim: Whether to remove batch dimension. Default True for typical
+ single-batch video generation.
+
+ Returns:
+ Post-processed video tensor:
+ - If remove_batch_dim=True: (T, H, W, C) uint8 tensor
+ - If remove_batch_dim=False: (B, T, H, W, C) uint8 tensor
+
+ Note:
+ Assumes video values are in [-1, 1] range (standard pipeline output).
+ """
+ # Remove batch dimension first if requested
+ if remove_batch_dim:
+ video = video[0] # (B, T, C, H, W) -> (T, C, H, W)
+ video = video.permute(0, 2, 3, 1) # (T, C, H, W) -> (T, H, W, C)
+ else:
+ video = video.permute(0, 1, 3, 4, 2) # (B, T, C, H, W) -> (B, T, H, W, C)
+
+ # Normalize to [0, 1] range
+ video = (video / 2 + 0.5).clamp(0, 1)
+
+ # Convert to uint8
+ video = (video * 255).round().to(torch.uint8)
+
+ return video
+
+
+def postprocess_hf_image_tensor(image: torch.Tensor) -> torch.Tensor:
+ """Post-process image tensor from HuggingFace pipeline output to final format.
+
+ HuggingFace pipelines with output_type="pt" return images in (B, C, H, W) format.
+
+ Args:
+ image: Image tensor in (B, C, H, W) or (C, H, W) format from HuggingFace pipeline
+
+ Returns:
+ Post-processed image tensor in (H, W, C) uint8 format
+
+ Note:
+ Assumes image values are in [-1, 1] range (standard pipeline output).
+ """
+ # Remove batch dimension if present
+ if image.ndim == 4:
+ image = image[0] # (B, C, H, W) -> (C, H, W)
+
+ # Convert to (H, W, C) format
+ image = image.permute(1, 2, 0) # (C, H, W) -> (H, W, C)
+
+ # Normalize to [0, 1] range
+ image = (image / 2 + 0.5).clamp(0, 1)
+
+ # Convert to uint8
+ image = (image * 255).round().to(torch.uint8)
+
+ return image
+
+
+class OutputHandler:
+ """Handle saving of generated outputs in various formats.
+
+ Supports MediaOutput from all models:
+ - Video models (WAN): MediaOutput(video=torch.Tensor)
+ - Image models: MediaOutput(image=torch.Tensor)
+ - Video+Audio models: MediaOutput(video=torch.Tensor, audio=torch.Tensor)
+
+ Supported output formats:
+ - .png: Save single image or middle frame
+ - .gif: Save video as animated GIF (no audio)
+ - .mp4: Save video with audio (requires diffusers export_utils)
+ """
+
+ @staticmethod
+ def save(output: MediaOutput, output_path: str, frame_rate: float = 24.0):
+ """Save output based on content type and file extension.
+
+ Args:
+ output: MediaOutput containing model outputs (image/video/audio)
+ output_path: Path to save the output file
+ frame_rate: Frames per second for video output (default: 24.0)
+ """
+ if not isinstance(output, MediaOutput):
+ raise ValueError(f"Expected output to be MediaOutput, got {type(output)}")
+
+ file_ext = os.path.splitext(output_path)[1].lower()
+
+ # Determine content type
+ if output.image is not None:
+ OutputHandler._save_image(output.image, output_path, file_ext)
+ elif output.video is not None:
+ OutputHandler._save_video(output.video, output.audio, output_path, file_ext, frame_rate)
+ else:
+ raise ValueError("Unknown output format. MediaOutput has no image or video data.")
+
+ @staticmethod
+ def _save_image(image: torch.Tensor, output_path: str, file_ext: str):
+ """Save single image output.
+
+ Args:
+ image: Image as torch tensor (H, W, C) uint8
+ output_path: Path to save the image
+ file_ext: File extension (.png, .jpg, etc.)
+ """
+ if file_ext not in [".png", ".jpg", ".jpeg"]:
+ logger.warning(f"Image output requested with {file_ext}, defaulting to .png")
+ output_path = output_path.replace(file_ext, ".png")
+
+ # Convert torch.Tensor to PIL Image and save
+ image_np = image.cpu().numpy()
+ Image.fromarray(image_np).save(output_path)
+ logger.info(f"Saved image to {output_path}")
+
+ @staticmethod
+ def _save_video(
+ video: torch.Tensor,
+ audio: Optional[torch.Tensor],
+ output_path: str,
+ file_ext: str,
+ frame_rate: float,
+ ):
+ """Save video output with optional audio.
+
+ Args:
+ video: Video frames as torch tensor (T, H, W, C) with dtype uint8
+ audio: Optional audio as torch tensor
+ output_path: Path to save the video
+ file_ext: File extension (.mp4, .gif, .png)
+ frame_rate: Frames per second
+ """
+ if file_ext == ".mp4":
+ OutputHandler._save_mp4(video, audio, output_path, frame_rate)
+ elif file_ext == ".gif":
+ OutputHandler._save_gif(video, output_path, frame_rate)
+ elif file_ext == ".png":
+ OutputHandler._save_middle_frame(video, output_path)
+ else:
+ logger.warning(f"Unsupported video output format: {file_ext}, defaulting to .png")
+ output_path = output_path.replace(file_ext, ".png")
+ OutputHandler._save_middle_frame(video, output_path)
+
+ @staticmethod
+ def _save_mp4(
+ video: torch.Tensor, audio: Optional[torch.Tensor], output_path: str, frame_rate: float
+ ):
+ """Save video with optional audio as MP4.
+
+ Args:
+ video: Video frames as torch tensor (T, H, W, C) uint8
+ audio: Optional audio as torch tensor (float32)
+ output_path: Output path for MP4
+ frame_rate: Frames per second
+ """
+ try:
+ from diffusers.pipelines.ltx2.export_utils import encode_video
+
+ # Prepare audio if present
+ audio_prepared = audio.float() if audio is not None else None
+
+ # encode_video expects (T, H, W, C) uint8 video and float32 audio
+ encode_video(
+ video,
+ fps=frame_rate,
+ audio=audio_prepared,
+ audio_sample_rate=24000 if audio_prepared is not None else None,
+ output_path=output_path,
+ )
+ logger.info(f"Saved video{' with audio' if audio is not None else ''} to {output_path}")
+
+ except ImportError:
+ logger.warning(
+ "diffusers export_utils (encode_video) not available. "
+ "Falling back to saving middle frame as PNG."
+ )
+ png_path = output_path.replace(".mp4", ".png")
+ OutputHandler._save_middle_frame(video, png_path)
+
+ @staticmethod
+ def _save_gif(video: torch.Tensor, output_path: str, frame_rate: float):
+ """Save video as animated GIF.
+
+ Args:
+ video: Video frames as torch tensor (T, H, W, C) uint8
+ output_path: Output path for GIF
+ frame_rate: Frames per second
+ """
+ # Convert torch.Tensor to numpy for PIL
+ video_np = video.cpu().numpy()
+
+ # Convert to list of PIL Images
+ frames = [Image.fromarray(video_np[i]) for i in range(video_np.shape[0])]
+
+ # Save as animated GIF
+ duration_ms = int(1000 / frame_rate)
+ frames[0].save(
+ output_path,
+ save_all=True,
+ append_images=frames[1:],
+ optimize=False,
+ duration=duration_ms,
+ loop=0,
+ )
+ logger.info(f"Saved video as GIF to {output_path} ({len(frames)} frames)")
+
+ @staticmethod
+ def _save_middle_frame(video: torch.Tensor, output_path: str):
+ """Save middle frame of video as PNG.
+
+ Args:
+ video: Video frames as torch tensor (T, H, W, C) uint8
+ output_path: Output path for PNG
+ """
+ # Convert torch.Tensor to numpy for PIL
+ video_np = video.cpu().numpy()
+
+ # Extract middle frame
+ frame_idx = video_np.shape[0] // 2
+ Image.fromarray(video_np[frame_idx]).save(output_path)
+ logger.info(f"Saved frame {frame_idx} to {output_path}")
diff --git a/examples/visual_gen/serve/README.md b/examples/visual_gen/serve/README.md
new file mode 100644
index 0000000000..b68dc7f2a2
--- /dev/null
+++ b/examples/visual_gen/serve/README.md
@@ -0,0 +1,322 @@
+# Visual Generation API Examples
+
+This directory contains example scripts that demonstrate how to use the TensorRT-LLM Visual Generation API endpoints for image and video generation.
+
+## Overview
+
+These examples show how to interact with the visual generation server using both the OpenAI Python SDK and standard HTTP requests. The API provides endpoints for:
+
+- **Image Generation**: Text-to-image generation (T2I)
+- **Video Generation**:
+ - Text-to-video generation (T2V) - generate videos from text prompts only
+ - Text+Image-to-video generation (TI2V) - generate videos from text + reference image
+ - Both synchronous and asynchronous modes supported
+ - Multipart/form-data support for file uploads
+- **Video Management**: Retrieving and deleting generated videos
+
+## Prerequisites
+
+Before running these examples, ensure you have:
+
+1. **Install modules**: Install required dependencies before running examples:
+
+ ```bash
+ pip install git+https://github.com/huggingface/diffusers.git
+ pip install av
+ ```
+
+2. **Server Running**: The TensorRT-LLM visual generation server must be running
+ ```bash
+ trtllm-serve --extra_visual_gen_options
+ ```
+
+ e.g.
+
+ ```bash
+ trtllm-serve $LLM_MODEL_DIR/Wan2.1-T2V-1.3B-Diffusers --extra_visual_gen_options ./configs/wan.yml
+
+ # Run server on background:
+ trtllm-serve $LLM_MODEL_DIR/Wan2.1-T2V-1.3B-Diffusers --extra_visual_gen_options ./configs/wan.yml > /tmp/serve.log 2>&1 &
+
+ ## Check if the server is setup
+ tail -f /tmp/serve.log
+
+ ```
+
+## Examples
+
+Current supported & tested models:
+
+1. WAN T2V/I2V for video generation (t2v, ti2v, delete_video)
+
+### 1. Synchronous Image Generation (`sync_t2i.py`)
+
+Demonstrates synchronous text-to-image generation using the OpenAI SDK.
+
+**Features:**
+- Generates images from text prompts
+- Supports configurable image size and quality
+- Returns base64-encoded images or URLs
+- Saves generated images to disk
+
+**Usage:**
+```bash
+# Use default localhost server
+python sync_image_gen.py
+
+# Specify custom server URL
+python sync_image_gen.py http://your-server:8000/v1
+```
+
+**API Endpoint:** `POST /v1/images/generations`
+
+**Output:** Saves generated image to `output_generation.png` (or numbered files for multiple images)
+
+---
+
+### 2. Synchronous Video Generation with T2V and TI2V Modes (`sync_video_gen.py`)
+
+Demonstrates synchronous video generation using direct HTTP requests. Waits for completion and returns the video file directly.
+
+**Features:**
+- **T2V Mode**: Generate videos from text prompts only
+- **TI2V Mode**: Generate videos from text + reference image (multipart/form-data)
+- Waits for video generation to complete before returning
+- Returns video file directly in response
+- Command-line interface for easy testing
+
+**Usage:**
+
+```bash
+# Text-to-Video (T2V) - No reference image
+python sync_video_gen.py --mode t2v \
+ --prompt "A cute cat playing with a ball in the park" \
+ --duration 4.0 --fps 24 --size 256x256
+
+# Text+Image-to-Video (TI2V) - With reference image
+## Note: longer duration and higher size will lead to much longer waiting time
+python sync_video_gen.py --mode ti2v \
+ --prompt "She turns around and smiles, then slowly walks out of the frame" \
+ --image ./media/woman_skyline_original_720p.jpeg \
+ --duration 4.0 --fps 24 --size 512x512
+
+# Custom parameters
+python sync_video_gen.py --mode t2v \
+ --prompt "A serene sunset over the ocean" \
+ --duration 5.0 --fps 30 --size 512x512 \
+ --output my_video.mp4
+```
+
+**Command-Line Arguments:**
+- `--mode` - Generation mode: `t2v` or `ti2v` (default: t2v)
+- `--prompt` - Text prompt for video generation (required)
+- `--image` - Path to reference image (required for ti2v mode)
+- `--base-url` - API server URL (default: http://localhost:8000/v1)
+- `--model` - Model name (default: wan)
+- `--duration` - Video duration in seconds (default: 4.0)
+- `--fps` - Frames per second (default: 24)
+- `--size` - Video resolution in WxH format (default: 256x256)
+- `--output` - Output video file path (default: output_sync.mp4)
+
+**API Endpoint:** `POST /v1/videos/generations`
+
+**API Details:**
+- T2V uses JSON `Content-Type: application/json`
+- TI2V uses multipart/form-data `Content-Type: multipart/form-data` with file upload
+
+**Output:** Saves generated video to specified output file
+
+---
+
+### 3. Async Video Generation with T2V and TI2V Modes (`async_video_gen.py`)
+
+**NEW**: Enhanced async video generation supporting both Text-to-Video (T2V) and Text+Image-to-Video (TI2V) modes.
+
+**Features:**
+- **T2V Mode**: Generate videos from text prompts only (JSON request)
+- **TI2V Mode**: Generate videos from text + reference image (multipart/form-data with file upload)
+- Command-line interface for easy testing
+- Automatic mode detection
+- Comprehensive parameter control
+
+**Usage:**
+
+```bash
+# Text-to-Video (T2V) - No reference image
+python async_video_gen.py --mode t2v \
+ --prompt "A cool cat on a motorcycle in the night" \
+ --duration 4.0 --fps 24 --size 256x256
+
+# Text+Image-to-Video (TI2V) - With reference image
+python async_video_gen.py --mode ti2v \
+ --prompt "She turns around and smiles, then slowly walks out of the frame" \
+ --image ./media/woman_skyline_original_720p.jpeg \
+ --duration 4.0 --fps 24 --size 512x512
+
+# Custom parameters
+python async_video_gen.py --mode t2v \
+ --prompt "A serene sunset over the ocean" \
+ --duration 5.0 --fps 30 --size 512x512 \
+ --output my_video.mp4
+```
+
+**Command-Line Arguments:**
+- `--mode` - Generation mode: `t2v` or `ti2v` (default: t2v)
+- `--prompt` - Text prompt for video generation (required)
+- `--image` - Path to reference image (required for ti2v mode)
+- `--base-url` - API server URL (default: http://localhost:8000/v1)
+- `--model` - Model name (default: wan)
+- `--duration` - Video duration in seconds (default: 4.0)
+- `--fps` - Frames per second (default: 24)
+- `--size` - Video resolution in WxH format (default: 256x256)
+- `--output` - Output video file path (default: output_async.mp4)
+
+**API Details:**
+- T2V uses JSON `Content-Type: application/json`
+- TI2V uses multipart/form-data `Content-Type: multipart/form-data` with file upload
+
+**Output:** Saves generated video to specified output file
+
+---
+
+### 4. Video Deletion (`delete_video.py`)
+
+Demonstrates the complete lifecycle of video generation and deletion.
+
+**Features:**
+- Creates a test video generation job
+- Waits for completion
+- Deletes the generated video
+- Verifies deletion by attempting to retrieve the deleted video
+- Tests error handling for non-existent videos
+
+**Usage:**
+```bash
+# Use default localhost server
+python delete_video.py
+
+# Specify custom server URL
+python delete_video.py http://your-server:8000/v1
+```
+
+**API Endpoints:**
+- `POST /v1/videos` - Create video job
+- `GET /v1/videos/{video_id}` - Check video status
+- `DELETE /v1/videos/{video_id}` - Delete video
+
+**Test Flow:**
+1. Create video generation job
+2. Wait for completion
+3. Delete the video
+4. Verify video returns `NotFoundError`
+5. Test deletion of non-existent video
+
+---
+
+## API Configuration
+
+All examples use the following default configuration:
+
+- **Base URL**: `http://localhost:8000/v1`
+- **API Key**: `"tensorrt_llm"` (authentication token)
+- **Timeout**: 300 seconds for async operations
+
+You can customize these by:
+1. Passing the base URL as a command-line argument
+2. Modifying the default parameters in each script's function
+
+## Common Parameters
+
+### Image Generation
+- `model`: Model identifier (e.g., "wan")
+- `prompt`: Text description
+- `n`: Number of images to generate
+- `size`: Image dimensions (e.g., "512x512", "1024x1024")
+- `quality`: "standard" or "hd"
+- `response_format`: "b64_json" or "url"
+
+### Video Generation
+- `model`: Model identifier (e.g., "wan")
+- `prompt`: Text description
+- `size`: Video resolution (e.g., "256x256", "512x512")
+- `seconds`: Duration in seconds
+- `fps`: Frames per second
+- `input_reference`: Reference image file (for TI2V mode)
+
+## Quick Reference - curl Examples
+
+### Text-to-Video (JSON)
+```bash
+curl -X POST "http://localhost:8000/v1/videos" \
+ -H "Content-Type: application/json" \
+ -d '{
+ "prompt": "A cool cat on a motorcycle",
+ "seconds": 4.0,
+ "fps": 24,
+ "size": "256x256"
+ }'
+```
+
+### Text+Image-to-Video (Multipart with File Upload)
+```bash
+curl -X POST "http://localhost:8000/v1/videos" \
+ -F "prompt=She turns around and smiles" \
+ -F "input_reference=@./media/woman_skyline_original_720p.jpeg" \
+ -F "seconds=4.0" \
+ -F "fps=24" \
+ -F "size=256x256" \
+ -F "guidance_scale=5.0"
+```
+
+### Check Video Status
+```bash
+curl -X GET "http://localhost:8000/v1/videos/{video_id}"
+```
+
+### Download Video
+```bash
+curl -X GET "http://localhost:8000/v1/videos/{video_id}/content" -o output.mp4
+```
+
+### Delete Video
+```bash
+curl -X DELETE "http://localhost:8000/v1/videos/{video_id}"
+```
+
+## API Endpoints Summary
+
+| Endpoint | Method | Mode | Content-Type | Purpose |
+|----------|--------|------|--------------|---------|
+| `/v1/videos` | POST | Async | JSON or Multipart | Create video job (T2V/TI2V) |
+| `/v1/videos/generations` | POST | Sync | JSON or Multipart | Generate video sync (T2V/TI2V) |
+| `/v1/videos/{id}` | GET | - | - | Get video status/metadata |
+| `/v1/videos/{id}/content` | GET | - | - | Download video file |
+| `/v1/videos/{id}` | DELETE | - | - | Delete video |
+| `/v1/videos` | GET | - | - | List all videos |
+| `/v1/images/generations` | POST | - | JSON | Generate images (T2I) |
+
+**Note:** Both `/v1/videos` (async) and `/v1/videos/generations` (sync) support:
+- **JSON**: Standard text-to-video (T2V)
+- **Multipart/Form-Data**: Text+image-to-video (TI2V) with file upload
+
+## Error Handling
+
+All examples include comprehensive error handling:
+
+- Connection errors (server not running)
+- API errors (invalid parameters, model not found)
+- Timeout errors (generation taking too long)
+- Resource errors (video not found for deletion)
+
+Errors are displayed with full stack traces for debugging.
+
+## Output Files
+
+Generated files are saved to the current working directory:
+
+- `output_generation.png` - Synchronous image generation (`sync_image_gen.py`)
+- `output_sync.mp4` - Synchronous video generation (`sync_video_gen.py`)
+- `output_async.mp4` - Asynchronous video generation (`async_video_gen.py`)
+- `output_multipart.mp4` - Multipart example output (`multipart_example.py`)
+
+**Note:** You can customize output filenames using the `--output` parameter in all scripts.
diff --git a/examples/visual_gen/serve/async_video_gen.py b/examples/visual_gen/serve/async_video_gen.py
new file mode 100755
index 0000000000..dec93bf3fa
--- /dev/null
+++ b/examples/visual_gen/serve/async_video_gen.py
@@ -0,0 +1,238 @@
+#!/usr/bin/env python
+"""Test script for asynchronous video generation endpoint.
+
+Tests POST /v1/videos endpoint which returns immediately with a job ID.
+The video is generated in the background and can be retrieved later.
+
+Supports two modes:
+ - Text-to-Video (T2V): Generate video from text prompt only
+ - Text+Image-to-Video (TI2V): Generate video from text prompt + reference image
+
+Examples:
+ # Text-to-Video (T2V)
+ python async_video_gen.py --mode t2v --prompt "A cool cat on a motorcycle"
+
+ # Text+Image-to-Video (TI2V)
+ python async_video_gen.py --mode ti2v --prompt "She turns and smiles" --image ./media/woman.jpg
+"""
+
+import argparse
+import sys
+import time
+from pathlib import Path
+
+import openai
+
+
+def test_async_video_generation(
+ base_url: str = "http://localhost:8000/v1",
+ model: str = "wan",
+ prompt: str = "A video of a cool cat on a motorcycle in the night",
+ input_reference: str = None,
+ duration: float = 4.0,
+ fps: int = 24,
+ size: str = "256x256",
+ output_file: str = "output_async.mp4",
+):
+ """Test asynchronous video generation with OpenAI SDK.
+
+ Args:
+ base_url: Base URL of the API server
+ model: Model name to use
+ prompt: Text prompt for generation
+ input_reference: Path to reference image (optional, for TI2V mode)
+ duration: Video duration in seconds
+ fps: Frames per second
+ size: Video resolution (WxH format)
+ output_file: Output video file path
+ """
+ mode = "TI2V" if input_reference else "T2V"
+ print("=" * 80)
+ print(f"Testing Async Video Generation API - {mode} Mode")
+ print("=" * 80)
+
+ # Initialize client
+ client = openai.OpenAI(base_url=base_url, api_key="tensorrt_llm")
+
+ print("\n1. Creating video generation job...")
+ print(f" Mode: {mode}")
+ print(f" Prompt: {prompt}")
+ if input_reference:
+ print(f" Input Reference: {input_reference}")
+ print(f" Duration: {duration}s")
+ print(f" FPS: {fps}")
+ print(f" Size: {size}")
+
+ try:
+ # Prepare request parameters
+ create_params = {
+ "model": model,
+ "prompt": prompt,
+ "size": size,
+ "seconds": duration,
+ "extra_body": {
+ "fps": fps,
+ },
+ }
+
+ # Add input reference if provided (TI2V mode)
+ if input_reference:
+ if not Path(input_reference).exists():
+ print(f"\nā Error: Input reference image not found: {input_reference}")
+ return False
+ create_params["input_reference"] = open(input_reference, "rb")
+
+ # Create video generation job
+ job = client.videos.create(**create_params)
+
+ print("Video generation started: \n", job.model_dump_json(indent=2))
+
+ video_id = job.id
+ print("\nā Job created successfully!")
+ print(f" Video ID: {video_id}")
+ print(f" Status: {job.status}")
+
+ # Poll for completion
+ print("\n2. Polling for completion...")
+ max_attempts = 300 # 5 minutes with 1s intervals
+ attempt = 0
+
+ while attempt < max_attempts:
+ attempt += 1
+
+ # Get job status using SDK's get method
+ job = client.videos.retrieve(video_id)
+ status = job.status
+
+ print(f" [{attempt:3d}] Status: {status}", end="\r")
+
+ if status == "completed":
+ print("\n\nā Video generation completed!")
+ print(f" Completion time: {job.completed_at}")
+ break
+ elif status == "failed":
+ print("\n\nā Video generation failed!")
+ print(f" Error: {job.error}")
+ return False
+
+ time.sleep(1)
+ else:
+ print(f"\n\nā Timeout waiting for completion (>{max_attempts}s)")
+ return False
+
+ # Download video
+ print("\n3. Downloading video...")
+ # For binary content, use the underlying HTTP client
+ content = client.videos.download_content(video_id, variant="video")
+ content.write_to_file(output_file)
+ print(f" ā Saved to: {output_file}")
+
+ print("\n" + "=" * 80)
+ print("ā Async video generation test completed successfully!")
+ print("=" * 80)
+ return True
+
+ except Exception as e:
+ print(f"\nā Error: {e}")
+ import traceback
+
+ traceback.print_exc()
+ return False
+
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser(
+ description="Test async video generation API with T2V and TI2V modes",
+ formatter_class=argparse.RawDescriptionHelpFormatter,
+ epilog="""
+Examples:
+ # Text-to-Video (T2V)
+ python async_video_gen.py --mode t2v --prompt "A cool cat on a motorcycle"
+
+ # Text+Image-to-Video (TI2V)
+ python async_video_gen.py --mode ti2v \\
+ --prompt "She turns around and smiles, then slowly walks out of the frame" \\
+ --image ./media/woman_skyline_original_720p.jpeg
+
+ # Custom parameters
+ python async_video_gen.py --mode t2v \\
+ --prompt "A serene sunset over the ocean" \\
+ --duration 5.0 --fps 30 --size 512x512 \\
+ --output my_video.mp4
+ """,
+ )
+
+ # Mode selection
+ parser.add_argument(
+ "--mode",
+ choices=["t2v", "ti2v"],
+ default="t2v",
+ help="Generation mode: t2v (Text-to-Video) or ti2v (Text+Image-to-Video)",
+ )
+
+ # Required parameters
+ parser.add_argument(
+ "--prompt",
+ type=str,
+ default="A video of a cool cat on a motorcycle in the night",
+ help="Text prompt for video generation",
+ )
+
+ # TI2V mode parameters
+ parser.add_argument(
+ "--image",
+ "--input-reference",
+ type=str,
+ default=None,
+ help="Path to reference image (required for ti2v mode)",
+ )
+
+ # Optional parameters
+ parser.add_argument(
+ "--base-url",
+ type=str,
+ default="http://localhost:8000/v1",
+ help="Base URL of the API server",
+ )
+ parser.add_argument("--model", type=str, default="wan", help="Model name to use")
+ parser.add_argument(
+ "--duration", "--seconds", type=float, default=4.0, help="Video duration in seconds"
+ )
+ parser.add_argument("--fps", type=int, default=24, help="Frames per second")
+ parser.add_argument(
+ "--size",
+ type=str,
+ default="256x256",
+ help="Video resolution in WxH format (e.g., 1280x720)",
+ )
+ parser.add_argument(
+ "--output", type=str, default="output_async.mp4", help="Output video file path"
+ )
+
+ args = parser.parse_args()
+
+ # Validate ti2v mode requirements
+ if args.mode == "ti2v" and not args.image:
+ parser.error("--image is required when using --mode ti2v")
+
+ # Display configuration
+ print("\n" + "=" * 80)
+ print("OpenAI SDK - Async Video Generation Test")
+ print("=" * 80)
+ print(f"Base URL: {args.base_url}")
+ print(f"Mode: {args.mode.upper()}")
+ print()
+
+ # Test async video generation
+ success = test_async_video_generation(
+ base_url=args.base_url,
+ model=args.model,
+ prompt=args.prompt,
+ input_reference=args.image,
+ duration=args.duration,
+ fps=args.fps,
+ size=args.size,
+ output_file=args.output,
+ )
+
+ sys.exit(0 if success else 1)
diff --git a/examples/visual_gen/serve/configs/wan.yml b/examples/visual_gen/serve/configs/wan.yml
new file mode 100644
index 0000000000..7dc65e6214
--- /dev/null
+++ b/examples/visual_gen/serve/configs/wan.yml
@@ -0,0 +1,8 @@
+linear:
+ type: default
+teacache:
+ enable_teacache: true
+ teacache_thresh: 0.2
+parallel:
+ dit_cfg_size: 1
+ dit_ulysses_size: 1
diff --git a/examples/visual_gen/serve/delete_video.py b/examples/visual_gen/serve/delete_video.py
new file mode 100755
index 0000000000..d44b8f046e
--- /dev/null
+++ b/examples/visual_gen/serve/delete_video.py
@@ -0,0 +1,151 @@
+#!/usr/bin/env python
+"""Test script for DELETE /v1/videos/{video_id} endpoint.
+
+Tests the video deletion functionality by:
+1. Creating a video generation job
+2. Waiting for completion
+3. Deleting the video
+4. Verifying the deletion
+"""
+
+import sys
+import time
+
+import openai
+
+
+def test_delete_video(
+ base_url: str = "http://localhost:8000/v1",
+ model: str = "wan",
+ prompt: str = "A simple test video for deletion",
+ duration: float = 2.0,
+ fps: int = 8,
+ size: str = "256x256",
+):
+ """Test video deletion endpoint using OpenAI SDK."""
+ print("=" * 80)
+ print("Testing DELETE /v1/videos/{video_id} Endpoint")
+ print("=" * 80)
+
+ # Initialize OpenAI client
+ client = openai.OpenAI(base_url=base_url, api_key="tensorrt_llm")
+
+ video_id = None
+
+ try:
+ # Step 1: Create a video generation job
+ print("\n1. Creating video generation job...")
+ print(f" Prompt: {prompt}")
+ print(f" Duration: {duration}s")
+ print(f" FPS: {fps}")
+ print(f" Size: {size}")
+
+ job = client.videos.create(
+ model=model,
+ prompt=prompt,
+ size=size,
+ seconds=duration,
+ extra_body={
+ "fps": fps,
+ },
+ )
+
+ video_id = job.id
+ print(f" ā Video job created with ID: {video_id}")
+ print(f" Status: {job.status}")
+
+ # Step 2: Wait for video completion
+ print("\n2. Waiting for video generation to complete...")
+ max_attempts = 60 # attempts with 1s intervals
+ attempt = 0
+
+ while attempt < max_attempts:
+ attempt += 1
+
+ # Get job status using SDK's retrieve method
+ job = client.videos.retrieve(video_id)
+ status = job.status
+
+ print(f" [{attempt:3d}] Status: {status}", end="\r")
+
+ if status == "completed":
+ print(" ā Video generation completed!")
+ break
+ elif status == "failed":
+ print(" ā Video generation failed!")
+ return False
+
+ time.sleep(1)
+ else:
+ print(" ā Timeout waiting for video completion")
+ # Continue with deletion anyway
+
+ # Step 3: Delete the video
+ print(f"\n3. Deleting video {video_id}...")
+
+ delete_result = client.videos.delete(video_id)
+
+ print(f" Response: {delete_result.model_dump_json(indent=2)}")
+
+ if delete_result.deleted:
+ print(" ā Video deleted successfully!")
+ else:
+ print(" ā Video deletion returned False")
+ return False
+
+ # Step 4: Verify the video is gone
+ print("\n4. Verifying video deletion...")
+
+ try:
+ verify_job = client.videos.retrieve(video_id)
+ print(f" ā Video still exists after deletion: {verify_job.status}")
+ return False
+ except openai.NotFoundError as e:
+ print(" ā Video correctly returns NotFoundError")
+ print(f" Error message: {e.message}")
+ except Exception as e:
+ print(f" ā Unexpected error: {type(e).__name__}: {e}")
+
+ # Step 5: Test deleting non-existent video
+ print("\n5. Testing deletion of non-existent video...")
+
+ fake_id = "nonexistent_video_id"
+
+ try:
+ fake_delete_result = client.videos.delete(fake_id)
+ print(" ā Deletion of non-existent video did not raise error")
+ print(f" Response: {fake_delete_result.model_dump_json(indent=2)}")
+ except openai.NotFoundError as e:
+ print(" ā Correctly raises NotFoundError for non-existent video")
+ print(f" Error message: {e.message}")
+ except Exception as e:
+ print(f" ā Unexpected error: {type(e).__name__}: {e}")
+
+ print("\n" + "=" * 80)
+ print("ā Video deletion test completed successfully!")
+ print("=" * 80)
+ return True
+
+ except Exception as e:
+ print(f"\nā Error: {e}")
+ import traceback
+
+ traceback.print_exc()
+ return False
+
+
+if __name__ == "__main__":
+ # Parse command line arguments
+ base_url = sys.argv[1] if len(sys.argv) > 1 else "http://localhost:8000/v1"
+
+ print("\n" + "=" * 80)
+ print("OpenAI SDK - Video Deletion Test")
+ print("=" * 80)
+ print(f"Base URL: {base_url}")
+ print()
+
+ # Test video deletion
+ success = test_delete_video(base_url=base_url)
+
+ # Exit with appropriate code
+ sys.exit(0 if success else 1)
diff --git a/examples/visual_gen/serve/media/woman_skyline_original_720p.jpeg b/examples/visual_gen/serve/media/woman_skyline_original_720p.jpeg
new file mode 100644
index 0000000000..44a7ed5c3a
Binary files /dev/null and b/examples/visual_gen/serve/media/woman_skyline_original_720p.jpeg differ
diff --git a/examples/visual_gen/serve/sync_image_gen.py b/examples/visual_gen/serve/sync_image_gen.py
new file mode 100755
index 0000000000..ca3c33d543
--- /dev/null
+++ b/examples/visual_gen/serve/sync_image_gen.py
@@ -0,0 +1,91 @@
+#!/usr/bin/env python
+"""Test script for image generation endpoints.
+
+Tests:
+- POST /v1/images/generations - Generate images from text
+- POST /v1/images/edits - Edit images with text prompts
+"""
+
+import base64
+import sys
+
+import openai
+
+
+def test_image_generation(
+ base_url: str = "http://localhost:8000/v1",
+ model: str = "flux2",
+ prompt: str = "A lovely cat lying on a sofa",
+ n: int = 1,
+ size: str = "512x512",
+ quality: str = "standard",
+ response_format: str = "b64_json",
+ output_file: str = "output_generation.png",
+):
+ """Test image generation endpoint."""
+ print("=" * 80)
+ print("Testing Image Generation API (POST /v1/images/generations)")
+ print("=" * 80)
+
+ # Initialize client
+ client = openai.OpenAI(base_url=base_url, api_key="tensorrt_llm")
+
+ print("\n1. Generating image...")
+ print(f" Prompt: {prompt}")
+ print(f" Size: {size}")
+ print(f" Quality: {quality}")
+ print(f" Number of images: {n}")
+
+ try:
+ # Use OpenAI SDK's images.generate() method
+ response = client.images.generate(
+ model=model,
+ prompt=prompt,
+ n=n,
+ size=size,
+ quality=quality,
+ response_format=response_format,
+ )
+
+ print("\nā Image generated successfully!")
+ print(f" Number of images: {len(response.data)}")
+
+ # Save images
+ for i, image in enumerate(response.data):
+ if response_format == "b64_json":
+ # Decode base64 image
+ image_data = base64.b64decode(image.b64_json)
+ output = f"{output_file.rsplit('.', 1)[0]}_{i}.png" if n > 1 else output_file
+
+ with open(output, "wb") as f:
+ f.write(image_data)
+
+ print(f" ā Saved image {i + 1} to: {output} ({len(image_data)} bytes)")
+ else:
+ print(f" Image {i + 1} URL: {image.url}")
+
+ print("\n" + "=" * 80)
+ print("ā Image generation test completed successfully!")
+ print("=" * 80)
+ return True
+
+ except Exception as e:
+ print(f"\nā Error: {e}")
+ import traceback
+
+ traceback.print_exc()
+ return False
+
+
+if __name__ == "__main__":
+ # Parse command line arguments
+ base_url = sys.argv[1] if len(sys.argv) > 1 else "http://localhost:8000/v1"
+
+ print("\n" + "=" * 80)
+ print("OpenAI SDK - Image Generation Tests")
+ print("=" * 80)
+ print(f"Base URL: {base_url}")
+ print()
+
+ # Test image generation
+ test_image_generation(base_url=base_url)
diff --git a/examples/visual_gen/serve/sync_video_gen.py b/examples/visual_gen/serve/sync_video_gen.py
new file mode 100755
index 0000000000..2f349e47b6
--- /dev/null
+++ b/examples/visual_gen/serve/sync_video_gen.py
@@ -0,0 +1,224 @@
+#!/usr/bin/env python
+"""Test script for synchronous video generation endpoint.
+
+Tests POST /v1/videos/generations endpoint which waits for completion and returns video data.
+The video is generated synchronously and the response contains the video file.
+
+Supports two modes:
+ - Text-to-Video (T2V): Generate video from text prompt only
+ - Text+Image-to-Video (TI2V): Generate video from text prompt + reference image
+
+Examples:
+ # Text-to-Video (T2V)
+ python sync_video_gen.py --mode t2v --prompt "A cool cat on a motorcycle"
+
+ # Text+Image-to-Video (TI2V)
+ python sync_video_gen.py --mode ti2v --prompt "She turns and smiles" --image ./media/woman.jpg
+"""
+
+import argparse
+import sys
+from pathlib import Path
+
+import requests
+
+
+def test_sync_video_generation(
+ base_url: str = "http://localhost:8000/v1",
+ model: str = "wan",
+ prompt: str = "A video of a cute cat playing with a ball in the park",
+ input_reference: str = None,
+ duration: float = 4.0,
+ fps: int = 24,
+ size: str = "256x256",
+ output_file: str = "output_sync.mp4",
+):
+ """Test synchronous video generation with direct HTTP requests.
+
+ Args:
+ base_url: Base URL of the API server
+ model: Model name to use
+ prompt: Text prompt for generation
+ input_reference: Path to reference image (optional, for TI2V mode)
+ duration: Video duration in seconds
+ fps: Frames per second
+ size: Video resolution (WxH format)
+ output_file: Output video file path
+ """
+ mode = "TI2V" if input_reference else "T2V"
+ print("=" * 80)
+ print(f"Testing Sync Video Generation API - {mode} Mode")
+ print("=" * 80)
+
+ print("\n1. Generating video (waiting for completion)...")
+ print(f" Mode: {mode}")
+ print(f" Prompt: {prompt}")
+ if input_reference:
+ print(f" Input Reference: {input_reference}")
+ print(f" Duration: {duration}s")
+ print(f" FPS: {fps}")
+ print(f" Size: {size}")
+
+ try:
+ endpoint = f"{base_url}/videos/generations"
+
+ if input_reference:
+ # TI2V mode - Use multipart/form-data with file upload
+ if not Path(input_reference).exists():
+ print(f"\nā Error: Input reference image not found: {input_reference}")
+ return False
+
+ # Prepare form data (all values as strings for multipart)
+ form_data = {
+ "model": model,
+ "prompt": prompt,
+ "size": size,
+ "seconds": str(duration),
+ "fps": str(fps),
+ }
+
+ # Add the file
+ ## Note: The content-type must be multipart/form-data.
+ files = {
+ "input_reference": (
+ Path(input_reference).name,
+ open(input_reference, "rb"),
+ "multipart/form-data",
+ )
+ }
+
+ print("\n Uploading reference image and generating video...")
+ response_video = requests.post(endpoint, data=form_data, files=files)
+ else:
+ # T2V mode - Use JSON
+ response_video = requests.post(
+ endpoint,
+ json={
+ "model": model,
+ "prompt": prompt,
+ "size": size,
+ "seconds": duration,
+ "fps": fps,
+ },
+ )
+
+ print(f"\nStatus code: {response_video.status_code}")
+
+ if response_video.status_code == 200:
+ with open(output_file, "wb") as f:
+ f.write(response_video.content)
+ print(f"ā Video saved to: {output_file}")
+
+ print("\n" + "=" * 80)
+ print("ā Sync video generation test completed successfully!")
+ print("=" * 80)
+ return True
+ else:
+ print(f"\nā Error: Server returned status {response_video.status_code}")
+ print(f"Response: {response_video.text}")
+ return False
+
+ except Exception as e:
+ print(f"\nā Error: {e}")
+ import traceback
+
+ traceback.print_exc()
+ return False
+
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser(
+ description="Test synchronous video generation API with T2V and TI2V modes",
+ formatter_class=argparse.RawDescriptionHelpFormatter,
+ epilog="""
+Examples:
+ # Text-to-Video (T2V)
+ python sync_video_gen.py --mode t2v --prompt "A cool cat on a motorcycle"
+
+ # Text+Image-to-Video (TI2V)
+ python sync_video_gen.py --mode ti2v \\
+ --prompt "She turns around and smiles, then slowly walks out of the frame" \\
+ --image ./media/woman_skyline_original_720p.jpeg
+
+ # Custom parameters
+ python sync_video_gen.py --mode t2v \\
+ --prompt "A serene sunset over the ocean" \\
+ --duration 5.0 --fps 30 --size 512x512 \\
+ --output my_video.mp4
+ """,
+ )
+
+ # Mode selection
+ parser.add_argument(
+ "--mode",
+ choices=["t2v", "ti2v"],
+ default="t2v",
+ help="Generation mode: t2v (Text-to-Video) or ti2v (Text+Image-to-Video)",
+ )
+
+ # Required parameters
+ parser.add_argument(
+ "--prompt",
+ type=str,
+ default="A video of a cute cat playing with a ball in the park",
+ help="Text prompt for video generation",
+ )
+
+ # TI2V mode parameters
+ parser.add_argument(
+ "--image",
+ "--input-reference",
+ type=str,
+ default=None,
+ help="Path to reference image (required for ti2v mode)",
+ )
+
+ # Optional parameters
+ parser.add_argument(
+ "--base-url",
+ type=str,
+ default="http://localhost:8000/v1",
+ help="Base URL of the API server",
+ )
+ parser.add_argument("--model", type=str, default="wan", help="Model name to use")
+ parser.add_argument(
+ "--duration", "--seconds", type=float, default=4.0, help="Video duration in seconds"
+ )
+ parser.add_argument("--fps", type=int, default=24, help="Frames per second")
+ parser.add_argument(
+ "--size",
+ type=str,
+ default="256x256",
+ help="Video resolution in WxH format (e.g., 1280x720)",
+ )
+ parser.add_argument(
+ "--output", type=str, default="output_sync.mp4", help="Output video file path"
+ )
+
+ args = parser.parse_args()
+
+ # Validate ti2v mode requirements
+ if args.mode == "ti2v" and not args.image:
+ parser.error("--image is required when using --mode ti2v")
+
+ # Display configuration
+ print("\n" + "=" * 80)
+ print("Synchronous Video Generation Test")
+ print("=" * 80)
+ print(f"Base URL: {args.base_url}")
+ print(f"Mode: {args.mode.upper()}")
+ print()
+
+ # Test sync video generation
+ success = test_sync_video_generation(
+ base_url=args.base_url,
+ model=args.model,
+ prompt=args.prompt,
+ input_reference=args.image,
+ duration=args.duration,
+ fps=args.fps,
+ size=args.size,
+ output_file=args.output,
+ )
+
+ sys.exit(0 if success else 1)
diff --git a/examples/visual_gen/visual_gen_examples.sh b/examples/visual_gen/visual_gen_examples.sh
new file mode 100755
index 0000000000..b769760203
--- /dev/null
+++ b/examples/visual_gen/visual_gen_examples.sh
@@ -0,0 +1,238 @@
+#!/bin/bash
+# Visual Generation Examples - Test different models and configurations
+#
+# This script runs a comprehensive suite of visual generation examples including:
+# - WAN T2V: Baseline, TeaCache, CFG parallelism, Ulysses parallelism, and combinations
+# - WAN I2V: Baseline, TeaCache, CFG parallelism, Ulysses parallelism, and combinations
+#
+# The script automatically detects GPU count and runs appropriate examples:
+# - 1 GPU: Single-GPU examples only
+# - 2 GPUs: + CFG parallelism, Ulysses parallelism
+# - 4 GPUs: + CFG + Ulysses combined
+# - 8 GPUs: + Large-scale high-resolution examples
+#
+# Usage:
+# export MODEL_ROOT=/path/to/models # required
+# # Optional: PROJECT_ROOT auto-detected when run from examples/visual_gen
+# cd examples/visual_gen && ./visual_gen_examples.sh
+#
+# Or inline:
+# MODEL_ROOT=/llm-models ./visual_gen_examples.sh
+
+set -e # Exit on error
+
+# Environment variables with defaults
+# PROJECT_ROOT: auto-detect repo root when run from examples/visual_gen
+SCRIPT_DIR="$(cd "$(dirname "$0")" && pwd)"
+PROJECT_ROOT=${PROJECT_ROOT:-"$(cd "${SCRIPT_DIR}/../.." && pwd)"}
+MODEL_ROOT=${MODEL_ROOT:-"/llm-models"}
+
+# Log configuration
+export TLLM_LOG_LEVEL=${TLLM_LOG_LEVEL:-"INFO"}
+
+echo "============================================"
+echo "Visual Generation Examples"
+echo "============================================"
+echo "PROJECT_ROOT: $PROJECT_ROOT"
+echo "MODEL_ROOT: $MODEL_ROOT"
+echo "LOG_LEVEL: $TLLM_LOG_LEVEL"
+echo "============================================"
+echo ""
+
+
+# Detect GPU count
+if command -v nvidia-smi &> /dev/null; then
+ GPU_COUNT=$(nvidia-smi --query-gpu=name --format=csv,noheader | wc -l)
+ echo "Detected $GPU_COUNT GPU(s)"
+ if [ "$GPU_COUNT" -lt 2 ]; then
+ echo "Note: Multi-GPU examples will be skipped"
+ SKIP_MULTI_GPU=1
+ elif [ "$GPU_COUNT" -ge 8 ]; then
+ echo "Note: Will run all examples including 8-GPU configurations"
+ elif [ "$GPU_COUNT" -ge 4 ]; then
+ echo "Note: Will run examples up to 4-GPU configurations"
+ else
+ echo "Note: Will run 2-GPU examples only"
+ fi
+else
+ echo "WARNING: nvidia-smi not found. Assuming single GPU."
+ GPU_COUNT=1
+ SKIP_MULTI_GPU=1
+fi
+echo ""
+
+#############################################
+# WAN (Wan2.1) Text-to-Video Examples
+#############################################
+# Demonstrates:
+# - Single GPU: Baseline and TeaCache
+# - 2 GPUs: CFG only, Ulysses only
+# - 4 GPUs: CFG + Ulysses combined
+# - 8 GPUs: Large-scale parallelism
+#############################################
+
+echo "=== WAN Example 1: Baseline (no optimization) ==="
+python ${PROJECT_ROOT}/examples/visual_gen/visual_gen_wan_t2v.py \
+ --height 480 \
+ --width 832 \
+ --num_frames 33 \
+ --model_path ${MODEL_ROOT}/Wan2.1-T2V-1.3B-Diffusers/ \
+ --prompt "A cute cat playing piano" \
+ --output_path wan_cat_piano.png
+
+echo ""
+echo "=== WAN Example 2: With TeaCache ==="
+python ${PROJECT_ROOT}/examples/visual_gen/visual_gen_wan_t2v.py \
+ --height 480 \
+ --width 832 \
+ --num_frames 33 \
+ --model_path ${MODEL_ROOT}/Wan2.1-T2V-1.3B-Diffusers \
+ --prompt "A cute cat playing piano" \
+ --output_path wan_cat_piano_teacache.png \
+ --enable_teacache
+
+if [ -z "$SKIP_MULTI_GPU" ]; then
+ echo ""
+ echo "=== WAN Example 3: CFG Only (2 GPUs) ==="
+ python ${PROJECT_ROOT}/examples/visual_gen/visual_gen_wan_t2v.py \
+ --height 480 \
+ --width 832 \
+ --num_frames 33 \
+ --model_path ${MODEL_ROOT}/Wan2.1-T2V-1.3B-Diffusers/ \
+ --prompt "A cute cat playing piano" \
+ --output_path wan_cfg_2gpu.mp4 \
+ --attention_backend TRTLLM \
+ --cfg_size 2 \
+ --ulysses_size 1
+else
+ echo ""
+ echo "=== WAN Example 3: Skipped (requires 2 GPUs) ==="
+fi
+
+if [ -z "$SKIP_MULTI_GPU" ]; then
+ echo ""
+ echo "=== WAN Example 4: Ulysses Only (2 GPUs) ==="
+ python ${PROJECT_ROOT}/examples/visual_gen/visual_gen_wan_t2v.py \
+ --height 480 \
+ --width 832 \
+ --num_frames 33 \
+ --model_path ${MODEL_ROOT}/Wan2.1-T2V-1.3B-Diffusers/ \
+ --prompt "A cute cat playing piano" \
+ --output_path wan_ulysses_2gpu.mp4 \
+ --attention_backend TRTLLM \
+ --cfg_size 1 \
+ --ulysses_size 2
+else
+ echo ""
+ echo "=== WAN Example 4: Skipped (requires 2 GPUs) ==="
+fi
+
+if [ "$GPU_COUNT" -ge 4 ]; then
+ echo ""
+ echo "=== WAN Example 5: CFG + Ulysses (4 GPUs) ==="
+ python ${PROJECT_ROOT}/examples/visual_gen/visual_gen_wan_t2v.py \
+ --height 480 \
+ --width 832 \
+ --num_frames 33 \
+ --model_path ${MODEL_ROOT}/Wan2.1-T2V-1.3B-Diffusers/ \
+ --prompt "A cute cat playing piano" \
+ --output_path wan_cfg_ulysses_4gpu.mp4 \
+ --attention_backend TRTLLM \
+ --cfg_size 2 \
+ --ulysses_size 2
+else
+ echo ""
+ echo "=== WAN Example 5: Skipped (requires 4 GPUs) ==="
+fi
+
+if [ "$GPU_COUNT" -ge 8 ]; then
+ echo ""
+ echo "=== WAN Example 6: Large-Scale (8 GPUs) ==="
+ python ${PROJECT_ROOT}/examples/visual_gen/visual_gen_wan_t2v.py \
+ --height 480 \
+ --width 832 \
+ --num_frames 33 \
+ --model_path ${MODEL_ROOT}/Wan2.1-T2V-1.3B-Diffusers/ \
+ --prompt "A cute cat playing piano" \
+ --output_path wan_cfg_ulysses_8gpu.mp4 \
+ --attention_backend TRTLLM \
+ --cfg_size 2 \
+ --ulysses_size 4
+else
+ echo ""
+ echo "=== WAN Example 6: Skipped (requires 8 GPUs) ==="
+fi
+
+#############################################
+# WAN 2.2 (Two-Stage) Text-to-Video Examples
+#############################################
+
+echo ""
+echo "=== WAN 2.2 T2V Example: Two-stage with optimizations (FP8 + TRT-LLM + TeaCache) ==="
+python ${PROJECT_ROOT}/examples/visual_gen/visual_gen_wan_t2v.py \
+ --height 720 \
+ --width 1280 \
+ --num_frames 81 \
+ --model_path ${MODEL_ROOT}/Wan2.2-T2V-A14B-Diffusers \
+ --prompt "A cute cat playing piano" \
+ --output_path wan22_t2v_cat_piano_optimized.gif \
+ --linear_type trtllm-fp8-blockwise \
+ --attention_backend TRTLLM \
+ --enable_teacache \
+ --teacache_thresh 0.2 \
+ --guidance_scale 3.0 \
+ --guidance_scale_2 2.5 \
+ --boundary_ratio 0.85
+
+#############################################
+# WAN 2.1 Image-to-Video Examples
+#############################################
+
+echo ""
+echo "=== WAN 2.1 I2V Example: Single-stage with optimizations (FP8 + TRT-LLM + TeaCache) ==="
+python ${PROJECT_ROOT}/examples/visual_gen/visual_gen_wan_i2v.py \
+ --height 480 \
+ --width 832 \
+ --num_frames 33 \
+ --model_path ${MODEL_ROOT}/Wan2.1-I2V-14B-480P-Diffusers \
+ --image_path ${PROJECT_ROOT}/examples/visual_gen/cat_piano.png \
+ --prompt "It snows as the cat plays piano, lots of snow \
+ appearing all over the screen, snowflakes, blizzard,
+ gradually more snow" \
+ --negative_prompt "blurry, low quality" \
+ --output_path wan21_i2v_cat_piano_optimized.gif \
+ --linear_type trtllm-fp8-per-tensor \
+ --attention_backend TRTLLM \
+ --enable_teacache \
+ --teacache_thresh 0.2 \
+ --guidance_scale 6.0
+
+#############################################
+# WAN 2.2 (Two-Stage) Image-to-Video Examples
+#############################################
+
+echo ""
+echo "=== WAN 2.2 I2V Example: Two-stage with optimizations (FP8 + TRT-LLM + TeaCache) ==="
+python ${PROJECT_ROOT}/examples/visual_gen/visual_gen_wan_i2v.py \
+ --height 480 \
+ --width 832 \
+ --num_frames 81 \
+ --model_path ${MODEL_ROOT}/Wan2.2-I2V-A14B-Diffusers \
+ --image_path ${PROJECT_ROOT}/examples/visual_gen/cat_piano.png \
+ --prompt "It snows as the cat plays piano, lots of snow \
+ appearing all over the screen, snowflakes, blizzard,
+ gradually more snow" \
+ --negative_prompt "blurry, low quality" \
+ --output_path wan22_i2v_cat_piano_optimized.gif \
+ --linear_type trtllm-fp8-blockwise \
+ --attention_backend TRTLLM \
+ --enable_teacache \
+ --teacache_thresh 0.2 \
+ --guidance_scale 6.0 \
+ --guidance_scale_2 5.0 \
+ --boundary_ratio 0.85
+
+echo ""
+echo "============================================"
+echo "All examples completed successfully!"
+echo "============================================"
diff --git a/examples/visual_gen/visual_gen_wan_i2v.py b/examples/visual_gen/visual_gen_wan_i2v.py
new file mode 100644
index 0000000000..3b76470bb6
--- /dev/null
+++ b/examples/visual_gen/visual_gen_wan_i2v.py
@@ -0,0 +1,226 @@
+#!/usr/bin/env python3
+"""WAN Image-to-Video generation using TensorRT-LLM Visual Generation."""
+
+import argparse
+import time
+
+from output_handler import OutputHandler
+
+from tensorrt_llm import logger
+from tensorrt_llm.llmapi.visual_gen import VisualGen, VisualGenParams
+
+# Set logger level to ensure timing logs are printed
+logger.set_level("info")
+
+
+def parse_args():
+ parser = argparse.ArgumentParser(
+ description="TRTLLM VisualGen - Wan Image-to-Video Inference Example (supports Wan 2.1 and Wan 2.2)"
+ )
+
+ # Model & Input
+ parser.add_argument(
+ "--model_path",
+ type=str,
+ required=True,
+ help="Path to Wan I2V Diffusers model directory (2.1 or 2.2)",
+ )
+ parser.add_argument(
+ "--image_path",
+ type=str,
+ required=True,
+ help="Path to input image for I2V conditioning",
+ )
+ parser.add_argument(
+ "--last_image_path",
+ type=str,
+ default=None,
+ help="Optional path to last frame image for interpolation (Wan 2.1 only)",
+ )
+ parser.add_argument("--prompt", type=str, required=True, help="Text prompt for generation")
+ parser.add_argument(
+ "--negative_prompt",
+ type=str,
+ default=None,
+ help="Negative prompt. Default is model-specific.",
+ )
+ parser.add_argument(
+ "--output_path",
+ type=str,
+ default="output.png",
+ help="Path to save the output image/video frame",
+ )
+
+ # Generation Params
+ parser.add_argument("--height", type=int, default=720, help="Video height")
+ parser.add_argument("--width", type=int, default=1280, help="Video width")
+ parser.add_argument("--num_frames", type=int, default=81, help="Number of frames to generate")
+ parser.add_argument(
+ "--steps",
+ type=int,
+ default=None,
+ help="Number of denoising steps (default: auto-detect, 50 for Wan2.1, 40 for Wan2.2)",
+ )
+ parser.add_argument(
+ "--guidance_scale",
+ type=float,
+ default=None,
+ help="Guidance scale (default: auto-detect, 5.0 for Wan2.1, 4.0 for Wan2.2)",
+ )
+ parser.add_argument(
+ "--guidance_scale_2",
+ type=float,
+ default=None,
+ help="Second-stage guidance scale for Wan2.2 two-stage denoising (default: 3.0)",
+ )
+ parser.add_argument(
+ "--boundary_ratio",
+ type=float,
+ default=None,
+ help="Custom boundary ratio for two-stage denoising (default: auto-detect)",
+ )
+ parser.add_argument("--seed", type=int, default=42, help="Random seed")
+
+ # TeaCache Arguments
+ parser.add_argument(
+ "--enable_teacache", action="store_true", help="Enable TeaCache acceleration"
+ )
+ parser.add_argument(
+ "--teacache_thresh",
+ type=float,
+ default=0.2,
+ help="TeaCache similarity threshold (rel_l1_thresh)",
+ )
+
+ # Quantization
+ parser.add_argument(
+ "--linear_type",
+ type=str,
+ default="default",
+ choices=["default", "trtllm-fp8-per-tensor", "trtllm-fp8-blockwise", "svd-nvfp4"],
+ help="Linear layer quantization type",
+ )
+
+ # Attention Backend
+ parser.add_argument(
+ "--attention_backend",
+ type=str,
+ default="VANILLA",
+ choices=["VANILLA", "TRTLLM"],
+ help="Attention backend (VANILLA: PyTorch SDPA, TRTLLM: optimized kernels). "
+ "Note: TRTLLM automatically falls back to VANILLA for cross-attention.",
+ )
+
+ # Parallelism
+ parser.add_argument(
+ "--cfg_size",
+ type=int,
+ default=1,
+ choices=[1, 2],
+ help="CFG parallel size (1 or 2). Set to 2 for CFG Parallelism.",
+ )
+ parser.add_argument(
+ "--ulysses_size",
+ type=int,
+ default=1,
+ help="Ulysses (sequence) parallel size within each CFG group.",
+ )
+
+ return parser.parse_args()
+
+
+def main():
+ args = parse_args()
+
+ # world_size = cfg_size * ulysses_size
+ # Example: cfg_size=2, ulysses_size=4 -> 8 GPUs
+ # GPU 0-3: CFG group 0 (positive prompt), internal Ulysses parallel
+ # GPU 4-7: CFG group 1 (negative prompt), internal Ulysses parallel
+ n_workers = args.cfg_size * args.ulysses_size
+
+ # Convert linear_type to quant_config
+ quant_config = None
+ if args.linear_type == "trtllm-fp8-per-tensor":
+ quant_config = {"quant_algo": "FP8", "dynamic": True}
+ elif args.linear_type == "trtllm-fp8-blockwise":
+ quant_config = {"quant_algo": "FP8_BLOCK_SCALES", "dynamic": True}
+ elif args.linear_type == "svd-nvfp4":
+ quant_config = {"quant_algo": "NVFP4", "dynamic": True}
+
+ # 1. Setup Configuration
+ diffusion_config = {
+ "model_type": "wan2",
+ "attention": {
+ "backend": args.attention_backend,
+ },
+ "teacache": {
+ "enable_teacache": args.enable_teacache,
+ "teacache_thresh": args.teacache_thresh,
+ },
+ "parallel": {
+ "dit_cfg_size": args.cfg_size,
+ "dit_ulysses_size": args.ulysses_size,
+ },
+ }
+
+ # Add quant_config if specified
+ if quant_config is not None:
+ diffusion_config["quant_config"] = quant_config
+
+ # 2. Initialize VisualGen
+ logger.info(
+ f"Initializing VisualGen: world_size={n_workers} "
+ f"(cfg_size={args.cfg_size}, ulysses_size={args.ulysses_size})"
+ )
+ visual_gen = VisualGen(
+ model_path=args.model_path,
+ n_workers=n_workers,
+ diffusion_config=diffusion_config,
+ )
+
+ try:
+ # 2. Run Inference
+ logger.info(f"Generating video for prompt: '{args.prompt}'")
+ logger.info(f"Negative prompt: '{args.negative_prompt}'")
+ logger.info(f"Input image: {args.image_path}")
+ if args.last_image_path:
+ logger.info(f"Last frame image: {args.last_image_path}")
+ logger.info(
+ f"Resolution: {args.height}x{args.width}, Frames: {args.num_frames}, Steps: {args.steps}"
+ )
+
+ start_time = time.time()
+
+ # Build parameters with explicit I2V and Wan 2.2 support
+ output = visual_gen.generate(
+ inputs={
+ "prompt": args.prompt,
+ "negative_prompt": args.negative_prompt,
+ },
+ params=VisualGenParams(
+ height=args.height,
+ width=args.width,
+ num_inference_steps=args.steps,
+ guidance_scale=args.guidance_scale,
+ seed=args.seed,
+ num_frames=args.num_frames,
+ input_reference=args.image_path,
+ last_image=args.last_image_path if args.last_image_path else None,
+ guidance_scale_2=args.guidance_scale_2,
+ boundary_ratio=args.boundary_ratio,
+ ),
+ )
+
+ end_time = time.time()
+ logger.info(f"Generation completed in {end_time - start_time:.2f}s")
+
+ # 3. Save Output
+ OutputHandler.save(output, args.output_path, frame_rate=16.0)
+
+ finally:
+ # 4. Shutdown
+ visual_gen.shutdown()
+
+
+if __name__ == "__main__":
+ main()
diff --git a/examples/visual_gen/visual_gen_wan_t2v.py b/examples/visual_gen/visual_gen_wan_t2v.py
new file mode 100755
index 0000000000..30c55e4a17
--- /dev/null
+++ b/examples/visual_gen/visual_gen_wan_t2v.py
@@ -0,0 +1,228 @@
+#!/usr/bin/env python3
+"""WAN Text-to-Video generation using TensorRT-LLM Visual Generation."""
+
+import argparse
+import time
+
+from output_handler import OutputHandler
+
+from tensorrt_llm import logger
+from tensorrt_llm.llmapi.visual_gen import VisualGen, VisualGenParams
+
+# Set logger level to ensure timing logs are printed
+logger.set_level("info")
+
+
+def parse_args():
+ parser = argparse.ArgumentParser(
+ description="TRTLLM VisualGen - Wan Text-to-Video Inference Example (supports Wan 2.1 and Wan 2.2)"
+ )
+
+ # Model & Input
+ parser.add_argument(
+ "--model_path",
+ type=str,
+ required=True,
+ help="Local path or HuggingFace Hub model ID (e.g., Wan-AI/Wan2.1-T2V-1.3B-Diffusers)",
+ )
+ parser.add_argument(
+ "--revision",
+ type=str,
+ default=None,
+ help="HuggingFace Hub revision (branch, tag, or commit SHA)",
+ )
+ parser.add_argument("--prompt", type=str, required=True, help="Text prompt for generation")
+ parser.add_argument(
+ "--negative_prompt",
+ type=str,
+ default=None,
+ help="Negative prompt. Default is model-specific.",
+ )
+ parser.add_argument(
+ "--output_path",
+ type=str,
+ default="output.png",
+ help="Path to save the output image/video frame",
+ )
+
+ # Generation Params
+ parser.add_argument("--height", type=int, default=720, help="Video height")
+ parser.add_argument("--width", type=int, default=1280, help="Video width")
+ parser.add_argument("--num_frames", type=int, default=81, help="Number of frames to generate")
+ parser.add_argument(
+ "--steps",
+ type=int,
+ default=None,
+ help="Number of denoising steps (default: auto-detect, 50 for Wan2.1, 40 for Wan2.2)",
+ )
+ parser.add_argument(
+ "--guidance_scale",
+ type=float,
+ default=None,
+ help="Guidance scale (default: auto-detect, 5.0 for Wan2.1, 4.0 for Wan2.2)",
+ )
+ parser.add_argument(
+ "--guidance_scale_2",
+ type=float,
+ default=None,
+ help="Second-stage guidance scale for Wan2.2 two-stage denoising (default: 3.0)",
+ )
+ parser.add_argument(
+ "--boundary_ratio",
+ type=float,
+ default=None,
+ help="Custom boundary ratio for two-stage denoising (default: auto-detect)",
+ )
+ parser.add_argument("--seed", type=int, default=42, help="Random seed")
+
+ # TeaCache Arguments
+ parser.add_argument(
+ "--enable_teacache", action="store_true", help="Enable TeaCache acceleration"
+ )
+ parser.add_argument(
+ "--teacache_thresh",
+ type=float,
+ default=0.2,
+ help="TeaCache similarity threshold (rel_l1_thresh)",
+ )
+
+ # Quantization
+ parser.add_argument(
+ "--linear_type",
+ type=str,
+ default="default",
+ choices=["default", "trtllm-fp8-per-tensor", "trtllm-fp8-blockwise", "svd-nvfp4"],
+ help="Linear layer quantization type",
+ )
+
+ # Attention Backend
+ parser.add_argument(
+ "--attention_backend",
+ type=str,
+ default="VANILLA",
+ choices=["VANILLA", "TRTLLM"],
+ help="Attention backend (VANILLA: PyTorch SDPA, TRTLLM: optimized kernels). "
+ "Note: TRTLLM automatically falls back to VANILLA for cross-attention.",
+ )
+
+ # Parallelism
+ parser.add_argument(
+ "--cfg_size",
+ type=int,
+ default=1,
+ choices=[1, 2],
+ help="CFG parallel size (1 or 2). "
+ "Distributes positive/negative prompts across GPUs. "
+ "Example: cfg_size=2 on 4 GPUs -> 2 GPUs per prompt.",
+ )
+ parser.add_argument(
+ "--ulysses_size",
+ type=int,
+ default=1,
+ help="Ulysses sequence parallel size within each CFG group. "
+ "Distributes sequence across GPUs for longer sequences. "
+ "Requirements: num_heads (12) and sequence length must both be divisible by ulysses_size. "
+ "Example: ulysses_size=2 on 4 GPUs with cfg_size=2 -> "
+ "2 CFG groups Ć 2 Ulysses ranks = 4 GPUs total.",
+ )
+
+ return parser.parse_args()
+
+
+def main():
+ args = parse_args()
+
+ # Total workers: cfg_size Ć ulysses_size
+ # See ParallelConfig in config.py for detailed parallelism strategy and examples
+ n_workers = args.cfg_size * args.ulysses_size
+
+ # Log Ulysses configuration (validation happens in setup_sequence_parallelism)
+ if args.ulysses_size > 1:
+ num_heads = 12 # WAN has 12 attention heads
+ logger.info(
+ f"Using Ulysses sequence parallelism: "
+ f"{num_heads} heads / {args.ulysses_size} ranks = "
+ f"{num_heads // args.ulysses_size} heads per GPU"
+ )
+
+ # Convert linear_type to quant_config
+ quant_config = None
+ if args.linear_type == "trtllm-fp8-per-tensor":
+ quant_config = {"quant_algo": "FP8", "dynamic": True}
+ elif args.linear_type == "trtllm-fp8-blockwise":
+ quant_config = {"quant_algo": "FP8_BLOCK_SCALES", "dynamic": True}
+ elif args.linear_type == "svd-nvfp4":
+ quant_config = {"quant_algo": "NVFP4", "dynamic": True}
+
+ # 1. Setup Configuration
+ diffusion_config = {
+ "model_type": "wan2",
+ "revision": args.revision,
+ "attention": {
+ "backend": args.attention_backend,
+ },
+ "teacache": {
+ "enable_teacache": args.enable_teacache,
+ "teacache_thresh": args.teacache_thresh,
+ },
+ "parallel": {
+ "dit_cfg_size": args.cfg_size,
+ "dit_ulysses_size": args.ulysses_size,
+ },
+ }
+
+ # Add quant_config if specified
+ if quant_config is not None:
+ diffusion_config["quant_config"] = quant_config
+
+ # 2. Initialize VisualGen
+ logger.info(
+ f"Initializing VisualGen: world_size={n_workers} "
+ f"(cfg_size={args.cfg_size}, ulysses_size={args.ulysses_size})"
+ )
+ visual_gen = VisualGen(
+ model_path=args.model_path,
+ n_workers=n_workers,
+ diffusion_config=diffusion_config,
+ )
+
+ try:
+ # 2. Run Inference
+ logger.info(f"Generating video for prompt: '{args.prompt}'")
+ logger.info(f"Negative prompt: '{args.negative_prompt}'")
+ logger.info(
+ f"Resolution: {args.height}x{args.width}, Frames: {args.num_frames}, Steps: {args.steps}"
+ )
+
+ start_time = time.time()
+
+ output = visual_gen.generate(
+ inputs={
+ "prompt": args.prompt,
+ "negative_prompt": args.negative_prompt,
+ },
+ params=VisualGenParams(
+ height=args.height,
+ width=args.width,
+ num_inference_steps=args.steps,
+ guidance_scale=args.guidance_scale,
+ seed=args.seed,
+ num_frames=args.num_frames,
+ guidance_scale_2=args.guidance_scale_2,
+ boundary_ratio=args.boundary_ratio,
+ ),
+ )
+
+ end_time = time.time()
+ logger.info(f"Generation completed in {end_time - start_time:.2f}s")
+
+ # 3. Save Output
+ OutputHandler.save(output, args.output_path, frame_rate=16.0)
+
+ finally:
+ # 4. Shutdown
+ visual_gen.shutdown()
+
+
+if __name__ == "__main__":
+ main()
diff --git a/requirements.txt b/requirements.txt
index 0bada4a2d4..d2aa81843d 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -83,3 +83,4 @@ llist
cuda-tile>=1.0.1
nvidia-cuda-tileiras>=13.1
etcd-sdk-python==0.0.7
+python-multipart
diff --git a/tensorrt_llm/_torch/distributed/__init__.py b/tensorrt_llm/_torch/distributed/__init__.py
index b8bfe4ffdf..2dafa88bf1 100644
--- a/tensorrt_llm/_torch/distributed/__init__.py
+++ b/tensorrt_llm/_torch/distributed/__init__.py
@@ -4,10 +4,11 @@ from .communicator import Distributed, MPIDist, TorchDist
from .moe_alltoall import MoeAlltoAll
from .ops import (AllReduce, AllReduceParams, AllReduceStrategy,
HelixAllToAllNative, MoEAllReduce, MoEAllReduceParams,
- allgather, alltoall_helix, cp_allgather, reducescatter,
- userbuffers_allreduce_finalize)
+ all_to_all_4d, allgather, alltoall_helix, cp_allgather,
+ reducescatter, userbuffers_allreduce_finalize)
__all__ = [
+ "all_to_all_4d",
"allgather",
"alltoall_helix",
"cp_allgather",
diff --git a/tensorrt_llm/_torch/distributed/ops.py b/tensorrt_llm/_torch/distributed/ops.py
index 84468dc612..525a825a3f 100644
--- a/tensorrt_llm/_torch/distributed/ops.py
+++ b/tensorrt_llm/_torch/distributed/ops.py
@@ -959,3 +959,126 @@ class MoEAllReduce(nn.Module):
nranks=self.mapping.tp_size,
eps=all_reduce_params.eps,
)
+
+
+def all_to_all_4d(
+ input: torch.Tensor,
+ scatter_dim: int,
+ gather_dim: int,
+ process_group: Optional[torch.distributed.ProcessGroup] = None,
+) -> torch.Tensor:
+ """
+ All-to-all for 4D tensors (batch, seq, heads, head_dim).
+
+ Redistributes a 4D tensor along two dimensions using all-to-all communication.
+ This is used for Ulysses-style sequence parallelism to transform between:
+ - Sequence sharding [B, S/P, H, D] ā Head sharding [B, S, H/P, D]
+ - Head sharding [B, S, H/P, D] ā Sequence sharding [B, S/P, H, D]
+
+ Args:
+ input: Input tensor with shape [batch, seq, heads, head_dim]
+ scatter_dim: Dimension to split and scatter (1 for seq, 2 for heads)
+ gather_dim: Dimension to gather (1 for seq, 2 for heads)
+ process_group: PyTorch distributed process group. If None, uses default process group.
+
+ Returns:
+ Redistributed tensor with same shape as input
+
+ Example:
+ # Transform from sequence sharding to head sharding
+ # Input: [B, S/P, H, D] (each rank has S/P sequence)
+ output = all_to_all_4d(input, scatter_dim=2, gather_dim=1, process_group=pg)
+ # Output: [B, S, H/P, D] (each rank has H/P heads)
+
+ # Transform back from head sharding to sequence sharding
+ output = all_to_all_4d(input, scatter_dim=1, gather_dim=2, process_group=pg)
+ """
+ # Only support PyTorch distributed mode (not MPI mode)
+ if not mpi_disabled():
+ raise NotImplementedError(
+ "all_to_all_4d currently only supports PyTorch distributed mode. "
+ "MPI mode is not supported.")
+
+ # Get world size from process group
+ world_size = torch.distributed.get_world_size(group=process_group)
+
+ # If world_size is 1, no communication needed
+ if world_size == 1:
+ return input
+
+ # Validate dimensions
+ assert scatter_dim in [1, 2], "scatter_dim must be 1 (seq) or 2 (heads)"
+ assert gather_dim in [1, 2], "gather_dim must be 1 (seq) or 2 (heads)"
+ assert scatter_dim != gather_dim, "scatter_dim and gather_dim must be different"
+
+ batch, seq, heads, head_dim = input.shape
+
+ # Validate that the scatter dimension is divisible by world_size
+ scatter_size = input.shape[scatter_dim]
+ assert scatter_size % world_size == 0, \
+ f"Dimension {scatter_dim} size {scatter_size} must be divisible by world_size {world_size}"
+
+ # For all-to-all, we need to:
+ # 1. Split input along scatter_dim into world_size chunks
+ # 2. Send chunk i to rank i
+ # 3. Receive chunk from each rank and concatenate along gather_dim
+
+ # Reshape for all-to-all: move scatter_dim chunks to a new dimension
+ if scatter_dim == 1: # Scatter along seq dimension
+ # [B, S, H, D] -> [B, P, S/P, H, D] where P = world_size
+ input_reshaped = input.view(batch, world_size, seq // world_size, heads,
+ head_dim)
+ # Transpose to group by destination rank: [B, P, S/P, H, D] -> [P, B, S/P, H, D]
+ input_transposed = input_reshaped.permute(1, 0, 2, 3, 4).contiguous()
+ else: # scatter_dim == 2, scatter along heads dimension
+ # [B, S, H, D] -> [B, S, P, H/P, D] where P = world_size
+ input_reshaped = input.view(batch, seq, world_size, heads // world_size,
+ head_dim)
+ # Transpose to group by destination rank: [B, S, P, H/P, D] -> [P, B, S, H/P, D]
+ input_transposed = input_reshaped.permute(2, 0, 1, 3, 4).contiguous()
+
+ # Flatten to [P * ...] for all-to-all communication
+ # Shape: [P, B, ...] -> [P * B * ...]
+ input_flat = input_transposed.flatten()
+ output_flat = torch.empty_like(input_flat)
+
+ # Perform all-to-all communication using PyTorch distributed
+ # all_to_all_single splits input into world_size chunks and exchanges them
+ torch.distributed.all_to_all_single(output_flat,
+ input_flat,
+ group=process_group)
+
+ # Reshape output back to [P, B, ...] form
+ output_transposed = output_flat.view_as(input_transposed)
+
+ # Transpose back and reshape to final form
+ if gather_dim == 1: # Gather along seq dimension
+ # [P, B, S/P, H, D] -> [B, P, S/P, H, D]
+ output_reshaped = output_transposed.permute(1, 0, 2, 3, 4).contiguous()
+ # [B, P, S/P, H, D] -> [B, S, H, D] where S = P * (S/P)
+ # When scattering heads and gathering seq: seq needs to be multiplied, heads needs to be divided
+ if scatter_dim == 2:
+ # Scattered heads, so we have H/P heads and need to gather S/P -> S sequence
+ gathered_seq = seq * world_size
+ sharded_heads = heads // world_size
+ output = output_reshaped.view(batch, gathered_seq, sharded_heads,
+ head_dim)
+ else:
+ # Scattered seq (should be impossible if gather_dim == 1), keep as is
+ output = output_reshaped.view(batch, seq, heads, head_dim)
+ else: # gather_dim == 2, gather along heads dimension
+ # [P, B, S, H/P, D] -> [B, S, P, H/P, D]
+ output_reshaped = output_transposed.permute(1, 2, 0, 3, 4).contiguous()
+ # [B, S, P, H/P, D] -> [B, S, H, D] where H = P * (H/P)
+ # When scattering seq and gathering heads: heads needs to be multiplied, seq needs to be divided
+ if scatter_dim == 1:
+ # Scattered seq, so we have S/P seq and need to gather H/P -> H heads
+ gathered_heads = heads * world_size
+ sharded_seq = seq // world_size
+ output = output_reshaped.view(batch, sharded_seq, gathered_heads,
+ head_dim)
+ else:
+ # Scattered heads (should be impossible if gather_dim == 2), keep as is
+ output = output_reshaped.view(batch, seq, heads, head_dim)
+
+ return output
diff --git a/tensorrt_llm/_torch/modules/linear.py b/tensorrt_llm/_torch/modules/linear.py
index 65811569ca..34581ff224 100644
--- a/tensorrt_llm/_torch/modules/linear.py
+++ b/tensorrt_llm/_torch/modules/linear.py
@@ -556,6 +556,13 @@ class FP8QDQLinearMethod(UnquantizedLinearMethod):
def apply(self, module: Linear, input: torch.Tensor,
bias: Optional[torch.Tensor]):
+
+ # Handle multi-dimensional inputs (e.g., 3D: batch, seq, hidden)
+ # GEMM ops require 2D matrices
+ original_shape = input.shape
+ if input.dim() > 2:
+ input = input.reshape(-1, input.shape[-1])
+
cur_input_scale = module.input_scale
if input.dtype != torch.float8_e4m3fn:
if module.input_scale is not None and not module.force_dynamic_quantization:
@@ -591,6 +598,11 @@ class FP8QDQLinearMethod(UnquantizedLinearMethod):
bias=None,
out_dtype=module.dtype or input.dtype,
)
+
+ # Reshape output back to original shape (with out_features as last dim)
+ if len(original_shape) > 2:
+ output = output.reshape(*original_shape[:-1], output.shape[-1])
+
if bias is not None:
output = output + bias
return output
@@ -975,6 +987,12 @@ class FP8BlockScalesLinearMethod(UnquantizedLinearMethod):
def apply(self, module: Linear, input: torch.Tensor,
bias: Optional[torch.Tensor]):
+ # Handle multi-dimensional inputs (e.g., 3D: batch, seq, hidden)
+ # GEMM ops require 2D matrices
+ original_shape = input.shape
+ if input.dim() > 2:
+ input = input.reshape(-1, input.shape[-1])
+
if input.dtype == torch.float8_e4m3fn:
input = input.to(torch.bfloat16) * module.input_scale
assert input.dtype == torch.bfloat16
@@ -1003,6 +1021,10 @@ class FP8BlockScalesLinearMethod(UnquantizedLinearMethod):
output = torch.ops.trtllm.fp8_block_scaling_gemm(
act_input_fp8, module.weight, act_input_sf, module.weight_scale)
+ # Reshape output back to original shape (with out_features as last dim)
+ if len(original_shape) > 2:
+ output = output.reshape(*original_shape[:-1], output.shape[-1])
+
if bias is not None:
output = output + bias
return output
@@ -1212,6 +1234,15 @@ class NVFP4LinearMethod(LinearMethodBase):
def apply(self, module: Linear, input: torch.Tensor,
bias: Optional[torch.Tensor]):
+ # Handle multi-dimensional inputs (e.g., 3D: batch, seq, hidden).
+ # GEMM requires 2D. Only plain tensors support for now, skip for
+ # tuple and Fp4QuantizedTensor.
+ original_shape = None
+ if not isinstance(input,
+ (tuple, Fp4QuantizedTensor)) and input.dim() > 2:
+ original_shape = input.shape
+ input = input.reshape(-1, input.shape[-1])
+
act_fp4, act_sf = self._input_prepare(module, input)
# Use unified interface - supports CUTLASS, cuBLASLt, CuteDSL
# Convert list to comma-separated string for torch.compile compatibility
@@ -1229,6 +1260,9 @@ class NVFP4LinearMethod(LinearMethodBase):
if output.shape[-1] > module.out_features:
output = output[..., :module.out_features].contiguous()
+ if original_shape is not None:
+ output = output.reshape(*original_shape[:-1], output.shape[-1])
+
if bias is not None:
output = output + bias
return output
diff --git a/tensorrt_llm/_torch/visual_gen/__init__.py b/tensorrt_llm/_torch/visual_gen/__init__.py
new file mode 100644
index 0000000000..c612522540
--- /dev/null
+++ b/tensorrt_llm/_torch/visual_gen/__init__.py
@@ -0,0 +1,45 @@
+"""Visual generation module for diffusion models."""
+
+from tensorrt_llm._torch.visual_gen.executor import (
+ DiffusionExecutor,
+ DiffusionRequest,
+ DiffusionResponse,
+)
+from tensorrt_llm._torch.visual_gen.output import MediaOutput
+
+# Checkpoint loading
+from .checkpoints import WeightLoader
+from .config import (
+ AttentionConfig,
+ DiffusionArgs,
+ DiffusionModelConfig,
+ ParallelConfig,
+ PipelineComponent,
+ PipelineConfig,
+ TeaCacheConfig,
+ discover_pipeline_components,
+)
+from .models import AutoPipeline, BasePipeline, WanPipeline
+from .pipeline_loader import PipelineLoader
+
+__all__ = [
+ # Config classes
+ "DiffusionArgs",
+ "DiffusionModelConfig",
+ "ParallelConfig",
+ "PipelineComponent",
+ "TeaCacheConfig",
+ # Checkpoint loading
+ "WeightLoader",
+ # Model loading
+ "PipelineLoader",
+ # Execution
+ "DiffusionExecutor",
+ "DiffusionRequest",
+ "DiffusionResponse",
+ "MediaOutput",
+ # Pipelines
+ "AutoPipeline",
+ "BasePipeline",
+ "WanPipeline",
+]
diff --git a/tensorrt_llm/_torch/visual_gen/attention_backend/__init__.py b/tensorrt_llm/_torch/visual_gen/attention_backend/__init__.py
new file mode 100644
index 0000000000..9cc0d3c272
--- /dev/null
+++ b/tensorrt_llm/_torch/visual_gen/attention_backend/__init__.py
@@ -0,0 +1,37 @@
+# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+# SPDX-License-Identifier: Apache-2.0
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""
+Visual Generation Attention Backend
+
+This module provides attention backend infrastructure for visual generation (diffusion) models.
+It reuses existing TRT-LLM attention backends (TrtllmAttention, VanillaAttention) with
+simplified metadata that doesn't require KV caching.
+"""
+
+from .interface import AttentionTensorLayout
+from .parallel import UlyssesAttention
+from .trtllm import TrtllmAttention, TrtllmAttentionMetadata
+from .utils import create_attention, get_visual_gen_attention_backend
+from .vanilla import VanillaAttention
+
+__all__ = [
+ "AttentionTensorLayout",
+ "get_visual_gen_attention_backend",
+ "create_attention",
+ "TrtllmAttention",
+ "TrtllmAttentionMetadata",
+ "UlyssesAttention",
+ "VanillaAttention",
+]
diff --git a/tensorrt_llm/_torch/visual_gen/attention_backend/interface.py b/tensorrt_llm/_torch/visual_gen/attention_backend/interface.py
new file mode 100644
index 0000000000..a32c3712b4
--- /dev/null
+++ b/tensorrt_llm/_torch/visual_gen/attention_backend/interface.py
@@ -0,0 +1,33 @@
+# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+# SPDX-License-Identifier: Apache-2.0
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""
+Visual Generation Attention Backend Interface
+
+Defines shared types and enums for attention backends.
+"""
+
+from enum import Enum
+
+
+class AttentionTensorLayout(str, Enum):
+ """
+ Tensor layout for attention backend input/output.
+
+ Backends declare their preferred layout so the attention module
+ can reshape tensors optimally before calling the backend.
+ """
+
+ NHD = "NHD" # [B, S, H, D] - batch, seq, heads, dim
+ HND = "HND" # [B, H, S, D] - batch, heads, seq, dim
diff --git a/tensorrt_llm/_torch/visual_gen/attention_backend/parallel.py b/tensorrt_llm/_torch/visual_gen/attention_backend/parallel.py
new file mode 100644
index 0000000000..a7e466423f
--- /dev/null
+++ b/tensorrt_llm/_torch/visual_gen/attention_backend/parallel.py
@@ -0,0 +1,162 @@
+# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+# SPDX-License-Identifier: Apache-2.0
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""
+Ulysses Sequence Parallelism Wrapper
+
+Wraps any attention backend with sequence parallelism via all-to-all
+communication. Not a standalone backend ā compose around a real backend
+(VANILLA/TRTLLM).
+
+Architecture:
+ Input: [B, S/P, H, D] (sequence sharded across P processes)
+ Step 1: All-to-All ā [B, S, H/P, D] (gather sequence, shard heads)
+ Step 2: Compute attention with wrapped backend (VANILLA or TRTLLM)
+ Step 3: All-to-All ā [B, S/P, H, D] (restore sequence sharding)
+ Output: [B, S/P, H, D] (sequence sharded)
+"""
+
+from typing import Optional
+
+import torch
+import torch.nn as nn
+
+from tensorrt_llm._torch.distributed import all_to_all_4d
+
+from .interface import AttentionTensorLayout
+
+
+class UlyssesAttention(nn.Module):
+ """
+ Ulysses Sequence Parallelism wrapper.
+
+ Wraps any attention backend with sequence parallelism via all-to-all.
+ Not a standalone backend ā compose around a real backend (VANILLA/TRTLLM).
+ """
+
+ def __init__(
+ self,
+ inner_backend: nn.Module,
+ process_group: Optional[torch.distributed.ProcessGroup] = None,
+ ):
+ super().__init__()
+ self.inner_backend = inner_backend
+ self.process_group = process_group
+ self._preferred_layout = AttentionTensorLayout.NHD
+
+ # Derive head info from inner backend
+ self.head_dim = inner_backend.head_dim
+ self.sharded_num_heads = inner_backend.num_heads
+ self.sharded_num_kv_heads = getattr(inner_backend, "num_kv_heads", self.sharded_num_heads)
+
+ # Get world size from process group
+ try:
+ self.world_size = torch.distributed.get_world_size(group=process_group)
+ except (RuntimeError, ValueError):
+ self.world_size = 1
+
+ # Full (unsharded) head counts for external interface
+ self.num_heads = self.sharded_num_heads * self.world_size
+ self.num_kv_heads = self.sharded_num_kv_heads * self.world_size
+
+ def forward(
+ self,
+ q: torch.Tensor,
+ k: torch.Tensor,
+ v: torch.Tensor,
+ batch_size: int,
+ attention_mask: Optional[torch.Tensor] = None,
+ **kwargs,
+ ) -> torch.Tensor:
+ """
+ Forward pass with Ulysses sequence parallelism.
+
+ Input/Output: [B, S/P, H, D] (sequence sharded)
+
+ Args:
+ q: Query tensor [B, S/P, H, D]
+ k: Key tensor [B, S/P, H, D]
+ v: Value tensor [B, S/P, H, D]
+ batch_size: Batch size
+ attention_mask: Optional attention mask
+
+ Returns:
+ Output tensor [B, S/P, H, D] (sequence sharded)
+
+ Note:
+ seq_len is computed from tensor shape after all-to-all, not passed as parameter.
+ """
+ # Step 1: All-to-All to gather full sequence, shard heads
+ # [B, S/P, H, D] -> [B, S, H/P, D]
+ if self.world_size > 1:
+ q = all_to_all_4d(q, scatter_dim=2, gather_dim=1, process_group=self.process_group)
+ k = all_to_all_4d(k, scatter_dim=2, gather_dim=1, process_group=self.process_group)
+ v = all_to_all_4d(v, scatter_dim=2, gather_dim=1, process_group=self.process_group)
+
+ seq_len_full = q.shape[1]
+ inner_layout = self.inner_backend.preferred_layout
+
+ # Step 2: Call wrapped backend for attention
+ # Transpose only if inner backend expects HND layout
+ if inner_layout == AttentionTensorLayout.HND:
+ # VANILLA expects [B, H/P, S, D]
+ q = q.transpose(1, 2)
+ k = k.transpose(1, 2)
+ v = v.transpose(1, 2)
+ # NHD backends (TRTLLM) keep [B, S, H/P, D] as-is
+
+ inner_kwargs = dict(
+ q=q,
+ k=k,
+ v=v,
+ batch_size=batch_size,
+ seq_len=seq_len_full,
+ )
+ if attention_mask is not None:
+ inner_kwargs["attention_mask"] = attention_mask
+ output = self.inner_backend.forward(**inner_kwargs)
+
+ # Convert output back to [B, S, H/P, D] for the reverse all-to-all
+ if inner_layout == AttentionTensorLayout.HND:
+ # VANILLA returns [B, H/P, S, D] -> transpose to [B, S, H/P, D]
+ output = output.transpose(1, 2).contiguous()
+ else:
+ # TRTLLM returns [B, S, (H/P)*D] (3D) -> reshape to [B, S, H/P, D]
+ if output.dim() == 3:
+ output = output.view(
+ batch_size, seq_len_full, self.sharded_num_heads, self.head_dim
+ )
+ output = output.contiguous()
+
+ # Step 3: All-to-All to restore sequence sharding
+ # [B, S, H/P, D] -> [B, S/P, H, D]
+ if self.world_size > 1:
+ output = all_to_all_4d(
+ output,
+ scatter_dim=1,
+ gather_dim=2,
+ process_group=self.process_group,
+ )
+
+ return output
+
+ @property
+ def preferred_layout(self) -> AttentionTensorLayout:
+ """Preferred tensor layout: [B, S, H, D]"""
+ return self._preferred_layout
+
+ @classmethod
+ def support_fused_qkv(cls) -> bool:
+ """This backend does not support fused QKV."""
+ return False
diff --git a/tensorrt_llm/_torch/visual_gen/attention_backend/trtllm.py b/tensorrt_llm/_torch/visual_gen/attention_backend/trtllm.py
new file mode 100644
index 0000000000..47b92ca27f
--- /dev/null
+++ b/tensorrt_llm/_torch/visual_gen/attention_backend/trtllm.py
@@ -0,0 +1,244 @@
+# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+# SPDX-License-Identifier: Apache-2.0
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""
+Diffusion TRTLLM Attention Backend
+
+Wraps TrtllmAttention with simplified metadata for visual generation (diffusion) models.
+Handles the specifics of no-KV-cache operation and fused QKV requirements.
+"""
+
+from typing import Optional, Union
+
+import torch
+
+from tensorrt_llm.mapping import Mapping
+from tensorrt_llm.models.modeling_utils import QuantConfig
+
+from ...attention_backend.interface import AttentionRuntimeFeatures, PredefinedAttentionMask
+from ...attention_backend.trtllm import TrtllmAttention as BaseTrtllmAttention
+from ...attention_backend.trtllm import TrtllmAttentionMetadata as BaseTrtllmAttentionMetadata
+from .interface import AttentionTensorLayout
+
+
+class TrtllmAttentionMetadata:
+ """
+ Simplified metadata adapter for diffusion models using TRTLLM backend.
+
+ Lazy initialization with auto-growing capacity:
+ - Metadata created only when capacity needs increase
+ - prepare() called only when seq_lens actually change
+ - Automatically reallocates when batch_size or seq_len exceeds current capacity
+
+ Args:
+ max_batch_size: Initial batch size hint. Will grow automatically if exceeded.
+ max_seq_len: Initial sequence length hint. Will grow automatically if exceeded.
+ device: Target device for tensors.
+ """
+
+ def __init__(
+ self,
+ max_batch_size: int = 16,
+ max_seq_len: int = 4096,
+ device: Optional[torch.device] = None,
+ ):
+ # These are initial hints, not hard limits - capacity grows as needed
+ self.max_batch_size = max_batch_size
+ self.max_seq_len = max_seq_len
+ self.device = device or torch.device("cuda")
+
+ # Lazily created BaseTrtllmAttentionMetadata
+ self._metadata: Optional[BaseTrtllmAttentionMetadata] = None
+
+ # Track allocated capacity
+ self._allocated_batch_size = 0
+ self._allocated_max_seq_len = 0
+
+ # Track prepared state
+ self._cached_seq_lens: Optional[torch.Tensor] = None
+ self._prepared = False
+
+ def _needs_new_metadata(self, batch_size: int, max_seq_len: int) -> bool:
+ """Check if we need to create new metadata (capacity change)."""
+ return (
+ self._metadata is None
+ or batch_size > self._allocated_batch_size
+ or max_seq_len > self._allocated_max_seq_len
+ )
+
+ def _needs_prepare(self, batch_size: int, seq_lens: torch.Tensor) -> bool:
+ """Check if we need to call prepare() (seq_lens changed)."""
+ if not self._prepared:
+ return True
+ if self._cached_seq_lens is None:
+ return True
+ if self._cached_seq_lens.shape[0] != batch_size:
+ return True
+ return not torch.equal(self._cached_seq_lens[:batch_size], seq_lens)
+
+ def _create_metadata(self, batch_size: int, max_seq_len: int) -> None:
+ """Create new metadata with given capacity."""
+ # Allocate with some headroom to avoid frequent reallocation
+ alloc_batch = max(batch_size, self._allocated_batch_size)
+ alloc_seq_len = max(max_seq_len, self._allocated_max_seq_len)
+
+ self._metadata = BaseTrtllmAttentionMetadata(
+ max_num_requests=alloc_batch,
+ max_num_tokens=alloc_batch * alloc_seq_len,
+ max_num_sequences=alloc_batch,
+ kv_cache_manager=None, # No KV cache for diffusion
+ mapping=Mapping(),
+ runtime_features=AttentionRuntimeFeatures(),
+ )
+
+ self._allocated_batch_size = alloc_batch
+ self._allocated_max_seq_len = alloc_seq_len
+ self._prepared = False # Reset prepare state on new metadata
+
+ def prepare(
+ self,
+ batch_size: int,
+ seq_lens: Union[int, torch.Tensor],
+ ) -> BaseTrtllmAttentionMetadata:
+ """
+ Prepare metadata for a forward pass.
+
+ Lazy behavior:
+ - Creates metadata only when capacity needs increase
+ - Calls prepare() only when seq_lens actually change
+ """
+ if isinstance(seq_lens, int):
+ seq_lens_tensor = torch.full((batch_size,), seq_lens, dtype=torch.int32)
+ else:
+ seq_lens_tensor = seq_lens.to(dtype=torch.int32)
+
+ max_seq_len = seq_lens_tensor.max().item()
+
+ if self._needs_new_metadata(batch_size, max_seq_len):
+ self._create_metadata(batch_size, max_seq_len)
+
+ if self._needs_prepare(batch_size, seq_lens_tensor):
+ self._metadata.seq_lens = seq_lens_tensor
+ self._metadata.num_contexts = batch_size
+ self._metadata.max_seq_len = max_seq_len
+ self._metadata.request_ids = list(range(batch_size))
+ self._metadata.prepare()
+
+ # Cache for next comparison
+ if self._cached_seq_lens is None or self._cached_seq_lens.shape[0] < batch_size:
+ self._cached_seq_lens = seq_lens_tensor.clone()
+ else:
+ self._cached_seq_lens[:batch_size].copy_(seq_lens_tensor)
+ self._prepared = True
+
+ return self._metadata
+
+
+class TrtllmAttention(BaseTrtllmAttention):
+ """
+ TRTLLM Attention wrapper for diffusion models.
+
+ Handles:
+ - Fused QKV requirement for TRTLLM kernel
+ - Metadata creation and preparation
+ - No KV cache operation
+ """
+
+ def __init__(
+ self,
+ layer_idx: int = 0,
+ num_heads: int = 8,
+ head_dim: int = 64,
+ num_kv_heads: Optional[int] = None,
+ quant_config: Optional[QuantConfig] = None,
+ dtype: Optional[torch.dtype] = None,
+ max_batch_size: int = 16,
+ max_seq_len: int = 4096,
+ ):
+ num_kv_heads = num_kv_heads or num_heads
+
+ super().__init__(
+ layer_idx=layer_idx,
+ num_heads=num_heads,
+ num_kv_heads=num_kv_heads,
+ head_dim=head_dim,
+ quant_config=quant_config,
+ dtype=dtype,
+ )
+
+ # TRTLLM expects flat [B*S, H*D] format
+ self._preferred_layout = AttentionTensorLayout.NHD
+
+ self.metadata = TrtllmAttentionMetadata(
+ max_batch_size=max_batch_size,
+ max_seq_len=max_seq_len,
+ )
+
+ def forward(
+ self,
+ q: torch.Tensor,
+ k: torch.Tensor,
+ v: torch.Tensor,
+ batch_size: int,
+ seq_len: int,
+ attention_mask: PredefinedAttentionMask = PredefinedAttentionMask.FULL,
+ seq_len_kv: Optional[int] = None,
+ **kwargs,
+ ) -> torch.Tensor:
+ """
+ Forward pass with automatic metadata handling.
+
+ For diffusion models, expects:
+ - Fused QKV: q contains [Q, K, V] concatenated, k and v are None
+ - OR separate Q, K, V which will be fused internally
+
+ Args:
+ q: Query tensor [num_tokens, hidden] or fused QKV [num_tokens, qkv_hidden]
+ k: Key tensor [num_tokens, kv_hidden] or None if fused
+ v: Value tensor [num_tokens, kv_hidden] or None if fused
+ batch_size: Batch size (required if not inferable)
+ seq_len: Sequence length for Q (required if not inferable)
+ attention_mask: Attention mask type
+ seq_len_kv: Sequence length for K/V (for cross-attention, defaults to seq_len)
+
+ Returns:
+ Output tensor [num_tokens, q_hidden]
+ """
+ # Handle cross-attention where K/V have different sequence length than Q
+ kv_seq_len = seq_len_kv if seq_len_kv is not None else seq_len
+
+ # Separate Q, K, V provided - fuse them
+ q = q.view(batch_size * seq_len, -1)
+ k = k.view(batch_size * kv_seq_len, -1)
+ v = v.view(batch_size * kv_seq_len, -1)
+ qkv = torch.cat([q, k, v], dim=-1)
+ prepared_metadata = self.metadata.prepare(batch_size, seq_len)
+ output = super().forward(
+ q=qkv,
+ k=None,
+ v=None,
+ metadata=prepared_metadata,
+ attention_mask=attention_mask,
+ )
+ output = output.view(batch_size, seq_len, -1)
+ return output
+
+ @property
+ def preferred_layout(self) -> AttentionTensorLayout:
+ """Return the preferred tensor layout for this backend."""
+ return self._preferred_layout
+
+ @classmethod
+ def support_fused_qkv(cls) -> bool:
+ return True
diff --git a/tensorrt_llm/_torch/visual_gen/attention_backend/utils.py b/tensorrt_llm/_torch/visual_gen/attention_backend/utils.py
new file mode 100644
index 0000000000..835e113c55
--- /dev/null
+++ b/tensorrt_llm/_torch/visual_gen/attention_backend/utils.py
@@ -0,0 +1,118 @@
+# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+# SPDX-License-Identifier: Apache-2.0
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""
+Visual Generation Attention Backend Utilities
+
+Factory functions for creating attention backends for visual generation models.
+Uses diffusion-specific wrappers (TrtllmAttention, VanillaAttention)
+that handle metadata preparation internally for simplified usage.
+"""
+
+from typing import TYPE_CHECKING, Optional, Type, Union
+
+import torch
+
+from tensorrt_llm.models.modeling_utils import QuantConfig
+
+# Lazy imports to avoid circular dependency
+if TYPE_CHECKING:
+ from .trtllm import TrtllmAttention
+ from .vanilla import VanillaAttention
+
+ # Type alias for diffusion attention backends
+ DiffusionAttentionBackend = Union[TrtllmAttention, VanillaAttention]
+
+
+def get_visual_gen_attention_backend(
+ backend_name: str,
+) -> Type["DiffusionAttentionBackend"]:
+ """
+ Get diffusion attention backend class by name.
+
+ Args:
+ backend_name: Backend identifier ("VANILLA", "TRTLLM")
+
+ Returns:
+ Diffusion attention backend class
+
+ Backend Selection Guide:
+ - "VANILLA": Full support for cross-attention (different Q/KV seq lengths)
+ Uses torch SDPA backend
+ - "TRTLLM": Optimized for self-attention (requires same Q/KV seq lengths)
+ Better performance but requires fused QKV
+ """
+ # Lazy imports to avoid circular dependency
+ from .trtllm import TrtllmAttention
+ from .vanilla import VanillaAttention
+
+ backend_name = backend_name.upper()
+
+ if backend_name == "VANILLA":
+ return VanillaAttention
+ elif backend_name == "TRTLLM":
+ return TrtllmAttention
+ else:
+ # Default to VANILLA for maximum compatibility
+ return VanillaAttention
+
+
+def create_attention(
+ backend: str,
+ layer_idx: int,
+ num_heads: int,
+ head_dim: int,
+ num_kv_heads: Optional[int] = None,
+ quant_config: Optional[QuantConfig] = None,
+ dtype: Optional[torch.dtype] = None,
+ max_batch_size: int = 16,
+ max_seq_len: int = 4096,
+ **kwargs,
+) -> "DiffusionAttentionBackend":
+ """
+ Factory function to create attention backend instance for visual generation.
+
+ Creates diffusion-specific attention backends that handle metadata preparation
+ internally, simplifying the forward() call.
+
+ Args:
+ backend: Backend identifier ("VANILLA", "TRTLLM")
+ layer_idx: Layer index in the model
+ num_heads: Number of attention heads
+ head_dim: Dimension per head
+ num_kv_heads: Number of KV heads (for GQA/MQA, defaults to num_heads)
+ quant_config: Optional quantization configuration
+ dtype: Data type for the attention
+ max_batch_size: Initial batch size for metadata pre-allocation. The backend
+ will automatically reallocate if larger batches are encountered.
+ max_seq_len: Initial sequence length for metadata pre-allocation. The backend
+ will automatically reallocate if longer sequences are encountered.
+ **kwargs: Additional backend-specific arguments
+
+ Returns:
+ Diffusion attention backend instance (TrtllmAttention or VanillaAttention)
+ """
+ attn_cls = get_visual_gen_attention_backend(backend)
+
+ return attn_cls(
+ layer_idx=layer_idx,
+ num_heads=num_heads,
+ head_dim=head_dim,
+ num_kv_heads=num_kv_heads,
+ quant_config=quant_config,
+ dtype=dtype,
+ max_batch_size=max_batch_size,
+ max_seq_len=max_seq_len,
+ **kwargs,
+ )
diff --git a/tensorrt_llm/_torch/visual_gen/attention_backend/vanilla.py b/tensorrt_llm/_torch/visual_gen/attention_backend/vanilla.py
new file mode 100644
index 0000000000..d9eb41ca55
--- /dev/null
+++ b/tensorrt_llm/_torch/visual_gen/attention_backend/vanilla.py
@@ -0,0 +1,126 @@
+# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+# SPDX-License-Identifier: Apache-2.0
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""
+Diffusion Vanilla Attention Backend
+
+Simple attention implementation for visual generation (diffusion) models using
+torch.nn.functional.scaled_dot_product_attention (SDPA).
+
+Supports both self-attention and cross-attention (different Q/KV sequence lengths).
+No KV cache - full recompute each diffusion step.
+"""
+
+import math
+from typing import Optional
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+from ...attention_backend.interface import PredefinedAttentionMask
+from .interface import AttentionTensorLayout
+
+
+class VanillaAttention(nn.Module):
+ """
+ Vanilla Attention for diffusion models using torch SDPA.
+
+ Uses torch.nn.functional.scaled_dot_product_attention which:
+ - Properly handles cross-attention (different Q/KV sequence lengths)
+ - Uses Flash Attention 2 when available (via SDPA backend selection)
+ - No KV cache needed for diffusion models
+
+ This is simpler than the LLM VanillaAttention which has complex
+ KV cache handling and uses flash_attn_varlen_func.
+ """
+
+ def __init__(
+ self,
+ layer_idx: int = 0,
+ num_heads: int = 8,
+ head_dim: int = 64,
+ num_kv_heads: Optional[int] = None,
+ dtype: Optional[torch.dtype] = None,
+ **kwargs,
+ ):
+ super().__init__()
+
+ self.layer_idx = layer_idx
+ self.num_heads = num_heads
+ self.head_dim = head_dim
+ self.num_kv_heads = num_kv_heads or num_heads
+ self.dtype = dtype
+ self.scale = 1.0 / math.sqrt(head_dim)
+
+ # SDPA expects [B, H, S, D] format
+ self._preferred_layout = AttentionTensorLayout.HND
+
+ def forward(
+ self,
+ q: torch.Tensor,
+ k: torch.Tensor,
+ v: torch.Tensor,
+ batch_size: int,
+ seq_len: int,
+ seq_len_kv: Optional[int] = None,
+ attention_mask: PredefinedAttentionMask = PredefinedAttentionMask.FULL,
+ **kwargs,
+ ) -> torch.Tensor:
+ """
+ Forward pass using torch SDPA.
+
+ Args:
+ q: Query tensor [num_tokens, num_heads * head_dim]
+ k: Key tensor [num_kv_tokens, num_kv_heads * head_dim]
+ v: Value tensor [num_kv_tokens, num_kv_heads * head_dim]
+ batch_size: Batch size
+ seq_len: Query sequence length
+ seq_len_kv: KV sequence length (for cross-attention)
+ attention_mask: Attention mask type (CAUSAL or FULL)
+
+ Returns:
+ Output tensor [num_tokens, num_heads * head_dim]
+ """
+ is_causal = attention_mask == PredefinedAttentionMask.CAUSAL
+
+ # Validate tensor shapes - flexible for Ulysses head sharding
+ # Expected: [batch_size, num_heads, seq_len, head_dim]
+ # Note: num_heads may be sharded (num_heads // ulysses_size) when using Ulysses
+ assert (
+ q.dim() == 4
+ and q.shape[0] == batch_size
+ and q.shape[2] == seq_len
+ and q.shape[3] == self.head_dim
+ ), (
+ f"Invalid q shape: expected [B={batch_size}, H, S={seq_len}, D={self.head_dim}], got {q.shape}"
+ )
+ assert k.dim() == 4 and k.shape[0] == batch_size and k.shape[3] == self.head_dim, (
+ f"Invalid k shape: expected [B={batch_size}, H_kv, S_kv, D={self.head_dim}], got {k.shape}"
+ )
+ assert v.dim() == 4 and v.shape[0] == batch_size and v.shape[3] == self.head_dim, (
+ f"Invalid v shape: expected [B={batch_size}, H_kv, S_kv, D={self.head_dim}], got {v.shape}"
+ )
+
+ # TODO: Maybe we need to enforce cuDNN backend here
+ return F.scaled_dot_product_attention(q, k, v, is_causal=is_causal, scale=self.scale)
+
+ @property
+ def preferred_layout(self) -> AttentionTensorLayout:
+ """Return the preferred tensor layout for this backend."""
+ return self._preferred_layout
+
+ @classmethod
+ def support_fused_qkv(cls) -> bool:
+ return False
diff --git a/tensorrt_llm/_torch/visual_gen/checkpoints/__init__.py b/tensorrt_llm/_torch/visual_gen/checkpoints/__init__.py
new file mode 100644
index 0000000000..6d3b138f90
--- /dev/null
+++ b/tensorrt_llm/_torch/visual_gen/checkpoints/__init__.py
@@ -0,0 +1,7 @@
+"""Diffusion model checkpoint loading utilities."""
+
+from .weight_loader import WeightLoader
+
+__all__ = [
+ "WeightLoader",
+]
diff --git a/tensorrt_llm/_torch/visual_gen/checkpoints/weight_loader.py b/tensorrt_llm/_torch/visual_gen/checkpoints/weight_loader.py
new file mode 100644
index 0000000000..77067fe9c9
--- /dev/null
+++ b/tensorrt_llm/_torch/visual_gen/checkpoints/weight_loader.py
@@ -0,0 +1,152 @@
+"""Weight loader for diffusion models."""
+
+import json
+from pathlib import Path
+from typing import Any, Dict, List, Union
+
+import torch
+import tqdm
+
+from tensorrt_llm._torch.models.checkpoints.base_weight_loader import BaseWeightLoader
+from tensorrt_llm._torch.visual_gen.config import PipelineComponent
+from tensorrt_llm.logger import logger
+from tensorrt_llm.mapping import Mapping
+
+
+class WeightLoader(BaseWeightLoader):
+ """
+ Weight loader for diffusion models.
+
+ Loads weights from safetensors/bin files, similar to HfWeightLoader
+ but simpler (no parallel loading optimization for now).
+
+ Supports loading multiple components (e.g., transformer and transformer_2):
+ loader = WeightLoader(components=["transformer", "transformer_2"])
+ weights = loader.load_weights(ckpt_dir, mapping)
+ # Returns: {"transformer": {...}, "transformer_2": {...}}
+ """
+
+ def __init__(self, components: Union[str, List[str]] = PipelineComponent.TRANSFORMER):
+ """
+ Args:
+ components: Component(s) to load weights for. Can be:
+ - Single string: "transformer" (returns flat dict)
+ - List of strings: ["transformer", "transformer_2"] (returns nested dict)
+ """
+ if isinstance(components, str):
+ self.components = [components]
+ self.single_component = True
+ else:
+ self.components = components
+ self.single_component = False
+
+ def load_weights(
+ self,
+ checkpoint_dir: str,
+ mapping: Mapping,
+ **kwargs,
+ ) -> Dict[str, Any]:
+ """
+ Load weights from checkpoint directory.
+
+ Args:
+ checkpoint_dir: Path to checkpoint (pipeline root or component dir)
+ mapping: Distributed mapping (for future TP/PP support)
+
+ Returns:
+ - If single component: Dict mapping weight names to tensors
+ - If multiple components: Dict mapping component names to weight dicts
+ Example: {"transformer": {...}, "transformer_2": {...}}
+ """
+ checkpoint_path = Path(checkpoint_dir)
+
+ # Check if this is a pipeline (has model_index.json)
+ model_index = checkpoint_path / "model_index.json"
+ is_pipeline = model_index.exists()
+
+ # Load weights for each component
+ all_weights = {}
+ for component in self.components:
+ if is_pipeline:
+ # Pipeline format: load from component subdirectory
+ component_dir = checkpoint_path / component
+ if not component_dir.exists():
+ raise ValueError(f"Component '{component}' not found in {checkpoint_dir}")
+ weight_dir = component_dir
+ else:
+ # Standalone model (only valid for single component)
+ if len(self.components) > 1:
+ raise ValueError(
+ f"Multiple components specified but {checkpoint_dir} is not a pipeline "
+ "(no model_index.json found)"
+ )
+ weight_dir = checkpoint_path
+
+ # Find weight files
+ weight_files = self._find_weight_files(weight_dir)
+ if not weight_files:
+ raise ValueError(f"No weight files found in {weight_dir}")
+
+ # Load all weights with progress bar
+ component_weights = {}
+ desc = f"Loading {component}" if is_pipeline else "Loading checkpoint"
+ for wf in tqdm.tqdm(weight_files, desc=desc):
+ component_weights.update(self._load_file(wf))
+
+ all_weights[component] = component_weights
+
+ # Return flat dict for single component (backward compatibility)
+ if self.single_component:
+ return all_weights[self.components[0]]
+
+ # Return nested dict for multiple components
+ return all_weights
+
+ def _find_weight_files(self, weight_dir) -> List[str]:
+ """Find safetensors or bin weight files.
+
+ Handles:
+ - Single safetensors file
+ - Sharded safetensors with index.json
+ - PyTorch bin/pth files
+ """
+ weight_dir = Path(weight_dir)
+
+ # Check for sharded safetensors index
+ index_file = weight_dir / "diffusion_pytorch_model.safetensors.index.json"
+ if not index_file.exists():
+ index_file = weight_dir / "model.safetensors.index.json"
+
+ if index_file.exists():
+ # Sharded safetensors: read index to get all shard files
+ with open(index_file) as f:
+ index = json.load(f)
+ shard_files = set(index.get("weight_map", {}).values())
+ return sorted([str(weight_dir / f) for f in shard_files])
+
+ # Single safetensors file
+ files = list(weight_dir.glob("*.safetensors"))
+ if files:
+ # Filter out consolidated if multiple files exist
+ if len(files) > 1:
+ files = [f for f in files if "consolidated" not in f.name]
+ return sorted([str(f) for f in files])
+
+ # Fallback to bin
+ files = list(weight_dir.glob("*.bin"))
+ if files:
+ return sorted([str(f) for f in files])
+
+ # Fallback to pth
+ files = list(weight_dir.glob("*.pth"))
+ return sorted([str(f) for f in files])
+
+ def _load_file(self, filepath: str) -> Dict[str, Any]:
+ """Load weights from a single file."""
+ logger.debug(f"Loading {filepath}")
+ if filepath.endswith(".safetensors"):
+ from safetensors.torch import load_file
+
+ return load_file(filepath)
+ else:
+ return torch.load(filepath, map_location="cpu", weights_only=True)
diff --git a/tensorrt_llm/_torch/visual_gen/config.py b/tensorrt_llm/_torch/visual_gen/config.py
new file mode 100644
index 0000000000..c5666ae0b3
--- /dev/null
+++ b/tensorrt_llm/_torch/visual_gen/config.py
@@ -0,0 +1,565 @@
+import json
+import os
+from enum import Enum
+from pathlib import Path
+from types import SimpleNamespace
+from typing import Any, Dict, List, Literal, Optional, Tuple
+
+import torch
+from pydantic import BaseModel, ConfigDict, model_validator
+from pydantic import Field as PydanticField
+
+from tensorrt_llm.functional import AllReduceStrategy
+from tensorrt_llm.mapping import Mapping
+from tensorrt_llm.models.modeling_utils import QuantConfig
+from tensorrt_llm.quantization.mode import QuantAlgo
+
+# =============================================================================
+# Pipeline component identifiers
+# =============================================================================
+
+
+class PipelineComponent(str, Enum):
+ """Identifiers for pipeline components that can be loaded or skipped.
+
+ Inherits from str so values compare equal to plain strings,
+ e.g. ``PipelineComponent.VAE == "vae"`` is ``True``.
+ """
+
+ TRANSFORMER = "transformer"
+ VAE = "vae"
+ TEXT_ENCODER = "text_encoder"
+ TOKENIZER = "tokenizer"
+ SCHEDULER = "scheduler"
+ IMAGE_ENCODER = "image_encoder"
+ IMAGE_PROCESSOR = "image_processor"
+
+
+# =============================================================================
+# Sub-configuration classes for DiffusionArgs
+# =============================================================================
+
+
+class AttentionConfig(BaseModel):
+ """Configuration for Attention layers."""
+
+ backend: Literal["VANILLA", "TRTLLM"] = PydanticField(
+ "VANILLA", description="Attention backend: VANILLA (PyTorch SDPA), TRTLLM"
+ )
+
+
+class ParallelConfig(BaseModel):
+ """Configuration for distributed parallelism.
+
+ Currently Supported:
+ - dit_cfg_size: CFG (Classifier-Free Guidance) parallelism
+ - dit_ulysses_size: Ulysses sequence parallelism
+
+ Not Yet Supported:
+ - dit_tp_size: Tensor parallelism (not implemented)
+ - dit_ring_size: Ring attention (not implemented)
+ - dit_cp_size, dit_dp_size, dit_fsdp_size: Other parallelism types
+
+ Total world_size = dit_cfg_size Ć dit_ulysses_size
+
+ Parallelism Strategy:
+ - CFG Parallelism: Distributes positive/negative prompts across GPUs
+ - Ulysses Parallelism: Distributes sequence within each CFG group
+
+ Example Configurations:
+ 1. cfg_size=1, ulysses_size=2 -> 2 GPUs (Ulysses only)
+ GPU 0-1: Single prompt, sequence parallelism across 2 GPUs
+
+ 2. cfg_size=2, ulysses_size=1 -> 2 GPUs (CFG only)
+ GPU 0: Positive prompt
+ GPU 1: Negative prompt
+
+ 3. cfg_size=2, ulysses_size=2 -> 4 GPUs (CFG + Ulysses)
+ GPU 0-1: CFG group 0 (positive), Ulysses parallel
+ GPU 2-3: CFG group 1 (negative), Ulysses parallel
+
+ 4. cfg_size=2, ulysses_size=4 -> 8 GPUs (CFG + Ulysses)
+ GPU 0-3: CFG group 0 (positive), Ulysses parallel
+ GPU 4-7: CFG group 1 (negative), Ulysses parallel
+ """
+
+ disable_parallel_vae: bool = False
+ parallel_vae_split_dim: Literal["width", "height"] = "width"
+
+ # DiT Parallelism
+ dit_dp_size: int = PydanticField(1, ge=1)
+ dit_tp_size: int = PydanticField(1, ge=1) # Not yet supported
+ dit_ulysses_size: int = PydanticField(1, ge=1) # Supported
+ dit_ring_size: int = PydanticField(1, ge=1) # Not yet supported
+ dit_cp_size: int = PydanticField(1, ge=1)
+ dit_cfg_size: int = PydanticField(1, ge=1) # Supported
+ dit_fsdp_size: int = PydanticField(1, ge=1)
+
+ # Refiner Parallelism (Optional)
+ refiner_dit_dp_size: int = 1
+ refiner_dit_tp_size: int = 1
+ refiner_dit_ulysses_size: int = 1
+ refiner_dit_ring_size: int = 1
+ refiner_dit_cp_size: int = 1
+ refiner_dit_cfg_size: int = 1
+ refiner_dit_fsdp_size: int = 1
+
+ t5_fsdp_size: int = 1
+
+ def to_mapping(self) -> Mapping:
+ """Convert to TRT-LLM Mapping."""
+ world_size = self.dit_tp_size * self.dit_cp_size
+ return Mapping(
+ world_size=world_size,
+ tp_size=self.dit_tp_size,
+ pp_size=1,
+ cp_size=self.dit_cp_size,
+ )
+
+ @model_validator(mode="after")
+ def validate_parallel_sizes(self) -> "ParallelConfig":
+ """Validate configuration against current environment."""
+ if torch.cuda.is_available():
+ world_size = int(os.environ.get("WORLD_SIZE", 1))
+ total_parallel = (
+ self.dit_tp_size
+ * self.dit_ulysses_size
+ * self.dit_ring_size
+ * self.dit_cp_size
+ * self.dit_dp_size
+ * self.dit_cfg_size
+ )
+ if total_parallel > world_size:
+ raise ValueError(
+ f"Total DiT parallel size ({total_parallel}) exceeds WORLD_SIZE ({world_size})"
+ )
+ return self
+
+
+class TeaCacheConfig(BaseModel):
+ """Configuration for TeaCache runtime optimization.
+
+ TeaCache speeds up diffusion by caching transformer outputs when timestep
+ embeddings change slowly. It monitors embedding distances and reuses cached
+ residuals when changes are below a threshold.
+
+ Attributes:
+ enable_teacache: Enable TeaCache optimization
+ teacache_thresh: Distance threshold for cache decisions (lower = more caching)
+ use_ret_steps: Use aggressive warmup mode (5 steps) vs minimal (1 step)
+ coefficients: Polynomial coefficients for rescaling embedding distances
+ Applied as: rescaled_distance = poly(raw_distance)
+ ret_steps: Number of warmup steps (always compute, initialized at runtime)
+ cutoff_steps: Step to stop caching (always compute after, initialized at runtime)
+ num_steps: Total inference steps (set at runtime)
+ _cnt: Internal step counter (reset per generation)
+ """
+
+ enable_teacache: bool = False
+ teacache_thresh: float = PydanticField(0.2, gt=0.0)
+ use_ret_steps: bool = True
+
+ coefficients: List[float] = PydanticField(default_factory=lambda: [1.0, 0.0])
+
+ # Runtime state fields (initialized by TeaCacheBackend.refresh)
+ ret_steps: Optional[int] = None
+ cutoff_steps: Optional[int] = None
+ num_steps: Optional[int] = None
+
+ # State tracking (reset per generation)
+ _cnt: int = 0
+
+ model_config = ConfigDict(arbitrary_types_allowed=True)
+
+ @model_validator(mode="after")
+ def validate_teacache(self) -> "TeaCacheConfig":
+ """Validate TeaCache configuration."""
+ # Validate coefficients
+ if len(self.coefficients) == 0:
+ raise ValueError("TeaCache coefficients list cannot be empty")
+
+ # Validate ret_steps if set
+ if self.ret_steps is not None and self.ret_steps < 0:
+ raise ValueError(f"ret_steps must be non-negative, got {self.ret_steps}")
+
+ # Validate cutoff_steps vs num_steps if both set
+ if self.cutoff_steps is not None and self.num_steps is not None:
+ if self.cutoff_steps > self.num_steps:
+ raise ValueError(
+ f"cutoff_steps ({self.cutoff_steps}) cannot exceed num_steps ({self.num_steps})"
+ )
+
+ return self
+
+
+class PipelineConfig(BaseModel):
+ """General pipeline configuration."""
+
+ enable_torch_compile: bool = True
+ torch_compile_models: str = PipelineComponent.TRANSFORMER
+ torch_compile_mode: str = "default"
+ fuse_qkv: bool = True
+
+ # Offloading Config
+ enable_offloading: bool = False
+ offload_device: Literal["cpu", "cuda"] = "cpu"
+ offload_param_pin_memory: bool = True
+
+
+# =============================================================================
+# DiffusionArgs - User-facing configuration (CLI / YAML)
+# =============================================================================
+
+
+class DiffusionArgs(BaseModel):
+ """User-facing configuration for diffusion model loading and inference.
+
+ This is the main config class used in CLI args and YAML config files.
+ PipelineLoader converts this to DiffusionModelConfig internally.
+
+ Example:
+ args = DiffusionArgs(
+ checkpoint_path="/path/to/model",
+ quant_config={"quant_algo": "FP8_BLOCK_SCALES", "dynamic": True},
+ parallel=ParallelConfig(dit_tp_size=2),
+ )
+ loader = PipelineLoader()
+ pipeline = loader.load(args)
+ """
+
+ model_config = ConfigDict(arbitrary_types_allowed=True)
+
+ # Required: Path to checkpoint or HuggingFace Hub model ID
+ checkpoint_path: str = PydanticField(
+ "",
+ description=(
+ "Local directory path or HuggingFace Hub model ID "
+ "(e.g., 'Wan-AI/Wan2.1-T2V-1.3B-Diffusers'). "
+ "Hub models are downloaded and cached automatically."
+ ),
+ )
+
+ # HuggingFace Hub options
+ revision: Optional[str] = PydanticField(
+ None,
+ description="HuggingFace Hub revision (branch, tag, or commit SHA) to download.",
+ )
+
+ # Device/dtype options
+ device: str = "cuda"
+ dtype: str = "bfloat16"
+
+ # Component loading options (use PipelineComponent enum values or plain strings)
+ skip_components: List[PipelineComponent] = PydanticField(
+ default_factory=list,
+ description=(
+ "Components to skip loading. "
+ "Accepts PipelineComponent enum values or equivalent strings "
+ "(e.g., [PipelineComponent.TEXT_ENCODER, PipelineComponent.VAE])"
+ ),
+ )
+
+ # Sub-configs (dict input for quant_config is coerced to QuantConfig in model_validator)
+ quant_config: QuantConfig = PydanticField(default_factory=QuantConfig)
+ pipeline: PipelineConfig = PydanticField(default_factory=PipelineConfig)
+ attention: AttentionConfig = PydanticField(default_factory=AttentionConfig)
+ parallel: ParallelConfig = PydanticField(default_factory=ParallelConfig)
+ teacache: TeaCacheConfig = PydanticField(default_factory=TeaCacheConfig)
+
+ # Set by model_validator when quant_config is provided as a dict (ModelOpt format)
+ dynamic_weight_quant: bool = False
+ force_dynamic_quantization: bool = False
+
+ @model_validator(mode="before")
+ @classmethod
+ def _parse_quant_config_dict(cls, data: Any) -> Any:
+ """Parse user-facing DiffusionArgs.quant_config (dict or None) into QuantConfig and dynamic flags.
+
+ User input is ModelOpt-format dict (e.g. {"quant_algo": "FP8", "dynamic": True}).
+ We coerce it to QuantConfig + dynamic_weight_quant + force_dynamic_quantization so that
+ from_pretrained() can copy them into DiffusionModelConfig (internal) without parsing again.
+ """
+ if not isinstance(data, dict):
+ return data
+ raw = data.get("quant_config")
+ if raw is None:
+ data = {**data, "quant_config": QuantConfig()}
+ return data
+ if not isinstance(raw, dict):
+ return data
+ qc, _, dwq, daq = DiffusionModelConfig.load_diffusion_quant_config(raw)
+ data = {
+ **data,
+ "quant_config": qc,
+ "dynamic_weight_quant": dwq,
+ "force_dynamic_quantization": daq,
+ }
+ return data
+
+ def to_mapping(self) -> Mapping:
+ """Derive Mapping from ParallelConfig."""
+ return self.parallel.to_mapping()
+
+ def to_dict(self) -> Dict[str, Any]:
+ """Convert to dictionary."""
+ return self.model_dump()
+
+ @classmethod
+ def from_dict(cls, config_dict: Dict[str, Any]) -> "DiffusionArgs":
+ """Create from dictionary with automatic nested config parsing.
+
+ Pydantic automatically handles nested configs, but we keep this method
+ for backward compatibility and to filter unknown fields.
+ """
+ # Get valid field names for DiffusionArgs
+ valid_fields = set(cls.model_fields.keys())
+
+ # Filter to only include valid fields (ignore unknown fields)
+ filtered_dict = {k: v for k, v in config_dict.items() if k in valid_fields}
+
+ # Pydantic automatically converts nested dicts to their respective config classes
+ return cls(**filtered_dict)
+
+
+# =============================================================================
+# Utilities
+# =============================================================================
+
+
+def discover_pipeline_components(checkpoint_path: Path) -> Dict[str, Path]:
+ """
+ Discover components from diffusers pipeline's model_index.json.
+
+ Returns dict mapping component name to config.json path.
+ """
+ model_index_path = checkpoint_path / "model_index.json"
+ if not model_index_path.exists():
+ return {}
+
+ with open(model_index_path) as f:
+ model_index = json.load(f)
+
+ components = {}
+ for key, value in model_index.items():
+ if key.startswith("_") or value is None:
+ continue
+ config_path = checkpoint_path / key / "config.json"
+ if config_path.exists():
+ components[key] = config_path
+
+ return components
+
+
+# =============================================================================
+# DiffusionModelConfig - Internal configuration (merged/parsed)
+# =============================================================================
+
+
+class DiffusionModelConfig(BaseModel):
+ """Internal ModelConfig for diffusion models.
+
+ This is created by PipelineLoader from DiffusionArgs + checkpoint.
+ Contains merged/parsed config from:
+ - pretrained_config: From checkpoint/config.json
+ - quant_config: From checkpoint or user quant config
+ - Sub-configs: From DiffusionArgs (pipeline, attention, parallel, teacache)
+ """
+
+ model_config = ConfigDict(arbitrary_types_allowed=True)
+
+ pretrained_config: Optional[Any] = None
+ mapping: Mapping = PydanticField(default_factory=Mapping)
+ skip_create_weights_in_init: bool = False
+ force_dynamic_quantization: bool = False
+ allreduce_strategy: AllReduceStrategy = PydanticField(default=AllReduceStrategy.AUTO)
+ extra_attrs: Dict = PydanticField(default_factory=dict)
+
+ # Distributed process groups
+ ulysses_process_group: Optional[torch.distributed.ProcessGroup] = None
+
+ dynamic_weight_quant: bool = False
+
+ # Sub-configs from DiffusionArgs (merged during from_pretrained)
+ quant_config: QuantConfig = PydanticField(default_factory=QuantConfig)
+ # Per-layer quant (from load_diffusion_quant_config layer_quant_config; None until mixed-precision parsing exists)
+ quant_config_dict: Optional[Dict[str, QuantConfig]] = None
+ pipeline: PipelineConfig = PydanticField(default_factory=PipelineConfig)
+ attention: AttentionConfig = PydanticField(default_factory=AttentionConfig)
+ parallel: ParallelConfig = PydanticField(default_factory=ParallelConfig)
+ teacache: TeaCacheConfig = PydanticField(default_factory=TeaCacheConfig)
+
+ @property
+ def torch_dtype(self) -> "torch.dtype":
+ """Get the torch dtype of the model (default: bfloat16)."""
+ return torch.bfloat16
+
+ def get_quant_config(self, name: Optional[str] = None) -> QuantConfig:
+ """Get quantization config for a layer or global. Resembles LLM ModelConfig.get_quant_config."""
+ if name is None or self.quant_config_dict is None:
+ return self.quant_config
+ if name in self.quant_config_dict:
+ return self.quant_config_dict[name]
+ return self.quant_config
+
+ @staticmethod
+ def load_diffusion_quant_config(
+ quant_config_dict: dict,
+ ) -> Tuple[QuantConfig, Optional[Dict], bool, bool]:
+ """
+ Parse quantization config in ModelOpt format.
+
+ Returns: (quant_config, layer_quant_config, dynamic_weight_quant, dynamic_activation_quant)
+ - quant_config: Global QuantConfig
+ - layer_quant_config: Per-layer config dict (None if not using mixed precision)
+ - dynamic_weight_quant: Whether to quantize weights at load time
+ - dynamic_activation_quant: Whether to quantize activations dynamically
+ """
+ quant_algo_str = quant_config_dict.get("quant_algo")
+ quant_algo = None
+ if quant_algo_str:
+ algo_map = {
+ "FP8": QuantAlgo.FP8,
+ "FP8_BLOCK_SCALES": QuantAlgo.FP8_BLOCK_SCALES,
+ "NVFP4": QuantAlgo.NVFP4,
+ "W4A16_AWQ": QuantAlgo.W4A16_AWQ,
+ "W4A8_AWQ": QuantAlgo.W4A8_AWQ,
+ "W8A8_SQ_PER_CHANNEL": QuantAlgo.W8A8_SQ_PER_CHANNEL,
+ }
+ quant_algo = algo_map.get(quant_algo_str)
+ if quant_algo is None:
+ raise ValueError(f"Unknown quant_algo: {quant_algo_str}")
+
+ # Parse group_size and dynamic flags from config_groups
+ group_size = None
+ dynamic_weight_quant = False
+ dynamic_activation_quant = False
+ for group_config in quant_config_dict.get("config_groups", {}).values():
+ weights_config = group_config.get("weights", {})
+ activations_config = group_config.get("input_activations", {})
+ dynamic_weight_quant = weights_config.get("dynamic", False)
+ dynamic_activation_quant = activations_config.get("dynamic", False)
+ # Extract group_size from weights config (e.g., NVFP4: group_size=16)
+ if group_size is None:
+ group_size = weights_config.get("group_size")
+ break
+
+ # Set defaults based on quant_algo if group_size not specified
+ if group_size is None:
+ if quant_algo in (QuantAlgo.NVFP4,):
+ group_size = 16 # NVFP4 default
+ elif quant_algo == QuantAlgo.FP8_BLOCK_SCALES:
+ group_size = 128 # FP8 blockwise default
+
+ # Auto-enable dynamic weight quantization if quant_algo is specified
+ # but no explicit config_groups setting is present.
+ # This allows simple configs like {"quant_algo": "FP8"} to work.
+ if quant_algo is not None and not quant_config_dict.get("config_groups"):
+ dynamic_weight_quant = quant_config_dict.get("dynamic", True)
+
+ quant_config = QuantConfig(
+ quant_algo=quant_algo,
+ group_size=group_size,
+ exclude_modules=quant_config_dict.get("ignore"),
+ )
+
+ # TODO: Per-layer config (None for now - future: parse mixed precision settings)
+ layer_quant_config = None
+
+ return quant_config, layer_quant_config, dynamic_weight_quant, dynamic_activation_quant
+
+ @classmethod
+ def from_pretrained(
+ cls,
+ checkpoint_dir: str,
+ args: Optional["DiffusionArgs"] = None,
+ **kwargs,
+ ) -> "DiffusionModelConfig":
+ """
+ Load config from pretrained checkpoint.
+
+ Called by PipelineLoader with DiffusionArgs:
+ config = DiffusionModelConfig.from_pretrained(
+ checkpoint_dir=args.checkpoint_path,
+ args=args,
+ )
+
+ Args:
+ checkpoint_dir: Path to checkpoint
+ args: DiffusionArgs containing user config (quant, pipeline, attention, parallel, teacache)
+ **kwargs: Additional config options (e.g., mapping)
+ """
+ kwargs.pop("trust_remote_code", None)
+
+ # Extract sub-configs from args or use defaults
+ pipeline_cfg = args.pipeline if args else PipelineConfig()
+ attention_cfg = args.attention if args else AttentionConfig()
+ parallel_cfg = args.parallel if args else ParallelConfig()
+ teacache_cfg = args.teacache if args else TeaCacheConfig()
+
+ component = PipelineComponent.TRANSFORMER
+ checkpoint_path = Path(checkpoint_dir)
+
+ # Discover pipeline components
+ components = discover_pipeline_components(checkpoint_path)
+
+ # Determine config path
+ if components:
+ if component not in components:
+ raise ValueError(
+ f"Component '{component}' not found. Available: {list(components.keys())}"
+ )
+ config_path = components[component]
+ else:
+ config_path = checkpoint_path / "config.json"
+
+ if not config_path.exists():
+ raise ValueError(f"Config not found at {config_path}")
+
+ # Load pretrained_config from checkpoint
+ with open(config_path) as f:
+ config_dict = json.load(f)
+ pretrained_config = SimpleNamespace(**config_dict)
+
+ model_index_path = checkpoint_path / "model_index.json"
+ if model_index_path.exists():
+ with open(model_index_path) as f:
+ model_index = json.load(f)
+ if "boundary_ratio" in model_index and "transformer_2" in model_index:
+ transformer_2_spec = model_index.get("transformer_2")
+ if transformer_2_spec and transformer_2_spec[0] is not None:
+ pretrained_config.boundary_ratio = model_index["boundary_ratio"]
+
+ # Resolve quant config: use args if user set quant (QuantConfig from dict), else checkpoint
+ if args and args.quant_config.quant_algo is not None:
+ quant_config = args.quant_config
+ quant_config_dict = (
+ None # DiffusionArgs has no per-layer dict; only from checkpoint parse
+ )
+ dynamic_weight_quant = args.dynamic_weight_quant
+ dynamic_activation_quant = args.force_dynamic_quantization
+ else:
+ quant_config = QuantConfig()
+ quant_config_dict = None
+ dynamic_weight_quant = False
+ dynamic_activation_quant = False
+ quant_dict = getattr(pretrained_config, "quantization_config", None)
+ if isinstance(quant_dict, dict):
+ quant_config, quant_config_dict, dynamic_weight_quant, dynamic_activation_quant = (
+ cls.load_diffusion_quant_config(quant_dict)
+ )
+
+ return cls(
+ pretrained_config=pretrained_config,
+ quant_config=quant_config,
+ quant_config_dict=quant_config_dict,
+ dynamic_weight_quant=dynamic_weight_quant,
+ force_dynamic_quantization=dynamic_activation_quant,
+ # Sub-configs from DiffusionArgs
+ pipeline=pipeline_cfg,
+ attention=attention_cfg,
+ parallel=parallel_cfg,
+ teacache=teacache_cfg,
+ # Delay weight creation after apply_quant_config_exclude_modules() in __post_init__
+ skip_create_weights_in_init=True,
+ **kwargs,
+ )
diff --git a/tensorrt_llm/_torch/visual_gen/executor.py b/tensorrt_llm/_torch/visual_gen/executor.py
new file mode 100644
index 0000000000..d6e03bdcfe
--- /dev/null
+++ b/tensorrt_llm/_torch/visual_gen/executor.py
@@ -0,0 +1,246 @@
+import os
+import queue
+import threading
+import traceback
+from dataclasses import dataclass
+from typing import List, Optional, Union
+
+import torch
+import torch.distributed as dist
+import zmq
+
+from tensorrt_llm._torch.visual_gen.config import DiffusionArgs
+from tensorrt_llm._torch.visual_gen.output import MediaOutput
+from tensorrt_llm._torch.visual_gen.pipeline_loader import PipelineLoader
+from tensorrt_llm.executor.ipc import ZeroMqQueue
+from tensorrt_llm.logger import logger
+
+
+@dataclass
+class DiffusionRequest:
+ """Request for diffusion inference with explicit model-specific parameters."""
+
+ request_id: int
+ prompt: str
+ negative_prompt: Optional[str] = None
+ height: int = 720
+ width: int = 1280
+ num_inference_steps: int = 50
+ guidance_scale: float = 5.0
+ max_sequence_length: int = 512
+ seed: int = 42
+
+ # Video-specific parameters
+ num_frames: int = 81
+ frame_rate: float = 24.0
+
+ # Image-specific parameters
+ num_images_per_prompt: int = 1
+
+ # Advanced parameters
+ guidance_rescale: float = 0.0
+ output_type: str = "pt"
+
+ # Wan-specific parameters
+ image: Optional[Union[str, List[str]]] = None
+ guidance_scale_2: Optional[float] = None
+ boundary_ratio: Optional[float] = None
+ last_image: Optional[Union[str, List[str]]] = None
+
+
+@dataclass
+class DiffusionResponse:
+ """Response with model-specific output.
+
+ Attributes:
+ request_id: Unique identifier for the request
+ output: Generated media as MediaOutput with model-specific fields populated
+ error_msg: Error message if generation failed
+ """
+
+ request_id: int
+ output: Optional[MediaOutput] = None
+ error_msg: Optional[str] = None
+
+
+class DiffusionExecutor:
+ """Execution engine for diffusion models running in worker processes."""
+
+ def __init__(
+ self,
+ model_path: str,
+ request_queue_addr: str,
+ response_queue_addr: str,
+ device_id: int,
+ diffusion_config: Optional[dict] = None,
+ ):
+ self.model_path = model_path
+ self.request_queue_addr = request_queue_addr
+ self.response_queue_addr = response_queue_addr
+ self.device_id = device_id
+ self.diffusion_config = diffusion_config
+
+ self.requests_ipc = None
+ self.rank = dist.get_rank()
+ self.response_queue = queue.Queue()
+ self.sender_thread = None
+
+ # Only rank 0 handles IPC
+ if self.rank == 0:
+ logger.info(f"Worker {device_id}: Connecting to request queue")
+ self.requests_ipc = ZeroMqQueue(
+ (request_queue_addr, None),
+ is_server=False,
+ socket_type=zmq.PULL,
+ use_hmac_encryption=False,
+ )
+ self.sender_thread = threading.Thread(target=self._sender_loop, daemon=True)
+ self.sender_thread.start()
+
+ self._load_pipeline()
+
+ def _sender_loop(self):
+ """Background thread for sending responses."""
+ logger.info(f"Worker {self.device_id}: Connecting to response queue")
+ responses_ipc = ZeroMqQueue(
+ (self.response_queue_addr, None),
+ is_server=False,
+ socket_type=zmq.PUSH,
+ use_hmac_encryption=False,
+ )
+
+ while True:
+ try:
+ resp = self.response_queue.get()
+ if resp is None:
+ break
+ responses_ipc.put(resp)
+ except Exception as e:
+ logger.error(f"Worker {self.device_id}: Sender error: {e}")
+
+ if responses_ipc.socket:
+ responses_ipc.socket.setsockopt(zmq.LINGER, 0)
+ responses_ipc.close()
+
+ def _load_pipeline(self):
+ """
+ Load pipeline using proper flow:
+ DiffusionArgs ā PipelineLoader ā DiffusionModelConfig ā AutoPipeline ā BasePipeline
+ """
+ logger.info(f"Worker {self.device_id}: Loading pipeline")
+
+ try:
+ # Convert diffusion_config dict to DiffusionArgs
+ config_dict = self.diffusion_config.copy()
+ config_dict["checkpoint_path"] = self.model_path
+ config_dict["device"] = f"cuda:{self.device_id}"
+
+ # Create DiffusionArgs from dict (handles nested configs)
+ args = DiffusionArgs.from_dict(config_dict)
+
+ # Use PipelineLoader for proper pipeline creation flow:
+ # PipelineLoader.load() internally:
+ # 1. Creates DiffusionModelConfig.from_pretrained()
+ # 2. Creates pipeline via AutoPipeline.from_config()
+ # 3. Loads weights with quantization support
+ # 4. Calls post_load_weights()
+ loader = PipelineLoader(args)
+ self.pipeline = loader.load()
+
+ except Exception as e:
+ logger.error(f"Worker {self.device_id}: Failed to load pipeline: {e}")
+ raise
+
+ logger.info(f"Worker {self.device_id}: Pipeline ready")
+
+ # Sync all workers
+ dist.barrier()
+
+ # Send READY signal
+ if self.rank == 0:
+ logger.info(f"Worker {self.device_id}: Sending READY")
+ self.response_queue.put(DiffusionResponse(request_id=-1, output="READY"))
+
+ def serve_forever(self):
+ """Main execution loop."""
+ while True:
+ req = None
+ if self.rank == 0:
+ req = self.requests_ipc.get()
+ logger.info(f"Worker {self.device_id}: Request available")
+
+ # Broadcast to all ranks
+ obj_list = [req]
+ dist.broadcast_object_list(obj_list, src=0)
+ req = obj_list[0]
+
+ if req is None:
+ logger.info(f"Worker {self.device_id}: Shutdown signal received")
+ if self.rank == 0 and self.sender_thread:
+ self.response_queue.put(None)
+ self.sender_thread.join()
+ break
+
+ logger.info(f"Worker {self.device_id}: Processing request {req.request_id}")
+ self.process_request(req)
+
+ def process_request(self, req: DiffusionRequest):
+ """Process a single request."""
+ try:
+ output = self.pipeline.infer(req)
+ if self.rank == 0:
+ self.response_queue.put(DiffusionResponse(request_id=req.request_id, output=output))
+ except Exception as e:
+ logger.error(f"Worker {self.device_id}: Error: {e}")
+ logger.error(traceback.format_exc())
+ if self.rank == 0:
+ self.response_queue.put(
+ DiffusionResponse(request_id=req.request_id, error_msg=str(e))
+ )
+
+
+def run_diffusion_worker(
+ rank: int,
+ world_size: int,
+ master_addr: str,
+ master_port: int,
+ model_path: str,
+ request_queue_addr: str,
+ response_queue_addr: str,
+ diffusion_config: Optional[dict] = None,
+):
+ """Entry point for worker process."""
+ try:
+ # Setup distributed env ā use PyTorch distributed, not MPI
+ os.environ["TLLM_DISABLE_MPI"] = "1"
+ os.environ["MASTER_ADDR"] = master_addr
+ os.environ["MASTER_PORT"] = str(master_port)
+ os.environ["RANK"] = str(rank)
+ os.environ["WORLD_SIZE"] = str(world_size)
+
+ # Calculate device_id before init_process_group
+ device_id = rank % torch.cuda.device_count() if torch.cuda.is_available() else 0
+ if torch.cuda.is_available():
+ torch.cuda.set_device(device_id)
+
+ dist.init_process_group(
+ backend="nccl" if torch.cuda.is_available() else "gloo",
+ init_method="env://",
+ world_size=world_size,
+ rank=rank,
+ device_id=torch.device(f"cuda:{device_id}") if torch.cuda.is_available() else None,
+ )
+
+ executor = DiffusionExecutor(
+ model_path=model_path,
+ request_queue_addr=request_queue_addr,
+ response_queue_addr=response_queue_addr,
+ device_id=device_id,
+ diffusion_config=diffusion_config,
+ )
+ executor.serve_forever()
+ dist.destroy_process_group()
+
+ except Exception as e:
+ logger.error(f"Worker failed: {e}")
+ traceback.print_exc()
diff --git a/tensorrt_llm/_torch/visual_gen/models/__init__.py b/tensorrt_llm/_torch/visual_gen/models/__init__.py
new file mode 100644
index 0000000000..5f726b84ec
--- /dev/null
+++ b/tensorrt_llm/_torch/visual_gen/models/__init__.py
@@ -0,0 +1,30 @@
+"""
+Visual generation model pipelines.
+
+Each model subdirectory contains:
+- pipeline_*.py: Main pipeline implementation inheriting from BasePipeline
+- __init__.py: Exports the pipeline class
+
+TeaCache extractors are registered inline in each pipeline's load() method
+using register_extractor_from_config().
+
+Pipelines are registered in pipeline_registry.py's PipelineRegistry._REGISTRY dict.
+
+Example structure:
+ models/
+ my_model/
+ pipeline_my_model.py # Pipeline class with inline extractor registration
+ __init__.py # Exports: __all__ = ["MyModelPipeline"]
+"""
+
+from ..pipeline import BasePipeline
+from ..pipeline_registry import AutoPipeline, register_pipeline
+from .wan import WanImageToVideoPipeline, WanPipeline
+
+__all__ = [
+ "AutoPipeline",
+ "BasePipeline",
+ "WanPipeline",
+ "WanImageToVideoPipeline",
+ "register_pipeline",
+]
diff --git a/tensorrt_llm/_torch/visual_gen/models/wan/__init__.py b/tensorrt_llm/_torch/visual_gen/models/wan/__init__.py
new file mode 100644
index 0000000000..f177740809
--- /dev/null
+++ b/tensorrt_llm/_torch/visual_gen/models/wan/__init__.py
@@ -0,0 +1,5 @@
+from .pipeline_wan import WanPipeline
+from .pipeline_wan_i2v import WanImageToVideoPipeline
+from .transformer_wan import WanTransformer3DModel
+
+__all__ = ["WanPipeline", "WanImageToVideoPipeline", "WanTransformer3DModel"]
diff --git a/tensorrt_llm/_torch/visual_gen/models/wan/pipeline_wan.py b/tensorrt_llm/_torch/visual_gen/models/wan/pipeline_wan.py
new file mode 100644
index 0000000000..f5d0f4fce5
--- /dev/null
+++ b/tensorrt_llm/_torch/visual_gen/models/wan/pipeline_wan.py
@@ -0,0 +1,521 @@
+import time
+from typing import Optional
+
+import torch
+from diffusers import AutoencoderKLWan, FlowMatchEulerDiscreteScheduler
+from diffusers.utils.torch_utils import randn_tensor
+from diffusers.video_processor import VideoProcessor
+from transformers import AutoTokenizer, UMT5EncoderModel
+
+from tensorrt_llm._torch.visual_gen.config import PipelineComponent
+from tensorrt_llm._torch.visual_gen.output import MediaOutput
+from tensorrt_llm._torch.visual_gen.pipeline import BasePipeline
+from tensorrt_llm._torch.visual_gen.pipeline_registry import register_pipeline
+from tensorrt_llm._torch.visual_gen.teacache import ExtractorConfig, register_extractor_from_config
+from tensorrt_llm._torch.visual_gen.utils import postprocess_video_tensor
+from tensorrt_llm.logger import logger
+
+from .transformer_wan import WanTransformer3DModel
+
+# Supported Wan T2V models:
+# - Wan2.1-T2V-14B: Single-stage text-to-video (14B parameters)
+# - Wan2.1-T2V-1.3B: Single-stage text-to-video (1.3B parameters)
+# - Wan2.2-T2V-A14B: Two-stage text-to-video (14B, boundary_ratio for high/low-noise stages; supports 480P & 720P)
+
+WAN_TEACACHE_COEFFICIENTS = {
+ "1.3B": {
+ "ret_steps": [
+ -5.21862437e04,
+ 9.23041404e03,
+ -5.28275948e02,
+ 1.36987616e01,
+ -4.99875664e-02,
+ ],
+ "standard": [2.39676752e03, -1.31110545e03, 2.01331979e02, -8.29855975e00, 1.37887774e-01],
+ },
+ "14B": {
+ "ret_steps": [
+ -3.03318725e05,
+ 4.90537029e04,
+ -2.65530556e03,
+ 5.87365115e01,
+ -3.15583525e-01,
+ ],
+ "standard": [-5784.54975374, 5449.50911966, -1811.16591783, 256.27178429, -13.02252404],
+ },
+}
+
+
+# Default negative prompt for Wan T2V models
+WAN_DEFAULT_NEGATIVE_PROMPT = (
+ "Vibrant colors, overexposed, static, blurry details, subtitles, style, artwork, painting, image, "
+ "still image, overall grayish tone, worst quality, low quality, JPEG compression artifacts, ugly, "
+ "incomplete, extra fingers, poorly drawn hands, poorly drawn face, deformed, disfigured, malformed limbs, "
+ "fused fingers, motionless image, cluttered background, three legs, many people in the background, walking backward"
+)
+
+
+@register_pipeline("WanPipeline")
+class WanPipeline(BasePipeline):
+ def __init__(self, model_config):
+ # Wan2.2 two-stage denoising parameters
+ self.transformer_2 = None
+ self.boundary_ratio = getattr(model_config.pretrained_config, "boundary_ratio", None)
+ self.is_wan22 = self.boundary_ratio is not None
+
+ super().__init__(model_config)
+
+ @staticmethod
+ def _compute_wan_timestep_embedding(module, timestep, guidance=None):
+ """Compute timestep embedding for WAN transformer.
+
+ WAN uses a condition_embedder with timesteps_proj and time_embedder layers.
+ Handles dtype casting to match the embedder's dtype.
+
+ Args:
+ module: WanTransformer3DModel instance
+ timestep: Timestep tensor (shape: [batch_size])
+ guidance: Unused for WAN (no guidance embedding)
+
+ Returns:
+ Timestep embedding tensor used by TeaCache for distance calculation
+ """
+ ce = module.condition_embedder
+ t_freq = ce.timesteps_proj(timestep)
+
+ # Cast to embedder's dtype (avoid int8 quantized layers)
+ te_dtype = next(iter(ce.time_embedder.parameters())).dtype
+ if t_freq.dtype != te_dtype and te_dtype != torch.int8:
+ t_freq = t_freq.to(te_dtype)
+
+ return ce.time_embedder(t_freq)
+
+ @property
+ def dtype(self):
+ return self.model_config.torch_dtype
+
+ @property
+ def device(self):
+ return self.transformer.device
+
+ @property
+ def transformer_components(self) -> list:
+ """Return list of transformer components this pipeline needs."""
+ if self.transformer_2 is not None:
+ return ["transformer", "transformer_2"]
+ return ["transformer"]
+
+ def _init_transformer(self) -> None:
+ logger.info("Creating WAN transformer with quantization support...")
+ self.transformer = WanTransformer3DModel(model_config=self.model_config)
+
+ # Wan2.2: create second transformer for two-stage denoising
+ if self.boundary_ratio is not None:
+ logger.info("Creating second transformer for Wan2.2 two-stage denoising...")
+ self.transformer_2 = WanTransformer3DModel(model_config=self.model_config)
+
+ def load_standard_components(
+ self,
+ checkpoint_dir: str,
+ device: torch.device,
+ skip_components: Optional[list] = None,
+ ) -> None:
+ """Load VAE, text encoder, tokenizer, and scheduler from checkpoint."""
+ skip_components = skip_components or []
+
+ if self.transformer_2 is not None and self.boundary_ratio is None:
+ raise RuntimeError(
+ "transformer_2 exists but boundary_ratio is not set. "
+ "This indicates an inconsistent pipeline configuration."
+ )
+
+ # Detect model version
+ if self.is_wan22:
+ logger.info("Detected Wan 2.2 T2V (two-stage denoising)")
+ else:
+ logger.info("Detected Wan 2.1 T2V (single-stage denoising)")
+
+ # Set default VAE scale factors (will be overridden if VAE is loaded)
+ self.vae_scale_factor_temporal = 4
+ self.vae_scale_factor_spatial = 8
+
+ if PipelineComponent.TOKENIZER not in skip_components:
+ logger.info("Loading tokenizer...")
+ self.tokenizer = AutoTokenizer.from_pretrained(
+ checkpoint_dir,
+ subfolder=PipelineComponent.TOKENIZER,
+ )
+
+ if PipelineComponent.TEXT_ENCODER not in skip_components:
+ logger.info("Loading text encoder...")
+ self.text_encoder = UMT5EncoderModel.from_pretrained(
+ checkpoint_dir,
+ subfolder=PipelineComponent.TEXT_ENCODER,
+ torch_dtype=self.model_config.torch_dtype,
+ ).to(device)
+
+ if PipelineComponent.VAE not in skip_components:
+ logger.info("Loading VAE...")
+ self.vae = AutoencoderKLWan.from_pretrained(
+ checkpoint_dir,
+ subfolder=PipelineComponent.VAE,
+ torch_dtype=torch.bfloat16, # load VAE in BF16 for memory saving
+ ).to(device)
+
+ self.vae_scale_factor_temporal = getattr(self.vae.config, "scale_factor_temporal", 4)
+ self.vae_scale_factor_spatial = getattr(self.vae.config, "scale_factor_spatial", 8)
+
+ if PipelineComponent.SCHEDULER not in skip_components:
+ logger.info("Loading scheduler...")
+ self.scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained(
+ checkpoint_dir,
+ subfolder=PipelineComponent.SCHEDULER,
+ )
+ if not hasattr(self.scheduler.config, "shift") or self.scheduler.config.shift == 1.0:
+ self.scheduler = FlowMatchEulerDiscreteScheduler.from_config(
+ self.scheduler.config,
+ shift=5.0,
+ )
+
+ self.video_processor = VideoProcessor(vae_scale_factor=self.vae_scale_factor_spatial)
+
+ def load_weights(self, weights: dict) -> None:
+ # Store weights for later use (in case transformer_2 is created after this call)
+ self._weights_dict = weights
+
+ has_separate_weights = "transformer" in weights and "transformer_2" in weights
+
+ if self.transformer is not None and hasattr(self.transformer, "load_weights"):
+ logger.info("Loading transformer weights...")
+ transformer_weights = weights.get("transformer", weights)
+ self.transformer.load_weights(transformer_weights)
+ logger.info("Transformer weights loaded successfully.")
+
+ # Wan2.2: Load weights for second transformer if it exists
+ if self.transformer_2 is not None and hasattr(self.transformer_2, "load_weights"):
+ logger.info("Loading transformer_2 weights for Wan2.2...")
+ if not has_separate_weights:
+ raise ValueError(
+ "Wan2.2 model requires separate 'transformer' and 'transformer_2' weights in checkpoint, "
+ f"but only found: {list(weights.keys())}. "
+ "Two-stage denoising requires distinct weights for high-noise and low-noise transformers."
+ )
+ transformer_2_weights = weights["transformer_2"]
+ self.transformer_2.load_weights(transformer_2_weights)
+ logger.info("Transformer_2 weights loaded successfully.")
+
+ # Cache the target dtype from model config (default: bfloat16)
+ self._target_dtype = self.model_config.torch_dtype
+
+ # Set model to eval mode
+ if self.transformer is not None:
+ self.transformer.eval()
+ if self.transformer_2 is not None:
+ self.transformer_2.eval()
+
+ def post_load_weights(self) -> None:
+ super().post_load_weights() # Calls transformer.post_load_weights() for FP8 scale transformations
+ if self.transformer is not None:
+ # Register TeaCache extractor for this model type
+ # Tells TeaCache how to compute timestep embeddings for Wan
+ register_extractor_from_config(
+ ExtractorConfig(
+ model_class_name="WanTransformer3DModel",
+ timestep_embed_fn=self._compute_wan_timestep_embedding,
+ return_dict_default=False, # Wan returns raw tensors, not wrapped outputs
+ )
+ )
+
+ # Enable TeaCache optimization with WAN-specific coefficients
+ self._setup_teacache(self.transformer, coefficients=WAN_TEACACHE_COEFFICIENTS)
+ # Save transformer backend before it gets overwritten
+ self.transformer_cache_backend = self.cache_backend
+
+ # Wan2.2: Setup TeaCache for second transformer (low-noise stage)
+ if self.transformer_2 is not None:
+ if hasattr(self.transformer_2, "post_load_weights"):
+ self.transformer_2.post_load_weights()
+
+ # Enable TeaCache for low-noise stage with same coefficients
+ self._setup_teacache(self.transformer_2, coefficients=WAN_TEACACHE_COEFFICIENTS)
+ # Save transformer_2 backend
+ self.transformer_2_cache_backend = self.cache_backend
+
+ def infer(self, req):
+ """Run inference with request parameters."""
+ return self.forward(
+ prompt=req.prompt,
+ negative_prompt=req.negative_prompt,
+ height=req.height,
+ width=req.width,
+ num_frames=req.num_frames,
+ num_inference_steps=req.num_inference_steps,
+ guidance_scale=req.guidance_scale,
+ guidance_scale_2=req.guidance_scale_2,
+ boundary_ratio=req.boundary_ratio,
+ seed=req.seed,
+ max_sequence_length=req.max_sequence_length,
+ )
+
+ @torch.no_grad()
+ def forward(
+ self,
+ prompt: str,
+ negative_prompt: Optional[str] = None,
+ height: int = 720,
+ width: int = 1280,
+ num_frames: int = 81,
+ num_inference_steps: Optional[int] = None,
+ guidance_scale: Optional[float] = None,
+ guidance_scale_2: Optional[float] = None,
+ boundary_ratio: Optional[float] = None,
+ seed: int = 42,
+ max_sequence_length: int = 226,
+ ):
+ pipeline_start = time.time()
+ generator = torch.Generator(device=self.device).manual_seed(seed)
+
+ # Use user-provided boundary_ratio if given, otherwise fall back to model config
+ boundary_ratio = boundary_ratio if boundary_ratio is not None else self.boundary_ratio
+
+ # Validate that Wan 2.2 models have boundary_ratio set
+ if self.transformer_2 is not None and boundary_ratio is None:
+ raise ValueError(
+ "Wan 2.2 models require boundary_ratio to be set. "
+ "boundary_ratio was not found in model config. "
+ "Please pass boundary_ratio as a parameter."
+ )
+
+ # Set default negative prompt if not provided
+ if negative_prompt is None:
+ negative_prompt = WAN_DEFAULT_NEGATIVE_PROMPT
+
+ # Set model-specific defaults based on Wan version
+ logger.info(
+ f"Running {'Wan 2.2' if self.is_wan22 else 'Wan 2.1'} T2V inference"
+ f"(boundary_ratio={boundary_ratio}, has_transformer_2={self.transformer_2 is not None})"
+ )
+
+ if num_inference_steps is None:
+ num_inference_steps = 40 if self.is_wan22 else 50
+
+ if guidance_scale is None:
+ guidance_scale = 4.0 if self.is_wan22 else 5.0
+
+ if self.is_wan22 and guidance_scale_2 is None:
+ guidance_scale_2 = 3.0
+
+ # Validate two-stage denoising configuration
+ if guidance_scale_2 is not None and boundary_ratio is None:
+ logger.warning(
+ "guidance_scale_2 is specified but boundary_ratio is None. "
+ "guidance_scale_2 will be ignored."
+ "Set boundary_ratio in config or pass as parameter to enable two-stage denoising."
+ )
+ guidance_scale_2 = None
+
+ # Encode Prompt
+ logger.info("Encoding prompts...")
+ encode_start = time.time()
+ prompt_embeds, neg_prompt_embeds = self._encode_prompt(
+ prompt, negative_prompt, max_sequence_length
+ )
+ logger.info(f"Prompt encoding completed in {time.time() - encode_start:.2f}s")
+
+ # Prepare Latents
+ latents = self._prepare_latents(height, width, num_frames, generator)
+ logger.info(f"Latents shape: {latents.shape}")
+
+ self.scheduler.set_timesteps(num_inference_steps, device=self.device)
+
+ # Wan2.2: Calculate boundary timestep for two-stage denoising
+ boundary_timestep = None
+ if boundary_ratio is not None and self.transformer_2 is not None:
+ boundary_timestep = boundary_ratio * self.scheduler.config.num_train_timesteps
+ logger.info(
+ f"Wan2.2 two-stage denoising: boundary_timestep={boundary_timestep:.1f}, "
+ f"guidance_scale={guidance_scale}, guidance_scale_2={guidance_scale_2}"
+ )
+
+ # Denoising with two-stage support
+ # Track which model was used in last step (for logging model transitions)
+ last_model_used = [None]
+
+ def forward_fn(
+ latents, extra_stream_latents, timestep, encoder_hidden_states, extra_tensors
+ ):
+ """Forward function for Wan transformer with two-stage support.
+
+ extra_stream_latents and extra_tensors are unused for WAN (single stream, no additional embeddings).
+ """
+ # Select model based on timestep (if two-stage denoising is enabled)
+ if boundary_timestep is not None and self.transformer_2 is not None:
+ # Extract scalar timestep for comparison
+ current_t = timestep if timestep.dim() == 0 else timestep[0]
+ if current_t >= boundary_timestep:
+ current_model = self.transformer
+ model_name = "transformer (high-noise)"
+ else:
+ current_model = self.transformer_2
+ model_name = "transformer_2 (low-noise)"
+
+ # Log when switching models
+ if last_model_used[0] != model_name:
+ if self.rank == 0:
+ logger.info(f"Switched to {model_name} at timestep {current_t:.1f}")
+ last_model_used[0] = model_name
+ else:
+ current_model = self.transformer
+
+ return current_model(
+ hidden_states=latents,
+ timestep=timestep,
+ encoder_hidden_states=encoder_hidden_states,
+ )
+
+ # Two-stage denoising: model switching in forward_fn, guidance scale switching in denoise()
+ latents = self.denoise(
+ latents=latents,
+ scheduler=self.scheduler,
+ prompt_embeds=prompt_embeds,
+ neg_prompt_embeds=neg_prompt_embeds,
+ guidance_scale=guidance_scale,
+ forward_fn=forward_fn,
+ guidance_scale_2=guidance_scale_2,
+ boundary_timestep=boundary_timestep,
+ )
+
+ # Log TeaCache statistics - show stats for each transformer separately
+ if self.rank == 0 and self.model_config.teacache.enable_teacache:
+ logger.info("=" * 80)
+ logger.info("TeaCache Statistics:")
+
+ # Stats for transformer (high-noise)
+ if hasattr(self, "transformer_cache_backend") and self.transformer_cache_backend:
+ stats = self.transformer_cache_backend.get_stats()
+ total_steps = stats.get("total_steps", 0)
+ cache_hits = stats.get("cached_steps", 0)
+ cache_misses = stats.get("compute_steps", 0)
+ hit_rate = (cache_hits / total_steps * 100) if total_steps > 0 else 0.0
+
+ logger.info(" Transformer (High-Noise):")
+ logger.info(f" Total steps: {total_steps}")
+ logger.info(f" Cache hits: {cache_hits}")
+ logger.info(f" Cache misses: {cache_misses}")
+ logger.info(f" Hit rate: {hit_rate:.1f}%")
+
+ # Stats for transformer_2 (low-noise)
+ if hasattr(self, "transformer_2_cache_backend") and self.transformer_2_cache_backend:
+ stats = self.transformer_2_cache_backend.get_stats()
+ total_steps = stats.get("total_steps", 0)
+ cache_hits = stats.get("cached_steps", 0)
+ cache_misses = stats.get("compute_steps", 0)
+ hit_rate = (cache_hits / total_steps * 100) if total_steps > 0 else 0.0
+
+ logger.info(" Transformer_2 (Low-Noise):")
+ logger.info(f" Total steps: {total_steps}")
+ logger.info(f" Cache hits: {cache_hits}")
+ logger.info(f" Cache misses: {cache_misses}")
+ logger.info(f" Hit rate: {hit_rate:.1f}%")
+
+ logger.info("=" * 80)
+
+ # Decode
+ logger.info("Decoding video...")
+ decode_start = time.time()
+ video = self.decode_latents(latents, self._decode_latents)
+
+ if self.rank == 0:
+ logger.info(f"Video decoded in {time.time() - decode_start:.2f}s")
+ logger.info(f"Total pipeline time: {time.time() - pipeline_start:.2f}s")
+
+ return MediaOutput(video=video)
+
+ def _encode_prompt(self, prompt, negative_prompt, max_sequence_length):
+ prompt = [prompt] if isinstance(prompt, str) else prompt
+
+ def get_embeds(texts):
+ text_inputs = self.tokenizer(
+ texts,
+ padding="max_length",
+ max_length=max_sequence_length,
+ truncation=True,
+ return_attention_mask=True,
+ return_tensors="pt",
+ )
+ input_ids = text_inputs.input_ids.to(self.device)
+ attention_mask = text_inputs.attention_mask.to(self.device)
+
+ embeds = self.text_encoder(input_ids, attention_mask=attention_mask).last_hidden_state
+ embeds = embeds.to(self.dtype)
+
+ # Zero-out padded tokens based on mask
+ seq_lens = attention_mask.gt(0).sum(dim=1).long()
+ cleaned_embeds = []
+ for u, v in zip(embeds, seq_lens):
+ real_content = u[:v]
+ pad_len = max_sequence_length - real_content.size(0)
+ if pad_len > 0:
+ padded = torch.cat(
+ [real_content, real_content.new_zeros(pad_len, real_content.size(1))]
+ )
+ else:
+ padded = real_content
+ cleaned_embeds.append(padded)
+
+ return torch.stack(cleaned_embeds, dim=0)
+
+ prompt_embeds = get_embeds(prompt)
+
+ if negative_prompt is None:
+ negative_prompt = ""
+
+ neg_prompt = [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt
+ if len(neg_prompt) == 1 and len(prompt) > 1:
+ neg_prompt = neg_prompt * len(prompt)
+
+ neg_embeds = get_embeds(neg_prompt)
+
+ return prompt_embeds, neg_embeds
+
+ def _prepare_latents(self, height, width, num_frames, generator):
+ num_channels_latents = 16
+ num_latent_frames = (num_frames - 1) // self.vae_scale_factor_temporal + 1
+
+ shape = (
+ 1,
+ num_channels_latents,
+ num_latent_frames,
+ height // self.vae_scale_factor_spatial,
+ width // self.vae_scale_factor_spatial,
+ )
+ return randn_tensor(shape, generator=generator, device=self.device, dtype=self.dtype)
+
+ def _decode_latents(self, latents):
+ latents = latents.to(self.vae.dtype)
+
+ # Denormalization
+ if hasattr(self.vae.config, "latents_mean") and hasattr(self.vae.config, "latents_std"):
+ if not hasattr(self, "_latents_mean"):
+ self._latents_mean = (
+ torch.tensor(self.vae.config.latents_mean)
+ .view(1, -1, 1, 1, 1)
+ .to(self.device, self.vae.dtype)
+ )
+ self._latents_std = (
+ torch.tensor(self.vae.config.latents_std)
+ .view(1, -1, 1, 1, 1)
+ .to(self.device, self.vae.dtype)
+ )
+ latents = (latents * self._latents_std) + self._latents_mean
+ else:
+ scaling_factor = self.vae.config.get("scaling_factor", 1.0)
+ latents = latents / scaling_factor
+
+ # VAE decode: returns (B, C, T, H, W)
+ video = self.vae.decode(latents, return_dict=False)[0]
+
+ # Post-process video tensor: (B, C, T, H, W) -> (T, H, W, C) uint8
+ video = postprocess_video_tensor(video, remove_batch_dim=True)
+
+ return video
diff --git a/tensorrt_llm/_torch/visual_gen/models/wan/pipeline_wan_i2v.py b/tensorrt_llm/_torch/visual_gen/models/wan/pipeline_wan_i2v.py
new file mode 100644
index 0000000000..d2a3fd629f
--- /dev/null
+++ b/tensorrt_llm/_torch/visual_gen/models/wan/pipeline_wan_i2v.py
@@ -0,0 +1,736 @@
+import json
+import os
+import time
+from typing import Optional, Tuple, Union
+
+import PIL.Image
+import torch
+from diffusers import AutoencoderKLWan, FlowMatchEulerDiscreteScheduler
+from diffusers.utils.torch_utils import randn_tensor
+from diffusers.video_processor import VideoProcessor
+from transformers import AutoTokenizer, CLIPImageProcessor, CLIPVisionModel, UMT5EncoderModel
+
+from tensorrt_llm._torch.visual_gen.config import PipelineComponent
+from tensorrt_llm._torch.visual_gen.output import MediaOutput
+from tensorrt_llm._torch.visual_gen.pipeline import BasePipeline
+from tensorrt_llm._torch.visual_gen.pipeline_registry import register_pipeline
+from tensorrt_llm._torch.visual_gen.teacache import ExtractorConfig, register_extractor_from_config
+from tensorrt_llm._torch.visual_gen.utils import postprocess_video_tensor
+from tensorrt_llm.logger import logger
+
+# Supported Wan I2V 14B models:
+# - Wan2.1-I2V-14B-480P: Single-stage image-to-video
+# - Wan2.1-I2V-14B-720P: Single-stage image-to-video
+# - Wan2.2-I2V-14B: Two-stage image-to-video (no CLIP, boundary_ratio for two-stage denoising)
+# Note: Wan2.2-I2V-5B (expand_timesteps mode) is NOT supported by this pipeline
+# Import shared coefficients from T2V pipeline
+from .pipeline_wan import WAN_TEACACHE_COEFFICIENTS
+from .transformer_wan import WanTransformer3DModel
+
+# Use same coefficients
+WAN_I2V_TEACACHE_COEFFICIENTS = WAN_TEACACHE_COEFFICIENTS
+
+# Default negative prompt for Wan I2V models
+WAN_DEFAULT_NEGATIVE_PROMPT = (
+ "Vibrant colors, overexposed, static, blurry details, subtitles, style, artwork, painting, image, "
+ "still image, overall grayish tone, worst quality, low quality, JPEG compression artifacts, ugly, "
+ "incomplete, extra fingers, poorly drawn hands, poorly drawn face, deformed, disfigured, malformed limbs, "
+ "fused fingers, motionless image, cluttered background, three legs, many people in the background, walking backward"
+)
+
+
+def retrieve_latents(
+ encoder_output: torch.Tensor,
+ generator: Optional[torch.Generator] = None,
+ sample_mode: str = "argmax",
+):
+ """Extract latents from VAE encoder output.
+
+ For I2V, we use argmax mode to get deterministic encoding of the input image.
+ """
+ if hasattr(encoder_output, "latent_dist") and sample_mode == "sample":
+ return encoder_output.latent_dist.sample(generator)
+ elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax":
+ return encoder_output.latent_dist.mode()
+ elif hasattr(encoder_output, "latents"):
+ return encoder_output.latents
+ else:
+ raise AttributeError("Could not access latents of provided encoder_output")
+
+
+@register_pipeline("WanImageToVideoPipeline")
+class WanImageToVideoPipeline(BasePipeline):
+ def __init__(self, model_config):
+ # Wan2.2 14B two-stage denoising parameters
+ self.transformer_2 = None
+ self.boundary_ratio = getattr(model_config.pretrained_config, "boundary_ratio", None)
+ self.is_wan22 = self.boundary_ratio is not None
+
+ super().__init__(model_config)
+
+ @staticmethod
+ def _compute_wan_timestep_embedding(module, timestep, guidance=None):
+ """Compute timestep embedding for Wan I2V transformer."""
+ ce = module.condition_embedder
+ t_freq = ce.timesteps_proj(timestep)
+
+ # Cast to embedder's dtype (avoid int8 quantized layers)
+ te_dtype = next(iter(ce.time_embedder.parameters())).dtype
+ if t_freq.dtype != te_dtype and te_dtype != torch.int8:
+ t_freq = t_freq.to(te_dtype)
+
+ return ce.time_embedder(t_freq)
+
+ @property
+ def dtype(self):
+ return self.model_config.torch_dtype
+
+ @property
+ def device(self):
+ return self.transformer.device
+
+ @property
+ def transformer_components(self) -> list:
+ if self.transformer_2 is not None:
+ return ["transformer", "transformer_2"]
+ return ["transformer"]
+
+ def _init_transformer(self) -> None:
+ logger.info("Creating WAN I2V transformer with quantization support...")
+ self.transformer = WanTransformer3DModel(model_config=self.model_config)
+
+ # Wan2.2: Optionally create second transformer for two-stage denoising
+ if self.boundary_ratio is not None:
+ logger.info("Creating second transformer for Wan2.2 I2V two-stage denoising...")
+ self.transformer_2 = WanTransformer3DModel(model_config=self.model_config)
+
+ def load_standard_components(
+ self,
+ checkpoint_dir: str,
+ device: torch.device,
+ skip_components: Optional[list] = None,
+ ) -> None:
+ """Load VAE, text encoder, tokenizer, scheduler, and I2V-specific components from checkpoint."""
+ skip_components = skip_components or []
+
+ # Load boundary_ratio and transformer_2 info from model_index.json (pipeline-level config)
+ # Wan 2.2 has both transformer_2 and boundary_ratio, Wan 2.1 doesn't
+ model_index_path = os.path.join(checkpoint_dir, "model_index.json")
+ has_transformer_2 = False
+ if os.path.exists(model_index_path):
+ with open(model_index_path) as f:
+ model_index = json.load(f)
+ # Check for boundary_ratio in model_index
+ if "boundary_ratio" in model_index:
+ self.boundary_ratio = model_index["boundary_ratio"]
+ logger.info(f"Found boundary_ratio in model_index.json: {self.boundary_ratio}")
+ else:
+ logger.info("No boundary_ratio found in model_index.json")
+ # Check for transformer_2 component
+ transformer_2_spec = model_index.get("transformer_2", None)
+ has_transformer_2 = (
+ transformer_2_spec is not None and transformer_2_spec[0] is not None
+ )
+ logger.info(f"transformer_2 in model_index.json: {has_transformer_2}")
+
+ # Set default VAE scale factors (will be overridden if VAE is loaded)
+ self.vae_scale_factor_temporal = 4
+ self.vae_scale_factor_spatial = 8
+
+ if PipelineComponent.TOKENIZER not in skip_components:
+ logger.info("Loading tokenizer...")
+ self.tokenizer = AutoTokenizer.from_pretrained(
+ checkpoint_dir,
+ subfolder=PipelineComponent.TOKENIZER,
+ )
+
+ if PipelineComponent.TEXT_ENCODER not in skip_components:
+ logger.info("Loading text encoder...")
+ self.text_encoder = UMT5EncoderModel.from_pretrained(
+ checkpoint_dir,
+ subfolder=PipelineComponent.TEXT_ENCODER,
+ torch_dtype=self.model_config.torch_dtype,
+ ).to(device)
+
+ if PipelineComponent.VAE not in skip_components:
+ logger.info("Loading VAE...")
+ self.vae = AutoencoderKLWan.from_pretrained(
+ checkpoint_dir,
+ subfolder=PipelineComponent.VAE,
+ torch_dtype=torch.bfloat16, # load VAE in BF16 for memory saving
+ ).to(device)
+
+ self.vae_scale_factor_temporal = getattr(self.vae.config, "scale_factor_temporal", 4)
+ self.vae_scale_factor_spatial = getattr(self.vae.config, "scale_factor_spatial", 8)
+
+ if PipelineComponent.SCHEDULER not in skip_components:
+ logger.info("Loading scheduler...")
+ self.scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained(
+ checkpoint_dir,
+ subfolder=PipelineComponent.SCHEDULER,
+ )
+ if not hasattr(self.scheduler.config, "shift") or self.scheduler.config.shift == 1.0:
+ self.scheduler = FlowMatchEulerDiscreteScheduler.from_config(
+ self.scheduler.config,
+ shift=5.0,
+ )
+
+ if self.transformer_2 is not None and self.boundary_ratio is None:
+ raise RuntimeError(
+ "transformer_2 exists but boundary_ratio is not set. "
+ "This indicates an inconsistent pipeline configuration."
+ )
+
+ # Load image encoder and processor (only for Wan 2.1)
+ # Wan 2.2: Has both transformer_2 and boundary_ratio (two-stage denoising)
+ if self.is_wan22:
+ logger.info("Detected Wan 2.2 I2V (two-stage, no CLIP)")
+ else:
+ logger.info("Detected Wan 2.1 I2V (single-stage, uses CLIP)")
+
+ if PipelineComponent.IMAGE_ENCODER not in skip_components and not self.is_wan22:
+ logger.info("Loading CLIP image encoder for I2V conditioning (Wan 2.1 only)...")
+ self.image_encoder = CLIPVisionModel.from_pretrained(
+ checkpoint_dir,
+ subfolder=PipelineComponent.IMAGE_ENCODER,
+ torch_dtype=torch.float32, # Keep CLIP in FP32 for stability
+ ).to(device)
+
+ if PipelineComponent.IMAGE_PROCESSOR not in skip_components and not self.is_wan22:
+ logger.info("Loading CLIP image processor...")
+ self.image_processor = CLIPImageProcessor.from_pretrained(
+ checkpoint_dir,
+ subfolder=PipelineComponent.IMAGE_PROCESSOR,
+ )
+
+ self.video_processor = VideoProcessor(vae_scale_factor=self.vae_scale_factor_spatial)
+
+ def load_weights(self, weights: dict) -> None:
+ # Store weights for later use
+ self._weights_dict = weights
+
+ # Check if weights dict has separate transformer/transformer_2 keys (Wan2.2)
+ has_separate_weights = "transformer" in weights and "transformer_2" in weights
+
+ if self.transformer is not None and hasattr(self.transformer, "load_weights"):
+ logger.info("Loading transformer weights...")
+ transformer_weights = weights.get("transformer", weights)
+ self.transformer.load_weights(transformer_weights)
+ logger.info("Transformer weights loaded successfully.")
+
+ # Wan2.2: Load weights for second transformer if it exists
+ if self.transformer_2 is not None and hasattr(self.transformer_2, "load_weights"):
+ logger.info("Loading transformer_2 weights for Wan2.2 I2V...")
+ if has_separate_weights:
+ transformer_2_weights = weights["transformer_2"]
+ logger.info("Using separate transformer_2 weights from checkpoint")
+ else:
+ # For Wan 2.2, transformer_2 weights must exist
+ raise ValueError(
+ "Wan2.2 model requires separate 'transformer' and 'transformer_2' weights in checkpoint, "
+ f"but only found: {list(weights.keys())}"
+ )
+ self.transformer_2.load_weights(transformer_2_weights)
+ logger.info("Transformer_2 weights loaded successfully.")
+
+ # Cache the target dtype from model config (default: bfloat16)
+ self._target_dtype = self.model_config.torch_dtype
+
+ # Set model to eval mode
+ if self.transformer is not None:
+ self.transformer.eval()
+ if self.transformer_2 is not None:
+ self.transformer_2.eval()
+ if hasattr(self, "image_encoder") and self.image_encoder is not None:
+ self.image_encoder.eval()
+
+ def post_load_weights(self) -> None:
+ super().post_load_weights() # Calls transformer.post_load_weights() for FP8 scale transformations
+ if self.transformer is not None:
+ # Register TeaCache extractor for this model type
+ register_extractor_from_config(
+ ExtractorConfig(
+ model_class_name="WanTransformer3DModel",
+ timestep_embed_fn=self._compute_wan_timestep_embedding,
+ return_dict_default=False, # Wan returns raw tensors, not wrapped outputs
+ )
+ )
+
+ # Enable TeaCache optimization with Wan I2V-specific coefficients
+ self._setup_teacache(self.transformer, coefficients=WAN_I2V_TEACACHE_COEFFICIENTS)
+ # Save transformer backend before it gets overwritten
+ self.transformer_cache_backend = self.cache_backend
+
+ # Wan2.2: Setup TeaCache for second transformer (low-noise stage)
+ if self.transformer_2 is not None:
+ if hasattr(self.transformer_2, "post_load_weights"):
+ self.transformer_2.post_load_weights()
+
+ # Enable TeaCache for low-noise stage with same coefficients
+ self._setup_teacache(self.transformer_2, coefficients=WAN_I2V_TEACACHE_COEFFICIENTS)
+ # Save transformer_2 backend
+ self.transformer_2_cache_backend = self.cache_backend
+
+ def infer(self, req):
+ """Run inference with request parameters."""
+ # Extract image from request (can be path, PIL Image, or torch.Tensor)
+ if req.image is None:
+ raise ValueError("I2V pipeline requires 'image' parameter")
+
+ image = req.image[0] if isinstance(req.image, list) else req.image
+ last_image = req.last_image
+
+ if last_image is not None and isinstance(last_image, list):
+ last_image = last_image[0] if last_image else None
+
+ return self.forward(
+ image=image,
+ prompt=req.prompt,
+ negative_prompt=req.negative_prompt,
+ height=req.height,
+ width=req.width,
+ num_frames=req.num_frames,
+ num_inference_steps=req.num_inference_steps,
+ guidance_scale=req.guidance_scale,
+ guidance_scale_2=req.guidance_scale_2,
+ boundary_ratio=req.boundary_ratio,
+ seed=req.seed,
+ max_sequence_length=req.max_sequence_length,
+ last_image=last_image,
+ )
+
+ @torch.no_grad()
+ def forward(
+ self,
+ image: Union[PIL.Image.Image, torch.Tensor, str],
+ prompt: str,
+ negative_prompt: Optional[str] = None,
+ height: int = 480,
+ width: int = 832,
+ num_frames: int = 81,
+ num_inference_steps: Optional[int] = None,
+ guidance_scale: Optional[float] = None,
+ guidance_scale_2: Optional[float] = None,
+ boundary_ratio: Optional[float] = None,
+ seed: int = 42,
+ max_sequence_length: int = 512,
+ last_image: Optional[Union[PIL.Image.Image, torch.Tensor, str]] = None,
+ ):
+ pipeline_start = time.time()
+ generator = torch.Generator(device=self.device).manual_seed(seed)
+
+ # Use user-provided boundary_ratio if given, otherwise fall back to model config
+ boundary_ratio = boundary_ratio if boundary_ratio is not None else self.boundary_ratio
+
+ if self.transformer_2 is not None and boundary_ratio is None:
+ raise ValueError(
+ "Wan 2.2 models require boundary_ratio to be set. "
+ "boundary_ratio was not found in model config. "
+ "Please pass boundary_ratio as a parameter."
+ )
+
+ # Set default negative prompt if not provided
+ if negative_prompt is None:
+ negative_prompt = WAN_DEFAULT_NEGATIVE_PROMPT
+
+ # Set model-specific defaults based on Wan version
+ if num_inference_steps is None:
+ num_inference_steps = 40 if self.is_wan22 else 50
+
+ if guidance_scale is None:
+ guidance_scale = 4.0 if self.is_wan22 else 5.0
+
+ if self.is_wan22 and guidance_scale_2 is None:
+ guidance_scale_2 = 3.0 # Wan2.2 recommended default
+
+ # Validate two-stage denoising configuration
+ if guidance_scale_2 is not None and boundary_ratio is None:
+ logger.warning(
+ "guidance_scale_2 is specified but boundary_ratio is None. "
+ "guidance_scale_2 will be ignored."
+ "Set boundary_ratio in config or pass as parameter to enable two-stage denoising."
+ )
+ guidance_scale_2 = None
+
+ # Validate and adjust frame count for VAE compatibility
+ if num_frames % self.vae_scale_factor_temporal != 1:
+ logger.warning(
+ f"`num_frames - 1` must be divisible by {self.vae_scale_factor_temporal}. "
+ f"Rounding {num_frames} to nearest valid value."
+ )
+ num_frames = (
+ num_frames // self.vae_scale_factor_temporal * self.vae_scale_factor_temporal + 1
+ )
+ num_frames = max(num_frames, 1)
+
+ # Validate and adjust resolution for transformer patchification
+ patch_size = (
+ self.transformer.config.patch_size
+ if self.transformer is not None
+ else self.transformer_2.config.patch_size
+ )
+ h_multiple_of = self.vae_scale_factor_spatial * patch_size[1]
+ w_multiple_of = self.vae_scale_factor_spatial * patch_size[2]
+ calc_height = height // h_multiple_of * h_multiple_of
+ calc_width = width // w_multiple_of * w_multiple_of
+ if height != calc_height or width != calc_width:
+ logger.warning(
+ f"Height and width must be multiples of ({h_multiple_of}, {w_multiple_of}) for patchification. "
+ f"Adjusting ({height}, {width}) -> ({calc_height}, {calc_width})."
+ )
+ height, width = calc_height, calc_width
+
+ # Encode Prompt
+ logger.info("Encoding prompts...")
+ encode_start = time.time()
+ prompt_embeds, neg_prompt_embeds = self._encode_prompt(
+ prompt, negative_prompt, max_sequence_length
+ )
+ logger.info(f"Prompt encoding completed in {time.time() - encode_start:.2f}s")
+
+ # Encode Image (I2V-specific)
+ logger.info("Encoding input image...")
+ image_encode_start = time.time()
+
+ # Determine model version
+ model_version = "Wan 2.2" if self.is_wan22 else "Wan 2.1"
+ logger.info(
+ f"Running {model_version} I2V inference "
+ f"(boundary_ratio={boundary_ratio}, has_transformer_2={self.transformer_2 is not None})"
+ )
+
+ if not self.is_wan22:
+ # Wan 2.1 I2V: Compute CLIP image embeddings
+ image_embeds = self._encode_image(image, last_image)
+ image_embeds = image_embeds.to(self.dtype)
+ else:
+ # Wan 2.2 I2V: No image embeddings needed
+ image_embeds = None
+
+ logger.info(f"Image encoding completed in {time.time() - image_encode_start:.2f}s")
+
+ # Prepare Latents with image conditioning (I2V-specific)
+ latents, condition_data = self._prepare_latents(
+ image, height, width, num_frames, generator, last_image
+ )
+
+ self.scheduler.set_timesteps(num_inference_steps, device=self.device)
+
+ # Wan2.2: Calculate boundary timestep for two-stage denoising
+ boundary_timestep = None
+ if boundary_ratio is not None and self.transformer_2 is not None:
+ boundary_timestep = boundary_ratio * self.scheduler.config.num_train_timesteps
+ logger.info(
+ f"Wan2.2 I2V two-stage denoising: boundary_timestep={boundary_timestep:.1f}, "
+ f"guidance_scale={guidance_scale}, guidance_scale_2={guidance_scale_2}"
+ )
+
+ # Denoising with two-stage support
+ # Track which model was used in last step (for logging model transitions)
+ last_model_used = [None]
+
+ def forward_fn(
+ latents_input, extra_stream_latents, timestep, encoder_hidden_states, extra_tensors
+ ):
+ """Forward function for WAN I2V transformer with two-stage support.
+
+ Both Wan 2.1 and Wan 2.2 14B use concatenation approach: [latents, condition].
+ Difference: Wan 2.1 passes image_embeds, Wan 2.2 passes None.
+ """
+ # Select model based on timestep (if two-stage denoising is enabled)
+ if boundary_timestep is not None and self.transformer_2 is not None:
+ # Extract scalar timestep for comparison
+ current_t = timestep if timestep.dim() == 0 else timestep[0]
+ if current_t >= boundary_timestep:
+ current_model = self.transformer
+ model_name = "transformer"
+ else:
+ current_model = self.transformer_2
+ model_name = "transformer_2"
+
+ # Log when switching models
+ if last_model_used[0] != model_name:
+ if self.rank == 0:
+ logger.info(
+ f"[TRTLLM] Switched to {model_name} at timestep {current_t:.1f}"
+ )
+ last_model_used[0] = model_name
+ else:
+ current_model = self.transformer
+
+ # Wan 2.1 & Wan 2.2 14B: concatenate latents and condition
+ # Handle CFG: duplicate condition if batch dimension doubled
+ if latents_input.shape[0] != condition_data.shape[0]:
+ condition_to_use = torch.cat([condition_data] * 2)
+ else:
+ condition_to_use = condition_data
+
+ latent_model_input = torch.cat([latents_input, condition_to_use], dim=1).to(self.dtype)
+ timestep_input = timestep.expand(latents_input.shape[0])
+
+ # Forward pass with I2V conditioning
+ # Wan 2.1: image_embeds is not None (CLIP embeddings)
+ # Wan 2.2 14B: image_embeds is None (no CLIP)
+ return current_model(
+ hidden_states=latent_model_input,
+ timestep=timestep_input,
+ encoder_hidden_states=encoder_hidden_states,
+ encoder_hidden_states_image=image_embeds,
+ )
+
+ # Two-stage denoising: model switching in forward_fn, guidance scale switching in denoise()
+ latents = self.denoise(
+ latents=latents,
+ scheduler=self.scheduler,
+ prompt_embeds=prompt_embeds,
+ neg_prompt_embeds=neg_prompt_embeds,
+ guidance_scale=guidance_scale,
+ forward_fn=forward_fn,
+ guidance_scale_2=guidance_scale_2,
+ boundary_timestep=boundary_timestep,
+ )
+
+ # Log TeaCache statistics - show stats for each transformer separately
+ if self.rank == 0 and self.model_config.teacache.enable_teacache:
+ logger.info("=" * 80)
+ logger.info("TeaCache Statistics:")
+
+ # Stats for transformer (high-noise)
+ if hasattr(self, "transformer_cache_backend") and self.transformer_cache_backend:
+ stats = self.transformer_cache_backend.get_stats()
+ total_steps = stats.get("total_steps", 0)
+ cache_hits = stats.get("cached_steps", 0)
+ cache_misses = stats.get("compute_steps", 0)
+ hit_rate = (cache_hits / total_steps * 100) if total_steps > 0 else 0.0
+
+ logger.info(" Transformer (High-Noise):")
+ logger.info(f" Total steps: {total_steps}")
+ logger.info(f" Cache hits: {cache_hits}")
+ logger.info(f" Cache misses: {cache_misses}")
+ logger.info(f" Hit rate: {hit_rate:.1f}%")
+
+ # Stats for transformer_2 (low-noise)
+ if hasattr(self, "transformer_2_cache_backend") and self.transformer_2_cache_backend:
+ stats = self.transformer_2_cache_backend.get_stats()
+ total_steps = stats.get("total_steps", 0)
+ cache_hits = stats.get("cached_steps", 0)
+ cache_misses = stats.get("compute_steps", 0)
+ hit_rate = (cache_hits / total_steps * 100) if total_steps > 0 else 0.0
+
+ logger.info(" Transformer_2 (Low-Noise):")
+ logger.info(f" Total steps: {total_steps}")
+ logger.info(f" Cache hits: {cache_hits}")
+ logger.info(f" Cache misses: {cache_misses}")
+ logger.info(f" Hit rate: {hit_rate:.1f}%")
+
+ logger.info("=" * 80)
+
+ # Decode
+ logger.info("Decoding video...")
+ decode_start = time.time()
+ video = self.decode_latents(latents, self._decode_latents)
+
+ if self.rank == 0:
+ logger.info(f"Video decoded in {time.time() - decode_start:.2f}s")
+ logger.info(f"Total pipeline time: {time.time() - pipeline_start:.2f}s")
+
+ return MediaOutput(video=video)
+
+ def _encode_prompt(self, prompt, negative_prompt, max_sequence_length):
+ """Encode text prompts to embeddings (same as T2V)."""
+ prompt = [prompt] if isinstance(prompt, str) else prompt
+
+ def get_embeds(texts):
+ text_inputs = self.tokenizer(
+ texts,
+ padding="max_length",
+ max_length=max_sequence_length,
+ truncation=True,
+ return_attention_mask=True,
+ return_tensors="pt",
+ )
+ input_ids = text_inputs.input_ids.to(self.device)
+ attention_mask = text_inputs.attention_mask.to(self.device)
+
+ embeds = self.text_encoder(input_ids, attention_mask=attention_mask).last_hidden_state
+ embeds = embeds.to(self.dtype)
+
+ # Zero-out padded tokens based on mask
+ seq_lens = attention_mask.gt(0).sum(dim=1).long()
+ cleaned_embeds = []
+ for u, v in zip(embeds, seq_lens):
+ real_content = u[:v]
+ pad_len = max_sequence_length - real_content.size(0)
+ if pad_len > 0:
+ padded = torch.cat(
+ [real_content, real_content.new_zeros(pad_len, real_content.size(1))]
+ )
+ else:
+ padded = real_content
+ cleaned_embeds.append(padded)
+
+ return torch.stack(cleaned_embeds, dim=0)
+
+ prompt_embeds = get_embeds(prompt)
+
+ if negative_prompt is None:
+ negative_prompt = ""
+
+ neg_prompt = [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt
+ if len(neg_prompt) == 1 and len(prompt) > 1:
+ neg_prompt = neg_prompt * len(prompt)
+
+ neg_embeds = get_embeds(neg_prompt)
+
+ return prompt_embeds, neg_embeds
+
+ def _encode_image(
+ self,
+ image: Union[PIL.Image.Image, torch.Tensor, str],
+ last_image: Optional[Union[PIL.Image.Image, torch.Tensor, str]] = None,
+ ) -> torch.Tensor:
+ """Encode image(s) using CLIP image encoder (Wan 2.1 I2V only)."""
+ if isinstance(image, str):
+ image = PIL.Image.open(image).convert("RGB")
+ if isinstance(last_image, str):
+ last_image = PIL.Image.open(last_image).convert("RGB")
+
+ images_to_encode = [image] if last_image is None else [image, last_image]
+
+ image_inputs = self.image_processor(images=images_to_encode, return_tensors="pt").to(
+ self.device
+ )
+ image_embeds = self.image_encoder(**image_inputs, output_hidden_states=True)
+
+ return image_embeds.hidden_states[-2]
+
+ def _prepare_latents(
+ self,
+ image: Union[PIL.Image.Image, torch.Tensor, str],
+ height: int,
+ width: int,
+ num_frames: int,
+ generator: torch.Generator,
+ last_image: Optional[Union[PIL.Image.Image, torch.Tensor, str]] = None,
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
+ """Prepare latents with image conditioning for I2V generation."""
+ num_channels_latents = 16
+ num_latent_frames = (num_frames - 1) // self.vae_scale_factor_temporal + 1
+ latent_height = height // self.vae_scale_factor_spatial
+ latent_width = width // self.vae_scale_factor_spatial
+
+ # Create random noise latents
+ shape = (1, num_channels_latents, num_latent_frames, latent_height, latent_width)
+ latents = randn_tensor(shape, generator=generator, device=self.device, dtype=self.dtype)
+
+ # Load and preprocess image(s)
+ if isinstance(image, str):
+ image = PIL.Image.open(image).convert("RGB")
+ image = self.video_processor.preprocess(image, height=height, width=width).to(
+ self.device, dtype=torch.float32
+ )
+
+ if last_image is not None:
+ if isinstance(last_image, str):
+ last_image = PIL.Image.open(last_image).convert("RGB")
+ last_image = self.video_processor.preprocess(last_image, height=height, width=width).to(
+ self.device, dtype=torch.float32
+ )
+
+ image = image.unsqueeze(2)
+
+ # Create video conditioning tensor (same for both Wan 2.1 and Wan 2.2 14B)
+ if last_image is None:
+ # First frame + zeros
+ video_condition = torch.cat(
+ [
+ image,
+ image.new_zeros(image.shape[0], image.shape[1], num_frames - 1, height, width),
+ ],
+ dim=2,
+ )
+ else:
+ # First frame + zeros + last frame (interpolation)
+ last_image = last_image.unsqueeze(2)
+ video_condition = torch.cat(
+ [
+ image,
+ image.new_zeros(image.shape[0], image.shape[1], num_frames - 2, height, width),
+ last_image,
+ ],
+ dim=2,
+ )
+
+ # Encode video condition through VAE
+ video_condition = video_condition.to(device=self.device, dtype=self.vae.dtype)
+ latent_condition = retrieve_latents(self.vae.encode(video_condition), sample_mode="argmax")
+ latent_condition = latent_condition.to(self.dtype)
+
+ # Normalize latents to match diffusion model's latent space
+ latents_mean = (
+ torch.tensor(self.vae.config.latents_mean)
+ .view(1, self.vae.config.z_dim, 1, 1, 1)
+ .to(latents.device, latents.dtype)
+ )
+ latents_std = 1.0 / torch.tensor(self.vae.config.latents_std).view(
+ 1, self.vae.config.z_dim, 1, 1, 1
+ ).to(latents.device, latents.dtype)
+ latent_condition = (latent_condition - latents_mean) * latents_std
+
+ # Create mask in video frame space
+ # Reshaping is required to match the transformer's expected input format
+ mask_lat_size = torch.ones(1, 1, num_frames, latent_height, latent_width)
+
+ if last_image is None:
+ mask_lat_size[:, :, list(range(1, num_frames))] = 0
+ else:
+ mask_lat_size[:, :, list(range(1, num_frames - 1))] = 0
+
+ first_frame_mask = mask_lat_size[:, :, 0:1]
+ first_frame_mask = torch.repeat_interleave(
+ first_frame_mask, dim=2, repeats=self.vae_scale_factor_temporal
+ )
+
+ mask_lat_size = torch.cat([first_frame_mask, mask_lat_size[:, :, 1:, :]], dim=2)
+
+ mask_lat_size = mask_lat_size.view(
+ 1, -1, self.vae_scale_factor_temporal, latent_height, latent_width
+ )
+
+ mask_lat_size = mask_lat_size.transpose(1, 2)
+
+ mask_lat_size = mask_lat_size.to(self.device, dtype=self.dtype)
+
+ # Concatenate mask and condition along channel dimension
+ condition = torch.cat([mask_lat_size, latent_condition], dim=1)
+ return latents, condition
+
+ def _decode_latents(self, latents):
+ """Decode latents to video (same as T2V)."""
+ latents = latents.to(self.vae.dtype)
+
+ # Denormalization
+ if hasattr(self.vae.config, "latents_mean") and hasattr(self.vae.config, "latents_std"):
+ if not hasattr(self, "_latents_mean"):
+ self._latents_mean = (
+ torch.tensor(self.vae.config.latents_mean)
+ .view(1, -1, 1, 1, 1)
+ .to(self.device, self.vae.dtype)
+ )
+ self._latents_std = (
+ torch.tensor(self.vae.config.latents_std)
+ .view(1, -1, 1, 1, 1)
+ .to(self.device, self.vae.dtype)
+ )
+ latents = (latents * self._latents_std) + self._latents_mean
+ else:
+ scaling_factor = self.vae.config.get("scaling_factor", 1.0)
+ latents = latents / scaling_factor
+
+ # VAE decode: returns (B, C, T, H, W)
+ video = self.vae.decode(latents, return_dict=False)[0]
+
+ # Post-process video tensor: (B, C, T, H, W) -> (T, H, W, C) uint8
+ video = postprocess_video_tensor(video, remove_batch_dim=True)
+
+ return video
diff --git a/tensorrt_llm/_torch/visual_gen/models/wan/transformer_wan.py b/tensorrt_llm/_torch/visual_gen/models/wan/transformer_wan.py
new file mode 100644
index 0000000000..bb19a78541
--- /dev/null
+++ b/tensorrt_llm/_torch/visual_gen/models/wan/transformer_wan.py
@@ -0,0 +1,756 @@
+import math
+from typing import Tuple
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from diffusers.models.embeddings import PixArtAlphaTextProjection, TimestepEmbedding, Timesteps
+from tqdm import tqdm
+from transformers.modeling_utils import get_parameter_device
+
+from tensorrt_llm._torch.modules.layer_norm import LayerNorm
+from tensorrt_llm._torch.modules.linear import Linear
+from tensorrt_llm._torch.modules.mlp import MLP
+from tensorrt_llm._torch.modules.rms_norm import RMSNorm
+from tensorrt_llm._torch.visual_gen.config import DiffusionModelConfig
+from tensorrt_llm._torch.visual_gen.modules.attention import Attention, QKVMode
+from tensorrt_llm._torch.visual_gen.parallelism import setup_sequence_parallelism
+from tensorrt_llm._torch.visual_gen.quantization.loader import DynamicLinearWeightLoader
+from tensorrt_llm.logger import logger
+from tensorrt_llm.models.modeling_utils import QuantConfig
+
+# =========================================================================
+# 1. Rotary Positional Embeddings
+# =========================================================================
+
+
+class WanRotaryPosEmbed(nn.Module):
+ def __init__(
+ self,
+ attention_head_dim: int,
+ patch_size: Tuple[int, int, int],
+ max_seq_len: int,
+ theta: float = 10000.0,
+ ):
+ super().__init__()
+ self.patch_size = patch_size
+
+ # Split logic matches Hugging Face exactly
+ self.h_dim = 2 * (attention_head_dim // 6)
+ self.w_dim = 2 * (attention_head_dim // 6)
+ self.t_dim = attention_head_dim - self.h_dim - self.w_dim
+
+ freqs_cos, freqs_sin = [], []
+
+ # Order: Time, Height, Width
+ for dim in [self.t_dim, self.h_dim, self.w_dim]:
+ # High precision generation
+ freqs = 1.0 / (theta ** (torch.arange(0, dim, 2, dtype=torch.float64) / dim))
+ t = torch.arange(max_seq_len, dtype=torch.float64)
+ freqs = torch.outer(t, freqs)
+
+ # Interleaved Pattern [c0, c0, c1, c1]
+ freqs_cos.append(freqs.cos().repeat_interleave(2, dim=-1).float())
+ freqs_sin.append(freqs.sin().repeat_interleave(2, dim=-1).float())
+
+ self.register_buffer("freqs_cos", torch.cat(freqs_cos, dim=1), persistent=False)
+ self.register_buffer("freqs_sin", torch.cat(freqs_sin, dim=1), persistent=False)
+
+ def forward(self, hidden_states: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
+ # Robust shape unpacking
+ b, c, f, h, w = hidden_states.shape
+ p_t, p_h, p_w = self.patch_size
+ ppf, pph, ppw = f // p_t, h // p_h, w // p_w
+
+ split_sizes = [self.t_dim, self.h_dim, self.w_dim]
+ freqs_cos = self.freqs_cos.split(split_sizes, dim=1)
+ freqs_sin = self.freqs_sin.split(split_sizes, dim=1)
+
+ # Broadcast frequencies to 3D grid: [Time, Height, Width]
+ f_cos_t = freqs_cos[0][:ppf].view(ppf, 1, 1, -1).expand(ppf, pph, ppw, -1)
+ f_sin_t = freqs_sin[0][:ppf].view(ppf, 1, 1, -1).expand(ppf, pph, ppw, -1)
+
+ f_cos_h = freqs_cos[1][:pph].view(1, pph, 1, -1).expand(ppf, pph, ppw, -1)
+ f_sin_h = freqs_sin[1][:pph].view(1, pph, 1, -1).expand(ppf, pph, ppw, -1)
+
+ f_cos_w = freqs_cos[2][:ppw].view(1, 1, ppw, -1).expand(ppf, pph, ppw, -1)
+ f_sin_w = freqs_sin[2][:ppw].view(1, 1, ppw, -1).expand(ppf, pph, ppw, -1)
+
+ # Concatenate and flatten for Attention [1, SeqLen, 1, Dim] (SHD format)
+ # New Attention module applies RoPE in [B, S, H, D] layout before reshaping to [B, H, S, D]
+ return (
+ torch.cat([f_cos_t, f_cos_h, f_cos_w], dim=-1).flatten(0, 2).unsqueeze(0).unsqueeze(2),
+ torch.cat([f_sin_t, f_sin_h, f_sin_w], dim=-1).flatten(0, 2).unsqueeze(0).unsqueeze(2),
+ )
+
+
+# =========================================================================
+# 2. Embeddings & Attention
+# =========================================================================
+
+
+class WanImageEmbedding(nn.Module):
+ """Image embedding for I2V models (Wan 2.1/2.2)."""
+
+ def __init__(
+ self,
+ in_features: int,
+ out_features: int,
+ pos_embed_seq_len: int = None,
+ model_config: DiffusionModelConfig = None,
+ ):
+ super().__init__()
+ dtype = model_config.torch_dtype if model_config else None
+ # LayerNorm weights in fp32 (matches internal float32 normalization; avoids bf16/fp32 mismatch).
+ self.norm1 = LayerNorm(
+ hidden_size=in_features, eps=1e-6, dtype=torch.float32, has_weights=True, has_bias=True
+ )
+
+ # Match HF FeedForward structure: Linear(in, in) ā GELU ā Linear(in, out)
+ self.ff_in = Linear(
+ in_features,
+ in_features,
+ bias=True,
+ dtype=dtype,
+ mapping=model_config.mapping if model_config else None,
+ quant_config=model_config.quant_config if model_config else None,
+ skip_create_weights_in_init=model_config.skip_create_weights_in_init
+ if model_config
+ else False,
+ force_dynamic_quantization=model_config.force_dynamic_quantization
+ if model_config
+ else False,
+ )
+ self.ff_out = Linear(
+ in_features,
+ out_features,
+ bias=True,
+ dtype=dtype,
+ mapping=model_config.mapping if model_config else None,
+ quant_config=model_config.quant_config if model_config else None,
+ skip_create_weights_in_init=model_config.skip_create_weights_in_init
+ if model_config
+ else False,
+ force_dynamic_quantization=model_config.force_dynamic_quantization
+ if model_config
+ else False,
+ )
+
+ self.norm2 = LayerNorm(
+ hidden_size=out_features, eps=1e-6, dtype=torch.float32, has_weights=True, has_bias=True
+ )
+
+ if pos_embed_seq_len is not None:
+ self.pos_embed = nn.Parameter(torch.zeros(1, pos_embed_seq_len, in_features))
+ else:
+ self.pos_embed = None
+
+ def forward(self, encoder_hidden_states_image: torch.Tensor) -> torch.Tensor:
+ if self.pos_embed is not None:
+ batch_size, seq_len, embed_dim = encoder_hidden_states_image.shape
+ encoder_hidden_states_image = encoder_hidden_states_image.view(
+ -1, 2 * seq_len, embed_dim
+ )
+ encoder_hidden_states_image = encoder_hidden_states_image + self.pos_embed
+
+ hidden_states = self.norm1(encoder_hidden_states_image)
+ hidden_states = self.ff_in(hidden_states)
+ hidden_states = F.gelu(hidden_states)
+ hidden_states = self.ff_out(hidden_states)
+ hidden_states = self.norm2(hidden_states)
+ return hidden_states
+
+
+class WanTimeTextImageEmbedding(nn.Module):
+ def __init__(
+ self,
+ dim,
+ time_freq_dim,
+ time_proj_dim,
+ text_embed_dim,
+ model_config: DiffusionModelConfig,
+ image_embed_dim: int = None,
+ pos_embed_seq_len: int = None,
+ ):
+ super().__init__()
+ dtype = model_config.torch_dtype
+ quant_config = model_config.quant_config
+ skip_create_weights = model_config.skip_create_weights_in_init
+ force_dynamic_quant = model_config.force_dynamic_quantization
+
+ self.timesteps_proj = Timesteps(
+ num_channels=time_freq_dim, flip_sin_to_cos=True, downscale_freq_shift=0
+ )
+ self.time_embedder = TimestepEmbedding(in_channels=time_freq_dim, time_embed_dim=dim)
+ self.act_fn = nn.SiLU()
+
+ self.time_proj = Linear(
+ dim,
+ time_proj_dim,
+ dtype=dtype,
+ mapping=model_config.mapping,
+ quant_config=quant_config,
+ skip_create_weights_in_init=skip_create_weights,
+ force_dynamic_quantization=force_dynamic_quant,
+ )
+ self.text_embedder = PixArtAlphaTextProjection(text_embed_dim, dim, act_fn="gelu_tanh")
+
+ self.image_embedder = None
+ if image_embed_dim is not None:
+ self.image_embedder = WanImageEmbedding(
+ image_embed_dim, dim, pos_embed_seq_len=pos_embed_seq_len, model_config=model_config
+ )
+
+ def forward(self, timestep, encoder_hidden_states, encoder_hidden_states_image=None):
+ timestep = self.timesteps_proj(timestep)
+
+ # Get time_embedder dtype
+ time_embedder_dtype = next(iter(self.time_embedder.parameters())).dtype
+ if timestep.dtype != time_embedder_dtype and time_embedder_dtype not in [
+ torch.float8_e4m3fn,
+ torch.float8_e5m2,
+ ]:
+ timestep = timestep.to(time_embedder_dtype)
+
+ temb = self.time_embedder(timestep).type_as(encoder_hidden_states)
+
+ temb_proj = self.time_proj(self.act_fn(temb))
+
+ encoder_hidden_states = self.text_embedder(encoder_hidden_states)
+
+ if encoder_hidden_states_image is not None and self.image_embedder is not None:
+ encoder_hidden_states_image = self.image_embedder(encoder_hidden_states_image)
+
+ return temb, temb_proj, encoder_hidden_states, encoder_hidden_states_image
+
+
+class WanBlock(nn.Module):
+ def __init__(
+ self,
+ model_config: DiffusionModelConfig,
+ _layer_idx: int,
+ added_kv_proj_dim: int = None,
+ ):
+ super().__init__()
+ config = model_config.pretrained_config
+
+ if hasattr(config, "hidden_size"):
+ hidden_size = config.hidden_size
+ elif hasattr(config, "attention_head_dim") and hasattr(config, "num_attention_heads"):
+ hidden_size = config.attention_head_dim * config.num_attention_heads
+ else:
+ hidden_size = 1536
+
+ # Wan 2.1 1.3B defaults
+ num_heads = getattr(config, "num_attention_heads", 12)
+ head_dim = getattr(config, "attention_head_dim", 128)
+ ffn_dim = getattr(config, "ffn_dim", 8960)
+ eps = getattr(config, "eps", 1e-6)
+
+ dtype = model_config.torch_dtype
+ quant_config = model_config.quant_config
+ skip_create_weights = model_config.skip_create_weights_in_init
+ force_dynamic_quant = model_config.force_dynamic_quantization
+
+ # Store for I2V reshaping logic
+ self.num_heads = num_heads
+ self.head_dim = head_dim
+
+ # LayerNorm weights in fp32 (matches internal float32 normalization; avoids bf16/fp32 mismatch).
+ self.norm1 = LayerNorm(
+ hidden_size=hidden_size, eps=eps, dtype=torch.float32, has_weights=False, has_bias=False
+ )
+
+ # Self-attention with fused QKV
+ self.attn1 = Attention(
+ hidden_size=hidden_size,
+ num_attention_heads=num_heads,
+ head_dim=head_dim,
+ qkv_mode=QKVMode.FUSE_QKV,
+ qk_norm=True,
+ eps=eps,
+ config=model_config,
+ layer_idx=_layer_idx,
+ )
+
+ # Cross-attention with separate Q, K, V
+ self.attn2 = Attention(
+ hidden_size=hidden_size,
+ num_attention_heads=num_heads,
+ head_dim=head_dim,
+ qkv_mode=QKVMode.SEPARATE_QKV,
+ qk_norm=True,
+ eps=eps,
+ config=model_config,
+ layer_idx=_layer_idx,
+ )
+
+ self.norm2 = LayerNorm(
+ hidden_size=hidden_size, eps=eps, dtype=torch.float32, has_weights=True, has_bias=True
+ )
+ self.norm3 = LayerNorm(
+ hidden_size=hidden_size, eps=eps, dtype=torch.float32, has_weights=False, has_bias=False
+ )
+
+ self.ffn = MLP(
+ hidden_size=hidden_size,
+ intermediate_size=ffn_dim,
+ bias=True,
+ activation=lambda x: F.gelu(x, approximate="tanh"),
+ dtype=dtype,
+ config=model_config,
+ layer_idx=_layer_idx,
+ reduce_output=False,
+ )
+
+ # I2V: Additional K/V projections for image embeddings
+ self.add_k_proj = self.add_v_proj = None
+ self.norm_added_k = None
+ if added_kv_proj_dim is not None:
+ self.add_k_proj = Linear(
+ added_kv_proj_dim,
+ hidden_size,
+ dtype=dtype,
+ mapping=model_config.mapping,
+ quant_config=quant_config,
+ skip_create_weights_in_init=skip_create_weights,
+ force_dynamic_quantization=force_dynamic_quant,
+ )
+ self.add_v_proj = Linear(
+ added_kv_proj_dim,
+ hidden_size,
+ dtype=dtype,
+ mapping=model_config.mapping,
+ quant_config=quant_config,
+ skip_create_weights_in_init=skip_create_weights,
+ force_dynamic_quantization=force_dynamic_quant,
+ )
+ self.norm_added_k = RMSNorm(
+ hidden_size=hidden_size, eps=eps, dtype=dtype, has_weights=True
+ )
+
+ # Use torch.empty().normal_(std=...) instead of torch.randn()/scale for MetaInitMode compatibility
+ self.scale_shift_table = nn.Parameter(
+ torch.empty(1, 6, hidden_size).normal_(std=hidden_size**-0.5)
+ )
+
+ def forward(
+ self,
+ x,
+ encoder_hidden_states,
+ temb,
+ freqs_cos,
+ freqs_sin,
+ ):
+ shift_msa, scale_msa, gate_msa, c_shift_msa, c_scale_msa, c_gate_msa = (
+ self.scale_shift_table.float() + temb.float()
+ ).chunk(6, dim=1)
+
+ normed = self.norm1(x.float()) * (1 + scale_msa) + shift_msa
+ normed = normed.to(x.dtype)
+
+ # Prepare frequencies for Attention
+ freqs = (freqs_cos, freqs_sin) if freqs_cos is not None and freqs_sin is not None else None
+
+ # Self-attention with RoPE
+ x = (
+ x.float()
+ + self.attn1(
+ normed,
+ freqs=freqs,
+ ).float()
+ * gate_msa
+ ).to(x.dtype)
+
+ norm_x = self.norm2(x.float()).to(x.dtype)
+
+ # I2V: Split encoder_hidden_states into image and text parts if needed
+ encoder_hidden_states_img = None
+ encoder_hidden_states_text = encoder_hidden_states
+ if self.add_k_proj is not None:
+ image_context_length = encoder_hidden_states.shape[1] - 512
+ encoder_hidden_states_img = encoder_hidden_states[:, :image_context_length]
+ encoder_hidden_states_text = encoder_hidden_states[:, image_context_length:]
+
+ # Text cross-attention
+ attn2_output = self.attn2(norm_x, encoder_hidden_states=encoder_hidden_states_text)
+
+ # I2V: Additional image cross-attention if image embeddings are present
+ if encoder_hidden_states_img is not None:
+ batch_size, seq_len = norm_x.shape[:2]
+
+ query = self.attn2.get_qkv(norm_x, None)[0] # Q only
+ query, _ = self.attn2.apply_qk_norm(query, query)
+
+ key_img = self.add_k_proj(encoder_hidden_states_img)
+ value_img = self.add_v_proj(encoder_hidden_states_img)
+ key_img = self.norm_added_k(key_img)
+
+ query = query.view(batch_size, seq_len, self.num_heads, self.head_dim)
+ key_img = key_img.view(
+ batch_size, encoder_hidden_states_img.shape[1], self.num_heads, self.head_dim
+ )
+ value_img = value_img.view(
+ batch_size, encoder_hidden_states_img.shape[1], self.num_heads, self.head_dim
+ )
+
+ attn_img_output = self.attn2._attn_impl(
+ query,
+ key_img,
+ value_img,
+ batch_size=batch_size,
+ seq_len=seq_len,
+ kv_seq_len=encoder_hidden_states_img.shape[1],
+ )
+
+ attn2_output = attn2_output + attn_img_output
+
+ x = x + attn2_output
+
+ # 3. Feed-forward
+ normed = self.norm3(x.float()) * (1 + c_scale_msa) + c_shift_msa
+ normed = normed.to(x.dtype)
+
+ x = (x.float() + self.ffn(normed).float() * c_gate_msa).to(x.dtype)
+
+ return x
+
+
+class WanTransformer3DModel(nn.Module):
+ _supports_gradient_checkpointing = True
+
+ def __init__(
+ self,
+ model_config: DiffusionModelConfig,
+ ):
+ super().__init__()
+
+ self.model_config = model_config
+
+ # Validate no tensor parallelism
+ if model_config.parallel.dit_tp_size > 1:
+ raise ValueError(
+ f"WAN does not support tensor parallelism. "
+ f"Got dit_tp_size={model_config.parallel.dit_tp_size}"
+ )
+
+ # Setup sequence parallelism (Ulysses)
+ num_heads = getattr(model_config.pretrained_config, "num_attention_heads", 12)
+ self.use_ulysses, self.ulysses_size, self.ulysses_pg, self.ulysses_rank = (
+ setup_sequence_parallelism(
+ model_config=model_config,
+ num_attention_heads=num_heads,
+ )
+ )
+
+ config = model_config.pretrained_config
+
+ dtype = model_config.torch_dtype
+ quant_config = model_config.quant_config
+ skip_create_weights = model_config.skip_create_weights_in_init
+ force_dynamic_quant = model_config.force_dynamic_quantization
+
+ if hasattr(config, "hidden_size"):
+ hidden_size = config.hidden_size
+ elif hasattr(config, "attention_head_dim") and hasattr(config, "num_attention_heads"):
+ hidden_size = config.attention_head_dim * config.num_attention_heads
+ else:
+ hidden_size = 1536 # Wan 1.3B default
+
+ num_layers = getattr(config, "num_layers", 30)
+ attention_head_dim = getattr(config, "attention_head_dim", 128)
+ in_channels = getattr(config, "in_channels", 16)
+ out_channels = getattr(config, "out_channels", 16)
+ text_dim = getattr(config, "text_dim", 4096)
+ freq_dim = getattr(config, "freq_dim", 256)
+ patch_size = getattr(config, "patch_size", [1, 2, 2])
+ image_embed_dim = getattr(config, "image_dim", None) # e.g., 1280 for I2V
+ added_kv_proj_dim = getattr(config, "added_kv_proj_dim", None)
+ pos_embed_seq_len = getattr(config, "pos_embed_seq_len", None)
+
+ # Calculate FFN Dim
+ ffn_dim = getattr(config, "ffn_dim", None)
+ if ffn_dim is None:
+ ffn_dim = (
+ 13824
+ if hidden_size == 5120
+ else (8960 if hidden_size == 1536 else int(hidden_size * 4))
+ )
+
+ # Store config for unpatchify and pipeline compatibility
+ self.config = type(
+ "Config",
+ (),
+ {
+ "patch_size": patch_size,
+ "hidden_size": hidden_size,
+ "image_dim": image_embed_dim,
+ "in_channels": in_channels,
+ "out_channels": out_channels,
+ "num_layers": num_layers,
+ },
+ )()
+
+ self.patch_embedding = nn.Conv3d(
+ in_channels,
+ hidden_size,
+ kernel_size=patch_size,
+ stride=patch_size,
+ dtype=dtype, # use model's target dtype (bf16)
+ )
+
+ self.condition_embedder = WanTimeTextImageEmbedding(
+ dim=hidden_size,
+ time_freq_dim=freq_dim,
+ time_proj_dim=hidden_size * 6,
+ text_embed_dim=text_dim,
+ model_config=model_config,
+ image_embed_dim=image_embed_dim,
+ pos_embed_seq_len=pos_embed_seq_len,
+ )
+
+ self.blocks = nn.ModuleList(
+ [
+ WanBlock(
+ model_config=model_config,
+ _layer_idx=i,
+ added_kv_proj_dim=added_kv_proj_dim,
+ )
+ for i in range(num_layers)
+ ]
+ )
+
+ self.rope = WanRotaryPosEmbed(attention_head_dim, patch_size, max_seq_len=1024)
+
+ # LayerNorm weights in fp32 (matches internal float32 normalization; avoids bf16/fp32 mismatch).
+ self.norm_out = LayerNorm(
+ hidden_size=hidden_size,
+ eps=1e-6,
+ dtype=torch.float32,
+ has_weights=False,
+ has_bias=False,
+ )
+
+ self.proj_out = Linear(
+ hidden_size,
+ out_channels * math.prod(patch_size),
+ dtype=dtype,
+ mapping=model_config.mapping,
+ quant_config=quant_config,
+ skip_create_weights_in_init=skip_create_weights,
+ force_dynamic_quantization=force_dynamic_quant,
+ )
+ # Use torch.empty().normal_(std=...) instead of torch.randn()/scale for MetaInitMode compatibility
+ self.scale_shift_table = nn.Parameter(
+ torch.empty(1, 2, hidden_size).normal_(std=hidden_size**-0.5)
+ )
+
+ self.__post_init__()
+
+ @property
+ def device(self):
+ return get_parameter_device(self)
+
+ def __post_init__(self):
+ self.apply_quant_config_exclude_modules()
+
+ for _, module in self.named_modules():
+ if callable(getattr(module, "create_weights", None)):
+ module.create_weights()
+
+ def apply_quant_config_exclude_modules(self):
+ quant_config = self.model_config.quant_config
+ if quant_config is None or quant_config.exclude_modules is None:
+ return
+
+ kv_cache_quant_algo = quant_config.kv_cache_quant_algo if quant_config else None
+ no_quant_config = QuantConfig(kv_cache_quant_algo=kv_cache_quant_algo)
+
+ for name, module in self.named_modules():
+ if isinstance(module, Linear):
+ is_excluded = quant_config.is_module_excluded_from_quantization(name)
+ if is_excluded and getattr(module, "quant_config", None) is not None:
+ module.quant_config = no_quant_config
+
+ def unpatchify(self, x, original_shape):
+ N, C, T, H, W = original_shape
+ pt, ph, pw = self.config.patch_size
+ gt, gh, gw = T // pt, H // ph, W // pw
+ # Use output channels instead of input channels for unpatchifying
+ out_channels = self.proj_out.out_features // (pt * ph * pw)
+ return (
+ x.view(N, gt, gh, gw, pt, ph, pw, out_channels)
+ .permute(0, 7, 1, 4, 2, 5, 3, 6)
+ .reshape(N, out_channels, T, H, W)
+ )
+
+ def forward(
+ self,
+ hidden_states,
+ timestep,
+ encoder_hidden_states,
+ encoder_hidden_states_image=None,
+ **kwargs,
+ ):
+ """
+ Forward pass with optional Ulysses sequence parallelism.
+
+ With Ulysses enabled (ulysses_size > 1):
+ 1. Shard input sequence across ranks: [B, S] -> [B, S/P]
+ 2. Each block's attention does internal all-to-all for full sequence
+ 3. Gather output sequence: [B, S/P] -> [B, S]
+
+ When TeaCache is enabled, TeaCacheHook intercepts and replaces this call.
+ """
+ original_shape = hidden_states.shape
+ B, C, T, H, W = original_shape
+ pt, ph, pw = self.config.patch_size
+
+ # Generate WAN RoPE frequencies
+ freqs_cos, freqs_sin = self.rope(hidden_states)
+
+ # Patchify and flatten: [B, C, T, H, W] -> [B, S, hidden_size]
+ x = self.patch_embedding(hidden_states).flatten(2).transpose(1, 2)
+
+ # Shard sequence for Ulysses parallelism: [B, S] -> [B, S/P]
+ if self.use_ulysses:
+ seq_len = x.shape[1]
+ if seq_len % self.ulysses_size != 0:
+ raise ValueError(
+ f"Sequence length ({seq_len}) is not divisible by ulysses_size ({self.ulysses_size}). "
+ f"Adjust video dimensions or use a different ulysses_size."
+ )
+
+ chunk_size = seq_len // self.ulysses_size
+ x = x[:, self.ulysses_rank * chunk_size : (self.ulysses_rank + 1) * chunk_size, :]
+
+ # Shard RoPE frequencies to match sequence sharding
+ # RoPE freqs shape: [B, S, ...], so shard along dim 1 (sequence dimension)
+ if freqs_cos is not None and freqs_sin is not None:
+ freqs_cos = freqs_cos[
+ :, self.ulysses_rank * chunk_size : (self.ulysses_rank + 1) * chunk_size
+ ]
+ freqs_sin = freqs_sin[
+ :, self.ulysses_rank * chunk_size : (self.ulysses_rank + 1) * chunk_size
+ ]
+
+ # Time and text/image embeddings
+ temb, temb_proj, encoder_hidden_states, encoder_hidden_states_image = (
+ self.condition_embedder(timestep, encoder_hidden_states, encoder_hidden_states_image)
+ )
+ temb_proj = temb_proj.view(-1, 6, self.config.hidden_size)
+
+ # I2V: Concatenate image and text embeddings if image embeddings are provided
+ if encoder_hidden_states_image is not None:
+ # Handle CFG: duplicate image embeddings if batch dimension is doubled
+ if encoder_hidden_states_image.shape[0] != encoder_hidden_states.shape[0]:
+ batch_multiplier = (
+ encoder_hidden_states.shape[0] // encoder_hidden_states_image.shape[0]
+ )
+ encoder_hidden_states_image = encoder_hidden_states_image.repeat(
+ batch_multiplier, 1, 1
+ )
+ encoder_hidden_states = torch.cat(
+ [encoder_hidden_states_image, encoder_hidden_states], dim=1
+ )
+
+ # Transformer blocks (attention handles all-to-all internally for Ulysses)
+ for block in self.blocks:
+ x = block(
+ x,
+ encoder_hidden_states,
+ temb_proj,
+ freqs_cos,
+ freqs_sin,
+ )
+
+ # Gather sequence from all ranks: [B, S/P] -> [B, S]
+ if self.use_ulysses:
+ # Ensure tensor is contiguous before all_gather
+ x = x.contiguous()
+ x_list = [torch.zeros_like(x) for _ in range(self.ulysses_size)]
+ torch.distributed.all_gather(x_list, x, group=self.ulysses_pg)
+ x = torch.cat(x_list, dim=1)
+
+ # Output projection and unpatchify
+ shift, scale = (self.scale_shift_table + temb.unsqueeze(1)).chunk(2, dim=1)
+ x = self.norm_out(x) * (1 + scale) + shift
+ x = x.to(hidden_states.dtype)
+
+ return self.unpatchify(self.proj_out(x), original_shape)
+
+ def load_weights(self, weights: dict) -> None:
+ # Remap checkpoint keys to match model structure
+ remapped_weights = {}
+ for key, value in weights.items():
+ # Remap transformer block FFN keys
+ if ".ffn.net.0.proj." in key:
+ new_key = key.replace(".ffn.net.0.proj.", ".ffn.up_proj.")
+ remapped_weights[new_key] = value
+ elif ".ffn.net.2." in key:
+ new_key = key.replace(".ffn.net.2.", ".ffn.down_proj.")
+ remapped_weights[new_key] = value
+ # Remap image embedder FF keys
+ elif ".image_embedder.ff.net.0.proj." in key:
+ new_key = key.replace(".image_embedder.ff.net.0.proj.", ".image_embedder.ff_in.")
+ remapped_weights[new_key] = value
+ elif ".image_embedder.ff.net.2." in key:
+ new_key = key.replace(".image_embedder.ff.net.2.", ".image_embedder.ff_out.")
+ remapped_weights[new_key] = value
+ # Remap I2V attention keys
+ elif ".attn2.add_k_proj." in key:
+ new_key = key.replace(".attn2.add_k_proj.", ".add_k_proj.")
+ remapped_weights[new_key] = value
+ elif ".attn2.add_v_proj." in key:
+ new_key = key.replace(".attn2.add_v_proj.", ".add_v_proj.")
+ remapped_weights[new_key] = value
+ elif ".attn2.norm_added_k." in key:
+ new_key = key.replace(".attn2.norm_added_k.", ".norm_added_k.")
+ remapped_weights[new_key] = value
+ else:
+ remapped_weights[key] = value
+
+ weights = remapped_weights
+
+ # Handle root-level parameters (filter_weights doesn't work for empty prefix)
+ for param_name, param in self._parameters.items():
+ if param is not None and param_name in weights:
+ param.data.copy_(weights[param_name].to(self.model_config.torch_dtype))
+
+ params_map = {
+ "qkv_proj": ["to_q", "to_k", "to_v"],
+ }
+ loader = DynamicLinearWeightLoader(self.model_config, params_map=params_map)
+
+ for name, module in tqdm(self.named_modules(), desc="Loading weights"):
+ if len(module._parameters) == 0:
+ continue
+
+ if isinstance(module, Linear):
+ weight_dicts = loader.get_linear_weights(module, name, weights)
+
+ if weight_dicts:
+ loader.load_linear_weights(module, name, weight_dicts)
+ elif "add_k_proj" in name or "add_v_proj" in name:
+ logger.info(f"[Weight Loading] No weights found for I2V module: {name}")
+ else:
+ module_weights = loader.filter_weights(name, weights)
+ for param_name, param in module._parameters.items():
+ if param is not None and param_name in module_weights:
+ param.data.copy_(
+ module_weights[param_name].to(self.model_config.torch_dtype)
+ )
+
+ def post_load_weights(self) -> None:
+ """Call post_load_weights on all Linear modules and convert embedders to target dtype."""
+ # Convert condition_embedder components to target dtype
+ target_dtype = self.model_config.torch_dtype
+ if hasattr(self.condition_embedder, "time_embedder"):
+ self.condition_embedder.time_embedder.to(target_dtype)
+ if hasattr(self.condition_embedder, "text_embedder"):
+ self.condition_embedder.text_embedder.to(target_dtype)
+
+ # Call post_load_weights on all Linear modules
+ for _, module in self.named_modules():
+ if isinstance(module, Linear):
+ module.post_load_weights()
diff --git a/tensorrt_llm/_torch/visual_gen/modules/__init__.py b/tensorrt_llm/_torch/visual_gen/modules/__init__.py
new file mode 100644
index 0000000000..93636ddb53
--- /dev/null
+++ b/tensorrt_llm/_torch/visual_gen/modules/__init__.py
@@ -0,0 +1,26 @@
+# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+# SPDX-License-Identifier: Apache-2.0
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""
+Visual Generation Modules
+
+This module provides modular neural network components for visual generation models.
+"""
+
+from .attention import Attention, QKVMode
+
+__all__ = [
+ "Attention",
+ "QKVMode",
+]
diff --git a/tensorrt_llm/_torch/visual_gen/modules/attention.py b/tensorrt_llm/_torch/visual_gen/modules/attention.py
new file mode 100644
index 0000000000..0c83bf5e28
--- /dev/null
+++ b/tensorrt_llm/_torch/visual_gen/modules/attention.py
@@ -0,0 +1,284 @@
+from enum import Enum
+from typing import TYPE_CHECKING, Optional, Tuple
+
+import torch
+import torch.nn as nn
+
+from ...modules.linear import Linear, WeightMode, WeightsLoadingConfig
+from ...modules.rms_norm import RMSNorm
+from ..attention_backend.interface import AttentionTensorLayout
+from ..attention_backend.utils import create_attention
+
+if TYPE_CHECKING:
+ from ..config import DiffusionModelConfig
+
+
+class QKVMode(str, Enum):
+ FUSE_QKV = "fuse_qkv"
+ FUSE_KV = "fuse_kv"
+ SEPARATE_QKV = "separate"
+
+
+# TODO: torch compile
+def apply_rotary_emb(
+ x: torch.Tensor, freqs_cos: torch.Tensor, freqs_sin: torch.Tensor
+) -> torch.Tensor:
+ freqs_cos = freqs_cos.to(x.dtype)
+ freqs_sin = freqs_sin.to(x.dtype)
+ x1, x2 = x.unflatten(-1, (-1, 2)).unbind(-1) # [B, S, H, D/2]
+
+ cos = freqs_cos[..., 0::2]
+ sin = freqs_sin[..., 1::2]
+
+ return torch.stack([x1 * cos - x2 * sin, x1 * sin + x2 * cos], dim=-1).flatten(-2)
+
+
+class Attention(nn.Module):
+ """Attention module for visual generation models."""
+
+ def __init__(
+ self,
+ hidden_size: int,
+ num_attention_heads: int,
+ num_key_value_heads: Optional[int] = None,
+ head_dim: Optional[int] = None,
+ qkv_mode: QKVMode = QKVMode.FUSE_QKV,
+ qk_norm: bool = True,
+ eps: float = 1e-6, # TODO: remove this, we should add this to the config
+ config: Optional["DiffusionModelConfig"] = None,
+ layer_idx: Optional[int] = None,
+ ):
+ super().__init__()
+
+ config = config or DiffusionModelConfig()
+ self.dtype = config.torch_dtype
+ self.quant_config = config.quant_config
+ self.skip_create_weights_in_init = config.skip_create_weights_in_init
+ self.force_dynamic_quantization = config.force_dynamic_quantization
+ self.mapping = getattr(config, "mapping", None)
+
+ self.hidden_size = hidden_size
+ self.num_attention_heads = num_attention_heads
+ self.num_key_value_heads = num_key_value_heads or num_attention_heads
+ self.head_dim = head_dim or (hidden_size // num_attention_heads)
+ self.qkv_mode = QKVMode(qkv_mode) if isinstance(qkv_mode, str) else qkv_mode
+
+ # Select compute backend (orthogonal to parallelism)
+ ulysses_size = config.parallel.dit_ulysses_size
+ base_backend = config.attention.backend
+
+ if self.qkv_mode == QKVMode.SEPARATE_QKV:
+ backend_name = "VANILLA" # Cross-attention requires VANILLA
+ else:
+ backend_name = base_backend
+ self.attn_backend = backend_name
+ self.qk_norm = qk_norm
+ self.layer_idx = layer_idx if layer_idx is not None else 0
+ self.eps = eps
+
+ self.q_dim = self.num_attention_heads * self.head_dim
+ self.kv_dim = self.num_key_value_heads * self.head_dim
+
+ self._init_qkv_proj()
+
+ if self.qk_norm:
+ self.norm_q = RMSNorm(
+ hidden_size=self.q_dim, eps=self.eps, dtype=self.dtype, has_weights=True
+ )
+ self.norm_k = RMSNorm(
+ hidden_size=self.kv_dim, eps=self.eps, dtype=self.dtype, has_weights=True
+ )
+
+ # TODO: Use weight mapper to create just a Linear module
+ self.to_out = nn.ModuleList(
+ [
+ Linear(
+ self.q_dim,
+ self.hidden_size,
+ dtype=self.dtype,
+ mapping=self.mapping,
+ quant_config=self.quant_config,
+ skip_create_weights_in_init=self.skip_create_weights_in_init,
+ force_dynamic_quantization=self.force_dynamic_quantization,
+ )
+ ]
+ )
+
+ # Compute head counts for the backend
+ # Ulysses shards heads across workers; inner backend sees sharded count
+ if ulysses_size > 1 and self.qkv_mode != QKVMode.SEPARATE_QKV:
+ backend_num_heads = self.num_attention_heads // ulysses_size
+ backend_num_kv_heads = self.num_key_value_heads // ulysses_size
+ else:
+ backend_num_heads = self.num_attention_heads
+ backend_num_kv_heads = self.num_key_value_heads
+
+ # Create compute backend
+ self.attn = create_attention(
+ backend=backend_name,
+ layer_idx=self.layer_idx,
+ num_heads=backend_num_heads,
+ head_dim=self.head_dim,
+ num_kv_heads=backend_num_kv_heads,
+ quant_config=self.quant_config,
+ dtype=self.dtype,
+ )
+
+ # Wrap with parallelism strategy (orthogonal to backend choice)
+ if ulysses_size > 1 and self.qkv_mode != QKVMode.SEPARATE_QKV:
+ from ..attention_backend.parallel import UlyssesAttention
+
+ process_group = getattr(config, "ulysses_process_group", None)
+ self.attn = UlyssesAttention(
+ inner_backend=self.attn,
+ process_group=process_group,
+ )
+
+ def _init_qkv_proj(self) -> None:
+ if self.qkv_mode == QKVMode.FUSE_QKV:
+ qkv_out_dim = self.q_dim + 2 * self.kv_dim
+ self.qkv_proj = Linear(
+ self.hidden_size,
+ qkv_out_dim,
+ dtype=self.dtype,
+ mapping=self.mapping,
+ quant_config=self.quant_config,
+ skip_create_weights_in_init=self.skip_create_weights_in_init,
+ force_dynamic_quantization=self.force_dynamic_quantization,
+ weights_loading_config=WeightsLoadingConfig(
+ weight_mode=WeightMode.FUSED_QKV_LINEAR
+ ),
+ fused_weight_shard_indices_mapping={
+ "q": (0, self.q_dim),
+ "k": (self.q_dim, self.kv_dim),
+ "v": (self.q_dim + self.kv_dim, self.kv_dim),
+ },
+ )
+ else:
+ self.to_q = Linear(
+ self.hidden_size,
+ self.q_dim,
+ dtype=self.dtype,
+ mapping=self.mapping,
+ quant_config=self.quant_config,
+ skip_create_weights_in_init=self.skip_create_weights_in_init,
+ force_dynamic_quantization=self.force_dynamic_quantization,
+ )
+ self.to_k = Linear(
+ self.hidden_size,
+ self.kv_dim,
+ dtype=self.dtype,
+ mapping=self.mapping,
+ quant_config=self.quant_config,
+ skip_create_weights_in_init=self.skip_create_weights_in_init,
+ force_dynamic_quantization=self.force_dynamic_quantization,
+ )
+ self.to_v = Linear(
+ self.hidden_size,
+ self.kv_dim,
+ dtype=self.dtype,
+ mapping=self.mapping,
+ quant_config=self.quant_config,
+ skip_create_weights_in_init=self.skip_create_weights_in_init,
+ force_dynamic_quantization=self.force_dynamic_quantization,
+ )
+
+ def get_qkv(
+ self,
+ hidden_states: torch.Tensor,
+ encoder_hidden_states: Optional[torch.Tensor] = None,
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
+ if self.qkv_mode == QKVMode.FUSE_QKV:
+ qkv = self.qkv_proj(hidden_states)
+ q, k, v = qkv.split([self.q_dim, self.kv_dim, self.kv_dim], dim=-1)
+ else:
+ kv_source = (
+ encoder_hidden_states if encoder_hidden_states is not None else hidden_states
+ )
+ q = self.to_q(hidden_states)
+ k = self.to_k(kv_source)
+ v = self.to_v(kv_source)
+ return q, k, v
+
+ def apply_qk_norm(self, q: torch.Tensor, k: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
+ if self.qk_norm:
+ q = self.norm_q(q)
+ k = self.norm_k(k)
+ return q, k
+
+ def _attn_impl(
+ self,
+ q: torch.Tensor,
+ k: torch.Tensor,
+ v: torch.Tensor,
+ batch_size: Optional[int] = None,
+ seq_len: Optional[int] = None,
+ kv_seq_len: Optional[int] = None,
+ ) -> torch.Tensor:
+ """
+ Call attention backend with appropriate tensor layout.
+
+ Two layout paths:
+ 1. HND backends (VANILLA): [B, S, H*D] -> [B, H, S, D]
+ 2. NHD backends (TRTLLM, UlyssesAttention): [B, S, H*D] -> [B, S, H, D]
+ """
+ backend_layout = getattr(self.attn, "preferred_layout", AttentionTensorLayout.NHD)
+
+ batch_size = batch_size or q.shape[0]
+ seq_len = seq_len or q.shape[1]
+ kv_seq_len = kv_seq_len or k.shape[1]
+
+ # Reshape inputs: [B, S, H*D] -> backend's preferred 4D layout
+ if backend_layout == AttentionTensorLayout.HND:
+ q = q.view(batch_size, -1, self.num_attention_heads, self.head_dim).transpose(1, 2)
+ k = k.view(batch_size, -1, self.num_key_value_heads, self.head_dim).transpose(1, 2)
+ v = v.view(batch_size, -1, self.num_key_value_heads, self.head_dim).transpose(1, 2)
+ else:
+ q = q.view(batch_size, -1, self.num_attention_heads, self.head_dim)
+ k = k.view(batch_size, -1, self.num_key_value_heads, self.head_dim)
+ v = v.view(batch_size, -1, self.num_key_value_heads, self.head_dim)
+
+ # Call backend
+ out = self.attn.forward(
+ q=q,
+ k=k,
+ v=v,
+ batch_size=batch_size,
+ seq_len=seq_len,
+ seq_len_kv=kv_seq_len if kv_seq_len != seq_len else None,
+ )
+
+ # Flatten back to [B, S, H*D]
+ if backend_layout == AttentionTensorLayout.HND:
+ return out.transpose(1, 2).flatten(2)
+ else:
+ return out.flatten(2)
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ encoder_hidden_states: Optional[torch.Tensor] = None,
+ freqs: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
+ ) -> torch.Tensor:
+ assert hidden_states.ndim == 3, "hidden_states must be a 3D tensor"
+ batch_size, seq_len = hidden_states.shape[:2]
+ kv_seq_len = (
+ encoder_hidden_states.shape[1] if encoder_hidden_states is not None else seq_len
+ )
+
+ q, k, v = self.get_qkv(hidden_states, encoder_hidden_states)
+ q, k = self.apply_qk_norm(q, k)
+
+ # Apply RoPE if provided (model handles RoPE, not attention backend)
+ if freqs is not None:
+ freqs_cos, freqs_sin = freqs
+ q = q.view(batch_size, seq_len, self.num_attention_heads, self.head_dim) # [B, S, H, D]
+ k = k.view(batch_size, kv_seq_len, self.num_key_value_heads, self.head_dim)
+ q = apply_rotary_emb(q, freqs_cos, freqs_sin)
+ k = apply_rotary_emb(k, freqs_cos, freqs_sin)
+ q = q.flatten(2)
+ k = k.flatten(2)
+
+ out = self._attn_impl(q, k, v, batch_size, seq_len, kv_seq_len)
+ out = self.to_out[0](out)
+ return out
diff --git a/tensorrt_llm/_torch/visual_gen/output.py b/tensorrt_llm/_torch/visual_gen/output.py
new file mode 100644
index 0000000000..d39c15b0c0
--- /dev/null
+++ b/tensorrt_llm/_torch/visual_gen/output.py
@@ -0,0 +1,29 @@
+"""Output dataclass for visual generation models."""
+
+from dataclasses import dataclass
+from typing import Optional
+
+import torch
+
+
+@dataclass
+class MediaOutput:
+ """Unified output for all visual generation models.
+
+ Different models populate different fields:
+ - FLUX2: image only
+ - WAN: video only
+ - LTX2: video + audio
+
+ Attributes:
+ image: Generated image as torch tensor with shape (height, width, channels) and dtype uint8.
+ Populated by FLUX2 for text-to-image generation.
+ video: Generated video frames as torch tensor with shape (num_frames, height, width, channels) and dtype uint8.
+ Populated by WAN and LTX2 for text-to-video generation.
+ audio: Generated audio as torch tensor with dtype float32.
+ Populated by LTX2 for text-to-video-with-audio generation.
+ """
+
+ image: Optional[torch.Tensor] = None
+ video: Optional[torch.Tensor] = None
+ audio: Optional[torch.Tensor] = None
diff --git a/tensorrt_llm/_torch/visual_gen/parallelism.py b/tensorrt_llm/_torch/visual_gen/parallelism.py
new file mode 100644
index 0000000000..1bda600fa0
--- /dev/null
+++ b/tensorrt_llm/_torch/visual_gen/parallelism.py
@@ -0,0 +1,100 @@
+"""Utilities for distributed parallelism setup in diffusion models."""
+
+from typing import Optional, Tuple
+
+import torch.distributed as dist
+
+from tensorrt_llm._torch.visual_gen.config import DiffusionModelConfig
+
+
+def setup_sequence_parallelism(
+ model_config: DiffusionModelConfig,
+ num_attention_heads: int,
+) -> Tuple[bool, int, Optional[dist.ProcessGroup], int]:
+ """
+ Setup sequence parallelism (currently Ulysses only) with CFG support.
+
+ Creates nested process groups where each CFG group has its own Ulysses group.
+ Example with cfg_size=2, ulysses_size=2, world_size=4:
+ GPU 0-1: CFG group 0, Ulysses group 0
+ GPU 2-3: CFG group 1, Ulysses group 1
+
+ Args:
+ model_config: Model configuration containing parallel settings
+ num_attention_heads: Number of attention heads in the model
+
+ Returns:
+ Tuple of (use_parallelism, parallelism_size, parallelism_pg, parallelism_rank):
+ - use_parallelism: Whether sequence parallelism is enabled
+ - parallelism_size: The sequence parallelism degree
+ - parallelism_pg: The process group for this rank (or None)
+ - parallelism_rank: This rank's position within its parallelism group
+
+ Raises:
+ RuntimeError: If torch.distributed is not initialized
+ ValueError: If configuration is invalid (incompatible sizes, head count not divisible, etc.)
+ NotImplementedError: If Ring attention is requested (not yet implemented)
+
+ Side Effects:
+ - Sets model_config.ulysses_process_group to the created process group
+
+ Note:
+ Both num_attention_heads and sequence length must be divisible by ulysses_size.
+ Head count is validated here; sequence length is validated at runtime during forward pass.
+ """
+ ulysses_size = model_config.parallel.dit_ulysses_size
+ ring_size = model_config.parallel.dit_ring_size
+ cfg_size = model_config.parallel.dit_cfg_size
+
+ # Check for ring attention (not yet implemented)
+ if ring_size > 1:
+ raise NotImplementedError("Ring attention parallelism is not yet implemented")
+
+ # Early exit if not using sequence parallelism
+ if ulysses_size <= 1:
+ model_config.ulysses_process_group = None
+ return False, 1, None, 0
+
+ # Validate distributed initialization
+ if not dist.is_initialized():
+ raise RuntimeError(
+ "torch.distributed.init_process_group() must be called before "
+ "setting up sequence parallelism"
+ )
+
+ rank = dist.get_rank()
+ world_size = dist.get_world_size()
+
+ # Validate total parallelism capacity
+ total_parallel = cfg_size * ulysses_size
+ if total_parallel > world_size:
+ raise ValueError(
+ f"cfg_size ({cfg_size}) * ulysses_size ({ulysses_size}) = "
+ f"{total_parallel} exceeds world_size ({world_size})"
+ )
+
+ # Validate head count divisibility
+ if num_attention_heads % ulysses_size != 0:
+ raise ValueError(
+ f"num_attention_heads ({num_attention_heads}) must be divisible by "
+ f"ulysses_size ({ulysses_size})"
+ )
+
+ # Create nested process groups
+ # Each CFG group has its own Ulysses group
+ ulysses_pg = None
+ ulysses_rank = 0
+
+ for cfg_id in range(cfg_size):
+ ulysses_ranks = list(range(cfg_id * ulysses_size, (cfg_id + 1) * ulysses_size))
+ pg = dist.new_group(ulysses_ranks, use_local_synchronization=True)
+
+ # Store if this rank belongs to this group
+ if rank in ulysses_ranks:
+ ulysses_pg = pg
+ ulysses_rank = rank - cfg_id * ulysses_size
+
+ # Store in config for Attention modules
+ model_config.ulysses_process_group = ulysses_pg
+
+ return True, ulysses_size, ulysses_pg, ulysses_rank
diff --git a/tensorrt_llm/_torch/visual_gen/pipeline.py b/tensorrt_llm/_torch/visual_gen/pipeline.py
new file mode 100644
index 0000000000..7876031df7
--- /dev/null
+++ b/tensorrt_llm/_torch/visual_gen/pipeline.py
@@ -0,0 +1,544 @@
+import time
+from typing import TYPE_CHECKING, Any, Callable, Dict, Optional, Tuple
+
+import torch
+import torch.distributed as dist
+import torch.nn as nn
+
+from tensorrt_llm.logger import logger
+from tensorrt_llm.mapping import Mapping
+
+from .teacache import TeaCacheBackend
+
+if TYPE_CHECKING:
+ from .config import DiffusionModelConfig
+
+
+class BasePipeline(nn.Module):
+ """
+ Base class for diffusion pipelines.
+ """
+
+ def __init__(self, model_config: "DiffusionModelConfig"):
+ super().__init__()
+ self.model_config = model_config
+ self.config = model_config.pretrained_config
+ self.mapping: Mapping = getattr(model_config, "mapping", None) or Mapping()
+
+ # Components
+ self.transformer: Optional[nn.Module] = None
+ self.vae: Optional[nn.Module] = None
+ self.text_encoder: Optional[nn.Module] = None
+ self.tokenizer: Optional[Any] = None
+ self.scheduler: Optional[Any] = None
+
+ # Initialize transformer
+ self._init_transformer()
+
+ @property
+ def rank(self):
+ return dist.get_rank() if dist.is_initialized() else 0
+
+ @property
+ def world_size(self):
+ return dist.get_world_size() if dist.is_initialized() else 1
+
+ @property
+ def dtype(self):
+ if hasattr(self, "transformer"):
+ return next(self.transformer.parameters()).dtype
+ return torch.float32
+
+ @property
+ def device(self):
+ return self.transformer.device
+
+ def infer(self, req: Any):
+ raise NotImplementedError
+
+ def _init_transformer(self) -> None:
+ raise NotImplementedError
+
+ def forward(self, *args, **kwargs):
+ raise NotImplementedError
+
+ def load_standard_components(
+ self,
+ checkpoint_dir: str,
+ device: torch.device,
+ skip_components: Optional[list] = None,
+ ) -> None:
+ raise NotImplementedError
+
+ def load_weights(self, weights: Dict[str, torch.Tensor]) -> None:
+ if self.transformer is not None and hasattr(self.transformer, "load_weights"):
+ self.transformer.load_weights(weights)
+
+ def post_load_weights(self) -> None:
+ if self.transformer is not None and hasattr(self.transformer, "post_load_weights"):
+ self.transformer.post_load_weights()
+
+ def _setup_teacache(self, model, coefficients: Optional[Dict] = None):
+ """Setup TeaCache optimization for the transformer model.
+
+ TeaCache caches transformer block outputs when timestep embeddings change slowly,
+ reducing computation during the denoising loop.
+
+ Args:
+ model: The transformer model to optimize
+ coefficients: Optional dict of model-specific polynomial coefficients for cache decisions
+ Format: {model_size: {"ret_steps": [...], "standard": [...]}}
+ """
+ self.cache_backend = None
+
+ # Get teacache config from model_config (always present now)
+ teacache_cfg = self.model_config.teacache
+ if not teacache_cfg.enable_teacache:
+ return
+
+ # Apply model-specific polynomial coefficients
+ # Coefficients are used to rescale embedding distances for cache decisions
+ if coefficients:
+ checkpoint_path = (
+ getattr(self.model_config.pretrained_config, "_name_or_path", "") or ""
+ )
+ for model_size, coeff_data in coefficients.items():
+ # Match model size in path (case-insensitive, e.g., "1.3B", "14B", "dev")
+ if model_size.lower() in checkpoint_path.lower():
+ if isinstance(coeff_data, dict):
+ # Select coefficient set based on warmup mode
+ mode = "ret_steps" if teacache_cfg.use_ret_steps else "standard"
+ if mode in coeff_data:
+ teacache_cfg.coefficients = coeff_data[mode]
+ logger.info(f"TeaCache: Using {model_size} coefficients ({mode} mode)")
+ else:
+ # Single coefficient list (no mode distinction)
+ teacache_cfg.coefficients = coeff_data
+ logger.info(f"TeaCache: Using {model_size} coefficients")
+ break
+
+ # Initialize and enable TeaCache backend
+ logger.info("TeaCache: Initializing...")
+ self.cache_backend = TeaCacheBackend(teacache_cfg)
+ self.cache_backend.enable(model)
+
+ def decode_latents(
+ self,
+ latents: torch.Tensor,
+ decode_fn: Callable[[torch.Tensor], Any],
+ extra_latents: Optional[Dict[str, Tuple[torch.Tensor, Callable]]] = None,
+ ):
+ """Execute VAE decoding. Only rank 0 performs decoding.
+
+ Args:
+ latents: Primary latents to decode (e.g., video)
+ decode_fn: Decoder function for primary latents
+ extra_latents: Optional dict of additional latents to decode.
+ Format: {name: (latents_tensor, decode_fn)}
+ Example: {"audio": (audio_latents, audio_decode_fn)}
+
+ Returns:
+ Single result if no extra_latents, tuple of results if extra_latents provided.
+ Non-rank-0 processes return None placeholders.
+ """
+ if self.rank == 0:
+ primary_result = decode_fn(latents)
+
+ if extra_latents:
+ extra_results = []
+ for name, (extra_latent, extra_decode_fn) in extra_latents.items():
+ extra_results.append(extra_decode_fn(extra_latent))
+ return (primary_result,) + tuple(extra_results)
+
+ return primary_result
+
+ # Return None placeholders for non-rank-0 processes
+ n_results = 1 + (len(extra_latents) if extra_latents else 0)
+ return (None,) * n_results if n_results > 1 else None
+
+ @staticmethod
+ def _rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0):
+ """Rescale noise to fix overexposure (https://huggingface.co/papers/2305.08891)."""
+ std_text = noise_pred_text.std(dim=list(range(1, noise_pred_text.ndim)), keepdim=True)
+ std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True)
+ noise_pred_rescaled = noise_cfg * (std_text / std_cfg)
+ return guidance_rescale * noise_pred_rescaled + (1 - guidance_rescale) * noise_cfg
+
+ def _setup_cfg_config(
+ self, guidance_scale, prompt_embeds, neg_prompt_embeds, extra_cfg_tensors=None
+ ):
+ """Setup CFG parallel configuration.
+
+ Args:
+ guidance_scale: CFG guidance scale
+ prompt_embeds: Positive prompt embeddings
+ neg_prompt_embeds: Negative prompt embeddings (None if already concatenated)
+ extra_cfg_tensors: Optional dict of additional tensors to split for CFG parallel.
+ Format: {name: (positive_tensor, negative_tensor)}
+ Example: {"audio_embeds": (pos_audio, neg_audio),
+ "attention_mask": (pos_mask, neg_mask)}
+
+ Returns:
+ Dict with CFG configuration including split tensors
+ """
+ # Access parallel config directly (always present now)
+ cfg_size = self.model_config.parallel.dit_cfg_size
+ ulysses_size = self.model_config.parallel.dit_ulysses_size
+
+ cfg_group = self.rank // ulysses_size
+ is_split_embeds = neg_prompt_embeds is not None
+ do_cfg_parallel = cfg_size >= 2 and guidance_scale > 1.0
+
+ local_extras = {}
+
+ if do_cfg_parallel:
+ if self.rank == 0:
+ logger.info(f"CFG Parallel: cfg_size={cfg_size}, ulysses_size={ulysses_size}")
+
+ # Split main embeddings
+ if is_split_embeds:
+ pos_embeds, neg_embeds = prompt_embeds, neg_prompt_embeds
+ else:
+ neg_embeds, pos_embeds = prompt_embeds.chunk(2)
+
+ local_embeds = pos_embeds if cfg_group == 0 else neg_embeds
+
+ # Split extra tensors if provided
+ if extra_cfg_tensors:
+ for name, (pos_tensor, neg_tensor) in extra_cfg_tensors.items():
+ if pos_tensor is not None and neg_tensor is not None:
+ local_extras[name] = pos_tensor if cfg_group == 0 else neg_tensor
+ elif pos_tensor is not None:
+ # Only positive provided, use it for both
+ local_extras[name] = pos_tensor
+ else:
+ local_embeds = None
+ if is_split_embeds and guidance_scale > 1.0:
+ prompt_embeds = torch.cat([neg_prompt_embeds, prompt_embeds])
+
+ # For standard CFG, concatenate extra tensors
+ if extra_cfg_tensors:
+ for name, (pos_tensor, neg_tensor) in extra_cfg_tensors.items():
+ if pos_tensor is not None and neg_tensor is not None and guidance_scale > 1.0:
+ local_extras[name] = torch.cat([neg_tensor, pos_tensor], dim=0)
+ elif pos_tensor is not None:
+ local_extras[name] = pos_tensor
+
+ return {
+ "enabled": do_cfg_parallel,
+ "cfg_size": cfg_size,
+ "ulysses_size": ulysses_size,
+ "cfg_group": cfg_group,
+ "local_embeds": local_embeds,
+ "prompt_embeds": prompt_embeds,
+ "local_extras": local_extras,
+ }
+
+ def _denoise_step_cfg_parallel(
+ self,
+ latents,
+ extra_stream_latents,
+ timestep,
+ local_embeds,
+ forward_fn,
+ guidance_scale,
+ guidance_rescale,
+ ulysses_size,
+ local_extras,
+ ):
+ """Execute single denoising step with CFG parallel."""
+ t_start = time.time()
+ result = forward_fn(latents, extra_stream_latents, timestep, local_embeds, local_extras)
+
+ # Handle return format: (primary_noise, extra_noises_dict) or just primary_noise
+ if isinstance(result, tuple) and len(result) == 2 and isinstance(result[1], dict):
+ noise_pred_local, extra_noise_locals = result
+ else:
+ noise_pred_local = result
+ extra_noise_locals = {}
+
+ t_transformer = time.time() - t_start
+
+ c_start = time.time()
+
+ # All-gather primary noise
+ gather_list = [torch.empty_like(noise_pred_local) for _ in range(self.world_size)]
+ dist.all_gather(gather_list, noise_pred_local)
+ noise_cond = gather_list[0]
+ noise_uncond = gather_list[ulysses_size]
+ noise_pred = noise_uncond + guidance_scale * (noise_cond - noise_uncond)
+
+ # All-gather extra stream noises
+ extra_noise_preds = {}
+ for name, noise_local in extra_noise_locals.items():
+ gather_list_extra = [torch.empty_like(noise_local) for _ in range(self.world_size)]
+ dist.all_gather(gather_list_extra, noise_local)
+ noise_cond_extra = gather_list_extra[0]
+ noise_uncond_extra = gather_list_extra[ulysses_size]
+ extra_noise_preds[name] = noise_uncond_extra + guidance_scale * (
+ noise_cond_extra - noise_uncond_extra
+ )
+
+ if guidance_rescale > 0.0:
+ extra_noise_preds[name] = self._rescale_noise_cfg(
+ extra_noise_preds[name], noise_cond_extra, guidance_rescale
+ )
+
+ if guidance_rescale > 0.0:
+ noise_pred = self._rescale_noise_cfg(noise_pred, noise_cond, guidance_rescale)
+
+ t_cfg = time.time() - c_start
+ return noise_pred, extra_noise_preds, t_transformer, t_cfg
+
+ def _denoise_step_standard(
+ self,
+ latents,
+ extra_stream_latents,
+ timestep,
+ prompt_embeds,
+ forward_fn,
+ guidance_scale,
+ guidance_rescale,
+ local_extras,
+ ):
+ """Execute single denoising step without CFG parallel."""
+ if guidance_scale > 1.0:
+ latent_input = torch.cat([latents] * 2)
+ # Duplicate extra stream latents for CFG
+ extra_stream_input = {
+ name: torch.cat([stream_latents] * 2)
+ for name, stream_latents in extra_stream_latents.items()
+ }
+ else:
+ latent_input = latents
+ extra_stream_input = extra_stream_latents
+
+ timestep_expanded = timestep.expand(latent_input.shape[0])
+
+ t_start = time.time()
+ result = forward_fn(
+ latent_input, extra_stream_input, timestep_expanded, prompt_embeds, local_extras
+ )
+
+ # Handle return format: (primary_noise, extra_noises_dict) or just primary_noise
+ if isinstance(result, tuple) and len(result) == 2 and isinstance(result[1], dict):
+ noise_pred, extra_noise_preds = result
+ else:
+ noise_pred = result
+ extra_noise_preds = {}
+
+ t_transformer = time.time() - t_start
+
+ c_start = time.time()
+ if guidance_scale > 1.0:
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
+ noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
+
+ # Apply CFG to extra streams
+ for name, noise_extra in extra_noise_preds.items():
+ noise_uncond_extra, noise_text_extra = noise_extra.chunk(2)
+ extra_noise_preds[name] = noise_uncond_extra + guidance_scale * (
+ noise_text_extra - noise_uncond_extra
+ )
+
+ if guidance_rescale > 0.0:
+ extra_noise_preds[name] = self._rescale_noise_cfg(
+ extra_noise_preds[name], noise_text_extra, guidance_rescale
+ )
+
+ if guidance_rescale > 0.0:
+ noise_pred = self._rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale)
+
+ t_cfg = time.time() - c_start
+ else:
+ t_cfg = 0.0
+
+ return noise_pred, extra_noise_preds, t_transformer, t_cfg
+
+ def _scheduler_step(
+ self,
+ latents,
+ extra_stream_latents,
+ noise_pred,
+ extra_noise_preds,
+ timestep,
+ scheduler,
+ extra_stream_schedulers,
+ ):
+ """Execute scheduler step for all streams."""
+ t_start = time.time()
+ latents = scheduler.step(noise_pred, timestep, latents, return_dict=False)[0]
+
+ # Step schedulers for extra streams
+ for name, noise_extra in extra_noise_preds.items():
+ if name in extra_stream_schedulers:
+ extra_stream_latents[name] = extra_stream_schedulers[name].step(
+ noise_extra, timestep, extra_stream_latents[name], return_dict=False
+ )[0]
+
+ t_sched = time.time() - t_start
+ return latents, extra_stream_latents, t_sched
+
+ def denoise(
+ self,
+ latents: torch.Tensor,
+ scheduler: Any,
+ prompt_embeds: torch.Tensor,
+ guidance_scale: float,
+ forward_fn: Callable,
+ timesteps: Optional[torch.Tensor] = None,
+ neg_prompt_embeds: Optional[torch.Tensor] = None,
+ guidance_rescale: float = 0.0,
+ extra_cfg_tensors: Optional[Dict[str, Tuple[torch.Tensor, Optional[torch.Tensor]]]] = None,
+ extra_streams: Optional[Dict[str, Tuple[torch.Tensor, Any]]] = None,
+ guidance_scale_2: Optional[float] = None,
+ boundary_timestep: Optional[float] = None,
+ ):
+ """Execute denoising loop with optional CFG parallel and TeaCache support.
+
+ Args:
+ latents: Initial noise latents (primary stream, e.g., video)
+ scheduler: Diffusion scheduler for primary stream
+ prompt_embeds: Text embeddings (positive)
+ guidance_scale: CFG strength (1.0 = no guidance)
+ forward_fn: Transformer forward function
+ Signature: forward_fn(latents, extra_stream_latents, timestep,
+ encoder_hidden_states, extra_tensors_dict)
+ Returns: (primary_noise, extra_stream_noises_dict) or just primary_noise
+ timesteps: Optional custom timesteps (defaults to scheduler.timesteps)
+ neg_prompt_embeds: Optional negative text embeddings for CFG
+ guidance_rescale: CFG rescale factor to prevent overexposure
+ extra_cfg_tensors: Optional dict of additional tensors to split for CFG parallel
+ Format: {name: (positive_tensor, negative_tensor)}
+ Example: {"audio_embeds": (pos_audio, neg_audio)}
+ extra_streams: Optional dict of additional streams to denoise in parallel
+ Format: {name: (stream_latents, stream_scheduler)}
+ Example: {"audio": (audio_latents, audio_scheduler)}
+ guidance_scale_2: Optional guidance scale for two-stage denoising.
+ When provided with boundary_timestep, switches from guidance_scale
+ to guidance_scale_2 when timestep < boundary_timestep.
+ boundary_timestep: Optional timestep boundary for two-stage denoising.
+ Switches guidance scale when crossing this threshold.
+
+ Returns:
+ Single latents if no extra_streams
+ Tuple (primary_latents, extra_streams_dict) if extra_streams provided
+ """
+ if timesteps is None:
+ timesteps = scheduler.timesteps
+
+ total_steps = len(timesteps)
+ has_extra_streams = extra_streams is not None and len(extra_streams) > 0
+
+ # Reset TeaCache state for new generation
+ # Sets warmup/cutoff steps based on total_steps
+ if (
+ hasattr(self, "cache_backend")
+ and self.cache_backend
+ and self.cache_backend.is_enabled()
+ ):
+ self.cache_backend.refresh(total_steps)
+
+ if self.rank == 0:
+ if has_extra_streams:
+ stream_names = ", ".join(["primary"] + list(extra_streams.keys()))
+ logger.info(
+ f"Denoising [{stream_names}]: {total_steps} steps, guidance={guidance_scale}"
+ )
+ else:
+ logger.info(f"Denoising: {total_steps} steps, guidance={guidance_scale}")
+
+ cfg_config = self._setup_cfg_config(
+ guidance_scale, prompt_embeds, neg_prompt_embeds, extra_cfg_tensors
+ )
+ do_cfg_parallel = cfg_config["enabled"]
+ prompt_embeds = cfg_config["prompt_embeds"]
+ local_extras = cfg_config["local_extras"]
+
+ # Extract extra stream latents and schedulers
+ extra_stream_latents = {}
+ extra_stream_schedulers = {}
+ if extra_streams:
+ for name, (stream_latents, stream_scheduler) in extra_streams.items():
+ extra_stream_latents[name] = stream_latents
+ extra_stream_schedulers[name] = stream_scheduler
+
+ start_time = time.time()
+
+ for i, t in enumerate(timesteps):
+ step_start = time.time()
+
+ # Two-stage denoising: switch guidance scale at boundary
+ current_guidance_scale = guidance_scale
+ if guidance_scale_2 is not None and boundary_timestep is not None:
+ t_scalar = t.item() if t.dim() == 0 else t[0].item()
+ if t_scalar < boundary_timestep:
+ current_guidance_scale = guidance_scale_2
+
+ # Denoise
+ if do_cfg_parallel:
+ timestep = t.expand(latents.shape[0])
+ noise_pred, extra_noise_preds, t_trans, t_cfg = self._denoise_step_cfg_parallel(
+ latents,
+ extra_stream_latents,
+ timestep,
+ cfg_config["local_embeds"],
+ forward_fn,
+ current_guidance_scale,
+ guidance_rescale,
+ cfg_config["ulysses_size"],
+ local_extras,
+ )
+ else:
+ noise_pred, extra_noise_preds, t_trans, t_cfg = self._denoise_step_standard(
+ latents,
+ extra_stream_latents,
+ t,
+ prompt_embeds,
+ forward_fn,
+ current_guidance_scale,
+ guidance_rescale,
+ local_extras,
+ )
+
+ # Scheduler step for all streams
+ latents, extra_stream_latents, t_sched = self._scheduler_step(
+ latents,
+ extra_stream_latents,
+ noise_pred,
+ extra_noise_preds,
+ t,
+ scheduler,
+ extra_stream_schedulers,
+ )
+
+ # Logging
+ if self.rank == 0:
+ step_time = time.time() - step_start
+ avg_time = (time.time() - start_time) / (i + 1)
+ eta = avg_time * (total_steps - i - 1)
+ logger.info(
+ f"Step {i + 1}/{total_steps} | {step_time:.2f}s "
+ f"(trans={t_trans:.2f}s cfg={t_cfg:.3f}s sched={t_sched:.3f}s) | "
+ f"Avg={avg_time:.2f}s/step ETA={eta:.1f}s"
+ )
+
+ if self.rank == 0:
+ total_time = time.time() - start_time
+ logger.info("=" * 80)
+ logger.info(f"Denoising done: {total_time:.2f}s ({total_time / total_steps:.2f}s/step)")
+
+ # Log TeaCache performance statistics
+ # Shows how many transformer steps were skipped (cache hits) vs computed
+ if (
+ hasattr(self, "cache_backend")
+ and self.cache_backend
+ and self.cache_backend.is_enabled()
+ ):
+ stats = self.cache_backend.get_stats()
+ if stats:
+ logger.info(
+ f"TeaCache: {stats['hit_rate']:.1%} hit rate ({stats['cached']}/{stats['total']} steps)"
+ )
+
+ return (latents, extra_stream_latents) if has_extra_streams else latents
diff --git a/tensorrt_llm/_torch/visual_gen/pipeline_loader.py b/tensorrt_llm/_torch/visual_gen/pipeline_loader.py
new file mode 100644
index 0000000000..4cbb05a8e7
--- /dev/null
+++ b/tensorrt_llm/_torch/visual_gen/pipeline_loader.py
@@ -0,0 +1,228 @@
+"""
+Model loader for diffusion pipelines.
+
+Flow:
+1. Load config via DiffusionModelConfig.from_pretrained()
+2. Create pipeline via AutoPipeline.from_config() with MetaInit
+3. Load weights with on-the-fly quantization if dynamic_weight_quant=True
+4. Call pipeline.post_load_weights()
+
+Dynamic Quantization:
+- If quant_config specifies FP8/NVFP4 and dynamic_weight_quant=True:
+ - Model Linear layers are created with FP8/NVFP4 buffers
+ - BF16 checkpoint weights are quantized on-the-fly during loading
+ - Quantized weights are copied into model buffers
+"""
+
+import os
+from typing import TYPE_CHECKING, Optional
+
+import torch
+
+from tensorrt_llm._torch.models.modeling_utils import MetaInitMode
+from tensorrt_llm.llmapi.utils import download_hf_model
+from tensorrt_llm.logger import logger
+from tensorrt_llm.mapping import Mapping
+
+from .checkpoints import WeightLoader
+from .config import DiffusionArgs, DiffusionModelConfig, PipelineComponent
+from .models import AutoPipeline
+
+if TYPE_CHECKING:
+ from .models import BasePipeline
+
+
+class PipelineLoader:
+ """
+ Loader for diffusion pipelines.
+
+ Supports dynamic quantization: when quant_config specifies FP8/NVFP4,
+ model is built with quantized buffers and BF16 weights are quantized
+ on-the-fly during loading.
+
+ Example:
+ args = DiffusionArgs(
+ checkpoint_path="/path/to/model",
+ linear=LinearConfig(type="trtllm-fp8-blockwise"),
+ parallel=ParallelConfig(dit_tp_size=2),
+ )
+ pipeline = PipelineLoader(args).load()
+ """
+
+ def __init__(
+ self,
+ args: Optional[DiffusionArgs] = None,
+ *,
+ mapping: Optional[Mapping] = None,
+ device: str = "cuda",
+ ):
+ """
+ Initialize model loader.
+
+ Args:
+ args: DiffusionArgs containing all configuration (preferred)
+ mapping: Tensor parallel mapping (fallback if args is None)
+ device: Device to load model on (fallback if args is None)
+ """
+ self.args = args
+ if args is not None:
+ self.mapping = args.to_mapping()
+ self.device = torch.device(args.device)
+ else:
+ self.mapping = mapping or Mapping()
+ self.device = torch.device(device)
+
+ def _resolve_checkpoint_dir(self, checkpoint_dir: str) -> str:
+ """Resolve checkpoint_dir to a local directory path.
+
+ If checkpoint_dir is an existing local path, returns it unchanged.
+ Otherwise, attempts to download from HuggingFace Hub using the
+ file-lock-protected ``download_hf_model`` utility (safe for
+ concurrent multi-process access).
+
+ Args:
+ checkpoint_dir: Local path or HuggingFace Hub model ID.
+
+ Returns:
+ Path to local directory containing the model.
+
+ Raises:
+ ValueError: If the path cannot be resolved (invalid repo ID,
+ authentication failure, offline with no cache, etc.)
+ """
+ if os.path.exists(checkpoint_dir):
+ return checkpoint_dir
+
+ revision = self.args.revision if self.args else None
+ logger.info(
+ f"'{checkpoint_dir}' not found locally; "
+ f"attempting HuggingFace Hub download (revision={revision})"
+ )
+ try:
+ local_dir = download_hf_model(checkpoint_dir, revision=revision)
+ except Exception as e:
+ raise ValueError(
+ f"Could not resolve '{checkpoint_dir}' as a local path or "
+ f"HuggingFace Hub model ID: {e}"
+ ) from e
+ return str(local_dir)
+
+ def load(
+ self,
+ checkpoint_dir: Optional[str] = None,
+ ) -> "BasePipeline":
+ """
+ Load a diffusion pipeline with optional dynamic quantization.
+
+ Flow:
+ 1. Resolve checkpoint_dir (local path or HuggingFace Hub model ID)
+ 2. Load config via DiffusionModelConfig.from_pretrained()
+ 3. Create pipeline via AutoPipeline.from_config() with MetaInit
+ 4. Load transformer weights via pipeline.load_weights()
+ 5. Load auxiliary components (VAE, text_encoder) via diffusers
+ 6. Call pipeline.post_load_weights()
+
+ Args:
+ checkpoint_dir: Local path or HF Hub model ID (uses args.checkpoint_path if not provided)
+
+ Returns:
+ Loaded pipeline (WanPipeline, FluxPipeline, etc.) - type auto-detected
+ """
+ # Resolve checkpoint_dir
+ checkpoint_dir = checkpoint_dir or (self.args.checkpoint_path if self.args else None)
+ if not checkpoint_dir:
+ raise ValueError("checkpoint_dir must be provided or set in DiffusionArgs")
+ checkpoint_dir = self._resolve_checkpoint_dir(str(checkpoint_dir))
+
+ # Get loading options from args
+ skip_components = self.args.skip_components if self.args else []
+
+ # =====================================================================
+ # STEP 1: Load Config (includes quant config parsing)
+ # Merge pretrained checkpoint config with user-provided DiffusionArgs
+ # =====================================================================
+ logger.info(f"Loading config from {checkpoint_dir}")
+ config = DiffusionModelConfig.from_pretrained(
+ checkpoint_dir,
+ args=self.args,
+ mapping=self.mapping,
+ )
+
+ # Log quantization settings
+ if config.quant_config and config.quant_config.quant_algo:
+ logger.info(f"Quantization: {config.quant_config.quant_algo.name}")
+ logger.info(f"Dynamic weight quant: {config.dynamic_weight_quant}")
+
+ # =====================================================================
+ # STEP 2: Create Pipeline with MetaInit
+ # Pipeline type is auto-detected from model_index.json
+ # - Meta tensors (no GPU memory until materialization)
+ # - If quant_config specifies FP8, Linear layers have FP8 weight buffers
+ # =====================================================================
+ logger.info("Creating pipeline with MetaInitMode")
+ with MetaInitMode():
+ pipeline = AutoPipeline.from_config(config, checkpoint_dir)
+
+ # Convert meta tensors to CUDA tensors
+ self._materialize_meta_tensors(pipeline)
+ pipeline.to(self.device)
+
+ # =====================================================================
+ # STEP 3: Load Transformer Weights
+ # If dynamic_weight_quant=True:
+ # - BF16 checkpoint weights are loaded
+ # - Quantized on-the-fly to FP8/NVFP4 by DynamicLinearWeightLoader
+ # - Copied into model's quantized buffers
+ # =====================================================================
+ if pipeline.transformer is None:
+ raise ValueError("Pipeline has no transformer component")
+
+ transformer_components = getattr(pipeline, "transformer_components", ["transformer"])
+ logger.info(f"Transformer components: {transformer_components}")
+
+ transformer_path = os.path.join(checkpoint_dir, PipelineComponent.TRANSFORMER)
+ if not os.path.exists(transformer_path):
+ raise FileNotFoundError(
+ f"Transformer path does not exist: {transformer_path}. "
+ f"Checkpoint directory must contain a 'transformer' subdirectory."
+ )
+
+ weight_loader = WeightLoader(components=transformer_components)
+ # TODO: accelerate the cpu loading w/ multiprocessing
+ weights = weight_loader.load_weights(checkpoint_dir, self.mapping)
+
+ # Load weights into pipeline
+ pipeline.load_weights(weights)
+
+ # =====================================================================
+ # STEP 4: Load Standard Components (VAE, TextEncoder via diffusers)
+ # These are NOT quantized - loaded as-is from checkpoint
+ # =====================================================================
+ pipeline.load_standard_components(checkpoint_dir, self.device, skip_components)
+
+ # =====================================================================
+ # STEP 5: Post-load Hooks (TeaCache setup, etc.)
+ # =====================================================================
+ if hasattr(pipeline, "post_load_weights"):
+ pipeline.post_load_weights()
+
+ logger.info(f"Pipeline loaded: {pipeline.__class__.__name__}")
+ return pipeline
+
+ def _materialize_meta_tensors(self, module: torch.nn.Module) -> None:
+ """
+ Convert meta tensors to CUDA tensors.
+
+ Meta tensors are placeholders that don't allocate GPU memory.
+ After model structure is defined, we materialize them to real tensors.
+ """
+ memo = {}
+
+ def init_meta_tensor(t: torch.Tensor) -> torch.Tensor:
+ if t.device != torch.device("meta"):
+ return t
+ if t not in memo:
+ memo[t] = torch.empty_like(t, device="cuda")
+ return memo[t]
+
+ module._apply(init_meta_tensor)
diff --git a/tensorrt_llm/_torch/visual_gen/pipeline_registry.py b/tensorrt_llm/_torch/visual_gen/pipeline_registry.py
new file mode 100644
index 0000000000..f4c7fc37da
--- /dev/null
+++ b/tensorrt_llm/_torch/visual_gen/pipeline_registry.py
@@ -0,0 +1,94 @@
+"""Pipeline registry for unified config flow.
+
+Follows: DiffusionArgs ā PipelineLoader ā DiffusionModelConfig ā AutoPipeline ā BasePipeline
+
+All pipelines (Wan, Flux2, LTX2) register via @register_pipeline decorator.
+"""
+
+import json
+import os
+from typing import TYPE_CHECKING, Dict, Type
+
+from tensorrt_llm.logger import logger
+
+if TYPE_CHECKING:
+ from .config import DiffusionModelConfig
+ from .pipeline import BasePipeline
+
+# Global registry: pipeline_name -> pipeline_class
+PIPELINE_REGISTRY: Dict[str, Type["BasePipeline"]] = {}
+
+
+def register_pipeline(name: str):
+ """Register a pipeline class for AutoPipeline.
+
+ Usage:
+ @register_pipeline("WanPipeline")
+ class WanPipeline(BasePipeline):
+ ...
+ """
+
+ def decorator(cls: Type["BasePipeline"]) -> Type["BasePipeline"]:
+ PIPELINE_REGISTRY[name] = cls
+ logger.debug(f"Registered pipeline: {name} -> {cls.__name__}")
+ return cls
+
+ return decorator
+
+
+class AutoPipeline:
+ """Factory for creating pipelines from config."""
+
+ @staticmethod
+ def from_config(
+ config: "DiffusionModelConfig",
+ checkpoint_dir: str,
+ ) -> "BasePipeline":
+ """
+ Create pipeline instance from DiffusionModelConfig.
+ """
+ # Detect pipeline type from model_index.json
+ pipeline_type = AutoPipeline._detect_from_checkpoint(checkpoint_dir)
+
+ if pipeline_type not in PIPELINE_REGISTRY:
+ raise ValueError(
+ f"Unknown pipeline: '{pipeline_type}'. "
+ f"Available: {list(PIPELINE_REGISTRY.keys())}\n"
+ f"Checkpoint: {checkpoint_dir}"
+ )
+
+ pipeline_class = PIPELINE_REGISTRY[pipeline_type]
+ logger.info(f"AutoPipeline: Creating {pipeline_class.__name__} from {checkpoint_dir}")
+
+ # Instantiate pipeline with DiffusionModelConfig
+ return pipeline_class(config)
+
+ @staticmethod
+ def _detect_from_checkpoint(checkpoint_dir: str) -> str:
+ """Detect pipeline type."""
+ index_path = os.path.join(checkpoint_dir, "model_index.json")
+
+ if os.path.exists(index_path):
+ with open(index_path) as f:
+ index = json.load(f)
+
+ class_name = index.get("_class_name", "")
+
+ if class_name in PIPELINE_REGISTRY:
+ return class_name
+
+ if "ImageToVideo" in class_name or "I2V" in class_name:
+ if "Wan" in class_name:
+ return "WanImageToVideoPipeline"
+ # Generic Wan (T2V)
+ if "Wan" in class_name:
+ return "WanPipeline"
+ if "Flux" in class_name:
+ return "FluxPipeline"
+ if "LTX" in class_name or "Ltx" in class_name:
+ return "LTX2Pipeline"
+
+ raise ValueError(
+ f"Cannot detect pipeline type for {checkpoint_dir}\n"
+ f"Expected model_index.json with '_class_name' field at: {index_path}"
+ )
diff --git a/tensorrt_llm/_torch/visual_gen/quantization/__init__.py b/tensorrt_llm/_torch/visual_gen/quantization/__init__.py
new file mode 100644
index 0000000000..909629b1b3
--- /dev/null
+++ b/tensorrt_llm/_torch/visual_gen/quantization/__init__.py
@@ -0,0 +1,15 @@
+# SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+# SPDX-License-Identifier: Apache-2.0
+
+"""
+Quantization support for diffusion models.
+"""
+
+from .loader import DynamicLinearWeightLoader
+from .ops import quantize_fp8_blockwise, quantize_fp8_per_tensor
+
+__all__ = [
+ "DynamicLinearWeightLoader",
+ "quantize_fp8_per_tensor",
+ "quantize_fp8_blockwise",
+]
diff --git a/tensorrt_llm/_torch/visual_gen/quantization/loader.py b/tensorrt_llm/_torch/visual_gen/quantization/loader.py
new file mode 100644
index 0000000000..a4a2a3a11c
--- /dev/null
+++ b/tensorrt_llm/_torch/visual_gen/quantization/loader.py
@@ -0,0 +1,197 @@
+"""
+Dynamic weight quantization loader for Linear modules.
+
+Wraps Linear.load_weights() to perform dynamic quantization before loading to device.
+"""
+
+from typing import Dict, List, Optional
+
+import torch
+
+from tensorrt_llm._torch.modules.linear import Linear, WeightMode
+from tensorrt_llm._torch.visual_gen.config import DiffusionModelConfig
+from tensorrt_llm._torch.visual_gen.quantization.ops import (
+ quantize_fp8_blockwise,
+ quantize_fp8_per_tensor,
+)
+from tensorrt_llm.quantization.mode import QuantAlgo
+
+
+class DynamicLinearWeightLoader:
+ """
+ Dynamic weight quantization loader for Linear modules.
+
+ Wraps Linear.load_weights() to perform dynamic (load-time) quantization
+ from BF16/FP16 to FP8 before loading weights to device.
+
+ Example:
+ params_map = {'qkv_proj': ['to_q', 'to_k', 'to_v']}
+ loader = DynamicLinearWeightLoader(model_config, params_map=params_map)
+
+ for name, module in model.named_modules():
+ if isinstance(module, Linear):
+ weight_dicts = loader.get_linear_weights(module, name, weights)
+ loader.load_linear_weights(module, name, weight_dicts)
+ """
+
+ def __init__(
+ self,
+ model_config: DiffusionModelConfig,
+ params_map: Optional[Dict[str, List[str]]] = None,
+ ):
+ self.model_config = model_config
+ self.quant_config = model_config.quant_config
+ self.quant_config_dict = model_config.quant_config_dict
+ self.dynamic_weight_quant = model_config.dynamic_weight_quant
+ self.params_map = params_map or {}
+
+ # =========================================================================
+ # Weight gathering methods
+ # =========================================================================
+
+ def get_linear_weights(
+ self,
+ module: Linear,
+ full_name: str,
+ weights: Dict[str, torch.Tensor],
+ ) -> List[Dict[str, torch.Tensor]]:
+ """Get weights for a Linear module, auto-detecting fused weights."""
+ weights_config = getattr(module, "weights_loading_config", None)
+ if weights_config is not None:
+ weight_mode = getattr(weights_config, "weight_mode", None)
+ if weight_mode == WeightMode.FUSED_QKV_LINEAR:
+ fused_names = self._get_fused_names(full_name)
+ return self._get_fused_weights(full_name, weights, fused_names)
+
+ return self._get_vanilla_weights(full_name, weights)
+
+ def filter_weights(
+ self, prefix: str, weights: Dict[str, torch.Tensor]
+ ) -> Dict[str, torch.Tensor]:
+ """
+ Filter weights by prefix and strip the prefix.
+
+ Example:
+ prefix = 'blocks.0.attn1.to_q'
+ weights = {'blocks.0.attn1.to_q.weight': ..., 'blocks.0.attn1.to_q.bias': ...}
+ Returns: {'weight': ..., 'bias': ...}
+ """
+ result = {}
+ prefix_dot = prefix + "."
+ for k, v in weights.items():
+ if k.startswith(prefix_dot):
+ result[k[len(prefix_dot) :]] = v
+ return result
+
+ def _get_fused_names(self, full_name: str) -> List[str]:
+ """Get checkpoint names for a fused module from params_map."""
+ for suffix, names in self.params_map.items():
+ if full_name.endswith(suffix):
+ return names
+ raise ValueError(
+ f"No params_map entry for fused module '{full_name}'. "
+ f"Add mapping like {{'qkv_proj': ['to_q', 'to_k', 'to_v']}} to params_map."
+ )
+
+ def _get_fused_weights(
+ self,
+ full_name: str,
+ weights: Dict[str, torch.Tensor],
+ fused_names: List[str],
+ ) -> List[Dict[str, torch.Tensor]]:
+ """Get weights for a fused module from checkpoint."""
+ parent_path = ".".join(full_name.split(".")[:-1])
+ module_weights = []
+ for ckpt_name in fused_names:
+ ckpt_path = f"{parent_path}.{ckpt_name}" if parent_path else ckpt_name
+ filtered = self.filter_weights(ckpt_path, weights)
+ module_weights.append(filtered)
+ return module_weights
+
+ def _get_vanilla_weights(
+ self,
+ full_name: str,
+ weights: Dict[str, torch.Tensor],
+ ) -> List[Dict[str, torch.Tensor]]:
+ """Get weights for a standard (non-fused) Linear module."""
+ fw = self.filter_weights(full_name, weights)
+ return [fw] if fw else []
+
+ # =========================================================================
+ # Quantization methods
+ # =========================================================================
+
+ def _get_quant_algo_for_layer(self, name: str) -> Optional[QuantAlgo]:
+ """Get quantization algorithm for a specific layer."""
+ if self.quant_config_dict is not None:
+ layer_config = self.quant_config_dict.get(name)
+ if layer_config is not None:
+ return layer_config.quant_algo
+
+ if self.quant_config is not None:
+ return self.quant_config.quant_algo
+
+ return None
+
+ def _should_dynamic_quantize(
+ self, weight_dict: Dict[str, torch.Tensor], quant_algo: Optional[QuantAlgo], name: str
+ ) -> bool:
+ """Decide if weight should be dynamically quantized at load time."""
+ if not self.dynamic_weight_quant or quant_algo is None:
+ return False
+
+ # Check if module is excluded
+ if self.quant_config is not None:
+ if self.quant_config.is_module_excluded_from_quantization(name):
+ return False
+
+ weight = weight_dict.get("weight")
+ if weight is None:
+ return False
+
+ # For FP8 algorithms: quantize if weight is high precision
+ if quant_algo in (QuantAlgo.FP8, QuantAlgo.FP8_BLOCK_SCALES):
+ if weight.dtype == torch.float8_e4m3fn and "weight_scale" in weight_dict:
+ return False # Already quantized
+ return weight.dtype in (torch.bfloat16, torch.float16, torch.float32)
+
+ return False
+
+ def _maybe_dynamic_quantize(
+ self, weight_dict: Dict[str, torch.Tensor], quant_algo: Optional[QuantAlgo], name: str
+ ) -> Dict[str, torch.Tensor]:
+ """Conditionally quantize weight at load time on GPU."""
+ if not self._should_dynamic_quantize(weight_dict, quant_algo, name):
+ return weight_dict
+
+ weight = weight_dict["weight"]
+
+ # Move to GPU only if needed
+ if weight.device.type != "cuda":
+ weight = weight.cuda()
+
+ if quant_algo == QuantAlgo.FP8:
+ qweight, scale = quantize_fp8_per_tensor(weight)
+ elif quant_algo == QuantAlgo.FP8_BLOCK_SCALES:
+ block_size = self.quant_config.group_size if self.quant_config else 128
+ qweight, scale = quantize_fp8_blockwise(weight, block_size=block_size)
+ else:
+ return weight_dict
+
+ return {**weight_dict, "weight": qweight, "weight_scale": scale}
+
+ def load_linear_weights(
+ self, module: Linear, name: str, weight_dicts: List[Dict[str, torch.Tensor]]
+ ) -> None:
+ """Load weights into Linear module with optional quantization."""
+ module_quant_config = getattr(module, "quant_config", None)
+ if module_quant_config is not None:
+ quant_algo = module_quant_config.quant_algo
+ else:
+ quant_algo = self._get_quant_algo_for_layer(name)
+
+ quantized_weight_dicts = [
+ self._maybe_dynamic_quantize(wd, quant_algo, name) for wd in weight_dicts
+ ]
+
+ module.load_weights(quantized_weight_dicts)
diff --git a/tensorrt_llm/_torch/visual_gen/quantization/ops.py b/tensorrt_llm/_torch/visual_gen/quantization/ops.py
new file mode 100644
index 0000000000..db550ced8a
--- /dev/null
+++ b/tensorrt_llm/_torch/visual_gen/quantization/ops.py
@@ -0,0 +1,98 @@
+"""
+Quantization operations for diffusion models.
+
+Provides on-the-fly quantization functions for dynamic (load-time) quantization.
+"""
+
+from typing import Tuple
+
+import torch
+
+# FP8 E4M3 max value
+FP8_E4M3_MAX = torch.finfo(torch.float8_e4m3fn).max
+
+
+def quantize_fp8_per_tensor(weight: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
+ """
+ Quantize weight to FP8 E4M3 with per-tensor scale.
+
+ Uses torch.ops.tensorrt_llm.quantize_e4m3_per_tensor CUDA kernel.
+
+ Args:
+ weight: Input weight tensor (BF16/FP16/FP32), shape (out_features, in_features)
+
+ Returns:
+ Tuple of:
+ - qweight: Quantized weight (FP8 E4M3), same shape as input
+ - weight_scale: Dequantization scale (FP32), shape (1, 1)
+ """
+ qweight, scale = torch.ops.tensorrt_llm.quantize_e4m3_per_tensor(weight)
+ # Ensure scale is float32 and has shape (1, 1) for consistency
+ return qweight, scale.to(torch.float32)
+
+
+def quantize_fp8_blockwise(
+ weight: torch.Tensor, block_size: int = 128
+) -> Tuple[torch.Tensor, torch.Tensor]:
+ """
+ Quantize weight to FP8 E4M3 with 128x128 blockwise scales.
+
+ This function converts BF16/FP16/FP32 weights to FP8 E4M3 format using
+ per-block scale factors. The weight is divided into blocks of size
+ (block_size, block_size) and each block has its own scale.
+
+ Args:
+ weight: Input weight tensor (BF16/FP16/FP32), shape (out_features, in_features)
+ block_size: Block size for blockwise quantization (default: 128)
+
+ Returns:
+ Tuple of:
+ - qweight: Quantized weight (FP8 E4M3), shape (out_features, in_features)
+ - block_scales: Block-wise dequantization scales (FP32),
+ shape (num_blocks_out, num_blocks_in)
+
+ Note:
+ - If dimensions are not divisible by block_size, the last block may be smaller
+ - block_scales are dequantization scales (multiply to get back original scale)
+ - This uses 128x128 block scaling compatible with Linear module's FP8_BLOCK_SCALES
+ """
+ out_features, in_features = weight.shape
+ weight_fp32 = weight.float()
+
+ # Calculate number of blocks
+ num_blocks_out = (out_features + block_size - 1) // block_size
+ num_blocks_in = (in_features + block_size - 1) // block_size
+
+ # Initialize outputs
+ qweight = torch.empty_like(weight, dtype=torch.float8_e4m3fn)
+ block_scales = torch.empty(
+ (num_blocks_out, num_blocks_in), dtype=torch.float32, device=weight.device
+ )
+
+ # Quantize each block
+ for i in range(num_blocks_out):
+ row_start = i * block_size
+ row_end = min((i + 1) * block_size, out_features)
+
+ for j in range(num_blocks_in):
+ col_start = j * block_size
+ col_end = min((j + 1) * block_size, in_features)
+
+ # Extract block
+ block = weight_fp32[row_start:row_end, col_start:col_end]
+
+ # Compute block scale
+ max_val = block.abs().max()
+ scale = (
+ max_val / FP8_E4M3_MAX if max_val > 0 else torch.tensor(1.0, device=weight.device)
+ )
+
+ # Quantize block
+ inv_scale = scale.reciprocal() if scale > 0 else torch.tensor(1.0, device=weight.device)
+ qblock = (block * inv_scale).clamp(-FP8_E4M3_MAX, FP8_E4M3_MAX).to(torch.float8_e4m3fn)
+
+ # Store results
+ qweight[row_start:row_end, col_start:col_end] = qblock
+ block_scales[i, j] = scale.to(torch.float32)
+
+ return qweight, block_scales
diff --git a/tensorrt_llm/_torch/visual_gen/teacache.py b/tensorrt_llm/_torch/visual_gen/teacache.py
new file mode 100644
index 0000000000..aa99c1655e
--- /dev/null
+++ b/tensorrt_llm/_torch/visual_gen/teacache.py
@@ -0,0 +1,409 @@
+import inspect
+from dataclasses import dataclass
+from typing import Any, Callable, Dict, List, Optional
+
+import numpy as np
+import torch
+from diffusers.models.modeling_outputs import Transformer2DModelOutput
+
+from tensorrt_llm.logger import logger
+
+# =============================================================================
+# Core Data Structures
+# =============================================================================
+
+
+@dataclass
+class CacheContext:
+ """Context returned by model extractors for TeaCache.
+
+ Attributes:
+ modulated_input: Timestep embedding used for cache distance calculation
+ hidden_states: Input hidden states for the transformer
+ encoder_hidden_states: Text/prompt embeddings
+ run_transformer_blocks: Callable that executes the transformer forward pass
+ postprocess: Callable that formats the output to the expected return type
+ """
+
+ modulated_input: torch.Tensor
+ hidden_states: torch.Tensor
+ encoder_hidden_states: Any = None
+ run_transformer_blocks: Callable = None
+ postprocess: Callable = None
+
+
+# =============================================================================
+# Extractor Registry
+# =============================================================================
+
+_EXTRACTORS = {}
+
+
+def register_extractor(model_name, extractor_fn):
+ """Register an extractor function for a model class."""
+ _EXTRACTORS[model_name] = extractor_fn
+
+
+def get_extractor(model_type):
+ """Get the registered extractor for a model type."""
+ if model_type not in _EXTRACTORS:
+ raise ValueError(
+ f"TeaCache: Unknown model '{model_type}'. Available: {list(_EXTRACTORS.keys())}"
+ )
+ return _EXTRACTORS[model_type]
+
+
+# =============================================================================
+# Config-Based Extractor System
+# =============================================================================
+
+
+@dataclass
+class ExtractorConfig:
+ """Configuration for model-specific TeaCache extractors.
+
+ Only the timestep embedding logic is model-specific; all other logic is handled generically.
+
+ Attributes:
+ model_class_name: Model class name (e.g., "LTX2VideoTransformer3DModel")
+ timestep_embed_fn: Callable(module, timestep, guidance=None) -> Tensor
+ timestep_param_name: Parameter name for timestep in forward() (default: "timestep")
+ guidance_param_name: Parameter name for guidance if used (default: None)
+ forward_params: List of parameter names (None = auto-introspect from forward signature)
+ return_dict_default: Default value for return_dict parameter (default: True)
+ output_model_class: Output class name for return type (default: "Transformer2DModelOutput")
+ """
+
+ model_class_name: str
+ timestep_embed_fn: Callable
+ timestep_param_name: str = "timestep"
+ guidance_param_name: Optional[str] = None
+ forward_params: Optional[List[str]] = None
+ return_dict_default: bool = True
+ output_model_class: str = "Transformer2DModelOutput"
+
+
+class GenericExtractor:
+ """Handles common TeaCache logic for all diffusion models.
+
+ Extracts forward() arguments, creates run_blocks and postprocess callbacks,
+ and delegates only timestep embedding computation to model-specific logic.
+ """
+
+ def __init__(self, config: ExtractorConfig):
+ self.config = config
+
+ def _extract_forward_args(self, module: torch.nn.Module, *args, **kwargs) -> Dict:
+ """Extract and normalize forward() arguments from *args and **kwargs."""
+ # Get parameter names (auto-introspect or use config)
+ if self.config.forward_params is not None:
+ param_names = self.config.forward_params
+ else:
+ # Auto-introspect forward signature
+ try:
+ sig = inspect.signature(module._original_forward)
+ param_names = [p for p in sig.parameters if p not in ("self", "args", "kwargs")]
+ except Exception as e:
+ logger.warning(f"Could not introspect forward signature: {e}")
+ param_names = []
+
+ # Map positional args to parameter names
+ extracted = {param_names[i]: arg for i, arg in enumerate(args) if i < len(param_names)}
+
+ # Merge kwargs (kwargs take precedence)
+ extracted.update(kwargs)
+ return extracted
+
+ def _compute_timestep_embedding(self, module: torch.nn.Module, params: Dict) -> torch.Tensor:
+ """Compute timestep embedding using configured callable."""
+ timestep = params.get(self.config.timestep_param_name)
+ if timestep is None:
+ raise ValueError(f"Missing required parameter: {self.config.timestep_param_name}")
+
+ # Flatten timestep if needed (common pattern)
+ timestep_flat = timestep.flatten() if timestep.ndim == 2 else timestep
+ guidance = (
+ params.get(self.config.guidance_param_name) if self.config.guidance_param_name else None
+ )
+
+ # Call configured timestep embedding function
+ try:
+ return self.config.timestep_embed_fn(module, timestep_flat, guidance)
+ except Exception as e:
+ logger.error(f"Timestep embedder failed: {e}")
+ # Last resort: use timestep as-is
+ logger.warning("Using timestep fallback")
+ return timestep_flat.unsqueeze(-1) if timestep_flat.ndim == 1 else timestep_flat
+
+ def __call__(self, module: torch.nn.Module, *args, **kwargs) -> CacheContext:
+ """Main extractor logic - called by TeaCacheHook.
+
+ Extracts forward arguments, computes timestep embedding, and creates callbacks
+ for running the transformer and post-processing the output.
+ """
+ # Extract forward arguments from positional and keyword args
+ params = self._extract_forward_args(module, *args, **kwargs)
+
+ # Compute timestep embedding (used for cache distance calculation)
+ t_emb = self._compute_timestep_embedding(module, params)
+ return_dict = params.get("return_dict", self.config.return_dict_default)
+
+ def run_blocks():
+ """Execute the full transformer forward pass with original parameters."""
+ ret = module._original_forward(**params)
+ # Normalize output to tuple format
+ if return_dict and not isinstance(ret, tuple):
+ sample = ret.sample if hasattr(ret, "sample") else ret
+ return (sample,) if not isinstance(sample, tuple) else sample
+ return ret if isinstance(ret, tuple) else (ret,)
+
+ def postprocess(output):
+ """Convert cached/computed output back to expected return format."""
+ if return_dict:
+ if isinstance(output, tuple):
+ return output
+ return Transformer2DModelOutput(sample=output)
+ # For return_dict=False, unwrap single-element tuple to raw tensor
+ if isinstance(output, tuple) and len(output) == 1:
+ return output[0]
+ # Return raw tensor as-is (TeaCacheHook always passes tensors to postprocess)
+ return output
+
+ return CacheContext(
+ modulated_input=t_emb,
+ hidden_states=params.get("hidden_states"),
+ encoder_hidden_states=params.get("encoder_hidden_states"),
+ run_transformer_blocks=run_blocks,
+ postprocess=postprocess,
+ )
+
+
+def register_extractor_from_config(config: ExtractorConfig):
+ """Register a TeaCache extractor for a model. Call this in pipeline's load() method.
+
+ Example:
+ register_extractor_from_config(ExtractorConfig(
+ model_class_name="LTX2VideoTransformer3DModel",
+ timestep_embed_fn=self._compute_ltx2_timestep_embedding,
+ ))
+ """
+ extractor = GenericExtractor(config)
+ register_extractor(config.model_class_name, extractor)
+ logger.debug(f"Registered TeaCache extractor for {config.model_class_name}")
+
+
+# =============================================================================
+# TeaCache Runtime (caching hook and lifecycle management)
+# =============================================================================
+
+
+class TeaCacheHook:
+ """Caches transformer blocks when timestep embeddings change slowly.
+
+ The hook monitors the relative change in timestep embeddings between steps.
+ When the change is small (below threshold), it reuses the cached residual
+ from the previous step instead of running the full transformer.
+
+ Separate cache states are maintained for conditional and unconditional branches
+ when using Classifier-Free Guidance (CFG).
+ """
+
+ def __init__(self, config):
+ self.config = config
+ # Polynomial function to rescale embedding distances
+ self.rescale_func = np.poly1d(config.coefficients)
+ self.extractor_fn = None
+
+ # Separate cache state for conditional (pos) and unconditional (neg) branches
+ self.state_pos = self._new_state()
+ self.state_neg = self._new_state()
+ self.stats = {"total": 0, "cached": 0}
+
+ def _new_state(self):
+ return {"cnt": 0, "acc_dist": 0.0, "prev_input": None, "prev_residual": None}
+
+ def initialize(self, module):
+ self.extractor_fn = get_extractor(module.__class__.__name__)
+
+ def reset_state(self):
+ self.state_pos = self._new_state()
+ self.state_neg = self._new_state()
+ self.stats = {"total": 0, "cached": 0}
+
+ def get_stats(self):
+ total = max(self.stats["total"], 1)
+ cached = self.stats["cached"]
+ return {
+ "hit_rate": cached / total,
+ "total": total,
+ "cached": cached,
+ # Backward compatibility
+ "total_steps": total,
+ "cached_steps": cached,
+ "compute_steps": total - cached,
+ }
+
+ def __call__(self, module, *args, **kwargs):
+ """Main hook called during transformer forward pass.
+
+ Decides whether to run the full transformer or reuse cached residual
+ based on timestep embedding distance.
+ """
+ # Extract context (timestep embedding, hidden states, callbacks)
+ ctx = self.extractor_fn(module, *args, **kwargs)
+
+ # Select cache state (for CFG: separate tracking for conditional/unconditional)
+ cache_branch = getattr(module, "_cache_branch", None)
+ state = self.state_neg if cache_branch == "uncond" else self.state_pos
+
+ # Decide: compute transformer or use cache?
+ should_compute = self._should_compute(state, ctx.modulated_input)
+ self.stats["total"] += 1
+
+ if not should_compute and state["prev_residual"] is not None:
+ # Cache hit: Add cached residual to skip transformer computation
+ logger.debug(f"TeaCache: SKIP step {state['cnt']}")
+ # For I2V: output might have fewer channels than input
+ # Apply residual only to the latent channels
+ if ctx.hidden_states.shape[1] != state["prev_residual"].shape[1]:
+ # Extract latent channels (match output channels)
+ num_output_channels = state["prev_residual"].shape[1]
+ latent_channels = ctx.hidden_states[:, :num_output_channels]
+ output = latent_channels + state["prev_residual"]
+ else:
+ output = ctx.hidden_states + state["prev_residual"]
+ self.stats["cached"] += 1
+ else:
+ # Cache miss: Run full transformer and cache the residual
+ outputs = ctx.run_transformer_blocks()
+ output = outputs[0] if isinstance(outputs, tuple) else outputs
+
+ # Store residual (output - input) for next potential cache hit
+ # For I2V: output may have fewer channels than input
+ # Compute residual only on the latent channels
+ if ctx.hidden_states.shape[1] != output.shape[1]:
+ # Extract latent channels (match output channels)
+ num_output_channels = output.shape[1]
+ latent_channels = ctx.hidden_states[:, :num_output_channels]
+ state["prev_residual"] = (output - latent_channels).detach()
+ else:
+ original = ctx.hidden_states.clone()
+ state["prev_residual"] = (output - original).detach()
+
+ # Update state for next iteration
+ state["prev_input"] = ctx.modulated_input.detach()
+ state["cnt"] += 1
+
+ return ctx.postprocess(output)
+
+ def _should_compute(self, state, modulated_inp):
+ """Decide whether to compute transformer or use cached result.
+
+ Returns True to compute, False to use cache.
+ """
+ # Warmup: Always compute first few steps to build stable cache
+ if self.config.ret_steps and state["cnt"] < self.config.ret_steps:
+ state["acc_dist"] = 0.0
+ return True
+
+ # Cooldown: Always compute last few steps for quality
+ if self.config.cutoff_steps and state["cnt"] >= self.config.cutoff_steps:
+ return True
+
+ # First step: no previous input to compare
+ if state["prev_input"] is None:
+ return True
+
+ # Compute relative change in timestep embedding
+ curr, prev = modulated_inp, state["prev_input"]
+
+ # For CFG (batch_size > 1), only compare conditional branch
+ # Both branches move similarly, so one comparison is sufficient
+ if modulated_inp.shape[0] > 1:
+ curr, prev = modulated_inp.chunk(2)[1], prev.chunk(2)[1]
+
+ # Calculate relative L1 distance (normalized by magnitude)
+ rel_dist = ((curr - prev).abs().mean() / (prev.abs().mean() + 1e-8)).cpu().item()
+
+ # Apply polynomial rescaling to adjust sensitivity
+ # Accumulate distance (capped at 2x threshold to prevent overflow)
+ rescaled = float(self.rescale_func(rel_dist))
+ state["acc_dist"] = min(
+ state["acc_dist"] + abs(rescaled), self.config.teacache_thresh * 2.0
+ )
+
+ logger.debug(
+ f"TeaCache: step {state['cnt']} | dist {rel_dist:.2e} | acc {state['acc_dist']:.4f}"
+ )
+
+ # Cache decision based on accumulated distance
+ if state["acc_dist"] < self.config.teacache_thresh:
+ # Below threshold: use cache, apply decay to distance
+ state["acc_dist"] *= 0.95
+ return False
+ else:
+ # Above threshold: compute, reset accumulated distance
+ state["acc_dist"] = 0.0
+ return True
+
+
+class TeaCacheBackend:
+ """Manages TeaCache lifecycle."""
+
+ def __init__(self, config):
+ self.config = config
+ self.hook = None
+
+ def enable(self, module):
+ if self.hook is None:
+ logger.info(f"TeaCache: Enabling for {module.__class__.__name__}")
+ self.hook = TeaCacheHook(self.config)
+ self.hook.initialize(module)
+ module._original_forward = module.forward
+ module.forward = lambda *args, **kwargs: self.hook(module, *args, **kwargs)
+
+ def disable(self, module):
+ if self.hook and hasattr(module, "_original_forward"):
+ module.forward = module._original_forward
+ self.hook = None
+
+ def refresh(self, num_inference_steps):
+ """Reset TeaCache state for a new generation.
+
+ Sets warmup/cutoff steps based on total inference steps:
+ - Warmup steps: Always compute to build stable cache
+ - Cutoff steps: Always compute for quality at the end
+ - Middle steps: Use caching based on distance threshold
+
+ Args:
+ num_inference_steps: Total number of denoising steps
+ """
+ if not self.hook:
+ return
+
+ # Reset cache state (clears previous residuals and counters)
+ self.hook.reset_state()
+
+ # Configure warmup and cutoff based on mode
+ if self.config.use_ret_steps:
+ # Aggressive warmup: 5 steps to stabilize cache
+ self.config.ret_steps = 5
+ self.config.cutoff_steps = num_inference_steps # No cutoff (cache until end)
+ else:
+ # Minimal warmup: 1 step
+ self.config.ret_steps = 1
+ self.config.cutoff_steps = num_inference_steps - 2 # Compute last 2 steps
+
+ self.config.num_steps = num_inference_steps
+
+ logger.info(
+ f"TeaCache: {num_inference_steps} steps | "
+ f"warmup: {self.config.ret_steps}, cutoff: {self.config.cutoff_steps}, "
+ f"thresh: {self.config.teacache_thresh}"
+ )
+
+ def is_enabled(self):
+ return self.hook is not None
+
+ def get_stats(self):
+ return self.hook.get_stats() if self.hook else {}
diff --git a/tensorrt_llm/_torch/visual_gen/utils.py b/tensorrt_llm/_torch/visual_gen/utils.py
new file mode 100644
index 0000000000..99f8837ceb
--- /dev/null
+++ b/tensorrt_llm/_torch/visual_gen/utils.py
@@ -0,0 +1,39 @@
+"""Utility functions for visual generation pipelines."""
+
+import torch
+
+
+@torch.compile
+def postprocess_video_tensor(video: torch.Tensor, remove_batch_dim: bool = True) -> torch.Tensor:
+ """Post-process video tensor from VAE decoder output to final format.
+
+ This is a more efficient implementation than using VideoProcessor for single-batch cases,
+ as it avoids loop overhead and processes the entire batch with vectorized operations.
+
+ Args:
+ video: Video tensor in (B, C, T, H, W) format from VAE decoder
+ remove_batch_dim: Whether to remove batch dimension. Default True for typical
+ single-batch video generation.
+
+ Returns:
+ Post-processed video tensor:
+ - If remove_batch_dim=True: (T, H, W, C) uint8 tensor
+ - If remove_batch_dim=False: (B, T, H, W, C) uint8 tensor
+
+ Note:
+ Assumes video values are in [-1, 1] range (standard VAE decoder output).
+ """
+ # Convert to (B, T, H, W, C) format
+ video = video.permute(0, 2, 3, 4, 1) # (B, C, T, H, W) -> (B, T, H, W, C)
+
+ # Normalize to [0, 1] range
+ video = (video / 2 + 0.5).clamp(0, 1)
+
+ # Convert to uint8
+ video = (video * 255).round().to(torch.uint8)
+
+ # Remove batch dimension if requested
+ if remove_batch_dim:
+ video = video[0] # (B, T, H, W, C) -> (T, H, W, C)
+
+ return video
diff --git a/tensorrt_llm/commands/serve.py b/tensorrt_llm/commands/serve.py
index 76cbde9646..8eb31bd5f1 100644
--- a/tensorrt_llm/commands/serve.py
+++ b/tensorrt_llm/commands/serve.py
@@ -19,11 +19,13 @@ from tensorrt_llm import LLM as PyTorchLLM
from tensorrt_llm import MultimodalEncoder
from tensorrt_llm._tensorrt_engine import LLM
from tensorrt_llm._utils import mpi_rank
+from tensorrt_llm.commands.utils import (get_is_diffusion_model,
+ get_visual_gen_model_type)
from tensorrt_llm.executor.utils import LlmLauncherEnvs
from tensorrt_llm.inputs.multimodal import MultimodalServerConfig
from tensorrt_llm.llmapi import (BuildConfig, CapacitySchedulerPolicy,
DynamicBatchConfig, KvCacheConfig,
- SchedulerConfig)
+ SchedulerConfig, VisualGen)
from tensorrt_llm.llmapi.disagg_utils import (DisaggClusterConfig,
MetadataServerConfig, ServerRole,
extract_disagg_cluster_config,
@@ -217,7 +219,7 @@ def launch_server(
f"{backend} is not a known backend, check help for available options.",
param_hint="backend")
- server = OpenAIServer(llm=llm,
+ server = OpenAIServer(generator=llm,
model=model,
tool_parser=tool_parser,
server_role=server_role,
@@ -361,7 +363,7 @@ def launch_mm_encoder_server(
encoder_args.pop("build_config")
mm_encoder = MultimodalEncoder(**encoder_args)
- server = OpenAIServer(llm=mm_encoder,
+ server = OpenAIServer(generator=mm_encoder,
model=model,
server_role=ServerRole.MM_ENCODER,
metadata_server_cfg=metadata_server_cfg,
@@ -369,6 +371,45 @@ def launch_mm_encoder_server(
asyncio.run(server(host, port))
+def launch_visual_gen_server(
+ host: str,
+ port: int,
+ visual_gen_config: dict,
+ metadata_server_cfg: Optional[MetadataServerConfig] = None,
+):
+ """Launch a VISUAL_GEN model server for image/video generation.
+
+ Args:
+ host: Server hostname.
+ port: Server port.
+ visual_gen_config: Arguments for VISUAL_GEN model initialization.
+ metadata_server_cfg: Optional metadata server configuration.
+ """
+ model = visual_gen_config["model"]
+ logger.info(f"Initializing VisualGen ({model})")
+
+ n_workers = 1
+ parallel_config = visual_gen_config.get("parallel", {})
+ if parallel_config:
+ n_workers = parallel_config.get(
+ "dit_cfg_size", 1) * parallel_config.get("dit_ulysses_size", 1)
+ logger.info(f"World size: {n_workers}")
+ logger.info(f"CFG size: {parallel_config.get('dit_cfg_size', 1)}")
+ logger.info(
+ f"Ulysses size: {parallel_config.get('dit_ulysses_size', 1)}")
+
+ visual_gen_model = VisualGen(model_path=model,
+ n_workers=n_workers,
+ diffusion_config=visual_gen_config)
+
+ server = OpenAIServer(generator=visual_gen_model,
+ model=model,
+ server_role=ServerRole.VISUAL_GEN,
+ metadata_server_cfg=metadata_server_cfg,
+ tool_parser=None)
+ asyncio.run(server(host, port))
+
+
class ChoiceWithAlias(click.Choice):
def __init__(self,
@@ -600,6 +641,12 @@ class ChoiceWithAlias(click.Choice):
default=False,
help="Run gRPC server instead of OpenAI HTTP server. "
"gRPC server accepts pre-tokenized requests and returns raw token IDs.")
+@click.option("--extra_visual_gen_options",
+ type=str,
+ default=None,
+ help=help_info_with_stability_tag(
+ "Path to a YAML file with extra VISUAL_GEN model options.",
+ "prototype"))
def serve(
model: str, tokenizer: Optional[str], custom_tokenizer: Optional[str],
host: str, port: int, log_level: str, backend: str, max_beam_width: int,
@@ -616,8 +663,8 @@ def serve(
otlp_traces_endpoint: Optional[str], enable_chunked_prefill: bool,
disagg_cluster_uri: Optional[str], media_io_kwargs: Optional[str],
custom_module_dirs: list[Path], chat_template: Optional[str],
- grpc: bool):
- """Running an OpenAI API compatible server (or gRPC server with --grpc flag)
+ grpc: bool, extra_visual_gen_options: Optional[str]):
+ """Running an OpenAI API compatible server
MODEL: model name | HF checkpoint path | TensorRT engine path
"""
@@ -630,93 +677,120 @@ def serve(
logger.error(
f"Failed to import custom module from {custom_module_dir}: {e}")
raise e
- llm_args, _ = get_llm_args(
- model=model,
- tokenizer=tokenizer,
- custom_tokenizer=custom_tokenizer,
- backend=backend,
- max_beam_width=max_beam_width,
- max_batch_size=max_batch_size,
- max_num_tokens=max_num_tokens,
- max_seq_len=max_seq_len,
- tensor_parallel_size=tensor_parallel_size,
- pipeline_parallel_size=pipeline_parallel_size,
- context_parallel_size=context_parallel_size,
- moe_expert_parallel_size=moe_expert_parallel_size,
- moe_cluster_parallel_size=moe_cluster_parallel_size,
- gpus_per_node=gpus_per_node,
- free_gpu_memory_fraction=free_gpu_memory_fraction,
- num_postprocess_workers=num_postprocess_workers,
- trust_remote_code=trust_remote_code,
- revision=revision,
- reasoning_parser=reasoning_parser,
- fail_fast_on_attention_window_too_large=
- fail_fast_on_attention_window_too_large,
- otlp_traces_endpoint=otlp_traces_endpoint,
- enable_chunked_prefill=enable_chunked_prefill)
- llm_args_extra_dict = {}
- if extra_llm_api_options is not None:
- with open(extra_llm_api_options, 'r') as f:
- llm_args_extra_dict = yaml.safe_load(f)
- llm_args = update_llm_args_with_extra_dict(llm_args, llm_args_extra_dict)
+ def _serve_llm():
+ nonlocal server_role
+ llm_args, _ = get_llm_args(
+ model=model,
+ tokenizer=tokenizer,
+ custom_tokenizer=custom_tokenizer,
+ backend=backend,
+ max_beam_width=max_beam_width,
+ max_batch_size=max_batch_size,
+ max_num_tokens=max_num_tokens,
+ max_seq_len=max_seq_len,
+ tensor_parallel_size=tensor_parallel_size,
+ pipeline_parallel_size=pipeline_parallel_size,
+ context_parallel_size=context_parallel_size,
+ moe_expert_parallel_size=moe_expert_parallel_size,
+ moe_cluster_parallel_size=moe_cluster_parallel_size,
+ gpus_per_node=gpus_per_node,
+ free_gpu_memory_fraction=free_gpu_memory_fraction,
+ num_postprocess_workers=num_postprocess_workers,
+ trust_remote_code=trust_remote_code,
+ revision=revision,
+ reasoning_parser=reasoning_parser,
+ fail_fast_on_attention_window_too_large=
+ fail_fast_on_attention_window_too_large,
+ otlp_traces_endpoint=otlp_traces_endpoint,
+ enable_chunked_prefill=enable_chunked_prefill)
- metadata_server_cfg = parse_metadata_server_config_file(
- metadata_server_config_file)
+ llm_args_extra_dict = {}
+ if extra_llm_api_options is not None:
+ with open(extra_llm_api_options, 'r') as f:
+ llm_args_extra_dict = yaml.safe_load(f)
+ llm_args = update_llm_args_with_extra_dict(llm_args,
+ llm_args_extra_dict)
- # Specify disagg_cluster_config in config file or through command line "--disagg_cluster_uri",
- # but disagg_cluster_uri takes precedence over cluster uri in config file
- disagg_cluster_config = llm_args.pop("disagg_cluster", None)
- if disagg_cluster_config:
- disagg_cluster_config = extract_disagg_cluster_config(
- disagg_cluster_config, disagg_cluster_uri)
- elif disagg_cluster_uri:
- disagg_cluster_config = DisaggClusterConfig(
- cluster_uri=disagg_cluster_uri)
+ metadata_server_cfg = parse_metadata_server_config_file(
+ metadata_server_config_file)
- if metadata_server_cfg is not None or disagg_cluster_config is not None:
- assert (
- server_role is not None
- ), "server_role is required when metadata_server_cfg or disagg_cluster_config is provided"
- try:
- server_role = ServerRole[server_role.upper()]
- except ValueError:
- raise ValueError(f"Invalid server role: {server_role}. " \
- f"Must be one of: {', '.join([role.name for role in ServerRole])}")
+ # Specify disagg_cluster_config in config file or through command line "--disagg_cluster_uri",
+ # but disagg_cluster_uri takes precedence over cluster uri in config file
+ disagg_cluster_config = llm_args.pop("disagg_cluster", None)
+ if disagg_cluster_config:
+ disagg_cluster_config = extract_disagg_cluster_config(
+ disagg_cluster_config, disagg_cluster_uri)
+ elif disagg_cluster_uri:
+ disagg_cluster_config = DisaggClusterConfig(
+ cluster_uri=disagg_cluster_uri)
- # Parse media_io_kwargs from JSON string to dict if provided
- parsed_media_io_kwargs = None
- if media_io_kwargs is not None:
- try:
- parsed_media_io_kwargs = json.loads(media_io_kwargs)
- except json.JSONDecodeError as e:
- raise ValueError(f"Invalid JSON for media_io_kwargs: {e}")
+ if metadata_server_cfg is not None or disagg_cluster_config is not None:
+ assert (
+ server_role is not None
+ ), "server_role is required when metadata_server_cfg or disagg_cluster_config is provided"
+ try:
+ server_role = ServerRole[server_role.upper()]
+ except ValueError:
+ raise ValueError(f"Invalid server role: {server_role}. " \
+ f"Must be one of: {', '.join([role.name for role in ServerRole])}")
+ # Parse media_io_kwargs from JSON string to dict if provided
+ parsed_media_io_kwargs = None
+ if media_io_kwargs is not None:
+ try:
+ parsed_media_io_kwargs = json.loads(media_io_kwargs)
+ except json.JSONDecodeError as e:
+ raise ValueError(f"Invalid JSON for media_io_kwargs: {e}")
- multimodal_server_config = MultimodalServerConfig(
- media_io_kwargs=parsed_media_io_kwargs)
+ multimodal_server_config = MultimodalServerConfig(
+ media_io_kwargs=parsed_media_io_kwargs)
- if grpc:
- # gRPC mode: launch gRPC server instead of OpenAI HTTP server
- # Check for unsupported arguments that are silently ignored in gRPC mode
- unsupported_args = {
- "tool_parser": tool_parser,
- "chat_template": chat_template,
- "metadata_server_config_file": metadata_server_config_file,
- "server_role": server_role,
- "disagg_cluster_config": disagg_cluster_config,
+ if grpc:
+ # gRPC mode: launch gRPC server instead of OpenAI HTTP server
+ # Check for unsupported arguments that are silently ignored in gRPC mode
+ unsupported_args = {
+ "tool_parser": tool_parser,
+ "chat_template": chat_template,
+ "metadata_server_config_file": metadata_server_config_file,
+ "server_role": server_role,
+ "disagg_cluster_config": disagg_cluster_config,
+ }
+ for name, value in unsupported_args.items():
+ if value is not None:
+ raise ValueError(
+ f"Argument '{name}' is not supported when running in gRPC mode. "
+ f"The gRPC server is designed for use with external routers that handle "
+ f"these features (e.g., tool parsing, chat templates).")
+ launch_grpc_server(host, port, llm_args)
+ else:
+ # Default: launch OpenAI HTTP server
+ launch_server(host, port, llm_args, tool_parser, chat_template,
+ metadata_server_cfg, server_role,
+ disagg_cluster_config, multimodal_server_config)
+
+ def _serve_visual_gen():
+ visual_gen_config = {
+ "model": model,
+ "model_type": get_visual_gen_model_type(model),
}
- for name, value in unsupported_args.items():
- if value is not None:
- raise ValueError(
- f"Argument '{name}' is not supported when running in gRPC mode. "
- f"The gRPC server is designed for use with external routers that handle "
- f"these features (e.g., tool parsing, chat templates).")
- launch_grpc_server(host, port, llm_args)
+
+ visual_gen_extra_args = {}
+ if extra_visual_gen_options is not None:
+ with open(extra_visual_gen_options, 'r') as f:
+ visual_gen_extra_args = yaml.safe_load(f)
+
+ visual_gen_config.update(visual_gen_extra_args)
+
+ metadata_server_cfg = parse_metadata_server_config_file(
+ metadata_server_config_file)
+
+ launch_visual_gen_server(host, port, visual_gen_config,
+ metadata_server_cfg)
+
+ if get_is_diffusion_model(model):
+ _serve_visual_gen()
else:
- # Default: launch OpenAI HTTP server
- launch_server(host, port, llm_args, tool_parser, chat_template,
- metadata_server_cfg, server_role, disagg_cluster_config,
- multimodal_server_config)
+ _serve_llm()
@click.command("mm_embedding_serve")
diff --git a/tensorrt_llm/commands/utils.py b/tensorrt_llm/commands/utils.py
new file mode 100644
index 0000000000..df1442c6e7
--- /dev/null
+++ b/tensorrt_llm/commands/utils.py
@@ -0,0 +1,132 @@
+# Adapted from https://github.com/sgl-project/sglang/blob/030496eb06472f76fcb11de53d93f10cefb4604f/python/sglang/cli/utils.py#L27
+import json
+import logging
+import os
+
+from tensorrt_llm.llmapi.utils import download_hf_partial
+
+logger = logging.getLogger(__name__)
+
+
+def _maybe_download_model(
+ model_name_or_path: str, local_dir: str | None = None, download: bool = True
+) -> str:
+ """Resolve a model path. If it's a local directory, return it.
+
+ If it's a Hugging Face Hub ID, download only the config file
+ (`model_index.json` or `config.json`) and return its directory.
+
+ Args:
+ model_name_or_path: Local path or Hugging Face Hub model ID
+ local_dir: Local directory to save the downloaded file (if any)
+ download: Whether to download from Hugging Face Hub when needed
+
+ Returns:
+ Local directory path that contains the downloaded config file, or the original local directory.
+ """
+ if os.path.exists(model_name_or_path):
+ logger.info("Model already exists locally")
+ return model_name_or_path
+
+ if not download:
+ return model_name_or_path
+
+ try:
+ logger.info(
+ "Downloading model_index.json from HF Hub for %s...",
+ model_name_or_path,
+ )
+ file_path = download_hf_partial(
+ model=model_name_or_path,
+ allow_patterns=["model_index.json", "config.json"],
+ )
+ logger.info("Downloaded to %s", file_path)
+ return str(file_path)
+ except Exception as e:
+ raise ValueError(
+ (
+ "Could not find model locally at %s and failed to download "
+ "model_index.json/config.json from HF Hub: %s"
+ )
+ % (model_name_or_path, e)
+ ) from e
+
+
+# Copied and adapted from hf_diffusers_utils.py
+def is_diffusers_model_path(model_path: str) -> bool:
+ """Verify if the model directory contains a valid diffusers configuration.
+
+ Args:
+ model_path: Path to the model directory
+
+ Returns:
+ The loaded model configuration as a dictionary if the model is a diffusers model
+ None if the model is not a diffusers model
+ """
+ # Prefer model_index.json which indicates a diffusers pipeline
+ config_path = os.path.join(model_path, "model_index.json")
+ if not os.path.exists(config_path):
+ return False
+
+ # Load the config
+ with open(config_path) as f:
+ config = json.load(f)
+
+ # Verify diffusers version exists
+ if "_diffusers_version" not in config:
+ return False
+ return True
+
+
+def get_is_diffusion_model(model_path: str):
+ model_path = _maybe_download_model(model_path)
+ is_diffusion_model = is_diffusers_model_path(model_path)
+ if is_diffusion_model:
+ logger.info("Diffusion model detected")
+ return is_diffusion_model
+
+
+def get_model_path(extra_argv):
+ # Find the model_path argument
+ model_path = None
+ for i, arg in enumerate(extra_argv):
+ if arg == "--model-path":
+ if i + 1 < len(extra_argv):
+ model_path = extra_argv[i + 1]
+ break
+ elif arg.startswith("--model-path="):
+ model_path = arg.split("=", 1)[1]
+ break
+
+ if model_path is None:
+ # Fallback for --help or other cases where model-path is not provided
+ if any(h in extra_argv for h in ["-h", "--help"]):
+ raise Exception(
+ "Usage: sglang serve --model-path [additional-arguments]\n\n"
+ "This command can launch either a standard language model server or a diffusion model server.\n"
+ "The server type is determined by the model path.\n"
+ "For specific arguments, please provide a model_path."
+ )
+ else:
+ raise Exception(
+ "Error: --model-path is required. Please provide the path to the model."
+ )
+ return model_path
+
+
+VISUAL_GEN_PARTIAL_MODEL_NAME_TO_MODEL_TYPE = {
+ "FLUX.2": "flux2",
+ "LTX-2": "ltx2",
+ "Wan2": "wan2",
+}
+
+
+def get_visual_gen_model_type(model_path: str):
+ for partial_model_name, model_type in VISUAL_GEN_PARTIAL_MODEL_NAME_TO_MODEL_TYPE.items():
+ if partial_model_name.lower() in model_path.lower():
+ return model_type
+
+ raise ValueError(
+ f"Unknown VISUAL_GEN model type for model path: {model_path},"
+ f"available models: {VISUAL_GEN_PARTIAL_MODEL_NAME_TO_MODEL_TYPE.keys()}"
+ )
diff --git a/tensorrt_llm/executor/ipc.py b/tensorrt_llm/executor/ipc.py
index f09dd31dc4..4fafc8307a 100644
--- a/tensorrt_llm/executor/ipc.py
+++ b/tensorrt_llm/executor/ipc.py
@@ -83,8 +83,7 @@ class ZeroMqQueue:
"Server and client should not receive HMAC key when encryption is disabled"
)
- if (socket_type == zmq.PAIR and self.is_server
- ) or socket_type == zmq.PULL or socket_type == zmq.ROUTER:
+ if self.should_bind_socket():
self.socket.bind(
self.address_endpoint
) # Binds to the address and occupy a port immediately
@@ -101,6 +100,31 @@ class ZeroMqQueue:
self.address = (self.address_endpoint, self.hmac_key)
+ def should_bind_socket(self) -> bool:
+ """
+ Determine if socket should bind vs connect based on type and role.
+
+ ZMQ binding conventions:
+ - PAIR: server binds, client connects (1-to-1 bidirectional)
+ - PULL: server binds to receive from multiple PUSH sockets
+ - PUSH: server binds when acting as message source
+ - ROUTER: always binds to handle multiple clients
+
+ Returns:
+ True if socket should bind, False if it should connect
+ """
+ # Server binds for PAIR, PULL, PUSH patterns
+ if self.is_server and self.socket_type in (zmq.PAIR, zmq.PULL,
+ zmq.PUSH):
+ return True
+
+ # ROUTER always binds (multi-client pattern)
+ if self.socket_type == zmq.ROUTER:
+ return True
+
+ # Client connects for all other cases
+ return False
+
def setup_lazily(self):
# Early return if setup is already done
if self._setup_done:
diff --git a/tensorrt_llm/inputs/data.py b/tensorrt_llm/inputs/data.py
index 615043fe48..48e3441df6 100644
--- a/tensorrt_llm/inputs/data.py
+++ b/tensorrt_llm/inputs/data.py
@@ -1,6 +1,6 @@
# Adapt from
# https://github.com/vllm-project/vllm/blob/2e33fe419186c65a18da6668972d61d7bbc31564/vllm/inputs/data.py
-from typing import Any, Dict, List, Union
+from typing import Any, Dict, List, Sequence, Union
from typing_extensions import NotRequired, TypedDict
@@ -85,3 +85,80 @@ def prompt_inputs(inputs: PromptInputs, ) -> Union[TextPrompt, TokensPrompt]:
f"Invalid type of inputs for llm.generate: {type(inputs)}")
return prompt_inputs
+
+
+class VisualGenTextPrompt(TypedDict):
+ prompt: str
+ negative_prompt: NotRequired[str]
+
+
+class VisualGenTokensPrompt(TypedDict):
+ prompt_token_ids: List[int]
+ negative_prompt_token_ids: NotRequired[List[int]]
+
+
+VisualGenPromptInputs = Union[
+ str,
+ List[int],
+ VisualGenTextPrompt,
+ VisualGenTokensPrompt,
+]
+
+VisualGenInputs = Union[
+ VisualGenPromptInputs,
+ Sequence[VisualGenPromptInputs],
+]
+
+
+def visual_gen_inputs(
+ inputs: "VisualGenPromptInputs",
+) -> Union["VisualGenTextPrompt", "VisualGenTokensPrompt"]:
+ # str -> text prompt
+ if isinstance(inputs, str):
+ return VisualGenTextPrompt(prompt=inputs)
+
+ # list[int] -> token prompt
+ if isinstance(inputs, list):
+ if len(inputs) == 0:
+ raise ValueError("`inputs` token list cannot be empty.")
+ if not all(isinstance(t, int) for t in inputs):
+ raise TypeError(
+ "`inputs` list must contain only ints when used as token IDs.")
+ return VisualGenTokensPrompt(prompt_token_ids=inputs)
+
+ # dict form
+ if isinstance(inputs, dict):
+ has_prompt = "prompt" in inputs
+ has_prompt_token_ids = "prompt_token_ids" in inputs
+
+ if has_prompt == has_prompt_token_ids:
+ raise ValueError(
+ "VisualGen prompt dict must contain exactly one of "
+ "`prompt` or `prompt_token_ids`.")
+
+ if has_prompt:
+ prompt = inputs.get("prompt")
+ if not isinstance(prompt, str) or prompt == "":
+ raise TypeError("`prompt` must be a non-empty string.")
+ if "negative_prompt" in inputs and not isinstance(
+ inputs["negative_prompt"], str):
+ raise TypeError("`negative_prompt` must be a string.")
+ return inputs # VisualGenTextPrompt
+
+ token_ids = inputs.get("prompt_token_ids")
+ if not isinstance(token_ids, list) or len(token_ids) == 0:
+ raise TypeError("`prompt_token_ids` must be a non-empty list[int].")
+ if not all(isinstance(t, int) for t in token_ids):
+ raise TypeError("`prompt_token_ids` must contain only ints.")
+ if "negative_prompt_token_ids" in inputs:
+ neg_ids = inputs["negative_prompt_token_ids"]
+ if not isinstance(neg_ids, list) or not all(
+ isinstance(t, int) for t in neg_ids):
+ raise TypeError(
+ "`negative_prompt_token_ids` must be a list[int].")
+ return inputs # VisualGenTokensPrompt
+
+ raise TypeError(
+ "Invalid `inputs` for VisualGen.generate. "
+ "Expected one of: str, list[int], VisualGenTextPrompt, VisualGenTokensPrompt."
+ )
diff --git a/tensorrt_llm/llmapi/__init__.py b/tensorrt_llm/llmapi/__init__.py
index b87b21f9f5..430426786f 100644
--- a/tensorrt_llm/llmapi/__init__.py
+++ b/tensorrt_llm/llmapi/__init__.py
@@ -22,10 +22,13 @@ from .llm_utils import (BuildConfig, KvCacheRetentionConfig, QuantAlgo,
QuantConfig)
from .mm_encoder import MultimodalEncoder
from .mpi_session import MpiCommSession
+from .visual_gen import VisualGen, VisualGenParams
__all__ = [
'LLM',
'AsyncLLM',
+ 'VisualGen',
+ 'VisualGenParams',
'MultimodalEncoder',
'CompletionOutput',
'RequestOutput',
diff --git a/tensorrt_llm/llmapi/disagg_utils.py b/tensorrt_llm/llmapi/disagg_utils.py
index 8512771e1f..86edbefea7 100644
--- a/tensorrt_llm/llmapi/disagg_utils.py
+++ b/tensorrt_llm/llmapi/disagg_utils.py
@@ -24,6 +24,7 @@ class ServerRole(IntEnum):
CONTEXT = 0
GENERATION = 1
MM_ENCODER = 2
+ VISUAL_GEN = 3
@dataclass
diff --git a/tensorrt_llm/llmapi/utils.py b/tensorrt_llm/llmapi/utils.py
index f79823d844..78fe3d6298 100644
--- a/tensorrt_llm/llmapi/utils.py
+++ b/tensorrt_llm/llmapi/utils.py
@@ -236,18 +236,34 @@ def download_hf_model(model: str, revision: Optional[str] = None) -> Path:
return Path(hf_folder)
-def download_hf_pretrained_config(model: str,
- revision: Optional[str] = None) -> Path:
+def download_hf_partial(model: str,
+ allow_patterns: List[str],
+ revision: Optional[str] = None) -> Path:
+ """Download a partial model from HuggingFace.
+
+ Args:
+ model: The model name or path.
+ revision: The revision to use for the model.
+ allow_patterns: The patterns to allow for the model.
+
+ Returns:
+ The path to the downloaded model.
+ """
with get_file_lock(model):
hf_folder = snapshot_download(
model,
local_files_only=huggingface_hub.constants.HF_HUB_OFFLINE,
revision=revision,
- allow_patterns=["config.json"],
+ allow_patterns=allow_patterns,
tqdm_class=DisabledTqdm)
return Path(hf_folder)
+def download_hf_pretrained_config(model: str,
+ revision: Optional[str] = None) -> Path:
+ return download_hf_partial(model, ["config.json"], revision)
+
+
def append_docstring(docstring: str):
''' A decorator to append a docstring to a function. '''
diff --git a/tensorrt_llm/llmapi/visual_gen.py b/tensorrt_llm/llmapi/visual_gen.py
new file mode 100644
index 0000000000..2113b43548
--- /dev/null
+++ b/tensorrt_llm/llmapi/visual_gen.py
@@ -0,0 +1,544 @@
+import asyncio
+import queue
+import socket
+import threading
+import time
+import traceback
+from dataclasses import dataclass
+from pathlib import Path
+from typing import Any, Dict, List, Optional, Union
+
+import torch.multiprocessing as mp
+import zmq
+
+from tensorrt_llm._torch.visual_gen import DiffusionRequest, DiffusionResponse
+from tensorrt_llm._torch.visual_gen.executor import run_diffusion_worker
+from tensorrt_llm._torch.visual_gen.output import MediaOutput
+
+__all__ = ["VisualGen", "VisualGenParams", "MediaOutput"]
+from tensorrt_llm.executor.ipc import ZeroMqQueue
+from tensorrt_llm.inputs.data import VisualGenInputs
+from tensorrt_llm.logger import logger
+
+# Timeouts (seconds)
+POLL_TIMEOUT = 0.01
+AWAIT_TIMEOUT = 0.05
+THREAD_TIMEOUT = 5.0
+WORKER_TIMEOUT = 2.0
+READY_TIMEOUT = 1200 # 20 minutes for large models (Wan 2.2 with transformer_2)
+
+
+def find_free_port() -> int:
+ with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
+ s.bind(("", 0))
+ return s.getsockname()[1]
+
+
+def get_ip_address() -> str:
+ """Get local IP address."""
+ s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
+ try:
+ s.connect(("10.255.255.255", 1))
+ return s.getsockname()[0]
+ except Exception:
+ return "127.0.0.1"
+ finally:
+ s.close()
+
+
+class DiffusionRemoteClient:
+ """Client proxy for remote DiffusionExecutor in worker processes."""
+
+ def __init__(
+ self,
+ model_path: Union[str, Path],
+ n_workers: int = 1,
+ diffusion_config: Optional[dict] = None,
+ ):
+ self.model_path = str(model_path)
+ self.n_workers = n_workers
+ self.diffusion_config = diffusion_config
+
+ # Setup distributed env
+ self.master_addr = "127.0.0.1"
+ self.master_port = find_free_port()
+
+ # Setup IPC addresses
+ self.host_ip = get_ip_address()
+ req_port, resp_port = find_free_port(), find_free_port()
+
+ self.request_queue_addr = f"tcp://0.0.0.0:{req_port}"
+ self.response_queue_addr = f"tcp://0.0.0.0:{resp_port}"
+ self.req_addr_connect = f"tcp://{self.host_ip}:{req_port}"
+ self.resp_addr_connect = f"tcp://{self.host_ip}:{resp_port}"
+
+ # IPC setup
+ self.requests_ipc = None
+ self.responses_ipc = None
+ self.pending_requests = queue.Queue()
+ self.completed_responses: Dict[int, DiffusionResponse] = {}
+
+ # We'll create asyncio primitives in the background thread's event loop
+ self._event_loop = None
+ self.response_event = None
+ self.lock = None
+ self.shutdown_event = threading.Event()
+ self.event_loop_ready = threading.Event()
+
+ # Start background thread (it will create its own event loop)
+ self.background_thread = threading.Thread(target=self._serve_forever_thread, daemon=True)
+ self.background_thread.start()
+
+ # Wait for the background thread to initialize the event loop
+ self.event_loop_ready.wait()
+
+ # Launch workers
+ logger.info(f"DiffusionClient: Launching {n_workers} workers")
+ ctx = mp.get_context("spawn")
+ self.worker_processes = []
+ for rank in range(n_workers):
+ p = ctx.Process(
+ target=run_diffusion_worker,
+ kwargs={
+ "rank": rank,
+ "world_size": n_workers,
+ "master_addr": self.master_addr,
+ "master_port": self.master_port,
+ "model_path": self.model_path,
+ "request_queue_addr": self.req_addr_connect,
+ "response_queue_addr": self.resp_addr_connect,
+ "diffusion_config": self.diffusion_config,
+ },
+ )
+ p.start()
+ self.worker_processes.append(p)
+
+ self._wait_ready()
+
+ @staticmethod
+ def _close_socket(ipc_queue):
+ if ipc_queue and ipc_queue.socket:
+ ipc_queue.socket.setsockopt(zmq.LINGER, 0)
+ ipc_queue.close()
+
+ def enqueue_requests(self, requests: List[DiffusionRequest]) -> List[int]:
+ """Enqueue requests and return their IDs."""
+ req_ids = []
+ for req in requests:
+ self.pending_requests.put(req)
+ req_ids.append(req.request_id)
+ return req_ids
+
+ async def await_responses(
+ self, request_ids: Union[int, List[int]], timeout: Optional[float] = None
+ ) -> Union[DiffusionResponse, List[DiffusionResponse]]:
+ """Wait for responses by request IDs.
+
+ Args:
+ request_ids: Single request ID or list of request IDs to wait for
+ timeout: Maximum total wait time in seconds (None = wait indefinitely)
+
+ Returns:
+ Single response or list of responses (None if request timed out)
+ """
+ is_single = isinstance(request_ids, int)
+ ids = [request_ids] if is_single else request_ids
+
+ start_time = time.time()
+ results = {}
+
+ while len(results) < len(ids):
+ async with self.lock:
+ for req_id in ids:
+ if req_id in self.completed_responses:
+ results[req_id] = self.completed_responses.pop(req_id)
+
+ # All responses collected
+ if len(results) == len(ids):
+ break
+
+ # Check if overall timeout exceeded
+ if timeout is not None:
+ elapsed = time.time() - start_time
+ if elapsed >= timeout:
+ break
+ # Wait for remaining time or AWAIT_TIMEOUT, whichever is shorter
+ wait_time = min(timeout - elapsed, AWAIT_TIMEOUT)
+ else:
+ wait_time = AWAIT_TIMEOUT
+
+ try:
+ await asyncio.wait_for(self.response_event.wait(), timeout=wait_time)
+ except asyncio.TimeoutError:
+ pass
+ self.response_event.clear()
+
+ out = [results.get(rid) for rid in ids]
+ return out[0] if is_single else out
+
+ def await_responses_sync(
+ self, request_ids: Union[int, List[int]], timeout: Optional[float] = None
+ ) -> Union[DiffusionResponse, List[DiffusionResponse]]:
+ """Sync wrapper to await responses from the main thread."""
+ future = asyncio.run_coroutine_threadsafe(
+ self.await_responses(request_ids, timeout), self._event_loop
+ )
+ return future.result(timeout=timeout if timeout else None)
+
+ def _init_ipc(self) -> bool:
+ """Initialize IPC queues."""
+ try:
+ logger.info("DiffusionClient: Initializing IPC")
+ self.requests_ipc = ZeroMqQueue(
+ (self.request_queue_addr, None),
+ is_server=True,
+ socket_type=zmq.PUSH,
+ use_hmac_encryption=False,
+ )
+ self.responses_ipc = ZeroMqQueue(
+ (self.response_queue_addr, None),
+ is_server=True,
+ socket_type=zmq.PULL,
+ use_hmac_encryption=False,
+ )
+ logger.info("DiffusionClient: IPC ready")
+ return True
+ except Exception as e:
+ logger.error(f"DiffusionClient: IPC init failed: {e}")
+ return False
+
+ def _send_shutdown(self):
+ """Send shutdown signal."""
+ logger.info("DiffusionClient: Sending shutdown signal")
+ if self.requests_ipc:
+ self.requests_ipc.put(None)
+ self._close_socket(self.requests_ipc)
+
+ def _process_requests(self):
+ """Process pending requests."""
+ try:
+ req = self.pending_requests.get(timeout=POLL_TIMEOUT)
+ if req is None:
+ self._send_shutdown()
+ self.shutdown_event.set()
+ return
+
+ logger.info(f"DiffusionClient: Sending request {req.request_id}")
+ self.requests_ipc.put(req)
+ except queue.Empty:
+ pass
+ except Exception as e:
+ logger.error(f"DiffusionClient: Error sending request: {e}")
+ logger.error(traceback.format_exc())
+
+ def _process_responses(self):
+ """Poll and process responses."""
+ try:
+ if self.responses_ipc.poll(timeout=POLL_TIMEOUT):
+ response = self.responses_ipc.get()
+ if isinstance(response, DiffusionResponse):
+ if response.request_id == -1:
+ logger.info("DiffusionClient: Received READY signal")
+
+ # Schedule the lock acquisition and event setting in the event loop
+ asyncio.run_coroutine_threadsafe(
+ self._store_response(response), self._event_loop
+ )
+ except Exception as e:
+ logger.error(f"DiffusionClient: Error processing response: {e}")
+
+ async def _store_response(self, response: DiffusionResponse):
+ """Store response in the completed_responses dict (async helper)."""
+ async with self.lock:
+ self.completed_responses[response.request_id] = response
+ self.response_event.set()
+
+ def _cleanup_ipc(self):
+ """Cleanup IPC."""
+ logger.info("DiffusionClient: Cleaning up IPC")
+ self._close_socket(self.requests_ipc)
+ self._close_socket(self.responses_ipc)
+
+ def _serve_forever_thread(self):
+ """Background thread wrapper that creates and runs an event loop."""
+ logger.info("DiffusionClient: Background thread started")
+
+ # Create a new event loop for this thread
+ self._event_loop = asyncio.new_event_loop()
+ asyncio.set_event_loop(self._event_loop)
+
+ # Create async primitives in this thread's event loop
+ self.response_event = asyncio.Event()
+ self.lock = asyncio.Lock()
+
+ # Signal that the event loop is ready
+ self.event_loop_ready.set()
+
+ # Run the async serve_forever
+ try:
+ self._event_loop.run_until_complete(self._serve_forever())
+ finally:
+ self._event_loop.close()
+ logger.info("DiffusionClient: Background thread stopped")
+
+ async def _serve_forever(self):
+ """Background thread main loop (async version)."""
+ if not self._init_ipc():
+ return
+
+ while not self.shutdown_event.is_set():
+ self._process_requests()
+ self._process_responses()
+ await asyncio.sleep(0.001) # Yield control to allow other coroutines to run
+
+ self._cleanup_ipc()
+
+ def shutdown(self):
+ """Shutdown client and workers."""
+ logger.info("DiffusionClient: Shutting down")
+ self.pending_requests.put(None)
+
+ self.background_thread.join(timeout=THREAD_TIMEOUT)
+ if self.background_thread.is_alive():
+ logger.warning("DiffusionClient: Force stopping background thread")
+ self.shutdown_event.set()
+ self.background_thread.join(timeout=1.0)
+
+ # Shutdown workers
+ logger.info("DiffusionClient: Stopping workers")
+ for p in self.worker_processes:
+ p.join(timeout=WORKER_TIMEOUT)
+ if p.is_alive():
+ logger.warning(f"DiffusionClient: Terminating worker {p.pid} with SIGTERM")
+ p.terminate()
+ p.join(timeout=WORKER_TIMEOUT)
+ if p.is_alive():
+ logger.warning(f"DiffusionClient: Force killing worker {p.pid} with SIGKILL")
+ p.kill()
+ p.join(timeout=WORKER_TIMEOUT)
+
+ def _wait_ready(self, timeout: float = READY_TIMEOUT):
+ """Wait for workers to be ready (sync wrapper for async operation)."""
+ logger.info("DiffusionClient: Waiting for workers")
+
+ # Run the async wait in the background thread's event loop
+ future = asyncio.run_coroutine_threadsafe(self._wait_ready_async(timeout), self._event_loop)
+ return future.result(timeout=timeout)
+
+ async def _wait_ready_async(self, timeout: float = READY_TIMEOUT):
+ """Wait for workers to be ready (async version)."""
+ start_time = time.time()
+
+ while True:
+ async with self.lock:
+ if -1 in self.completed_responses:
+ self.completed_responses.pop(-1)
+ logger.info("DiffusionClient: Workers ready")
+ return
+
+ if time.time() - start_time > timeout:
+ raise RuntimeError("DiffusionClient: Timeout waiting for workers")
+
+ try:
+ await asyncio.wait_for(self.response_event.wait(), timeout=AWAIT_TIMEOUT)
+ except asyncio.TimeoutError:
+ pass
+ self.response_event.clear()
+
+
+class DiffusionGenerationResult:
+ """Future-like object for async generation."""
+
+ def __init__(self, request_id: int, executor: DiffusionRemoteClient):
+ self.request_id = request_id
+ self.executor = executor
+ self._result = None
+ self._finished = False
+ self._error = None
+
+ async def result(self, timeout: Optional[float] = None) -> Any:
+ """Wait for and return result (async version).
+
+ Can be awaited from any async context (e.g., FastAPI background tasks).
+ """
+ if self._finished:
+ if self._error:
+ raise RuntimeError(self._error)
+ return self._result
+
+ # Use run_coroutine_threadsafe to execute in the background thread's event loop
+ future = asyncio.run_coroutine_threadsafe(
+ self.executor.await_responses(self.request_id, timeout=timeout),
+ self.executor._event_loop,
+ )
+
+ # Await the future in the current event loop
+ response = await asyncio.wrap_future(future)
+
+ if response.error_msg:
+ self._error = response.error_msg
+ self._finished = True
+ raise RuntimeError(f"Generation failed: {response.error_msg}")
+
+ self._result = response.output
+ self._finished = True
+ return self._result
+
+ def cancel(self):
+ raise NotImplementedError("Cancel request (not yet implemented).")
+
+
+@dataclass
+class VisualGenParams:
+ """Parameters for visual generation.
+
+ Attributes:
+ height: Output height in pixels
+ width: Output width in pixels
+ num_inference_steps: Number of denoising steps
+ guidance_scale: Classifier-free guidance scale
+ max_sequence_length: Maximum sequence length for text encoding
+ seed: Random seed for reproducibility
+
+ # Video-specific parameters
+ num_frames: Number of video frames to generate
+ frame_rate: Frame rate for video output in fps
+
+ # Image-specific parameters
+ num_images_per_prompt: Number of images to generate per prompt (for image models)
+
+ # Advanced parameters
+ guidance_rescale: Guidance rescale factor (for some models)
+ output_type: Output type ("pt" for PyTorch tensors, "pil" for PIL images)
+ """
+
+ height: int = 720
+ width: int = 1280
+ num_inference_steps: int = 50
+ guidance_scale: float = 5.0
+ max_sequence_length: int = 512
+ seed: int = 42
+
+ # Video-specific parameters
+ num_frames: int = 81
+ frame_rate: float = 24.0
+ input_reference: Optional[str] = None
+
+ # Image-specific parameters
+ num_images_per_prompt: int = 1
+
+ ## Image edit parameters
+ image: Optional[List[str]] = None
+ mask: Optional[str] = None
+
+ # Advanced parameters
+ guidance_rescale: float = 0.0
+ output_type: str = "pt"
+
+ # Wan-specific parameters
+ guidance_scale_2: Optional[float] = None
+ boundary_ratio: Optional[float] = None
+ last_image: Optional[str] = None
+
+
+class VisualGen:
+ """High-level API for visual generation."""
+
+ def __init__(
+ self,
+ model_path: Union[str, Path],
+ n_workers: int = 1,
+ diffusion_config: Optional[dict] = None,
+ ):
+ self.model_path = str(model_path)
+ self.n_workers = n_workers
+ self.diffusion_config = diffusion_config
+
+ self.executor = DiffusionRemoteClient(
+ model_path=self.model_path,
+ n_workers=self.n_workers,
+ diffusion_config=self.diffusion_config,
+ )
+ self.req_counter = 0
+
+ def generate(
+ self,
+ inputs: VisualGenInputs,
+ params: VisualGenParams,
+ ) -> MediaOutput:
+ """Synchronous generation. Blocks until complete.
+
+ Args:
+ params: Generation parameters.
+
+ Returns:
+ MediaOutput: Generated media with model-specific fields populated:
+ - FLUX2: MediaOutput(image=torch.Tensor)
+ - WAN: MediaOutput(video=torch.Tensor)
+ - LTX2: MediaOutput(video=torch.Tensor, audio=torch.Tensor)
+ """
+ future = self.generate_async(
+ inputs=inputs,
+ params=params,
+ )
+
+ # Use the sync wrapper to get result
+ response = self.executor.await_responses_sync(future.request_id, timeout=None)
+ if response.error_msg:
+ raise RuntimeError(f"Generation failed: {response.error_msg}")
+ return response.output
+
+ def generate_async(
+ self,
+ inputs: VisualGenInputs,
+ params: VisualGenParams,
+ ) -> DiffusionGenerationResult:
+ """Async generation. Returns immediately with future-like object.
+
+ Args:
+ params: Generation parameters.
+
+ Returns:
+ DiffusionGenerationResult: Call result() to get output dict.
+ """
+ req_id = self.req_counter
+ self.req_counter += 1
+
+ if isinstance(inputs, dict):
+ prompt = inputs.get("prompt")
+ negative_prompt = inputs.get("negative_prompt", None)
+ elif isinstance(inputs, str):
+ prompt = inputs
+ negative_prompt = None
+ else:
+ # TODO: Support batch generation
+ raise ValueError(f"Invalid inputs type: {type(inputs)}")
+
+ request = DiffusionRequest(
+ request_id=req_id,
+ prompt=prompt,
+ negative_prompt=negative_prompt,
+ height=params.height,
+ width=params.width,
+ num_inference_steps=params.num_inference_steps,
+ guidance_scale=params.guidance_scale,
+ max_sequence_length=params.max_sequence_length,
+ seed=params.seed,
+ num_frames=params.num_frames,
+ frame_rate=params.frame_rate,
+ num_images_per_prompt=params.num_images_per_prompt,
+ guidance_rescale=params.guidance_rescale,
+ output_type=params.output_type,
+ image=params.input_reference,
+ guidance_scale_2=params.guidance_scale_2,
+ boundary_ratio=params.boundary_ratio,
+ last_image=params.last_image,
+ )
+
+ self.executor.enqueue_requests([request])
+ return DiffusionGenerationResult(req_id, self.executor)
+
+ def shutdown(self):
+ """Shutdown executor and cleanup."""
+ logger.info("VisualGen: Shutting down")
+ self.executor.shutdown()
diff --git a/tensorrt_llm/ray_stub.py b/tensorrt_llm/ray_stub.py
index 9bd699d929..34d3b4e97c 100644
--- a/tensorrt_llm/ray_stub.py
+++ b/tensorrt_llm/ray_stub.py
@@ -16,10 +16,8 @@ from functools import wraps as _wraps
from tensorrt_llm._utils import mpi_disabled as _mpi_disabled
-if _mpi_disabled():
- raise RuntimeError(
- "Ray requested (TLLM_DISABLE_MPI=1), but not installed. Please install Ray."
- )
+# Don't raise error on import - only when Ray functionality is actually used
+_RAY_NOT_INSTALLED_MSG = "Ray requested (TLLM_DISABLE_MPI=1), but not installed. Please install Ray."
def remote(*args, **kwargs):
@@ -42,6 +40,7 @@ def remote(*args, **kwargs):
def __getattr__(name):
- raise RuntimeError(
- f'Ray not installed, so "ray.{name}" is unavailable. Please install Ray.'
- )
+ msg = f'Ray not installed, so "ray.{name}" is unavailable.'
+ if _mpi_disabled():
+ msg = _RAY_NOT_INSTALLED_MSG
+ raise RuntimeError(msg)
diff --git a/tensorrt_llm/serve/media_storage.py b/tensorrt_llm/serve/media_storage.py
new file mode 100644
index 0000000000..acbdedae65
--- /dev/null
+++ b/tensorrt_llm/serve/media_storage.py
@@ -0,0 +1,426 @@
+#!/usr/bin/env python
+"""Media Storage for generated images and videos.
+
+This module provides storage handlers for persisting generated media assets
+(videos, images) and their associated metadata.
+"""
+
+import os
+from io import BytesIO
+from pathlib import Path
+from typing import Any, Optional
+
+import torch
+from PIL import Image
+
+from tensorrt_llm.logger import logger
+
+
+class MediaStorage:
+ """Handler for storing images and videos in various formats."""
+
+ @staticmethod
+ def save_image(
+ image: Any, output_path: str, format: Optional[str] = None, quality: int = 95
+ ) -> str:
+ """Save image to file.
+
+ Args:
+ image: torch.Tensor (H, W, C) uint8
+ output_path: Path to save the image
+ format: Image format (png, jpg, webp). If None, infer from extension
+ quality: Quality for lossy formats (1-100, higher is better)
+
+ Returns:
+ Path where the image was saved
+ """
+ # Ensure output directory exists
+ output_dir = os.path.dirname(output_path)
+ if output_dir:
+ os.makedirs(output_dir, exist_ok=True)
+
+ # Convert to PIL Image if needed
+ pil_image = MediaStorage._to_pil_image(image)
+
+ # Determine format
+ if format is None:
+ ext = os.path.splitext(output_path)[1].lower()
+ if ext in [".png"]:
+ format = "PNG"
+ elif ext in [".jpg", ".jpeg"]:
+ format = "JPEG"
+ elif ext in [".webp"]:
+ format = "WEBP"
+ else:
+ logger.warning(f"Unknown image extension {ext}, defaulting to PNG")
+ format = "PNG"
+ output_path = output_path.rsplit(".", 1)[0] + ".png"
+
+ # Save image with format-specific handling
+ MediaStorage._save_pil_image(pil_image, output_path, format, quality)
+
+ logger.info(f"Saved image to {output_path} (format={format})")
+ return output_path
+
+ @staticmethod
+ def convert_image_to_bytes(image: Any, format: str = "PNG", quality: int = 95) -> bytes:
+ """Convert image to bytes buffer.
+
+ Args:
+ image: torch.Tensor (H, W, C) uint8
+ format: Image format (PNG, JPEG, WEBP)
+ quality: Quality for lossy formats (1-100)
+
+ Returns:
+ Image bytes
+ """
+ pil_image = MediaStorage._to_pil_image(image)
+
+ # Save to bytes buffer
+ buffer = BytesIO()
+ MediaStorage._save_pil_image(pil_image, buffer, format, quality)
+
+ return buffer.getvalue()
+
+ @staticmethod
+ def _to_pil_image(image: torch.Tensor) -> Image.Image:
+ """Convert torch.Tensor to PIL Image.
+
+ Args:
+ image: torch.Tensor (H, W, C) uint8
+
+ Returns:
+ PIL Image
+ """
+ if not isinstance(image, torch.Tensor):
+ raise ValueError(f"Expected torch.Tensor, got {type(image)}")
+
+ # Convert to numpy for PIL
+ image_np = image.cpu().numpy()
+ return Image.fromarray(image_np)
+
+ @staticmethod
+ def _save_pil_image(
+ pil_image: Image.Image,
+ output: Any, # Can be path string or BytesIO
+ format: str,
+ quality: int,
+ ):
+ """Save PIL Image to file or buffer.
+
+ Args:
+ pil_image: PIL Image to save
+ output: Output path (str) or BytesIO buffer
+ format: Image format (PNG, JPEG, WEBP)
+ quality: Quality for lossy formats (1-100)
+ """
+ format_upper = format.upper()
+
+ if format_upper in ["JPEG", "JPG"]:
+ # Convert RGBA to RGB for JPEG
+ if pil_image.mode in ("RGBA", "LA", "P"):
+ background = Image.new("RGB", pil_image.size, (255, 255, 255))
+ if pil_image.mode == "P":
+ pil_image = pil_image.convert("RGBA")
+ background.paste(
+ pil_image, mask=pil_image.split()[-1] if pil_image.mode == "RGBA" else None
+ )
+ pil_image = background
+ pil_image.save(output, format="JPEG", quality=quality, optimize=True)
+ elif format_upper == "WEBP":
+ pil_image.save(output, format="WEBP", quality=quality)
+ else: # PNG or default
+ pil_image.save(output, format="PNG", optimize=True)
+
+ @staticmethod
+ def save_video(
+ video: Any,
+ output_path: str,
+ audio: Optional[Any] = None,
+ frame_rate: float = 24.0,
+ format: Optional[str] = None,
+ ) -> str:
+ """Save video to file with optional audio.
+
+ Args:
+ video: Video frames as torch.Tensor (T, H, W, C) uint8
+ output_path: Path to save the video
+ audio: Optional audio as torch.Tensor
+ frame_rate: Frames per second (default: 24.0)
+ format: Video format (mp4, gif, png). If None, infer from extension
+
+ Returns:
+ Path where the video was saved
+ """
+ # Ensure output directory exists
+ if isinstance(output_path, Path):
+ output_path = str(output_path)
+
+ output_dir = os.path.dirname(output_path)
+ if output_dir:
+ os.makedirs(output_dir, exist_ok=True)
+
+ # Determine format
+ if format is None:
+ ext = os.path.splitext(output_path)[1].lower()
+ format = ext[1:] if ext else "mp4"
+
+ format = format.lower()
+
+ # Save based on format
+ if format == "mp4":
+ MediaStorage._save_mp4(video, audio, output_path, frame_rate)
+ elif format == "gif":
+ MediaStorage._save_gif(video, output_path, frame_rate)
+ elif format == "png":
+ MediaStorage._save_middle_frame(video, output_path)
+ else:
+ logger.warning(f"Unsupported video format: {format}, defaulting to mp4")
+ output_path = output_path.rsplit(".", 1)[0] + ".mp4"
+ MediaStorage._save_mp4(video, audio, output_path, frame_rate)
+
+ return output_path
+
+ @staticmethod
+ def convert_video_to_bytes(
+ video: Any, audio: Optional[Any] = None, frame_rate: float = 24.0, format: str = "mp4"
+ ) -> bytes:
+ """Convert video to bytes buffer.
+
+ Args:
+ video: Video frames as torch.Tensor (T, H, W, C) uint8
+ audio: Optional audio as torch.Tensor
+ frame_rate: Frames per second
+ format: Video format (mp4, gif)
+
+ Returns:
+ Video bytes
+ """
+ import tempfile
+
+ # Create temporary file
+ with tempfile.NamedTemporaryFile(suffix=f".{format}", delete=False) as tmp_file:
+ tmp_path = tmp_file.name
+
+ try:
+ # Save to temporary file
+ MediaStorage.save_video(video, tmp_path, audio, frame_rate, format)
+
+ # Read bytes
+ with open(tmp_path, "rb") as f:
+ video_bytes = f.read()
+
+ return video_bytes
+ finally:
+ # Clean up temporary file
+ if os.path.exists(tmp_path):
+ os.unlink(tmp_path)
+
+ @staticmethod
+ def _save_mp4(
+ video: torch.Tensor, audio: Optional[torch.Tensor], output_path: str, frame_rate: float
+ ) -> str:
+ """Save video with optional audio as MP4.
+
+ Args:
+ video: Video frames as torch.Tensor (T, H, W, C) uint8
+ audio: Optional audio as torch.Tensor
+ output_path: Output path for MP4
+ frame_rate: Frames per second
+
+ Returns:
+ Path where the video was saved
+ """
+ try:
+ from fractions import Fraction
+
+ import av
+
+ if not isinstance(video, torch.Tensor):
+ raise ValueError(f"Expected torch.Tensor for video, got {type(video)}")
+
+ # Convert video tensor to numpy: (T, H, W, C) uint8
+ video_np = video.cpu().numpy()
+ num_frames, height, width, channels = video_np.shape
+
+ # Ensure RGB format (3 channels)
+ if channels != 3:
+ raise ValueError(f"Expected 3-channel RGB video, got {channels} channels")
+
+ # Open output container
+ container = av.open(output_path, mode="w")
+
+ # Add video stream (H.264 codec)
+ video_stream = container.add_stream("libx264", rate=int(frame_rate))
+ video_stream.width = width
+ video_stream.height = height
+ video_stream.pix_fmt = "yuv420p"
+ video_stream.options = {"preset": "medium", "crf": "23"}
+
+ # Pre-process audio and add audio stream BEFORE any muxing.
+ # All streams must be registered before the first mux() call
+ # (which triggers container header writing).
+ audio_stream = None
+ audio_tensor = None
+ audio_sample_rate = 24000 # Default sample rate
+ if audio is not None:
+ if not isinstance(audio, torch.Tensor):
+ raise ValueError(f"Expected torch.Tensor for audio, got {type(audio)}")
+
+ # Prepare audio tensor: convert to (samples, channels) format
+ audio_tensor = audio
+
+ # Handle different audio tensor dimensions
+ if audio_tensor.ndim == 1:
+ # Mono audio: (samples,) -> (samples, 1)
+ audio_tensor = audio_tensor[:, None]
+ elif audio_tensor.ndim == 2:
+ # If shape[1] != 2 and shape[0] == 2, transpose to (samples, channels)
+ if audio_tensor.shape[1] != 2 and audio_tensor.shape[0] == 2:
+ audio_tensor = audio_tensor.T
+ if audio_tensor.shape[1] > 2:
+ audio_tensor = audio_tensor[:, :2]
+ elif audio_tensor.ndim == 3:
+ if audio_tensor.shape[0] == 1:
+ audio_tensor = audio_tensor.squeeze(0)
+ else:
+ audio_tensor = audio_tensor[0]
+ if audio_tensor.shape[1] != 2 and audio_tensor.shape[0] == 2:
+ audio_tensor = audio_tensor.T
+ if audio_tensor.shape[1] > 2:
+ audio_tensor = audio_tensor[:, :2]
+ else:
+ raise ValueError(
+ f"Unsupported audio tensor shape: {audio_tensor.shape}. "
+ f"Expected 1D, 2D, or 3D tensor."
+ )
+
+ if audio_tensor.shape[1] > 2:
+ audio_tensor = audio_tensor[:, :2]
+
+ # Convert to int16 if needed
+ if audio_tensor.dtype != torch.int16:
+ audio_tensor = torch.clip(audio_tensor, -1.0, 1.0)
+ audio_tensor = (audio_tensor * 32767.0).to(torch.int16)
+
+ # Add audio stream now (before any muxing)
+ audio_stream = container.add_stream("aac", rate=audio_sample_rate)
+ audio_stream.codec_context.sample_rate = audio_sample_rate
+ audio_stream.codec_context.layout = "stereo"
+ audio_stream.codec_context.time_base = Fraction(1, audio_sample_rate)
+
+ # --- Encode video frames ---
+ for frame_array in video_np:
+ frame = av.VideoFrame.from_ndarray(frame_array, format="rgb24")
+ for packet in video_stream.encode(frame):
+ container.mux(packet)
+
+ # Flush video encoder
+ for packet in video_stream.encode():
+ container.mux(packet)
+
+ # --- Encode audio (after video is done) ---
+ if audio_stream is not None and audio_tensor is not None:
+ # Build packed int16 frame: (1, samples*channels)
+ audio_np = audio_tensor.contiguous().reshape(1, -1).cpu().numpy()
+
+ frame_in = av.AudioFrame.from_ndarray(audio_np, format="s16", layout="stereo")
+ frame_in.sample_rate = audio_sample_rate
+
+ # Use AudioResampler to convert s16āfltp (AAC's native format)
+ cc = audio_stream.codec_context
+ audio_resampler = av.audio.resampler.AudioResampler(
+ format=cc.format or "fltp",
+ layout=cc.layout or "stereo",
+ rate=cc.sample_rate or audio_sample_rate,
+ )
+
+ audio_next_pts = 0
+ for rframe in audio_resampler.resample(frame_in):
+ if rframe.pts is None:
+ rframe.pts = audio_next_pts
+ audio_next_pts += rframe.samples
+ rframe.sample_rate = audio_sample_rate
+ container.mux(audio_stream.encode(rframe))
+
+ # Flush audio encoder
+ for packet in audio_stream.encode():
+ container.mux(packet)
+
+ # Close container
+ container.close()
+
+ logger.info(f"Saved video{' with audio' if audio is not None else ''} to {output_path}")
+ return output_path
+
+ except ImportError:
+ logger.warning(
+ "PyAV (av) library not available. "
+ "Falling back to saving middle frame as PNG. "
+ "Install with: pip install av"
+ )
+ png_path = output_path.replace(".mp4", ".png")
+ return MediaStorage._save_middle_frame(video, png_path)
+ except Exception as e:
+ logger.error(f"Error encoding video with PyAV: {e}")
+ import traceback
+
+ logger.error(traceback.format_exc())
+ logger.warning("Falling back to saving middle frame as PNG.")
+ png_path = output_path.replace(".mp4", ".png")
+ return MediaStorage._save_middle_frame(video, png_path)
+
+ @staticmethod
+ def _save_gif(video: torch.Tensor, output_path: str, frame_rate: float) -> str:
+ """Save video as animated GIF.
+
+ Args:
+ video: Video frames as torch.Tensor (T, H, W, C) uint8
+ output_path: Output path for GIF
+ frame_rate: Frames per second
+
+ Returns:
+ Path where the GIF was saved
+ """
+ if not isinstance(video, torch.Tensor):
+ raise ValueError(f"Expected torch.Tensor for video, got {type(video)}")
+
+ # Convert to numpy and then to list of PIL Images
+ video_np = video.cpu().numpy()
+ frames = [Image.fromarray(video_np[i]) for i in range(video_np.shape[0])]
+
+ # Save as GIF
+ duration_ms = int(1000 / frame_rate)
+ frames[0].save(
+ output_path,
+ save_all=True,
+ append_images=frames[1:],
+ optimize=False,
+ duration=duration_ms,
+ loop=0,
+ )
+ logger.info(f"Saved video as GIF to {output_path} ({len(frames)} frames)")
+ return output_path
+
+ @staticmethod
+ def _save_middle_frame(video: torch.Tensor, output_path: str) -> str:
+ """Save middle frame of video as PNG.
+
+ Args:
+ video: Video frames as torch.Tensor (T, H, W, C) uint8
+ output_path: Output path for PNG
+
+ Returns:
+ Path where the frame was saved
+ """
+ if not isinstance(video, torch.Tensor):
+ raise ValueError(f"Expected torch.Tensor for video, got {type(video)}")
+
+ # Extract middle frame
+ video_np = video.cpu().numpy()
+ frame_idx = video_np.shape[0] // 2
+ image = Image.fromarray(video_np[frame_idx])
+
+ image.save(output_path)
+ logger.info(f"Saved frame {frame_idx} to {output_path}")
+ return output_path
diff --git a/tensorrt_llm/serve/openai_protocol.py b/tensorrt_llm/serve/openai_protocol.py
index 3afcb989d4..21f411cc33 100644
--- a/tensorrt_llm/serve/openai_protocol.py
+++ b/tensorrt_llm/serve/openai_protocol.py
@@ -1,12 +1,14 @@
# Adapted from
# https://github.com/vllm-project/vllm/blob/4db5176d9758b720b05460c50ace3c01026eb158/vllm/entrypoints/openai/protocol.py
import base64
+import re
import time
import uuid
from typing import Any, Dict, List, Literal, Optional, Union
import torch
import xgrammar
+from fastapi import UploadFile
from openai.types.chat import ChatCompletionAssistantMessageParam
from openai.types.chat import \
ChatCompletionContentPartParam as OpenAIChatCompletionContentPartParam
@@ -1120,5 +1122,218 @@ def to_llm_disaggregated_params(
)
+# ============================================================================
+# Diffusion API Protocol Classes
+# ============================================================================
+
+
+class ImageGenerationRequest(OpenAIBaseModel):
+ """OpenAI-compatible image generation request.
+
+ Follows the OpenAI Images API specification:
+ https://platform.openai.com/docs/api-reference/images/create
+ """
+ prompt: str
+ model: Optional[str] = None
+ n: int = Field(default=1, ge=1, le=10)
+ output_format: Literal["png", "webp", "jpeg"] = "png"
+ size: Optional[str] = Field(
+ default="auto",
+ description=(
+ "The size of the generated images. Must be in 'WxH' format like "
+ "1024x1024, 1536x1024 (landscape), 1024x1536 (portrait), etc. "
+ "Use 'auto' for model default size."))
+ quality: Literal["standard", "hd"] = "standard"
+ response_format: Literal["url", "b64_json"] = "url"
+ style: Optional[Literal["vivid", "natural"]] = "vivid"
+ user: Optional[str] = None
+
+ # Extended parameters for diffusion control
+ num_inference_steps: Optional[int] = Field(
+ default=None,
+ description=
+ "Number of denoising steps. More steps = higher quality but slower.")
+ guidance_scale: Optional[float] = Field(
+ default=None,
+ description=
+ "Classifier-free guidance scale. Higher values follow prompt more closely."
+ )
+ guidance_rescale: Optional[float] = Field(
+ default=None, description="Classifier-free guidance rescale.")
+ negative_prompt: Optional[str] = Field(
+ default=None,
+ description="Text describing what to avoid in the generated image.")
+ seed: Optional[int] = Field(default=None,
+ description="Random seed for reproducibility.")
+
+ @field_validator("size")
+ @classmethod
+ def validate_size(cls, v):
+ """Validate size format is 'WxH' or 'auto'."""
+ if v is None or v == "auto":
+ return v
+ if not isinstance(v, str):
+ raise ValueError("size must be a string in 'WxH' format or 'auto'")
+ # Check format: should be like "1024x1024"
+ import re
+ if not re.match(r'^\d+x\d+$', v):
+ raise ValueError(
+ f"Invalid size format '{v}'. Must be in 'WxH' format "
+ "(e.g., '1024x1024', '1536x1024') or 'auto'.")
+ return v
+
+
+class ImageObject(OpenAIBaseModel):
+ """Generated image object in the response."""
+ b64_json: Optional[str] = None
+ url: Optional[str] = None
+ revised_prompt: Optional[str] = None
+
+
+class ImageGenerationResponse(OpenAIBaseModel):
+ """Response from image generation endpoint."""
+ created: int = Field(default_factory=lambda: int(time.time()))
+ data: List[ImageObject]
+ output_format: Literal["png", "webp", "jpeg"] = "png"
+ quality: Literal["low", "medium", "high"] = "medium"
+ size: Optional[str] = None
+
+
+class ImageEditRequest(OpenAIBaseModel):
+ """Request for image editing endpoint.
+
+ Follows the OpenAI Images API specification:
+ https://platform.openai.com/docs/api-reference/images/createEdit
+ """
+ image: Union[List[str], str] = Field(
+ description="Base64-encoded source image(s) to edit")
+ prompt: str = Field(description="Text description of desired edits")
+ model: Optional[str] = None
+ mask: Optional[str] = Field(
+ default=None,
+ description=
+ "Base64-encoded mask image (optional, black areas will be edited)")
+ n: int = Field(default=1, ge=1, le=10)
+ size: Optional[str] = Field(
+ default="auto",
+ description=(
+ "The size of the edited images. Must be in 'WxH' format like "
+ "1024x1024, 1536x1024 (landscape), 1024x1536 (portrait), etc. "
+ "Use 'auto' to match source image size."))
+ response_format: Literal["url", "b64_json"] = "url"
+ user: Optional[str] = None
+
+ # Extended parameters for diffusion control
+ num_inference_steps: Optional[int] = Field(
+ default=None, description="Number of denoising steps.")
+ guidance_scale: Optional[float] = Field(
+ default=None, description="Classifier-free guidance scale.")
+ guidance_rescale: Optional[float] = Field(
+ default=None, description="Classifier-free guidance rescale.")
+ negative_prompt: Optional[str] = Field(
+ default=None,
+ description="Text describing what to avoid in the edited image.")
+ seed: Optional[int] = Field(default=None,
+ description="Random seed for reproducibility.")
+
+ @field_validator("size")
+ @classmethod
+ def validate_size(cls, v):
+ """Validate size format is 'WxH' or 'auto'."""
+ if v != "auto" and not re.match(r"^\d+x\d+$", v):
+ raise ValueError(
+ "Size must be 'auto' or in 'WxH' format (e.g., '1024x1024')")
+ return v
+
+
+class VideoGenerationRequest(OpenAIBaseModel):
+ """Video generation request (extended API).
+
+ This is an extension to the OpenAI API for video generation support.
+ """
+ prompt: str
+ input_reference: Optional[Union[str, UploadFile]] = Field(
+ default=None,
+ description="Optional image reference that guides generation.")
+ model: Optional[str] = None
+ size: Optional[str] = Field(
+ default="auto",
+ description=
+ ("The size of the generated video frames. Must be in 'WxH' format like "
+ "512x512, 1024x576 (landscape), 576x1024 (portrait), etc. "
+ "Use 'auto' for model default size."))
+ seconds: float = Field(default=2.0,
+ ge=1.0,
+ le=16.0,
+ description="Video duration in seconds.")
+
+ # Extended parameters for diffusion control
+ n: int = Field(default=1, ge=1, le=4)
+ fps: int = Field(default=24, ge=8, le=60, description="Frames per second.")
+ num_inference_steps: Optional[int] = Field(
+ default=None, description="Number of denoising steps.")
+ guidance_scale: Optional[float] = Field(
+ default=None, description="Classifier-free guidance scale.")
+ guidance_rescale: Optional[float] = Field(
+ default=None, description="Classifier-free guidance rescale.")
+ negative_prompt: Optional[str] = Field(
+ default=None,
+ description="Text describing what to avoid in the generated video.")
+ seed: Optional[int] = Field(default=None,
+ description="Random seed for reproducibility.")
+
+ @field_validator("size")
+ @classmethod
+ def validate_size(cls, v):
+ """Validate size format is 'WxH' or 'auto'."""
+ if v is None or v == "auto":
+ return v
+ if not isinstance(v, str):
+ raise ValueError("size must be a string in 'WxH' format or 'auto'")
+ import re
+ if not re.match(r'^\d+x\d+$', v):
+ raise ValueError(
+ f"Invalid size format '{v}'. Must be in 'WxH' format "
+ "(e.g., '512x512', '1024x576') or 'auto'.")
+ return v
+
+
+class VideoJob(OpenAIBaseModel):
+ """Metadata for an asynchronous video generation job.
+
+ Follows the OpenAI Videos API specification:
+ https://platform.openai.com/docs/api-reference/videos
+ """
+ completed_at: Optional[int] = Field(
+ default=None, description="Unix timestamp of completion")
+ created_at: int = Field(description="Unix timestamp of creation")
+ error: Optional[str] = Field(default=None,
+ description="Error message if failed")
+ expires_at: Optional[int] = Field(
+ default=None, description="Unix timestamp of expiration")
+ id: str = Field(description="Unique identifier for the video")
+ model: str = Field(description="The model used for generation")
+ object: str = Field(default="video", description="Object type")
+ progress: Optional[int] = Field(
+ default=None,
+ description="Progress of the video generation job (0-100)")
+ prompt: str = Field(description="The prompt used to generate the video")
+ status: Literal["queued", "in_progress", "completed", "failed"] = Field(
+ description="Current status of the video generation job")
+
+ # Video properties
+ duration: Optional[float] = Field(default=None,
+ description="Video duration in seconds")
+ fps: Optional[int] = Field(default=None, description="Frames per second")
+ size: Optional[str] = Field(default=None,
+ description="Video dimensions in 'WxH' format")
+
+
+class VideoJobList(OpenAIBaseModel):
+ """Response from listing video jobs endpoint."""
+ data: List[VideoJob] = Field(description="List of video jobs")
+ object: str = Field(default="list", description="Object type")
+
+
UCompletionRequest = Union[CompletionRequest, ChatCompletionRequest]
UCompletionResponse = Union[CompletionResponse, ChatCompletionResponse]
diff --git a/tensorrt_llm/serve/openai_server.py b/tensorrt_llm/serve/openai_server.py
index 9cb9d59918..c19dd93149 100644
--- a/tensorrt_llm/serve/openai_server.py
+++ b/tensorrt_llm/serve/openai_server.py
@@ -1,10 +1,13 @@
#!/usr/bin/env python
import asyncio
+import base64
import os
import re
import signal
import socket
+import time
import traceback
+import uuid
from collections import deque
from contextlib import asynccontextmanager
from datetime import datetime
@@ -16,7 +19,8 @@ from typing import (Annotated, Any, AsyncGenerator, AsyncIterator, List,
import uvicorn
from fastapi import Body, FastAPI, Request
from fastapi.exceptions import RequestValidationError
-from fastapi.responses import JSONResponse, Response, StreamingResponse
+from fastapi.responses import (FileResponse, JSONResponse, Response,
+ StreamingResponse)
from starlette.routing import Mount
from transformers import AutoProcessor
@@ -26,11 +30,12 @@ from tensorrt_llm._torch.async_llm import AsyncLLM
from tensorrt_llm.executor import CppExecutorError
from tensorrt_llm.executor.postproc_worker import PostprocParams
from tensorrt_llm.inputs import prompt_inputs
-from tensorrt_llm.inputs.data import TokensPrompt
+from tensorrt_llm.inputs.data import TokensPrompt, visual_gen_inputs
from tensorrt_llm.inputs.multimodal import MultimodalServerConfig
from tensorrt_llm.inputs.utils import ConversationMessage, apply_chat_template
from tensorrt_llm.llmapi import DisaggregatedParams as LlmDisaggregatedParams
-from tensorrt_llm.llmapi import MultimodalEncoder, tracing
+from tensorrt_llm.llmapi import (MultimodalEncoder, VisualGen, VisualGenParams,
+ tracing)
from tensorrt_llm.llmapi.disagg_utils import (DisaggClusterConfig,
MetadataServerConfig, ServerRole)
from tensorrt_llm.llmapi.llm import RequestOutput
@@ -40,6 +45,7 @@ from tensorrt_llm.serve.chat_utils import (load_chat_template,
parse_chat_messages_coroutines)
from tensorrt_llm.serve.cluster_storage import create_cluster_storage_client
from tensorrt_llm.serve.disagg_auto_scaling import DisaggClusterWorker
+from tensorrt_llm.serve.media_storage import MediaStorage
from tensorrt_llm.serve.metadata_server import create_metadata_server
from tensorrt_llm.serve.openai_protocol import (ChatCompletionRequest,
ChatCompletionResponse,
@@ -47,12 +53,17 @@ from tensorrt_llm.serve.openai_protocol import (ChatCompletionRequest,
ChatMessage, CompletionRequest,
CompletionResponse,
CompletionResponseChoice,
- ErrorResponse,
+ ErrorResponse, ImageEditRequest,
+ ImageGenerationRequest,
+ ImageGenerationResponse,
+ ImageObject,
MemoryUpdateRequest, ModelCard,
ModelList, PromptTokensDetails,
ResponsesRequest,
ResponsesResponse,
UpdateWeightsRequest, UsageInfo,
+ VideoGenerationRequest,
+ VideoJob, VideoJobList,
to_llm_disaggregated_params)
from tensorrt_llm.serve.postprocess_handlers import (
ChatCompletionPostprocArgs, ChatPostprocArgs, CompletionPostprocArgs,
@@ -69,6 +80,8 @@ from tensorrt_llm.serve.responses_utils import \
from tensorrt_llm.serve.responses_utils import get_steady_clock_now_in_seconds
from tensorrt_llm.serve.responses_utils import \
request_preprocess as responses_api_request_preprocess
+from tensorrt_llm.serve.visual_gen_utils import (VIDEO_STORE,
+ parse_visual_gen_params)
from tensorrt_llm.version import __version__ as VERSION
from .._utils import nvtx_mark, set_prometheus_multiproc_dir
@@ -82,7 +95,7 @@ TIMEOUT_KEEP_ALIVE = 5 # seconds.
class OpenAIServer:
def __init__(self,
- llm: Union[LLM, MultimodalEncoder],
+ generator: Union[LLM, MultimodalEncoder, VisualGen],
model: str,
tool_parser: Optional[str],
server_role: Optional[ServerRole],
@@ -90,40 +103,17 @@ class OpenAIServer:
disagg_cluster_config: Optional[DisaggClusterConfig] = None,
multimodal_server_config: Optional[MultimodalServerConfig] = None,
chat_template: Optional[str] = None):
- self.llm = llm
- self.tokenizer = llm.tokenizer
+ self.generator = generator
+ self._is_visual_gen = isinstance(generator, VisualGen)
self.tool_parser = tool_parser
self.metadata_server = create_metadata_server(metadata_server_cfg)
self.disagg_cluster_config = disagg_cluster_config
self.multimodal_server_config = multimodal_server_config
- self.chat_template = load_chat_template(chat_template)
self.server_role = server_role
# Will be set in __call__
self.binding_addr = None
self.host = None
self.port = None
- hf_tokenizer_path = llm._hf_model_dir or self.tokenizer.tokenizer.name_or_path
- trust_remote_code = llm.args.trust_remote_code
- try:
- self.processor = AutoProcessor.from_pretrained(hf_tokenizer_path, trust_remote_code=trust_remote_code)
- except Exception:
- logger.debug("Failed to load AutoProcessor or AutoConfig for %s", hf_tokenizer_path)
- self.processor = None
- # load model config
- try:
- from tensorrt_llm._torch.pyexecutor.config_utils import \
- load_pretrained_config
- self.model_config = load_pretrained_config(hf_tokenizer_path,
- trust_remote_code=trust_remote_code,
- checkpoint_format=getattr(self.llm.args, "checkpoint_format", None))
- except Exception:
- logger.debug("Failed to load AutoConfig for %s", hf_tokenizer_path)
- self.model_config = None
-
- # Enable response storage for Responses API
- self.enable_store = (len(os.getenv("TRTLLM_RESPONSES_API_DISABLE_STORE", "")) < 1) and not self.postproc_worker_enabled
-
- self.conversation_store = ConversationHistoryStore()
model_dir = Path(model)
if model_dir.exists() and model_dir.is_dir():
@@ -135,35 +125,19 @@ class OpenAIServer:
self.perf_metrics_lock = None
# The steady clock offset (in seconds) between this server and the disagg server
self.disagg_server_steady_clock_offset = 0
- if self.llm.args.return_perf_metrics:
- set_prometheus_multiproc_dir()
- self.metrics_collector = MetricsCollector({
- "model_name": "undefined",
- "engine_type": "undefined"
- })
- max_perf_metrics = self.llm.args.perf_metrics_max_requests
- if max_perf_metrics > 0:
- self.perf_metrics = deque(maxlen=max_perf_metrics)
- self.perf_metrics_lock = asyncio.Lock()
-
- # gpt-oss
- self.harmony_adapter: HarmonyAdapter | None = None
- disable_harmony = os.getenv("DISABLE_HARMONY_ADAPTER", "0") == "1"
- if disable_harmony:
- self.use_harmony = False
- else:
- self.use_harmony = (self.model_config.model_type == "gpt_oss")
-
- self.tool_call_id_type = "random" # default tool call id type is random
- if self.model_config.model_type == "kimi_k2":
- self.tool_call_id_type = "kimi_k2"
- elif self.model_config.model_type == "deepseek_v32":
- self.tool_call_id_type = "deepseek_v32"
# as disagg-worker
self.disagg_cluster_storage = None
self.disagg_cluster_worker = None
+ # Skip loading AutoProcessor and model_config for VISUAL_GEN models
+ # These are LLM-specific and can cause unnecessary memory usage
+ if self._is_visual_gen:
+ self._init_visual_gen()
+ else:
+ self._init_llm(chat_template)
+
+
@asynccontextmanager
async def lifespan(app: FastAPI):
if self.metadata_server is not None:
@@ -176,8 +150,8 @@ class OpenAIServer:
}
# TODO: add more metadata
# Register with ETCD using the existing key format
- self.metadata_server.put(f"trtllm/{self.llm.llm_id}", metadata)
- logger.info(f"trtllm/{self.llm.llm_id} is registered")
+ self.metadata_server.put(f"trtllm/{self.generator.llm_id}", metadata)
+ logger.info(f"trtllm/{self.generator.llm_id} is registered")
if self.disagg_cluster_config:
self.disagg_cluster_storage = create_cluster_storage_client(self.disagg_cluster_config.cluster_uri, self.disagg_cluster_config.cluster_name)
@@ -188,11 +162,11 @@ class OpenAIServer:
yield
if self.metadata_server is not None:
- self.metadata_server.remove(f"trtllm/{self.llm.llm_id}")
- logger.info(f"trtllm/{self.llm.llm_id} is unregistered")
+ self.metadata_server.remove(f"trtllm/{self.generator.llm_id}")
+ logger.info(f"trtllm/{self.generator.llm_id} is unregistered")
if self.disagg_cluster_worker:
await self.disagg_cluster_worker.deregister_worker()
- self.llm.shutdown()
+ self.generator.shutdown()
self.app = FastAPI(lifespan=lifespan)
@@ -200,15 +174,81 @@ class OpenAIServer:
async def validation_exception_handler(_, exc):
return JSONResponse(status_code=400, content={"error": str(exc)})
- if self.server_role is not ServerRole.MM_ENCODER:
- self.register_routes()
- else:
- assert isinstance(self.llm, MultimodalEncoder), "llm must be a MultimodalEncoder for multimodal encoder"
+ if self.server_role is ServerRole.VISUAL_GEN:
+ assert isinstance(self.generator, VisualGen), "generator must be a VisualGen for VISUAL_GEN server"
+ self.register_visual_gen_routes()
+ elif self.server_role is ServerRole.MM_ENCODER:
+ assert isinstance(self.generator, MultimodalEncoder), "generator must be a MultimodalEncoder for multimodal encoder"
self.register_mm_encoder_routes()
+ else:
+ self.register_routes()
self.app.add_middleware(ServerArrivalTimeMiddleware)
+ def _init_visual_gen(self):
+ self.processor = None
+ self.model_config = None
+ self.media_storage_path = Path(os.getenv("TRTLLM_MEDIA_STORAGE_PATH", "/tmp/trtllm_generated")) # nosec B108
+ self.media_storage_path.mkdir(exist_ok=True, parents= True)
+ self.video_gen_tasks = {}
+
+
+ def _init_llm(self, chat_template: Optional[str] = None):
+ self.tokenizer = self.generator.tokenizer
+ hf_tokenizer_path = self.generator._hf_model_dir or self.tokenizer.tokenizer.name_or_path
+ trust_remote_code = self.generator.args.trust_remote_code
+ try:
+ self.processor = AutoProcessor.from_pretrained(hf_tokenizer_path, trust_remote_code=trust_remote_code)
+ except Exception:
+ logger.debug("Failed to load AutoProcessor or AutoConfig for %s", hf_tokenizer_path)
+ self.processor = None
+
+ # load model config
+ try:
+ from tensorrt_llm._torch.pyexecutor.config_utils import \
+ load_pretrained_config
+ self.model_config = load_pretrained_config(hf_tokenizer_path,
+ trust_remote_code=trust_remote_code,
+ checkpoint_format=getattr(self.generator.args, "checkpoint_format", None))
+ except Exception:
+ logger.debug("Failed to load AutoConfig for %s", hf_tokenizer_path)
+ self.model_config = None
+
+ self.chat_template = load_chat_template(chat_template)
+
+ # Enable response storage for Responses API
+ self.enable_store = (len(os.getenv("TRTLLM_RESPONSES_API_DISABLE_STORE", "")) < 1) and not self.postproc_worker_enabled
+
+ self.conversation_store = ConversationHistoryStore()
+
+ # gpt-oss
+ self.harmony_adapter: HarmonyAdapter | None = None
+ disable_harmony = os.getenv("DISABLE_HARMONY_ADAPTER", "0") == "1"
+ if disable_harmony or self.model_config is None:
+ self.use_harmony = False
+ else:
+ self.use_harmony = (self.model_config.model_type == "gpt_oss")
+
+ self.tool_call_id_type = "random" # default tool call id type is random
+ if self.model_config is not None:
+ if self.model_config.model_type == "kimi_k2":
+ self.tool_call_id_type = "kimi_k2"
+ elif self.model_config.model_type == "deepseek_v32":
+ self.tool_call_id_type = "deepseek_v32"
+
+ if self.generator.args.return_perf_metrics:
+ set_prometheus_multiproc_dir()
+ self.metrics_collector = MetricsCollector({
+ "model_name": "undefined",
+ "engine_type": "undefined"
+ })
+ max_perf_metrics = self.generator.args.perf_metrics_max_requests
+ if max_perf_metrics > 0:
+ self.perf_metrics = deque(maxlen=max_perf_metrics)
+ self.perf_metrics_lock = asyncio.Lock()
+
+
async def await_disconnected(self, raw_request: Request, promise):
if raw_request is None:
return
@@ -221,7 +261,7 @@ class OpenAIServer:
@property
def postproc_worker_enabled(self) -> bool:
- return True if self.llm.args.num_postprocess_workers > 0 else False
+ return True if self.generator.args.num_postprocess_workers > 0 else False
@staticmethod
def create_error_response(
@@ -248,8 +288,20 @@ class OpenAIServer:
status_code=HTTPStatus.NOT_FOUND,
)
+ def _create_not_supported_error(self, message: str) -> Response:
+ return self.create_error_response(
+ err_type="NotImplementedError",
+ message=message,
+ status_code=HTTPStatus.NOT_IMPLEMENTED,
+ )
+
def _check_health(self) -> bool:
- return self.llm._check_health()
+ if isinstance(self.generator, LLM):
+ return self.generator._check_health()
+ # llmapi.LLM (e.g. PyTorch backend) is not isinstance(_tensorrt_engine.LLM)
+ if hasattr(self.generator, '_check_health'):
+ return self.generator._check_health()
+ return True
def register_routes(self):
self.app.add_api_route("/health", self.health, methods=["GET"])
@@ -293,7 +345,7 @@ class OpenAIServer:
self.app.add_api_route("/server_info",
self.get_server_info,
methods=["GET"])
- if self.llm.args.return_perf_metrics:
+ if self.generator.args.return_perf_metrics:
# register /prometheus/metrics
self.mount_metrics()
@@ -340,6 +392,45 @@ class OpenAIServer:
self.update_weights,
methods=["POST"])
+ def register_visual_gen_routes(self):
+ """Register routes for diffusion model serving."""
+ # Health and info endpoints
+ self.app.add_api_route("/health", self.health, methods=["GET"])
+ self.app.add_api_route("/version", self.version, methods=["GET"])
+ self.app.add_api_route("/v1/models", self.get_model, methods=["GET"])
+ self.app.add_api_route("/metrics", self.get_iteration_stats, methods=["GET"])
+
+ # Image generation endpoints (OpenAI compatible)
+ self.app.add_api_route("/v1/images/generations",
+ self.openai_image_generation,
+ methods=["POST"])
+ self.app.add_api_route("/v1/images/edits",
+ self.openai_image_edit,
+ methods=["POST"])
+
+ # Video generation endpoints (Extended OpenAI API)
+ # Asynchronous video generation (returns immediately with job metadata, OpenAI API)
+ self.app.add_api_route("/v1/videos",
+ self.openai_video_generation_async,
+ methods=["POST"])
+ # Synchronous video generation (waits for completion, extended API)
+ self.app.add_api_route("/v1/videos/generations",
+ self.openai_video_generation_sync,
+ methods=["POST"])
+ # Video management endpoints
+ self.app.add_api_route("/v1/videos",
+ self.list_videos,
+ methods=["GET"])
+ self.app.add_api_route("/v1/videos/{video_id}",
+ self.get_video_metadata,
+ methods=["GET"])
+ self.app.add_api_route("/v1/videos/{video_id}/content",
+ self.get_video_content,
+ methods=["GET"])
+ self.app.add_api_route("/v1/videos/{video_id}",
+ self.delete_video,
+ methods=["DELETE"])
+
async def health(self) -> Response:
if self._check_health():
return Response(status_code=200)
@@ -349,10 +440,10 @@ class OpenAIServer:
async def health_generate(self, raw_request: Request) -> Response:
"""Health check that performs a minimal generation."""
extra_args = {}
- if self.llm.args.max_beam_width > 1:
+ if self.generator.args.max_beam_width > 1:
extra_args = dict(
use_beam_search=True,
- best_of=self.llm.args.max_beam_width,
+ best_of=self.generator.args.max_beam_width,
n=1,
)
try:
@@ -396,7 +487,7 @@ class OpenAIServer:
async def get_iteration_stats(self) -> JSONResponse:
stats = []
- async for stat in self.llm.get_stats_async(2):
+ async for stat in self.generator.get_stats_async(2):
stats.append(stat)
return JSONResponse(content=stats)
@@ -416,7 +507,7 @@ class OpenAIServer:
return JSONResponse(content=[])
async with self.perf_metrics_lock:
perf_metrics = self.perf_metrics
- self.perf_metrics = deque(maxlen=self.llm.args.perf_metrics_max_requests)
+ self.perf_metrics = deque(maxlen=self.generator.args.perf_metrics_max_requests)
for metrics_dict in perf_metrics:
metrics = metrics_dict["perf_metrics"]
timing_metrics = metrics.timing_metrics
@@ -466,7 +557,7 @@ class OpenAIServer:
async def get_kv_cache_events(self) -> JSONResponse:
events = []
try:
- async for event in self.llm.get_kv_cache_events_async(2):
+ async for event in self.generator.get_kv_cache_events_async(2):
events.append(event)
except IndexError:
# queue is empty, no more events
@@ -478,7 +569,7 @@ class OpenAIServer:
return
if self.metrics_collector:
self.metrics_collector.log_metrics_dict(res.metrics_dict)
- if self.llm.args.return_perf_metrics:
+ if self.generator.args.return_perf_metrics:
output = res.outputs[0]
item = {
"request_id": res.request_id,
@@ -549,9 +640,9 @@ class OpenAIServer:
# expanded into an embedding bias tensor in the sampler.
sampling_params = request.to_sampling_params(
vocab_size=self.tokenizer.tokenizer.vocab_size,
- gather_generation_logits=self.llm.args.gather_generation_logits,
- reasoning_parser=self.llm.args.reasoning_parser,
- backend=self.llm.args.backend)
+ gather_generation_logits=self.generator.args.gather_generation_logits,
+ reasoning_parser=self.generator.args.reasoning_parser,
+ backend=self.generator.args.backend)
postproc_args = ChatPostprocArgs.from_request(request)
disaggregated_params = to_llm_disaggregated_params(request.disaggregated_params)
@@ -582,7 +673,7 @@ class OpenAIServer:
if mm_data and mm_embeddings:
raise ValueError("Passing 'multi_modal_data' and 'multi_modal_embeddings' at the same time is not supported.")
- postproc_args.reasoning_parser = self.llm.args.reasoning_parser
+ postproc_args.reasoning_parser = self.generator.args.reasoning_parser
postproc_args.tool_parser = self.tool_parser
postproc_args.tool_call_id_type = self.tool_call_id_type
if conversation and conversation[-1].get(
@@ -596,7 +687,7 @@ class OpenAIServer:
trace_headers = (None if raw_request is None else tracing.extract_trace_headers(raw_request.headers))
- promise = self.llm.generate_async(
+ promise = self.generator.generate_async(
inputs=prompt,
sampling_params=sampling_params,
_postproc_params=postproc_params if self.postproc_worker_enabled else None,
@@ -701,7 +792,7 @@ class OpenAIServer:
if mm_data is not None:
prompt["multi_modal_data"] = mm_data
- promise = self.llm.generate_async(
+ promise = self.generator.generate_async(
inputs=prompt,
)
asyncio.create_task(self.await_disconnected(raw_request, promise))
@@ -819,8 +910,8 @@ class OpenAIServer:
# expanded into an embedding bias tensor in the sampler.
sampling_params = request.to_sampling_params(
vocab_size=self.tokenizer.tokenizer.vocab_size,
- gather_generation_logits=self.llm.args.gather_generation_logits,
- backend=self.llm.args.backend)
+ gather_generation_logits=self.generator.args.gather_generation_logits,
+ backend=self.generator.args.backend)
# TODO: better way to enable metrics
if len(os.getenv("TRTLLM_KVCACHE_TIME_OUTPUT_PATH", "")) > 0:
sampling_params.return_perf_metrics = True
@@ -839,12 +930,12 @@ class OpenAIServer:
prompt = prompt_inputs(prompt)
if prompt.get("prompt") is not None:
- prompt_token_ids, extra_processed_inputs = await asyncio.to_thread(self.llm.input_processor, prompt, sampling_params)
+ prompt_token_ids, extra_processed_inputs = await asyncio.to_thread(self.generator.input_processor, prompt, sampling_params)
tokens_prompt = TokensPrompt(prompt_token_ids=prompt_token_ids, query_token_ids=extra_processed_inputs.get("query_token_ids") if extra_processed_inputs is not None else None)
else:
tokens_prompt = prompt
- promise = self.llm.generate_async(
+ promise = self.generator.generate_async(
inputs=tokens_prompt,
sampling_params=sampling_params,
_postproc_params=postproc_params,
@@ -947,7 +1038,7 @@ class OpenAIServer:
)
# Generate
- promise = self.llm.generate_async(
+ promise = self.generator.generate_async(
inputs=harmony_tokens,
sampling_params=sampling_params,
_postproc_params=postproc_params if self.postproc_worker_enabled else None,
@@ -1040,7 +1131,7 @@ class OpenAIServer:
tokenizer=self.tokenizer if not self.use_harmony else None,
model_config=self.model_config if not self.use_harmony else None,
processor=self.processor if not self.use_harmony else None,
- reasoning_parser=self.llm.args.reasoning_parser if not self.use_harmony else "gpt_oss",
+ reasoning_parser=self.generator.args.reasoning_parser if not self.use_harmony else "gpt_oss",
)
streaming_processor = None
@@ -1053,7 +1144,7 @@ class OpenAIServer:
conversation_store=self.conversation_store,
enable_store=self.enable_store and request.store,
use_harmony=self.use_harmony,
- reasoning_parser=self.llm.args.reasoning_parser,
+ reasoning_parser=self.generator.args.reasoning_parser,
tool_parser=self.tool_parser,
)
@@ -1062,7 +1153,7 @@ class OpenAIServer:
request=request,
sampling_params=sampling_params,
use_harmony=self.use_harmony,
- reasoning_parser=self.llm.args.reasoning_parser,
+ reasoning_parser=self.generator.args.reasoning_parser,
tool_parser=self.tool_parser,
streaming_processor=streaming_processor,
)
@@ -1071,7 +1162,7 @@ class OpenAIServer:
if request.stream else responses_api_post_processor,
postproc_args=postproc_args,
)
- promise = self.llm.generate_async(
+ promise = self.generator.generate_async(
inputs=input_tokens,
sampling_params=sampling_params,
streaming=request.stream,
@@ -1134,22 +1225,500 @@ class OpenAIServer:
})
async def release_memory(self, request: MemoryUpdateRequest) -> JSONResponse:
- assert isinstance(self.llm, AsyncLLM), "/release_memory endpoint is only supported with AsyncLLM()"
- await self.llm.collective_rpc('sleep', args=(request.tags,))
+ assert isinstance(self.generator, AsyncLLM), "/release_memory endpoint is only supported with AsyncLLM()"
+ await self.generator.collective_rpc('sleep', args=(request.tags,))
return JSONResponse(content={"status": "success"})
async def resume_memory(self, request: MemoryUpdateRequest) -> JSONResponse:
- assert isinstance(self.llm, AsyncLLM), "/resume_memory endpoint is only supported with AsyncLLM()"
- await self.llm.collective_rpc('wakeup', args=(request.tags,))
+ assert isinstance(self.generator, AsyncLLM), "/resume_memory endpoint is only supported with AsyncLLM()"
+ await self.generator.collective_rpc('wakeup', args=(request.tags,))
return JSONResponse(content={"status": "success"})
async def update_weights(self, request: UpdateWeightsRequest) -> JSONResponse:
- assert isinstance(self.llm, AsyncLLM), "/update_weights endpoint is only supported with AsyncLLM()"
- await self.llm.collective_rpc('update_weights', args=(request.weights,))
+ assert isinstance(self.generator, AsyncLLM), "/update_weights endpoint is only supported with AsyncLLM()"
+ await self.generator.collective_rpc('update_weights', args=(request.weights,))
return JSONResponse(content={"status": "success"})
async def get_server_info(self) -> JSONResponse:
- return JSONResponse(content={"disaggregated_params": self.llm.disaggregated_params})
+ return JSONResponse(content={"disaggregated_params": self.generator.disaggregated_params})
+
+ async def openai_image_generation(
+ self,
+ request: ImageGenerationRequest,
+ raw_request: Request
+ ) -> Response:
+ """OpenAI-compatible image generation endpoint.
+
+ Follows the OpenAI Images API specification for image generation.
+ """
+ try:
+ image_id = f"image_{uuid.uuid4().hex}"
+ params = parse_visual_gen_params(request, image_id)
+ logger.info(f"Generating image: {image_id} with params: {params} and prompt: {request.prompt}")
+
+ if request.negative_prompt is not None:
+ inputs = visual_gen_inputs({"prompt": request.prompt, "negative_prompt": request.negative_prompt})
+ else:
+ inputs = visual_gen_inputs(request.prompt)
+ output = self.generator.generate(inputs=inputs, params=params)
+ if output.image is None:
+ return self.create_error_response(
+ message="Image generation failed",
+ err_type="InternalServerError",
+ status_code=HTTPStatus.INTERNAL_SERVER_ERROR,
+ )
+
+ # Build response
+ output_images = output.image
+ MediaStorage.save_image(
+ output_images,
+ self.media_storage_path / f"{image_id}.png",
+ )
+
+ if not isinstance(output_images, list):
+ output_images = [output_images]
+
+ if request.response_format == "b64_json":
+ data = [
+ ImageObject(
+ b64_json=base64.b64encode(MediaStorage.convert_image_to_bytes(image)).decode('utf-8'),
+ revised_prompt=request.prompt
+ ) for image in output_images
+ ]
+
+ response = ImageGenerationResponse(
+ created=int(time.time()),
+ data=data,
+ size=f"{params.width}x{params.height}",
+ )
+
+ elif request.response_format == "url":
+ # TODO: Support URL mode
+ return self._create_not_supported_error("URL mode is not supported for image generation")
+
+ return JSONResponse(content=response.model_dump())
+
+ except Exception as e:
+ logger.error(traceback.format_exc())
+ return self.create_error_response(str(e))
+
+
+ async def openai_image_edit(
+ self,
+ request: ImageEditRequest,
+ raw_request: Request
+ ) -> Response:
+ """OpenAI-compatible image editing endpoint.
+
+ Follows the OpenAI Images API specification for image editing.
+ Creates an edited or extended image given an original image and a prompt.
+ """
+ try:
+ image_id = f"image_{uuid.uuid4().hex}"
+ params = parse_visual_gen_params(request, image_id)
+ logger.info(f"Editing image: {image_id} with params: {params} and prompt: {request.prompt}")
+
+ if request.negative_prompt is not None:
+ inputs = visual_gen_inputs({"prompt": request.prompt, "negative_prompt": request.negative_prompt})
+ else:
+ inputs = visual_gen_inputs(request.prompt)
+ output = self.generator.generate(inputs=inputs, params=params)
+ if output.image is None:
+ return self.create_error_response(
+ message="Image editing failed",
+ err_type="InternalServerError",
+ status_code=HTTPStatus.INTERNAL_SERVER_ERROR,
+ )
+
+ # Build response
+ output_images = output.image
+ MediaStorage.save_image(
+ output_images,
+ self.media_storage_path / f"{image_id}.png",
+ )
+
+ if not isinstance(output_images, list):
+ output_images = [output_images]
+
+ response = ImageGenerationResponse(
+ created=int(time.time()),
+ data=[
+ ImageObject(
+ b64_json=base64.b64encode(MediaStorage.convert_image_to_bytes(image)).decode('utf-8'),
+ revised_prompt=request.prompt
+ ) for image in output_images
+ ],
+ size=f"{params.width}x{params.height}",
+ )
+
+ return JSONResponse(content=response.model_dump())
+
+ except Exception as e:
+ logger.error(traceback.format_exc())
+ return self.create_error_response(message=str(e), err_type="InternalServerError", status_code=HTTPStatus.INTERNAL_SERVER_ERROR)
+
+ async def openai_video_generation_sync(
+ self,
+ raw_request: Request
+ ) -> Response:
+ """Synchronous video generation endpoint.
+
+ Waits for video generation to complete before returning.
+ Compatible with simple use cases where waiting is acceptable.
+
+ Supports both JSON and multipart/form-data requests:
+ - JSON: Send VideoGenerationRequest as application/json
+ - Multipart: Send form fields + optional input_reference file
+ """
+ try:
+ # Parse request based on content-type
+ request = await self._parse_video_generation_request(raw_request)
+
+ video_id = f"video_{uuid.uuid4().hex}"
+ params = parse_visual_gen_params(request, video_id, media_storage_path=str(self.media_storage_path))
+ logger.info(f"Generating video: {video_id} with params: {params} and prompt: {request.prompt}")
+
+ if request.negative_prompt is not None:
+ inputs = visual_gen_inputs({"prompt": request.prompt, "negative_prompt": request.negative_prompt})
+ else:
+ inputs = visual_gen_inputs(request.prompt)
+ output = self.generator.generate(inputs=inputs, params=params)
+ if output.video is None:
+ return self.create_error_response(
+ message="Video generation failed",
+ err_type="InternalServerError",
+ status_code=HTTPStatus.INTERNAL_SERVER_ERROR,
+ )
+
+ MediaStorage.save_video(
+ video=output.video,
+ output_path=self.media_storage_path / f"{video_id}.mp4",
+ audio=output.audio,
+ frame_rate=request.fps or params.frame_rate,
+ )
+
+ return FileResponse(
+ self.media_storage_path / f"{video_id}.mp4",
+ media_type="video/mp4",
+ filename=f"{video_id}.mp4",
+ )
+
+ except ValueError as e:
+ logger.error(f"Request parsing error: {e}")
+ return self.create_error_response(str(e))
+ except Exception as e:
+ logger.error(traceback.format_exc())
+ return self.create_error_response(str(e))
+
+ async def _parse_video_generation_request(
+ self,
+ raw_request: Request,
+ ) -> VideoGenerationRequest:
+ """Parse video generation request from either JSON or multipart/form-data.
+
+ Supports both:
+ - application/json: Standard JSON request with VideoGenerationRequest model
+ - multipart/form-data: Form fields + file upload for input_reference
+ """
+ content_type = raw_request.headers.get("content-type", "")
+
+ if "application/json" in content_type:
+ # Parse as JSON using Pydantic model
+ body = await raw_request.json()
+ return VideoGenerationRequest(**body)
+
+ if "multipart/form-data" in content_type:
+ # Parse multipart/form-data manually
+ form = await raw_request.form()
+
+ # Extract all fields and convert to proper types
+ data = {}
+
+ # Required field
+ if "prompt" in form:
+ data["prompt"] = form["prompt"]
+ else:
+ raise ValueError("'prompt' is required")
+
+ # Optional string fields
+ for field in ["model", "size", "negative_prompt"]:
+ if field in form and form[field]:
+ data[field] = form[field]
+
+ # Optional numeric fields
+ if "seconds" in form and form["seconds"]:
+ data["seconds"] = float(form["seconds"])
+ if "fps" in form and form["fps"]:
+ data["fps"] = int(form["fps"])
+ if "n" in form and form["n"]:
+ data["n"] = int(form["n"])
+ if "num_inference_steps" in form and form["num_inference_steps"]:
+ data["num_inference_steps"] = int(form["num_inference_steps"])
+ if "guidance_scale" in form and form["guidance_scale"]:
+ data["guidance_scale"] = float(form["guidance_scale"])
+ if "guidance_rescale" in form and form["guidance_rescale"]:
+ data["guidance_rescale"] = float(form["guidance_rescale"])
+ if "seed" in form and form["seed"]:
+ data["seed"] = int(form["seed"])
+
+ # Handle file upload for input_reference
+ if "input_reference" in form:
+ input_ref = form["input_reference"]
+ if hasattr(input_ref, "file"): # It's an UploadFile
+ data["input_reference"] = input_ref
+
+ return VideoGenerationRequest(**data)
+
+ else:
+ raise ValueError(f"Unsupported content-type: {content_type}. Use 'application/json' or 'multipart/form-data'")
+
+ async def openai_video_generation_async(
+ self,
+ raw_request: Request,
+ ) -> Response:
+ """Asynchronous video generation endpoint (OpenAI Videos API compatible).
+
+ Creates a video generation job and returns immediately with job metadata.
+ The video is generated in the background and stored in media storage.
+ Client can poll GET /v1/videos/{video_id} to check status and retrieve the video.
+
+ Supports both JSON and multipart/form-data requests:
+ - JSON: Send VideoGenerationRequest as application/json
+ - Multipart: Send form fields + optional input_reference file
+ """
+ try:
+ # Parse request based on content-type
+ request = await self._parse_video_generation_request(raw_request)
+
+ video_id = f"video_{uuid.uuid4().hex}"
+ params = parse_visual_gen_params(request, video_id, media_storage_path=str(self.media_storage_path))
+ logger.info(f"Generating video: {video_id} with params: {params} and prompt: {request.prompt}")
+
+ # Start background generation task
+ self.video_gen_tasks[video_id] = asyncio.create_task(
+ self._generate_video_background(
+ video_id=video_id,
+ request=request,
+ params=params,
+ )
+ )
+
+ # Return job metadata immediately
+ video_job = VideoJob(
+ created_at=int(time.time()),
+ id=video_id,
+ model=request.model or self.model,
+ prompt=request.prompt,
+ status="queued",
+ duration=request.seconds,
+ fps=request.fps,
+ size=f"{params.width}x{params.height}",
+ )
+ await VIDEO_STORE.upsert(video_id, video_job)
+
+ return JSONResponse(content=video_job.model_dump(), status_code=202)
+
+ except ValueError as e:
+ logger.error(f"Request parsing error: {e}")
+ return self.create_error_response(str(e))
+ except Exception as e:
+ logger.error(traceback.format_exc())
+ return self.create_error_response(str(e))
+
+ async def _generate_video_background(
+ self,
+ video_id: str,
+ request: VideoGenerationRequest,
+ params: VisualGenParams,
+ ):
+ """Background task to generate video and save to storage."""
+ try:
+ if request.negative_prompt is not None:
+ inputs = visual_gen_inputs({"prompt": request.prompt, "negative_prompt": request.negative_prompt})
+ else:
+ inputs = visual_gen_inputs(request.prompt)
+ future = self.generator.generate_async(inputs=inputs, params=params)
+ output = await future.result()
+
+ if output.video is None:
+ return self.create_error_response(
+ message="Video generation failed",
+ err_type="InternalServerError",
+ status_code=HTTPStatus.INTERNAL_SERVER_ERROR,
+ )
+
+ MediaStorage.save_video(
+ video=output.video,
+ output_path=self.media_storage_path / f"{video_id}.mp4",
+ audio=output.audio,
+ frame_rate=request.fps or params.frame_rate,
+ )
+ job = await VIDEO_STORE.get(video_id)
+ if job:
+ job.status = "completed"
+ job.completed_at = int(time.time())
+ await VIDEO_STORE.upsert(video_id, job)
+
+ except Exception as e:
+ logger.error(traceback.format_exc())
+ job = await VIDEO_STORE.get(video_id)
+ if job:
+ job.status = "failed"
+ job.completed_at = int(time.time())
+ job.error = str(e)
+ await VIDEO_STORE.upsert(video_id, job)
+
+ async def list_videos(
+ self,
+ raw_request: Request
+ ) -> Response:
+ """List all generated videos.
+
+ GET /v1/videos
+ Returns a list of generated video metadata (job details).
+ """
+ try:
+ # List videos from storage
+ video_jobs = await VIDEO_STORE.list_values()
+
+ # Convert to API format
+ response = VideoJobList(
+ data=video_jobs,
+ )
+ return JSONResponse(content=response.model_dump())
+
+ except Exception as e:
+ logger.error(traceback.format_exc())
+ return self.create_error_response(str(e))
+
+ async def get_video_metadata(
+ self,
+ video_id: str,
+ raw_request: Request
+ ) -> Response:
+ """Get video metadata by ID.
+
+ GET /v1/videos/{video_id}
+ Retrieves the metadata (job status and details) for a specific generated video.
+ """
+ try:
+ logger.info(f"Getting video metadata: {video_id}")
+ # Get metadata from storage
+ job = await VIDEO_STORE.get(video_id)
+ if not job:
+ return self.create_error_response(
+ f"Video {video_id} not found",
+ err_type="NotFoundError",
+ status_code=HTTPStatus.NOT_FOUND
+ )
+
+ # Ensure it's a video
+ if job.object != "video":
+ return self.create_error_response(
+ f"Resource {video_id} is not a video",
+ err_type="BadRequestError",
+ status_code=HTTPStatus.BAD_REQUEST
+ )
+
+ return JSONResponse(content=job.model_dump())
+
+ except Exception as e:
+ logger.error(traceback.format_exc())
+ return self.create_error_response(str(e))
+
+ async def get_video_content(
+ self,
+ video_id: str,
+ raw_request: Request
+ ) -> Response:
+ """Download video file by ID.
+
+ GET /v1/videos/{video_id}/content
+ Downloads the generated video file.
+ """
+ try:
+ # Get metadata first to check status
+ job = await VIDEO_STORE.get(video_id)
+ if not job:
+ return self.create_error_response(
+ f"Video {video_id} not found",
+ err_type="NotFoundError",
+ status_code=HTTPStatus.NOT_FOUND
+ )
+
+ # Ensure it's a video and completed
+ if job.object != "video":
+ return self.create_error_response(
+ f"Resource {video_id} is not a video",
+ err_type="BadRequestError",
+ status_code=HTTPStatus.BAD_REQUEST
+ )
+
+ if job.status != "completed":
+ return self.create_error_response(
+ f"Video {video_id} is not ready (status: {job.status})",
+ err_type="BadRequestError",
+ status_code=HTTPStatus.BAD_REQUEST
+ )
+
+ video_file_name = f"{video_id}.mp4"
+ if os.path.exists(self.media_storage_path / video_file_name):
+ return FileResponse(
+ self.media_storage_path / video_file_name,
+ media_type="video/mp4",
+ filename=video_file_name,
+ )
+ else:
+ return self.create_error_response(
+ f"Video {video_id} not found",
+ err_type="NotFoundError",
+ status_code=HTTPStatus.NOT_FOUND
+ )
+
+ except Exception as e:
+ logger.error(traceback.format_exc())
+ return self.create_error_response(str(e))
+
+ async def delete_video(
+ self,
+ video_id: str,
+ raw_request: Request
+ ) -> Response:
+ """Delete a video by ID.
+
+ DELETE /v1/videos/{video_id}
+ Deletes a generated video by its ID.
+ """
+ try:
+ # Check if video exists
+ job = await VIDEO_STORE.get(video_id)
+ if not job:
+ return self.create_error_response(
+ f"Video {video_id} not found",
+ err_type="NotFoundError",
+ status_code=HTTPStatus.NOT_FOUND
+ )
+
+ # Ensure it's a video
+ if job.object != "video":
+ return self.create_error_response(
+ f"Resource {video_id} is not a video",
+ err_type="BadRequestError",
+ status_code=HTTPStatus.BAD_REQUEST
+ )
+
+ # Delete the video
+ success = await VIDEO_STORE.pop(video_id)
+ video_file_name = f"{video_id}.mp4"
+
+ if os.path.exists(self.media_storage_path / video_file_name):
+ os.remove(self.media_storage_path / video_file_name)
+
+ return JSONResponse(content={"deleted": success is not None})
+
+ except Exception as e:
+ logger.error(traceback.format_exc())
+ return self.create_error_response(str(e))
async def __call__(self, host, port, sockets: list[socket.socket] | None = None):
# Store the binding address for server registration
diff --git a/tensorrt_llm/serve/visual_gen_utils.py b/tensorrt_llm/serve/visual_gen_utils.py
new file mode 100644
index 0000000000..f0cd31f7fb
--- /dev/null
+++ b/tensorrt_llm/serve/visual_gen_utils.py
@@ -0,0 +1,112 @@
+import asyncio
+import base64
+import os
+import shutil
+from typing import Any, Dict, List, Optional
+
+from tensorrt_llm.llmapi.visual_gen import VisualGenParams
+from tensorrt_llm.serve.openai_protocol import (
+ ImageEditRequest,
+ ImageGenerationRequest,
+ VideoGenerationRequest,
+)
+
+
+def parse_visual_gen_params(
+ request: ImageGenerationRequest | VideoGenerationRequest | ImageEditRequest,
+ id: str,
+ media_storage_path: Optional[str] = None,
+) -> VisualGenParams:
+ params = VisualGenParams()
+ params.prompt = request.prompt
+ if request.negative_prompt is not None:
+ params.negative_prompt = request.negative_prompt
+ if request.size is not None and request.size != "auto":
+ params.width, params.height = map(int, request.size.split("x"))
+ if request.guidance_scale is not None:
+ params.guidance_scale = request.guidance_scale
+ if request.guidance_rescale is not None:
+ params.guidance_rescale = request.guidance_rescale
+
+ if isinstance(request, ImageGenerationRequest) or isinstance(request, ImageEditRequest):
+ if request.num_inference_steps is not None:
+ params.num_inference_steps = request.num_inference_steps
+ elif isinstance(request, ImageGenerationRequest) and request.quality == "hd":
+ params.num_inference_steps = 30
+ if request.n is not None:
+ params.num_images_per_prompt = request.n
+ if isinstance(request, ImageEditRequest):
+ if request.image is not None:
+ if isinstance(request.image, list):
+ params.image = [base64.b64decode(image) for image in request.image]
+ else:
+ params.image = [base64.b64decode(request.image)]
+ if request.mask is not None:
+ if isinstance(request.mask, list):
+ params.mask = [base64.b64decode(mask) for mask in request.mask]
+ else:
+ params.mask = base64.b64decode(request.mask)
+
+ elif isinstance(request, VideoGenerationRequest):
+ if request.num_inference_steps is not None:
+ params.num_inference_steps = request.num_inference_steps
+ if request.input_reference is not None:
+ if media_storage_path is None:
+ raise ValueError("media_storage_path is required when input_reference is provided")
+ params.input_reference = os.path.join(media_storage_path, f"{id}_reference.png")
+ if isinstance(request.input_reference, str):
+ with open(params.input_reference, "wb") as f:
+ f.write(base64.b64decode(request.input_reference))
+ else:
+ with open(params.input_reference, "wb") as f:
+ shutil.copyfileobj(request.input_reference.file, f)
+
+ params.frame_rate = request.fps
+ params.num_frames = int(request.seconds * request.fps)
+
+ if request.seed is not None:
+ params.seed = int(request.seed)
+
+ return params
+
+
+class AsyncDictStore:
+ """A small async-safe in-memory key-value store for dict items.
+
+ This encapsulates the usual pattern of a module-level dict guarded by
+ an asyncio.Lock and provides simple CRUD methods that are safe to call
+ concurrently from FastAPI request handlers and background tasks.
+ """
+
+ def __init__(self) -> None:
+ self._items: Dict[str, Dict[str, Any]] = {}
+ self._lock = asyncio.Lock()
+
+ async def upsert(self, key: str, value: Dict[str, Any]) -> None:
+ async with self._lock:
+ self._items[key] = value
+
+ async def update_fields(self, key: str, updates: Dict[str, Any]) -> Optional[Dict[str, Any]]:
+ async with self._lock:
+ item = self._items.get(key)
+ if item is None:
+ return None
+ item.update(updates)
+ return item
+
+ async def get(self, key: str) -> Optional[Dict[str, Any]]:
+ async with self._lock:
+ return self._items.get(key)
+
+ async def pop(self, key: str) -> Optional[Dict[str, Any]]:
+ async with self._lock:
+ return self._items.pop(key, None)
+
+ async def list_values(self) -> List[Dict[str, Any]]:
+ async with self._lock:
+ return list(self._items.values())
+
+
+# Global stores shared by OpenAI entrypoints
+# [request_id, dict]
+VIDEO_STORE = AsyncDictStore()
diff --git a/tests/integration/defs/examples/test_visual_gen.py b/tests/integration/defs/examples/test_visual_gen.py
new file mode 100644
index 0000000000..2b48bab322
--- /dev/null
+++ b/tests/integration/defs/examples/test_visual_gen.py
@@ -0,0 +1,288 @@
+# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+# SPDX-License-Identifier: Apache-2.0
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Integration tests: VBench dimension scores for WAN and LTX-2 (TRT-LLM vs diffusers reference)."""
+
+import glob
+import json
+import os
+
+import pytest
+from defs.common import venv_check_call
+from defs.conftest import llm_models_root
+from defs.trt_test_alternative import check_call
+
+WAN_T2V_MODEL_SUBPATH = "Wan2.1-T2V-1.3B-Diffusers"
+VISUAL_GEN_OUTPUT_VIDEO = "trtllm_output.mp4"
+DIFFUSERS_REFERENCE_VIDEO = "diffusers_reference.mp4"
+WAN_T2V_PROMPT = "A cute cat playing piano"
+WAN_T2V_HEIGHT = 480
+WAN_T2V_WIDTH = 832
+WAN_T2V_NUM_FRAMES = 165
+
+# Dimensions to evaluate
+VBENCH_DIMENSIONS = [
+ "subject_consistency",
+ "background_consistency",
+ "motion_smoothness",
+ "dynamic_degree",
+ "aesthetic_quality",
+ "imaging_quality",
+]
+
+# Golden VBench scores from HF reference video (WAN); TRT-LLM is compared against these.
+VBENCH_WAN_GOLDEN_SCORES = {
+ "subject_consistency": 0.9381,
+ "background_consistency": 0.9535,
+ "motion_smoothness": 0.9923,
+ "dynamic_degree": 1.0000,
+ "aesthetic_quality": 0.5033,
+ "imaging_quality": 0.3033,
+}
+
+VBENCH_REPO = "https://github.com/Vchitect/VBench.git"
+VBENCH_BRANCH = "master"
+# Pin to a fixed commit for reproducible runs
+VBENCH_COMMIT = "98b19513678e99c80d8377fda25ba53b81a491a6"
+
+
+@pytest.fixture(scope="session")
+def vbench_repo_root(llm_venv):
+ """Clone VBench repo into workspace and install; return repo root path."""
+ workspace = llm_venv.get_working_directory()
+ repo_path = os.path.join(workspace, "VBench_repo")
+ if os.path.exists(repo_path):
+ return repo_path
+ # Clone without --depth=1 so we can checkout a specific commit
+ check_call(
+ ["git", "clone", "--single-branch", "--branch", VBENCH_BRANCH, VBENCH_REPO, repo_path],
+ shell=False,
+ )
+ check_call(["git", "-C", repo_path, "checkout", VBENCH_COMMIT], shell=False)
+ # # Install VBench dependencies explicitly
+ # llm_venv.run_cmd([
+ # "-m", "pip", "install",
+ # "tqdm>=4.60.0",
+ # "openai-clip>=1.0",
+ # "pyiqa>=0.1.0", # install this might also install transformers=4.37.2, which is incompatible
+ # "easydict",
+ # "decord>=0.6.0",
+ # ])
+ return repo_path
+
+
+@pytest.fixture(scope="session")
+def wan_trtllm_video_path(llm_venv, llm_root):
+ """Generate input video via visual_gen_wan_t2v.py and return path to trtllm_output.mp4."""
+ scratch_space = llm_models_root()
+ model_path = os.path.join(scratch_space, WAN_T2V_MODEL_SUBPATH)
+ if not os.path.isdir(model_path):
+ pytest.skip(
+ f"Wan T2V model not found: {model_path} "
+ f"(set LLM_MODELS_ROOT or place {WAN_T2V_MODEL_SUBPATH} under scratch)"
+ )
+ out_dir = os.path.join(llm_venv.get_working_directory(), "visual_gen_output")
+ os.makedirs(out_dir, exist_ok=True)
+ output_path = os.path.join(out_dir, VISUAL_GEN_OUTPUT_VIDEO)
+ if os.path.isfile(output_path):
+ return output_path
+ # Install av and diffusers from main branch
+ llm_venv.run_cmd(["-m", "pip", "install", "av"])
+ llm_venv.run_cmd(
+ [
+ "-m",
+ "pip",
+ "install",
+ "git+https://github.com/huggingface/diffusers.git",
+ ]
+ )
+ script_path = os.path.join(llm_root, "examples", "visual_gen", "visual_gen_wan_t2v.py")
+ assert os.path.isfile(script_path), f"Visual gen script not found: {script_path}"
+ venv_check_call(
+ llm_venv,
+ [
+ script_path,
+ "--height",
+ str(WAN_T2V_HEIGHT),
+ "--width",
+ str(WAN_T2V_WIDTH),
+ "--num_frames",
+ str(WAN_T2V_NUM_FRAMES),
+ "--model_path",
+ model_path,
+ "--prompt",
+ WAN_T2V_PROMPT,
+ "--output_path",
+ output_path,
+ ],
+ )
+ assert os.path.isfile(output_path), f"Visual gen did not produce {output_path}"
+ return output_path
+
+
+@pytest.fixture(scope="session")
+def wan_reference_video_path(llm_venv, llm_root):
+ """Generate reference video via diffusers (hf_wan.py) using the same model checkpoint."""
+ scratch_space = llm_models_root()
+ model_path = os.path.join(scratch_space, WAN_T2V_MODEL_SUBPATH)
+ if not os.path.isdir(model_path):
+ pytest.skip(
+ f"Wan T2V model not found: {model_path} "
+ f"(set LLM_MODELS_ROOT or place {WAN_T2V_MODEL_SUBPATH} under scratch)"
+ )
+ out_dir = os.path.join(llm_venv.get_working_directory(), "visual_gen_output")
+ os.makedirs(out_dir, exist_ok=True)
+ reference_path = os.path.join(out_dir, DIFFUSERS_REFERENCE_VIDEO)
+ if os.path.isfile(reference_path):
+ return reference_path
+ hf_script = os.path.join(llm_root, "examples", "visual_gen", "hf_wan.py")
+ assert os.path.isfile(hf_script), f"Diffusers script not found: {hf_script}"
+ venv_check_call(
+ llm_venv,
+ [
+ hf_script,
+ "--model_path",
+ model_path,
+ "--prompt",
+ WAN_T2V_PROMPT,
+ "--output_path",
+ reference_path,
+ "--height",
+ str(WAN_T2V_HEIGHT),
+ "--width",
+ str(WAN_T2V_WIDTH),
+ "--num_frames",
+ str(WAN_T2V_NUM_FRAMES),
+ ],
+ )
+ assert os.path.isfile(reference_path), f"Diffusers did not produce {reference_path}"
+ return reference_path
+
+
+def _visual_gen_out_dir(llm_venv, subdir=""):
+ """Output directory for generated media; subdir e.g. 'ltx2' for model-specific outputs."""
+ base = os.path.join(llm_venv.get_working_directory(), "visual_gen_output")
+ return os.path.join(base, subdir) if subdir else base
+
+
+def _normalize_score(val):
+ """Normalize to 0-1 scale (e.g. imaging_quality can be 0-100)."""
+ if isinstance(val, bool):
+ return float(val)
+ if isinstance(val, (int, float)) and val > 1.5:
+ return val / 100.0
+ return float(val)
+
+
+def _get_per_video_scores(results, video_path_substr):
+ """From VBench results, get per-dimension score for the video whose path contains video_path_substr."""
+ scores = {}
+ for dim in VBENCH_DIMENSIONS:
+ dim_result = results[dim]
+ assert isinstance(dim_result, list) and len(dim_result) >= 2, (
+ f"Dimension '{dim}' result must be [overall_score, video_results]; got {type(dim_result)}"
+ )
+ video_results = dim_result[1]
+ for entry in video_results:
+ if video_path_substr in entry.get("video_path", ""):
+ raw = entry.get("video_results")
+ scores[dim] = _normalize_score(raw)
+ break
+ else:
+ raise AssertionError(
+ f"No video matching '{video_path_substr}' in dimension '{dim}'; "
+ f"paths: {[e.get('video_path') for e in video_results]}"
+ )
+ return scores
+
+
+def _run_vbench_and_compare_to_golden(
+ vbench_repo_root,
+ videos_dir,
+ trtllm_filename,
+ golden_scores,
+ llm_venv,
+ title,
+ max_score_diff=0.1,
+):
+ """Run VBench on videos_dir (TRT-LLM output only), compare to golden HF reference scores."""
+ output_path = os.path.join(
+ llm_venv.get_working_directory(), "vbench_eval_output", title.replace(" ", "_").lower()
+ )
+ os.makedirs(output_path, exist_ok=True)
+ evaluate_script = os.path.join(vbench_repo_root, "evaluate.py")
+ cmd = [
+ evaluate_script,
+ "--videos_path",
+ videos_dir,
+ "--output_path",
+ output_path,
+ "--mode",
+ "custom_input",
+ ]
+ cmd.extend(["--dimension"] + VBENCH_DIMENSIONS)
+ venv_check_call(llm_venv, cmd)
+ pattern = os.path.join(output_path, "*_eval_results.json")
+ result_files = glob.glob(pattern)
+ assert result_files, (
+ f"No eval results found matching {pattern}; output dir: {os.listdir(output_path)}"
+ )
+ with open(result_files[0], "r") as f:
+ results = json.load(f)
+ for dim in VBENCH_DIMENSIONS:
+ assert dim in results, (
+ f"Expected dimension '{dim}' in results; keys: {list(results.keys())}"
+ )
+ scores_trtllm = _get_per_video_scores(results, trtllm_filename)
+ scores_ref = golden_scores
+ max_len = max(len(d) for d in VBENCH_DIMENSIONS)
+ header = f"{'Dimension':<{max_len}} | {'TRT-LLM':>10} | {'HF Ref':>10} | {'Diff':>8}"
+ sep = "-" * len(header)
+ print("\n" + "=" * len(header))
+ print(f"VBench dimension scores ({title}): TRT-LLM vs golden HF reference scores")
+ print("=" * len(header))
+ print(header)
+ print(sep)
+ max_diff_val = 0.0
+ for dim in VBENCH_DIMENSIONS:
+ t, r = scores_trtllm[dim], scores_ref[dim]
+ diff = abs(t - r)
+ max_diff_val = max(max_diff_val, diff)
+ print(f"{dim:<{max_len}} | {t:>10.4f} | {r:>10.4f} | {diff:>8.4f}")
+ print(sep)
+ print(
+ f"{' (all dimensions)':<{max_len}} | (TRT-LLM) | (golden) | max_diff={max_diff_val:.4f}"
+ )
+ print("=" * len(header) + "\n")
+ for dim in VBENCH_DIMENSIONS:
+ diff = abs(scores_trtllm[dim] - scores_ref[dim])
+ assert diff < max_score_diff or scores_trtllm[dim] >= scores_ref[dim], (
+ f"Dimension '{dim}' score difference {diff:.4f} >= {max_score_diff} "
+ f"(TRT-LLM={scores_trtllm[dim]:.4f}, golden={scores_ref[dim]:.4f})"
+ )
+
+
+def test_vbench_dimension_score_wan(vbench_repo_root, wan_trtllm_video_path, llm_venv):
+ """Run VBench on WAN TRT-LLM video; compare to golden HF reference scores (diff < 0.05 or TRT-LLM >= golden)."""
+ videos_dir = os.path.dirname(wan_trtllm_video_path)
+ assert os.path.isfile(wan_trtllm_video_path), "TRT-LLM video must exist"
+ _run_vbench_and_compare_to_golden(
+ vbench_repo_root,
+ videos_dir,
+ VISUAL_GEN_OUTPUT_VIDEO,
+ VBENCH_WAN_GOLDEN_SCORES,
+ llm_venv,
+ title="WAN",
+ max_score_diff=0.05,
+ )
diff --git a/tests/integration/test_lists/test-db/l0_b200.yml b/tests/integration/test_lists/test-db/l0_b200.yml
index fc45b1eb8b..c8e88a6e1f 100644
--- a/tests/integration/test_lists/test-db/l0_b200.yml
+++ b/tests/integration/test_lists/test-db/l0_b200.yml
@@ -91,6 +91,17 @@ l0_b200:
- unittest/tools/test_layer_wise_benchmarks.py::test_performance_alignment[1]
- unittest/_torch/modeling/test_modeling_exaone4.py::TestEXAONE4::test_llm_load_1_FP8
- unittest/kv_cache_manager_v2_tests/
+ # ------------- Visual Gen tests ---------------
+ - unittest/_torch/visual_gen/test_fused_qkv.py
+ - unittest/_torch/visual_gen/test_quant_ops.py
+ - unittest/_torch/visual_gen/test_attention_integration.py
+ - unittest/_torch/visual_gen/test_attention_perf.py
+ - unittest/_torch/visual_gen/test_trtllm_serve_endpoints.py
+ - unittest/_torch/visual_gen/test_trtllm_serve_e2e.py
+ - unittest/_torch/visual_gen/test_wan.py -k "not TestWanTwoStageTransformer"
+ - unittest/_torch/visual_gen/test_wan_i2v.py
+ - unittest/_torch/visual_gen/test_model_loader.py
+ # - examples/test_visual_gen.py
- condition:
ranges:
system_gpu_count:
@@ -161,6 +172,7 @@ l0_b200:
- accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_nvfp4[moe_backend=CUTEDSL-mtp_nextn=0-fp8kv=True-attention_dp=True-cuda_graph=True-overlap_scheduler=True-torch_compile=False]
- accuracy/test_llm_api_pytorch.py::TestSeedOss_36B::test_auto_dtype
- accuracy/test_llm_api_pytorch.py::TestLlama3_1_8B_Instruct_RocketKV::test_auto_dtype
+ - unittest/_torch/visual_gen/test_wan.py::TestWanTwoStageTransformer
# ------------- AutoDeploy Backend Stages ---------------
- condition:
ranges:
diff --git a/tests/integration/test_lists/test-db/l0_dgx_b200.yml b/tests/integration/test_lists/test-db/l0_dgx_b200.yml
index 3c8eb1b6bc..1592d1247f 100644
--- a/tests/integration/test_lists/test-db/l0_dgx_b200.yml
+++ b/tests/integration/test_lists/test-db/l0_dgx_b200.yml
@@ -30,6 +30,12 @@ l0_dgx_b200:
- disaggregated/test_disaggregated.py::test_disaggregated_gpt_oss_120b_harmony[gpt_oss/gpt-oss-120b]
- accuracy/test_llm_api_pytorch.py::TestDeepSeekR1::test_nvfp4_multi_gpus[latency_adp_lmtp_tp4]
- accuracy/test_llm_api_pytorch.py::TestMiniMaxM2::test_4gpus[attention_dp=False-cuda_graph=True-overlap_scheduler=True-tp_size=4-ep_size=4] TIMEOUT (60)
+ # ------------- VisualGen multi-GPU tests ---------------
+ - unittest/_torch/visual_gen/multi_gpu
+ - unittest/_torch/visual_gen/test_wan.py::TestWanParallelism::test_cfg_2gpu_correctness
+ - unittest/_torch/visual_gen/test_wan.py::TestWanCombinedOptimizations::test_all_optimizations_combined
+ - unittest/_torch/visual_gen/test_wan_i2v.py::TestWanI2VParallelism::test_cfg_2gpu_correctness
+ - unittest/_torch/visual_gen/test_wan_i2v.py::TestWanI2VCombinedOptimizations::test_all_optimizations_combined
- condition:
ranges:
system_gpu_count:
diff --git a/tests/unittest/_torch/visual_gen/multi_gpu/__init__.py b/tests/unittest/_torch/visual_gen/multi_gpu/__init__.py
new file mode 100644
index 0000000000..fac2aaa011
--- /dev/null
+++ b/tests/unittest/_torch/visual_gen/multi_gpu/__init__.py
@@ -0,0 +1 @@
+"""Multi-GPU tests for visual generation modules."""
diff --git a/tests/unittest/_torch/visual_gen/multi_gpu/test_ulysses_attention.py b/tests/unittest/_torch/visual_gen/multi_gpu/test_ulysses_attention.py
new file mode 100644
index 0000000000..0d691cf9ae
--- /dev/null
+++ b/tests/unittest/_torch/visual_gen/multi_gpu/test_ulysses_attention.py
@@ -0,0 +1,505 @@
+"""Multi-GPU tests for Ulysses Attention.
+
+These tests use torch.multiprocessing.spawn to launch multiple processes internally.
+Run with:
+ pytest tests/visual_gen/multi_gpu/test_ulysses_attention.py -v
+"""
+
+import os
+
+os.environ["TLLM_DISABLE_MPI"] = "1"
+
+import math
+from typing import Callable
+
+import pytest
+import torch
+import torch.distributed as dist
+import torch.multiprocessing as mp
+import torch.nn.functional as F
+
+# Try to import the modules - skip tests if not available
+try:
+ from tensorrt_llm._torch.attention_backend.interface import PredefinedAttentionMask
+ from tensorrt_llm._torch.distributed import all_to_all_4d
+ from tensorrt_llm._torch.visual_gen.attention_backend import UlyssesAttention, VanillaAttention
+ from tensorrt_llm._utils import get_free_port
+
+ MODULES_AVAILABLE = True
+except ImportError:
+ MODULES_AVAILABLE = False
+
+
+@pytest.fixture(autouse=True, scope="module")
+def _cleanup_mpi_env():
+ """Clean up TLLM_DISABLE_MPI env var after tests complete."""
+ yield
+ os.environ.pop("TLLM_DISABLE_MPI", None)
+
+
+def init_distributed_worker(rank: int, world_size: int, backend: str = "gloo", port: int = 29500):
+ """Initialize distributed environment for a worker process."""
+ os.environ["MASTER_ADDR"] = "localhost"
+ os.environ["MASTER_PORT"] = str(port)
+ os.environ["RANK"] = str(rank)
+ os.environ["WORLD_SIZE"] = str(world_size)
+
+ # Use gloo backend for CPU, nccl for GPU
+ if backend == "nccl" and torch.cuda.is_available():
+ torch.cuda.set_device(rank % torch.cuda.device_count())
+ dist.init_process_group(backend="nccl", rank=rank, world_size=world_size)
+ else:
+ dist.init_process_group(backend="gloo", rank=rank, world_size=world_size)
+
+
+def cleanup_distributed():
+ """Clean up distributed environment."""
+ if dist.is_initialized():
+ dist.destroy_process_group()
+
+
+def _distributed_worker(rank, world_size, backend, test_fn, port):
+ """Worker function that runs in each process. Module-level for pickling."""
+ try:
+ init_distributed_worker(rank, world_size, backend, port)
+ test_fn(rank, world_size)
+ except Exception as e:
+ print(f"Rank {rank} failed with error: {e}")
+ raise
+ finally:
+ cleanup_distributed()
+
+
+def run_test_in_distributed(world_size: int, test_fn: Callable, use_cuda: bool = True):
+ """Run a test function in a distributed environment with multiple processes.
+
+ Args:
+ world_size: Number of processes to spawn
+ test_fn: Test function to run (must be module-level for pickling).
+ Should accept (rank, world_size) as arguments.
+ use_cuda: Whether to use CUDA (requires sufficient GPUs)
+ """
+ if not MODULES_AVAILABLE:
+ pytest.skip("Required modules not available")
+
+ if use_cuda and torch.cuda.device_count() < world_size:
+ pytest.skip(f"Test requires {world_size} GPUs, only {torch.cuda.device_count()} available")
+
+ backend = "nccl" if use_cuda else "gloo"
+
+ port = get_free_port()
+
+ # Spawn processes
+ mp.spawn(
+ _distributed_worker, args=(world_size, backend, test_fn, port), nprocs=world_size, join=True
+ )
+
+
+# =============================================================================
+# Test logic functions (module-level so they can be pickled by mp.spawn)
+# =============================================================================
+
+
+def _logic_a2a_seq_to_head(rank, world_size):
+ """all_to_all_4d: sequence sharding to head sharding."""
+ batch = 2
+ seq_per_rank = 4
+ heads = 8
+ head_dim = 64
+
+ if heads % world_size != 0:
+ heads = world_size * 2
+
+ device = torch.device(f"cuda:{rank}") if torch.cuda.is_available() else torch.device("cpu")
+
+ input_tensor = (
+ torch.randn(batch, seq_per_rank, heads, head_dim, device=device, dtype=torch.float32)
+ + rank * 100
+ )
+
+ output = all_to_all_4d(
+ input_tensor,
+ scatter_dim=2,
+ gather_dim=1,
+ process_group=None,
+ )
+
+ expected_shape = (batch, seq_per_rank * world_size, heads // world_size, head_dim)
+ assert output.shape == expected_shape, (
+ f"Rank {rank}: Expected shape {expected_shape}, got {output.shape}"
+ )
+ assert output.device == device
+
+
+def _logic_a2a_head_to_seq(rank, world_size):
+ """all_to_all_4d: head sharding to sequence sharding."""
+ batch = 2
+ seq = 16
+ heads_per_rank = 2
+ head_dim = 64
+
+ device = torch.device(f"cuda:{rank}") if torch.cuda.is_available() else torch.device("cpu")
+
+ input_tensor = torch.randn(
+ batch, seq, heads_per_rank, head_dim, device=device, dtype=torch.float32
+ )
+
+ output = all_to_all_4d(
+ input_tensor,
+ scatter_dim=1,
+ gather_dim=2,
+ process_group=None,
+ )
+
+ expected_shape = (batch, seq // world_size, heads_per_rank * world_size, head_dim)
+ assert output.shape == expected_shape, (
+ f"Rank {rank}: Expected shape {expected_shape}, got {output.shape}"
+ )
+
+
+def _logic_a2a_roundtrip(rank, world_size):
+ """all_to_all_4d: forward and backward are inverses."""
+ batch = 2
+ seq_per_rank = 4
+ heads = world_size * 4
+ head_dim = 64
+
+ device = torch.device(f"cuda:{rank}") if torch.cuda.is_available() else torch.device("cpu")
+
+ original = torch.randn(batch, seq_per_rank, heads, head_dim, device=device, dtype=torch.float32)
+
+ intermediate = all_to_all_4d(original, scatter_dim=2, gather_dim=1, process_group=None)
+ reconstructed = all_to_all_4d(intermediate, scatter_dim=1, gather_dim=2, process_group=None)
+
+ assert reconstructed.shape == original.shape
+ torch.testing.assert_close(reconstructed, original, rtol=1e-5, atol=1e-5)
+
+
+def _logic_a2a_single_process(rank, world_size):
+ """all_to_all_4d: single process returns input unchanged."""
+ batch, seq, heads, head_dim = 2, 8, 4, 64
+ device = torch.device(f"cuda:{rank}" if torch.cuda.is_available() else "cpu")
+
+ input_tensor = torch.randn(batch, seq, heads, head_dim, device=device)
+
+ output = all_to_all_4d(input_tensor, scatter_dim=2, gather_dim=1, process_group=None)
+
+ torch.testing.assert_close(output, input_tensor)
+
+
+def _logic_ulysses_init(rank, world_size):
+ """UlyssesAttention initialization."""
+ num_heads = world_size * 4
+ head_dim = 64
+
+ inner = VanillaAttention(num_heads=num_heads // world_size, head_dim=head_dim)
+ attention = UlyssesAttention(
+ inner_backend=inner,
+ process_group=None,
+ )
+
+ assert attention.num_heads == num_heads
+ assert attention.head_dim == head_dim
+ assert attention.world_size == world_size
+ assert rank >= 0 and rank < world_size
+
+
+def _logic_ulysses_forward(rank, world_size):
+ """UlyssesAttention forward pass."""
+ batch = 2
+ seq_per_rank = 8
+ num_heads = world_size * 4
+ head_dim = 64
+
+ device = torch.device(f"cuda:{rank}") if torch.cuda.is_available() else torch.device("cpu")
+
+ inner = VanillaAttention(num_heads=num_heads // world_size, head_dim=head_dim)
+ attention = UlyssesAttention(
+ inner_backend=inner,
+ process_group=None,
+ ).to(device)
+
+ q = torch.randn(batch, seq_per_rank, num_heads, head_dim, device=device)
+ k = torch.randn(batch, seq_per_rank, num_heads, head_dim, device=device)
+ v = torch.randn(batch, seq_per_rank, num_heads, head_dim, device=device)
+
+ output = attention(q, k, v, batch_size=batch, seq_len=seq_per_rank * world_size)
+
+ assert output.shape == q.shape, f"Rank {rank}: Expected shape {q.shape}, got {output.shape}"
+ assert output.device == device
+
+
+def _logic_ulysses_with_mask(rank, world_size):
+ """UlyssesAttention with attention mask."""
+ batch = 2
+ seq_per_rank = 8
+ seq_full = seq_per_rank * world_size
+ num_heads = world_size * 4
+ head_dim = 64
+
+ device = torch.device(f"cuda:{rank}") if torch.cuda.is_available() else torch.device("cpu")
+
+ inner = VanillaAttention(num_heads=num_heads // world_size, head_dim=head_dim)
+ attention = UlyssesAttention(
+ inner_backend=inner,
+ process_group=None,
+ ).to(device)
+
+ q = torch.randn(batch, seq_per_rank, num_heads, head_dim, device=device)
+ k = torch.randn(batch, seq_per_rank, num_heads, head_dim, device=device)
+ v = torch.randn(batch, seq_per_rank, num_heads, head_dim, device=device)
+
+ mask = PredefinedAttentionMask.CAUSAL
+
+ output = attention(q, k, v, batch_size=batch, seq_len=seq_full, attention_mask=mask)
+
+ assert output.shape == q.shape
+
+
+def _logic_ulysses_vs_standard_multi_gpu(rank, world_size):
+ """UlyssesAttention across multiple GPUs matches standard attention on the full sequence."""
+ batch = 2
+ seq_per_rank = 8
+ seq_full = seq_per_rank * world_size
+ num_heads = world_size * 4
+ head_dim = 64
+
+ device = torch.device(f"cuda:{rank}" if torch.cuda.is_available() else "cpu")
+
+ # Every rank generates identical full tensors using the same seed.
+ torch.manual_seed(42)
+ q_full = torch.randn(batch, seq_full, num_heads, head_dim, device=device)
+ k_full = torch.randn(batch, seq_full, num_heads, head_dim, device=device)
+ v_full = torch.randn(batch, seq_full, num_heads, head_dim, device=device)
+
+ # Each rank takes its sequence shard.
+ q_shard = q_full[:, rank * seq_per_rank : (rank + 1) * seq_per_rank].contiguous()
+ k_shard = k_full[:, rank * seq_per_rank : (rank + 1) * seq_per_rank].contiguous()
+ v_shard = v_full[:, rank * seq_per_rank : (rank + 1) * seq_per_rank].contiguous()
+
+ # Ulysses attention on shards.
+ inner = VanillaAttention(num_heads=num_heads // world_size, head_dim=head_dim)
+ attention = UlyssesAttention(
+ inner_backend=inner,
+ process_group=None,
+ ).to(device)
+
+ ulysses_output = attention(q_shard, k_shard, v_shard, batch_size=batch, seq_len=seq_full)
+
+ # Standard attention on the full tensors.
+ q_std = q_full.transpose(1, 2) # [B, H, S, D]
+ k_std = k_full.transpose(1, 2)
+ v_std = v_full.transpose(1, 2)
+
+ std_output = F.scaled_dot_product_attention(
+ q_std, k_std, v_std, scale=1.0 / math.sqrt(head_dim), dropout_p=0.0
+ )
+ std_output = std_output.transpose(1, 2).contiguous() # [B, S, H, D]
+
+ # Compare the shard slice.
+ expected_shard = std_output[:, rank * seq_per_rank : (rank + 1) * seq_per_rank]
+ torch.testing.assert_close(
+ ulysses_output,
+ expected_shard,
+ rtol=1e-4,
+ atol=1e-4,
+ msg=f"Rank {rank}: Ulysses multi-GPU output differs from standard attention",
+ )
+
+
+def _logic_ulysses_invalid_heads(rank, world_size):
+ """Invalid head count (not divisible by world_size) cannot be sharded."""
+ assert rank >= 0 and rank < world_size
+
+ num_heads = world_size * 4 + 1 # Not divisible
+ head_dim = 64
+
+ # With the decorator pattern, the caller is responsible for sharding heads.
+ # num_heads // world_size truncates, so the wrapper's computed full head
+ # count won't match the original.
+ sharded_heads = num_heads // world_size
+ inner = VanillaAttention(num_heads=sharded_heads, head_dim=head_dim)
+ attention = UlyssesAttention(inner_backend=inner, process_group=None)
+ assert attention.num_heads != num_heads # Truncation means mismatch
+
+
+def _logic_different_batch_sizes(rank, world_size):
+ """Various batch sizes."""
+ num_heads = world_size * 4
+ head_dim = 64
+ seq_per_rank = 8
+ device = torch.device(f"cuda:{rank}") if torch.cuda.is_available() else torch.device("cpu")
+
+ inner = VanillaAttention(num_heads=num_heads // world_size, head_dim=head_dim)
+ attention = UlyssesAttention(
+ inner_backend=inner,
+ process_group=None,
+ ).to(device)
+
+ for batch_size in [1, 2, 4, 8]:
+ q = torch.randn(batch_size, seq_per_rank, num_heads, head_dim, device=device)
+ k = torch.randn(batch_size, seq_per_rank, num_heads, head_dim, device=device)
+ v = torch.randn(batch_size, seq_per_rank, num_heads, head_dim, device=device)
+
+ output = attention(q, k, v, batch_size=batch_size, seq_len=seq_per_rank * world_size)
+ assert output.shape == q.shape
+
+
+def _logic_different_head_dims(rank, world_size):
+ """Various head dims."""
+ batch = 2
+ seq_per_rank = 8
+ num_heads = world_size * 4
+ device = torch.device(f"cuda:{rank}") if torch.cuda.is_available() else torch.device("cpu")
+
+ for head_dim in [32, 64, 128]:
+ inner = VanillaAttention(num_heads=num_heads // world_size, head_dim=head_dim)
+ attention = UlyssesAttention(
+ inner_backend=inner,
+ process_group=None,
+ ).to(device)
+
+ q = torch.randn(batch, seq_per_rank, num_heads, head_dim, device=device)
+ k = torch.randn(batch, seq_per_rank, num_heads, head_dim, device=device)
+ v = torch.randn(batch, seq_per_rank, num_heads, head_dim, device=device)
+
+ output = attention(q, k, v, batch_size=batch, seq_len=seq_per_rank * world_size)
+ assert output.shape == q.shape
+
+
+def _logic_world_size_4(rank, world_size):
+ """4-GPU test."""
+ batch = 2
+ seq_per_rank = 16
+ num_heads = world_size * 8 # 32 heads total
+ head_dim = 64
+
+ device = torch.device(f"cuda:{rank}") if torch.cuda.is_available() else torch.device("cpu")
+
+ inner = VanillaAttention(num_heads=num_heads // world_size, head_dim=head_dim)
+ attention = UlyssesAttention(
+ inner_backend=inner,
+ process_group=None,
+ ).to(device)
+
+ q = torch.randn(batch, seq_per_rank, num_heads, head_dim, device=device)
+ k = torch.randn(batch, seq_per_rank, num_heads, head_dim, device=device)
+ v = torch.randn(batch, seq_per_rank, num_heads, head_dim, device=device)
+
+ output = attention(q, k, v, batch_size=batch, seq_len=seq_per_rank * world_size)
+ assert output.shape == q.shape
+
+
+# =============================================================================
+# Test classes
+# =============================================================================
+
+
+class TestAllToAll4D:
+ """Tests for all_to_all_4d function."""
+
+ def test_all_to_all_4d_sequence_to_head(self):
+ """Test sequence sharding to head sharding transformation."""
+ run_test_in_distributed(world_size=2, test_fn=_logic_a2a_seq_to_head, use_cuda=True)
+
+ def test_all_to_all_4d_head_to_sequence(self):
+ """Test head sharding to sequence sharding transformation."""
+ run_test_in_distributed(world_size=2, test_fn=_logic_a2a_head_to_seq, use_cuda=True)
+
+ def test_all_to_all_4d_roundtrip(self):
+ """Test that forward and backward all-to-all are inverses."""
+ run_test_in_distributed(world_size=2, test_fn=_logic_a2a_roundtrip, use_cuda=True)
+
+ def test_all_to_all_4d_single_process(self):
+ """Test that single process returns input unchanged."""
+ run_test_in_distributed(world_size=1, test_fn=_logic_a2a_single_process, use_cuda=True)
+
+
+class TestUlyssesAttention:
+ """Tests for UlyssesAttention module."""
+
+ def test_ulysses_attention_initialization(self):
+ """Test UlyssesAttention initialization."""
+ run_test_in_distributed(world_size=2, test_fn=_logic_ulysses_init, use_cuda=True)
+
+ def test_ulysses_attention_forward(self):
+ """Test UlyssesAttention forward pass."""
+ run_test_in_distributed(world_size=2, test_fn=_logic_ulysses_forward, use_cuda=True)
+
+ def test_ulysses_attention_with_mask(self):
+ """Test UlyssesAttention with attention mask."""
+ run_test_in_distributed(world_size=2, test_fn=_logic_ulysses_with_mask, use_cuda=True)
+
+ def test_ulysses_vs_standard_attention_single_gpu(self):
+ """Compare UlyssesAttention with standard attention on single GPU."""
+ if not MODULES_AVAILABLE:
+ pytest.skip("Required modules not available")
+
+ if not torch.cuda.is_available():
+ pytest.skip("Test requires CUDA")
+
+ batch = 2
+ seq = 16
+ num_heads = 8
+ head_dim = 64
+ device = torch.device("cuda:0")
+
+ inner = VanillaAttention(num_heads=num_heads, head_dim=head_dim)
+ ulysses_attn = UlyssesAttention(
+ inner_backend=inner,
+ process_group=None,
+ ).to(device)
+
+ torch.manual_seed(42)
+ q = torch.randn(batch, seq, num_heads, head_dim, device=device)
+ k = torch.randn(batch, seq, num_heads, head_dim, device=device)
+ v = torch.randn(batch, seq, num_heads, head_dim, device=device)
+
+ ulysses_output = ulysses_attn(q, k, v, batch_size=batch, seq_len=seq)
+
+ q_std = q.transpose(1, 2) # [B, H, S, D]
+ k_std = k.transpose(1, 2)
+ v_std = v.transpose(1, 2)
+
+ std_output = F.scaled_dot_product_attention(
+ q_std, k_std, v_std, scale=1.0 / math.sqrt(head_dim), dropout_p=0.0
+ )
+ std_output = std_output.transpose(1, 2).contiguous() # [B, S, H, D]
+
+ torch.testing.assert_close(
+ ulysses_output,
+ std_output,
+ rtol=1e-4,
+ atol=1e-4,
+ msg="Ulysses attention output differs from standard attention",
+ )
+
+ def test_ulysses_vs_standard_attention_multi_gpu(self):
+ """Compare UlyssesAttention across GPUs with standard attention on full sequence."""
+ run_test_in_distributed(
+ world_size=2, test_fn=_logic_ulysses_vs_standard_multi_gpu, use_cuda=True
+ )
+
+ def test_ulysses_attention_invalid_heads(self):
+ """Test that invalid head count raises error."""
+ run_test_in_distributed(world_size=2, test_fn=_logic_ulysses_invalid_heads, use_cuda=False)
+
+
+class TestUlyssesAttentionEdgeCases:
+ """Edge case tests for UlyssesAttention."""
+
+ def test_different_batch_sizes(self):
+ """Test with various batch sizes."""
+ run_test_in_distributed(world_size=2, test_fn=_logic_different_batch_sizes, use_cuda=True)
+
+ def test_different_head_dims(self):
+ """Test with various head dims."""
+ run_test_in_distributed(world_size=2, test_fn=_logic_different_head_dims, use_cuda=True)
+
+ def test_world_size_4(self):
+ """Test with 4 GPUs."""
+ run_test_in_distributed(world_size=4, test_fn=_logic_world_size_4, use_cuda=True)
+
+
+if __name__ == "__main__":
+ pytest.main([__file__, "-v"])
diff --git a/tests/unittest/_torch/visual_gen/test_attention_integration.py b/tests/unittest/_torch/visual_gen/test_attention_integration.py
new file mode 100644
index 0000000000..e346421d76
--- /dev/null
+++ b/tests/unittest/_torch/visual_gen/test_attention_integration.py
@@ -0,0 +1,540 @@
+# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+# SPDX-License-Identifier: Apache-2.0
+"""Test WAN Attention Integration.
+
+Compares the new integrated attention (using TRT-LLM backend) with the original
+naive implementation to ensure numerical equivalence.
+"""
+
+from types import SimpleNamespace
+
+import pytest
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+from tensorrt_llm._torch.modules.rms_norm import RMSNorm
+from tensorrt_llm._torch.visual_gen.config import AttentionConfig, DiffusionModelConfig
+
+# Import new integrated versions
+from tensorrt_llm._torch.visual_gen.modules.attention import Attention, QKVMode, apply_rotary_emb
+
+# ============================================================================
+# Original naive implementations for comparison
+# ============================================================================
+
+
+class NaiveWanSelfAttention(nn.Module):
+ """Original naive self-attention implementation (for comparison)."""
+
+ def __init__(
+ self, hidden_size: int, num_heads: int, head_dim: int, eps: float = 1e-6, dtype=None
+ ):
+ super().__init__()
+ self.num_heads = num_heads
+ self.head_dim = head_dim
+ self.hidden_size = hidden_size
+
+ # fused QKV projection
+ self.to_qkv = nn.Linear(hidden_size, 3 * hidden_size, dtype=dtype)
+ self.norm_q = RMSNorm(hidden_size=hidden_size, eps=eps, dtype=dtype, has_weights=True)
+ self.norm_k = RMSNorm(hidden_size=hidden_size, eps=eps, dtype=dtype, has_weights=True)
+ self.to_out = nn.ModuleList([nn.Linear(hidden_size, hidden_size, dtype=dtype)])
+
+ def forward(self, hidden_states, freqs_cos, freqs_sin):
+ B, S = hidden_states.shape[:2]
+
+ q, k, v = self.to_qkv(hidden_states).chunk(3, dim=-1)
+
+ q = self.norm_q(q).view(B, S, self.num_heads, self.head_dim).transpose(1, 2)
+ k = self.norm_k(k).view(B, S, self.num_heads, self.head_dim).transpose(1, 2)
+ v = v.view(B, S, self.num_heads, self.head_dim).transpose(1, 2)
+
+ if freqs_cos is not None and freqs_sin is not None:
+ q = apply_rotary_emb(q, freqs_cos, freqs_sin)
+ k = apply_rotary_emb(k, freqs_cos, freqs_sin)
+
+ out = F.scaled_dot_product_attention(q, k, v, is_causal=False)
+ out = out.transpose(1, 2).flatten(2)
+ out = self.to_out[0](out)
+ return out
+
+
+class NaiveWanCrossAttention(nn.Module):
+ """Original naive cross-attention implementation (for comparison)."""
+
+ def __init__(
+ self, hidden_size: int, num_heads: int, head_dim: int, eps: float = 1e-6, dtype=None
+ ):
+ super().__init__()
+ self.num_heads = num_heads
+ self.head_dim = head_dim
+ self.hidden_size = hidden_size
+
+ self.to_q = nn.Linear(hidden_size, hidden_size, dtype=dtype)
+ self.to_k = nn.Linear(hidden_size, hidden_size, dtype=dtype)
+ self.to_v = nn.Linear(hidden_size, hidden_size, dtype=dtype)
+ self.norm_q = RMSNorm(hidden_size=hidden_size, eps=eps, dtype=dtype, has_weights=True)
+ self.norm_k = RMSNorm(hidden_size=hidden_size, eps=eps, dtype=dtype, has_weights=True)
+ self.to_out = nn.ModuleList([nn.Linear(hidden_size, hidden_size, dtype=dtype)])
+
+ def forward(self, hidden_states, encoder_hidden_states):
+ B, S = hidden_states.shape[:2]
+
+ q = self.norm_q(self.to_q(hidden_states))
+ k = self.norm_k(self.to_k(encoder_hidden_states))
+ v = self.to_v(encoder_hidden_states)
+
+ q = q.view(B, S, self.num_heads, self.head_dim).transpose(1, 2)
+ k = k.view(B, -1, self.num_heads, self.head_dim).transpose(1, 2)
+ v = v.view(B, -1, self.num_heads, self.head_dim).transpose(1, 2)
+
+ out = F.scaled_dot_product_attention(q, k, v, is_causal=False)
+ out = out.transpose(1, 2).flatten(2)
+ out = self.to_out[0](out)
+ return out
+
+
+# ============================================================================
+# Test utilities
+# ============================================================================
+
+
+def create_model_config(
+ hidden_size: int,
+ num_heads: int,
+ head_dim: int,
+ eps: float = 1e-6,
+ attn_backend: str = "VANILLA",
+):
+ """Create a mock DiffusionModelConfig for testing."""
+ pretrained_config = SimpleNamespace(
+ hidden_size=hidden_size,
+ num_attention_heads=num_heads,
+ attention_head_dim=head_dim,
+ eps=eps,
+ )
+
+ # Create a minimal config without quantization
+ config = DiffusionModelConfig(
+ pretrained_config=pretrained_config,
+ attention=AttentionConfig(backend=attn_backend),
+ skip_create_weights_in_init=False,
+ )
+ return config
+
+
+def copy_weights_self_attention(naive: NaiveWanSelfAttention, integrated: Attention):
+ """Copy weights from naive to integrated self-attention."""
+ # QKV projection: naive has to_qkv, integrated has qkv_proj
+ integrated.qkv_proj.weight.data.copy_(naive.to_qkv.weight.data)
+ if naive.to_qkv.bias is not None and integrated.qkv_proj.bias is not None:
+ integrated.qkv_proj.bias.data.copy_(naive.to_qkv.bias.data)
+
+ # QK norms
+ integrated.norm_q.weight.data.copy_(naive.norm_q.weight.data)
+ integrated.norm_k.weight.data.copy_(naive.norm_k.weight.data)
+
+ # Output projection
+ integrated.to_out[0].weight.data.copy_(naive.to_out[0].weight.data)
+ if naive.to_out[0].bias is not None and integrated.to_out[0].bias is not None:
+ integrated.to_out[0].bias.data.copy_(naive.to_out[0].bias.data)
+
+
+def copy_weights_cross_attention(naive: NaiveWanCrossAttention, integrated: Attention):
+ """Copy weights from naive to integrated cross-attention."""
+ # Q, K, V projections
+ integrated.to_q.weight.data.copy_(naive.to_q.weight.data)
+ integrated.to_k.weight.data.copy_(naive.to_k.weight.data)
+ integrated.to_v.weight.data.copy_(naive.to_v.weight.data)
+
+ if naive.to_q.bias is not None and integrated.to_q.bias is not None:
+ integrated.to_q.bias.data.copy_(naive.to_q.bias.data)
+ if naive.to_k.bias is not None and integrated.to_k.bias is not None:
+ integrated.to_k.bias.data.copy_(naive.to_k.bias.data)
+ if naive.to_v.bias is not None and integrated.to_v.bias is not None:
+ integrated.to_v.bias.data.copy_(naive.to_v.bias.data)
+
+ # QK norms
+ integrated.norm_q.weight.data.copy_(naive.norm_q.weight.data)
+ integrated.norm_k.weight.data.copy_(naive.norm_k.weight.data)
+
+ # Output projection
+ integrated.to_out[0].weight.data.copy_(naive.to_out[0].weight.data)
+ if naive.to_out[0].bias is not None and integrated.to_out[0].bias is not None:
+ integrated.to_out[0].bias.data.copy_(naive.to_out[0].bias.data)
+
+
+def generate_rope_embeddings(
+ seq_len: int, head_dim: int, device: torch.device, is_HSD: bool = False
+):
+ """Generate RoPE embeddings with full head_dim.
+
+ apply_rotary_emb expects freqs with full head_dim, then slices with [..., 0::2] and [..., 1::2].
+
+ Args:
+ is_HSD: If True, returns [1, 1, S, D] for broadcasting with [B, H, S, D] (naive)
+ If False, returns [1, S, 1, D] for broadcasting with [B, S, H, D] (integrated)
+ """
+ position = torch.arange(seq_len, device=device).unsqueeze(1)
+ # Use full head_dim - apply_rotary_emb will slice with 0::2 and 1::2
+ div_term = torch.exp(
+ torch.arange(0, head_dim, device=device) * (-torch.log(torch.tensor(10000.0)) / head_dim)
+ )
+
+ if is_HSD:
+ freqs_cos = torch.cos(position * div_term).unsqueeze(0).unsqueeze(0) # [1, 1, S, D]
+ freqs_sin = torch.sin(position * div_term).unsqueeze(0).unsqueeze(0) # [1, 1, S, D]
+ else:
+ freqs_cos = torch.cos(position * div_term).unsqueeze(0).unsqueeze(2) # [1, S, 1, D]
+ freqs_sin = torch.sin(position * div_term).unsqueeze(0).unsqueeze(2) # [1, S, 1, D]
+
+ return freqs_cos, freqs_sin
+
+
+# ============================================================================
+# Test functions
+# ============================================================================
+@pytest.mark.parametrize("attn_backend", ["VANILLA", "TRTLLM"])
+def test_self_attention_equivalence(attn_backend: str):
+ """Test that integrated self-attention produces same output as naive."""
+ print("\n" + "=" * 60)
+ print("Testing Self-Attention Equivalence")
+ print("=" * 60)
+
+ # Config
+ batch_size = 2
+ seq_len = 16
+ hidden_size = 128
+ num_heads = 4
+ head_dim = hidden_size // num_heads
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
+ dtype = torch.bfloat16 # Use bf16 since flashinfer doesn't support fp32
+
+ print(f"Config: B={batch_size}, S={seq_len}, H={hidden_size}, heads={num_heads}")
+ print(f"Device: {device}, dtype: {dtype}")
+
+ # Create models
+ naive = NaiveWanSelfAttention(hidden_size, num_heads, head_dim, dtype=dtype).to(device)
+
+ model_config = create_model_config(hidden_size, num_heads, head_dim, attn_backend=attn_backend)
+ integrated = Attention(
+ hidden_size, num_heads, qkv_mode=QKVMode.FUSE_QKV, config=model_config
+ ).to(device) # self attention
+
+ # Copy weights
+ copy_weights_self_attention(naive, integrated)
+
+ # Set to eval mode
+ naive.eval()
+ integrated.eval()
+
+ # Create inputs
+ torch.manual_seed(42)
+ hidden_states = torch.randn(batch_size, seq_len, hidden_size, device=device, dtype=dtype)
+ # Naive uses [1, 1, S, D] (HSD format) - broadcasts with [B, H, S, D]
+ freqs_cos_HSD, freqs_sin_HSD = generate_rope_embeddings(seq_len, head_dim, device, is_HSD=True)
+ # Integrated uses [1, S, 1, D] (SHD format) - broadcasts with [B, S, H, D]
+ freqs_cos_SHD, freqs_sin_SHD = generate_rope_embeddings(seq_len, head_dim, device, is_HSD=False)
+
+ # Forward pass
+ with torch.no_grad():
+ out_naive = naive(hidden_states, freqs_cos_HSD, freqs_sin_HSD)
+ out_integrated = integrated(hidden_states, freqs=(freqs_cos_SHD, freqs_sin_SHD))
+
+ # Compare (using looser tolerance for bf16)
+ max_diff = (out_naive - out_integrated).abs().max().item()
+ mean_diff = (out_naive - out_integrated).abs().mean().item()
+ is_close = torch.allclose(out_naive, out_integrated, rtol=1e-2, atol=1e-2)
+
+ print("\nResults:")
+ print(f" Output shape: naive={out_naive.shape}, integrated={out_integrated.shape}")
+ print(f" Max absolute difference: {max_diff:.2e}")
+ print(f" Mean absolute difference: {mean_diff:.2e}")
+ print(f" Outputs match (rtol=1e-2, atol=1e-2): {is_close}")
+
+ if is_close:
+ print(" ā
PASS: Self-attention outputs match!")
+ else:
+ print(" ā FAIL: Self-attention outputs differ!")
+
+ assert is_close, (
+ f"Self-attention outputs differ: max_diff={max_diff:.2e}, mean_diff={mean_diff:.2e}"
+ )
+ return is_close
+
+
+@pytest.mark.parametrize("attn_backend", ["VANILLA"])
+def test_cross_attention_equivalence(attn_backend: str):
+ """Test that integrated cross-attention produces same output as naive."""
+ print("\n" + "=" * 60)
+ print("Testing Cross-Attention Equivalence")
+ print("=" * 60)
+
+ # Config
+ batch_size = 2
+ seq_len = 16
+ encoder_seq_len = 24 # Different from query seq_len
+ hidden_size = 128
+ num_heads = 4
+ head_dim = hidden_size // num_heads
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
+ dtype = torch.bfloat16 # Use bf16 since flashinfer doesn't support fp32
+
+ print(
+ f"Config: B={batch_size}, S_q={seq_len}, S_kv={encoder_seq_len}, H={hidden_size}, heads={num_heads}"
+ )
+ print(f"Device: {device}, dtype: {dtype}")
+
+ # Create models
+ naive = NaiveWanCrossAttention(hidden_size, num_heads, head_dim, dtype=dtype).to(device)
+
+ model_config = create_model_config(hidden_size, num_heads, head_dim, attn_backend=attn_backend)
+ integrated = Attention(
+ hidden_size, num_heads, qkv_mode=QKVMode.SEPARATE_QKV, config=model_config
+ ).to(device) # cross attention
+
+ # Copy weights
+ copy_weights_cross_attention(naive, integrated)
+
+ # Set to eval mode
+ naive.eval()
+ integrated.eval()
+
+ # Create inputs
+ torch.manual_seed(42)
+ hidden_states = torch.randn(batch_size, seq_len, hidden_size, device=device, dtype=dtype)
+ encoder_hidden_states = torch.randn(
+ batch_size, encoder_seq_len, hidden_size, device=device, dtype=dtype
+ )
+
+ # Forward pass
+ with torch.no_grad():
+ out_naive = naive(hidden_states, encoder_hidden_states)
+ out_integrated = integrated(hidden_states, encoder_hidden_states)
+
+ # Compare (using looser tolerance for bf16)
+ max_diff = (out_naive - out_integrated).abs().max().item()
+ mean_diff = (out_naive - out_integrated).abs().mean().item()
+ is_close = torch.allclose(out_naive, out_integrated, rtol=1e-2, atol=1e-2)
+
+ print("\nResults:")
+ print(f" Output shape: naive={out_naive.shape}, integrated={out_integrated.shape}")
+ print(f" Max absolute difference: {max_diff:.2e}")
+ print(f" Mean absolute difference: {mean_diff:.2e}")
+ print(f" Outputs match (rtol=1e-2, atol=1e-2): {is_close}")
+
+ if is_close:
+ print(" ā
PASS: Cross-attention outputs match!")
+ else:
+ print(" ā FAIL: Cross-attention outputs differ!")
+
+ assert is_close, (
+ f"Cross-attention outputs differ: max_diff={max_diff:.2e}, mean_diff={mean_diff:.2e}"
+ )
+ return is_close
+
+
+def test_trtllm_cached_prepare():
+ """Test that TRTLLM attention cached prepare works correctly.
+
+ This test verifies that when running multiple forward passes with same B/S
+ but different q/k/v values, the cached prepare phase doesn't cause incorrect
+ results (i.e., outputs should differ when inputs differ).
+ """
+ print("\n" + "=" * 60)
+ print("Testing TRTLLM Cached Prepare Phase")
+ print("=" * 60)
+
+ # Config - same B, S for all iterations
+ batch_size = 2
+ seq_len = 16
+ hidden_size = 128
+ num_heads = 4
+ head_dim = hidden_size // num_heads
+ num_iterations = 5
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
+ dtype = torch.bfloat16
+
+ print(f"Config: B={batch_size}, S={seq_len}, H={hidden_size}, heads={num_heads}")
+ print(f"Running {num_iterations} iterations with same B/S but different inputs")
+
+ # Create models - single instance to test caching
+ naive = NaiveWanSelfAttention(hidden_size, num_heads, head_dim, dtype=dtype).to(device)
+ model_config = create_model_config(hidden_size, num_heads, head_dim, attn_backend="TRTLLM")
+ integrated = Attention(
+ hidden_size, num_heads, qkv_mode=QKVMode.FUSE_QKV, config=model_config
+ ).to(device) # self attention
+
+ # Copy weights
+ copy_weights_self_attention(naive, integrated)
+
+ naive.eval()
+ integrated.eval()
+
+ # Generate freqs (same for all iterations since S is same)
+ freqs_cos_HSD, freqs_sin_HSD = generate_rope_embeddings(seq_len, head_dim, device, is_HSD=True)
+ freqs_cos_SHD, freqs_sin_SHD = generate_rope_embeddings(seq_len, head_dim, device, is_HSD=False)
+
+ all_passed = True
+ outputs_integrated = []
+
+ with torch.no_grad():
+ for i in range(num_iterations):
+ # Different random inputs for each iteration
+ torch.manual_seed(42 + i) # Different seed each time
+ hidden_states = torch.randn(
+ batch_size, seq_len, hidden_size, device=device, dtype=dtype
+ )
+
+ out_naive = naive(hidden_states, freqs_cos_HSD, freqs_sin_HSD)
+ out_integrated = integrated(hidden_states, freqs=(freqs_cos_SHD, freqs_sin_SHD))
+
+ # Check this iteration matches naive
+ max_diff = (out_naive - out_integrated).abs().max().item()
+ is_close = torch.allclose(out_naive, out_integrated, rtol=1e-2, atol=1e-2)
+
+ status = "ā
" if is_close else "ā"
+ print(f" Iteration {i + 1}: max_diff={max_diff:.2e} {status}")
+
+ if not is_close:
+ all_passed = False
+
+ outputs_integrated.append(out_integrated.clone())
+
+ # Additional check: outputs should be DIFFERENT across iterations
+ # (since inputs were different)
+ print("\n Checking outputs differ across iterations (inputs were different):")
+ outputs_differ = True
+ for i in range(1, num_iterations):
+ diff = (outputs_integrated[i] - outputs_integrated[0]).abs().max().item()
+ if diff < 1e-6:
+ print(
+ f" ā ļø Iteration {i + 1} output same as iteration 1 (diff={diff:.2e}) - possible caching bug!"
+ )
+ outputs_differ = False
+ else:
+ print(f" Iteration {i + 1} vs 1: diff={diff:.2e} ā
")
+
+ if all_passed and outputs_differ:
+ print("\n ā
PASS: Cached prepare works correctly!")
+ else:
+ print("\n ā FAIL: Cached prepare may have issues!")
+ all_passed = False
+
+ assert all_passed, "Cached prepare: outputs did not match naive reference"
+ assert outputs_differ, (
+ "Cached prepare: outputs should differ across iterations with different inputs"
+ )
+ return all_passed
+
+
+def test_trtllm_varying_seq_len():
+ """Test TRTLLM attention with varying sequence lengths.
+
+ This tests that the prepare phase correctly handles different seq_lens
+ and doesn't incorrectly reuse cached metadata.
+ """
+ print("\n" + "=" * 60)
+ print("Testing TRTLLM with Varying Sequence Lengths")
+ print("=" * 60)
+
+ batch_size = 2
+ hidden_size = 128
+ num_heads = 4
+ head_dim = hidden_size // num_heads
+ seq_lens = [8, 16, 32, 16, 8] # Vary seq_len, including repeats
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
+ dtype = torch.bfloat16
+
+ print(f"Config: B={batch_size}, H={hidden_size}, heads={num_heads}")
+ print(f"Testing seq_lens: {seq_lens}")
+
+ # Create models - single instance to test caching across different seq_lens
+ naive = NaiveWanSelfAttention(hidden_size, num_heads, head_dim, dtype=dtype).to(device)
+ model_config = create_model_config(hidden_size, num_heads, head_dim, attn_backend="TRTLLM")
+ integrated = Attention(
+ hidden_size, num_heads, qkv_mode=QKVMode.FUSE_QKV, config=model_config
+ ).to(device) # self attention
+
+ copy_weights_self_attention(naive, integrated)
+
+ naive.eval()
+ integrated.eval()
+
+ all_passed = True
+
+ with torch.no_grad():
+ for i, seq_len in enumerate(seq_lens):
+ torch.manual_seed(42 + i)
+ hidden_states = torch.randn(
+ batch_size, seq_len, hidden_size, device=device, dtype=dtype
+ )
+
+ freqs_cos_HSD, freqs_sin_HSD = generate_rope_embeddings(
+ seq_len, head_dim, device, is_HSD=True
+ )
+ freqs_cos_SHD, freqs_sin_SHD = generate_rope_embeddings(
+ seq_len, head_dim, device, is_HSD=False
+ )
+
+ out_naive = naive(hidden_states, freqs_cos_HSD, freqs_sin_HSD)
+ out_integrated = integrated(hidden_states, freqs=(freqs_cos_SHD, freqs_sin_SHD))
+
+ max_diff = (out_naive - out_integrated).abs().max().item()
+ is_close = torch.allclose(out_naive, out_integrated, rtol=1e-2, atol=1e-2)
+
+ status = "ā
" if is_close else "ā"
+ print(f" seq_len={seq_len:3d}: max_diff={max_diff:.2e} {status}")
+
+ if not is_close:
+ all_passed = False
+
+ if all_passed:
+ print("\n ā
PASS: Varying seq_len handled correctly!")
+ else:
+ print("\n ā FAIL: Issues with varying seq_len!")
+
+ assert all_passed, "Varying seq_len: outputs did not match naive reference"
+ return all_passed
+
+
+def run_all_tests():
+ """Run all tests and report results."""
+ print("\n" + "=" * 60)
+ print("WAN Attention Integration Tests")
+ print("=" * 60)
+
+ results = {}
+
+ # Run self-attention tests with different backends
+ for backend in ["VANILLA", "TRTLLM"]:
+ results[f"self_attention_{backend}"] = test_self_attention_equivalence(backend)
+
+ # Run cross-attention test (VANILLA only)
+ results["cross_attention_VANILLA"] = test_cross_attention_equivalence("VANILLA")
+
+ # Run TRTLLM-specific caching tests
+ results["trtllm_cached_prepare"] = test_trtllm_cached_prepare()
+ results["trtllm_varying_seq_len"] = test_trtllm_varying_seq_len()
+
+ print("\n" + "=" * 60)
+ print("Summary")
+ print("=" * 60)
+
+ all_passed = all(results.values())
+ for name, passed in results.items():
+ status = "ā
PASS" if passed else "ā FAIL"
+ print(f" {name}: {status}")
+
+ print()
+ if all_passed:
+ print("All tests passed! ā
")
+ else:
+ print("Some tests failed! ā")
+
+ return all_passed
+
+
+if __name__ == "__main__":
+ run_all_tests()
diff --git a/tests/unittest/_torch/visual_gen/test_attention_perf.py b/tests/unittest/_torch/visual_gen/test_attention_perf.py
new file mode 100644
index 0000000000..d2662105dc
--- /dev/null
+++ b/tests/unittest/_torch/visual_gen/test_attention_perf.py
@@ -0,0 +1,622 @@
+# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+# SPDX-License-Identifier: Apache-2.0
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""WAN Attention Performance Benchmark.
+
+Compares VANILLA vs TRTLLM attention backends for visual generation models.
+Uses CUDA events for precise GPU timing and supports NVTX profiling.
+
+Usage:
+ # Run all tests
+ python test_attention_perf.py
+
+ # With Nsight Systems profiling
+ nsys profile -t cuda,nvtx --nvtx-capture=range -o wan_attn_perf python test_attention_perf.py
+
+ # Run specific tests with pytest
+ pytest test_attention_perf.py -v -k "test_self_attention_perf"
+"""
+
+import time
+from contextlib import contextmanager
+from types import SimpleNamespace
+from typing import Dict, Optional, Tuple
+
+import pytest
+import torch
+
+from tensorrt_llm._torch.visual_gen.config import AttentionConfig, DiffusionModelConfig
+from tensorrt_llm._torch.visual_gen.modules.attention import Attention, QKVMode
+
+# NVTX support for profiling
+try:
+ import nvtx
+
+ NVTX_AVAILABLE = True
+ if hasattr(nvtx, "annotate"):
+ NVTX_METHOD = "annotate"
+ elif hasattr(nvtx, "range_start") and hasattr(nvtx, "range_end"):
+ NVTX_METHOD = "range"
+ else:
+ NVTX_METHOD = None
+ NVTX_AVAILABLE = False
+except ImportError:
+ NVTX_AVAILABLE = False
+ NVTX_METHOD = None
+
+# Torch profiler support
+try:
+ from torch.profiler import record_function
+
+ TORCH_PROFILER_AVAILABLE = True
+except ImportError:
+ TORCH_PROFILER_AVAILABLE = False
+
+
+# ============================================================================
+# Timing utilities
+# ============================================================================
+
+
+@contextmanager
+def cuda_timer(device: torch.device):
+ """Context manager for precise GPU timing using CUDA events."""
+ if device.type == "cuda":
+ start_event = torch.cuda.Event(enable_timing=True)
+ end_event = torch.cuda.Event(enable_timing=True)
+ start_event.record()
+
+ def get_elapsed_time():
+ end_event.record()
+ torch.cuda.synchronize()
+ return start_event.elapsed_time(end_event)
+
+ yield get_elapsed_time
+ else:
+ start_time = time.perf_counter()
+
+ def get_elapsed_time():
+ return (time.perf_counter() - start_time) * 1000
+
+ yield get_elapsed_time
+
+
+@contextmanager
+def nvtx_range(name: str):
+ """Context manager for NVTX range profiling."""
+ if NVTX_AVAILABLE and NVTX_METHOD:
+ if NVTX_METHOD == "annotate":
+ with nvtx.annotate(name):
+ yield
+ elif NVTX_METHOD == "range":
+ range_id = nvtx.range_start(name)
+ try:
+ yield
+ finally:
+ nvtx.range_end(range_id)
+ else:
+ yield
+ else:
+ yield
+
+
+@contextmanager
+def torch_profiler_range(name: str):
+ """Context manager for torch profiler range."""
+ if TORCH_PROFILER_AVAILABLE:
+ with record_function(name):
+ yield
+ else:
+ yield
+
+
+# ============================================================================
+# Test utilities
+# ============================================================================
+
+
+def create_model_config(
+ hidden_size: int,
+ num_heads: int,
+ head_dim: int,
+ eps: float = 1e-6,
+ attn_backend: str = "VANILLA",
+) -> DiffusionModelConfig:
+ """Create a mock DiffusionModelConfig for testing."""
+ pretrained_config = SimpleNamespace(
+ hidden_size=hidden_size,
+ num_attention_heads=num_heads,
+ attention_head_dim=head_dim,
+ eps=eps,
+ )
+
+ config = DiffusionModelConfig(
+ pretrained_config=pretrained_config,
+ attention=AttentionConfig(backend=attn_backend),
+ skip_create_weights_in_init=False,
+ )
+ return config
+
+
+def generate_rope_embeddings(
+ seq_len: int, head_dim: int, device: torch.device, is_HSD: bool = False
+) -> Tuple[torch.Tensor, torch.Tensor]:
+ """Generate RoPE embeddings.
+
+ Args:
+ seq_len: Sequence length
+ head_dim: Head dimension
+ device: Target device
+ is_HSD: If True, returns [1, 1, S, D] for HSD format, else [1, S, 1, D] for SHD
+
+ Returns:
+ Tuple of (freqs_cos, freqs_sin)
+ """
+ position = torch.arange(seq_len, device=device).unsqueeze(1)
+ div_term = torch.exp(
+ torch.arange(0, head_dim, device=device) * (-torch.log(torch.tensor(10000.0)) / head_dim)
+ )
+
+ if is_HSD:
+ freqs_cos = torch.cos(position * div_term).unsqueeze(0).unsqueeze(0)
+ freqs_sin = torch.sin(position * div_term).unsqueeze(0).unsqueeze(0)
+ else:
+ freqs_cos = torch.cos(position * div_term).unsqueeze(0).unsqueeze(2)
+ freqs_sin = torch.sin(position * div_term).unsqueeze(0).unsqueeze(2)
+
+ return freqs_cos, freqs_sin
+
+
+# ============================================================================
+# Performance benchmark class
+# ============================================================================
+
+
+class WanAttentionPerformanceBenchmark:
+ """Performance benchmark for WAN attention backends."""
+
+ # WAN model configurations: (batch_size, num_heads, seq_len, head_dim, description)
+ TEST_SIZES = [
+ # Wan2.1-T2V-1.3B configurations
+ (1, 24, 14040, 64, "Wan-1.3B 480p 2s"),
+ (1, 24, 3510, 64, "Wan-1.3B 480p 2s ring4"),
+ (1, 24, 7020, 64, "Wan-1.3B 480p 2s ring2"),
+ # Wan2.1-T2V-14B configurations
+ (1, 40, 75600, 128, "Wan-14B 720p 5s"),
+ (1, 40, 37800, 128, "Wan-14B 720p 5s ring2"),
+ (1, 40, 18900, 128, "Wan-14B 720p 5s ring4"),
+ (1, 40, 9450, 128, "Wan-14B 720p 5s ring8"),
+ # Ulysses parallelism configurations
+ (1, 20, 75600, 128, "Wan-14B 720p ulysses2"),
+ (1, 10, 75600, 128, "Wan-14B 720p ulysses4"),
+ (1, 5, 75600, 128, "Wan-14B 720p ulysses8"),
+ # Smaller test cases for quick validation
+ (2, 24, 1024, 64, "Small batch2"),
+ (1, 24, 4096, 64, "Medium 4k"),
+ (1, 40, 8192, 128, "Large 8k"),
+ ]
+
+ # Quick test sizes for CI/pytest
+ QUICK_TEST_SIZES = [
+ (1, 24, 1024, 64, "Quick 1k"),
+ (1, 24, 2048, 64, "Quick 2k"),
+ (2, 24, 1024, 64, "Quick batch2"),
+ ]
+
+ def __init__(
+ self,
+ device: Optional[torch.device] = None,
+ dtype: torch.dtype = torch.bfloat16,
+ warmup_iterations: int = 10,
+ benchmark_iterations: int = 50,
+ ):
+ self.device = device or torch.device("cuda" if torch.cuda.is_available() else "cpu")
+ self.dtype = dtype
+ self.warmup_iterations = warmup_iterations
+ self.benchmark_iterations = benchmark_iterations
+ self.backends = ["VANILLA", "TRTLLM"]
+
+ def create_attention_model(
+ self, hidden_size: int, num_heads: int, head_dim: int, backend: str
+ ) -> Attention:
+ """Create a WAN self-attention model with specified backend."""
+ config = create_model_config(hidden_size, num_heads, head_dim, attn_backend=backend)
+ model = Attention(hidden_size, num_heads, qkv_mode=QKVMode.FUSE_QKV, config=config).to(
+ self.device
+ )
+ model.eval()
+ return model
+
+ def create_test_data(
+ self, batch_size: int, seq_len: int, hidden_size: int, head_dim: int
+ ) -> Tuple[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
+ """Create test input data and RoPE embeddings."""
+ hidden_states = torch.randn(
+ batch_size, seq_len, hidden_size, device=self.device, dtype=self.dtype
+ )
+ freqs = generate_rope_embeddings(seq_len, head_dim, self.device, is_HSD=False)
+ return hidden_states, freqs
+
+ def estimate_memory_gb(
+ self, batch_size: int, num_heads: int, seq_len: int, head_dim: int
+ ) -> float:
+ """Estimate tensor memory usage in GB."""
+ hidden_size = num_heads * head_dim
+ # Input: [B, S, H] + Q, K, V: [B, S, num_heads, head_dim] each
+ bytes_per_element = 2 # bf16
+ input_bytes = batch_size * seq_len * hidden_size * bytes_per_element
+ qkv_bytes = 3 * batch_size * seq_len * num_heads * head_dim * bytes_per_element
+ output_bytes = batch_size * seq_len * hidden_size * bytes_per_element
+ # Attention matrix can be O(S^2) but flash attention avoids materializing it
+ return (input_bytes + qkv_bytes + output_bytes) / (1024**3)
+
+ def benchmark_single(
+ self,
+ batch_size: int,
+ num_heads: int,
+ seq_len: int,
+ head_dim: int,
+ backend: str,
+ verbose: bool = True,
+ ) -> Optional[Dict]:
+ """Benchmark a single configuration.
+
+ Returns:
+ Dict with timing statistics or None if test failed/skipped
+ """
+ hidden_size = num_heads * head_dim
+
+ # Memory check
+ est_memory = self.estimate_memory_gb(batch_size, num_heads, seq_len, head_dim)
+ if est_memory > 8.0:
+ if verbose:
+ print(f" Skipping - estimated memory {est_memory:.2f}GB > 8GB limit")
+ return None
+
+ try:
+ # Create model and data
+ model = self.create_attention_model(hidden_size, num_heads, head_dim, backend)
+ hidden_states, freqs = self.create_test_data(batch_size, seq_len, hidden_size, head_dim)
+
+ # Warmup
+ with nvtx_range(f"warmup_{backend}"):
+ with torch_profiler_range(f"warmup_{backend}"):
+ with torch.no_grad():
+ for _ in range(self.warmup_iterations):
+ _ = model(hidden_states, freqs=freqs)
+
+ if self.device.type == "cuda":
+ torch.cuda.synchronize()
+
+ # Benchmark
+ times = []
+ with nvtx_range(f"benchmark_{backend}"):
+ with torch_profiler_range(f"benchmark_{backend}"):
+ with torch.no_grad():
+ for i in range(self.benchmark_iterations):
+ with nvtx_range(f"iter_{backend}_{i}"):
+ with cuda_timer(self.device) as get_time:
+ _ = model(hidden_states, freqs=freqs)
+ times.append(get_time())
+
+ # Statistics
+ times_tensor = torch.tensor(times)
+ stats = {
+ "avg_ms": times_tensor.mean().item(),
+ "min_ms": times_tensor.min().item(),
+ "max_ms": times_tensor.max().item(),
+ "std_ms": times_tensor.std().item(),
+ "median_ms": times_tensor.median().item(),
+ "p95_ms": torch.quantile(times_tensor, 0.95).item(),
+ "p99_ms": torch.quantile(times_tensor, 0.99).item(),
+ }
+
+ # Calculate throughput (approximate TOPS)
+ total_ops = batch_size * num_heads * seq_len * seq_len * head_dim
+ stats["throughput_tops"] = (total_ops / 1e12) / (stats["avg_ms"] / 1000)
+
+ if verbose:
+ print(
+ f" {backend}: avg={stats['avg_ms']:.3f}ms, "
+ f"median={stats['median_ms']:.3f}ms, "
+ f"throughput={stats['throughput_tops']:.2f} TOPS"
+ )
+
+ return stats
+
+ except Exception as e:
+ if verbose:
+ print(f" {backend}: ERROR - {e}")
+ return None
+
+ def benchmark_comparison(
+ self,
+ batch_size: int,
+ num_heads: int,
+ seq_len: int,
+ head_dim: int,
+ description: str = "",
+ verbose: bool = True,
+ ) -> Dict[str, Optional[Dict]]:
+ """Benchmark and compare all backends for a given configuration."""
+ if verbose:
+ print(
+ f"\nBenchmarking: ({batch_size}, {num_heads}, {seq_len}, {head_dim}) {description}"
+ )
+ print(f" Device: {self.device}, dtype: {self.dtype}")
+ print(f" Warmup: {self.warmup_iterations}, Iterations: {self.benchmark_iterations}")
+
+ results = {}
+ for backend in self.backends:
+ results[backend] = self.benchmark_single(
+ batch_size, num_heads, seq_len, head_dim, backend, verbose
+ )
+
+ # Print comparison
+ if verbose and results.get("VANILLA") and results.get("TRTLLM"):
+ vanilla_avg = results["VANILLA"]["avg_ms"]
+ trtllm_avg = results["TRTLLM"]["avg_ms"]
+ speedup = vanilla_avg / trtllm_avg
+ print(f" TRTLLM vs VANILLA: {speedup:.2f}x {'faster' if speedup > 1 else 'slower'}")
+
+ return results
+
+ def run_full_benchmark(self, use_quick_sizes: bool = False) -> Dict:
+ """Run benchmark on all configured sizes."""
+ test_sizes = self.QUICK_TEST_SIZES if use_quick_sizes else self.TEST_SIZES
+
+ print("\n" + "=" * 70)
+ print("WAN ATTENTION PERFORMANCE BENCHMARK")
+ print("=" * 70)
+ print(f"Device: {self.device}")
+ print(f"dtype: {self.dtype}")
+ print(f"Backends: {self.backends}")
+ print(f"NVTX: {'Enabled' if NVTX_AVAILABLE else 'Disabled'}")
+ print(f"Torch Profiler: {'Enabled' if TORCH_PROFILER_AVAILABLE else 'Disabled'}")
+
+ all_results = {}
+
+ with nvtx_range("wan_attention_benchmark"):
+ with torch_profiler_range("wan_attention_benchmark"):
+ for batch_size, num_heads, seq_len, head_dim, desc in test_sizes:
+ key = f"{desc}_{batch_size}x{num_heads}x{seq_len}x{head_dim}"
+ results = self.benchmark_comparison(
+ batch_size, num_heads, seq_len, head_dim, desc
+ )
+ all_results[key] = {
+ "config": {
+ "batch_size": batch_size,
+ "num_heads": num_heads,
+ "seq_len": seq_len,
+ "head_dim": head_dim,
+ "description": desc,
+ },
+ "results": results,
+ }
+
+ # Print summary
+ self._print_summary(all_results)
+ return all_results
+
+ def _print_summary(self, all_results: Dict) -> None:
+ """Print benchmark summary table."""
+ print("\n" + "=" * 70)
+ print("BENCHMARK SUMMARY")
+ print("=" * 70)
+ print(f"{'Configuration':<40} {'VANILLA (ms)':<15} {'TRTLLM (ms)':<15} {'Speedup':<10}")
+ print("-" * 70)
+
+ for key, data in all_results.items():
+ desc = data["config"]["description"]
+ results = data["results"]
+
+ vanilla = results.get("VANILLA")
+ trtllm = results.get("TRTLLM")
+
+ vanilla_str = f"{vanilla['avg_ms']:.2f}" if vanilla else "N/A"
+ trtllm_str = f"{trtllm['avg_ms']:.2f}" if trtllm else "N/A"
+
+ if vanilla and trtllm:
+ speedup = vanilla["avg_ms"] / trtllm["avg_ms"]
+ speedup_str = f"{speedup:.2f}x"
+ else:
+ speedup_str = "N/A"
+
+ print(f"{desc:<40} {vanilla_str:<15} {trtllm_str:<15} {speedup_str:<10}")
+
+ def test_memory_usage(
+ self,
+ batch_size: int = 1,
+ num_heads: int = 24,
+ seq_len: int = 4096,
+ head_dim: int = 64,
+ ) -> Dict[str, Dict]:
+ """Test memory usage of different backends."""
+ if self.device.type != "cuda":
+ print("Memory test requires CUDA device")
+ return {}
+
+ print("\n" + "=" * 70)
+ print("MEMORY USAGE TEST")
+ print("=" * 70)
+ print(f"Config: ({batch_size}, {num_heads}, {seq_len}, {head_dim})")
+
+ hidden_size = num_heads * head_dim
+ memory_results = {}
+
+ for backend in self.backends:
+ print(f"\nTesting {backend}...")
+
+ try:
+ # Clear cache
+ torch.cuda.empty_cache()
+ torch.cuda.reset_peak_memory_stats()
+
+ # Create model and data
+ model = self.create_attention_model(hidden_size, num_heads, head_dim, backend)
+ hidden_states, freqs = self.create_test_data(
+ batch_size, seq_len, hidden_size, head_dim
+ )
+
+ # Warmup
+ with torch.no_grad():
+ _ = model(hidden_states, freqs=freqs)
+
+ torch.cuda.synchronize()
+ torch.cuda.reset_peak_memory_stats()
+
+ # Forward pass
+ with nvtx_range(f"memory_test_{backend}"):
+ with torch.no_grad():
+ _ = model(hidden_states, freqs=freqs)
+
+ torch.cuda.synchronize()
+
+ peak_memory_gb = torch.cuda.max_memory_allocated() / (1024**3)
+ current_memory_gb = torch.cuda.memory_allocated() / (1024**3)
+
+ memory_results[backend] = {
+ "peak_memory_gb": peak_memory_gb,
+ "current_memory_gb": current_memory_gb,
+ }
+
+ print(f" Peak memory: {peak_memory_gb:.3f} GB")
+ print(f" Current memory: {current_memory_gb:.3f} GB")
+
+ except Exception as e:
+ print(f" ERROR: {e}")
+ memory_results[backend] = None
+
+ return memory_results
+
+
+# ============================================================================
+# Pytest test functions
+# ============================================================================
+
+
+class TestWanAttentionPerformance:
+ """Pytest test class for WAN attention performance."""
+
+ @pytest.fixture(autouse=True)
+ def setup(self):
+ """Setup test environment."""
+ self.benchmark = WanAttentionPerformanceBenchmark(
+ warmup_iterations=5,
+ benchmark_iterations=20,
+ )
+
+ @pytest.mark.parametrize("backend", ["VANILLA", "TRTLLM"])
+ def test_self_attention_perf(self, backend: str):
+ """Test that attention backend runs without errors."""
+ batch_size, num_heads, seq_len, head_dim = 1, 24, 1024, 64
+
+ result = self.benchmark.benchmark_single(
+ batch_size, num_heads, seq_len, head_dim, backend, verbose=True
+ )
+
+ if result is not None:
+ assert result["avg_ms"] > 0, "Average time should be positive"
+ assert result["min_ms"] <= result["avg_ms"], "Min should be <= avg"
+ assert result["max_ms"] >= result["avg_ms"], "Max should be >= avg"
+ print(f" {backend}: avg={result['avg_ms']:.3f}ms OK")
+
+ @pytest.mark.parametrize(
+ "batch_size,num_heads,seq_len,head_dim",
+ [
+ (1, 24, 1024, 64),
+ (1, 24, 2048, 64),
+ (2, 24, 1024, 64),
+ ],
+ )
+ def test_backend_comparison(self, batch_size: int, num_heads: int, seq_len: int, head_dim: int):
+ """Test VANILLA vs TRTLLM comparison."""
+ results = self.benchmark.benchmark_comparison(
+ batch_size, num_heads, seq_len, head_dim, verbose=True
+ )
+
+ # At least one backend should work
+ assert any(r is not None for r in results.values()), "All backends failed"
+
+ @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA required")
+ def test_memory_usage(self):
+ """Test memory usage tracking."""
+ memory_results = self.benchmark.test_memory_usage(
+ batch_size=1, num_heads=24, seq_len=2048, head_dim=64
+ )
+
+ for backend, result in memory_results.items():
+ if result is not None:
+ assert result["peak_memory_gb"] > 0, f"{backend} peak memory should be positive"
+
+ def test_quick_benchmark(self):
+ """Run quick benchmark for CI validation."""
+ results = self.benchmark.run_full_benchmark(use_quick_sizes=True)
+ assert len(results) > 0, "Should have benchmark results"
+
+
+# ============================================================================
+# Main entry point
+# ============================================================================
+
+
+def main():
+ """Run full benchmark suite."""
+ print("\n" + "=" * 70)
+ print("WAN ATTENTION PERFORMANCE BENCHMARK SUITE")
+ print("=" * 70)
+
+ if not torch.cuda.is_available():
+ print("WARNING: CUDA not available, results will not be meaningful")
+
+ # Print profiling instructions
+ if torch.cuda.is_available():
+ print("\nPROFILING INSTRUCTIONS:")
+ print("-" * 50)
+ if NVTX_AVAILABLE:
+ print("NVTX Profiling (Nsight Systems):")
+ print(" nsys profile -t cuda,nvtx --nvtx-capture=range \\")
+ print(" -o wan_attn_perf python test_attention_perf.py")
+ else:
+ print("NVTX not available. Install with: pip install nvtx")
+
+ print("\nPyTorch Profiler:")
+ print(" The benchmark includes record_function() calls for profiling")
+ print("-" * 50)
+
+ # Create benchmark instance
+ benchmark = WanAttentionPerformanceBenchmark(
+ warmup_iterations=10,
+ benchmark_iterations=50,
+ )
+
+ # Run full benchmark
+ print("\n" + "=" * 70)
+ print("FULL BENCHMARK")
+ print("=" * 70)
+ all_results = benchmark.run_full_benchmark(use_quick_sizes=False)
+
+ # Memory test
+ if torch.cuda.is_available():
+ benchmark.test_memory_usage(batch_size=1, num_heads=24, seq_len=4096, head_dim=64)
+
+ print("\n" + "=" * 70)
+ print("BENCHMARK COMPLETE")
+ print("=" * 70)
+
+ return all_results
+
+
+if __name__ == "__main__":
+ main()
diff --git a/tests/unittest/_torch/visual_gen/test_fused_qkv.py b/tests/unittest/_torch/visual_gen/test_fused_qkv.py
new file mode 100644
index 0000000000..c918b44d06
--- /dev/null
+++ b/tests/unittest/_torch/visual_gen/test_fused_qkv.py
@@ -0,0 +1,126 @@
+"""Tests for fused QKV support in diffusion models.
+
+Tests:
+1. Model structure with fuse_qkv=True (default) vs fuse_qkv=False
+2. Weight loading works for fused QKV layers
+"""
+
+import unittest
+from types import SimpleNamespace
+from typing import Dict
+
+import torch
+
+from tensorrt_llm._torch.visual_gen.config import DiffusionModelConfig
+
+
+def _create_test_config(hidden_size: int = 64) -> DiffusionModelConfig:
+ """Create a test DiffusionModelConfig."""
+ num_heads = hidden_size // 8 # e.g., 64 // 8 = 8 heads
+ head_dim = 8
+ return DiffusionModelConfig(
+ pretrained_config=SimpleNamespace(
+ hidden_size=hidden_size,
+ num_attention_heads=num_heads,
+ attention_head_dim=head_dim,
+ num_layers=2,
+ ffn_dim=256,
+ out_channels=16,
+ patch_size=[1, 2, 2],
+ in_channels=16,
+ text_dim=64,
+ freq_dim=32,
+ ),
+ )
+
+
+class TestFusedQKVWeightLoading(unittest.TestCase):
+ """Test weight loading for fused QKV layers."""
+
+ def setUp(self):
+ """Set up test fixtures."""
+ torch.manual_seed(42)
+ self.hidden_size = 64
+
+ def _create_mock_checkpoint_weights(self) -> Dict[str, torch.Tensor]:
+ """Create mock checkpoint weights with separate to_q, to_k, to_v."""
+ dtype = torch.bfloat16 # Match model dtype
+ weights = {}
+ for block_idx in range(2):
+ for attn_name in ["attn1", "attn2"]:
+ prefix = f"blocks.{block_idx}.{attn_name}"
+ # Separate QKV weights (as in checkpoint)
+ weights[f"{prefix}.to_q.weight"] = torch.randn(
+ self.hidden_size, self.hidden_size, dtype=dtype
+ )
+ weights[f"{prefix}.to_q.bias"] = torch.randn(self.hidden_size, dtype=dtype)
+ weights[f"{prefix}.to_k.weight"] = torch.randn(
+ self.hidden_size, self.hidden_size, dtype=dtype
+ )
+ weights[f"{prefix}.to_k.bias"] = torch.randn(self.hidden_size, dtype=dtype)
+ weights[f"{prefix}.to_v.weight"] = torch.randn(
+ self.hidden_size, self.hidden_size, dtype=dtype
+ )
+ weights[f"{prefix}.to_v.bias"] = torch.randn(self.hidden_size, dtype=dtype)
+ # Output projection
+ weights[f"{prefix}.to_out.0.weight"] = torch.randn(
+ self.hidden_size, self.hidden_size, dtype=dtype
+ )
+ weights[f"{prefix}.to_out.0.bias"] = torch.randn(self.hidden_size, dtype=dtype)
+
+ # FFN weights
+ ffn_dim = 256
+ prefix = f"blocks.{block_idx}.ffn"
+ weights[f"{prefix}.net.0.proj.weight"] = torch.randn(
+ ffn_dim, self.hidden_size, dtype=dtype
+ )
+ weights[f"{prefix}.net.0.proj.bias"] = torch.randn(ffn_dim, dtype=dtype)
+ weights[f"{prefix}.net.2.weight"] = torch.randn(self.hidden_size, ffn_dim, dtype=dtype)
+ weights[f"{prefix}.net.2.bias"] = torch.randn(self.hidden_size, dtype=dtype)
+
+ # proj_out
+ weights["proj_out.weight"] = torch.randn(64, self.hidden_size, dtype=dtype)
+ weights["proj_out.bias"] = torch.randn(64, dtype=dtype)
+
+ return weights
+
+ def test_load_weights_fused(self):
+ """Test loading weights with fused QKV (default for self-attention)."""
+ from tensorrt_llm._torch.visual_gen.models.wan.transformer_wan import WanTransformer3DModel
+
+ config = _create_test_config(self.hidden_size)
+
+ # Create model - self-attention (attn1) uses fused QKV by default
+ model = WanTransformer3DModel(model_config=config)
+ weights = self._create_mock_checkpoint_weights()
+
+ # Load weights (model handles fused QKV internally via DynamicLinearWeightLoader)
+ model.load_weights(weights)
+
+ # Verify fused weights were loaded correctly for self-attention
+ attn1 = model.blocks[0].attn1
+ qkv_weight = attn1.qkv_proj.weight.data
+
+ # Expected: concatenation of to_q, to_k, to_v weights
+ expected_weight = torch.cat(
+ [
+ weights["blocks.0.attn1.to_q.weight"],
+ weights["blocks.0.attn1.to_k.weight"],
+ weights["blocks.0.attn1.to_v.weight"],
+ ],
+ dim=0,
+ )
+
+ self.assertEqual(qkv_weight.shape, expected_weight.shape)
+ self.assertTrue(torch.allclose(qkv_weight, expected_weight))
+
+ # Also verify cross-attention (attn2) uses separate Q/K/V
+ attn2 = model.blocks[0].attn2
+ self.assertTrue(hasattr(attn2, "to_q"), "Cross-attention should have separate to_q")
+ self.assertTrue(
+ torch.allclose(attn2.to_q.weight.data, weights["blocks.0.attn2.to_q.weight"])
+ )
+
+
+if __name__ == "__main__":
+ unittest.main()
diff --git a/tests/unittest/_torch/visual_gen/test_model_loader.py b/tests/unittest/_torch/visual_gen/test_model_loader.py
new file mode 100644
index 0000000000..9cc8cba70e
--- /dev/null
+++ b/tests/unittest/_torch/visual_gen/test_model_loader.py
@@ -0,0 +1,494 @@
+"""Test PipelineLoader with DiffusionArgs API."""
+
+import os
+from pathlib import Path
+
+import pytest
+import torch
+
+from tensorrt_llm._torch.visual_gen.config import PipelineComponent
+
+
+def _llm_models_root() -> str:
+ """Return LLM_MODELS_ROOT path if it is set in env, assert when it's set but not a valid path."""
+ root = Path("/home/scratch.trt_llm_data_ci/llm-models/")
+ if "LLM_MODELS_ROOT" in os.environ:
+ root = Path(os.environ["LLM_MODELS_ROOT"])
+ if not root.exists():
+ root = Path("/scratch.trt_llm_data/llm-models/")
+ assert root.exists(), (
+ "You shall set LLM_MODELS_ROOT env or be able to access scratch.trt_llm_data to run this test"
+ )
+ return str(root)
+
+
+# Skip if checkpoint not available
+# Set DIFFUSION_MODEL_PATH env var to run integration tests
+CHECKPOINT_PATH = os.environ.get(
+ "DIFFUSION_MODEL_PATH",
+ os.path.join(_llm_models_root(), "Wan2.1-T2V-1.3B-Diffusers"),
+)
+
+# Skip heavy components (text_encoder ~44GB, vae ~300MB) to speed up tests
+# These components are loaded via diffusers and don't need quantization testing
+SKIP_HEAVY_COMPONENTS = [
+ PipelineComponent.TEXT_ENCODER,
+ PipelineComponent.VAE,
+ PipelineComponent.TOKENIZER,
+ PipelineComponent.SCHEDULER,
+]
+
+
+@pytest.fixture
+def checkpoint_exists():
+ return CHECKPOINT_PATH and os.path.exists(CHECKPOINT_PATH)
+
+
+def test_meta_init_mode_creates_meta_tensors(checkpoint_exists):
+ """Test that MetaInitMode creates tensors on meta device (no GPU memory)."""
+ if not checkpoint_exists:
+ pytest.skip("Checkpoint not available")
+
+ from tensorrt_llm._torch.models.modeling_utils import MetaInitMode
+ from tensorrt_llm._torch.visual_gen import DiffusionArgs
+ from tensorrt_llm._torch.visual_gen.config import DiffusionModelConfig
+ from tensorrt_llm._torch.visual_gen.models import AutoPipeline
+
+ # Load config directly
+ args = DiffusionArgs(checkpoint_path=CHECKPOINT_PATH)
+ config = DiffusionModelConfig.from_pretrained(
+ CHECKPOINT_PATH,
+ args=args,
+ )
+
+ # Create pipeline WITH MetaInitMode
+ with MetaInitMode():
+ pipeline = AutoPipeline.from_config(config, CHECKPOINT_PATH)
+
+ # Verify tensors are on meta device (no GPU memory allocated)
+ param = next(pipeline.transformer.parameters())
+ assert param.device.type == "meta", f"Expected meta device, got {param.device}"
+
+
+def test_load_wan_pipeline_basic(checkpoint_exists):
+ """Test basic loading without quantization using DiffusionArgs."""
+ if not checkpoint_exists:
+ pytest.skip("Checkpoint not available")
+
+ from tensorrt_llm._torch.visual_gen import DiffusionArgs, PipelineLoader
+
+ # Simple one-liner with DiffusionArgs
+ # Skip text_encoder/vae to speed up test (focus on transformer)
+ args = DiffusionArgs(
+ checkpoint_path=CHECKPOINT_PATH,
+ skip_components=SKIP_HEAVY_COMPONENTS,
+ )
+ pipeline = PipelineLoader(args).load()
+
+ # Verify pipeline type
+ assert pipeline.__class__.__name__ == "WanPipeline"
+ assert pipeline.transformer is not None
+
+ # Verify text_encoder/vae were skipped
+ assert pipeline.text_encoder is None, "text_encoder should be skipped"
+ assert pipeline.vae is None, "vae should be skipped"
+
+ # Verify weights are loaded (not meta tensors)
+ param = next(pipeline.transformer.parameters())
+ assert param.device.type == "cuda"
+ assert param.dtype in [torch.float32, torch.bfloat16, torch.float16]
+
+
+def test_load_wan_pipeline_with_fp8_dynamic_quant(checkpoint_exists):
+ """Test loading with FP8 dynamic quantization using DiffusionArgs.
+
+ Verifies the dynamic quantization flow:
+ 1. Config has dynamic_weight_quant=True when linear.type="trtllm-fp8-per-tensor"
+ 2. Model Linear layers have FP8 weight buffers
+ 3. BF16 checkpoint weights are quantized on-the-fly
+ 4. Quantized weights are in FP8 format
+ """
+ if not checkpoint_exists:
+ pytest.skip("Checkpoint not available")
+
+ from tensorrt_llm._torch.modules.linear import Linear
+ from tensorrt_llm._torch.visual_gen import DiffusionArgs, PipelineLoader
+
+ # Use DiffusionArgs with FP8 quantization
+ # Skip text_encoder/vae to speed up test (focus on transformer quantization)
+ args = DiffusionArgs(
+ checkpoint_path=CHECKPOINT_PATH,
+ quant_config={"quant_algo": "FP8", "dynamic": True},
+ skip_components=SKIP_HEAVY_COMPONENTS,
+ )
+ pipeline = PipelineLoader(args).load()
+
+ # Verify model config has dynamic_weight_quant enabled
+ assert pipeline.model_config.dynamic_weight_quant is True, (
+ "dynamic_weight_quant should be True when linear.type specifies FP8"
+ )
+
+ # Verify FP8 weights in transformer Linear layers
+ found_fp8_linear = False
+ for name, module in pipeline.transformer.named_modules():
+ if isinstance(module, Linear):
+ if hasattr(module, "weight") and module.weight is not None:
+ assert module.weight.dtype == torch.float8_e4m3fn, (
+ f"Linear {name} weight dtype is {module.weight.dtype}, expected float8_e4m3fn"
+ )
+ assert hasattr(module, "weight_scale") and module.weight_scale is not None, (
+ f"Linear {name} missing weight_scale buffer"
+ )
+ found_fp8_linear = True
+ break
+
+ assert found_fp8_linear, "No FP8 Linear modules found in transformer"
+
+
+def test_load_wan_pipeline_with_fp8_blockwise(checkpoint_exists):
+ """Test loading with FP8 blockwise quantization using DiffusionArgs."""
+ if not checkpoint_exists:
+ pytest.skip("Checkpoint not available")
+
+ from tensorrt_llm._torch.modules.linear import Linear
+ from tensorrt_llm._torch.visual_gen import DiffusionArgs, PipelineLoader
+
+ # Skip text_encoder/vae to speed up test (focus on transformer quantization)
+ args = DiffusionArgs(
+ checkpoint_path=CHECKPOINT_PATH,
+ quant_config={"quant_algo": "FP8_BLOCK_SCALES", "dynamic": True},
+ skip_components=SKIP_HEAVY_COMPONENTS,
+ )
+ pipeline = PipelineLoader(args).load()
+
+ # Verify FP8 weights
+ for name, module in pipeline.transformer.named_modules():
+ if isinstance(module, Linear):
+ if hasattr(module, "weight") and module.weight is not None:
+ assert module.weight.dtype == torch.float8_e4m3fn, (
+ f"Linear {name} should have FP8 weight"
+ )
+ break
+
+
+def test_diffusion_args_to_quant_config():
+ """Test that DiffusionArgs correctly parses quant_config dict to QuantConfig."""
+ from tensorrt_llm._torch.visual_gen import DiffusionArgs
+ from tensorrt_llm.quantization.mode import QuantAlgo
+
+ # Default - no quantization
+ args = DiffusionArgs(checkpoint_path="/fake/path")
+ assert args.quant_config.quant_algo is None
+
+ # FP8 per-tensor (dict is coerced to QuantConfig by model_validator)
+ args = DiffusionArgs(
+ checkpoint_path="/fake/path",
+ quant_config={"quant_algo": "FP8", "dynamic": True},
+ )
+ qc = args.quant_config
+ assert qc is not None
+ assert qc.quant_algo == QuantAlgo.FP8
+ assert args.dynamic_weight_quant is True
+
+ # FP8 blockwise
+ args = DiffusionArgs(
+ checkpoint_path="/fake/path",
+ quant_config={"quant_algo": "FP8_BLOCK_SCALES", "dynamic": True},
+ )
+ qc = args.quant_config
+ assert qc.quant_algo == QuantAlgo.FP8_BLOCK_SCALES
+
+ # NVFP4
+ args = DiffusionArgs(
+ checkpoint_path="/fake/path",
+ quant_config={"quant_algo": "NVFP4", "dynamic": True},
+ )
+ qc = args.quant_config
+ assert qc.quant_algo == QuantAlgo.NVFP4
+
+ # With ignore patterns (exclude_modules)
+ args = DiffusionArgs(
+ checkpoint_path="/fake/path",
+ quant_config={
+ "quant_algo": "FP8",
+ "ignore": ["blocks.0.attn1.*", "proj_out"],
+ "config_groups": {
+ "group_0": {
+ "weights": {"dynamic": True, "num_bits": 8, "type": "float"},
+ "targets": ["Linear"],
+ }
+ },
+ },
+ )
+ qc = args.quant_config
+ assert qc is not None
+ assert qc.quant_algo == QuantAlgo.FP8
+ assert qc.exclude_modules == ["blocks.0.attn1.*", "proj_out"]
+ assert args.dynamic_weight_quant is True
+
+
+def test_diffusion_args_to_mapping():
+ """Test that DiffusionArgs correctly generates Mapping from ParallelConfig."""
+ from tensorrt_llm._torch.visual_gen import DiffusionArgs, ParallelConfig
+
+ # ParallelConfig validator requires WORLD_SIZE >= total parallel (tp*cp = 4)
+ old_world = os.environ.get("WORLD_SIZE")
+ try:
+ os.environ["WORLD_SIZE"] = "4"
+ args = DiffusionArgs(
+ checkpoint_path="/fake/path",
+ parallel=ParallelConfig(dit_tp_size=2, dit_cp_size=2),
+ )
+ mapping = args.to_mapping()
+ assert mapping.tp_size == 2
+ assert mapping.cp_size == 2
+ # world_size = tp_size * pp_size * cp_size (DP is handled separately)
+ assert mapping.world_size == 4
+ finally:
+ if old_world is not None:
+ os.environ["WORLD_SIZE"] = old_world
+ elif "WORLD_SIZE" in os.environ:
+ del os.environ["WORLD_SIZE"]
+
+
+def test_load_without_quant_config_no_fp8(checkpoint_exists):
+ """Test that loading without quant_config does NOT produce FP8 weights."""
+ if not checkpoint_exists:
+ pytest.skip("Checkpoint not available")
+
+ from tensorrt_llm._torch.modules.linear import Linear
+ from tensorrt_llm._torch.visual_gen import DiffusionArgs, PipelineLoader
+
+ # No quantization specified
+ # Skip text_encoder/vae to speed up test (focus on transformer)
+ args = DiffusionArgs(
+ checkpoint_path=CHECKPOINT_PATH,
+ skip_components=SKIP_HEAVY_COMPONENTS,
+ )
+ pipeline = PipelineLoader(args).load()
+
+ # Verify dynamic_weight_quant is False
+ assert pipeline.model_config.dynamic_weight_quant is False, (
+ "dynamic_weight_quant should be False when no quant_config"
+ )
+
+ # Verify NO FP8 weights
+ for name, module in pipeline.transformer.named_modules():
+ if isinstance(module, Linear):
+ if hasattr(module, "weight") and module.weight is not None:
+ assert module.weight.dtype != torch.float8_e4m3fn, (
+ f"Linear {name} should NOT be FP8 without quant_config"
+ )
+ break
+
+
+def test_diffusion_args_from_dict():
+ """Test DiffusionArgs can be created from a dictionary."""
+ from tensorrt_llm._torch.visual_gen import DiffusionArgs
+ from tensorrt_llm.quantization.mode import QuantAlgo
+
+ config_dict = {
+ "checkpoint_path": "/path/to/model",
+ "quant_config": {"quant_algo": "FP8", "dynamic": True},
+ "parallel": {"dit_tp_size": 2},
+ "pipeline": {"fuse_qkv": True},
+ }
+ # ParallelConfig validator requires WORLD_SIZE >= total parallel (dit_tp_size=2)
+ old_world = os.environ.get("WORLD_SIZE")
+ try:
+ os.environ["WORLD_SIZE"] = "2"
+ args = DiffusionArgs.from_dict(config_dict)
+ assert args.checkpoint_path == "/path/to/model"
+ assert args.quant_config.quant_algo == QuantAlgo.FP8
+ assert args.dynamic_weight_quant is True
+ assert args.parallel.dit_tp_size == 2
+ assert args.pipeline.fuse_qkv is True
+ finally:
+ if old_world is not None:
+ os.environ["WORLD_SIZE"] = old_world
+ elif "WORLD_SIZE" in os.environ:
+ del os.environ["WORLD_SIZE"]
+
+
+# =============================================================================
+# Memory and Performance Tests
+# =============================================================================
+
+
+def _get_module_memory_gb(module):
+ """Get GPU memory usage of a module in GB."""
+ return sum(p.numel() * p.element_size() for p in module.parameters()) / 1024**3
+
+
+def _get_cuda_memory_gb():
+ """Get current CUDA memory allocated in GB."""
+ return torch.cuda.memory_allocated() / 1024**3
+
+
+def _get_cuda_peak_memory_gb():
+ """Get peak CUDA memory allocated in GB."""
+ return torch.cuda.max_memory_allocated() / 1024**3
+
+
+def test_fp8_vs_bf16_memory_comparison(checkpoint_exists):
+ """Test FP8 dynamic quant uses ~2x less memory than BF16, including peak memory.
+
+ This test verifies that dynamic quantization doesn't create unnecessary
+ intermediate buffers that would negate the memory savings.
+
+ Expected for Wan 1.3B transformer:
+ - BF16: ~2.6 GB model memory, similar peak during loading
+ - FP8: ~1.3 GB model memory, peak should be < 2x BF16 peak
+ """
+ if not checkpoint_exists:
+ pytest.skip("Checkpoint not available")
+
+ from tensorrt_llm._torch.visual_gen import DiffusionArgs, PipelineLoader
+
+ # =========================================================================
+ # Test 1: Load BF16 (no quantization)
+ # =========================================================================
+ torch.cuda.empty_cache()
+ torch.cuda.reset_peak_memory_stats()
+
+ args_bf16 = DiffusionArgs(
+ checkpoint_path=CHECKPOINT_PATH,
+ skip_components=SKIP_HEAVY_COMPONENTS,
+ )
+ pipeline_bf16 = PipelineLoader(args_bf16).load()
+
+ bf16_model_mem = _get_module_memory_gb(pipeline_bf16.transformer)
+ bf16_total_mem = _get_cuda_memory_gb()
+ bf16_peak_mem = _get_cuda_peak_memory_gb()
+
+ print(f"\n[BF16] Transformer model memory: {bf16_model_mem:.2f} GB")
+ print(f"[BF16] Total CUDA memory: {bf16_total_mem:.2f} GB")
+ print(f"[BF16] Peak CUDA memory: {bf16_peak_mem:.2f} GB")
+
+ # Cleanup BF16
+ del pipeline_bf16
+ torch.cuda.empty_cache()
+
+ # =========================================================================
+ # Test 2: Load FP8 (dynamic quantization)
+ # =========================================================================
+ torch.cuda.reset_peak_memory_stats()
+
+ args_fp8 = DiffusionArgs(
+ checkpoint_path=CHECKPOINT_PATH,
+ quant_config={"quant_algo": "FP8", "dynamic": True},
+ skip_components=SKIP_HEAVY_COMPONENTS,
+ )
+ pipeline_fp8 = PipelineLoader(args_fp8).load()
+
+ fp8_model_mem = _get_module_memory_gb(pipeline_fp8.transformer)
+ fp8_total_mem = _get_cuda_memory_gb()
+ fp8_peak_mem = _get_cuda_peak_memory_gb()
+
+ print(f"\n[FP8] Transformer model memory: {fp8_model_mem:.2f} GB")
+ print(f"[FP8] Total CUDA memory: {fp8_total_mem:.2f} GB")
+ print(f"[FP8] Peak CUDA memory: {fp8_peak_mem:.2f} GB")
+
+ # =========================================================================
+ # Verify memory savings
+ # =========================================================================
+ model_mem_ratio = bf16_model_mem / fp8_model_mem
+ peak_mem_ratio = bf16_peak_mem / fp8_peak_mem
+
+ print(f"\n[Comparison] Model memory ratio (BF16/FP8): {model_mem_ratio:.2f}x")
+ print(f"[Comparison] Peak memory ratio (BF16/FP8): {peak_mem_ratio:.2f}x")
+
+ # Model memory should be ~2x smaller for FP8
+ assert model_mem_ratio > 1.8, (
+ f"FP8 model memory should be ~2x smaller than BF16, got {model_mem_ratio:.2f}x"
+ )
+
+ # Peak memory during loading should also show savings
+ # Allow some overhead for dynamic quant, but should still be significantly better
+ assert peak_mem_ratio > 1.5, (
+ f"FP8 peak memory should be significantly smaller than BF16, got {peak_mem_ratio:.2f}x. "
+ f"This may indicate unnecessary intermediate buffers during dynamic quantization."
+ )
+
+ # FP8 peak should not be much higher than FP8 final (no large temp buffers)
+ fp8_peak_overhead = fp8_peak_mem / fp8_total_mem
+ print(f"[FP8 Per-Tensor] Peak/Final memory ratio: {fp8_peak_overhead:.2f}x")
+
+ # Peak should be close to final (< 1.5x overhead during loading)
+ assert fp8_peak_overhead < 2.0, (
+ f"FP8 peak memory ({fp8_peak_mem:.2f} GB) is too high compared to final "
+ f"({fp8_total_mem:.2f} GB). Ratio: {fp8_peak_overhead:.2f}x. "
+ f"This suggests unnecessary buffer allocation during dynamic quantization."
+ )
+
+ # Cleanup
+ del pipeline_fp8
+ torch.cuda.empty_cache()
+
+ # =========================================================================
+ # Test 3: Load FP8 Blockwise (dynamic quantization with block scales)
+ # =========================================================================
+ torch.cuda.reset_peak_memory_stats()
+
+ args_fp8_block = DiffusionArgs(
+ checkpoint_path=CHECKPOINT_PATH,
+ quant_config={"quant_algo": "FP8_BLOCK_SCALES", "dynamic": True},
+ skip_components=SKIP_HEAVY_COMPONENTS,
+ )
+ pipeline_fp8_block = PipelineLoader(args_fp8_block).load()
+
+ fp8_block_model_mem = _get_module_memory_gb(pipeline_fp8_block.transformer)
+ fp8_block_total_mem = _get_cuda_memory_gb()
+ fp8_block_peak_mem = _get_cuda_peak_memory_gb()
+
+ print(f"\n[FP8 Blockwise] Transformer model memory: {fp8_block_model_mem:.2f} GB")
+ print(f"[FP8 Blockwise] Total CUDA memory: {fp8_block_total_mem:.2f} GB")
+ print(f"[FP8 Blockwise] Peak CUDA memory: {fp8_block_peak_mem:.2f} GB")
+
+ # =========================================================================
+ # Verify FP8 Blockwise memory savings
+ # =========================================================================
+ block_model_mem_ratio = bf16_model_mem / fp8_block_model_mem
+ block_peak_mem_ratio = bf16_peak_mem / fp8_block_peak_mem
+
+ print(f"\n[Comparison] Model memory ratio (BF16/FP8-Block): {block_model_mem_ratio:.2f}x")
+ print(f"[Comparison] Peak memory ratio (BF16/FP8-Block): {block_peak_mem_ratio:.2f}x")
+
+ # FP8 Blockwise has additional scale tensors, so slightly less than 2x savings
+ # But should still be significantly better than BF16
+ assert block_model_mem_ratio > 1.5, (
+ f"FP8 Blockwise model memory should be significantly smaller than BF16, got {block_model_mem_ratio:.2f}x"
+ )
+
+ # Peak memory check
+ assert block_peak_mem_ratio > 1.3, (
+ f"FP8 Blockwise peak memory should be smaller than BF16, got {block_peak_mem_ratio:.2f}x"
+ )
+
+ fp8_block_peak_overhead = fp8_block_peak_mem / fp8_block_total_mem
+ print(f"[FP8 Blockwise] Peak/Final memory ratio: {fp8_block_peak_overhead:.2f}x")
+
+ assert fp8_block_peak_overhead < 2.0, (
+ f"FP8 Blockwise peak memory ({fp8_block_peak_mem:.2f} GB) is too high compared to final "
+ f"({fp8_block_total_mem:.2f} GB). Ratio: {fp8_block_peak_overhead:.2f}x."
+ )
+
+ # Cleanup
+ del pipeline_fp8_block
+ torch.cuda.empty_cache()
+
+ # =========================================================================
+ # Summary
+ # =========================================================================
+ print("\n" + "=" * 60)
+ print("SUMMARY")
+ print("=" * 60)
+ print(f"BF16: {bf16_model_mem:.2f} GB model, {bf16_peak_mem:.2f} GB peak")
+ print(
+ f"FP8 Per-Tensor: {fp8_model_mem:.2f} GB model, {fp8_peak_mem:.2f} GB peak "
+ f"({model_mem_ratio:.2f}x savings)"
+ )
+ print(
+ f"FP8 Blockwise: {fp8_block_model_mem:.2f} GB model, {fp8_block_peak_mem:.2f} GB peak "
+ f"({block_model_mem_ratio:.2f}x savings)"
+ )
diff --git a/tests/unittest/_torch/visual_gen/test_quant_ops.py b/tests/unittest/_torch/visual_gen/test_quant_ops.py
new file mode 100644
index 0000000000..5b141ee74d
--- /dev/null
+++ b/tests/unittest/_torch/visual_gen/test_quant_ops.py
@@ -0,0 +1,120 @@
+"""Unit tests for diffusion quantization operations."""
+
+import unittest
+
+import torch
+
+from tensorrt_llm._torch.visual_gen.quantization.ops import (
+ quantize_fp8_blockwise,
+ quantize_fp8_per_tensor,
+)
+
+
+def _dequant_fp8_per_tensor(qweight, scale):
+ """Dequantize per-tensor FP8 weight."""
+ return qweight.to(torch.float32) * scale
+
+
+class TestQuantOps(unittest.TestCase):
+ """Test quantization operations."""
+
+ def setUp(self):
+ """Set random seed for reproducibility."""
+ torch.manual_seed(42)
+ if not torch.cuda.is_available():
+ self.skipTest("CUDA not available")
+
+ def test_fp8_per_tensor(self):
+ """Test FP8 per-tensor quantization using CUDA kernel."""
+ weight = torch.randn(256, 512, dtype=torch.bfloat16, device="cuda")
+ qweight, scale = quantize_fp8_per_tensor(weight)
+
+ # Check output types
+ self.assertEqual(qweight.dtype, torch.float8_e4m3fn)
+ self.assertEqual(qweight.shape, weight.shape)
+ self.assertEqual(scale.dtype, torch.float32)
+ self.assertEqual(scale.shape, (1, 1))
+
+ # Verify dequantization (approximate)
+ dequant = _dequant_fp8_per_tensor(qweight, scale)
+ error = (dequant - weight.to(torch.float32)).abs().mean()
+ self.assertLess(error, 0.15) # Reasonable quantization error
+
+ def test_fp8_per_tensor_different_shapes(self):
+ """Test FP8 per-tensor quantization with various shapes."""
+ shapes = [(128, 256), (256, 512), (512, 1024), (1024, 2048)]
+ for shape in shapes:
+ with self.subTest(shape=shape):
+ weight = torch.randn(shape, dtype=torch.bfloat16, device="cuda")
+ qweight, scale = quantize_fp8_per_tensor(weight)
+
+ self.assertEqual(qweight.dtype, torch.float8_e4m3fn)
+ self.assertEqual(qweight.shape, weight.shape)
+ self.assertEqual(scale.dtype, torch.float32)
+
+ def test_fp8_blockwise(self):
+ """Test FP8 128x128 blockwise quantization."""
+ weight = torch.randn(512, 512, dtype=torch.bfloat16, device="cuda")
+ block_size = 128
+ qweight, scales = quantize_fp8_blockwise(weight, block_size=block_size)
+
+ # Check output types
+ self.assertEqual(qweight.dtype, torch.float8_e4m3fn)
+ self.assertEqual(qweight.shape, weight.shape)
+ self.assertEqual(scales.dtype, torch.float32)
+
+ # Check scales shape: (num_blocks_out, num_blocks_in) for 128x128 blocks
+ num_blocks_out = (512 + block_size - 1) // block_size # 4
+ num_blocks_in = (512 + block_size - 1) // block_size # 4
+ self.assertEqual(scales.shape, (num_blocks_out, num_blocks_in))
+
+ def test_fp8_blockwise_non_divisible(self):
+ """Test FP8 blockwise quantization with non-divisible dimensions."""
+ weight = torch.randn(300, 500, dtype=torch.bfloat16, device="cuda")
+ block_size = 128
+ qweight, scales = quantize_fp8_blockwise(weight, block_size=block_size)
+
+ # Check output types
+ self.assertEqual(qweight.dtype, torch.float8_e4m3fn)
+ self.assertEqual(qweight.shape, weight.shape)
+
+ # Check scales shape (should handle non-divisible dimensions)
+ num_blocks_out = (300 + block_size - 1) // block_size # 3
+ num_blocks_in = (500 + block_size - 1) // block_size # 4
+ self.assertEqual(scales.shape, (num_blocks_out, num_blocks_in))
+
+ def test_fp8_blockwise_different_block_sizes(self):
+ """Test FP8 blockwise quantization with different block sizes."""
+ weight = torch.randn(256, 256, dtype=torch.bfloat16, device="cuda")
+
+ for block_size in [64, 128, 256]:
+ with self.subTest(block_size=block_size):
+ qweight, scales = quantize_fp8_blockwise(weight, block_size=block_size)
+
+ self.assertEqual(qweight.dtype, torch.float8_e4m3fn)
+ self.assertEqual(qweight.shape, weight.shape)
+
+ num_blocks = (256 + block_size - 1) // block_size
+ self.assertEqual(scales.shape, (num_blocks, num_blocks))
+
+ def test_fp8_per_tensor_zero_weight(self):
+ """Test FP8 per-tensor quantization with zero weight."""
+ weight = torch.zeros(128, 256, dtype=torch.bfloat16, device="cuda")
+ qweight, scale = quantize_fp8_per_tensor(weight)
+
+ # Should handle zero weights gracefully
+ self.assertEqual(qweight.dtype, torch.float8_e4m3fn)
+ self.assertTrue(torch.all(qweight.to(torch.float32) == 0))
+
+ def test_fp8_blockwise_zero_weight(self):
+ """Test FP8 blockwise quantization with zero weight."""
+ weight = torch.zeros(256, 256, dtype=torch.bfloat16, device="cuda")
+ qweight, scales = quantize_fp8_blockwise(weight, block_size=128)
+
+ # Should handle zero weights gracefully
+ self.assertEqual(qweight.dtype, torch.float8_e4m3fn)
+ self.assertTrue(torch.all(qweight.to(torch.float32) == 0))
+
+
+if __name__ == "__main__":
+ unittest.main()
diff --git a/tests/unittest/_torch/visual_gen/test_trtllm_serve_e2e.py b/tests/unittest/_torch/visual_gen/test_trtllm_serve_e2e.py
new file mode 100644
index 0000000000..c785fe8bae
--- /dev/null
+++ b/tests/unittest/_torch/visual_gen/test_trtllm_serve_e2e.py
@@ -0,0 +1,398 @@
+"""End-to-end tests for trtllm-serve visual_gen with real models.
+
+Tests text-to-video (t2v) and text+image-to-video (ti2v) generation through
+the full ``trtllm-serve`` stack backed by real VisualGen models.
+
+The server is launched as a subprocess (same pattern as
+``tests/unittest/llmapi/apps/openai_server.py``), so each test class gets an
+isolated ``trtllm-serve`` process.
+
+Usage::
+
+ # Run all real-model tests (requires GPU + models in $HOME/llm-models-ci)
+ pytest tests/visual_gen/test_trtllm_serve_e2e.py -v
+
+ # Run only t2v tests
+ pytest tests/visual_gen/test_trtllm_serve_e2e.py -v -k TestWanT2V
+
+ # Run only ti2v tests
+ pytest tests/visual_gen/test_trtllm_serve_e2e.py -v -k TestWanI2V
+"""
+
+import os
+import subprocess
+import sys
+import tempfile
+import time
+from pathlib import Path
+from typing import List, Optional
+
+import pytest
+import requests
+import yaml
+
+from tensorrt_llm._utils import get_free_port
+
+# ---------------------------------------------------------------------------
+# Model paths
+# ---------------------------------------------------------------------------
+
+
+def _llm_models_root() -> str:
+ """Return LLM_MODELS_ROOT path if it is set in env, assert when it's set but not a valid path."""
+ root = Path("/home/scratch.trt_llm_data_ci/llm-models/")
+ if "LLM_MODELS_ROOT" in os.environ:
+ root = Path(os.environ["LLM_MODELS_ROOT"])
+ if not root.exists():
+ root = Path("/scratch.trt_llm_data/llm-models/")
+ assert root.exists(), (
+ "You shall set LLM_MODELS_ROOT env or be able to access scratch.trt_llm_data to run this test"
+ )
+ return str(root)
+
+
+_WAN_T2V_PATH = Path(_llm_models_root()) / "Wan2.1-T2V-1.3B-Diffusers"
+_WAN_I2V_PATH = Path(_llm_models_root()) / "Wan2.2-I2V-A14B-Diffusers"
+
+# Reference image used for image-to-video (ti2v) tests
+_PROJECT_ROOT = Path(__file__).resolve().parents[4] # repo root
+_REF_IMAGE_PATH = _PROJECT_ROOT / "examples" / "visual_gen" / "cat_piano.png"
+
+
+# ---------------------------------------------------------------------------
+# Remote server helper (follows RemoteOpenAIServer pattern)
+# ---------------------------------------------------------------------------
+
+
+class RemoteVisualGenServer:
+ """Launch ``trtllm-serve`` for a visual-gen model as a subprocess.
+
+ Mirrors the interface of ``tests.unittest.llmapi.apps.openai_server.RemoteOpenAIServer``
+ adapted for diffusion / visual-gen models.
+ """
+
+ MAX_SERVER_START_WAIT_S = 1200 # 20 min ā large models need time to load
+
+ def __init__(
+ self,
+ model: str,
+ extra_visual_gen_options: Optional[dict] = None,
+ cli_args: Optional[List[str]] = None,
+ host: str = "localhost",
+ port: Optional[int] = None,
+ env: Optional[dict] = None,
+ ) -> None:
+ self.host = host
+ self.port = port if port is not None else get_free_port()
+ self._config_file: Optional[str] = None
+ self.proc: Optional[subprocess.Popen] = None
+
+ args = ["--host", self.host, "--port", str(self.port)]
+ if cli_args:
+ args += cli_args
+
+ # Write the visual-gen YAML config to a temp file
+ if extra_visual_gen_options:
+ fd, self._config_file = tempfile.mkstemp(suffix=".yml", prefix="vg_cfg_")
+ with os.fdopen(fd, "w") as f:
+ yaml.dump(extra_visual_gen_options, f)
+ args += ["--extra_visual_gen_options", self._config_file]
+
+ launch_cmd = ["trtllm-serve", model] + args
+
+ if env is None:
+ env = os.environ.copy()
+
+ self.proc = subprocess.Popen(
+ launch_cmd,
+ env=env,
+ stdout=sys.stdout,
+ stderr=sys.stderr,
+ )
+ self._wait_for_server(timeout=self.MAX_SERVER_START_WAIT_S)
+
+ # -- lifecycle ---------------------------------------------------------
+
+ def __enter__(self):
+ return self
+
+ def __exit__(self, exc_type, exc_value, traceback):
+ self.terminate()
+
+ def terminate(self):
+ if self.proc is None:
+ return
+ self.proc.terminate()
+ try:
+ self.proc.wait(timeout=30)
+ except subprocess.TimeoutExpired:
+ self.proc.kill()
+ self.proc.wait(timeout=30)
+ self.proc = None
+
+ if self._config_file:
+ try:
+ os.remove(self._config_file)
+ except OSError:
+ pass
+ self._config_file = None
+
+ # -- readiness ---------------------------------------------------------
+
+ def _wait_for_server(self, timeout: float):
+ url = self.url_for("health")
+ start = time.time()
+ while True:
+ try:
+ if requests.get(url).status_code == 200:
+ return
+ except Exception as err:
+ result = self.proc.poll()
+ if result is not None and result != 0:
+ raise RuntimeError("Visual-gen server exited unexpectedly.") from err
+ time.sleep(1)
+ if time.time() - start > timeout:
+ self.terminate()
+ raise RuntimeError(f"Visual-gen server failed to start within {timeout}s.")
+
+ # -- URL helpers -------------------------------------------------------
+
+ @property
+ def url_root(self) -> str:
+ return f"http://{self.host}:{self.port}"
+
+ def url_for(self, *parts: str) -> str:
+ return self.url_root + "/" + "/".join(parts)
+
+
+# ---------------------------------------------------------------------------
+# Helpers
+# ---------------------------------------------------------------------------
+
+
+def _model_available(path: Path) -> bool:
+ return path.is_dir()
+
+
+def _av_available() -> bool:
+ """Check if PyAV is installed (required for video encoding in E2E tests)."""
+ try:
+ import av # noqa: F401
+
+ return True
+ except ImportError:
+ return False
+
+
+def _make_visual_gen_options(**extra) -> dict:
+ """Build the YAML dict passed via ``--extra_visual_gen_options``."""
+ config = {
+ "linear": {"type": "default"},
+ "parallel": {"dit_cfg_size": 1, "dit_ulysses_size": 1},
+ }
+ config.update(extra)
+ return config
+
+
+# =========================================================================
+# WAN 2.1 ā Text-to-Video (t2v)
+# =========================================================================
+
+
+@pytest.mark.skipif(
+ not _model_available(_WAN_T2V_PATH), reason=f"Wan2.1-T2V model not found at {_WAN_T2V_PATH}"
+)
+@pytest.mark.skipif(
+ not _av_available(), reason="PyAV (av) not installed ā required for video encoding in E2E tests"
+)
+class TestWanTextToVideo:
+ """Test Wan2.1-T2V-1.3B-Diffusers text-to-video generation via serve API."""
+
+ @pytest.fixture(scope="class")
+ def server(self):
+ with RemoteVisualGenServer(
+ model=str(_WAN_T2V_PATH),
+ extra_visual_gen_options=_make_visual_gen_options(),
+ ) as srv:
+ yield srv
+
+ # ------------------------------------------------------------------
+
+ def test_health(self, server):
+ resp = requests.get(server.url_for("health"))
+ assert resp.status_code == 200
+
+ def test_t2v_sync(self, server):
+ """Synchronous text-to-video via POST /v1/videos/generations."""
+ resp = requests.post(
+ server.url_for("v1", "videos", "generations"),
+ json={
+ "prompt": "A cute cat playing piano",
+ "size": "480x320",
+ "seconds": 1.0,
+ "fps": 8,
+ "num_inference_steps": 4,
+ "seed": 42,
+ },
+ )
+ assert resp.status_code == 200, resp.text
+ assert resp.headers["content-type"] == "video/mp4"
+ assert len(resp.content) > 1000, "Video file too small"
+
+ def test_t2v_async_lifecycle(self, server):
+ """Async video generation: create job ā poll ā download ā delete."""
+ base = server.url_for("v1", "videos")
+
+ # 1. Create job
+ create_resp = requests.post(
+ base,
+ json={
+ "prompt": "A rocket launching into a starry sky",
+ "size": "480x320",
+ "seconds": 1.0,
+ "fps": 8,
+ "num_inference_steps": 4,
+ "seed": 42,
+ },
+ )
+ assert create_resp.status_code == 202, create_resp.text
+ job = create_resp.json()
+ video_id = job["id"]
+ assert video_id.startswith("video_")
+ assert job["status"] == "queued"
+
+ # 2. Poll until completed (or timeout)
+ deadline = time.time() + 600 # 10 min
+ status = "queued"
+ while status not in ("completed", "failed") and time.time() < deadline:
+ time.sleep(2)
+ meta_resp = requests.get(f"{base}/{video_id}")
+ assert meta_resp.status_code == 200
+ status = meta_resp.json()["status"]
+
+ assert status == "completed", f"Video generation did not complete: {status}"
+
+ # 3. Download video content
+ content_resp = requests.get(f"{base}/{video_id}/content")
+ assert content_resp.status_code == 200
+ assert "video/mp4" in content_resp.headers.get("content-type", "")
+ assert len(content_resp.content) > 1000
+
+ # 4. Verify it appears in list
+ list_resp = requests.get(base)
+ assert list_resp.status_code == 200
+ ids = [v["id"] for v in list_resp.json()["data"]]
+ assert video_id in ids
+
+ # 5. Delete
+ del_resp = requests.delete(f"{base}/{video_id}")
+ assert del_resp.status_code == 200
+ assert del_resp.json()["deleted"] is True
+
+ # 6. Confirm gone
+ gone_resp = requests.get(f"{base}/{video_id}")
+ assert gone_resp.status_code == 404
+
+
+# =========================================================================
+# WAN 2.2 ā Image-to-Video (ti2v)
+# =========================================================================
+
+
+@pytest.mark.skipif(
+ not _model_available(_WAN_I2V_PATH), reason=f"Wan2.2-I2V model not found at {_WAN_I2V_PATH}"
+)
+@pytest.mark.skipif(
+ not _REF_IMAGE_PATH.is_file(), reason=f"Reference image not found at {_REF_IMAGE_PATH}"
+)
+@pytest.mark.skipif(
+ not _av_available(), reason="PyAV (av) not installed ā required for video encoding in E2E tests"
+)
+class TestWanImageToVideo:
+ """Test Wan2.2-I2V-A14B-Diffusers image-to-video generation via serve API."""
+
+ @pytest.fixture(scope="class")
+ def server(self):
+ with RemoteVisualGenServer(
+ model=str(_WAN_I2V_PATH),
+ extra_visual_gen_options=_make_visual_gen_options(),
+ ) as srv:
+ yield srv
+
+ # ------------------------------------------------------------------
+
+ def test_health(self, server):
+ resp = requests.get(server.url_for("health"))
+ assert resp.status_code == 200
+
+ def test_ti2v_sync(self, server):
+ """Synchronous image-to-video via multipart POST /v1/videos/generations."""
+ with open(_REF_IMAGE_PATH, "rb") as f:
+ resp = requests.post(
+ server.url_for("v1", "videos", "generations"),
+ data={
+ "prompt": "The cat starts playing piano, keys moving",
+ "size": "480x320",
+ "seconds": "1.0",
+ "fps": "8",
+ "num_inference_steps": "4",
+ "seed": "42",
+ },
+ files={
+ "input_reference": ("cat_piano.png", f, "image/png"),
+ },
+ )
+ assert resp.status_code == 200, resp.text
+ assert resp.headers["content-type"] == "video/mp4"
+ assert len(resp.content) > 1000, "Video file too small"
+
+ def test_ti2v_async_lifecycle(self, server):
+ """Async i2v: create job with image ā poll ā download ā delete."""
+ base = server.url_for("v1", "videos")
+
+ # 1. Create job via multipart
+ with open(_REF_IMAGE_PATH, "rb") as f:
+ create_resp = requests.post(
+ base,
+ data={
+ "prompt": "Snow falls on the piano and the cat",
+ "size": "480x320",
+ "seconds": "1.0",
+ "fps": "8",
+ "num_inference_steps": "4",
+ "seed": "42",
+ },
+ files={
+ "input_reference": ("cat_piano.png", f, "image/png"),
+ },
+ )
+ assert create_resp.status_code == 202, create_resp.text
+ job = create_resp.json()
+ video_id = job["id"]
+ assert job["status"] == "queued"
+
+ # 2. Poll until completed
+ deadline = time.time() + 600
+ status = "queued"
+ while status not in ("completed", "failed") and time.time() < deadline:
+ time.sleep(2)
+ meta_resp = requests.get(f"{base}/{video_id}")
+ assert meta_resp.status_code == 200
+ status = meta_resp.json()["status"]
+
+ assert status == "completed", f"Video generation did not complete: {status}"
+
+ # 3. Download
+ content_resp = requests.get(f"{base}/{video_id}/content")
+ assert content_resp.status_code == 200
+ assert "video/mp4" in content_resp.headers.get("content-type", "")
+ assert len(content_resp.content) > 1000
+
+ # 4. Delete
+ del_resp = requests.delete(f"{base}/{video_id}")
+ assert del_resp.status_code == 200
+ assert del_resp.json()["deleted"] is True
+
+ # 5. Confirm gone
+ gone_resp = requests.get(f"{base}/{video_id}")
+ assert gone_resp.status_code == 404
diff --git a/tests/unittest/_torch/visual_gen/test_trtllm_serve_endpoints.py b/tests/unittest/_torch/visual_gen/test_trtllm_serve_endpoints.py
new file mode 100644
index 0000000000..a66e742447
--- /dev/null
+++ b/tests/unittest/_torch/visual_gen/test_trtllm_serve_endpoints.py
@@ -0,0 +1,876 @@
+"""trtllm-serve visual_gen endpoints tests.
+
+Tests all endpoints registered for the VISUAL_GEN server role
+in OpenAIServer.register_visual_gen_routes():
+
+ POST /v1/images/generations
+ POST /v1/images/edits
+ POST /v1/videos/generations (sync)
+ POST /v1/videos (async)
+ GET /v1/videos (list)
+ GET /v1/videos/{video_id} (metadata)
+ GET /v1/videos/{video_id}/content (download)
+ DELETE /v1/videos/{video_id} (delete)
+"""
+
+import asyncio
+import base64
+import os
+from io import BytesIO
+from typing import Optional
+from unittest.mock import patch
+
+import pytest
+import torch
+from fastapi.testclient import TestClient
+from PIL import Image
+
+from tensorrt_llm._torch.visual_gen.output import MediaOutput
+from tensorrt_llm.serve.media_storage import MediaStorage
+from tensorrt_llm.serve.openai_protocol import VideoJob
+from tensorrt_llm.serve.visual_gen_utils import VIDEO_STORE
+
+# ---------------------------------------------------------------------------
+# Helpers
+# ---------------------------------------------------------------------------
+
+
+def _make_dummy_image_tensor(height: int = 64, width: int = 64) -> torch.Tensor:
+ """Create a small dummy uint8 image tensor (H, W, C)."""
+ return torch.randint(0, 256, (height, width, 3), dtype=torch.uint8)
+
+
+def _make_dummy_video_tensor(
+ num_frames: int = 4, height: int = 64, width: int = 64
+) -> torch.Tensor:
+ """Create a small dummy uint8 video tensor (T, H, W, C)."""
+ return torch.randint(0, 256, (num_frames, height, width, 3), dtype=torch.uint8)
+
+
+def _make_dummy_audio_tensor(length: int = 16000) -> torch.Tensor:
+ """Create a small dummy float32 audio tensor."""
+ return torch.randn(1, length, dtype=torch.float32)
+
+
+def _b64_white_png_1x1() -> str:
+ """Return a base64-encoded 1x1 white PNG for image edit tests."""
+ buf = BytesIO()
+ Image.new("RGB", (1, 1), (255, 255, 255)).save(buf, format="PNG")
+ return base64.b64encode(buf.getvalue()).decode("utf-8")
+
+
+def _run_async(coro):
+ """Run an async coroutine in a new event loop (for test helpers)."""
+ loop = asyncio.new_event_loop()
+ try:
+ return loop.run_until_complete(coro)
+ finally:
+ loop.close()
+
+
+# ---------------------------------------------------------------------------
+# Mock VisualGen
+# ---------------------------------------------------------------------------
+
+
+class MockVisualGen:
+ """Lightweight stand-in for VisualGen that avoids GPU / model loading."""
+
+ def __init__(
+ self,
+ image_output: Optional[torch.Tensor] = None,
+ video_output: Optional[torch.Tensor] = None,
+ audio_output: Optional[torch.Tensor] = None,
+ should_fail: bool = False,
+ ):
+ self._image = image_output
+ self._video = video_output
+ self._audio = audio_output
+ self._should_fail = should_fail
+ self._healthy = True
+ self.req_counter = 0
+
+ # --- VisualGen interface ---
+
+ def generate(self, inputs=None, params=None) -> MediaOutput:
+ if self._should_fail:
+ raise RuntimeError("Generation intentionally failed")
+ return MediaOutput(
+ image=self._image,
+ video=self._video,
+ audio=self._audio,
+ )
+
+ def generate_async(self, inputs=None, params=None) -> "MockDiffusionGenerationResult":
+ return MockDiffusionGenerationResult(
+ image=self._image,
+ video=self._video,
+ audio=self._audio,
+ should_fail=self._should_fail,
+ )
+
+ def _check_health(self) -> bool:
+ return self._healthy
+
+ async def get_stats_async(self, timeout: int):
+ return
+
+ def shutdown(self):
+ pass
+
+
+class MockDiffusionGenerationResult:
+ """Mock future-like result for generate_async."""
+
+ def __init__(
+ self,
+ image: Optional[torch.Tensor] = None,
+ video: Optional[torch.Tensor] = None,
+ audio: Optional[torch.Tensor] = None,
+ should_fail: bool = False,
+ ):
+ self._image = image
+ self._video = video
+ self._audio = audio
+ self._should_fail = should_fail
+
+ async def result(self, timeout=None):
+ if self._should_fail:
+ raise RuntimeError("Async generation intentionally failed")
+ return MediaOutput(
+ image=self._image,
+ video=self._video,
+ audio=self._audio,
+ )
+
+
+# ---------------------------------------------------------------------------
+# Server factory
+# ---------------------------------------------------------------------------
+
+
+def _create_server(generator: MockVisualGen, model_name: str = "test-model") -> TestClient:
+ """Instantiate an OpenAIServer for VISUAL_GEN with a mocked generator.
+
+ We patch the ``VisualGen`` name inside the ``openai_server`` module so that
+ ``isinstance(generator, VisualGen)`` returns True for our mock.
+ """
+ from tensorrt_llm.llmapi.disagg_utils import ServerRole
+ from tensorrt_llm.serve.openai_server import OpenAIServer
+
+ with patch("tensorrt_llm.serve.openai_server.VisualGen", MockVisualGen):
+ server = OpenAIServer(
+ generator=generator,
+ model=model_name,
+ tool_parser=None,
+ server_role=ServerRole.VISUAL_GEN,
+ metadata_server_cfg=None,
+ )
+ return TestClient(server.app)
+
+
+# ---------------------------------------------------------------------------
+# Fixtures
+# ---------------------------------------------------------------------------
+
+
+@pytest.fixture()
+def image_client(tmp_path):
+ """TestClient backed by a MockVisualGen that produces images."""
+ gen = MockVisualGen(image_output=_make_dummy_image_tensor())
+ os.environ["TRTLLM_MEDIA_STORAGE_PATH"] = str(tmp_path)
+ client = _create_server(gen)
+ yield client
+ os.environ.pop("TRTLLM_MEDIA_STORAGE_PATH", None)
+
+
+@pytest.fixture()
+def video_client(tmp_path):
+ """TestClient backed by a MockVisualGen that produces videos."""
+ gen = MockVisualGen(video_output=_make_dummy_video_tensor())
+ os.environ["TRTLLM_MEDIA_STORAGE_PATH"] = str(tmp_path)
+ client = _create_server(gen)
+ yield client
+ os.environ.pop("TRTLLM_MEDIA_STORAGE_PATH", None)
+
+
+@pytest.fixture()
+def video_audio_client(tmp_path):
+ """TestClient backed by a MockVisualGen that produces videos with audio."""
+ gen = MockVisualGen(
+ video_output=_make_dummy_video_tensor(),
+ audio_output=_make_dummy_audio_tensor(),
+ )
+ os.environ["TRTLLM_MEDIA_STORAGE_PATH"] = str(tmp_path)
+ client = _create_server(gen)
+ yield client
+ os.environ.pop("TRTLLM_MEDIA_STORAGE_PATH", None)
+
+
+@pytest.fixture()
+def failing_client(tmp_path):
+ """TestClient backed by a MockVisualGen that always fails."""
+ gen = MockVisualGen(should_fail=True)
+ os.environ["TRTLLM_MEDIA_STORAGE_PATH"] = str(tmp_path)
+ client = _create_server(gen)
+ yield client
+ os.environ.pop("TRTLLM_MEDIA_STORAGE_PATH", None)
+
+
+@pytest.fixture(autouse=True)
+def _clear_video_store():
+ """Reset the global VIDEO_STORE before each test."""
+ VIDEO_STORE._items.clear()
+ yield
+ VIDEO_STORE._items.clear()
+
+
+@pytest.fixture(autouse=True)
+def _mock_video_encoding():
+ """Mock MP4 encoding to avoid PyAV dependency in unit tests.
+
+ Replaces MediaStorage._save_mp4 with a stub that writes a small
+ dummy file so FileResponse can serve it.
+ """
+
+ def _dummy_save_mp4(video, audio, output_path, frame_rate):
+ os.makedirs(os.path.dirname(str(output_path)) or ".", exist_ok=True)
+ with open(str(output_path), "wb") as f:
+ f.write(b"\x00\x00\x00\x1cftypisom" + b"\x00" * 32)
+ return str(output_path)
+
+ with patch.object(MediaStorage, "_save_mp4", staticmethod(_dummy_save_mp4)):
+ yield
+
+
+# =========================================================================
+# POST /v1/images/generations
+# =========================================================================
+
+
+class TestImageGeneration:
+ def test_basic_image_generation_b64(self, image_client):
+ resp = image_client.post(
+ "/v1/images/generations",
+ json={
+ "prompt": "A cat sitting on a mat",
+ "response_format": "b64_json",
+ "size": "64x64",
+ },
+ )
+ assert resp.status_code == 200
+ data = resp.json()
+ assert "data" in data
+ assert len(data["data"]) >= 1
+ img_obj = data["data"][0]
+ assert img_obj["b64_json"] is not None
+ # Verify it decodes to valid bytes
+ decoded = base64.b64decode(img_obj["b64_json"])
+ assert len(decoded) > 0
+ assert img_obj["revised_prompt"] == "A cat sitting on a mat"
+
+ def test_image_generation_with_optional_params(self, image_client):
+ resp = image_client.post(
+ "/v1/images/generations",
+ json={
+ "prompt": "Sunset over ocean",
+ "response_format": "b64_json",
+ "size": "128x64",
+ "num_inference_steps": 20,
+ "guidance_scale": 7.5,
+ "seed": 123,
+ "negative_prompt": "blurry",
+ },
+ )
+ assert resp.status_code == 200
+ data = resp.json()
+ assert data["size"] == "128x64"
+
+ def test_image_generation_url_format_not_supported(self, image_client):
+ resp = image_client.post(
+ "/v1/images/generations",
+ json={
+ "prompt": "A dog",
+ "response_format": "url",
+ },
+ )
+ assert resp.status_code == 501
+
+ def test_image_generation_auto_size(self, image_client):
+ resp = image_client.post(
+ "/v1/images/generations",
+ json={
+ "prompt": "A tree",
+ "response_format": "b64_json",
+ "size": "auto",
+ },
+ )
+ assert resp.status_code == 200
+
+ def test_image_generation_failure(self, failing_client):
+ resp = failing_client.post(
+ "/v1/images/generations",
+ json={
+ "prompt": "A bird",
+ "response_format": "b64_json",
+ },
+ )
+ assert resp.status_code == 400
+
+ def test_image_generation_invalid_size(self, image_client):
+ """Invalid size triggers RequestValidationError ā custom handler ā 400."""
+ resp = image_client.post(
+ "/v1/images/generations",
+ json={
+ "prompt": "A mountain",
+ "response_format": "b64_json",
+ "size": "invalid",
+ },
+ )
+ assert resp.status_code == 400
+
+ def test_image_generation_null_output(self, tmp_path):
+ """Generator returns MediaOutput with image=None."""
+ gen = MockVisualGen(image_output=None)
+ os.environ["TRTLLM_MEDIA_STORAGE_PATH"] = str(tmp_path)
+ client = _create_server(gen)
+ resp = client.post(
+ "/v1/images/generations",
+ json={
+ "prompt": "null image",
+ "response_format": "b64_json",
+ },
+ )
+ assert resp.status_code == 500
+ os.environ.pop("TRTLLM_MEDIA_STORAGE_PATH", None)
+
+ def test_image_generation_multiple_n(self, image_client):
+ """Request n=2 images in one call."""
+ resp = image_client.post(
+ "/v1/images/generations",
+ json={
+ "prompt": "Flowers",
+ "response_format": "b64_json",
+ "size": "64x64",
+ "n": 2,
+ },
+ )
+ assert resp.status_code == 200
+
+ def test_image_generation_hd_quality(self, image_client):
+ resp = image_client.post(
+ "/v1/images/generations",
+ json={
+ "prompt": "HD landscape",
+ "response_format": "b64_json",
+ "quality": "hd",
+ },
+ )
+ assert resp.status_code == 200
+
+ def test_missing_prompt_image_generation(self, image_client):
+ """Missing required field ā RequestValidationError ā custom handler ā 400."""
+ resp = image_client.post(
+ "/v1/images/generations",
+ json={},
+ )
+ assert resp.status_code == 400
+
+
+# =========================================================================
+# POST /v1/images/edits
+# =========================================================================
+
+
+class TestImageEdit:
+ def test_basic_image_edit(self, image_client):
+ b64_img = _b64_white_png_1x1()
+ resp = image_client.post(
+ "/v1/images/edits",
+ json={
+ "image": b64_img,
+ "prompt": "Make it blue",
+ "num_inference_steps": 10,
+ },
+ )
+ assert resp.status_code == 200
+ data = resp.json()
+ assert "data" in data
+ assert len(data["data"]) >= 1
+ assert data["data"][0]["b64_json"] is not None
+
+ def test_image_edit_with_list_images(self, image_client):
+ b64_img = _b64_white_png_1x1()
+ resp = image_client.post(
+ "/v1/images/edits",
+ json={
+ "image": [b64_img, b64_img],
+ "prompt": "Merge them",
+ "num_inference_steps": 10,
+ },
+ )
+ assert resp.status_code == 200
+
+ def test_image_edit_with_mask(self, image_client):
+ b64_img = _b64_white_png_1x1()
+ b64_mask = _b64_white_png_1x1()
+ resp = image_client.post(
+ "/v1/images/edits",
+ json={
+ "image": b64_img,
+ "prompt": "Remove object",
+ "mask": b64_mask,
+ "num_inference_steps": 10,
+ },
+ )
+ assert resp.status_code == 200
+
+ def test_image_edit_with_optional_params(self, image_client):
+ b64_img = _b64_white_png_1x1()
+ resp = image_client.post(
+ "/v1/images/edits",
+ json={
+ "image": b64_img,
+ "prompt": "Enhance colors",
+ "size": "128x128",
+ "guidance_scale": 8.0,
+ "num_inference_steps": 15,
+ "seed": 42,
+ "negative_prompt": "dark",
+ },
+ )
+ assert resp.status_code == 200
+ data = resp.json()
+ assert data["size"] == "128x128"
+
+ def test_image_edit_failure(self, failing_client):
+ b64_img = _b64_white_png_1x1()
+ resp = failing_client.post(
+ "/v1/images/edits",
+ json={
+ "image": b64_img,
+ "prompt": "Edit this",
+ "num_inference_steps": 10,
+ },
+ )
+ assert resp.status_code == 500
+
+ def test_missing_image_for_edit(self, image_client):
+ """Missing required field ā RequestValidationError ā custom handler ā 400."""
+ resp = image_client.post(
+ "/v1/images/edits",
+ json={
+ "prompt": "Edit without image",
+ },
+ )
+ assert resp.status_code == 400
+
+
+# =========================================================================
+# POST /v1/videos/generations (synchronous)
+# =========================================================================
+
+
+@pytest.mark.threadleak(enabled=False) # FileResponse spawns AnyIO worker threads
+class TestVideoGenerationSync:
+ def test_basic_sync_video_generation(self, video_client):
+ resp = video_client.post(
+ "/v1/videos/generations",
+ json={
+ "prompt": "A rocket launching",
+ "size": "64x64",
+ "seconds": 1.0,
+ "fps": 8,
+ },
+ headers={"content-type": "application/json"},
+ )
+ assert resp.status_code == 200
+ assert resp.headers["content-type"] == "video/mp4"
+ assert len(resp.content) > 0
+
+ def test_sync_video_generation_with_params(self, video_client):
+ resp = video_client.post(
+ "/v1/videos/generations",
+ json={
+ "prompt": "Ocean waves",
+ "size": "64x64",
+ "seconds": 2.0,
+ "fps": 8,
+ "num_inference_steps": 10,
+ "guidance_scale": 5.0,
+ "seed": 42,
+ "negative_prompt": "blurry",
+ },
+ headers={"content-type": "application/json"},
+ )
+ assert resp.status_code == 200
+ assert len(resp.content) > 0
+
+ def test_sync_video_generation_multipart(self, video_client):
+ # Use files={} with a dummy file to ensure multipart/form-data
+ dummy_file = BytesIO(b"")
+ resp = video_client.post(
+ "/v1/videos/generations",
+ data={
+ "prompt": "Mountain sunrise",
+ "size": "64x64",
+ "seconds": "1.0",
+ "fps": "8",
+ },
+ files={"_dummy": ("dummy", dummy_file, "application/octet-stream")},
+ )
+ # The server will parse fields; _dummy is ignored since it's not "input_reference"
+ assert resp.status_code == 200
+ assert len(resp.content) > 0
+
+ def test_sync_video_generation_multipart_with_reference(self, video_client, tmp_path):
+ # Create a dummy reference image file
+ ref_path = tmp_path / "ref.png"
+ Image.new("RGB", (4, 4), (128, 128, 128)).save(str(ref_path))
+
+ with open(ref_path, "rb") as f:
+ resp = video_client.post(
+ "/v1/videos/generations",
+ data={
+ "prompt": "Animate this image",
+ "size": "64x64",
+ "seconds": "1.0",
+ "fps": "8",
+ },
+ files={"input_reference": ("ref.png", f, "image/png")},
+ )
+ assert resp.status_code == 200
+ assert len(resp.content) > 0
+
+ def test_sync_video_failure(self, failing_client):
+ resp = failing_client.post(
+ "/v1/videos/generations",
+ json={
+ "prompt": "Should fail",
+ "size": "64x64",
+ "seconds": 1.0,
+ "fps": 8,
+ },
+ headers={"content-type": "application/json"},
+ )
+ assert resp.status_code == 400
+
+ def test_sync_video_null_output(self, tmp_path):
+ """Generator returns MediaOutput with video=None."""
+ gen = MockVisualGen(video_output=None)
+ os.environ["TRTLLM_MEDIA_STORAGE_PATH"] = str(tmp_path)
+ client = _create_server(gen)
+ resp = client.post(
+ "/v1/videos/generations",
+ json={"prompt": "null video", "size": "64x64", "seconds": 1.0, "fps": 8},
+ headers={"content-type": "application/json"},
+ )
+ assert resp.status_code == 500
+ os.environ.pop("TRTLLM_MEDIA_STORAGE_PATH", None)
+
+ def test_sync_video_unsupported_content_type(self, video_client):
+ resp = video_client.post(
+ "/v1/videos/generations",
+ content=b"some raw bytes",
+ headers={"content-type": "text/plain"},
+ )
+ assert resp.status_code == 400
+
+ def test_sync_video_missing_prompt_json(self, video_client):
+ """Missing required prompt ā Pydantic ValidationError ā 400."""
+ resp = video_client.post(
+ "/v1/videos/generations",
+ json={"size": "64x64"},
+ headers={"content-type": "application/json"},
+ )
+ assert resp.status_code == 400
+
+ def test_sync_video_missing_prompt_multipart(self, video_client):
+ """Missing prompt in multipart form ā ValueError ā 400."""
+ dummy_file = BytesIO(b"")
+ resp = video_client.post(
+ "/v1/videos/generations",
+ data={"size": "64x64"},
+ files={"_dummy": ("dummy", dummy_file, "application/octet-stream")},
+ )
+ assert resp.status_code == 400
+
+
+# =========================================================================
+# POST /v1/videos (asynchronous)
+# =========================================================================
+
+
+class TestVideoGenerationAsync:
+ def test_async_video_returns_202(self, video_client):
+ resp = video_client.post(
+ "/v1/videos",
+ json={
+ "prompt": "A dancing robot",
+ "size": "64x64",
+ "seconds": 1.0,
+ "fps": 8,
+ },
+ headers={"content-type": "application/json"},
+ )
+ assert resp.status_code == 202
+ data = resp.json()
+ assert data["status"] == "queued"
+ assert data["object"] == "video"
+ assert data["prompt"] == "A dancing robot"
+ assert data["id"].startswith("video_")
+
+ def test_async_video_job_metadata_fields(self, video_client):
+ resp = video_client.post(
+ "/v1/videos",
+ json={
+ "prompt": "Starry night",
+ "size": "64x64",
+ "seconds": 2.0,
+ "fps": 12,
+ },
+ headers={"content-type": "application/json"},
+ )
+ data = resp.json()
+ assert "created_at" in data
+ assert data["duration"] == 2.0
+ assert data["fps"] == 12
+ assert data["size"] == "64x64"
+
+ def test_async_video_multipart(self, video_client):
+ """Multipart encoding requires a file field to trigger the correct content-type."""
+ dummy_file = BytesIO(b"")
+ resp = video_client.post(
+ "/v1/videos",
+ data={
+ "prompt": "A sunset",
+ "size": "64x64",
+ "seconds": "1.0",
+ "fps": "8",
+ },
+ files={"_dummy": ("dummy", dummy_file, "application/octet-stream")},
+ )
+ assert resp.status_code == 202
+
+ def test_async_video_invalid_seconds(self, video_client):
+ """Seconds must be between 1.0 and 16.0. Validation error ā 400."""
+ resp = video_client.post(
+ "/v1/videos",
+ json={
+ "prompt": "Too short",
+ "seconds": 0.1,
+ "size": "64x64",
+ "fps": 8,
+ },
+ headers={"content-type": "application/json"},
+ )
+ assert resp.status_code == 400
+
+ def test_async_video_invalid_fps(self, video_client):
+ """Fps must be between 8 and 60. Validation error ā 400."""
+ resp = video_client.post(
+ "/v1/videos",
+ json={
+ "prompt": "Bad fps",
+ "seconds": 1.0,
+ "fps": 2,
+ "size": "64x64",
+ },
+ headers={"content-type": "application/json"},
+ )
+ assert resp.status_code == 400
+
+
+# =========================================================================
+# GET /v1/videos (list)
+# =========================================================================
+
+
+class TestListVideos:
+ def test_list_videos_empty(self, video_client):
+ resp = video_client.get("/v1/videos")
+ assert resp.status_code == 200
+ data = resp.json()
+ assert data["object"] == "list"
+ assert data["data"] == []
+
+ def test_list_videos_after_creation(self, video_client):
+ # Create two video jobs
+ video_client.post(
+ "/v1/videos",
+ json={"prompt": "First video", "size": "64x64", "seconds": 1.0, "fps": 8},
+ headers={"content-type": "application/json"},
+ )
+ video_client.post(
+ "/v1/videos",
+ json={"prompt": "Second video", "size": "64x64", "seconds": 1.0, "fps": 8},
+ headers={"content-type": "application/json"},
+ )
+
+ resp = video_client.get("/v1/videos")
+ assert resp.status_code == 200
+ data = resp.json()
+ assert len(data["data"]) == 2
+
+
+# =========================================================================
+# GET /v1/videos/{video_id} (metadata)
+# =========================================================================
+
+
+class TestGetVideoMetadata:
+ def test_get_video_metadata_success(self, video_client):
+ create_resp = video_client.post(
+ "/v1/videos",
+ json={"prompt": "Space walk", "size": "64x64", "seconds": 1.0, "fps": 8},
+ headers={"content-type": "application/json"},
+ )
+ video_id = create_resp.json()["id"]
+
+ resp = video_client.get(f"/v1/videos/{video_id}")
+ assert resp.status_code == 200
+ data = resp.json()
+ assert data["id"] == video_id
+ assert data["object"] == "video"
+ assert data["prompt"] == "Space walk"
+
+ def test_get_video_metadata_not_found(self, video_client):
+ resp = video_client.get("/v1/videos/video_nonexistent")
+ assert resp.status_code == 404
+
+
+# =========================================================================
+# GET /v1/videos/{video_id}/content (download)
+# =========================================================================
+
+
+@pytest.mark.threadleak(enabled=False) # FileResponse spawns AnyIO worker threads
+class TestGetVideoContent:
+ def _insert_video_job(self, video_id: str, status: str = "queued"):
+ import time as _time
+
+ job = VideoJob(
+ created_at=int(_time.time()),
+ id=video_id,
+ model="test-model",
+ prompt="test prompt",
+ status=status,
+ )
+ _run_async(VIDEO_STORE.upsert(video_id, job))
+
+ def test_get_video_content_success(self, tmp_path):
+ gen = MockVisualGen(video_output=_make_dummy_video_tensor())
+ os.environ["TRTLLM_MEDIA_STORAGE_PATH"] = str(tmp_path)
+ client = _create_server(gen)
+
+ video_id = "video_testcontent"
+ self._insert_video_job(video_id, status="completed")
+
+ # Write a dummy mp4 file so FileResponse can serve it
+ video_path = tmp_path / f"{video_id}.mp4"
+ video_path.write_bytes(b"\x00\x00\x00\x1cftyp" + b"\x00" * 16)
+
+ resp = client.get(f"/v1/videos/{video_id}/content")
+ assert resp.status_code == 200
+ assert "video/mp4" in resp.headers.get("content-type", "")
+ assert len(resp.content) > 0
+ os.environ.pop("TRTLLM_MEDIA_STORAGE_PATH", None)
+
+ def test_get_video_content_not_found(self, video_client):
+ resp = video_client.get("/v1/videos/video_nonexistent/content")
+ assert resp.status_code == 404
+
+ def test_get_video_content_not_ready(self, tmp_path):
+ """A queued video should return 400 when its content is requested."""
+ gen = MockVisualGen(video_output=_make_dummy_video_tensor())
+ os.environ["TRTLLM_MEDIA_STORAGE_PATH"] = str(tmp_path)
+ client = _create_server(gen)
+
+ video_id = "video_notready"
+ self._insert_video_job(video_id, status="queued")
+
+ resp = client.get(f"/v1/videos/{video_id}/content")
+ assert resp.status_code == 400
+ os.environ.pop("TRTLLM_MEDIA_STORAGE_PATH", None)
+
+ def test_get_video_content_completed_but_file_missing(self, tmp_path):
+ """Video marked completed but file deleted from disk ā 404."""
+ gen = MockVisualGen(video_output=_make_dummy_video_tensor())
+ os.environ["TRTLLM_MEDIA_STORAGE_PATH"] = str(tmp_path)
+ client = _create_server(gen)
+
+ video_id = "video_nofile"
+ self._insert_video_job(video_id, status="completed")
+ # Do NOT write a file
+
+ resp = client.get(f"/v1/videos/{video_id}/content")
+ assert resp.status_code == 404
+ os.environ.pop("TRTLLM_MEDIA_STORAGE_PATH", None)
+
+
+# =========================================================================
+# DELETE /v1/videos/{video_id}
+# =========================================================================
+
+
+class TestDeleteVideo:
+ def test_delete_video_success(self, tmp_path):
+ gen = MockVisualGen(video_output=_make_dummy_video_tensor())
+ os.environ["TRTLLM_MEDIA_STORAGE_PATH"] = str(tmp_path)
+ client = _create_server(gen)
+
+ create_resp = client.post(
+ "/v1/videos",
+ json={"prompt": "Delete me", "size": "64x64", "seconds": 1.0, "fps": 8},
+ headers={"content-type": "application/json"},
+ )
+ video_id = create_resp.json()["id"]
+
+ # Write a dummy video file
+ (tmp_path / f"{video_id}.mp4").write_bytes(b"\x00" * 32)
+
+ resp = client.delete(f"/v1/videos/{video_id}")
+ assert resp.status_code == 200
+ data = resp.json()
+ assert data["deleted"] is True
+
+ # Verify it's gone from the store
+ resp = client.get(f"/v1/videos/{video_id}")
+ assert resp.status_code == 404
+
+ # Verify file is deleted
+ assert not (tmp_path / f"{video_id}.mp4").exists()
+ os.environ.pop("TRTLLM_MEDIA_STORAGE_PATH", None)
+
+ def test_delete_video_not_found(self, video_client):
+ resp = video_client.delete("/v1/videos/video_nonexistent")
+ assert resp.status_code == 404
+
+ def test_delete_video_without_file_on_disk(self, video_client):
+ """Delete a video job that exists in the store but has no file on disk."""
+ create_resp = video_client.post(
+ "/v1/videos",
+ json={"prompt": "No file", "size": "64x64", "seconds": 1.0, "fps": 8},
+ headers={"content-type": "application/json"},
+ )
+ video_id = create_resp.json()["id"]
+
+ resp = video_client.delete(f"/v1/videos/{video_id}")
+ assert resp.status_code == 200
+ data = resp.json()
+ assert data["deleted"] is True
+
+ def test_delete_video_then_list_empty(self, video_client):
+ """After deleting the only video, the list should be empty."""
+ create_resp = video_client.post(
+ "/v1/videos",
+ json={"prompt": "Ephemeral", "size": "64x64", "seconds": 1.0, "fps": 8},
+ headers={"content-type": "application/json"},
+ )
+ video_id = create_resp.json()["id"]
+
+ video_client.delete(f"/v1/videos/{video_id}")
+
+ resp = video_client.get("/v1/videos")
+ assert resp.status_code == 200
+ assert resp.json()["data"] == []
diff --git a/tests/unittest/_torch/visual_gen/test_wan.py b/tests/unittest/_torch/visual_gen/test_wan.py
new file mode 100644
index 0000000000..998452cea2
--- /dev/null
+++ b/tests/unittest/_torch/visual_gen/test_wan.py
@@ -0,0 +1,3094 @@
+"""Comprehensive unit tests for the Wan model and pipeline."""
+
+import os
+
+os.environ["TLLM_DISABLE_MPI"] = "1"
+
+import unittest
+from copy import deepcopy
+from pathlib import Path
+from types import SimpleNamespace
+
+import pytest
+import torch
+import torch.distributed as dist
+import torch.multiprocessing as mp
+import torch.nn.functional as F
+from diffusers import WanTransformer3DModel as HFWanTransformer3DModel
+from parameterized import parameterized
+
+from tensorrt_llm._torch.modules.linear import Linear
+from tensorrt_llm._torch.visual_gen.config import (
+ AttentionConfig,
+ DiffusionArgs,
+ DiffusionModelConfig,
+ ParallelConfig,
+ PipelineComponent,
+ TeaCacheConfig,
+)
+from tensorrt_llm._torch.visual_gen.models.wan.transformer_wan import WanTransformer3DModel
+from tensorrt_llm._torch.visual_gen.pipeline_loader import PipelineLoader
+from tensorrt_llm.models.modeling_utils import QuantConfig
+from tensorrt_llm.quantization.mode import QuantAlgo
+
+
+@pytest.fixture(autouse=True, scope="module")
+def _cleanup_mpi_env():
+ """Clean up TLLM_DISABLE_MPI env var after tests complete."""
+ yield
+ os.environ.pop("TLLM_DISABLE_MPI", None)
+
+
+def _llm_models_root() -> str:
+ """Return LLM_MODELS_ROOT path if it is set in env, assert when it's set but not a valid path."""
+ root = Path("/home/scratch.trt_llm_data_ci/llm-models/")
+ if "LLM_MODELS_ROOT" in os.environ:
+ root = Path(os.environ["LLM_MODELS_ROOT"])
+ if not root.exists():
+ root = Path("/scratch.trt_llm_data/llm-models/")
+ assert root.exists(), (
+ "You shall set LLM_MODELS_ROOT env or be able to access scratch.trt_llm_data to run this test"
+ )
+ return str(root)
+
+
+# Checkpoint paths for integration tests
+CHECKPOINT_PATH = os.environ.get(
+ "DIFFUSION_MODEL_PATH",
+ os.path.join(_llm_models_root(), "Wan2.1-T2V-1.3B-Diffusers"),
+)
+# Wan 2.2 TI2V-5B: BF16 base, FP8 pre-quantized, NVFP4 pre-quantized
+CHECKPOINT_PATH_WAN22_BF16 = os.environ.get(
+ "DIFFUSION_MODEL_PATH_WAN22_BF16",
+ os.path.join(_llm_models_root(), "Wan2.2-TI2V-5B-Diffusers"),
+)
+CHECKPOINT_PATH_WAN22_FP8 = os.environ.get(
+ "DIFFUSION_MODEL_PATH_WAN22_FP8",
+ os.path.join(_llm_models_root(), "Wan2.2-TI2V-5B-Diffusers-FP8"),
+)
+CHECKPOINT_PATH_WAN22_NVFP4 = os.environ.get(
+ "DIFFUSION_MODEL_PATH_WAN22_NVFP4",
+ os.path.join(_llm_models_root(), "Wan2.2-TI2V-5B-Diffusers-NVFP4"),
+)
+# Wan 2.2 T2V (two-stage transformer)
+CHECKPOINT_PATH_WAN22_T2V = os.environ.get(
+ "DIFFUSION_MODEL_PATH_WAN22_T2V",
+ os.path.join(_llm_models_root(), "Wan2.2-T2V-A14B-Diffusers"),
+)
+SKIP_COMPONENTS = [
+ PipelineComponent.TEXT_ENCODER,
+ PipelineComponent.VAE,
+ PipelineComponent.TOKENIZER,
+ PipelineComponent.SCHEDULER,
+]
+
+
+def is_wan21_checkpoint() -> bool:
+ """Check if DIFFUSION_MODEL_PATH is Wan 2.1 (contains '2.1' in path)."""
+ return "2.1" in CHECKPOINT_PATH
+
+
+def is_wan22_checkpoint() -> bool:
+ """Check if DIFFUSION_MODEL_PATH is Wan 2.2 (contains '2.2' in path)."""
+ return "2.2" in CHECKPOINT_PATH_WAN22_T2V
+
+
+WAN_1_3B_CONFIG = {
+ "attention_head_dim": 128,
+ "eps": 1e-06,
+ "ffn_dim": 8960,
+ "freq_dim": 256,
+ "in_channels": 16,
+ "num_attention_heads": 12,
+ "num_layers": 30,
+ "out_channels": 16,
+ "patch_size": [1, 2, 2],
+ "qk_norm": "rms_norm_across_heads",
+ "rope_max_seq_len": 1024,
+ "text_dim": 4096,
+ "torch_dtype": "bfloat16",
+ "cross_attn_norm": True,
+}
+
+
+def reduce_wan_config(mem_for_full_model: int, config_dict: dict):
+ """Reduce model size if insufficient GPU memory."""
+ _, total_mem = torch.cuda.mem_get_info()
+ if total_mem < mem_for_full_model:
+ model_fraction = total_mem / mem_for_full_model
+ num_layers = max(1, int(config_dict["num_layers"] * model_fraction))
+ config_dict["num_layers"] = min(num_layers, 4)
+
+
+def setup_distributed(rank, world_size, backend="nccl"):
+ """Initialize distributed process group for multi-GPU tests."""
+ os.environ["MASTER_ADDR"] = "localhost"
+ os.environ["MASTER_PORT"] = "12355"
+ os.environ["RANK"] = str(rank)
+ os.environ["WORLD_SIZE"] = str(world_size)
+
+ dist.init_process_group(backend=backend, rank=rank, world_size=world_size)
+ torch.cuda.set_device(rank)
+
+
+def cleanup_distributed():
+ """Clean up distributed process group."""
+ if dist.is_initialized():
+ dist.destroy_process_group()
+
+
+def _run_cfg_worker(rank, world_size, checkpoint_path, inputs_list, return_dict):
+ """Worker function for CFG Parallelism multi-GPU test.
+
+ Must be module-level for multiprocessing.spawn() pickling.
+ """
+ try:
+ setup_distributed(rank, world_size)
+
+ from tensorrt_llm._torch.visual_gen.config import DiffusionArgs, ParallelConfig
+ from tensorrt_llm._torch.visual_gen.pipeline_loader import PipelineLoader
+
+ # Load pipeline with CFG parallel
+ args = DiffusionArgs(
+ checkpoint_path=checkpoint_path,
+ device=f"cuda:{rank}",
+ dtype="bfloat16",
+ skip_components=SKIP_COMPONENTS,
+ parallel=ParallelConfig(dit_cfg_size=world_size),
+ )
+ pipeline = PipelineLoader(args).load()
+
+ # Verify CFG parallel configuration
+ assert pipeline.model_config.parallel.dit_cfg_size == world_size, (
+ f"Expected cfg_size={world_size}, got {pipeline.model_config.parallel.dit_cfg_size}"
+ )
+
+ # Load inputs on this GPU
+ prompt_embeds = inputs_list[0].to(f"cuda:{rank}")
+ neg_prompt_embeds = inputs_list[1].to(f"cuda:{rank}")
+ latents = inputs_list[2].to(f"cuda:{rank}")
+ timestep = inputs_list[3].to(f"cuda:{rank}")
+
+ # Setup CFG config
+ cfg_config = pipeline._setup_cfg_config(
+ guidance_scale=5.0,
+ prompt_embeds=prompt_embeds,
+ neg_prompt_embeds=neg_prompt_embeds,
+ )
+
+ # Verify CFG parallel is enabled
+ assert cfg_config["enabled"], f"Rank {rank}: CFG parallel not enabled"
+ assert cfg_config["cfg_size"] == world_size, f"Rank {rank}: Wrong cfg_size"
+
+ expected_cfg_group = rank // cfg_config["ulysses_size"]
+ assert cfg_config["cfg_group"] == expected_cfg_group, (
+ f"Rank {rank}: Wrong cfg_group. Expected {expected_cfg_group}, got {cfg_config['cfg_group']}"
+ )
+
+ if rank == 0:
+ print(f"[CFG Rank {rank}] Loaded with cfg_size={world_size}")
+ print(f" cfg_group: {cfg_config['cfg_group']}")
+ print(f" local_embeds shape: {cfg_config['local_embeds'].shape}")
+ print(f" Using {'positive' if cfg_config['cfg_group'] == 0 else 'negative'} prompts")
+
+ # Verify prompt splitting - rank 0 gets positive, rank 1 gets negative
+ expected_embeds = prompt_embeds if cfg_config["cfg_group"] == 0 else neg_prompt_embeds
+ assert torch.allclose(cfg_config["local_embeds"], expected_embeds), (
+ f"Rank {rank}: local_embeds doesn't match expected"
+ f"{'positive' if cfg_config['cfg_group'] == 0 else 'negative'} embeds"
+ )
+
+ # Run single denoising step with CFG parallel
+ def forward_fn(
+ latents, extra_stream_latents, timestep, encoder_hidden_states, extra_tensors
+ ):
+ return pipeline.transformer( # noqa: F821
+ hidden_states=latents,
+ timestep=timestep,
+ encoder_hidden_states=encoder_hidden_states,
+ )
+
+ with torch.no_grad():
+ noise_pred, _, _, _ = pipeline._denoise_step_cfg_parallel(
+ latents=latents,
+ extra_stream_latents={},
+ timestep=timestep,
+ local_embeds=cfg_config["local_embeds"],
+ forward_fn=forward_fn,
+ guidance_scale=5.0,
+ guidance_rescale=0.0,
+ ulysses_size=cfg_config["ulysses_size"],
+ local_extras={},
+ )
+
+ # Validate output
+ assert not torch.isnan(noise_pred).any(), f"Rank {rank}: Output contains NaN"
+ assert not torch.isinf(noise_pred).any(), f"Rank {rank}: Output contains Inf"
+
+ # Return output from rank 0
+ if rank == 0:
+ return_dict["output"] = noise_pred.cpu()
+ print(f"[CFG Rank {rank}] ā Output shape: {noise_pred.shape}")
+ print(
+ f"[CFG Rank {rank}] ā Output range: [{noise_pred.min():.4f}, {noise_pred.max():.4f}]"
+ )
+
+ del pipeline
+ torch.cuda.empty_cache()
+
+ finally:
+ cleanup_distributed()
+
+
+def _run_all_optimizations_worker(rank, world_size, checkpoint_path, inputs_list, return_dict):
+ """Worker function for all optimizations combined test (FP8 + TeaCache + TRTLLM + CFG).
+
+ Must be module-level for multiprocessing.spawn() pickling.
+ """
+ try:
+ setup_distributed(rank, world_size)
+
+ # Load pipeline with ALL optimizations
+ args_full = DiffusionArgs(
+ checkpoint_path=checkpoint_path,
+ device=f"cuda:{rank}",
+ dtype="bfloat16",
+ skip_components=SKIP_COMPONENTS,
+ quant_config={"quant_algo": "FP8", "dynamic": True},
+ teacache=TeaCacheConfig(
+ enable_teacache=True,
+ teacache_thresh=0.2,
+ use_ret_steps=True,
+ ),
+ attention=AttentionConfig(backend="TRTLLM"),
+ parallel=ParallelConfig(dit_cfg_size=world_size),
+ )
+ pipeline = PipelineLoader(args_full).load()
+ transformer = pipeline.transformer.eval()
+
+ # Verify all optimizations are enabled
+ assert pipeline.model_config.parallel.dit_cfg_size == world_size, "CFG parallel not enabled"
+ assert transformer.model_config.quant_config.quant_algo == QuantAlgo.FP8, "FP8 not enabled"
+ assert hasattr(pipeline, "cache_backend"), "TeaCache not enabled"
+ assert transformer.blocks[0].attn1.attn_backend == "TRTLLM", (
+ "TRTLLM not enabled for self-attn"
+ )
+
+ if rank == 0:
+ print(f" ā All optimizations verified on rank {rank}:")
+ print(f" - FP8 quantization: {transformer.model_config.quant_config.quant_algo}")
+ print(" - TeaCache: enabled")
+ print(f" - TRTLLM attention: {transformer.blocks[0].attn1.attn_backend}")
+ print(f" - CFG Parallelism: cfg_size={world_size}")
+
+ # Initialize TeaCache for single-step inference
+ if hasattr(pipeline, "cache_backend"):
+ pipeline.cache_backend.refresh(num_inference_steps=1)
+
+ # Load inputs on this GPU
+ prompt_embeds = inputs_list[0].to(f"cuda:{rank}")
+ neg_prompt_embeds = inputs_list[1].to(f"cuda:{rank}")
+ latents = inputs_list[2].to(f"cuda:{rank}")
+ timestep = inputs_list[3].to(f"cuda:{rank}")
+
+ # Setup CFG config
+ cfg_config = pipeline._setup_cfg_config(
+ guidance_scale=5.0,
+ prompt_embeds=prompt_embeds,
+ neg_prompt_embeds=neg_prompt_embeds,
+ )
+
+ assert cfg_config["enabled"], "CFG parallel not enabled"
+
+ # Run single denoising step with all optimizations
+ def forward_fn(
+ latents, extra_stream_latents, timestep, encoder_hidden_states, extra_tensors
+ ):
+ return transformer( # noqa: F821
+ hidden_states=latents,
+ timestep=timestep,
+ encoder_hidden_states=encoder_hidden_states,
+ )
+
+ with torch.no_grad():
+ noise_pred, _, _, _ = pipeline._denoise_step_cfg_parallel(
+ latents=latents,
+ extra_stream_latents={},
+ timestep=timestep,
+ local_embeds=cfg_config["local_embeds"],
+ forward_fn=forward_fn,
+ guidance_scale=5.0,
+ guidance_rescale=0.0,
+ ulysses_size=cfg_config["ulysses_size"],
+ local_extras={},
+ )
+
+ # Validate output
+ assert not torch.isnan(noise_pred).any(), f"Rank {rank}: Output contains NaN"
+ assert not torch.isinf(noise_pred).any(), f"Rank {rank}: Output contains Inf"
+
+ # Return output from rank 0
+ if rank == 0:
+ return_dict["output"] = noise_pred.cpu()
+ print(f" ā Combined optimization output shape: {noise_pred.shape}")
+ print(
+ f" ā Combined optimization range: [{noise_pred.min():.4f}, {noise_pred.max():.4f}]"
+ )
+
+ del pipeline, transformer
+ torch.cuda.empty_cache()
+
+ finally:
+ cleanup_distributed()
+
+
+# =============================================================================
+# Basic Unit Tests
+# =============================================================================
+
+
+class TestWan(unittest.TestCase):
+ DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
+
+ def _create_model_config(self, config_dict):
+ """Helper to create DiffusionModelConfig from test config dict."""
+ # Create pretrained_config as SimpleNamespace
+ pretrained_config = SimpleNamespace(**config_dict)
+
+ # Use default quantization (no quantization for unit tests)
+ quant_config = QuantConfig()
+ dynamic_weight_quant = False
+ dynamic_activation_quant = False
+
+ # Create DiffusionModelConfig
+ model_config = DiffusionModelConfig(
+ pretrained_config=pretrained_config,
+ quant_config=quant_config,
+ quant_config_dict=None,
+ dynamic_weight_quant=dynamic_weight_quant,
+ force_dynamic_quantization=dynamic_activation_quant,
+ skip_create_weights_in_init=False, # Create weights immediately for testing
+ )
+ return model_config
+
+ def test_wan_model_structure(self):
+ """Test that model structure matches HuggingFace naming."""
+ config = deepcopy(WAN_1_3B_CONFIG)
+ config["num_layers"] = 1
+ hidden_size = config["num_attention_heads"] * config["attention_head_dim"]
+ config["hidden_size"] = hidden_size
+
+ model_config = self._create_model_config(config)
+
+ model = WanTransformer3DModel(model_config=model_config)
+
+ # Check FFN structure
+ param_names = [n for n in model.state_dict().keys() if "ffn" in n]
+ print("\n[DEBUG] FFN parameter names in TRT-LLM model:")
+ for pn in param_names[:5]:
+ print(f" - {pn}")
+
+ # Verify expected structure exists (MLP uses up_proj/down_proj)
+ assert any("ffn.up_proj" in n for n in param_names), "Missing ffn.up_proj structure"
+ assert any("ffn.down_proj" in n for n in param_names), "Missing ffn.down_proj structure"
+
+ def test_wan_sanity(self):
+ """Basic sanity test that the model can run forward pass."""
+ config = deepcopy(WAN_1_3B_CONFIG)
+ dtype = getattr(torch, config["torch_dtype"])
+ # Use fewer layers for sanity test
+ config["num_layers"] = 2
+
+ hidden_size = config["num_attention_heads"] * config["attention_head_dim"]
+ config["hidden_size"] = hidden_size
+
+ # Create model config
+ model_config = self._create_model_config(config)
+
+ # Create model with model_config
+ model = WanTransformer3DModel(model_config=model_config).to(self.DEVICE, dtype=dtype).eval()
+
+ batch_size = 1
+ num_frames = 1
+ height, width = 64, 64
+ seq_len = 128
+ generator = torch.Generator(device=self.DEVICE).manual_seed(42)
+
+ hidden_states = torch.randn(
+ batch_size,
+ config["in_channels"],
+ num_frames,
+ height,
+ width,
+ generator=generator,
+ device=self.DEVICE,
+ dtype=dtype,
+ )
+ timestep = torch.tensor([50], device=self.DEVICE, dtype=torch.long)
+ encoder_hidden_states = torch.randn(
+ batch_size,
+ seq_len,
+ config["text_dim"],
+ generator=generator,
+ device=self.DEVICE,
+ dtype=dtype,
+ )
+
+ with torch.inference_mode():
+ output = model(
+ hidden_states=hidden_states,
+ timestep=timestep,
+ encoder_hidden_states=encoder_hidden_states,
+ )
+
+ self.assertEqual(output.shape, hidden_states.shape)
+
+ @parameterized.expand(
+ [
+ ("1_3b", WAN_1_3B_CONFIG),
+ ]
+ )
+ @torch.no_grad()
+ def test_wan_allclose_to_hf(self, name, config_template):
+ """Test TRT-LLM transformer matches HuggingFace output (BF16)."""
+ torch.random.manual_seed(42)
+ config = deepcopy(config_template)
+ dtype = getattr(torch, config["torch_dtype"])
+
+ mem_for_full_model = (2 + 1) * 1.3 * 2**30
+ reduce_wan_config(mem_for_full_model, config)
+
+ if config["num_layers"] <= 0:
+ self.skipTest("Insufficient memory for a single Wan layer")
+
+ hidden_size = config["num_attention_heads"] * config["attention_head_dim"]
+
+ # Create HuggingFace model (random weights)
+ hf_model = (
+ HFWanTransformer3DModel(
+ patch_size=config["patch_size"],
+ num_attention_heads=config["num_attention_heads"],
+ attention_head_dim=config["attention_head_dim"],
+ in_channels=config["in_channels"],
+ out_channels=config["out_channels"],
+ text_dim=config["text_dim"],
+ freq_dim=config["freq_dim"],
+ ffn_dim=config["ffn_dim"],
+ num_layers=config["num_layers"],
+ cross_attn_norm=config["cross_attn_norm"],
+ qk_norm=config["qk_norm"],
+ eps=config["eps"],
+ )
+ .to(self.DEVICE, dtype=dtype)
+ .eval()
+ )
+
+ # Create TRT-LLM model with model_config
+ config["hidden_size"] = hidden_size
+ model_config = self._create_model_config(config)
+
+ trtllm_model = (
+ WanTransformer3DModel(model_config=model_config).to(self.DEVICE, dtype=dtype).eval()
+ )
+
+ # Copy weights from HF to TRT-LLM
+ loaded_count = self._load_weights_from_hf(trtllm_model, hf_model.state_dict())
+ print(f"[DEBUG] Loaded {loaded_count} weight tensors from HF to TRT-LLM")
+
+ # Create test inputs
+ batch_size = 1
+ num_frames = 1
+ height, width = 64, 64
+ seq_len = 128
+ generator = torch.Generator(device=self.DEVICE).manual_seed(42)
+
+ hidden_states = torch.randn(
+ batch_size,
+ config["in_channels"],
+ num_frames,
+ height,
+ width,
+ generator=generator,
+ device=self.DEVICE,
+ dtype=dtype,
+ )
+ timestep = torch.tensor([50], device=self.DEVICE, dtype=torch.long)
+ encoder_hidden_states = torch.randn(
+ batch_size,
+ seq_len,
+ config["text_dim"],
+ generator=generator,
+ device=self.DEVICE,
+ dtype=dtype,
+ )
+
+ # Run both models
+ with (
+ torch.inference_mode(),
+ torch.backends.cuda.sdp_kernel(
+ enable_flash=False, enable_math=True, enable_mem_efficient=False
+ ),
+ ):
+ hf_output = hf_model(
+ hidden_states=hidden_states,
+ timestep=timestep,
+ encoder_hidden_states=encoder_hidden_states,
+ return_dict=False,
+ )[0]
+
+ trtllm_output = trtllm_model(
+ hidden_states=hidden_states,
+ timestep=timestep,
+ encoder_hidden_states=encoder_hidden_states,
+ )
+
+ # Compare outputs
+ hf_output = hf_output.float()
+ trtllm_output = trtllm_output.float()
+
+ # Debug: Check for NaN/Inf
+ hf_has_nan = torch.isnan(hf_output).any().item()
+ trtllm_has_nan = torch.isnan(trtllm_output).any().item()
+ hf_has_inf = torch.isinf(hf_output).any().item()
+ trtllm_has_inf = torch.isinf(trtllm_output).any().item()
+
+ print("\n[DEBUG] Output validation:")
+ print(f" HF has NaN: {hf_has_nan}, Inf: {hf_has_inf}")
+ print(f" TRT-LLM has NaN: {trtllm_has_nan}, Inf: {trtllm_has_inf}")
+
+ if not (hf_has_nan or trtllm_has_nan or hf_has_inf or trtllm_has_inf):
+ # Compute detailed comparison metrics
+ diff = (trtllm_output - hf_output).abs()
+ max_diff = diff.max().item()
+ mean_diff = diff.mean().item()
+
+ cos_sim = torch.nn.functional.cosine_similarity(
+ trtllm_output.flatten(), hf_output.flatten(), dim=0
+ ).item()
+
+ print("\n[DEBUG] Comparison metrics:")
+ print(f" Max absolute diff: {max_diff:.6f}")
+ print(f" Mean absolute diff: {mean_diff:.6f}")
+ print(f" Cosine similarity: {cos_sim:.6f}")
+ print(f" HF output range: [{hf_output.min():.4f}, {hf_output.max():.4f}]")
+ print(f" TRT-LLM output range: [{trtllm_output.min():.4f}, {trtllm_output.max():.4f}]")
+
+ torch.testing.assert_close(
+ trtllm_output, hf_output, atol=0.4, rtol=0.4, msg=f"Output mismatch for {name} config"
+ )
+
+ def _load_weights_from_hf(self, trtllm_model, hf_state_dict):
+ """Load weights from HuggingFace model to TRT-LLM model.
+
+ TRT-LLM structure:
+ - blocks.0.attn1.qkv_proj (fused QKV for self-attention)
+ - blocks.0.attn2.to_q/to_k/to_v (separate for cross-attention)
+ - blocks.0.attn1.to_out.0 and blocks.0.attn2.to_out.0
+
+ HuggingFace structure:
+ - blocks.0.attn1.to_q/to_k/to_v (separate Q/K/V)
+ - blocks.0.attn2.to_q/to_k/to_v (separate Q/K/V)
+ - blocks.0.attn1.to_out.0 and blocks.0.attn2.to_out.0
+ """
+ loaded_count = 0
+ missing_weights = []
+
+ def load_linear(module, trtllm_key, hf_key, sd):
+ """Load weights from HF key into TRT-LLM module."""
+ if f"{hf_key}.weight" in sd:
+ weight_dict = {"weight": sd[f"{hf_key}.weight"]}
+ if f"{hf_key}.bias" in sd:
+ weight_dict["bias"] = sd[f"{hf_key}.bias"]
+ module.load_weights([weight_dict])
+ return 1
+ else:
+ missing_weights.append(hf_key)
+ return 0
+
+ for name, module in trtllm_model.named_modules():
+ if isinstance(module, Linear):
+ # Self-attention fused QKV: blocks.0.attn1.qkv_proj
+ # Load from HF separate Q/K/V: blocks.0.attn1.to_q/to_k/to_v
+ if "attn1.qkv_proj" in name:
+ base = name.replace(".qkv_proj", "")
+ q_key, k_key, v_key = f"{base}.to_q", f"{base}.to_k", f"{base}.to_v"
+ if f"{q_key}.weight" in hf_state_dict:
+ q_dict = {"weight": hf_state_dict[f"{q_key}.weight"]}
+ k_dict = {"weight": hf_state_dict[f"{k_key}.weight"]}
+ v_dict = {"weight": hf_state_dict[f"{v_key}.weight"]}
+ if f"{q_key}.bias" in hf_state_dict:
+ q_dict["bias"] = hf_state_dict[f"{q_key}.bias"]
+ k_dict["bias"] = hf_state_dict[f"{k_key}.bias"]
+ v_dict["bias"] = hf_state_dict[f"{v_key}.bias"]
+ module.load_weights([q_dict, k_dict, v_dict])
+ loaded_count += 1
+
+ # Cross-attention separate Q/K/V: blocks.0.attn2.to_q (same path as HF)
+ elif "attn2.to_q" in name or "attn2.to_k" in name or "attn2.to_v" in name:
+ # Direct mapping - TRT-LLM and HF use same paths for cross-attention
+ loaded_count += load_linear(module, name, name, hf_state_dict)
+
+ # Output projections: blocks.0.attn1.to_out.0 (same path as HF)
+ elif ".to_out" in name:
+ # Direct mapping - TRT-LLM and HF use same paths for output projections
+ loaded_count += load_linear(module, name, name, hf_state_dict)
+
+ # FFN layers: TRT-LLM uses up_proj/down_proj, HF uses net.0.proj/net.2
+ elif "ffn.up_proj" in name:
+ hf_key = name.replace(".ffn.up_proj", ".ffn.net.0.proj")
+ loaded_count += load_linear(module, name, hf_key, hf_state_dict)
+ elif "ffn.down_proj" in name:
+ hf_key = name.replace(".ffn.down_proj", ".ffn.net.2")
+ loaded_count += load_linear(module, name, hf_key, hf_state_dict)
+
+ # Other layers: direct mapping
+ elif "condition_embedder" in name or "proj_out" in name:
+ loaded_count += load_linear(module, name, name, hf_state_dict)
+
+ else:
+ # Direct mapping for any other Linear modules
+ loaded_count += load_linear(module, name, name, hf_state_dict)
+
+ elif hasattr(module, "weight") and f"{name}.weight" in hf_state_dict:
+ # Norms & embeddings
+ with torch.no_grad():
+ module.weight.copy_(hf_state_dict[f"{name}.weight"])
+ if (
+ getattr(module, "bias", None) is not None
+ and f"{name}.bias" in hf_state_dict
+ ):
+ module.bias.copy_(hf_state_dict[f"{name}.bias"])
+ loaded_count += 1
+
+ # Load scale_shift_table parameters
+ for name, param in trtllm_model.named_parameters():
+ if "scale_shift_table" in name and name in hf_state_dict:
+ with torch.no_grad():
+ param.copy_(hf_state_dict[name].view(param.shape))
+ loaded_count += 1
+
+ if missing_weights:
+ print(f"[DEBUG] Missing {len(missing_weights)} weights:")
+ for mw in missing_weights[:10]: # Show first 10
+ print(f" - {mw}")
+
+ return loaded_count
+
+ def _load_weights_from_state_dict(self, model, state_dict):
+ """Load weights from state_dict into model (same structure)."""
+ for name, module in model.named_modules():
+ if isinstance(module, Linear):
+ weight_key = f"{name}.weight"
+ if weight_key in state_dict:
+ weight_dict = {"weight": state_dict[weight_key]}
+ bias_key = f"{name}.bias"
+ if bias_key in state_dict:
+ weight_dict["bias"] = state_dict[bias_key]
+ module.load_weights([weight_dict])
+
+ elif hasattr(module, "weight") and f"{name}.weight" in state_dict:
+ with torch.no_grad():
+ module.weight.copy_(state_dict[f"{name}.weight"])
+ if getattr(module, "bias", None) is not None and f"{name}.bias" in state_dict:
+ module.bias.copy_(state_dict[f"{name}.bias"])
+
+ # Load parameters
+ for name, param in model.named_parameters():
+ if name in state_dict:
+ with torch.no_grad():
+ param.copy_(state_dict[name].view(param.shape))
+
+
+# =============================================================================
+# Pipeline Test - Require Real Checkpoint
+# =============================================================================
+
+
+@pytest.fixture
+def checkpoint_exists():
+ """Check if checkpoint path is set and exists."""
+ return CHECKPOINT_PATH and os.path.exists(CHECKPOINT_PATH)
+
+
+@pytest.fixture(autouse=True)
+def cleanup_gpu_memory():
+ """Automatically cleanup GPU memory after each test to prevent OOM errors.
+
+ This fixture runs automatically after every test in this file.
+ It performs garbage collection and clears CUDA cache to free up GPU memory.
+ """
+ yield # Test runs here
+ # Cleanup after test completes
+ import gc
+
+ gc.collect()
+ if torch.cuda.is_available():
+ torch.cuda.empty_cache()
+
+
+def _is_fp32_layernorm_param(param_name: str) -> bool:
+ """True if param is a LayerNorm weight/bias we keep in float32. Only LayerNorm (norm1/norm2/norm3/norm_out)."""
+ if not param_name.endswith((".weight", ".bias")):
+ return False
+ # blocks..norm1, norm2, norm3 (LayerNorm only; attn norm_q/norm_k are RMSNorm)
+ if ".norm" in param_name and "blocks." in param_name:
+ parts = param_name.split(".")
+ for p in parts:
+ if p in ("norm1", "norm2", "norm3"):
+ return True
+ return False
+ # top-level norm_out (LayerNorm)
+ if param_name == "norm_out.weight" or param_name == "norm_out.bias":
+ return True
+ # condition_embedder.norm1, norm2 (LayerNorm)
+ if param_name.startswith("condition_embedder.") and ".norm" in param_name:
+ return True
+ return False
+
+
+class TestWanPipeline:
+ """Pipeline tests for Wan pipeline loading with PipelineLoader.
+
+ These tests require a real checkpoint (set DIFFUSION_MODEL_PATH env var).
+ They test the full loading flow: config ā model ā weight loading ā inference.
+ """
+
+ DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
+
+ def test_load_wan_pipeline_basic(self, checkpoint_exists):
+ """Test loading Wan pipeline without quantization via PipelineLoader."""
+ if not checkpoint_exists:
+ pytest.skip("Checkpoint not available. Set DIFFUSION_MODEL_PATH.")
+ if not is_wan21_checkpoint():
+ pytest.skip(
+ "This test requires Wan 2.1 checkpoint (single-stage). Use DIFFUSION_MODEL_PATH with '2.1' in the path."
+ )
+
+ args = DiffusionArgs(
+ checkpoint_path=CHECKPOINT_PATH,
+ device="cuda",
+ dtype="bfloat16",
+ skip_components=SKIP_COMPONENTS,
+ )
+ pipeline = PipelineLoader(args).load()
+
+ # Verify pipeline loaded correctly
+ assert pipeline.transformer is not None
+ assert len(pipeline.transformer.blocks) > 0
+
+ # Verify weights are loaded
+ # Check that non-scale parameters are bfloat16
+ bf16_count = 0
+ f32_scale_count = 0
+ for name, param in pipeline.transformer.named_parameters():
+ assert param.device.type == "cuda", f"Parameter {name} not on CUDA"
+ if "scale" in name.lower():
+ # Scale parameters can stay float32 for FP8 kernels
+ assert param.dtype in [torch.float32, torch.bfloat16], (
+ f"Scale param {name} has unexpected dtype {param.dtype}"
+ )
+ if param.dtype == torch.float32:
+ f32_scale_count += 1
+ elif _is_fp32_layernorm_param(name):
+ # LayerNorm (norm1/norm2/norm3/norm_out) use float32; RMSNorm (norm_q, norm_k, etc.) stay bf16
+ assert param.dtype == torch.float32, (
+ f"LayerNorm param {name} expected float32 but got {param.dtype}"
+ )
+ else:
+ # Non-scale parameters should be bfloat16
+ assert param.dtype == torch.bfloat16, (
+ f"Parameter {name} expected bfloat16 but got {param.dtype}"
+ )
+ bf16_count += 1
+
+ assert bf16_count > 0, "Should have at least some bfloat16 parameters"
+ print(
+ f"\n[Pipeline] BF16 pipeline loaded: {bf16_count} bf16 params"
+ f"\n{f32_scale_count} f32 scale params, {len(pipeline.transformer.blocks)} blocks"
+ )
+
+ @pytest.mark.parametrize("quant_algo", ["FP8", "FP8_BLOCK_SCALES"])
+ def test_load_wan_pipeline_with_quantization(self, checkpoint_exists, quant_algo):
+ """Test loading Wan with FP8 quantization (per-tensor or blockwise)."""
+ if not checkpoint_exists:
+ pytest.skip("Checkpoint not available. Set DIFFUSION_MODEL_PATH.")
+ if not is_wan21_checkpoint():
+ pytest.skip(
+ "This test requires Wan 2.1 checkpoint. Use DIFFUSION_MODEL_PATH with '2.1' in the path."
+ )
+
+ args = DiffusionArgs(
+ checkpoint_path=CHECKPOINT_PATH,
+ device="cuda",
+ dtype="bfloat16",
+ skip_components=SKIP_COMPONENTS,
+ quant_config={"quant_algo": quant_algo, "dynamic": True},
+ )
+ pipeline = PipelineLoader(args).load()
+
+ # Verify FP8 weights in transformer blocks
+ found_fp8 = False
+ for name, module in pipeline.transformer.named_modules():
+ if isinstance(module, Linear):
+ if "blocks." in name and hasattr(module, "weight") and module.weight is not None:
+ assert module.weight.dtype == torch.float8_e4m3fn, (
+ f"Linear {name} should have FP8 weight, got {module.weight.dtype}"
+ )
+ assert hasattr(module, "weight_scale"), f"Linear {name} missing weight_scale"
+ found_fp8 = True
+ print(f"[{quant_algo}] FP8 layer {name}: weight {module.weight.shape}")
+ break
+
+ assert found_fp8, f"No FP8 Linear modules found for {quant_algo}"
+
+ @pytest.mark.parametrize("quant_algo", ["FP8", "FP8_BLOCK_SCALES"])
+ def test_fp8_vs_bf16_numerical_correctness(self, checkpoint_exists, quant_algo):
+ """Test FP8 vs BF16 numerical accuracy on real checkpoint weights.
+
+ Pattern (similar to that in test_pipeline_dynamic_quant.py):
+ 1. Use F.linear() with BF16 weights as ground truth reference
+ 2. Verify BF16 layer matches F.linear exactly
+ 3. Compare FP8 layer output against reference
+ 4. Check max_diff, cosine_similarity, mse_loss
+ """
+ if not checkpoint_exists:
+ pytest.skip("Checkpoint not available. Set DIFFUSION_MODEL_PATH.")
+ if not is_wan21_checkpoint():
+ pytest.skip(
+ "This test requires Wan 2.1 checkpoint (loads 2 full models and "
+ "Needs single transformer). Use DIFFUSION_MODEL_PATH with '2.1' in the path."
+ )
+
+ # =====================================================================
+ # Load BF16 Pipeline (Reference)
+ # =====================================================================
+ print(f"\n[Compare {quant_algo}] Loading BF16 pipeline...")
+
+ args_bf16 = DiffusionArgs(
+ checkpoint_path=CHECKPOINT_PATH,
+ device="cuda",
+ dtype="bfloat16",
+ )
+ pipeline_bf16 = PipelineLoader(args_bf16).load()
+
+ # =====================================================================
+ # Load FP8 Pipeline
+ # =====================================================================
+ print(f"[Compare {quant_algo}] Loading {quant_algo} pipeline...")
+
+ args_fp8 = DiffusionArgs(
+ checkpoint_path=CHECKPOINT_PATH,
+ device="cuda",
+ dtype="bfloat16",
+ quant_config={"quant_algo": quant_algo, "dynamic": True},
+ )
+ pipeline_fp8 = PipelineLoader(args_fp8).load()
+
+ # =====================================================================
+ # Get Linear Layers from Both Pipelines
+ # =====================================================================
+ attn_bf16 = pipeline_bf16.transformer.blocks[0].attn1
+ attn_fp8 = pipeline_fp8.transformer.blocks[0].attn1
+
+ # Get linear layer - try fused qkv_proj first, fallback to qkv_proj on attention module
+ if hasattr(attn_bf16, "qkv_proj"):
+ linear_bf16 = attn_bf16.qkv_proj
+ linear_fp8 = attn_fp8.qkv_proj
+ layer_name = "blocks.0.attn1.qkv_proj"
+ elif hasattr(attn_bf16, "attn") and hasattr(attn_bf16.attn, "qkv_proj"):
+ linear_bf16 = attn_bf16.attn.qkv_proj
+ linear_fp8 = attn_fp8.attn.qkv_proj
+ layer_name = "blocks.0.attn1.attn.qkv_proj"
+ else:
+ # Use FFN linear instead (always available)
+ linear_bf16 = pipeline_bf16.transformer.blocks[0].ffn.net[0]["proj"]
+ linear_fp8 = pipeline_fp8.transformer.blocks[0].ffn.net[0]["proj"]
+ layer_name = "blocks.0.ffn.net.0.proj"
+
+ # =====================================================================
+ # Get BF16 weights and bias for F.linear reference
+ # =====================================================================
+ weight_bf16 = linear_bf16.weight.data.clone()
+ bias_bf16 = linear_bf16.bias.data.clone() if linear_bf16.bias is not None else None
+
+ # =====================================================================
+ # Create Test Input
+ # =====================================================================
+ torch.manual_seed(42)
+ hidden_size = linear_bf16.in_features
+ batch_size = 1
+ seq_len = 14040
+
+ # 2D input for FP8 kernel compatibility
+ input_tensor = torch.randn(
+ batch_size * seq_len, hidden_size, dtype=torch.bfloat16, device="cuda"
+ )
+ print(f"[Compare] Input shape: {input_tensor.shape}")
+
+ # =====================================================================
+ # Compute Reference Output: F.linear (ground truth)
+ # =====================================================================
+ with torch.no_grad():
+ expected = F.linear(input_tensor, weight_bf16, bias_bf16)
+
+ # =====================================================================
+ # Compute FP8 Output
+ # =====================================================================
+ with torch.no_grad():
+ result_fp8 = linear_fp8(input_tensor)
+
+ # =====================================================================
+ # Compute BF16 Layer Output
+ # =====================================================================
+ with torch.no_grad():
+ result_bf16 = linear_bf16(input_tensor)
+
+ # Verify BF16 layer matches F.linear reference
+ assert torch.allclose(result_bf16, expected, rtol=1e-5, atol=1e-6), (
+ "BF16 layer should match F.linear reference exactly"
+ )
+
+ # Compare FP8 vs Reference
+ max_diff = torch.max(torch.abs(result_fp8 - expected)).item()
+ cos_sim = F.cosine_similarity(
+ result_fp8.flatten().float(), expected.flatten().float(), dim=0
+ )
+ mse = F.mse_loss(result_fp8.flatten().float(), expected.flatten().float())
+
+ print(
+ f"\n[{layer_name}] max_diff={max_diff:.6f}, cos_sim={cos_sim.item():.6f}, mse={mse.item():.6f}"
+ )
+
+ assert cos_sim > 0.99, f"Cosine similarity too low: {cos_sim.item()}"
+ assert mse < 1.0, f"MSE too high: {mse.item()}"
+
+ # Cleanup
+ del pipeline_bf16, pipeline_fp8
+ torch.cuda.empty_cache()
+
+ def test_fp8_vs_bf16_memory_comparison(self, checkpoint_exists):
+ """Test FP8 uses ~2x less memory than BF16."""
+ if not checkpoint_exists:
+ pytest.skip("Checkpoint not available. Set DIFFUSION_MODEL_PATH.")
+ if not is_wan21_checkpoint():
+ pytest.skip(
+ "This test requires Wan 2.1 checkpoint. Use DIFFUSION_MODEL_PATH with '2.1' in the path."
+ )
+
+ def get_module_memory_gb(module):
+ return sum(p.numel() * p.element_size() for p in module.parameters()) / 1024**3
+
+ # Load BF16
+ torch.cuda.empty_cache()
+ torch.cuda.reset_peak_memory_stats()
+
+ args_bf16 = DiffusionArgs(
+ checkpoint_path=CHECKPOINT_PATH,
+ device="cuda",
+ dtype="bfloat16",
+ )
+ pipeline_bf16 = PipelineLoader(args_bf16).load()
+
+ bf16_model_mem = get_module_memory_gb(pipeline_bf16.transformer)
+ bf16_peak_mem = torch.cuda.max_memory_allocated() / 1024**3
+
+ print(f"\n[BF16] Transformer memory: {bf16_model_mem:.2f} GB")
+ print(f"[BF16] Peak memory: {bf16_peak_mem:.2f} GB")
+
+ del pipeline_bf16
+ torch.cuda.empty_cache()
+
+ # Load FP8
+ torch.cuda.reset_peak_memory_stats()
+
+ args_fp8 = DiffusionArgs(
+ checkpoint_path=CHECKPOINT_PATH,
+ device="cuda",
+ dtype="bfloat16",
+ quant_config={"quant_algo": "FP8", "dynamic": True},
+ )
+ pipeline_fp8 = PipelineLoader(args_fp8).load()
+
+ fp8_model_mem = get_module_memory_gb(pipeline_fp8.transformer)
+ fp8_peak_mem = torch.cuda.max_memory_allocated() / 1024**3
+
+ print(f"\n[FP8] Transformer memory: {fp8_model_mem:.2f} GB")
+ print(f"[FP8] Peak memory: {fp8_peak_mem:.2f} GB")
+
+ # Verify memory savings
+ model_mem_ratio = bf16_model_mem / fp8_model_mem
+ peak_mem_ratio = bf16_peak_mem / fp8_peak_mem
+
+ print(f"\n[Comparison] Model memory ratio (BF16/FP8): {model_mem_ratio:.2f}x")
+ print(f"[Comparison] Peak memory ratio (BF16/FP8): {peak_mem_ratio:.2f}x")
+
+ # FP8 should use ~2x less memory
+ assert model_mem_ratio > 1.8, f"FP8 should use ~2x less memory, got {model_mem_ratio:.2f}x"
+
+ del pipeline_fp8
+ torch.cuda.empty_cache()
+
+ @pytest.mark.parametrize("quant_algo", ["FP8", "FP8_BLOCK_SCALES"])
+ def test_fp8_vs_bf16_full_transformer_e2e(self, checkpoint_exists, quant_algo):
+ """End-to-end test: Compare full Wan transformer FP8 vs BF16 output.
+
+ Unlike test_fp8_vs_bf16_numerical_correctness which tests a single Linear layer,
+ this test runs the ENTIRE transformer (all 30 blocks) and compares outputs.
+
+ Expectations:
+ - Errors accumulate across 30 layers, so use relaxed tolerances
+ - Cosine similarity should be high (>0.95) but lower than single-layer test (>0.99)
+ - This validates that FP8 quantization doesn't degrade quality too much end-to-end
+ """
+ if not checkpoint_exists:
+ pytest.skip("Checkpoint not available. Set DIFFUSION_MODEL_PATH.")
+ if not is_wan21_checkpoint():
+ pytest.skip(
+ "This test requires Wan 2.1 checkpoint. Use DIFFUSION_MODEL_PATH with '2.1' in the path."
+ )
+
+ # =====================================================================
+ # Load BF16 Transformer (Reference)
+ # =====================================================================
+ print("\n[E2E] Loading BF16 transformer...")
+
+ args_bf16 = DiffusionArgs(
+ checkpoint_path=CHECKPOINT_PATH,
+ device="cuda",
+ dtype="bfloat16",
+ )
+ pipeline_bf16 = PipelineLoader(args_bf16).load()
+ transformer_bf16 = pipeline_bf16.transformer
+
+ # =====================================================================
+ # Load FP8 Transformer
+ # =====================================================================
+ print(f"[E2E] Loading {quant_algo} transformer...")
+
+ args_fp8 = DiffusionArgs(
+ checkpoint_path=CHECKPOINT_PATH,
+ device="cuda",
+ dtype="bfloat16",
+ quant_config={"quant_algo": quant_algo, "dynamic": True},
+ )
+ pipeline_fp8 = PipelineLoader(args_fp8).load()
+ transformer_fp8 = pipeline_fp8.transformer
+
+ # =====================================================================
+ # Create Realistic Inputs
+ # =====================================================================
+ torch.manual_seed(42)
+
+ # Use smaller size for faster testing (still realistic)
+ batch_size = 1
+ num_frames = 1
+ height, width = 64, 64 # Smaller than full 720x1280
+ in_channels = 16
+ text_seq_len = 128
+ text_dim = 4096
+
+ # Create inputs
+ hidden_states = torch.randn(
+ batch_size, in_channels, num_frames, height, width, dtype=torch.bfloat16, device="cuda"
+ )
+ timestep = torch.tensor([500], dtype=torch.long, device="cuda")
+ encoder_hidden_states = torch.randn(
+ batch_size, text_seq_len, text_dim, dtype=torch.bfloat16, device="cuda"
+ )
+
+ print("[E2E] Input shapes:")
+ print(f" hidden_states: {hidden_states.shape}")
+ print(f" timestep: {timestep.shape}")
+ print(f" encoder_hidden_states: {encoder_hidden_states.shape}")
+
+ # =====================================================================
+ # Run Full Transformer Forward Pass
+ # =====================================================================
+ print("[E2E] Running BF16 transformer forward...")
+ with torch.no_grad():
+ output_bf16 = transformer_bf16(
+ hidden_states=hidden_states.clone(),
+ timestep=timestep,
+ encoder_hidden_states=encoder_hidden_states.clone(),
+ )
+
+ print(f"[E2E] Running {quant_algo} transformer forward...")
+ with torch.no_grad():
+ output_fp8 = transformer_fp8(
+ hidden_states=hidden_states.clone(),
+ timestep=timestep,
+ encoder_hidden_states=encoder_hidden_states.clone(),
+ )
+
+ # =====================================================================
+ # Verify Outputs
+ # =====================================================================
+ assert output_bf16.shape == output_fp8.shape, (
+ f"Output shape mismatch: BF16={output_bf16.shape}, FP8={output_fp8.shape}"
+ )
+ print(f"[E2E] Output shape: {output_bf16.shape}")
+
+ # Check for NaN/Inf
+ bf16_has_nan = torch.isnan(output_bf16).any().item()
+ fp8_has_nan = torch.isnan(output_fp8).any().item()
+ bf16_has_inf = torch.isinf(output_bf16).any().item()
+ fp8_has_inf = torch.isinf(output_fp8).any().item()
+
+ assert not bf16_has_nan, "BF16 output contains NaN"
+ assert not bf16_has_inf, "BF16 output contains Inf"
+ assert not fp8_has_nan, f"{quant_algo} output contains NaN"
+ assert not fp8_has_inf, f"{quant_algo} output contains Inf"
+
+ # =====================================================================
+ # Compare Numerical Accuracy
+ # =====================================================================
+ output_bf16_float = output_bf16.float()
+ output_fp8_float = output_fp8.float()
+
+ max_diff = torch.max(torch.abs(output_fp8_float - output_bf16_float)).item()
+ mean_diff = torch.mean(torch.abs(output_fp8_float - output_bf16_float)).item()
+
+ cos_sim = F.cosine_similarity(
+ output_fp8_float.flatten(), output_bf16_float.flatten(), dim=0
+ ).item()
+
+ mse = F.mse_loss(output_fp8_float, output_bf16_float).item()
+
+ # Relative error
+ rel_error = mean_diff / (output_bf16_float.abs().mean().item() + 1e-8)
+
+ print(f"\n{'=' * 60}")
+ print(f"END-TO-END TRANSFORMER COMPARISON ({quant_algo} vs BF16)")
+ print(f"{'=' * 60}")
+ print(f"Number of layers: {len(transformer_bf16.blocks)}")
+ print(f"Output shape: {output_bf16.shape}")
+ print("")
+ print(f"Max absolute difference: {max_diff:.6f}")
+ print(f"Mean absolute difference: {mean_diff:.6f}")
+ print(f"Relative error: {rel_error:.6f}")
+ print(f"Cosine similarity: {cos_sim:.6f}")
+ print(f"MSE loss: {mse:.6f}")
+ print("")
+ print(f"BF16 output range: [{output_bf16_float.min():.4f}, {output_bf16_float.max():.4f}]")
+ print(
+ f"{quant_algo} output range: [{output_fp8_float.min():.4f}, {output_fp8_float.max():.4f}]"
+ )
+ print(f"{'=' * 60}")
+
+ # =====================================================================
+ # Assert Numerical Correctness (Relaxed Tolerances)
+ # =====================================================================
+ # Cosine similarity should be high, but lower than single-layer test
+ # due to error accumulation across 30 layers
+ assert cos_sim > 0.95, (
+ f"Cosine similarity too low for full transformer: {cos_sim:.6f} (expected >0.95)"
+ )
+
+ # Relative error should be reasonable
+ # Note: Error accumulates across 30 layers, so we use a relaxed tolerance
+ assert rel_error < 0.15, f"Relative error too high: {rel_error:.6f} (expected <0.15)"
+
+ print(f"\n[PASS] {quant_algo} full transformer output matches BF16 within tolerance!")
+ print(f" ā Cosine similarity: {cos_sim:.4f} (>0.95)")
+ print(f" ā Relative error: {rel_error:.4f} (<0.15)")
+
+ # Cleanup
+ del pipeline_bf16, pipeline_fp8, transformer_bf16, transformer_fp8
+ torch.cuda.empty_cache()
+
+ def test_attention_backend_comparison(self, checkpoint_exists):
+ """Test accuracy of full Wan forward pass with attention backend comparison.
+
+ Wan uses both self-attention (attn1) and cross-attention (attn2). TRTLLM backend
+ doesn't support cross-attention (seq_len != kv_seq_len), but WanAttention
+ automatically falls back to VANILLA for cross-attention when TRTLLM is configured.
+
+ This test verifies:
+ 1. VANILLA backend works correctly
+ 2. TRTLLM backend with automatic VANILLA fallback for cross-attention produces
+ numerically similar results to pure VANILLA
+ """
+ if not checkpoint_exists:
+ pytest.skip("Checkpoint not available. Set DIFFUSION_MODEL_PATH.")
+ if not is_wan21_checkpoint():
+ pytest.skip(
+ "This test requires Wan 2.1 checkpoint. Use DIFFUSION_MODEL_PATH with '2.1' in the path."
+ )
+
+ # =====================================================================
+ # Load Baseline Transformer (Default VANILLA)
+ # =====================================================================
+ print("\n[Attention Backend Test] Loading baseline transformer (default VANILLA)...")
+
+ from tensorrt_llm._torch.visual_gen.config import AttentionConfig
+
+ args_baseline = DiffusionArgs(
+ checkpoint_path=CHECKPOINT_PATH,
+ device="cuda",
+ dtype="bfloat16",
+ )
+ # Default attention backend is VANILLA
+ pipeline_baseline = PipelineLoader(args_baseline).load()
+ transformer_baseline = pipeline_baseline.transformer
+
+ # =====================================================================
+ # Load VANILLA Transformer
+ # =====================================================================
+ print("[Attention Backend Test] Loading VANILLA transformer (explicit)...")
+
+ args_vanilla = DiffusionArgs(
+ checkpoint_path=CHECKPOINT_PATH,
+ device="cuda",
+ dtype="bfloat16",
+ )
+ args_vanilla.attention = AttentionConfig(backend="VANILLA")
+ pipeline_vanilla = PipelineLoader(args_vanilla).load()
+ transformer_vanilla = pipeline_vanilla.transformer
+
+ # =====================================================================
+ # Create Fixed Test Inputs
+ # =====================================================================
+ torch.manual_seed(42)
+
+ # Smaller size for faster testing
+ batch_size = 1
+ num_frames = 1
+ height, width = 64, 64
+ in_channels = 16
+ text_seq_len = 128
+ text_dim = 4096
+
+ # Create inputs
+ hidden_states = torch.randn(
+ batch_size, in_channels, num_frames, height, width, dtype=torch.bfloat16, device="cuda"
+ )
+ timestep = torch.tensor([500], dtype=torch.long, device="cuda")
+ encoder_hidden_states = torch.randn(
+ batch_size, text_seq_len, text_dim, dtype=torch.bfloat16, device="cuda"
+ )
+
+ print("[Attention Backend Test] Input shapes:")
+ print(f" hidden_states: {hidden_states.shape}")
+ print(f" timestep: {timestep.shape}")
+ print(f" encoder_hidden_states: {encoder_hidden_states.shape}")
+
+ # =====================================================================
+ # Run Full Transformer Forward Pass
+ # =====================================================================
+ print("[Attention Backend Test] Running baseline transformer forward...")
+ with torch.no_grad():
+ output_baseline = transformer_baseline(
+ hidden_states=hidden_states.clone(),
+ timestep=timestep,
+ encoder_hidden_states=encoder_hidden_states.clone(),
+ )
+
+ print("[Attention Backend Test] Running VANILLA transformer forward...")
+ with torch.no_grad():
+ output_vanilla = transformer_vanilla(
+ hidden_states=hidden_states.clone(),
+ timestep=timestep,
+ encoder_hidden_states=encoder_hidden_states.clone(),
+ )
+
+ # =====================================================================
+ # Verify Output Shapes
+ # =====================================================================
+ assert output_baseline.shape == output_vanilla.shape, (
+ f"Output shape mismatch: baseline={output_baseline.shape}, "
+ f"VANILLA={output_vanilla.shape}"
+ )
+ print(f"[Attention Backend Test] Output shape: {output_baseline.shape}")
+
+ # =====================================================================
+ # Check for NaN/Inf in All Outputs
+ # =====================================================================
+ for name, output in [("baseline", output_baseline), ("VANILLA", output_vanilla)]:
+ has_nan = torch.isnan(output).any().item()
+ has_inf = torch.isinf(output).any().item()
+ assert not has_nan, f"{name} output contains NaN"
+ assert not has_inf, f"{name} output contains Inf"
+ print(f"[Attention Backend Test] {name} output: NaN={has_nan}, Inf={has_inf}")
+
+ # =====================================================================
+ # Compare VANILLA (Explicit) vs Baseline
+ # =====================================================================
+ output_baseline_float = output_baseline.float()
+ output_vanilla_float = output_vanilla.float()
+
+ # VANILLA explicit vs baseline (should be identical)
+ max_diff_vanilla = torch.max(torch.abs(output_vanilla_float - output_baseline_float)).item()
+ mean_diff_vanilla = torch.mean(
+ torch.abs(output_vanilla_float - output_baseline_float)
+ ).item()
+ cos_sim_vanilla = F.cosine_similarity(
+ output_vanilla_float.flatten(), output_baseline_float.flatten(), dim=0
+ ).item()
+ mse_vanilla = F.mse_loss(output_vanilla_float, output_baseline_float).item()
+
+ print(f"\n{'=' * 60}")
+ print("VANILLA (Explicit) vs Baseline Comparison")
+ print(f"{'=' * 60}")
+ print(f"Max absolute difference: {max_diff_vanilla:.6f}")
+ print(f"Mean absolute difference: {mean_diff_vanilla:.6f}")
+ print(f"Cosine similarity: {cos_sim_vanilla:.6f}")
+ print(f"MSE loss: {mse_vanilla:.6f}")
+ print(f"{'=' * 60}")
+
+ # VANILLA explicit should match baseline closely (same backend)
+ # Note: Not exactly identical
+ assert cos_sim_vanilla > 0.995, (
+ f"VANILLA explicit should match baseline closely: cos_sim={cos_sim_vanilla:.6f}"
+ )
+
+ print("\n[PASS] VANILLA backend produces consistent outputs!")
+ print(f" ā VANILLA (explicit) matches baseline: cos_sim={cos_sim_vanilla:.6f} (>0.995)")
+
+ # =====================================================================
+ # Load TRTLLM Transformer (with automatic VANILLA fallback for cross-attention)
+ # =====================================================================
+ print("\n[Attention Backend Test] Loading TRTLLM transformer...")
+
+ args_trtllm = DiffusionArgs(
+ checkpoint_path=CHECKPOINT_PATH,
+ device="cuda",
+ dtype="bfloat16",
+ )
+ args_trtllm.attention = AttentionConfig(backend="TRTLLM")
+ pipeline_trtllm = PipelineLoader(args_trtllm).load()
+ transformer_trtllm = pipeline_trtllm.transformer
+
+ # Verify automatic backend override for cross-attention
+ print("[Attention Backend Test] Verifying backend configuration...")
+ first_block = transformer_trtllm.blocks[0]
+ attn1_backend = first_block.attn1.attn_backend
+ attn2_backend = first_block.attn2.attn_backend
+ print(f" attn1 (self-attention) backend: {attn1_backend}")
+ print(f" attn2 (cross-attention) backend: {attn2_backend}")
+ assert attn1_backend == "TRTLLM", f"Expected attn1 to use TRTLLM, got {attn1_backend}"
+ assert attn2_backend == "VANILLA", f"Expected attn2 to use VANILLA, got {attn2_backend}"
+ print(" ā Automatic backend override working correctly!")
+
+ # =====================================================================
+ # Run TRTLLM Transformer Forward Pass
+ # =====================================================================
+ print("[Attention Backend Test] Running TRTLLM transformer forward...")
+ with torch.no_grad():
+ output_trtllm = transformer_trtllm(
+ hidden_states=hidden_states.clone(),
+ timestep=timestep,
+ encoder_hidden_states=encoder_hidden_states.clone(),
+ )
+
+ # =====================================================================
+ # Check for NaN/Inf in TRTLLM Output
+ # =====================================================================
+ has_nan = torch.isnan(output_trtllm).any().item()
+ has_inf = torch.isinf(output_trtllm).any().item()
+ assert not has_nan, "TRTLLM output contains NaN"
+ assert not has_inf, "TRTLLM output contains Inf"
+ print(f"[Attention Backend Test] TRTLLM output: NaN={has_nan}, Inf={has_inf}")
+
+ # =====================================================================
+ # Compare TRTLLM vs Baseline
+ # =====================================================================
+ output_trtllm_float = output_trtllm.float()
+
+ max_diff_trtllm = torch.max(torch.abs(output_trtllm_float - output_baseline_float)).item()
+ mean_diff_trtllm = torch.mean(torch.abs(output_trtllm_float - output_baseline_float)).item()
+ cos_sim_trtllm = F.cosine_similarity(
+ output_trtllm_float.flatten(), output_baseline_float.flatten(), dim=0
+ ).item()
+ mse_trtllm = F.mse_loss(output_trtllm_float, output_baseline_float).item()
+
+ print(f"\n{'=' * 60}")
+ print("TRTLLM (with auto VANILLA fallback) vs Baseline Comparison")
+ print(f"{'=' * 60}")
+ print(f"Max absolute difference: {max_diff_trtllm:.6f}")
+ print(f"Mean absolute difference: {mean_diff_trtllm:.6f}")
+ print(f"Cosine similarity: {cos_sim_trtllm:.6f}")
+ print(f"MSE loss: {mse_trtllm:.6f}")
+ print(f"{'=' * 60}")
+
+ # TRTLLM should produce similar results (attn1 uses TRTLLM, attn2 uses VANILLA)
+ # Allow slightly more tolerance since different attention implementations
+ assert cos_sim_trtllm > 0.99, (
+ f"TRTLLM should produce similar results to baseline: cos_sim={cos_sim_trtllm:.6f}"
+ )
+
+ print("\n[PASS] TRTLLM backend with automatic fallback works correctly!")
+ print(f" ā TRTLLM matches baseline: cos_sim={cos_sim_trtllm:.6f} (>0.99)")
+
+ # Cleanup
+ del pipeline_baseline, pipeline_vanilla, pipeline_trtllm
+ del transformer_baseline, transformer_vanilla, transformer_trtllm
+ torch.cuda.empty_cache()
+
+ @pytest.mark.parametrize("quant_algo", ["FP8", "FP8_BLOCK_SCALES"])
+ def test_fp8_mixed_quant_numerical_correctness(self, checkpoint_exists, quant_algo):
+ """Test numerical correctness with mixed quantization (some layers excluded).
+
+ Compares outputs between:
+ 1. Full BF16 model (reference)
+ 2. Full FP8 model (all layers quantized)
+ 3. Mixed FP8 model (some layers excluded from quantization)
+
+ Expected behavior:
+ - Mixed model should have accuracy between full BF16 and full FP8
+ - Excluding sensitive layers (like first/last blocks) may improve accuracy
+ """
+ if not checkpoint_exists:
+ pytest.skip("Checkpoint not available. Set DIFFUSION_MODEL_PATH.")
+ if not is_wan21_checkpoint():
+ pytest.skip(
+ "This test requires Wan 2.1 checkpoint. Use DIFFUSION_MODEL_PATH with '2.1' in the path."
+ )
+
+ # =====================================================================
+ # Define Mixed Quant Config
+ # =====================================================================
+ # Exclude first block and output projection (often sensitive layers)
+ mixed_ignore_patterns = [
+ "proj_out",
+ "condition_embedder.*",
+ "blocks.0.*",
+ "blocks.29.*", # Last block (if exists)
+ ]
+
+ # =====================================================================
+ # Load Models
+ # =====================================================================
+ print("\n[Mixed Quant Accuracy] Loading BF16 model (reference)...")
+ args_bf16 = DiffusionArgs(
+ checkpoint_path=CHECKPOINT_PATH,
+ device="cuda",
+ dtype="bfloat16",
+ skip_components=SKIP_COMPONENTS,
+ )
+ pipeline_bf16 = PipelineLoader(args_bf16).load()
+
+ print(f"[Mixed Quant Accuracy] Loading mixed {quant_algo} model...")
+ args_fp8_mixed = DiffusionArgs(
+ checkpoint_path=CHECKPOINT_PATH,
+ device="cuda",
+ dtype="bfloat16",
+ skip_components=SKIP_COMPONENTS,
+ quant_config={
+ "quant_algo": quant_algo,
+ "dynamic": True,
+ "ignore": mixed_ignore_patterns,
+ },
+ )
+ pipeline_fp8_mixed = PipelineLoader(args_fp8_mixed).load()
+
+ # =====================================================================
+ # Create Test Inputs
+ # =====================================================================
+ torch.manual_seed(42)
+
+ batch_size = 1
+ num_frames = 1
+ height, width = 64, 64
+ in_channels = 16
+ text_seq_len = 128
+ text_dim = 4096
+
+ hidden_states = torch.randn(
+ batch_size, in_channels, num_frames, height, width, dtype=torch.bfloat16, device="cuda"
+ )
+ timestep = torch.tensor([500], dtype=torch.long, device="cuda")
+ encoder_hidden_states = torch.randn(
+ batch_size, text_seq_len, text_dim, dtype=torch.bfloat16, device="cuda"
+ )
+
+ # =====================================================================
+ # Run Forward Pass
+ # =====================================================================
+ print("[Mixed Quant Accuracy] Running forward passes...")
+
+ with torch.no_grad():
+ output_bf16 = pipeline_bf16.transformer(
+ hidden_states=hidden_states.clone(),
+ timestep=timestep,
+ encoder_hidden_states=encoder_hidden_states.clone(),
+ )
+
+ output_fp8_mixed = pipeline_fp8_mixed.transformer(
+ hidden_states=hidden_states.clone(),
+ timestep=timestep,
+ encoder_hidden_states=encoder_hidden_states.clone(),
+ )
+
+ # =====================================================================
+ # Compute Metrics
+ # =====================================================================
+ output_bf16_float = output_bf16.float()
+ output_fp8_mixed_float = output_fp8_mixed.float()
+
+ # Mixed FP8 vs BF16
+ cos_sim_mixed = F.cosine_similarity(
+ output_fp8_mixed_float.flatten(), output_bf16_float.flatten(), dim=0
+ ).item()
+ mse_mixed = F.mse_loss(output_fp8_mixed_float, output_bf16_float).item()
+
+ print(f"\n{'=' * 60}")
+ print(f"MIXED QUANTIZATION ACCURACY TEST ({quant_algo})")
+ print(f"{'=' * 60}")
+ print(f"Ignored patterns: {mixed_ignore_patterns}")
+ print("")
+ print(f"Mixed {quant_algo} vs BF16:")
+ print(f" Cosine similarity: {cos_sim_mixed:.6f}")
+ print(f" MSE: {mse_mixed:.6f}")
+ print(f"{'=' * 60}")
+
+ # =====================================================================
+ # Assertions
+ # =====================================================================
+ # Both should maintain reasonable accuracy
+ assert cos_sim_mixed > 0.99, (
+ f"Mixed {quant_algo} cosine similarity too low: {cos_sim_mixed}"
+ )
+ assert mse_mixed < 1.0, f"Mixed {quant_algo} MSE too high: {mse_mixed}"
+
+ print("\n[PASS] Mixed quantization numerical correctness verified!")
+ print(f" ā Mixed {quant_algo}: cos_sim={cos_sim_mixed:.4f}")
+
+ # Cleanup
+ del pipeline_bf16, pipeline_fp8_mixed
+ torch.cuda.empty_cache()
+
+ def test_fp8_static_vs_bf16_accuracy(self, wan22_both_checkpoints_exist):
+ """Test FP8 static and dynamic quantization accuracy against BF16 reference.
+
+ Compares outputs from:
+ 1. TRT-LLM BF16 model (reference checkpoint)
+ 2. TRT-LLM FP8 static quantized model (pre-quantized checkpoint)
+ 3. TRT-LLM FP8 dynamic quantized model (BF16 checkpoint + on-the-fly quant)
+
+ Uses spatially-correlated inputs that mimic real VAE latent patterns,
+ which achieves much higher accuracy than random noise inputs.
+ """
+ if not wan22_both_checkpoints_exist:
+ pytest.skip(
+ f"Both checkpoints required. FP8: {CHECKPOINT_PATH_WAN22_FP8}, "
+ f"BF16: {CHECKPOINT_PATH_WAN22_BF16}"
+ )
+
+ # Reset dynamo cache to avoid recompile-limit errors from prior
+ # tests that compiled kernels with different dtypes (e.g. Float32).
+ torch._dynamo.reset()
+
+ print("\n" + "=" * 70)
+ print("FP8 STATIC & DYNAMIC QUANT vs BF16 ACCURACY TEST")
+ print("=" * 70)
+
+ # Load BF16 reference model
+ print(f"\n[BF16] Loading from {CHECKPOINT_PATH_WAN22_BF16}")
+ args_bf16 = DiffusionArgs(
+ checkpoint_path=CHECKPOINT_PATH_WAN22_BF16,
+ device="cuda",
+ dtype="bfloat16",
+ skip_components=SKIP_COMPONENTS,
+ )
+ pipeline_bf16 = PipelineLoader(args_bf16).load()
+
+ # Load FP8 static quantized model (from pre-quantized checkpoint)
+ print(f"\n[FP8 Static] Loading from {CHECKPOINT_PATH_WAN22_FP8}")
+ args_fp8_static = DiffusionArgs(
+ checkpoint_path=CHECKPOINT_PATH_WAN22_FP8,
+ device="cuda",
+ dtype="bfloat16",
+ skip_components=SKIP_COMPONENTS,
+ )
+ pipeline_fp8_static = PipelineLoader(args_fp8_static).load()
+
+ # Load FP8 dynamic quantized model (from BF16 checkpoint with on-the-fly quant)
+ print(f"\n[FP8 Dynamic] Loading from {CHECKPOINT_PATH_WAN22_BF16} with dynamic quant")
+ args_fp8_dynamic = DiffusionArgs(
+ checkpoint_path=CHECKPOINT_PATH_WAN22_BF16,
+ device="cuda",
+ dtype="bfloat16",
+ skip_components=SKIP_COMPONENTS,
+ quant_config={
+ "quant_algo": "FP8",
+ "dynamic": True,
+ },
+ )
+ pipeline_fp8_dynamic = PipelineLoader(args_fp8_dynamic).load()
+
+ # Verify FP8 static model has calibrated scales
+ static_quant_modules = 0
+ for name, module in pipeline_fp8_static.transformer.named_modules():
+ if isinstance(module, Linear):
+ if hasattr(module, "input_scale") and module.input_scale is not None:
+ static_quant_modules += 1
+ print(f"[FP8 Static] Quantized Linear modules with input_scale: {static_quant_modules}")
+ assert static_quant_modules > 0, "FP8 static model should have calibrated scales"
+
+ # Verify FP8 dynamic model has quantized weights
+ dynamic_quant_modules = 0
+ for name, module in pipeline_fp8_dynamic.transformer.named_modules():
+ if isinstance(module, Linear):
+ if hasattr(module, "weight_scale") and module.weight_scale is not None:
+ dynamic_quant_modules += 1
+ print(f"[FP8 Dynamic] Quantized Linear modules: {dynamic_quant_modules}")
+
+ # Create spatially-correlated test inputs (mimics real VAE latent patterns)
+ # Wan 2.2 TI2V-5B specs:
+ # - VAE compression: 16x16x4 (spatial x spatial x temporal)
+ # - Latent channels: 48 (z_dim=48)
+ # - 720P resolution: 1280x704 -> latent: 80x44
+ # - Text encoder: UMT5, max_length=512, dim=4096
+ torch.manual_seed(42)
+
+ batch_size = 2 # For CFG (positive + negative)
+ in_channels = 48 # Wan 2.2 TI2V-5B uses 48 latent channels
+ time_dim = 1 # Single frame for unit test
+
+ # 720P latent dimensions: 1280/16=80 width, 704/16=44 height
+ height = 44 # 720P latent height (704 / 16)
+ width = 80 # 720P latent width (1280 / 16)
+
+ # Text encoder: UMT5 with 4096 dim, typical sequence length ~226
+ text_seq_len = 226 # Default max_sequence_length for Wan
+ text_dim = 4096
+
+ # Create structured latent (not purely random - simulate real VAE output)
+ base_pattern = torch.randn(
+ 1, in_channels, time_dim, height // 4, width // 4, device="cuda", dtype=torch.bfloat16
+ )
+ hidden_states = F.interpolate(
+ base_pattern.view(1, in_channels, height // 4, width // 4),
+ size=(height, width),
+ mode="bilinear",
+ align_corners=False,
+ ).view(1, in_channels, time_dim, height, width)
+ hidden_states = hidden_states * 2.0
+ hidden_states = hidden_states.expand(batch_size, -1, -1, -1, -1).contiguous()
+
+ timestep = torch.tensor([500.0, 500.0], device="cuda", dtype=torch.bfloat16)
+
+ text_base = (
+ torch.randn(1, text_seq_len, text_dim, device="cuda", dtype=torch.bfloat16) * 0.1
+ )
+ encoder_hidden_states = text_base.expand(batch_size, -1, -1).contiguous()
+
+ print(
+ f"\n[Input] 720P latent: {hidden_states.shape} "
+ f"(batch={batch_size}, ch={in_channels}, t={time_dim}, h={height}, w={width})"
+ )
+ print(f"[Input] range: [{hidden_states.min():.2f}, {hidden_states.max():.2f}]")
+ print(f"[Input] encoder_hidden_states: {encoder_hidden_states.shape}")
+
+ # Run forward passes
+ print("\n[Forward] Running BF16 model...")
+ with torch.no_grad():
+ output_bf16 = pipeline_bf16.transformer(
+ hidden_states=hidden_states.clone(),
+ timestep=timestep,
+ encoder_hidden_states=encoder_hidden_states.clone(),
+ )
+
+ print("[Forward] Running FP8 static quant model...")
+ with torch.no_grad():
+ output_fp8_static = pipeline_fp8_static.transformer(
+ hidden_states=hidden_states.clone(),
+ timestep=timestep,
+ encoder_hidden_states=encoder_hidden_states.clone(),
+ )
+
+ print("[Forward] Running FP8 dynamic quant model...")
+ with torch.no_grad():
+ output_fp8_dynamic = pipeline_fp8_dynamic.transformer(
+ hidden_states=hidden_states.clone(),
+ timestep=timestep,
+ encoder_hidden_states=encoder_hidden_states.clone(),
+ )
+
+ # Compute metrics
+ output_bf16_float = output_bf16.float()
+ output_fp8_static_float = output_fp8_static.float()
+ output_fp8_dynamic_float = output_fp8_dynamic.float()
+
+ # FP8 Static vs BF16
+ cos_sim_static = F.cosine_similarity(
+ output_fp8_static_float.flatten(), output_bf16_float.flatten(), dim=0
+ ).item()
+ mse_static = F.mse_loss(output_fp8_static_float, output_bf16_float).item()
+
+ # FP8 Dynamic vs BF16
+ cos_sim_dynamic = F.cosine_similarity(
+ output_fp8_dynamic_float.flatten(), output_bf16_float.flatten(), dim=0
+ ).item()
+ mse_dynamic = F.mse_loss(output_fp8_dynamic_float, output_bf16_float).item()
+
+ # Output statistics
+ bf16_range = (output_bf16_float.min().item(), output_bf16_float.max().item())
+ fp8_static_range = (
+ output_fp8_static_float.min().item(),
+ output_fp8_static_float.max().item(),
+ )
+ fp8_dynamic_range = (
+ output_fp8_dynamic_float.min().item(),
+ output_fp8_dynamic_float.max().item(),
+ )
+
+ print("\n" + "=" * 70)
+ print("RESULTS: FP8 QUANT vs BF16")
+ print("=" * 70)
+ print(f"{'Method':<20} {'Cosine Sim':>12} {'MSE':>12}")
+ print("-" * 70)
+ print(f"{'FP8 Static':<20} {cos_sim_static:>12.6f} {mse_static:>12.6f}")
+ print(f"{'FP8 Dynamic':<20} {cos_sim_dynamic:>12.6f} {mse_dynamic:>12.6f}")
+ print("-" * 70)
+ print(f"BF16 Output Range: [{bf16_range[0]:.4f}, {bf16_range[1]:.4f}]")
+ print(f"FP8 Static Output Range: [{fp8_static_range[0]:.4f}, {fp8_static_range[1]:.4f}]")
+ print(f"FP8 Dynamic Output Range:[{fp8_dynamic_range[0]:.4f}, {fp8_dynamic_range[1]:.4f}]")
+ print("=" * 70)
+
+ # Assertions
+ # Static should have high accuracy (calibrated scales)
+ assert cos_sim_static > 0.99, (
+ f"FP8 Static cosine similarity too low: {cos_sim_static:.6f}. Expected >0.99."
+ )
+ # Dynamic may have slightly lower accuracy (no calibration)
+ assert cos_sim_dynamic > 0.95, (
+ f"FP8 Dynamic cosine similarity too low: {cos_sim_dynamic:.6f}. Expected >0.95."
+ )
+ assert not torch.isnan(output_fp8_static).any(), "FP8 static output contains NaN"
+ assert not torch.isnan(output_fp8_dynamic).any(), "FP8 dynamic output contains NaN"
+
+ print("\n[PASS] FP8 quantization accuracy test passed!")
+ print(f" - FP8 Static: cos_sim={cos_sim_static:.4f} (>0.99), MSE={mse_static:.6f}")
+ print(f" - FP8 Dynamic: cos_sim={cos_sim_dynamic:.4f} (>0.95), MSE={mse_dynamic:.6f}")
+
+ # Cleanup
+ del pipeline_bf16, pipeline_fp8_static, pipeline_fp8_dynamic
+ torch.cuda.empty_cache()
+
+ def test_nvfp4_static_vs_bf16_accuracy(self, wan22_nvfp4_bf16_checkpoints_exist):
+ """Test NVFP4 static quantization accuracy against BF16 reference.
+
+ Compares outputs from:
+ 1. TRT-LLM BF16 model (reference checkpoint)
+ 2. TRT-LLM NVFP4 static quantized model (pre-quantized checkpoint)
+
+ Uses spatially-correlated inputs that mimic real VAE latent patterns.
+ NVFP4 (4-bit) has lower precision than FP8 (8-bit), so we use relaxed thresholds.
+ """
+ if not wan22_nvfp4_bf16_checkpoints_exist:
+ pytest.skip(
+ f"Both checkpoints required. NVFP4: {CHECKPOINT_PATH_WAN22_NVFP4}, "
+ f"BF16: {CHECKPOINT_PATH_WAN22_BF16}"
+ )
+
+ # Reset dynamo cache to avoid recompile-limit errors from prior
+ # tests that compiled kernels with different dtypes (e.g. Float32).
+ torch._dynamo.reset()
+
+ print("\n" + "=" * 70)
+ print("NVFP4 STATIC QUANT vs BF16 ACCURACY TEST")
+ print("=" * 70)
+
+ # Load BF16 reference model
+ print(f"\n[BF16] Loading from {CHECKPOINT_PATH_WAN22_BF16}")
+ args_bf16 = DiffusionArgs(
+ checkpoint_path=CHECKPOINT_PATH_WAN22_BF16,
+ device="cuda",
+ dtype="bfloat16",
+ skip_components=SKIP_COMPONENTS,
+ )
+ pipeline_bf16 = PipelineLoader(args_bf16).load()
+
+ # Load NVFP4 static quantized model (from pre-quantized checkpoint)
+ print(f"\n[NVFP4 Static] Loading from {CHECKPOINT_PATH_WAN22_NVFP4}")
+ args_nvfp4_static = DiffusionArgs(
+ checkpoint_path=CHECKPOINT_PATH_WAN22_NVFP4,
+ device="cuda",
+ dtype="bfloat16",
+ skip_components=SKIP_COMPONENTS,
+ )
+ pipeline_nvfp4_static = PipelineLoader(args_nvfp4_static).load()
+
+ # Verify NVFP4 static model has quantized weights
+ static_quant_modules = 0
+ for name, module in pipeline_nvfp4_static.transformer.named_modules():
+ if isinstance(module, Linear):
+ if hasattr(module, "weight_scale") and module.weight_scale is not None:
+ if module.weight_scale.numel() > 1:
+ static_quant_modules += 1
+ print(f"[NVFP4 Static] Quantized Linear modules: {static_quant_modules}")
+ assert static_quant_modules > 0, "NVFP4 static model should have quantization scales"
+
+ # Create spatially-correlated test inputs (mimics real VAE latent patterns)
+ # Wan 2.2 TI2V-5B specs:
+ # - VAE compression: 16x16x4 (spatial x spatial x temporal)
+ # - Latent channels: 48 (z_dim=48)
+ # - 720P resolution: 1280x704 -> latent: 80x44
+ # - Text encoder: UMT5, max_length=512, dim=4096
+ torch.manual_seed(42)
+
+ batch_size = 2 # For CFG (positive + negative)
+ in_channels = 48 # Wan 2.2 TI2V-5B uses 48 latent channels
+ time_dim = 1 # Single frame for unit test
+
+ # 720P latent dimensions: 1280/16=80 width, 704/16=44 height
+ height = 44 # 720P latent height (704 / 16)
+ width = 80 # 720P latent width (1280 / 16)
+
+ # Text encoder: UMT5 with 4096 dim, typical sequence length ~226
+ text_seq_len = 226 # Default max_sequence_length for Wan
+ text_dim = 4096
+
+ # Create structured latent (not purely random - simulate real VAE output)
+ base_pattern = torch.randn(
+ 1, in_channels, time_dim, height // 4, width // 4, device="cuda", dtype=torch.bfloat16
+ )
+ hidden_states = F.interpolate(
+ base_pattern.view(1, in_channels, height // 4, width // 4),
+ size=(height, width),
+ mode="bilinear",
+ align_corners=False,
+ ).view(1, in_channels, time_dim, height, width)
+ hidden_states = hidden_states * 2.0
+ hidden_states = hidden_states.expand(batch_size, -1, -1, -1, -1).contiguous()
+
+ timestep = torch.tensor([500.0, 500.0], device="cuda", dtype=torch.bfloat16)
+
+ text_base = (
+ torch.randn(1, text_seq_len, text_dim, device="cuda", dtype=torch.bfloat16) * 0.1
+ )
+ encoder_hidden_states = text_base.expand(batch_size, -1, -1).contiguous()
+
+ print(
+ f"\n[Input] 720P latent: {hidden_states.shape} "
+ f"(batch={batch_size}, ch={in_channels}, t={time_dim}, h={height}, w={width})"
+ )
+ print(f"[Input] range: [{hidden_states.min():.2f}, {hidden_states.max():.2f}]")
+ print(f"[Input] encoder_hidden_states: {encoder_hidden_states.shape}")
+
+ # Run forward passes
+ print("\n[Forward] Running BF16 model...")
+ with torch.no_grad():
+ output_bf16 = pipeline_bf16.transformer(
+ hidden_states=hidden_states.clone(),
+ timestep=timestep,
+ encoder_hidden_states=encoder_hidden_states.clone(),
+ )
+
+ print("[Forward] Running NVFP4 static quant model...")
+ with torch.no_grad():
+ output_nvfp4_static = pipeline_nvfp4_static.transformer(
+ hidden_states=hidden_states.clone(),
+ timestep=timestep,
+ encoder_hidden_states=encoder_hidden_states.clone(),
+ )
+
+ # Compute metrics
+ output_bf16_float = output_bf16.float()
+ output_nvfp4_static_float = output_nvfp4_static.float()
+
+ # NVFP4 Static vs BF16
+ cos_sim_static = F.cosine_similarity(
+ output_nvfp4_static_float.flatten(), output_bf16_float.flatten(), dim=0
+ ).item()
+ mse_static = F.mse_loss(output_nvfp4_static_float, output_bf16_float).item()
+
+ # Output statistics
+ bf16_range = (output_bf16_float.min().item(), output_bf16_float.max().item())
+ nvfp4_static_range = (
+ output_nvfp4_static_float.min().item(),
+ output_nvfp4_static_float.max().item(),
+ )
+
+ print("\n" + "=" * 70)
+ print("RESULTS: NVFP4 QUANT vs BF16")
+ print("=" * 70)
+ print(f"{'Method':<25} {'Cosine Sim':>12} {'MSE':>12}")
+ print("-" * 70)
+ print(f"{'NVFP4 Static':<25} {cos_sim_static:>12.6f} {mse_static:>12.6f}")
+ print("-" * 70)
+ print(f"BF16 Output Range: [{bf16_range[0]:.4f}, {bf16_range[1]:.4f}]")
+ print(
+ f"NVFP4 Static Range: [{nvfp4_static_range[0]:.4f}, {nvfp4_static_range[1]:.4f}]"
+ )
+ print("=" * 70)
+
+ # Assertions - NVFP4 (4-bit) has lower precision than FP8 (8-bit)
+ assert cos_sim_static > 0.95, (
+ f"NVFP4 Static cosine similarity too low: {cos_sim_static:.6f}. Expected >0.95."
+ )
+ assert not torch.isnan(output_nvfp4_static).any(), "NVFP4 static output contains NaN"
+
+ print("\n[PASS] NVFP4 quantization accuracy test passed!")
+ print(f" - NVFP4 Static: cos_sim={cos_sim_static:.4f} (>0.95), MSE={mse_static:.6f}")
+
+ # Cleanup
+ del pipeline_bf16, pipeline_nvfp4_static
+ torch.cuda.empty_cache()
+
+
+# =============================================================================
+# Wan 2.2 FP8 Pre-quantized Checkpoint Fixtures
+# =============================================================================
+
+
+@pytest.fixture
+def wan22_fp8_checkpoint_exists():
+ """Check if Wan 2.2 FP8 checkpoint path exists."""
+ return CHECKPOINT_PATH_WAN22_FP8 and os.path.exists(CHECKPOINT_PATH_WAN22_FP8)
+
+
+@pytest.fixture
+def wan22_bf16_checkpoint_exists():
+ """Check if Wan 2.2 BF16 checkpoint path exists."""
+ return CHECKPOINT_PATH_WAN22_BF16 and os.path.exists(CHECKPOINT_PATH_WAN22_BF16)
+
+
+@pytest.fixture
+def wan22_both_checkpoints_exist():
+ """Check if both Wan 2.2 FP8 and BF16 checkpoints exist."""
+ fp8_exists = CHECKPOINT_PATH_WAN22_FP8 and os.path.exists(CHECKPOINT_PATH_WAN22_FP8)
+ bf16_exists = CHECKPOINT_PATH_WAN22_BF16 and os.path.exists(CHECKPOINT_PATH_WAN22_BF16)
+ return fp8_exists and bf16_exists
+
+
+@pytest.fixture
+def wan22_nvfp4_bf16_checkpoints_exist():
+ """Check if both NVFP4 and BF16 checkpoints exist."""
+ nvfp4_exists = CHECKPOINT_PATH_WAN22_NVFP4 and os.path.exists(CHECKPOINT_PATH_WAN22_NVFP4)
+ bf16_exists = CHECKPOINT_PATH_WAN22_BF16 and os.path.exists(CHECKPOINT_PATH_WAN22_BF16)
+ return nvfp4_exists and bf16_exists
+
+
+# =============================================================================
+# Optimization Tests
+# =============================================================================
+
+
+class TestWanOptimizations(unittest.TestCase):
+ """Runtime optimization correctness tests."""
+
+ DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
+
+ def setUp(self):
+ """Set up test fixtures and skip if checkpoint not available."""
+ torch.manual_seed(42)
+ if torch.cuda.is_available():
+ torch.cuda.manual_seed_all(42)
+ if not CHECKPOINT_PATH or not os.path.exists(CHECKPOINT_PATH):
+ self.skipTest(
+ "Checkpoint not available. Set DIFFUSION_MODEL_PATH environment variable."
+ )
+
+ def tearDown(self):
+ """Clean up GPU memory."""
+ import gc
+
+ gc.collect()
+ if torch.cuda.is_available():
+ torch.cuda.empty_cache()
+ torch.cuda.synchronize()
+
+ @torch.no_grad()
+ def test_teacache_multi_step(self):
+ """Test TeaCache correctness across multiple timesteps (validates caching behavior).
+
+ TeaCache is a runtime optimization that caches transformer outputs when timestep
+ embeddings change slowly. This test validates:
+ 1. Correctness against HuggingFace baseline
+ 2. Actual caching behavior across 20 timesteps
+ 3. Cache hits occur after warmup phase
+ """
+ if not os.path.exists(CHECKPOINT_PATH):
+ pytest.skip("Checkpoint not available. Set DIFFUSION_MODEL_PATH.")
+ if not is_wan21_checkpoint():
+ pytest.skip(
+ "This test requires Wan 2.1 checkpoint. Use DIFFUSION_MODEL_PATH with '2.1' in the path."
+ )
+
+ from safetensors.torch import load_file
+
+ print("\n" + "=" * 80)
+ print("TEACACHE MULTI-STEP TEST (20 steps, validates caching)")
+ print("=" * 80)
+
+ # Load HuggingFace baseline
+ print("\n[1/4] Loading HuggingFace baseline...")
+ args_trtllm = DiffusionArgs(
+ checkpoint_path=CHECKPOINT_PATH,
+ device="cuda",
+ dtype="bfloat16",
+ skip_components=SKIP_COMPONENTS,
+ )
+ pipeline_trtllm = PipelineLoader(args_trtllm).load()
+ config = pipeline_trtllm.transformer.model_config.pretrained_config
+
+ hf_model = (
+ HFWanTransformer3DModel(
+ patch_size=[config.patch_size[0], config.patch_size[1], config.patch_size[2]],
+ num_attention_heads=config.num_attention_heads,
+ attention_head_dim=config.attention_head_dim,
+ in_channels=config.in_channels,
+ out_channels=config.out_channels,
+ text_dim=config.text_dim,
+ freq_dim=config.freq_dim,
+ ffn_dim=config.ffn_dim,
+ num_layers=config.num_layers,
+ cross_attn_norm=config.cross_attn_norm,
+ qk_norm=config.qk_norm,
+ eps=config.eps,
+ )
+ .to("cuda", dtype=torch.bfloat16)
+ .eval()
+ )
+
+ # Load weights from checkpoint (auto-discover all shard files)
+ import glob
+
+ transformer_dir = os.path.join(CHECKPOINT_PATH, "transformer")
+ shard_pattern = os.path.join(transformer_dir, "diffusion_pytorch_model-*.safetensors")
+ shard_files = sorted(glob.glob(shard_pattern))
+
+ checkpoint_weights = {}
+ for shard_file in shard_files:
+ if os.path.exists(shard_file):
+ checkpoint_weights.update(load_file(shard_file))
+ hf_model.load_state_dict(checkpoint_weights, strict=True)
+ print(" ā HuggingFace model loaded")
+
+ # Load TeaCache-enabled pipeline
+ print("\n[2/4] Loading TeaCache-enabled TRT-LLM pipeline...")
+ args_teacache = DiffusionArgs(
+ checkpoint_path=CHECKPOINT_PATH,
+ device="cuda",
+ dtype="bfloat16",
+ skip_components=SKIP_COMPONENTS,
+ teacache=TeaCacheConfig(
+ enable_teacache=True,
+ teacache_thresh=0.2,
+ use_ret_steps=True,
+ ),
+ )
+ pipeline_teacache = PipelineLoader(args_teacache).load()
+ transformer_teacache = pipeline_teacache.transformer.eval()
+
+ # Verify TeaCache is enabled
+ assert hasattr(pipeline_teacache, "cache_backend"), "TeaCache backend not found in pipeline"
+ assert hasattr(transformer_teacache, "_original_forward"), (
+ "TeaCache forward hook not installed"
+ )
+ print(" ā TeaCache enabled and verified")
+
+ # Create FIXED test inputs
+ print("\n[3/4] Creating fixed test inputs...")
+ torch.manual_seed(42)
+ batch_size, num_frames, height, width, seq_len = 1, 1, 64, 64, 128
+
+ hidden_states = torch.randn(
+ batch_size,
+ config.in_channels,
+ num_frames,
+ height,
+ width,
+ dtype=torch.bfloat16,
+ device="cuda",
+ )
+ encoder_hidden_states = torch.randn(
+ batch_size, seq_len, config.text_dim, dtype=torch.bfloat16, device="cuda"
+ )
+
+ # Run multi-step inference
+ print("\n[4/4] Running 20-step inference with TeaCache...")
+ num_steps = 20
+ pipeline_teacache.cache_backend.refresh(num_inference_steps=num_steps)
+
+ # Simulate diffusion timestep schedule (from high to low)
+ timesteps = torch.linspace(999, 0, num_steps, dtype=torch.long, device="cuda")
+
+ hf_outputs, teacache_outputs = [], []
+
+ for step_idx, timestep in enumerate(timesteps):
+ timestep_tensor = timestep.unsqueeze(0)
+
+ # Run HuggingFace
+ with torch.no_grad():
+ hf_out = hf_model(
+ hidden_states=hidden_states.clone(),
+ timestep=timestep_tensor,
+ encoder_hidden_states=encoder_hidden_states.clone(),
+ return_dict=False,
+ )[0]
+ hf_outputs.append(hf_out)
+
+ # Run TeaCache
+ with torch.no_grad():
+ teacache_out = transformer_teacache(
+ hidden_states=hidden_states.clone(),
+ timestep=timestep_tensor,
+ encoder_hidden_states=encoder_hidden_states.clone(),
+ )
+ teacache_outputs.append(teacache_out)
+
+ if step_idx % 5 == 0:
+ print(f" Step {step_idx}/{num_steps} - timestep: {timestep.item()}")
+
+ # Compare outputs at selected steps
+ print("\n[Comparison] TeaCache vs HuggingFace at different steps:")
+ test_steps = [0, num_steps // 2, num_steps - 1]
+
+ for step_idx in test_steps:
+ hf_float = hf_outputs[step_idx].float()
+ teacache_float = teacache_outputs[step_idx].float()
+
+ cos_sim = F.cosine_similarity(
+ teacache_float.flatten(), hf_float.flatten(), dim=0
+ ).item()
+
+ print(f"\n Step {step_idx} (timestep={timesteps[step_idx].item()}):")
+ print(f" Cosine similarity: {cos_sim:.6f}")
+
+ assert cos_sim > 0.99, (
+ f"Step {step_idx}: TeaCache cosine similarity {cos_sim:.6f} below threshold 0.99"
+ )
+
+ print("\n[PASS] TeaCache multi-step correctness validated!")
+ print("=" * 80)
+
+ # Cleanup
+ del pipeline_trtllm, pipeline_teacache, transformer_teacache, hf_model
+ torch.cuda.empty_cache()
+
+
+# =============================================================================
+# Parallelism Tests
+# =============================================================================
+
+
+class TestWanParallelism(unittest.TestCase):
+ """Distributed parallelism correctness tests (CFG Parallelism)."""
+
+ DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
+
+ def setUp(self):
+ """Set up test fixtures and skip if checkpoint not available."""
+ torch.manual_seed(42)
+ if torch.cuda.is_available():
+ torch.cuda.manual_seed_all(42)
+ if not CHECKPOINT_PATH or not os.path.exists(CHECKPOINT_PATH):
+ self.skipTest(
+ "Checkpoint not available. Set DIFFUSION_MODEL_PATH environment variable."
+ )
+
+ def tearDown(self):
+ """Clean up GPU memory."""
+ import gc
+
+ gc.collect()
+ if torch.cuda.is_available():
+ torch.cuda.empty_cache()
+ torch.cuda.synchronize()
+
+ def test_cfg_2gpu_correctness(self):
+ """Test CFG Parallelism (cfg_size=2) correctness against standard CFG baseline."""
+ num_gpus = torch.cuda.device_count()
+ if num_gpus < 2:
+ pytest.skip("CFG parallel test requires at least 2 GPUs")
+ if not is_wan21_checkpoint():
+ pytest.skip(
+ "This test requires Wan 2.1 checkpoint. Use DIFFUSION_MODEL_PATH with '2.1' in the path."
+ )
+
+ print("\n" + "=" * 80)
+ print("CFG PARALLELISM (cfg_size=2) CORRECTNESS TEST")
+ print("=" * 80)
+
+ # Load standard CFG baseline on GPU 0
+ print("\n[1/3] Loading standard CFG baseline (cfg_size=1) on GPU 0...")
+ args_baseline = DiffusionArgs(
+ checkpoint_path=CHECKPOINT_PATH,
+ device="cuda:0",
+ dtype="bfloat16",
+ skip_components=SKIP_COMPONENTS,
+ parallel=ParallelConfig(dit_cfg_size=1), # Standard CFG (no parallel)
+ )
+ pipeline_baseline = PipelineLoader(args_baseline).load()
+ config = pipeline_baseline.transformer.model_config.pretrained_config
+
+ # Reset torch compile state to avoid BFloat16 dtype issues
+ torch._dynamo.reset()
+
+ # Create FIXED test inputs
+ print("\n[2/3] Creating fixed test inputs...")
+ torch.manual_seed(42)
+ batch_size, num_frames, height, width, seq_len = 1, 1, 64, 64, 128
+
+ latents = torch.randn(
+ batch_size,
+ config.in_channels,
+ num_frames,
+ height,
+ width,
+ dtype=torch.bfloat16,
+ device="cuda:0",
+ )
+ timestep = torch.tensor([500], dtype=torch.long, device="cuda:0")
+ prompt_embeds = torch.randn(
+ batch_size, seq_len, config.text_dim, dtype=torch.bfloat16, device="cuda:0"
+ )
+ neg_prompt_embeds = torch.randn(
+ batch_size, seq_len, config.text_dim, dtype=torch.bfloat16, device="cuda:0"
+ )
+
+ # Setup standard CFG config
+ cfg_config_baseline = pipeline_baseline._setup_cfg_config(
+ guidance_scale=5.0,
+ prompt_embeds=prompt_embeds,
+ neg_prompt_embeds=neg_prompt_embeds,
+ )
+
+ print(" Baseline CFG config:")
+ print(f" enabled: {cfg_config_baseline['enabled']}")
+ print(f" cfg_size: {cfg_config_baseline['cfg_size']}")
+
+ # Verify standard CFG is NOT parallel
+ assert not cfg_config_baseline["enabled"], "Baseline should not use CFG parallel"
+ assert cfg_config_baseline["cfg_size"] == 1, "Baseline cfg_size should be 1"
+
+ # Run standard CFG denoising step
+ def forward_fn(
+ latents, extra_stream_latents, timestep, encoder_hidden_states, extra_tensors
+ ):
+ return pipeline_baseline.transformer( # noqa: F821
+ hidden_states=latents,
+ timestep=timestep,
+ encoder_hidden_states=encoder_hidden_states,
+ )
+
+ with torch.no_grad():
+ baseline_output, _, _, _ = pipeline_baseline._denoise_step_standard(
+ latents=latents.clone(),
+ extra_stream_latents={},
+ timestep=timestep,
+ prompt_embeds=cfg_config_baseline["prompt_embeds"],
+ forward_fn=forward_fn,
+ guidance_scale=5.0,
+ guidance_rescale=0.0,
+ local_extras={},
+ )
+
+ print(f" ā Baseline output shape: {baseline_output.shape}")
+ print(f" ā Baseline range: [{baseline_output.min():.4f}, {baseline_output.max():.4f}]")
+
+ # Cleanup baseline to free memory for CFG workers
+ del pipeline_baseline
+ torch.cuda.empty_cache()
+
+ # Run CFG parallel (cfg_size=2) in distributed processes
+ print("\n[3/3] Running CFG Parallelism (cfg_size=2) across 2 GPUs...")
+ cfg_size = 2
+
+ inputs_cpu = [
+ prompt_embeds.cpu(),
+ neg_prompt_embeds.cpu(),
+ latents.cpu(),
+ timestep.cpu(),
+ ]
+
+ manager = mp.Manager()
+ return_dict = manager.dict()
+
+ # Spawn CFG workers
+ mp.spawn(
+ _run_cfg_worker,
+ args=(cfg_size, CHECKPOINT_PATH, inputs_cpu, return_dict),
+ nprocs=cfg_size,
+ join=True,
+ )
+
+ # Get CFG parallel output from rank 0
+ cfg_parallel_output = return_dict["output"].to("cuda:0")
+ print(f" ā CFG parallel output shape: {cfg_parallel_output.shape}")
+
+ # Compare outputs
+ print("\n[Comparison] CFG Parallel vs Standard CFG:")
+ baseline_float = baseline_output.float()
+ cfg_parallel_float = cfg_parallel_output.float()
+
+ cos_sim = F.cosine_similarity(
+ cfg_parallel_float.flatten(), baseline_float.flatten(), dim=0
+ ).item()
+
+ max_diff = torch.max(torch.abs(cfg_parallel_float - baseline_float)).item()
+ mean_diff = torch.mean(torch.abs(cfg_parallel_float - baseline_float)).item()
+
+ print(f" Cosine similarity: {cos_sim:.6f}")
+ print(f" Max absolute difference: {max_diff:.6f}")
+ print(f" Mean absolute difference: {mean_diff:.6f}")
+ print(
+ f" CFG parallel range: [{cfg_parallel_float.min():.4f}, {cfg_parallel_float.max():.4f}]"
+ )
+ print(f" Baseline range: [{baseline_float.min():.4f}, {baseline_float.max():.4f}]")
+
+ assert cos_sim > 0.99, (
+ f"CFG parallel cosine similarity {cos_sim:.6f} below threshold 0.99. "
+ f"CFG Parallelism does not match standard CFG baseline."
+ )
+
+ print("\n[PASS] CFG Parallelism (cfg_size=2) validated!")
+ print(" ā CFG parallel produces same output as standard CFG")
+ print(" ā Prompt splitting and all-gather working correctly")
+ print("=" * 80)
+
+ torch.cuda.empty_cache()
+
+
+# =============================================================================
+# Combined Optimizations Tests
+# =============================================================================
+
+
+class TestWanCombinedOptimizations(unittest.TestCase):
+ """Test all optimizations combined: FP8 + TeaCache + TRTLLM attention + CFG Parallelism."""
+
+ DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
+
+ def setUp(self):
+ """Set up test fixtures and skip if checkpoint not available."""
+ torch.manual_seed(42)
+ if torch.cuda.is_available():
+ torch.cuda.manual_seed_all(42)
+ if not CHECKPOINT_PATH or not os.path.exists(CHECKPOINT_PATH):
+ self.skipTest(
+ "Checkpoint not available. Set DIFFUSION_MODEL_PATH environment variable."
+ )
+
+ def tearDown(self):
+ """Clean up GPU memory."""
+ import gc
+
+ gc.collect()
+ if torch.cuda.is_available():
+ torch.cuda.empty_cache()
+ torch.cuda.synchronize()
+
+ def test_all_optimizations_combined(self):
+ """Test FP8 + TeaCache + TRTLLM attention + CFG=2 combined correctness.
+
+ This test validates that all optimizations work together correctly:
+ 1. FP8 per-tensor quantization for reduced memory/compute
+ 2. TeaCache for caching repeated computations
+ 3. TRTLLM attention backend for optimized attention kernels
+ 4. CFG Parallelism (cfg_size=2) for distributed CFG computation
+
+ We compare against a standard CFG baseline with relaxed thresholds since multiple
+ optimizations compound numerical differences.
+ """
+ num_gpus = torch.cuda.device_count()
+ if num_gpus < 2:
+ pytest.skip("Combined optimization test requires at least 2 GPUs for CFG parallel")
+ if not is_wan21_checkpoint():
+ pytest.skip(
+ "This test requires Wan 2.1 checkpoint. Use DIFFUSION_MODEL_PATH with '2.1' in the path."
+ )
+
+ print("\n" + "=" * 80)
+ print("ALL OPTIMIZATIONS COMBINED TEST")
+ print("FP8 + TeaCache + TRTLLM Attention + CFG Parallelism (cfg_size=2)")
+ print("=" * 80)
+
+ # Load baseline on GPU 0 (no optimizations, standard CFG)
+ print("\n[1/3] Loading baseline on GPU 0 (standard CFG, no optimizations)...")
+ args_baseline = DiffusionArgs(
+ checkpoint_path=CHECKPOINT_PATH,
+ device="cuda:0",
+ dtype="bfloat16",
+ skip_components=SKIP_COMPONENTS,
+ parallel=ParallelConfig(dit_cfg_size=1), # Standard CFG
+ )
+ pipeline_baseline = PipelineLoader(args_baseline).load()
+ config = pipeline_baseline.transformer.model_config.pretrained_config
+
+ # Reset torch compile state to avoid BFloat16 dtype issues
+ torch._dynamo.reset()
+
+ # Create FIXED test inputs
+ print("\n[2/3] Creating fixed test inputs...")
+ torch.manual_seed(42)
+ batch_size, num_frames, height, width, seq_len = 1, 1, 64, 64, 128
+
+ latents = torch.randn(
+ batch_size,
+ config.in_channels,
+ num_frames,
+ height,
+ width,
+ dtype=torch.bfloat16,
+ device="cuda:0",
+ )
+ timestep = torch.tensor([500], dtype=torch.long, device="cuda:0")
+ prompt_embeds = torch.randn(
+ batch_size, seq_len, config.text_dim, dtype=torch.bfloat16, device="cuda:0"
+ )
+ neg_prompt_embeds = torch.randn(
+ batch_size, seq_len, config.text_dim, dtype=torch.bfloat16, device="cuda:0"
+ )
+
+ # Setup standard CFG config
+ cfg_config_baseline = pipeline_baseline._setup_cfg_config(
+ guidance_scale=5.0,
+ prompt_embeds=prompt_embeds,
+ neg_prompt_embeds=neg_prompt_embeds,
+ )
+
+ # Run baseline standard CFG
+ print(" Running baseline (standard CFG)...")
+
+ def forward_fn_baseline(
+ latents, extra_stream_latents, timestep, encoder_hidden_states, extra_tensors
+ ):
+ return pipeline_baseline.transformer( # noqa: F821
+ hidden_states=latents,
+ timestep=timestep,
+ encoder_hidden_states=encoder_hidden_states,
+ )
+
+ with torch.no_grad():
+ baseline_output, _, _, _ = pipeline_baseline._denoise_step_standard(
+ latents=latents.clone(),
+ extra_stream_latents={},
+ timestep=timestep,
+ prompt_embeds=cfg_config_baseline["prompt_embeds"],
+ forward_fn=forward_fn_baseline,
+ guidance_scale=5.0,
+ guidance_rescale=0.0,
+ local_extras={},
+ )
+
+ print(f" ā Baseline output shape: {baseline_output.shape}")
+ print(f" ā Baseline range: [{baseline_output.min():.4f}, {baseline_output.max():.4f}]")
+
+ # Cleanup baseline to free memory for workers
+ del pipeline_baseline
+ torch.cuda.empty_cache()
+
+ # Run with ALL optimizations combined in distributed processes
+ print("\n[3/3] Running with ALL optimizations (FP8 + TeaCache + TRTLLM + CFG=2)...")
+ cfg_size = 2
+
+ inputs_cpu = [
+ prompt_embeds.cpu(),
+ neg_prompt_embeds.cpu(),
+ latents.cpu(),
+ timestep.cpu(),
+ ]
+
+ manager = mp.Manager()
+ return_dict = manager.dict()
+
+ # Spawn workers
+ mp.spawn(
+ _run_all_optimizations_worker,
+ args=(cfg_size, CHECKPOINT_PATH, inputs_cpu, return_dict),
+ nprocs=cfg_size,
+ join=True,
+ )
+
+ # Get combined optimization output
+ combined_output = return_dict["output"].to("cuda:0")
+
+ # Compare outputs with RELAXED thresholds (multiple optimizations compound errors)
+ print("\n[Comparison] Combined Optimizations vs Baseline:")
+ baseline_float = baseline_output.float()
+ combined_float = combined_output.float()
+
+ cos_sim = F.cosine_similarity(
+ combined_float.flatten(), baseline_float.flatten(), dim=0
+ ).item()
+
+ max_diff = torch.max(torch.abs(combined_float - baseline_float)).item()
+ mean_diff = torch.mean(torch.abs(combined_float - baseline_float)).item()
+
+ print(f" Cosine similarity: {cos_sim:.6f}")
+ print(f" Max absolute difference: {max_diff:.6f}")
+ print(f" Mean absolute difference: {mean_diff:.6f}")
+ print(f" Combined range: [{combined_float.min():.4f}, {combined_float.max():.4f}]")
+ print(f" Baseline range: [{baseline_float.min():.4f}, {baseline_float.max():.4f}]")
+
+ # Relaxed threshold: cos_sim > 0.90 (compounded numerical differences from 4 optimizations)
+ assert cos_sim > 0.90, (
+ f"Combined optimization cosine similarity {cos_sim:.6f} below threshold 0.90. "
+ f"This suggests an issue with optimization interactions."
+ )
+
+ print("\n[PASS] All optimizations (FP8 + TeaCache + TRTLLM + CFG) validated!")
+ print(" ā All optimizations work correctly together")
+ print(" ā Numerical accuracy within acceptable tolerance")
+ print("=" * 80)
+
+ torch.cuda.empty_cache()
+
+
+# =============================================================================
+# Two-Stage Transformer Tests (Wan 2.2)
+# =============================================================================
+
+
+class TestWanTwoStageTransformer(unittest.TestCase):
+ """Test two-stage transformer support for Wan 2.2 T2V."""
+
+ DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
+
+ def setUp(self):
+ """Set up test fixtures and skip if checkpoint not available."""
+ torch.manual_seed(42)
+ if torch.cuda.is_available():
+ torch.cuda.manual_seed_all(42)
+ if not CHECKPOINT_PATH_WAN22_T2V or not os.path.exists(CHECKPOINT_PATH_WAN22_T2V):
+ self.skipTest(
+ "Wan 2.2 T2V checkpoint not available. Set DIFFUSION_MODEL_PATH_WAN22_T2V."
+ )
+
+ def tearDown(self):
+ """Clean up GPU memory."""
+ import gc
+
+ gc.collect()
+ if torch.cuda.is_available():
+ torch.cuda.empty_cache()
+ torch.cuda.synchronize()
+
+ def test_two_stage_pipeline_initialization(self):
+ """Test that Wan 2.2 pipeline initializes with two transformers."""
+ if not is_wan22_checkpoint():
+ pytest.skip(
+ "This test requires Wan 2.2 T2V checkpoint. Set DIFFUSION_MODEL_PATH_WAN22_T2V."
+ )
+ print("\n" + "=" * 80)
+ print("WAN 2.2 TWO-STAGE PIPELINE INITIALIZATION TEST")
+ print("=" * 80)
+
+ args = DiffusionArgs(
+ checkpoint_path=CHECKPOINT_PATH_WAN22_T2V,
+ device="cuda",
+ dtype="bfloat16",
+ skip_components=SKIP_COMPONENTS,
+ )
+ pipeline = PipelineLoader(args).load()
+
+ try:
+ # Check if this is a two-stage model
+ has_boundary_ratio = pipeline.boundary_ratio is not None
+ has_transformer_2 = pipeline.transformer_2 is not None
+
+ print(f"\n[Pipeline] boundary_ratio: {pipeline.boundary_ratio}")
+ print(f"[Pipeline] transformer: {pipeline.transformer is not None}")
+ print(f"[Pipeline] transformer_2: {has_transformer_2}")
+
+ if not has_boundary_ratio:
+ pytest.skip("Checkpoint is not Wan 2.2 (no boundary_ratio)")
+
+ # Verify two-stage configuration
+ assert pipeline.transformer is not None, "Transformer (high-noise) should exist"
+ assert has_transformer_2, "Transformer_2 (low-noise) should exist for Wan 2.2"
+ assert 0.0 < pipeline.boundary_ratio < 1.0, (
+ f"boundary_ratio should be in (0, 1), got {pipeline.boundary_ratio}"
+ )
+
+ print("\n[PASS] ā Wan 2.2 two-stage pipeline initialized correctly")
+ print(f" ā boundary_ratio: {pipeline.boundary_ratio}")
+ print("=" * 80)
+
+ finally:
+ del pipeline
+ import gc
+
+ gc.collect()
+ torch.cuda.empty_cache()
+
+ def test_two_stage_transformer_selection_logic(self):
+ """Test that correct transformer is selected based on timestep."""
+ if not is_wan22_checkpoint():
+ pytest.skip(
+ "This test requires Wan 2.2 T2V checkpoint. Set DIFFUSION_MODEL_PATH_WAN22_T2V."
+ )
+ print("\n" + "=" * 80)
+ print("WAN 2.2 TRANSFORMER SELECTION LOGIC TEST")
+ print("=" * 80)
+
+ args = DiffusionArgs(
+ checkpoint_path=CHECKPOINT_PATH_WAN22_T2V,
+ device="cuda",
+ dtype="bfloat16",
+ skip_components=SKIP_COMPONENTS,
+ )
+ pipeline = PipelineLoader(args).load()
+
+ try:
+ # Skip if not two-stage
+ if pipeline.boundary_ratio is None or pipeline.transformer_2 is None:
+ pytest.skip("Checkpoint is not Wan 2.2 (two-stage)")
+
+ # Calculate boundary timestep
+ num_train_timesteps = 1000 # Default for Wan models
+ boundary_timestep = pipeline.boundary_ratio * num_train_timesteps
+
+ print(f"\n[Selection Logic] boundary_ratio: {pipeline.boundary_ratio}")
+ print(f"[Selection Logic] boundary_timestep: {boundary_timestep:.1f}")
+
+ # Create mock tensors for testing
+ batch_size, num_frames, height, width = 1, 1, 64, 64
+ seq_len = 128
+ # Use standard Wan model dimensions
+ in_channels = 16 # Standard for Wan models
+ text_dim = 4096 # Standard for Wan models
+
+ latents = torch.randn(
+ batch_size,
+ in_channels,
+ num_frames,
+ height,
+ width,
+ dtype=torch.bfloat16,
+ device=self.DEVICE,
+ )
+ encoder_hidden_states = torch.randn(
+ batch_size, seq_len, text_dim, dtype=torch.bfloat16, device=self.DEVICE
+ )
+
+ # Test high-noise timestep (should use transformer)
+ high_noise_t = torch.tensor([900.0], device=self.DEVICE)
+ print(f"\n[High-Noise] timestep: {high_noise_t.item():.1f}")
+ print(f"[High-Noise] {high_noise_t.item():.1f} >= {boundary_timestep:.1f}: True")
+ print("[High-Noise] Should use: transformer (high-noise)")
+
+ with torch.no_grad():
+ high_noise_output = pipeline.transformer(
+ hidden_states=latents,
+ timestep=high_noise_t,
+ encoder_hidden_states=encoder_hidden_states,
+ )
+ print(f"[High-Noise] ā Output shape: {high_noise_output.shape}")
+
+ # Test low-noise timestep (should use transformer_2)
+ low_noise_t = torch.tensor([200.0], device=self.DEVICE)
+ print(f"\n[Low-Noise] timestep: {low_noise_t.item():.1f}")
+ print(f"[Low-Noise] {low_noise_t.item():.1f} < {boundary_timestep:.1f}: True")
+ print("[Low-Noise] Should use: transformer_2 (low-noise)")
+
+ with torch.no_grad():
+ low_noise_output = pipeline.transformer_2(
+ hidden_states=latents,
+ timestep=low_noise_t,
+ encoder_hidden_states=encoder_hidden_states,
+ )
+ print(f"[Low-Noise] ā Output shape: {low_noise_output.shape}")
+
+ # Verify outputs have same shape but different values
+ assert high_noise_output.shape == low_noise_output.shape
+ assert not torch.allclose(high_noise_output, low_noise_output, atol=1e-3), (
+ "Different transformers should produce different outputs"
+ )
+
+ print("\n[PASS] ā Transformer selection logic working correctly")
+ print(" ā High-noise stage uses transformer")
+ print(" ā Low-noise stage uses transformer_2")
+ print("=" * 80)
+
+ finally:
+ del pipeline
+ import gc
+
+ gc.collect()
+ torch.cuda.empty_cache()
+
+ def test_two_stage_with_custom_boundary_ratio(self):
+ """Test overriding boundary_ratio at inference time."""
+ if not is_wan22_checkpoint():
+ pytest.skip(
+ "This test requires Wan 2.2 T2V checkpoint. Set DIFFUSION_MODEL_PATH_WAN22_T2V."
+ )
+ print("\n" + "=" * 80)
+ print("WAN 2.2 CUSTOM BOUNDARY_RATIO TEST")
+ print("=" * 80)
+
+ args = DiffusionArgs(
+ checkpoint_path=CHECKPOINT_PATH_WAN22_T2V,
+ device="cuda",
+ dtype="bfloat16",
+ skip_components=SKIP_COMPONENTS,
+ )
+ pipeline = PipelineLoader(args).load()
+
+ try:
+ # Skip if not two-stage
+ if pipeline.boundary_ratio is None or pipeline.transformer_2 is None:
+ pytest.skip("Checkpoint is not Wan 2.2 (two-stage)")
+
+ model_boundary_ratio = pipeline.boundary_ratio
+ custom_boundary_ratio = 0.3 # Override value
+
+ print(f"\n[Custom Boundary] Model default: {model_boundary_ratio}")
+ print(f"[Custom Boundary] Custom override: {custom_boundary_ratio}")
+
+ # Verify custom value would change boundary timestep
+ num_train_timesteps = 1000
+ model_boundary_t = model_boundary_ratio * num_train_timesteps
+ custom_boundary_t = custom_boundary_ratio * num_train_timesteps
+
+ print(f"[Custom Boundary] Model boundary_timestep: {model_boundary_t:.1f}")
+ print(f"[Custom Boundary] Custom boundary_timestep: {custom_boundary_t:.1f}")
+ print(
+ f"[Custom Boundary] Difference: {abs(model_boundary_t - custom_boundary_t):.1f} timesteps"
+ )
+
+ assert custom_boundary_ratio != model_boundary_ratio
+ print("\n[PASS] ā Custom boundary_ratio can override model default")
+ print("=" * 80)
+
+ finally:
+ del pipeline
+ import gc
+
+ gc.collect()
+ torch.cuda.empty_cache()
+
+ def test_two_stage_guidance_scale_2(self):
+ """Test two-stage denoising with different guidance_scale_2 values."""
+ if not is_wan22_checkpoint():
+ pytest.skip(
+ "This test requires Wan 2.2 T2V checkpoint. Set DIFFUSION_MODEL_PATH_WAN22_T2V."
+ )
+ print("\n" + "=" * 80)
+ print("WAN 2.2 GUIDANCE_SCALE_2 SUPPORT TEST")
+ print("=" * 80)
+
+ args = DiffusionArgs(
+ checkpoint_path=CHECKPOINT_PATH_WAN22_T2V,
+ device="cuda",
+ dtype="bfloat16",
+ skip_components=SKIP_COMPONENTS,
+ )
+ pipeline = PipelineLoader(args).load()
+
+ try:
+ # Skip if not two-stage
+ if pipeline.boundary_ratio is None or pipeline.transformer_2 is None:
+ pytest.skip("Checkpoint is not Wan 2.2 (two-stage)")
+
+ print("\n[Guidance Scale 2] Two-stage model supports separate guidance scales:")
+ print("[Guidance Scale 2] High-noise stage: uses guidance_scale (e.g., 4.0)")
+ print("[Guidance Scale 2] Low-noise stage: uses guidance_scale_2 (e.g., 2.0, 3.0, 4.0)")
+ print("\n[PASS] ā Different guidance scales supported for two stages")
+ print("=" * 80)
+
+ finally:
+ del pipeline
+ import gc
+
+ gc.collect()
+ torch.cuda.empty_cache()
+
+ def test_two_stage_with_teacache_both_transformers(self):
+ """Test that TeaCache is enabled for both transformers in two-stage mode."""
+ if not is_wan22_checkpoint():
+ pytest.skip(
+ "This test requires Wan 2.2 T2V checkpoint. Set DIFFUSION_MODEL_PATH_WAN22_T2V."
+ )
+ print("\n" + "=" * 80)
+ print("WAN 2.2 TWO-STAGE + TEACACHE TEST")
+ print("=" * 80)
+
+ args = DiffusionArgs(
+ checkpoint_path=CHECKPOINT_PATH_WAN22_T2V,
+ device="cuda",
+ dtype="bfloat16",
+ skip_components=SKIP_COMPONENTS,
+ teacache=TeaCacheConfig(
+ enable_teacache=True,
+ teacache_thresh=0.2,
+ use_ret_steps=True,
+ ),
+ )
+ pipeline = PipelineLoader(args).load()
+
+ try:
+ # Skip if not two-stage
+ if pipeline.boundary_ratio is None or pipeline.transformer_2 is None:
+ pytest.skip("Checkpoint is not Wan 2.2 (two-stage)")
+
+ # Verify TeaCache on transformer (high-noise)
+ assert hasattr(pipeline, "transformer_cache_backend"), (
+ "Pipeline missing transformer_cache_backend"
+ )
+ assert pipeline.transformer_cache_backend is not None
+ print("\n[TeaCache] ā Transformer (high-noise): TeaCache enabled")
+
+ # Verify TeaCache on transformer_2 (low-noise)
+ assert hasattr(pipeline, "transformer_2_cache_backend"), (
+ "Pipeline missing transformer_2_cache_backend"
+ )
+ assert pipeline.transformer_2_cache_backend is not None
+ print("[TeaCache] ā Transformer_2 (low-noise): TeaCache enabled")
+
+ # Verify both have get_stats method
+ assert hasattr(pipeline.transformer_cache_backend, "get_stats")
+ assert hasattr(pipeline.transformer_2_cache_backend, "get_stats")
+ print("[TeaCache] ā Both transformers support statistics logging")
+
+ print("\n[PASS] ā TeaCache enabled for BOTH transformers")
+ print(" ā Low-noise stage benefits MORE from TeaCache")
+ print("=" * 80)
+
+ finally:
+ del pipeline
+ import gc
+
+ gc.collect()
+ torch.cuda.empty_cache()
+
+ def test_two_stage_with_fp8_quantization(self):
+ """Test two-stage with FP8 quantization on both transformers."""
+ if not is_wan22_checkpoint():
+ pytest.skip(
+ "This test requires Wan 2.2 T2V checkpoint. Set DIFFUSION_MODEL_PATH_WAN22_T2V."
+ )
+ print("\n" + "=" * 80)
+ print("WAN 2.2 TWO-STAGE + FP8 QUANTIZATION TEST")
+ print("=" * 80)
+
+ args = DiffusionArgs(
+ checkpoint_path=CHECKPOINT_PATH_WAN22_T2V,
+ device="cuda",
+ dtype="bfloat16",
+ skip_components=SKIP_COMPONENTS,
+ quant_config={"quant_algo": "FP8", "dynamic": True},
+ )
+ pipeline = PipelineLoader(args).load()
+
+ try:
+ # Skip if not two-stage
+ if pipeline.boundary_ratio is None or pipeline.transformer_2 is None:
+ pytest.skip("Checkpoint is not Wan 2.2 (two-stage)")
+
+ # Verify FP8 in transformer (high-noise)
+ found_fp8_t1 = False
+ for name, param in pipeline.transformer.named_parameters():
+ if "blocks.0" in name and "weight" in name and param.dtype == torch.float8_e4m3fn:
+ found_fp8_t1 = True
+ print(f"\n[FP8] ā Transformer: Found FP8 weight in {name}")
+ break
+ assert found_fp8_t1, "No FP8 weights found in transformer"
+
+ # Verify FP8 in transformer_2 (low-noise)
+ found_fp8_t2 = False
+ for name, param in pipeline.transformer_2.named_parameters():
+ if "blocks.0" in name and "weight" in name and param.dtype == torch.float8_e4m3fn:
+ found_fp8_t2 = True
+ print(f"[FP8] ā Transformer_2: Found FP8 weight in {name}")
+ break
+ assert found_fp8_t2, "No FP8 weights found in transformer_2"
+
+ print("\n[PASS] ā FP8 quantization enabled for BOTH transformers")
+ print("=" * 80)
+
+ finally:
+ del pipeline
+ import gc
+
+ gc.collect()
+ torch.cuda.empty_cache()
+
+ def test_two_stage_with_trtllm_attention(self):
+ """Test two-stage with TRTLLM attention backend on both transformers."""
+ if not is_wan22_checkpoint():
+ pytest.skip(
+ "This test requires Wan 2.2 T2V checkpoint. Set DIFFUSION_MODEL_PATH_WAN22_T2V."
+ )
+ print("\n" + "=" * 80)
+ print("WAN 2.2 TWO-STAGE + TRTLLM ATTENTION TEST")
+ print("=" * 80)
+
+ args = DiffusionArgs(
+ checkpoint_path=CHECKPOINT_PATH_WAN22_T2V,
+ device="cuda",
+ dtype="bfloat16",
+ skip_components=SKIP_COMPONENTS,
+ attention=AttentionConfig(backend="TRTLLM"),
+ )
+ pipeline = PipelineLoader(args).load()
+
+ try:
+ # Skip if not two-stage
+ if pipeline.boundary_ratio is None or pipeline.transformer_2 is None:
+ pytest.skip("Checkpoint is not Wan 2.2 (two-stage)")
+
+ # Verify TRTLLM attention on transformer (high-noise)
+ first_block_t1 = pipeline.transformer.blocks[0]
+ attn1_backend_t1 = first_block_t1.attn1.attn_backend
+ attn2_backend_t1 = first_block_t1.attn2.attn_backend
+
+ assert attn1_backend_t1 == "TRTLLM", (
+ f"Expected TRTLLM for transformer self-attn, got {attn1_backend_t1}"
+ )
+ assert attn2_backend_t1 == "VANILLA", (
+ f"Expected VANILLA for transformer cross-attn, got {attn2_backend_t1}"
+ )
+
+ print("\n[Attention] Transformer (high-noise):")
+ print(f" ā Self-attention: {attn1_backend_t1}")
+ print(f" ā Cross-attention: {attn2_backend_t1}")
+
+ # Verify TRTLLM attention on transformer_2 (low-noise)
+ first_block_t2 = pipeline.transformer_2.blocks[0]
+ attn1_backend_t2 = first_block_t2.attn1.attn_backend
+ attn2_backend_t2 = first_block_t2.attn2.attn_backend
+
+ assert attn1_backend_t2 == "TRTLLM", (
+ f"Expected TRTLLM for transformer_2 self-attn, got {attn1_backend_t2}"
+ )
+ assert attn2_backend_t2 == "VANILLA", (
+ f"Expected VANILLA for transformer_2 cross-attn, got {attn2_backend_t2}"
+ )
+
+ print("[Attention] Transformer_2 (low-noise):")
+ print(f" ā Self-attention: {attn1_backend_t2}")
+ print(f" ā Cross-attention: {attn2_backend_t2}")
+
+ print("\n[PASS] ā TRTLLM attention enabled for BOTH transformers")
+ print("=" * 80)
+
+ finally:
+ del pipeline
+ import gc
+
+ gc.collect()
+ torch.cuda.empty_cache()
+
+ def test_two_stage_all_optimizations(self):
+ """Test two-stage with ALL optimizations: FP8 + TeaCache + TRTLLM."""
+ if not is_wan22_checkpoint():
+ pytest.skip(
+ "This test requires Wan 2.2 T2V checkpoint. Set DIFFUSION_MODEL_PATH_WAN22_T2V."
+ )
+ print("\n" + "=" * 80)
+ print("WAN 2.2 TWO-STAGE + ALL OPTIMIZATIONS TEST")
+ print("FP8 + TeaCache + TRTLLM Attention")
+ print("=" * 80)
+
+ args = DiffusionArgs(
+ checkpoint_path=CHECKPOINT_PATH_WAN22_T2V,
+ device="cuda",
+ dtype="bfloat16",
+ skip_components=SKIP_COMPONENTS,
+ quant_config={"quant_algo": "FP8_BLOCK_SCALES", "dynamic": True},
+ attention=AttentionConfig(backend="TRTLLM"),
+ teacache=TeaCacheConfig(
+ enable_teacache=True,
+ teacache_thresh=0.2,
+ use_ret_steps=True,
+ ),
+ )
+ pipeline = PipelineLoader(args).load()
+
+ try:
+ # Skip if not two-stage
+ if pipeline.boundary_ratio is None or pipeline.transformer_2 is None:
+ pytest.skip("Checkpoint is not Wan 2.2 (two-stage)")
+
+ optimizations = []
+
+ # Check FP8
+ for name, param in pipeline.transformer.named_parameters():
+ if "blocks.0" in name and "weight" in name and param.dtype == torch.float8_e4m3fn:
+ optimizations.append("FP8")
+ break
+
+ # Check TRTLLM
+ if pipeline.transformer.blocks[0].attn1.attn_backend == "TRTLLM":
+ optimizations.append("TRTLLM")
+
+ # Check TeaCache
+ if (
+ hasattr(pipeline, "transformer_cache_backend")
+ and pipeline.transformer_cache_backend is not None
+ ):
+ optimizations.append("TeaCache")
+
+ # Check two-stage
+ optimizations.append("Two-Stage")
+
+ print(f"\n[All Optimizations] Enabled: {', '.join(optimizations)}")
+ assert len(optimizations) == 4, (
+ f"Expected 4 optimizations, got {len(optimizations)}: {optimizations}"
+ )
+
+ # Verify all optimizations on transformer_2 as well
+ for name, param in pipeline.transformer_2.named_parameters():
+ if "blocks.0" in name and "weight" in name and param.dtype == torch.float8_e4m3fn:
+ print("[All Optimizations] ā Transformer_2: FP8 enabled")
+ break
+
+ if pipeline.transformer_2.blocks[0].attn1.attn_backend == "TRTLLM":
+ print("[All Optimizations] ā Transformer_2: TRTLLM enabled")
+
+ if (
+ hasattr(pipeline, "transformer_2_cache_backend")
+ and pipeline.transformer_2_cache_backend is not None
+ ):
+ print("[All Optimizations] ā Transformer_2: TeaCache enabled")
+
+ print("\n[PASS] ā All optimizations working on BOTH transformers")
+ print("=" * 80)
+
+ finally:
+ del pipeline
+ import gc
+
+ gc.collect()
+ torch.cuda.empty_cache()
+
+
+# =============================================================================
+# Robustness Tests
+# =============================================================================
+
+
+class TestWanRobustness(unittest.TestCase):
+ """Error handling and edge case tests."""
+
+ DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
+
+ def setUp(self):
+ """Set up test fixtures and skip if checkpoint not available."""
+ if not CHECKPOINT_PATH or not os.path.exists(CHECKPOINT_PATH):
+ self.skipTest(
+ "Checkpoint not available. Set DIFFUSION_MODEL_PATH environment variable."
+ )
+
+ def tearDown(self):
+ """Clean up GPU memory."""
+ import gc
+
+ gc.collect()
+ if torch.cuda.is_available():
+ torch.cuda.empty_cache()
+ torch.cuda.synchronize()
+
+ def test_invalid_quant_config(self):
+ """Test that invalid quantization config raises appropriate error."""
+ with pytest.raises((ValueError, KeyError)):
+ args = DiffusionArgs(
+ checkpoint_path=CHECKPOINT_PATH,
+ device="cuda",
+ dtype="bfloat16",
+ skip_components=SKIP_COMPONENTS,
+ quant_config={"quant_algo": "INVALID_ALGO"},
+ )
+ pipeline = PipelineLoader(args).load() # noqa: F841
+
+ print("\n[Error Handling] ā Invalid quant_algo raises error as expected")
+
+
+if __name__ == "__main__":
+ unittest.main(verbosity=2)
diff --git a/tests/unittest/_torch/visual_gen/test_wan_i2v.py b/tests/unittest/_torch/visual_gen/test_wan_i2v.py
new file mode 100644
index 0000000000..34d232893f
--- /dev/null
+++ b/tests/unittest/_torch/visual_gen/test_wan_i2v.py
@@ -0,0 +1,1491 @@
+"""Optimized tests for Wan Image-to-Video (I2V) pipeline with module-scoped fixtures.
+
+Run with:
+ pytest tests/visual_gen/test_wan_i2v_2.py -v
+
+ # With real checkpoint:
+ DIFFUSION_MODEL_PATH=/path/to/Wan-I2V-Diffusers pytest tests/visual_gen/test_wan_i2v_2.py -v
+
+ # Run only smoke tests:
+ pytest tests/visual_gen/test_wan_i2v_2.py -v -m "unit and smoke"
+
+ # Run only Wan 2.1 tests:
+ pytest tests/visual_gen/test_wan_i2v_2.py -v -m "wan21"
+
+ # Run only Wan 2.2 tests:
+ pytest tests/visual_gen/test_wan_i2v_2.py -v -m "wan22"
+"""
+
+import os
+
+os.environ["TLLM_DISABLE_MPI"] = "1"
+
+import unittest
+from pathlib import Path
+from types import SimpleNamespace
+
+import pytest
+import torch
+import torch.distributed as dist
+import torch.multiprocessing as mp
+import torch.nn.functional as F
+from PIL import Image
+
+from tensorrt_llm._torch.visual_gen.config import (
+ AttentionConfig,
+ DiffusionArgs,
+ DiffusionModelConfig,
+ ParallelConfig,
+ TeaCacheConfig,
+)
+from tensorrt_llm._torch.visual_gen.models.wan.pipeline_wan_i2v import WanImageToVideoPipeline
+from tensorrt_llm._torch.visual_gen.pipeline_loader import PipelineLoader
+from tensorrt_llm.models.modeling_utils import QuantConfig
+from tensorrt_llm.quantization.mode import QuantAlgo
+
+
+@pytest.fixture(autouse=True, scope="module")
+def _cleanup_mpi_env():
+ """Clean up TLLM_DISABLE_MPI env var after tests complete."""
+ yield
+ os.environ.pop("TLLM_DISABLE_MPI", None)
+
+
+def _llm_models_root() -> str:
+ """Return LLM_MODELS_ROOT path if it is set in env, assert when it's set but not a valid path."""
+ root = Path("/home/scratch.trt_llm_data_ci/llm-models/")
+ if "LLM_MODELS_ROOT" in os.environ:
+ root = Path(os.environ["LLM_MODELS_ROOT"])
+ if not root.exists():
+ root = Path("/scratch.trt_llm_data/llm-models/")
+ assert root.exists(), (
+ "You shall set LLM_MODELS_ROOT env or be able to access scratch.trt_llm_data to run this test"
+ )
+ return str(root)
+
+
+# Checkpoint paths
+CHECKPOINT_PATH = os.environ.get(
+ "DIFFUSION_MODEL_PATH",
+ os.path.join(_llm_models_root(), "Wan2.2-I2V-A14B-Diffusers"),
+)
+
+# Skip components for different test scenarios
+SKIP_MINIMAL = ["text_encoder", "vae", "tokenizer", "scheduler", "image_encoder", "image_processor"]
+SKIP_WITH_IMAGE = ["text_encoder", "vae", "tokenizer", "scheduler"]
+
+
+# ============================================================================
+# VERSION DETECTION HELPERS
+# ============================================================================
+
+
+def is_wan21_checkpoint() -> bool:
+ """Check if DIFFUSION_MODEL_PATH is Wan 2.1 (contains '2.1' in path)."""
+ return "2.1" in CHECKPOINT_PATH
+
+
+def is_wan22_checkpoint() -> bool:
+ """Check if DIFFUSION_MODEL_PATH is Wan 2.2 (contains '2.2' in path)."""
+ return "2.2" in CHECKPOINT_PATH
+
+
+# ============================================================================
+# MODULE-SCOPED FIXTURES
+# ============================================================================
+
+
+@pytest.fixture(scope="module")
+def wan21_i2v_pipeline_bf16():
+ """Load Wan 2.1 I2V BF16 pipeline once per module."""
+ if not CHECKPOINT_PATH or not os.path.exists(CHECKPOINT_PATH):
+ pytest.skip("I2V checkpoint not available")
+ if not is_wan21_checkpoint():
+ pytest.skip("This fixture requires Wan 2.1 checkpoint")
+
+ args = DiffusionArgs(
+ checkpoint_path=CHECKPOINT_PATH,
+ device="cuda",
+ dtype="bfloat16",
+ skip_components=SKIP_MINIMAL,
+ )
+ pipeline = PipelineLoader(args).load()
+ yield pipeline
+ del pipeline
+ torch.cuda.empty_cache()
+
+
+@pytest.fixture(scope="module")
+def wan21_i2v_pipeline_fp8():
+ """Load Wan 2.1 I2V FP8 pipeline once per module."""
+ if not CHECKPOINT_PATH or not os.path.exists(CHECKPOINT_PATH):
+ pytest.skip("I2V checkpoint not available")
+ if not is_wan21_checkpoint():
+ pytest.skip("This fixture requires Wan 2.1 checkpoint")
+
+ args = DiffusionArgs(
+ checkpoint_path=CHECKPOINT_PATH,
+ device="cuda",
+ dtype="bfloat16",
+ skip_components=SKIP_MINIMAL,
+ quant_config={"quant_algo": "FP8", "dynamic": True},
+ )
+ pipeline = PipelineLoader(args).load()
+ yield pipeline
+ del pipeline
+ torch.cuda.empty_cache()
+
+
+@pytest.fixture(scope="module")
+def wan21_i2v_pipeline_fp8_blockwise():
+ """Load Wan 2.1 I2V FP8 blockwise pipeline once per module."""
+ if not CHECKPOINT_PATH or not os.path.exists(CHECKPOINT_PATH):
+ pytest.skip("I2V checkpoint not available")
+ if not is_wan21_checkpoint():
+ pytest.skip("This fixture requires Wan 2.1 checkpoint")
+
+ args = DiffusionArgs(
+ checkpoint_path=CHECKPOINT_PATH,
+ device="cuda",
+ dtype="bfloat16",
+ skip_components=SKIP_MINIMAL,
+ quant_config={"quant_algo": "FP8_BLOCK_SCALES", "dynamic": True},
+ )
+ pipeline = PipelineLoader(args).load()
+ yield pipeline
+ del pipeline
+ torch.cuda.empty_cache()
+
+
+@pytest.fixture(scope="module")
+def wan21_i2v_pipeline_with_image_encoder():
+ """Load Wan 2.1 I2V pipeline with image encoder once per module."""
+ if not CHECKPOINT_PATH or not os.path.exists(CHECKPOINT_PATH):
+ pytest.skip("I2V checkpoint not available")
+ if not is_wan21_checkpoint():
+ pytest.skip("This fixture requires Wan 2.1 checkpoint")
+
+ args = DiffusionArgs(
+ checkpoint_path=CHECKPOINT_PATH,
+ device="cuda",
+ dtype="bfloat16",
+ skip_components=SKIP_WITH_IMAGE,
+ )
+ pipeline = PipelineLoader(args).load()
+ yield pipeline
+ del pipeline
+ torch.cuda.empty_cache()
+
+
+@pytest.fixture(scope="module")
+def wan22_i2v_pipeline_bf16():
+ """Load Wan 2.2 I2V BF16 pipeline once per module."""
+ if not CHECKPOINT_PATH or not os.path.exists(CHECKPOINT_PATH):
+ pytest.skip("I2V checkpoint not available")
+ if not is_wan22_checkpoint():
+ pytest.skip("This fixture requires Wan 2.2 checkpoint")
+
+ args = DiffusionArgs(
+ checkpoint_path=CHECKPOINT_PATH,
+ device="cuda",
+ dtype="bfloat16",
+ skip_components=SKIP_MINIMAL,
+ )
+ pipeline = PipelineLoader(args).load()
+ yield pipeline
+ del pipeline
+ torch.cuda.empty_cache()
+
+
+@pytest.fixture(scope="module")
+def wan22_i2v_pipeline_fp8():
+ """Load Wan 2.2 I2V FP8 pipeline once per module."""
+ if not CHECKPOINT_PATH or not os.path.exists(CHECKPOINT_PATH):
+ pytest.skip("I2V checkpoint not available")
+ if not is_wan22_checkpoint():
+ pytest.skip("This fixture requires Wan 2.2 checkpoint")
+
+ args = DiffusionArgs(
+ checkpoint_path=CHECKPOINT_PATH,
+ device="cuda",
+ dtype="bfloat16",
+ skip_components=SKIP_MINIMAL,
+ quant_config={"quant_algo": "FP8_BLOCK_SCALES", "dynamic": True},
+ )
+ pipeline = PipelineLoader(args).load()
+ yield pipeline
+ del pipeline
+ torch.cuda.empty_cache()
+
+
+@pytest.fixture(scope="module")
+def test_image():
+ """Create a shared test image for I2V tests."""
+ import numpy as np
+
+ img_array = np.zeros((480, 832, 3), dtype=np.uint8)
+ for i in range(480):
+ img_array[i, :, 0] = int((i / 480) * 255)
+ img_array[i, :, 1] = 128
+ return Image.fromarray(img_array, mode="RGB")
+
+
+@pytest.fixture(autouse=True)
+def cleanup_gpu():
+ """GPU cleanup fixture."""
+ import gc
+
+ gc.collect()
+ torch.cuda.empty_cache()
+ yield
+ gc.collect()
+ torch.cuda.empty_cache()
+
+
+# ============================================================================
+# DISTRIBUTED HELPERS (for CFG Parallelism tests)
+# ============================================================================
+
+
+def setup_distributed(rank, world_size, backend="nccl"):
+ """Initialize distributed process group for multi-GPU tests."""
+ os.environ["MASTER_ADDR"] = "localhost"
+ os.environ["MASTER_PORT"] = "12356" # Different port from T2V tests
+ os.environ["RANK"] = str(rank)
+ os.environ["WORLD_SIZE"] = str(world_size)
+
+ dist.init_process_group(backend=backend, rank=rank, world_size=world_size)
+ torch.cuda.set_device(rank)
+
+
+def cleanup_distributed():
+ """Clean up distributed process group."""
+ if dist.is_initialized():
+ dist.destroy_process_group()
+
+
+def _run_cfg_worker_i2v(rank, world_size, checkpoint_path, inputs_list, return_dict):
+ """Worker function for I2V CFG Parallelism multi-GPU test."""
+ try:
+ setup_distributed(rank, world_size)
+
+ from tensorrt_llm._torch.visual_gen.config import DiffusionArgs, ParallelConfig
+ from tensorrt_llm._torch.visual_gen.pipeline_loader import PipelineLoader
+
+ # Load I2V pipeline with CFG parallel
+ args = DiffusionArgs(
+ checkpoint_path=checkpoint_path,
+ device=f"cuda:{rank}",
+ dtype="bfloat16",
+ skip_components=SKIP_MINIMAL,
+ parallel=ParallelConfig(dit_cfg_size=world_size),
+ )
+ pipeline = PipelineLoader(args).load()
+
+ # Verify CFG parallel configuration
+ assert pipeline.model_config.parallel.dit_cfg_size == world_size, (
+ f"Expected cfg_size={world_size}, got {pipeline.model_config.parallel.dit_cfg_size}"
+ )
+
+ # Load inputs on this GPU
+ prompt_embeds = inputs_list[0].to(f"cuda:{rank}")
+ neg_prompt_embeds = inputs_list[1].to(f"cuda:{rank}")
+ latents = inputs_list[2].to(f"cuda:{rank}")
+ timestep = inputs_list[3].to(f"cuda:{rank}")
+ # I2V-specific: image embeddings (if present)
+ image_embeds = inputs_list[4].to(f"cuda:{rank}") if inputs_list[4] is not None else None
+
+ # Setup CFG config
+ cfg_config = pipeline._setup_cfg_config(
+ guidance_scale=5.0,
+ prompt_embeds=prompt_embeds,
+ neg_prompt_embeds=neg_prompt_embeds,
+ )
+
+ # Verify CFG parallel is enabled
+ assert cfg_config["enabled"], f"Rank {rank}: CFG parallel not enabled"
+ assert cfg_config["cfg_size"] == world_size, f"Rank {rank}: Wrong cfg_size"
+
+ expected_cfg_group = rank // cfg_config["ulysses_size"]
+ assert cfg_config["cfg_group"] == expected_cfg_group, (
+ f"Rank {rank}: Wrong cfg_group. Expected {expected_cfg_group}, got {cfg_config['cfg_group']}"
+ )
+
+ if rank == 0:
+ print(f"[CFG I2V Rank {rank}] Loaded with cfg_size={world_size}")
+ print(f" cfg_group: {cfg_config['cfg_group']}")
+ print(f" local_embeds shape: {cfg_config['local_embeds'].shape}")
+ print(f" Using {'positive' if cfg_config['cfg_group'] == 0 else 'negative'} prompts")
+ print(f" Image embeds: {'present' if image_embeds is not None else 'None'}")
+
+ # Verify prompt splitting
+ expected_embeds = prompt_embeds if cfg_config["cfg_group"] == 0 else neg_prompt_embeds
+ assert torch.allclose(cfg_config["local_embeds"], expected_embeds), (
+ f"Rank {rank}: local_embeds doesn't match expected embeds"
+ )
+
+ # Run single denoising step with CFG parallel
+ def forward_fn(
+ latents, extra_stream_latents, timestep, encoder_hidden_states, extra_tensors
+ ):
+ # I2V-specific: include image embeddings in extra_tensors if present
+ return pipeline.transformer( # noqa: F821
+ hidden_states=latents,
+ timestep=timestep,
+ encoder_hidden_states=encoder_hidden_states,
+ encoder_hidden_states_image=extra_tensors.get("encoder_hidden_states_image"),
+ )
+
+ with torch.no_grad():
+ local_extras = (
+ {"encoder_hidden_states_image": image_embeds} if image_embeds is not None else {}
+ )
+ noise_pred, _, _, _ = pipeline._denoise_step_cfg_parallel(
+ latents=latents,
+ extra_stream_latents={},
+ timestep=timestep,
+ local_embeds=cfg_config["local_embeds"],
+ forward_fn=forward_fn,
+ guidance_scale=5.0,
+ guidance_rescale=0.0,
+ ulysses_size=cfg_config["ulysses_size"],
+ local_extras=local_extras,
+ )
+
+ # Validate output
+ assert not torch.isnan(noise_pred).any(), f"Rank {rank}: Output contains NaN"
+ assert not torch.isinf(noise_pred).any(), f"Rank {rank}: Output contains Inf"
+
+ # Return output from rank 0
+ if rank == 0:
+ return_dict["output"] = noise_pred.cpu()
+ print(f"[CFG I2V Rank {rank}] ā Output shape: {noise_pred.shape}")
+ print(
+ f"[CFG I2V Rank {rank}] ā Output range: [{noise_pred.min():.4f}, {noise_pred.max():.4f}]"
+ )
+
+ del pipeline
+ torch.cuda.empty_cache()
+
+ finally:
+ cleanup_distributed()
+
+
+def _run_all_optimizations_worker_i2v(rank, world_size, checkpoint_path, inputs_list, return_dict):
+ try:
+ setup_distributed(rank, world_size)
+
+ # Load I2V pipeline with ALL optimizations
+ args_full = DiffusionArgs(
+ checkpoint_path=checkpoint_path,
+ device=f"cuda:{rank}",
+ dtype="bfloat16",
+ skip_components=SKIP_MINIMAL,
+ quant_config={"quant_algo": "FP8", "dynamic": True},
+ teacache=TeaCacheConfig(
+ enable_teacache=True,
+ teacache_thresh=0.2,
+ use_ret_steps=True,
+ ),
+ attention=AttentionConfig(backend="TRTLLM"),
+ parallel=ParallelConfig(dit_cfg_size=world_size),
+ )
+ pipeline = PipelineLoader(args_full).load()
+ transformer = pipeline.transformer.eval()
+
+ # Verify all optimizations are enabled
+ assert pipeline.model_config.parallel.dit_cfg_size == world_size, "CFG parallel not enabled"
+ assert transformer.model_config.quant_config.quant_algo == QuantAlgo.FP8, "FP8 not enabled"
+ assert hasattr(pipeline, "transformer_cache_backend"), "TeaCache not enabled"
+ assert transformer.blocks[0].attn1.attn_backend == "TRTLLM", (
+ "TRTLLM not enabled for self-attn"
+ )
+
+ if rank == 0:
+ print(f" ā All optimizations verified on I2V rank {rank}:")
+ print(f" - FP8 quantization: {transformer.model_config.quant_config.quant_algo}")
+ print(" - TeaCache: enabled")
+ print(f" - TRTLLM attention: {transformer.blocks[0].attn1.attn_backend}")
+ print(f" - CFG Parallelism: cfg_size={world_size}")
+
+ # Initialize TeaCache for single-step inference
+ if hasattr(pipeline, "transformer_cache_backend"):
+ pipeline.transformer_cache_backend.refresh(num_inference_steps=1)
+
+ # Load inputs on this GPU
+ prompt_embeds = inputs_list[0].to(f"cuda:{rank}")
+ neg_prompt_embeds = inputs_list[1].to(f"cuda:{rank}")
+ latents = inputs_list[2].to(f"cuda:{rank}")
+ timestep = inputs_list[3].to(f"cuda:{rank}")
+ image_embeds = inputs_list[4].to(f"cuda:{rank}") if inputs_list[4] is not None else None
+
+ # Setup CFG config
+ cfg_config = pipeline._setup_cfg_config(
+ guidance_scale=5.0,
+ prompt_embeds=prompt_embeds,
+ neg_prompt_embeds=neg_prompt_embeds,
+ )
+
+ assert cfg_config["enabled"], "CFG parallel not enabled"
+
+ # Run single denoising step with all optimizations
+ def forward_fn(
+ latents, extra_stream_latents, timestep, encoder_hidden_states, extra_tensors
+ ):
+ return transformer( # noqa: F821
+ hidden_states=latents,
+ timestep=timestep,
+ encoder_hidden_states=encoder_hidden_states,
+ encoder_hidden_states_image=extra_tensors.get("encoder_hidden_states_image"),
+ )
+
+ with torch.no_grad():
+ local_extras = (
+ {"encoder_hidden_states_image": image_embeds} if image_embeds is not None else {}
+ )
+ noise_pred, _, _, _ = pipeline._denoise_step_cfg_parallel(
+ latents=latents,
+ extra_stream_latents={},
+ timestep=timestep,
+ local_embeds=cfg_config["local_embeds"],
+ forward_fn=forward_fn,
+ guidance_scale=5.0,
+ guidance_rescale=0.0,
+ ulysses_size=cfg_config["ulysses_size"],
+ local_extras=local_extras,
+ )
+
+ # Validate output
+ assert not torch.isnan(noise_pred).any(), f"Rank {rank}: Output contains NaN"
+ assert not torch.isinf(noise_pred).any(), f"Rank {rank}: Output contains Inf"
+
+ # Return output from rank 0
+ if rank == 0:
+ return_dict["output"] = noise_pred.cpu()
+ print(f" ā Combined optimization I2V output shape: {noise_pred.shape}")
+ print(
+ f" ā Combined optimization I2V range: [{noise_pred.min():.4f}, {noise_pred.max():.4f}]"
+ )
+
+ del pipeline, transformer
+ torch.cuda.empty_cache()
+
+ finally:
+ cleanup_distributed()
+
+
+# ============================================================================
+# SMOKE TESTS (No Checkpoint Required)
+# ============================================================================
+
+
+@pytest.mark.unit
+@pytest.mark.smoke
+class TestWanI2VSmoke:
+ def _create_model_config(self, boundary_ratio=None):
+ """Helper to create test model config."""
+ config_dict = {
+ "attention_head_dim": 128,
+ "in_channels": 16,
+ "out_channels": 16,
+ "num_attention_heads": 4,
+ "num_layers": 1,
+ "patch_size": [1, 2, 2],
+ "text_dim": 4096,
+ "freq_dim": 256,
+ "ffn_dim": 1024,
+ "torch_dtype": "bfloat16",
+ "hidden_size": 512,
+ "qk_norm": "rms_norm_across_heads",
+ "cross_attn_norm": "layer_norm",
+ "eps": 1e-06,
+ "image_dim": 1280, # CLIP dimension (HF naming convention)
+ "added_kv_proj_dim": 1280, # Added KV projection dimension for I2V
+ "boundary_ratio": boundary_ratio,
+ }
+ pretrained_config = SimpleNamespace(**config_dict)
+ quant_config = QuantConfig()
+
+ return DiffusionModelConfig(
+ pretrained_config=pretrained_config,
+ quant_config=quant_config,
+ skip_create_weights_in_init=True,
+ )
+
+ def test_wan21_instantiation(self):
+ """Test Wan 2.1 I2V pipeline (single-stage)."""
+ model_config = self._create_model_config(boundary_ratio=None)
+ pipeline = WanImageToVideoPipeline(model_config)
+
+ assert pipeline.transformer is not None
+ assert pipeline.transformer_2 is None # Single-stage
+ assert pipeline.boundary_ratio is None
+
+ def test_wan22_instantiation(self):
+ """Test Wan 2.2 I2V pipeline (two-stage)."""
+ model_config = self._create_model_config(boundary_ratio=0.4)
+ pipeline = WanImageToVideoPipeline(model_config)
+
+ assert pipeline.transformer is not None
+ assert pipeline.transformer_2 is not None # Two-stage
+ assert pipeline.boundary_ratio == 0.4
+
+ def test_retrieve_latents(self):
+ """Test retrieve_latents helper."""
+ from tensorrt_llm._torch.visual_gen.models.wan.pipeline_wan_i2v import retrieve_latents
+
+ class MockLatentDist:
+ def mode(self):
+ return torch.randn(1, 16, 1, 64, 64)
+
+ def sample(self, generator=None):
+ return torch.randn(1, 16, 1, 64, 64)
+
+ class MockEncoderOutput:
+ def __init__(self):
+ self.latent_dist = MockLatentDist()
+
+ encoder_output = MockEncoderOutput()
+
+ # Test argmax mode (I2V default for deterministic encoding)
+ latents_argmax = retrieve_latents(encoder_output, sample_mode="argmax")
+ assert latents_argmax.shape == (1, 16, 1, 64, 64)
+
+ # Test sample mode
+ latents_sample = retrieve_latents(encoder_output, sample_mode="sample")
+ assert latents_sample.shape == (1, 16, 1, 64, 64)
+
+
+# ============================================================================
+# INTEGRATION TESTS - WAN 2.1 (Require Wan 2.1 Checkpoint)
+# ============================================================================
+
+
+@pytest.mark.integration
+@pytest.mark.i2v
+@pytest.mark.wan21
+class TestWanI2VIntegration:
+ """Integration tests with Wan 2.1 checkpoint."""
+
+ def test_load_pipeline(self, wan21_i2v_pipeline_bf16):
+ """Test loading I2V pipeline from checkpoint."""
+ # Verify I2V pipeline
+ assert "ImageToVideo" in type(wan21_i2v_pipeline_bf16).__name__
+ assert wan21_i2v_pipeline_bf16.transformer is not None
+ assert len(wan21_i2v_pipeline_bf16.transformer.blocks) > 0
+
+ # Detect version
+ is_two_stage = (
+ wan21_i2v_pipeline_bf16.boundary_ratio is not None
+ and wan21_i2v_pipeline_bf16.transformer_2 is not None
+ )
+
+ print(f"\nā Pipeline: {type(wan21_i2v_pipeline_bf16).__name__}")
+ print(f"ā Transformer blocks: {len(wan21_i2v_pipeline_bf16.transformer.blocks)}")
+ print(f"ā boundary_ratio: {wan21_i2v_pipeline_bf16.boundary_ratio}")
+ print(f"ā Two-stage: {is_two_stage}")
+
+ def test_image_encoding(self, wan21_i2v_pipeline_with_image_encoder, test_image):
+ """Test CLIP image encoding (if model uses it)."""
+ # Check if model uses image encoder
+ if (
+ not hasattr(wan21_i2v_pipeline_with_image_encoder, "image_encoder")
+ or wan21_i2v_pipeline_with_image_encoder.image_encoder is None
+ ):
+ pytest.skip("This checkpoint doesn't use image encoder")
+
+ # Encode test image
+ image_embeds = wan21_i2v_pipeline_with_image_encoder._encode_image(test_image)
+
+ assert image_embeds is not None
+ assert image_embeds.dim() == 3 # [batch, seq_len, embed_dim]
+ print(f"\nā Image embeddings: {image_embeds.shape}, dtype={image_embeds.dtype}")
+
+ def test_fp8_per_tensor_quantization(self, wan21_i2v_pipeline_fp8):
+ """Test FP8 per-tensor dynamic quantization."""
+ # Check transformer for FP8 weights
+ found_fp8 = any(
+ param.dtype == torch.float8_e4m3fn
+ for name, param in wan21_i2v_pipeline_fp8.transformer.named_parameters()
+ if "blocks.0" in name and "weight" in name
+ )
+ assert found_fp8, "No FP8 weights found for FP8"
+ print("\nā FP8: FP8 weights found in transformer")
+
+ # Check transformer_2 if two-stage
+ if wan21_i2v_pipeline_fp8.transformer_2 is not None:
+ found_fp8_t2 = any(
+ param.dtype == torch.float8_e4m3fn
+ for name, param in wan21_i2v_pipeline_fp8.transformer_2.named_parameters()
+ if "blocks.0" in name and "weight" in name
+ )
+ assert found_fp8_t2, "No FP8 weights in transformer_2"
+ print("ā FP8: FP8 weights found in transformer_2")
+
+ def test_fp8_blockwise_quantization(self, wan21_i2v_pipeline_fp8_blockwise):
+ """Test FP8 blockwise dynamic quantization."""
+ # Check transformer for FP8 weights
+ found_fp8 = any(
+ param.dtype == torch.float8_e4m3fn
+ for name, param in wan21_i2v_pipeline_fp8_blockwise.transformer.named_parameters()
+ if "blocks.0" in name and "weight" in name
+ )
+ assert found_fp8, "No FP8 weights found for FP8_BLOCK_SCALES"
+ print("\nā FP8_BLOCK_SCALES: FP8 weights found in transformer")
+
+ @pytest.mark.parametrize("backend", ["VANILLA", "TRTLLM"])
+ def test_attention_backends(self, backend):
+ """Test different attention backends."""
+ if not CHECKPOINT_PATH or not os.path.exists(CHECKPOINT_PATH):
+ pytest.skip("DIFFUSION_MODEL_PATH not set")
+ if not is_wan21_checkpoint():
+ pytest.skip("This test requires Wan 2.1 checkpoint")
+
+ args = DiffusionArgs(
+ checkpoint_path=CHECKPOINT_PATH,
+ device="cuda",
+ dtype="bfloat16",
+ skip_components=SKIP_MINIMAL,
+ attention=AttentionConfig(backend=backend),
+ )
+ pipeline = PipelineLoader(args).load()
+
+ try:
+ # Check transformer attention backend
+ first_block = pipeline.transformer.blocks[0]
+ attn1_backend = first_block.attn1.attn_backend
+ attn2_backend = first_block.attn2.attn_backend
+
+ # TRTLLM for self-attention, VANILLA for cross-attention
+ if backend == "TRTLLM":
+ assert attn1_backend == "TRTLLM", f"Expected TRTLLM, got {attn1_backend}"
+ assert attn2_backend == "VANILLA", (
+ f"Cross-attn should be VANILLA, got {attn2_backend}"
+ )
+ else:
+ assert attn1_backend == "VANILLA"
+ assert attn2_backend == "VANILLA"
+
+ print(f"\nā Attention backend: {backend}")
+ print(f" Self-attn: {attn1_backend}, Cross-attn: {attn2_backend}")
+
+ # Check transformer_2 if two-stage
+ if pipeline.transformer_2 is not None:
+ first_block_t2 = pipeline.transformer_2.blocks[0]
+ attn1_backend_t2 = first_block_t2.attn1.attn_backend
+ attn2_backend_t2 = first_block_t2.attn2.attn_backend
+
+ if backend == "TRTLLM":
+ assert attn1_backend_t2 == "TRTLLM"
+ assert attn2_backend_t2 == "VANILLA"
+ print(
+ f" Transformer_2 - Self-attn: {attn1_backend_t2}, Cross-attn: {attn2_backend_t2}"
+ )
+
+ finally:
+ del pipeline
+ torch.cuda.empty_cache()
+
+ def test_teacache(self):
+ """Test TeaCache on both transformers."""
+ if not CHECKPOINT_PATH or not os.path.exists(CHECKPOINT_PATH):
+ pytest.skip("DIFFUSION_MODEL_PATH not set")
+ if not is_wan21_checkpoint():
+ pytest.skip("This test requires Wan 2.1 checkpoint")
+
+ args = DiffusionArgs(
+ checkpoint_path=CHECKPOINT_PATH,
+ device="cuda",
+ dtype="bfloat16",
+ skip_components=SKIP_MINIMAL,
+ teacache=TeaCacheConfig(
+ enable_teacache=True,
+ teacache_thresh=0.2,
+ use_ret_steps=True,
+ ),
+ )
+ pipeline = PipelineLoader(args).load()
+
+ try:
+ # Verify TeaCache on transformer
+ assert hasattr(pipeline, "transformer_cache_backend")
+ assert pipeline.transformer_cache_backend is not None
+ print("\nā TeaCache enabled on transformer (high-noise)")
+
+ # Verify get_stats method
+ stats = pipeline.transformer_cache_backend.get_stats()
+ assert "total_steps" in stats
+ assert "cached_steps" in stats
+ assert "compute_steps" in stats
+ print("ā TeaCache stats available")
+
+ # Check transformer_2 if two-stage
+ if pipeline.transformer_2 is not None:
+ assert hasattr(pipeline, "transformer_2_cache_backend")
+ assert pipeline.transformer_2_cache_backend is not None
+ stats2 = pipeline.transformer_2_cache_backend.get_stats()
+ assert "total_steps" in stats2
+ print("ā TeaCache enabled on transformer_2 (low-noise)")
+
+ finally:
+ del pipeline
+ torch.cuda.empty_cache()
+
+ def test_all_optimizations_combined(self):
+ """Test all optimizations enabled simultaneously."""
+ if not CHECKPOINT_PATH or not os.path.exists(CHECKPOINT_PATH):
+ pytest.skip("DIFFUSION_MODEL_PATH not set")
+ if not is_wan21_checkpoint():
+ pytest.skip("This test requires Wan 2.1 checkpoint")
+
+ args = DiffusionArgs(
+ checkpoint_path=CHECKPOINT_PATH,
+ device="cuda",
+ dtype="bfloat16",
+ skip_components=SKIP_MINIMAL,
+ quant_config={"quant_algo": "FP8_BLOCK_SCALES", "dynamic": True},
+ attention=AttentionConfig(backend="VANILLA"), # VANILLA more stable with all opts
+ teacache=TeaCacheConfig(
+ enable_teacache=True,
+ teacache_thresh=0.2,
+ use_ret_steps=True,
+ ),
+ )
+ pipeline = PipelineLoader(args).load()
+
+ try:
+ optimizations = []
+
+ # Check FP8
+ if any(p.dtype == torch.float8_e4m3fn for p in pipeline.transformer.parameters()):
+ optimizations.append("FP8")
+
+ # Check TeaCache
+ if (
+ hasattr(pipeline, "transformer_cache_backend")
+ and pipeline.transformer_cache_backend
+ ):
+ optimizations.append("TeaCache")
+
+ # Check two-stage
+ if pipeline.transformer_2 is not None:
+ optimizations.append("Two-Stage")
+
+ # Check attention backend
+ optimizations.append(f"Attention={args.attention.backend}")
+
+ print(f"\nā All optimizations: {', '.join(optimizations)}")
+ assert len(optimizations) >= 3
+
+ finally:
+ del pipeline
+ torch.cuda.empty_cache()
+
+ def test_fp8_vs_bf16_numerical_correctness(
+ self, wan21_i2v_pipeline_bf16, wan21_i2v_pipeline_fp8
+ ):
+ """Test FP8 vs BF16 numerical accuracy on I2V transformer."""
+ # Get linear layers from first transformer
+ attn_bf16 = wan21_i2v_pipeline_bf16.transformer.blocks[0].attn1
+ attn_fp8 = wan21_i2v_pipeline_fp8.transformer.blocks[0].attn1
+
+ # Get qkv_proj layer
+ if hasattr(attn_bf16, "qkv_proj"):
+ linear_bf16 = attn_bf16.qkv_proj
+ linear_fp8 = attn_fp8.qkv_proj
+ layer_name = "blocks.0.attn1.qkv_proj"
+ elif hasattr(attn_bf16, "attn") and hasattr(attn_bf16.attn, "qkv_proj"):
+ linear_bf16 = attn_bf16.attn.qkv_proj
+ linear_fp8 = attn_fp8.attn.qkv_proj
+ layer_name = "blocks.0.attn1.attn.qkv_proj"
+ else:
+ # Use FFN linear instead
+ linear_bf16 = wan21_i2v_pipeline_bf16.transformer.blocks[0].ffn.net[0]["proj"]
+ linear_fp8 = wan21_i2v_pipeline_fp8.transformer.blocks[0].ffn.net[0]["proj"]
+ layer_name = "blocks.0.ffn.net.0.proj"
+
+ # Get weights
+ weight_bf16 = linear_bf16.weight.data.clone()
+ bias_bf16 = linear_bf16.bias.data.clone() if linear_bf16.bias is not None else None
+
+ # Create test input
+ torch.manual_seed(42)
+ hidden_size = linear_bf16.in_features
+ batch_size = 1
+ seq_len = 14040
+
+ input_tensor = torch.randn(
+ batch_size * seq_len, hidden_size, dtype=torch.bfloat16, device="cuda"
+ )
+ print(f"\n[Compare] Input shape: {input_tensor.shape}")
+
+ # Compute reference output
+ with torch.no_grad():
+ expected = F.linear(input_tensor, weight_bf16, bias_bf16)
+
+ # Compute FP8 output
+ with torch.no_grad():
+ result_fp8 = linear_fp8(input_tensor)
+
+ # Compute BF16 output
+ with torch.no_grad():
+ result_bf16 = linear_bf16(input_tensor)
+
+ # Verify BF16 matches reference
+ assert torch.allclose(result_bf16, expected, rtol=1e-5, atol=1e-6), (
+ "BF16 layer should match F.linear reference exactly"
+ )
+
+ # Compare FP8 vs reference
+ max_diff = torch.max(torch.abs(result_fp8 - expected)).item()
+ cos_sim = F.cosine_similarity(
+ result_fp8.flatten().float(), expected.flatten().float(), dim=0
+ )
+ mse = F.mse_loss(result_fp8.flatten().float(), expected.flatten().float())
+
+ print(
+ f"\n[{layer_name}] max_diff={max_diff:.6f}, cos_sim={cos_sim.item():.6f}, mse={mse.item():.6f}"
+ )
+
+ assert cos_sim > 0.99, f"Cosine similarity too low: {cos_sim.item()}"
+ assert mse < 1.0, f"MSE too high: {mse.item()}"
+
+ # Test transformer_2 if two-stage
+ if (
+ wan21_i2v_pipeline_bf16.transformer_2 is not None
+ and wan21_i2v_pipeline_fp8.transformer_2 is not None
+ ):
+ print("\n[Testing transformer_2]")
+ attn2_bf16 = wan21_i2v_pipeline_bf16.transformer_2.blocks[0].attn1
+ attn2_fp8 = wan21_i2v_pipeline_fp8.transformer_2.blocks[0].attn1
+
+ if hasattr(attn2_bf16, "qkv_proj"):
+ linear2_bf16 = attn2_bf16.qkv_proj
+ linear2_fp8 = attn2_fp8.qkv_proj
+ else:
+ linear2_bf16 = wan21_i2v_pipeline_bf16.transformer_2.blocks[0].ffn.net[0]["proj"]
+ linear2_fp8 = wan21_i2v_pipeline_fp8.transformer_2.blocks[0].ffn.net[0]["proj"]
+
+ weight2_bf16 = linear2_bf16.weight.data.clone()
+ bias2_bf16 = linear2_bf16.bias.data.clone() if linear2_bf16.bias is not None else None
+
+ with torch.no_grad():
+ expected2 = F.linear(input_tensor, weight2_bf16, bias2_bf16)
+ result2_fp8 = linear2_fp8(input_tensor)
+
+ cos_sim2 = F.cosine_similarity(
+ result2_fp8.flatten().float(), expected2.flatten().float(), dim=0
+ )
+ print(f"[transformer_2] cos_sim={cos_sim2.item():.6f}")
+ assert cos_sim2 > 0.99, f"Transformer_2 cosine similarity too low: {cos_sim2.item()}"
+
+ def test_fp8_vs_bf16_memory_comparison(self, wan21_i2v_pipeline_bf16, wan21_i2v_pipeline_fp8):
+ """Test FP8 uses ~2x less memory than BF16 for I2V."""
+
+ def get_module_memory_gb(module):
+ return sum(p.numel() * p.element_size() for p in module.parameters()) / 1024**3
+
+ bf16_model_mem = get_module_memory_gb(wan21_i2v_pipeline_bf16.transformer)
+ if wan21_i2v_pipeline_bf16.transformer_2 is not None:
+ bf16_model_mem += get_module_memory_gb(wan21_i2v_pipeline_bf16.transformer_2)
+
+ fp8_model_mem = get_module_memory_gb(wan21_i2v_pipeline_fp8.transformer)
+ if wan21_i2v_pipeline_fp8.transformer_2 is not None:
+ fp8_model_mem += get_module_memory_gb(wan21_i2v_pipeline_fp8.transformer_2)
+
+ print(f"\n[BF16] Transformer(s) memory: {bf16_model_mem:.2f} GB")
+ print(f"[FP8] Transformer(s) memory: {fp8_model_mem:.2f} GB")
+
+ # Verify memory savings
+ model_mem_ratio = bf16_model_mem / fp8_model_mem
+
+ print(f"\n[Comparison] Model memory ratio (BF16/FP8): {model_mem_ratio:.2f}x")
+
+ # FP8 should use ~2x less memory
+ assert model_mem_ratio > 1.8, f"FP8 should use ~2x less memory, got {model_mem_ratio:.2f}x"
+
+
+# ============================================================================
+# TWO-STAGE SPECIFIC TESTS - WAN 2.2 (Require Wan 2.2 Checkpoint)
+# ============================================================================
+
+
+@pytest.mark.integration
+@pytest.mark.i2v
+@pytest.mark.wan22
+class TestWanI2VTwoStage:
+ """Tests specific to Wan 2.2 two-stage denoising."""
+
+ def test_transformer_selection_logic(self, wan22_i2v_pipeline_bf16):
+ """Test boundary_timestep logic for transformer selection."""
+ # Skip if not two-stage
+ if (
+ wan22_i2v_pipeline_bf16.boundary_ratio is None
+ or wan22_i2v_pipeline_bf16.transformer_2 is None
+ ):
+ pytest.skip("Not a two-stage checkpoint")
+
+ # Calculate boundary
+ num_train_timesteps = 1000
+ boundary_timestep = wan22_i2v_pipeline_bf16.boundary_ratio * num_train_timesteps
+
+ print(f"\nā boundary_ratio: {wan22_i2v_pipeline_bf16.boundary_ratio}")
+ print(f"ā boundary_timestep: {boundary_timestep:.1f}")
+ print(f"ā High-noise (t >= {boundary_timestep:.1f}): uses transformer")
+ print(f"ā Low-noise (t < {boundary_timestep:.1f}): uses transformer_2")
+
+ @pytest.mark.parametrize("guidance_scale_2", [2.0, 3.0, 4.0])
+ def test_guidance_scale_2_parameter(self, wan22_i2v_pipeline_bf16, guidance_scale_2):
+ """Test guidance_scale_2 for low-noise stage."""
+ # Skip if not two-stage
+ if (
+ wan22_i2v_pipeline_bf16.boundary_ratio is None
+ or wan22_i2v_pipeline_bf16.transformer_2 is None
+ ):
+ pytest.skip("Not a two-stage checkpoint")
+
+ print(f"\nā Two-stage model supports guidance_scale_2={guidance_scale_2}")
+ print("ā High-noise: uses guidance_scale")
+ print(f"ā Low-noise: uses guidance_scale_2={guidance_scale_2}")
+
+ def test_custom_boundary_ratio(self, wan22_i2v_pipeline_bf16):
+ """Test overriding boundary_ratio at runtime."""
+ # Skip if not two-stage
+ if (
+ wan22_i2v_pipeline_bf16.boundary_ratio is None
+ or wan22_i2v_pipeline_bf16.transformer_2 is None
+ ):
+ pytest.skip("Not a two-stage checkpoint")
+
+ default_ratio = wan22_i2v_pipeline_bf16.boundary_ratio
+ custom_ratio = 0.3
+
+ print(f"\nā Model default boundary_ratio: {default_ratio}")
+ print(f"ā Custom override: {custom_ratio}")
+ print("ā forward() accepts boundary_ratio parameter for runtime override")
+
+ def test_two_stage_with_all_optimizations(self, wan22_i2v_pipeline_fp8):
+ """Test Wan 2.2 with FP8, TeaCache, and TRTLLM attention."""
+ # Skip if not two-stage
+ if (
+ wan22_i2v_pipeline_fp8.boundary_ratio is None
+ or wan22_i2v_pipeline_fp8.transformer_2 is None
+ ):
+ pytest.skip("Not a two-stage checkpoint")
+
+ # Load pipeline with all optimizations
+ args = DiffusionArgs(
+ checkpoint_path=CHECKPOINT_PATH,
+ device="cuda",
+ dtype="bfloat16",
+ skip_components=SKIP_MINIMAL,
+ quant_config={"quant_algo": "FP8_BLOCK_SCALES", "dynamic": True},
+ attention=AttentionConfig(backend="TRTLLM"),
+ teacache=TeaCacheConfig(
+ enable_teacache=True,
+ teacache_thresh=0.2,
+ use_ret_steps=True,
+ ),
+ )
+ pipeline = PipelineLoader(args).load()
+
+ try:
+ print("\n[Two-Stage + All Optimizations]")
+
+ # Check FP8 on both transformers
+ fp8_t1 = any(p.dtype == torch.float8_e4m3fn for p in pipeline.transformer.parameters())
+ fp8_t2 = any(
+ p.dtype == torch.float8_e4m3fn for p in pipeline.transformer_2.parameters()
+ )
+ print(f"ā FP8: transformer={fp8_t1}, transformer_2={fp8_t2}")
+ assert fp8_t1 and fp8_t2
+
+ # Check TeaCache on both transformers
+ has_cache_t1 = (
+ hasattr(pipeline, "transformer_cache_backend")
+ and pipeline.transformer_cache_backend
+ )
+ has_cache_t2 = (
+ hasattr(pipeline, "transformer_2_cache_backend")
+ and pipeline.transformer_2_cache_backend
+ )
+ print(f"ā TeaCache: transformer={has_cache_t1}, transformer_2={has_cache_t2}")
+ assert has_cache_t1 and has_cache_t2
+
+ # Check TRTLLM attention
+ attn1_backend = pipeline.transformer.blocks[0].attn1.attn_backend
+ attn2_backend = pipeline.transformer_2.blocks[0].attn1.attn_backend
+ print(f"ā TRTLLM: transformer={attn1_backend}, transformer_2={attn2_backend}")
+ assert attn1_backend == "TRTLLM"
+ assert attn2_backend == "TRTLLM"
+
+ print("ā All optimizations working on two-stage model!")
+
+ finally:
+ del pipeline
+ torch.cuda.empty_cache()
+
+
+# ============================================================================
+# ROBUSTNESS TESTS
+# ============================================================================
+
+
+@pytest.mark.robustness
+class TestWanI2VRobustness:
+ """Robustness and error handling tests."""
+
+ def test_invalid_quant_config(self):
+ """Test that invalid quantization config raises appropriate error."""
+ with pytest.raises((ValueError, KeyError)):
+ args = DiffusionArgs(
+ checkpoint_path=CHECKPOINT_PATH,
+ device="cuda",
+ dtype="bfloat16",
+ skip_components=SKIP_MINIMAL,
+ quant_config={"quant_algo": "INVALID_ALGO", "dynamic": True},
+ )
+ pipeline = PipelineLoader(args).load()
+ del pipeline
+
+ def test_mismatched_image_size(self, test_image):
+ """Test handling of unexpected image dimensions."""
+ if not CHECKPOINT_PATH or not os.path.exists(CHECKPOINT_PATH):
+ pytest.skip("DIFFUSION_MODEL_PATH not set")
+
+ args = DiffusionArgs(
+ checkpoint_path=CHECKPOINT_PATH,
+ device="cuda",
+ dtype="bfloat16",
+ skip_components=SKIP_WITH_IMAGE,
+ )
+ pipeline = PipelineLoader(args).load()
+
+ try:
+ # Check if model uses image encoder
+ if not hasattr(pipeline, "image_encoder") or pipeline.image_encoder is None:
+ pytest.skip("This checkpoint doesn't use image encoder")
+
+ # Create image with unexpected size
+ import numpy as np
+
+ small_img = np.zeros((224, 224, 3), dtype=np.uint8)
+ small_image = Image.fromarray(small_img, mode="RGB")
+
+ # Should handle gracefully
+ try:
+ image_embeds = pipeline._encode_image(small_image)
+ assert image_embeds is not None
+ print("\nā Handled non-standard image size gracefully")
+ except Exception as e:
+ # Some error is expected
+ print(f"\nā Raised appropriate error for mismatched size: {type(e).__name__}")
+
+ finally:
+ del pipeline
+ torch.cuda.empty_cache()
+
+
+# ============================================================================
+# CFG PARALLELISM TESTS (Requires 2+ GPUs)
+# ============================================================================
+
+
+@pytest.mark.parallelism
+class TestWanI2VParallelism(unittest.TestCase):
+ """Distributed parallelism correctness tests for I2V (CFG Parallelism)."""
+
+ DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
+
+ def setUp(self):
+ """Set up test fixtures and skip if checkpoint not available."""
+ torch.manual_seed(42)
+ if torch.cuda.is_available():
+ torch.cuda.manual_seed_all(42)
+ if not CHECKPOINT_PATH or not os.path.exists(CHECKPOINT_PATH):
+ self.skipTest(
+ "Checkpoint not available. Set DIFFUSION_MODEL_PATH environment variable."
+ )
+
+ def tearDown(self):
+ """Clean up GPU memory."""
+ import gc
+
+ gc.collect()
+ if torch.cuda.is_available():
+ torch.cuda.empty_cache()
+ torch.cuda.synchronize()
+
+ def test_cfg_2gpu_correctness(self):
+ """Test I2V CFG Parallelism (cfg_size=2) correctness against standard CFG baseline."""
+ num_gpus = torch.cuda.device_count()
+ if num_gpus < 2:
+ pytest.skip("CFG parallel test requires at least 2 GPUs")
+ if not is_wan21_checkpoint():
+ pytest.skip(
+ "This test requires Wan 2.1 checkpoint. Use DIFFUSION_MODEL_PATH with '2.1' in the path."
+ )
+
+ print("\n" + "=" * 80)
+ print("I2V CFG PARALLELISM (cfg_size=2) CORRECTNESS TEST")
+ print("=" * 80)
+
+ # Load standard CFG baseline on GPU 0
+ print("\n[1/3] Loading standard CFG I2V baseline (cfg_size=1) on GPU 0...")
+ args_baseline = DiffusionArgs(
+ checkpoint_path=CHECKPOINT_PATH,
+ device="cuda:0",
+ dtype="bfloat16",
+ skip_components=SKIP_MINIMAL,
+ parallel=ParallelConfig(dit_cfg_size=1), # Standard CFG (no parallel)
+ )
+ pipeline_baseline = PipelineLoader(args_baseline).load()
+ config = pipeline_baseline.transformer.model_config.pretrained_config
+
+ # Reset torch compile state
+ torch._dynamo.reset()
+
+ # Create FIXED test inputs
+ print("\n[2/3] Creating fixed test inputs...")
+ torch.manual_seed(42)
+ batch_size, num_frames, height, width, seq_len = 1, 1, 64, 64, 128
+
+ latents = torch.randn(
+ batch_size,
+ config.in_channels,
+ num_frames,
+ height,
+ width,
+ dtype=torch.bfloat16,
+ device="cuda:0",
+ )
+ timestep = torch.tensor([500], dtype=torch.long, device="cuda:0")
+ prompt_embeds = torch.randn(
+ batch_size, seq_len, config.text_dim, dtype=torch.bfloat16, device="cuda:0"
+ )
+ neg_prompt_embeds = torch.randn(
+ batch_size, seq_len, config.text_dim, dtype=torch.bfloat16, device="cuda:0"
+ )
+
+ # I2V-specific: Create image embeddings (or None if Wan 2.2)
+ image_embeds = None
+ image_dim = getattr(config, "image_dim", getattr(config, "image_embed_dim", None))
+ if image_dim is not None:
+ # Wan 2.1 uses CLIP image embeddings
+ image_seq_len = 256 # CLIP patch count
+ image_embeds = torch.randn(
+ batch_size, image_seq_len, image_dim, dtype=torch.bfloat16, device="cuda:0"
+ )
+ print(f" ā Created image embeddings: {image_embeds.shape}")
+
+ # Setup standard CFG config
+ cfg_config_baseline = pipeline_baseline._setup_cfg_config(
+ guidance_scale=5.0,
+ prompt_embeds=prompt_embeds,
+ neg_prompt_embeds=neg_prompt_embeds,
+ )
+
+ print(" Baseline CFG config:")
+ print(f" enabled: {cfg_config_baseline['enabled']}")
+ print(f" cfg_size: {cfg_config_baseline['cfg_size']}")
+
+ # Verify standard CFG is NOT parallel
+ assert not cfg_config_baseline["enabled"], "Baseline should not use CFG parallel"
+ assert cfg_config_baseline["cfg_size"] == 1, "Baseline cfg_size should be 1"
+
+ # Run standard CFG denoising step
+ def forward_fn(
+ latents, extra_stream_latents, timestep, encoder_hidden_states, extra_tensors
+ ):
+ return pipeline_baseline.transformer( # noqa: F821
+ hidden_states=latents,
+ timestep=timestep,
+ encoder_hidden_states=encoder_hidden_states,
+ encoder_hidden_states_image=extra_tensors.get("encoder_hidden_states_image"),
+ )
+
+ with torch.no_grad():
+ local_extras = (
+ {"encoder_hidden_states_image": image_embeds} if image_embeds is not None else {}
+ )
+ baseline_output, _, _, _ = pipeline_baseline._denoise_step_standard(
+ latents=latents.clone(),
+ extra_stream_latents={},
+ timestep=timestep,
+ prompt_embeds=cfg_config_baseline["prompt_embeds"],
+ forward_fn=forward_fn,
+ guidance_scale=5.0,
+ guidance_rescale=0.0,
+ local_extras=local_extras,
+ )
+
+ print(f" ā Baseline output shape: {baseline_output.shape}")
+ print(f" ā Baseline range: [{baseline_output.min():.4f}, {baseline_output.max():.4f}]")
+
+ # Cleanup baseline to free memory for CFG workers
+ del pipeline_baseline
+ torch.cuda.empty_cache()
+
+ # Run CFG parallel (cfg_size=2) in distributed processes
+ print("\n[3/3] Running I2V CFG Parallelism (cfg_size=2) across 2 GPUs...")
+ cfg_size = 2
+
+ inputs_cpu = [
+ prompt_embeds.cpu(),
+ neg_prompt_embeds.cpu(),
+ latents.cpu(),
+ timestep.cpu(),
+ image_embeds.cpu() if image_embeds is not None else None,
+ ]
+
+ manager = mp.Manager()
+ return_dict = manager.dict()
+
+ # Spawn CFG workers
+ mp.spawn(
+ _run_cfg_worker_i2v,
+ args=(cfg_size, CHECKPOINT_PATH, inputs_cpu, return_dict),
+ nprocs=cfg_size,
+ join=True,
+ )
+
+ # Get CFG parallel output from rank 0
+ cfg_parallel_output = return_dict["output"].to("cuda:0")
+ print(f" ā CFG parallel output shape: {cfg_parallel_output.shape}")
+
+ # Compare outputs
+ print("\n[Comparison] I2V CFG Parallel vs Standard CFG:")
+ baseline_float = baseline_output.float()
+ cfg_parallel_float = cfg_parallel_output.float()
+
+ cos_sim = F.cosine_similarity(
+ cfg_parallel_float.flatten(), baseline_float.flatten(), dim=0
+ ).item()
+
+ max_diff = torch.max(torch.abs(cfg_parallel_float - baseline_float)).item()
+ mean_diff = torch.mean(torch.abs(cfg_parallel_float - baseline_float)).item()
+
+ print(f" Cosine similarity: {cos_sim:.6f}")
+ print(f" Max absolute difference: {max_diff:.6f}")
+ print(f" Mean absolute difference: {mean_diff:.6f}")
+ print(
+ f" CFG parallel range: [{cfg_parallel_float.min():.4f}, {cfg_parallel_float.max():.4f}]"
+ )
+ print(f" Baseline range: [{baseline_float.min():.4f}, {baseline_float.max():.4f}]")
+
+ assert cos_sim > 0.99, (
+ f"I2V CFG parallel cosine similarity {cos_sim:.6f} below threshold 0.99. "
+ f"CFG Parallelism does not match standard CFG baseline."
+ )
+
+ print("\n[PASS] I2V CFG Parallelism (cfg_size=2) validated!")
+ print(" ā CFG parallel produces same output as standard CFG")
+ print(" ā Prompt splitting and all-gather working correctly")
+ print(" ā Image embeddings handled correctly")
+ print("=" * 80)
+
+ torch.cuda.empty_cache()
+
+
+# ============================================================================
+# COMBINED OPTIMIZATIONS TESTS (I2V)
+# ============================================================================
+
+
+@pytest.mark.parallelism
+class TestWanI2VCombinedOptimizations(unittest.TestCase):
+ """Test all optimizations combined for I2V: FP8 + TeaCache + TRTLLM + CFG Parallelism."""
+
+ DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
+
+ def setUp(self):
+ """Set up test fixtures and skip if checkpoint not available."""
+ torch.manual_seed(42)
+ if torch.cuda.is_available():
+ torch.cuda.manual_seed_all(42)
+ if not CHECKPOINT_PATH or not os.path.exists(CHECKPOINT_PATH):
+ self.skipTest(
+ "Checkpoint not available. Set DIFFUSION_MODEL_PATH environment variable."
+ )
+
+ def tearDown(self):
+ """Clean up GPU memory."""
+ import gc
+
+ gc.collect()
+ if torch.cuda.is_available():
+ torch.cuda.empty_cache()
+ torch.cuda.synchronize()
+
+ def test_all_optimizations_combined(self):
+ """Test I2V FP8 + TeaCache + TRTLLM attention + CFG=2 combined correctness.
+
+ This test validates that all optimizations work together correctly for I2V:
+ 1. FP8 per-tensor quantization for reduced memory/compute
+ 2. TeaCache for caching repeated computations
+ 3. TRTLLM attention backend for optimized attention kernels
+ 4. CFG Parallelism (cfg_size=2) for distributed CFG computation
+
+ We compare against a standard CFG baseline with relaxed thresholds.
+ """
+ num_gpus = torch.cuda.device_count()
+ if num_gpus < 2:
+ pytest.skip("Combined optimization test requires at least 2 GPUs for CFG parallel")
+ if not is_wan21_checkpoint():
+ pytest.skip(
+ "This test requires Wan 2.1 checkpoint. Use DIFFUSION_MODEL_PATH with '2.1' in the path."
+ )
+
+ print("\n" + "=" * 80)
+ print("I2V ALL OPTIMIZATIONS COMBINED TEST")
+ print("FP8 + TeaCache + TRTLLM Attention + CFG Parallelism (cfg_size=2)")
+ print("=" * 80)
+
+ # Load baseline on GPU 0 (no optimizations, standard CFG)
+ print("\n[1/3] Loading I2V baseline on GPU 0 (standard CFG, no optimizations)...")
+ args_baseline = DiffusionArgs(
+ checkpoint_path=CHECKPOINT_PATH,
+ device="cuda:0",
+ dtype="bfloat16",
+ skip_components=SKIP_MINIMAL,
+ parallel=ParallelConfig(dit_cfg_size=1), # Standard CFG
+ )
+ pipeline_baseline = PipelineLoader(args_baseline).load()
+ config = pipeline_baseline.transformer.model_config.pretrained_config
+
+ # Reset torch compile state
+ torch._dynamo.reset()
+
+ # Create FIXED test inputs
+ print("\n[2/3] Creating fixed test inputs...")
+ torch.manual_seed(42)
+ batch_size, num_frames, height, width, seq_len = 1, 1, 64, 64, 128
+
+ latents = torch.randn(
+ batch_size,
+ config.in_channels,
+ num_frames,
+ height,
+ width,
+ dtype=torch.bfloat16,
+ device="cuda:0",
+ )
+ timestep = torch.tensor([500], dtype=torch.long, device="cuda:0")
+ prompt_embeds = torch.randn(
+ batch_size, seq_len, config.text_dim, dtype=torch.bfloat16, device="cuda:0"
+ )
+ neg_prompt_embeds = torch.randn(
+ batch_size, seq_len, config.text_dim, dtype=torch.bfloat16, device="cuda:0"
+ )
+
+ # I2V-specific: Create image embeddings
+ image_embeds = None
+ image_dim = getattr(config, "image_dim", getattr(config, "image_embed_dim", None))
+ if image_dim is not None:
+ image_seq_len = 256
+ image_embeds = torch.randn(
+ batch_size, image_seq_len, image_dim, dtype=torch.bfloat16, device="cuda:0"
+ )
+
+ # Setup standard CFG config
+ cfg_config_baseline = pipeline_baseline._setup_cfg_config(
+ guidance_scale=5.0,
+ prompt_embeds=prompt_embeds,
+ neg_prompt_embeds=neg_prompt_embeds,
+ )
+
+ # Run baseline standard CFG
+ print(" Running baseline (standard CFG)...")
+
+ def forward_fn_baseline(
+ latents, extra_stream_latents, timestep, encoder_hidden_states, extra_tensors
+ ):
+ return pipeline_baseline.transformer( # noqa: F821
+ hidden_states=latents,
+ timestep=timestep,
+ encoder_hidden_states=encoder_hidden_states,
+ encoder_hidden_states_image=extra_tensors.get("encoder_hidden_states_image"),
+ )
+
+ with torch.no_grad():
+ local_extras = (
+ {"encoder_hidden_states_image": image_embeds} if image_embeds is not None else {}
+ )
+ baseline_output, _, _, _ = pipeline_baseline._denoise_step_standard(
+ latents=latents.clone(),
+ extra_stream_latents={},
+ timestep=timestep,
+ prompt_embeds=cfg_config_baseline["prompt_embeds"],
+ forward_fn=forward_fn_baseline,
+ guidance_scale=5.0,
+ guidance_rescale=0.0,
+ local_extras=local_extras,
+ )
+
+ print(f" ā Baseline output shape: {baseline_output.shape}")
+ print(f" ā Baseline range: [{baseline_output.min():.4f}, {baseline_output.max():.4f}]")
+
+ # Cleanup baseline
+ del pipeline_baseline
+ torch.cuda.empty_cache()
+
+ # Run with ALL optimizations (FP8 + TeaCache + TRTLLM + CFG=2)
+ print("\n[3/3] Running with ALL optimizations (FP8 + TeaCache + TRTLLM + CFG=2)...")
+ cfg_size = 2
+
+ inputs_cpu = [
+ prompt_embeds.cpu(),
+ neg_prompt_embeds.cpu(),
+ latents.cpu(),
+ timestep.cpu(),
+ image_embeds.cpu() if image_embeds is not None else None,
+ ]
+
+ manager = mp.Manager()
+ return_dict = manager.dict()
+
+ # Spawn workers with all optimizations
+ mp.spawn(
+ _run_all_optimizations_worker_i2v,
+ args=(cfg_size, CHECKPOINT_PATH, inputs_cpu, return_dict),
+ nprocs=cfg_size,
+ join=True,
+ )
+
+ # Get combined optimization output
+ combined_output = return_dict["output"].to("cuda:0")
+ print(f" ā Combined optimization output shape: {combined_output.shape}")
+
+ # Compare outputs (relaxed threshold for combined optimizations)
+ print("\n[Comparison] I2V Combined Optimizations vs Baseline:")
+ baseline_float = baseline_output.float()
+ combined_float = combined_output.float()
+
+ cos_sim = F.cosine_similarity(
+ combined_float.flatten(), baseline_float.flatten(), dim=0
+ ).item()
+
+ max_diff = torch.max(torch.abs(combined_float - baseline_float)).item()
+ mean_diff = torch.mean(torch.abs(combined_float - baseline_float)).item()
+
+ print(f" Cosine similarity: {cos_sim:.6f}")
+ print(f" Max absolute difference: {max_diff:.6f}")
+ print(f" Mean absolute difference: {mean_diff:.6f}")
+
+ # Relaxed threshold (0.95) since multiple optimizations compound numerical differences
+ assert cos_sim > 0.95, (
+ f"I2V combined optimization cosine similarity {cos_sim:.6f} below threshold 0.95"
+ )
+
+ print("\n[PASS] All optimizations (FP8 + TeaCache + TRTLLM + CFG) validated!")
+ print(" ā All optimizations work together correctly")
+ print(" ā I2V image embeddings handled correctly with all opts")
+ print("=" * 80)
+
+ torch.cuda.empty_cache()
+
+
+if __name__ == "__main__":
+ import unittest
+
+ unittest.main(verbosity=2)