[TRTLLM-10612][feat] Initial support of AIGV models in TRTLLM (#11462)

Signed-off-by: Chang Liu (Enterprise Products) <liuc@nvidia.com>
Signed-off-by: Chang Liu <9713593+chang-l@users.noreply.github.com>
Signed-off-by: Zhenhua Wang <zhenhuaw@nvidia.com>
Co-authored-by: Freddy Qi <junq@nvidia.com>
Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com>
Co-authored-by: Zhenhua Wang <zhenhuaw@nvidia.com>
This commit is contained in:
Chang Liu 2026-02-13 14:11:11 -08:00 committed by GitHub
parent 19a3031ecb
commit 26901e4aa0
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
75 changed files with 19366 additions and 195 deletions

View File

@ -5,9 +5,6 @@ TensorRT LLM
<h4>TensorRT LLM provides users with an easy-to-use Python API to define Large Language Models (LLMs) and supports
state-of-the-art optimizations to perform inference efficiently on NVIDIA GPUs.</h4>
🌟 TensorRT LLM is experimenting with Image&Video Generation models in [TensorRT-LLM/feat/visual_gen](https://github.com/NVIDIA/TensorRT-LLM/tree/feat/visual_gen/tensorrt_llm/visual_gen) branch.
This branch is a prototype and not stable for production use. PRs are not accepted.
[![Documentation](https://img.shields.io/badge/docs-latest-brightgreen.svg?style=flat)](https://nvidia.github.io/TensorRT-LLM/)
[![Ask DeepWiki](https://deepwiki.com/badge.svg)](https://deepwiki.com/NVIDIA/TensorRT-LLM)
[![python](https://img.shields.io/badge/python-3.12-green)](https://www.python.org/downloads/release/python-3123/)

View File

@ -0,0 +1,172 @@
# Visual Generation Examples
Quick reference for running visual generation models (WAN).
## Prerequisites
```bash
# Install dependencies (from repository root)
pip install -r requirements-dev.txt
pip install git+https://github.com/huggingface/diffusers.git
pip install av
```
## Quick Start
```bash
# Set MODEL_ROOT to your model directory (required for examples)
export MODEL_ROOT=/llm-models
# Optional: PROJECT_ROOT defaults to repo root when run from examples/visual_gen
# Run all examples (auto-detects GPUs)
cd examples/visual_gen
./visual_gen_examples.sh
```
## Environment Variables
| Variable | Default | Description |
|----------|---------|-------------|
| `PROJECT_ROOT` | Auto-detected | Path to repository root (set when running from `examples/visual_gen`) |
| `MODEL_ROOT` | `/llm-models` | Path to model directory |
| `TLLM_LOG_LEVEL` | `INFO` | Logging level |
---
## WAN (Text-to-Video)
### Basic Usage
**Single GPU:**
```bash
python visual_gen_wan_t2v.py \
--model_path ${MODEL_ROOT}/Wan2.1-T2V-1.3B-Diffusers \
--prompt "A cute cat playing piano" \
--height 480 --width 832 --num_frames 33 \
--output_path output.mp4
```
**With TeaCache:**
```bash
python visual_gen_wan_t2v.py \
--model_path ${MODEL_ROOT}/Wan2.1-T2V-1.3B-Diffusers \
--prompt "A cute cat playing piano" \
--height 480 --width 832 --num_frames 33 \
--enable_teacache \
--output_path output.mp4
```
### Multi-GPU Parallelism
WAN supports two parallelism modes that can be combined:
- **CFG Parallelism**: Split positive/negative prompts across GPUs
- **Ulysses Parallelism**: Split sequence across GPUs for longer sequences
**Ulysses Only (2 GPUs):**
```bash
python visual_gen_wan_t2v.py \
--model_path ${MODEL_ROOT}/Wan2.1-T2V-1.3B-Diffusers \
--prompt "A cute cat playing piano" \
--height 480 --width 832 --num_frames 33 \
--attention_backend TRTLLM \
--cfg_size 1 --ulysses_size 2 \
--output_path output.mp4
```
GPU Layout: GPU 0-1 share sequence (6 heads each)
**CFG Only (2 GPUs):**
```bash
python visual_gen_wan_t2v.py \
--model_path ${MODEL_ROOT}/Wan2.1-T2V-1.3B-Diffusers \
--prompt "A cute cat playing piano" \
--height 480 --width 832 --num_frames 33 \
--attention_backend TRTLLM \
--cfg_size 2 --ulysses_size 1 \
--output_path output.mp4
```
GPU Layout: GPU 0 (positive) | GPU 1 (negative)
**CFG + Ulysses (4 GPUs):**
```bash
python visual_gen_wan_t2v.py \
--model_path ${MODEL_ROOT}/Wan2.1-T2V-1.3B-Diffusers \
--prompt "A cute cat playing piano" \
--height 480 --width 832 --num_frames 33 \
--attention_backend TRTLLM \
--cfg_size 2 --ulysses_size 2 \
--output_path output.mp4
```
GPU Layout: GPU 0-1 (positive, Ulysses) | GPU 2-3 (negative, Ulysses)
**Large-Scale (8 GPUs):**
```bash
python visual_gen_wan_t2v.py \
--model_path ${MODEL_ROOT}/Wan2.1-T2V-1.3B-Diffusers \
--prompt "A cute cat playing piano" \
--height 480 --width 832 --num_frames 33 \
--attention_backend TRTLLM \
--cfg_size 2 --ulysses_size 4 \
--output_path output.mp4
```
GPU Layout: GPU 0-3 (positive) | GPU 4-7 (negative)
---
## Common Arguments
| Argument | WAN | Default | Description |
|----------|-----|---------|-------------|
| `--height` | ✓ | 720 | Output height |
| `--width` | ✓ | 1280 | Output width |
| `--num_frames` | ✓ | 81 | Number of frames |
| `--steps` | ✓ | 50 | Denoising steps |
| `--guidance_scale` | ✓ | 5.0 | CFG guidance strength |
| `--seed` | ✓ | 42 | Random seed |
| `--enable_teacache` | ✓ | False | Cache optimization |
| `--teacache_thresh` | ✓ | 0.2 | TeaCache similarity threshold |
| `--attention_backend` | ✓ | VANILLA | VANILLA or TRTLLM |
| `--cfg_size` | ✓ | 1 | CFG parallelism |
| `--ulysses_size` | ✓ | 1 | Sequence parallelism |
| `--linear_type` | ✓ | default | Quantization type |
## Troubleshooting
**Out of Memory:**
- Use quantization: `--linear_type trtllm-fp8-blockwise`
- Reduce resolution or frames
- Enable TeaCache: `--enable_teacache`
- Use Ulysses parallelism with more GPUs
**Slow Inference:**
- Enable TeaCache: `--enable_teacache`
- Use TRTLLM backend: `--attention_backend TRTLLM`
- Use multi-GPU: `--cfg_size 2` or `--ulysses_size 2`
**Import Errors:**
- Run from repository root
- Install necessary dependencies, e.g., `pip install -r requirements-dev.txt`
**Ulysses Errors:**
- `ulysses_size` must divide 12 (WAN heads)
- Total GPUs = `cfg_size × ulysses_size`
- Sequence length must be divisible by `ulysses_size`
## Output Formats
- **WAN**: `.mp4` (video), `.gif` (animated), `.png` (single frame)
## Baseline Validation
Compare with official HuggingFace Diffusers implementation:
```bash
# Run HuggingFace baselines
./hf_examples.sh
# Or run individual models
python hf_wan.py --model_path ${MODEL_ROOT}/Wan2.1-T2V-1.3B-Diffusers
```
Compare outputs with same seed for correctness verification.

Binary file not shown.

After

Width:  |  Height:  |  Size: 445 KiB

View File

@ -0,0 +1,128 @@
#!/bin/bash
# HuggingFace Baseline Tests - Official Diffusers Implementation
#
# Usage:
# export PROJECT_ROOT=/path/to/tekit
# export MODEL_ROOT=/path/to/models
# ./hf_examples.sh
#
# Or inline:
# PROJECT_ROOT=/workspace/gitlab/tekit-b200 MODEL_ROOT=/llm-models ./hf_examples.sh
set -e # Exit on error
# Environment variables with defaults
PROJECT_ROOT=${PROJECT_ROOT:-"/workspace/gitlab/tekit-b200"}
MODEL_ROOT=${MODEL_ROOT:-"/llm-models"}
# Log configuration
export TLLM_LOG_LEVEL=${TLLM_LOG_LEVEL:-"INFO"}
echo "============================================"
echo "HuggingFace Diffusers Baseline Tests"
echo "============================================"
echo "PROJECT_ROOT: $PROJECT_ROOT"
echo "MODEL_ROOT: $MODEL_ROOT"
echo "LOG_LEVEL: $TLLM_LOG_LEVEL"
echo ""
echo "Purpose: Establish baseline results using"
echo " official diffusers implementations"
echo "============================================"
echo ""
# Check Python dependencies
echo "Checking dependencies..."
MISSING_DEPS=""
if ! python -c "import diffusers" 2>/dev/null; then
echo "❌ ERROR: diffusers not found"
MISSING_DEPS="$MISSING_DEPS diffusers"
fi
if ! python -c "import torch" 2>/dev/null; then
echo "❌ ERROR: torch not found"
MISSING_DEPS="$MISSING_DEPS torch"
fi
if [ -n "$MISSING_DEPS" ]; then
echo ""
echo "❌ Missing required dependencies:$MISSING_DEPS"
echo "Install with: pip install$MISSING_DEPS"
exit 1
fi
echo "✅ All required dependencies found"
echo ""
# Detect GPU
if command -v nvidia-smi &> /dev/null; then
GPU_COUNT=$(nvidia-smi --query-gpu=name --format=csv,noheader | wc -l)
echo "Detected $GPU_COUNT GPU(s)"
GPU_NAME=$(nvidia-smi --query-gpu=name --format=csv,noheader | head -1)
echo "GPU: $GPU_NAME"
else
echo "⚠️ WARNING: nvidia-smi not found"
echo " Continuing with CPU (very slow!)"
GPU_COUNT=0
fi
echo ""
# Create output directory (in current directory)
OUTPUT_DIR="./baseline_outputs"
mkdir -p "$OUTPUT_DIR"
echo "Output directory: $OUTPUT_DIR ($(pwd)/baseline_outputs)"
echo ""
#############################################
# WAN (Wan2.1) Baseline Test
#############################################
echo "============================================"
echo "1/1: WAN Baseline Test"
echo "============================================"
echo ""
WAN_MODEL="${MODEL_ROOT}/Wan2.1-T2V-1.3B-Diffusers/"
WAN_OUTPUT="${OUTPUT_DIR}/wan_baseline.gif"
if [ -d "$WAN_MODEL" ]; then
echo "Testing WAN with official diffusers..."
python ${PROJECT_ROOT}/examples/visual_gen/hf_wan.py \
--model_path "$WAN_MODEL" \
--output_path "$WAN_OUTPUT" \
--prompt "A cute cat playing piano" \
--height 480 \
--width 832 \
--num_frames 33 \
--steps 50 \
--guidance_scale 7.0 \
--seed 42
echo ""
echo "✅ WAN baseline test completed"
echo " Output: $WAN_OUTPUT"
else
echo "⚠️ SKIPPED: WAN model not found at $WAN_MODEL"
fi
echo ""
#############################################
# Summary
#############################################
echo "============================================"
echo "Baseline Tests Complete!"
echo "============================================"
echo ""
echo "Output files saved to: $OUTPUT_DIR"
echo ""
ls -lh "$OUTPUT_DIR" 2>/dev/null || echo "No outputs generated"
echo ""
echo "Next Steps:"
echo " 1. Verify outputs are correct (images/videos generated)"
echo " 2. Compare with custom implementation outputs"
echo " 3. Use these as reference/baseline for debugging"
echo ""
echo "Comparison command:"
echo " diff -r $OUTPUT_DIR <custom_implementation_outputs>"
echo "============================================"

141
examples/visual_gen/hf_wan.py Executable file
View File

@ -0,0 +1,141 @@
#!/usr/bin/env python3
"""Baseline test for WAN using official diffusers library."""
import sys
import torch
from output_handler import OutputHandler, postprocess_hf_video_tensor
from tensorrt_llm._torch.visual_gen import MediaOutput
def test_wan_baseline(
model_path: str,
output_path: str,
prompt: str = "A cute cat playing piano",
height: int = 480,
width: int = 832,
num_frames: int = 33,
num_inference_steps: int = 50,
guidance_scale: float = 7.0,
seed: int = 42,
):
"""Test WAN video generation with official diffusers."""
from diffusers import WanPipeline
print("=" * 80)
print("WAN Baseline Test (Official Diffusers)")
print("=" * 80)
print()
# Load pipeline
print(f"Loading WAN pipeline from {model_path}...")
pipe = WanPipeline.from_pretrained(model_path, torch_dtype=torch.bfloat16)
pipe.to("cuda")
print("✅ Pipeline loaded")
print()
# Check model states
print("Model Training States:")
print(f" text_encoder.training: {pipe.text_encoder.training}")
print(f" transformer.training: {pipe.transformer.training}")
print(f" vae.training: {pipe.vae.training}")
print()
# Generate video
print(f"Generating video: '{prompt}'")
print(
f"Parameters: {height}x{width}, {num_frames} frames, {num_inference_steps} steps, guidance={guidance_scale}"
)
print()
# Set random seed
generator = torch.Generator(device="cuda").manual_seed(seed)
result = pipe(
prompt=prompt,
height=height,
width=width,
num_frames=num_frames,
num_inference_steps=num_inference_steps,
guidance_scale=guidance_scale,
generator=generator,
output_type="pt",
return_dict=False,
)
video = result[0]
# Post-process video tensor: (B, T, C, H, W) -> (T, H, W, C) uint8
video = postprocess_hf_video_tensor(video, remove_batch_dim=True)
print("=" * 80)
print("Generation Complete!")
print("=" * 80)
print(f"Video shape: {video.shape}")
print(f"Video dtype: {video.dtype}")
print()
# Save output
print(f"Saving output to {output_path}...")
OutputHandler.save(output=MediaOutput(video=video), output_path=output_path, frame_rate=24.0)
print(f"✅ Saved to {output_path}")
print()
print("=" * 80)
print("WAN BASELINE TEST PASSED ✅")
print("=" * 80)
return video
if __name__ == "__main__":
import argparse
parser = argparse.ArgumentParser(
description="HuggingFace Baseline - WAN Text-to-Video Generation"
)
# Model & Input
parser.add_argument(
"--model_path",
type=str,
default="/llm-models/Wan2.1-T2V-1.3B-Diffusers/",
help="Path to WAN model",
)
parser.add_argument(
"--prompt", type=str, default="A cute cat playing piano", help="Text prompt for generation"
)
parser.add_argument(
"--output_path", type=str, default="wan_baseline.gif", help="Output file path"
)
# Generation parameters
parser.add_argument("--height", type=int, default=480, help="Video height")
parser.add_argument("--width", type=int, default=832, help="Video width")
parser.add_argument("--num_frames", type=int, default=33, help="Number of frames to generate")
parser.add_argument("--steps", type=int, default=50, help="Number of denoising steps")
parser.add_argument(
"--guidance_scale", type=float, default=7.0, help="Classifier-free guidance scale"
)
parser.add_argument("--seed", type=int, default=42, help="Random seed")
args = parser.parse_args()
try:
test_wan_baseline(
args.model_path,
args.output_path,
prompt=args.prompt,
height=args.height,
width=args.width,
num_frames=args.num_frames,
num_inference_steps=args.steps,
guidance_scale=args.guidance_scale,
seed=args.seed,
)
except Exception as e:
print(f"\n❌ ERROR: {e}")
import traceback
traceback.print_exc()
sys.exit(1)

View File

@ -0,0 +1,237 @@
"""Unified output handler for diffusion model outputs."""
import os
from typing import Optional
import torch
from PIL import Image
from tensorrt_llm import logger
from tensorrt_llm.llmapi.visual_gen import MediaOutput
def postprocess_hf_video_tensor(video: torch.Tensor, remove_batch_dim: bool = True) -> torch.Tensor:
"""Post-process video tensor from HuggingFace pipeline output to final format.
HuggingFace pipelines with output_type="pt" return videos in (B, T, C, H, W) format,
which is different from VAE decoder output format.
Args:
video: Video tensor in (B, T, C, H, W) format from HuggingFace pipeline
remove_batch_dim: Whether to remove batch dimension. Default True for typical
single-batch video generation.
Returns:
Post-processed video tensor:
- If remove_batch_dim=True: (T, H, W, C) uint8 tensor
- If remove_batch_dim=False: (B, T, H, W, C) uint8 tensor
Note:
Assumes video values are in [-1, 1] range (standard pipeline output).
"""
# Remove batch dimension first if requested
if remove_batch_dim:
video = video[0] # (B, T, C, H, W) -> (T, C, H, W)
video = video.permute(0, 2, 3, 1) # (T, C, H, W) -> (T, H, W, C)
else:
video = video.permute(0, 1, 3, 4, 2) # (B, T, C, H, W) -> (B, T, H, W, C)
# Normalize to [0, 1] range
video = (video / 2 + 0.5).clamp(0, 1)
# Convert to uint8
video = (video * 255).round().to(torch.uint8)
return video
def postprocess_hf_image_tensor(image: torch.Tensor) -> torch.Tensor:
"""Post-process image tensor from HuggingFace pipeline output to final format.
HuggingFace pipelines with output_type="pt" return images in (B, C, H, W) format.
Args:
image: Image tensor in (B, C, H, W) or (C, H, W) format from HuggingFace pipeline
Returns:
Post-processed image tensor in (H, W, C) uint8 format
Note:
Assumes image values are in [-1, 1] range (standard pipeline output).
"""
# Remove batch dimension if present
if image.ndim == 4:
image = image[0] # (B, C, H, W) -> (C, H, W)
# Convert to (H, W, C) format
image = image.permute(1, 2, 0) # (C, H, W) -> (H, W, C)
# Normalize to [0, 1] range
image = (image / 2 + 0.5).clamp(0, 1)
# Convert to uint8
image = (image * 255).round().to(torch.uint8)
return image
class OutputHandler:
"""Handle saving of generated outputs in various formats.
Supports MediaOutput from all models:
- Video models (WAN): MediaOutput(video=torch.Tensor)
- Image models: MediaOutput(image=torch.Tensor)
- Video+Audio models: MediaOutput(video=torch.Tensor, audio=torch.Tensor)
Supported output formats:
- .png: Save single image or middle frame
- .gif: Save video as animated GIF (no audio)
- .mp4: Save video with audio (requires diffusers export_utils)
"""
@staticmethod
def save(output: MediaOutput, output_path: str, frame_rate: float = 24.0):
"""Save output based on content type and file extension.
Args:
output: MediaOutput containing model outputs (image/video/audio)
output_path: Path to save the output file
frame_rate: Frames per second for video output (default: 24.0)
"""
if not isinstance(output, MediaOutput):
raise ValueError(f"Expected output to be MediaOutput, got {type(output)}")
file_ext = os.path.splitext(output_path)[1].lower()
# Determine content type
if output.image is not None:
OutputHandler._save_image(output.image, output_path, file_ext)
elif output.video is not None:
OutputHandler._save_video(output.video, output.audio, output_path, file_ext, frame_rate)
else:
raise ValueError("Unknown output format. MediaOutput has no image or video data.")
@staticmethod
def _save_image(image: torch.Tensor, output_path: str, file_ext: str):
"""Save single image output.
Args:
image: Image as torch tensor (H, W, C) uint8
output_path: Path to save the image
file_ext: File extension (.png, .jpg, etc.)
"""
if file_ext not in [".png", ".jpg", ".jpeg"]:
logger.warning(f"Image output requested with {file_ext}, defaulting to .png")
output_path = output_path.replace(file_ext, ".png")
# Convert torch.Tensor to PIL Image and save
image_np = image.cpu().numpy()
Image.fromarray(image_np).save(output_path)
logger.info(f"Saved image to {output_path}")
@staticmethod
def _save_video(
video: torch.Tensor,
audio: Optional[torch.Tensor],
output_path: str,
file_ext: str,
frame_rate: float,
):
"""Save video output with optional audio.
Args:
video: Video frames as torch tensor (T, H, W, C) with dtype uint8
audio: Optional audio as torch tensor
output_path: Path to save the video
file_ext: File extension (.mp4, .gif, .png)
frame_rate: Frames per second
"""
if file_ext == ".mp4":
OutputHandler._save_mp4(video, audio, output_path, frame_rate)
elif file_ext == ".gif":
OutputHandler._save_gif(video, output_path, frame_rate)
elif file_ext == ".png":
OutputHandler._save_middle_frame(video, output_path)
else:
logger.warning(f"Unsupported video output format: {file_ext}, defaulting to .png")
output_path = output_path.replace(file_ext, ".png")
OutputHandler._save_middle_frame(video, output_path)
@staticmethod
def _save_mp4(
video: torch.Tensor, audio: Optional[torch.Tensor], output_path: str, frame_rate: float
):
"""Save video with optional audio as MP4.
Args:
video: Video frames as torch tensor (T, H, W, C) uint8
audio: Optional audio as torch tensor (float32)
output_path: Output path for MP4
frame_rate: Frames per second
"""
try:
from diffusers.pipelines.ltx2.export_utils import encode_video
# Prepare audio if present
audio_prepared = audio.float() if audio is not None else None
# encode_video expects (T, H, W, C) uint8 video and float32 audio
encode_video(
video,
fps=frame_rate,
audio=audio_prepared,
audio_sample_rate=24000 if audio_prepared is not None else None,
output_path=output_path,
)
logger.info(f"Saved video{' with audio' if audio is not None else ''} to {output_path}")
except ImportError:
logger.warning(
"diffusers export_utils (encode_video) not available. "
"Falling back to saving middle frame as PNG."
)
png_path = output_path.replace(".mp4", ".png")
OutputHandler._save_middle_frame(video, png_path)
@staticmethod
def _save_gif(video: torch.Tensor, output_path: str, frame_rate: float):
"""Save video as animated GIF.
Args:
video: Video frames as torch tensor (T, H, W, C) uint8
output_path: Output path for GIF
frame_rate: Frames per second
"""
# Convert torch.Tensor to numpy for PIL
video_np = video.cpu().numpy()
# Convert to list of PIL Images
frames = [Image.fromarray(video_np[i]) for i in range(video_np.shape[0])]
# Save as animated GIF
duration_ms = int(1000 / frame_rate)
frames[0].save(
output_path,
save_all=True,
append_images=frames[1:],
optimize=False,
duration=duration_ms,
loop=0,
)
logger.info(f"Saved video as GIF to {output_path} ({len(frames)} frames)")
@staticmethod
def _save_middle_frame(video: torch.Tensor, output_path: str):
"""Save middle frame of video as PNG.
Args:
video: Video frames as torch tensor (T, H, W, C) uint8
output_path: Output path for PNG
"""
# Convert torch.Tensor to numpy for PIL
video_np = video.cpu().numpy()
# Extract middle frame
frame_idx = video_np.shape[0] // 2
Image.fromarray(video_np[frame_idx]).save(output_path)
logger.info(f"Saved frame {frame_idx} to {output_path}")

View File

@ -0,0 +1,322 @@
# Visual Generation API Examples
This directory contains example scripts that demonstrate how to use the TensorRT-LLM Visual Generation API endpoints for image and video generation.
## Overview
These examples show how to interact with the visual generation server using both the OpenAI Python SDK and standard HTTP requests. The API provides endpoints for:
- **Image Generation**: Text-to-image generation (T2I)
- **Video Generation**:
- Text-to-video generation (T2V) - generate videos from text prompts only
- Text+Image-to-video generation (TI2V) - generate videos from text + reference image
- Both synchronous and asynchronous modes supported
- Multipart/form-data support for file uploads
- **Video Management**: Retrieving and deleting generated videos
## Prerequisites
Before running these examples, ensure you have:
1. **Install modules**: Install required dependencies before running examples:
```bash
pip install git+https://github.com/huggingface/diffusers.git
pip install av
```
2. **Server Running**: The TensorRT-LLM visual generation server must be running
```bash
trtllm-serve <path to your model> --extra_visual_gen_options <path to config yaml>
```
e.g.
```bash
trtllm-serve $LLM_MODEL_DIR/Wan2.1-T2V-1.3B-Diffusers --extra_visual_gen_options ./configs/wan.yml
# Run server on background:
trtllm-serve $LLM_MODEL_DIR/Wan2.1-T2V-1.3B-Diffusers --extra_visual_gen_options ./configs/wan.yml > /tmp/serve.log 2>&1 &
## Check if the server is setup
tail -f /tmp/serve.log
```
## Examples
Current supported & tested models:
1. WAN T2V/I2V for video generation (t2v, ti2v, delete_video)
### 1. Synchronous Image Generation (`sync_t2i.py`)
Demonstrates synchronous text-to-image generation using the OpenAI SDK.
**Features:**
- Generates images from text prompts
- Supports configurable image size and quality
- Returns base64-encoded images or URLs
- Saves generated images to disk
**Usage:**
```bash
# Use default localhost server
python sync_image_gen.py
# Specify custom server URL
python sync_image_gen.py http://your-server:8000/v1
```
**API Endpoint:** `POST /v1/images/generations`
**Output:** Saves generated image to `output_generation.png` (or numbered files for multiple images)
---
### 2. Synchronous Video Generation with T2V and TI2V Modes (`sync_video_gen.py`)
Demonstrates synchronous video generation using direct HTTP requests. Waits for completion and returns the video file directly.
**Features:**
- **T2V Mode**: Generate videos from text prompts only
- **TI2V Mode**: Generate videos from text + reference image (multipart/form-data)
- Waits for video generation to complete before returning
- Returns video file directly in response
- Command-line interface for easy testing
**Usage:**
```bash
# Text-to-Video (T2V) - No reference image
python sync_video_gen.py --mode t2v \
--prompt "A cute cat playing with a ball in the park" \
--duration 4.0 --fps 24 --size 256x256
# Text+Image-to-Video (TI2V) - With reference image
## Note: longer duration and higher size will lead to much longer waiting time
python sync_video_gen.py --mode ti2v \
--prompt "She turns around and smiles, then slowly walks out of the frame" \
--image ./media/woman_skyline_original_720p.jpeg \
--duration 4.0 --fps 24 --size 512x512
# Custom parameters
python sync_video_gen.py --mode t2v \
--prompt "A serene sunset over the ocean" \
--duration 5.0 --fps 30 --size 512x512 \
--output my_video.mp4
```
**Command-Line Arguments:**
- `--mode` - Generation mode: `t2v` or `ti2v` (default: t2v)
- `--prompt` - Text prompt for video generation (required)
- `--image` - Path to reference image (required for ti2v mode)
- `--base-url` - API server URL (default: http://localhost:8000/v1)
- `--model` - Model name (default: wan)
- `--duration` - Video duration in seconds (default: 4.0)
- `--fps` - Frames per second (default: 24)
- `--size` - Video resolution in WxH format (default: 256x256)
- `--output` - Output video file path (default: output_sync.mp4)
**API Endpoint:** `POST /v1/videos/generations`
**API Details:**
- T2V uses JSON `Content-Type: application/json`
- TI2V uses multipart/form-data `Content-Type: multipart/form-data` with file upload
**Output:** Saves generated video to specified output file
---
### 3. Async Video Generation with T2V and TI2V Modes (`async_video_gen.py`)
**NEW**: Enhanced async video generation supporting both Text-to-Video (T2V) and Text+Image-to-Video (TI2V) modes.
**Features:**
- **T2V Mode**: Generate videos from text prompts only (JSON request)
- **TI2V Mode**: Generate videos from text + reference image (multipart/form-data with file upload)
- Command-line interface for easy testing
- Automatic mode detection
- Comprehensive parameter control
**Usage:**
```bash
# Text-to-Video (T2V) - No reference image
python async_video_gen.py --mode t2v \
--prompt "A cool cat on a motorcycle in the night" \
--duration 4.0 --fps 24 --size 256x256
# Text+Image-to-Video (TI2V) - With reference image
python async_video_gen.py --mode ti2v \
--prompt "She turns around and smiles, then slowly walks out of the frame" \
--image ./media/woman_skyline_original_720p.jpeg \
--duration 4.0 --fps 24 --size 512x512
# Custom parameters
python async_video_gen.py --mode t2v \
--prompt "A serene sunset over the ocean" \
--duration 5.0 --fps 30 --size 512x512 \
--output my_video.mp4
```
**Command-Line Arguments:**
- `--mode` - Generation mode: `t2v` or `ti2v` (default: t2v)
- `--prompt` - Text prompt for video generation (required)
- `--image` - Path to reference image (required for ti2v mode)
- `--base-url` - API server URL (default: http://localhost:8000/v1)
- `--model` - Model name (default: wan)
- `--duration` - Video duration in seconds (default: 4.0)
- `--fps` - Frames per second (default: 24)
- `--size` - Video resolution in WxH format (default: 256x256)
- `--output` - Output video file path (default: output_async.mp4)
**API Details:**
- T2V uses JSON `Content-Type: application/json`
- TI2V uses multipart/form-data `Content-Type: multipart/form-data` with file upload
**Output:** Saves generated video to specified output file
---
### 4. Video Deletion (`delete_video.py`)
Demonstrates the complete lifecycle of video generation and deletion.
**Features:**
- Creates a test video generation job
- Waits for completion
- Deletes the generated video
- Verifies deletion by attempting to retrieve the deleted video
- Tests error handling for non-existent videos
**Usage:**
```bash
# Use default localhost server
python delete_video.py
# Specify custom server URL
python delete_video.py http://your-server:8000/v1
```
**API Endpoints:**
- `POST /v1/videos` - Create video job
- `GET /v1/videos/{video_id}` - Check video status
- `DELETE /v1/videos/{video_id}` - Delete video
**Test Flow:**
1. Create video generation job
2. Wait for completion
3. Delete the video
4. Verify video returns `NotFoundError`
5. Test deletion of non-existent video
---
## API Configuration
All examples use the following default configuration:
- **Base URL**: `http://localhost:8000/v1`
- **API Key**: `"tensorrt_llm"` (authentication token)
- **Timeout**: 300 seconds for async operations
You can customize these by:
1. Passing the base URL as a command-line argument
2. Modifying the default parameters in each script's function
## Common Parameters
### Image Generation
- `model`: Model identifier (e.g., "wan")
- `prompt`: Text description
- `n`: Number of images to generate
- `size`: Image dimensions (e.g., "512x512", "1024x1024")
- `quality`: "standard" or "hd"
- `response_format`: "b64_json" or "url"
### Video Generation
- `model`: Model identifier (e.g., "wan")
- `prompt`: Text description
- `size`: Video resolution (e.g., "256x256", "512x512")
- `seconds`: Duration in seconds
- `fps`: Frames per second
- `input_reference`: Reference image file (for TI2V mode)
## Quick Reference - curl Examples
### Text-to-Video (JSON)
```bash
curl -X POST "http://localhost:8000/v1/videos" \
-H "Content-Type: application/json" \
-d '{
"prompt": "A cool cat on a motorcycle",
"seconds": 4.0,
"fps": 24,
"size": "256x256"
}'
```
### Text+Image-to-Video (Multipart with File Upload)
```bash
curl -X POST "http://localhost:8000/v1/videos" \
-F "prompt=She turns around and smiles" \
-F "input_reference=@./media/woman_skyline_original_720p.jpeg" \
-F "seconds=4.0" \
-F "fps=24" \
-F "size=256x256" \
-F "guidance_scale=5.0"
```
### Check Video Status
```bash
curl -X GET "http://localhost:8000/v1/videos/{video_id}"
```
### Download Video
```bash
curl -X GET "http://localhost:8000/v1/videos/{video_id}/content" -o output.mp4
```
### Delete Video
```bash
curl -X DELETE "http://localhost:8000/v1/videos/{video_id}"
```
## API Endpoints Summary
| Endpoint | Method | Mode | Content-Type | Purpose |
|----------|--------|------|--------------|---------|
| `/v1/videos` | POST | Async | JSON or Multipart | Create video job (T2V/TI2V) |
| `/v1/videos/generations` | POST | Sync | JSON or Multipart | Generate video sync (T2V/TI2V) |
| `/v1/videos/{id}` | GET | - | - | Get video status/metadata |
| `/v1/videos/{id}/content` | GET | - | - | Download video file |
| `/v1/videos/{id}` | DELETE | - | - | Delete video |
| `/v1/videos` | GET | - | - | List all videos |
| `/v1/images/generations` | POST | - | JSON | Generate images (T2I) |
**Note:** Both `/v1/videos` (async) and `/v1/videos/generations` (sync) support:
- **JSON**: Standard text-to-video (T2V)
- **Multipart/Form-Data**: Text+image-to-video (TI2V) with file upload
## Error Handling
All examples include comprehensive error handling:
- Connection errors (server not running)
- API errors (invalid parameters, model not found)
- Timeout errors (generation taking too long)
- Resource errors (video not found for deletion)
Errors are displayed with full stack traces for debugging.
## Output Files
Generated files are saved to the current working directory:
- `output_generation.png` - Synchronous image generation (`sync_image_gen.py`)
- `output_sync.mp4` - Synchronous video generation (`sync_video_gen.py`)
- `output_async.mp4` - Asynchronous video generation (`async_video_gen.py`)
- `output_multipart.mp4` - Multipart example output (`multipart_example.py`)
**Note:** You can customize output filenames using the `--output` parameter in all scripts.

View File

@ -0,0 +1,238 @@
#!/usr/bin/env python
"""Test script for asynchronous video generation endpoint.
Tests POST /v1/videos endpoint which returns immediately with a job ID.
The video is generated in the background and can be retrieved later.
Supports two modes:
- Text-to-Video (T2V): Generate video from text prompt only
- Text+Image-to-Video (TI2V): Generate video from text prompt + reference image
Examples:
# Text-to-Video (T2V)
python async_video_gen.py --mode t2v --prompt "A cool cat on a motorcycle"
# Text+Image-to-Video (TI2V)
python async_video_gen.py --mode ti2v --prompt "She turns and smiles" --image ./media/woman.jpg
"""
import argparse
import sys
import time
from pathlib import Path
import openai
def test_async_video_generation(
base_url: str = "http://localhost:8000/v1",
model: str = "wan",
prompt: str = "A video of a cool cat on a motorcycle in the night",
input_reference: str = None,
duration: float = 4.0,
fps: int = 24,
size: str = "256x256",
output_file: str = "output_async.mp4",
):
"""Test asynchronous video generation with OpenAI SDK.
Args:
base_url: Base URL of the API server
model: Model name to use
prompt: Text prompt for generation
input_reference: Path to reference image (optional, for TI2V mode)
duration: Video duration in seconds
fps: Frames per second
size: Video resolution (WxH format)
output_file: Output video file path
"""
mode = "TI2V" if input_reference else "T2V"
print("=" * 80)
print(f"Testing Async Video Generation API - {mode} Mode")
print("=" * 80)
# Initialize client
client = openai.OpenAI(base_url=base_url, api_key="tensorrt_llm")
print("\n1. Creating video generation job...")
print(f" Mode: {mode}")
print(f" Prompt: {prompt}")
if input_reference:
print(f" Input Reference: {input_reference}")
print(f" Duration: {duration}s")
print(f" FPS: {fps}")
print(f" Size: {size}")
try:
# Prepare request parameters
create_params = {
"model": model,
"prompt": prompt,
"size": size,
"seconds": duration,
"extra_body": {
"fps": fps,
},
}
# Add input reference if provided (TI2V mode)
if input_reference:
if not Path(input_reference).exists():
print(f"\n❌ Error: Input reference image not found: {input_reference}")
return False
create_params["input_reference"] = open(input_reference, "rb")
# Create video generation job
job = client.videos.create(**create_params)
print("Video generation started: \n", job.model_dump_json(indent=2))
video_id = job.id
print("\n✓ Job created successfully!")
print(f" Video ID: {video_id}")
print(f" Status: {job.status}")
# Poll for completion
print("\n2. Polling for completion...")
max_attempts = 300 # 5 minutes with 1s intervals
attempt = 0
while attempt < max_attempts:
attempt += 1
# Get job status using SDK's get method
job = client.videos.retrieve(video_id)
status = job.status
print(f" [{attempt:3d}] Status: {status}", end="\r")
if status == "completed":
print("\n\n✓ Video generation completed!")
print(f" Completion time: {job.completed_at}")
break
elif status == "failed":
print("\n\n❌ Video generation failed!")
print(f" Error: {job.error}")
return False
time.sleep(1)
else:
print(f"\n\n❌ Timeout waiting for completion (>{max_attempts}s)")
return False
# Download video
print("\n3. Downloading video...")
# For binary content, use the underlying HTTP client
content = client.videos.download_content(video_id, variant="video")
content.write_to_file(output_file)
print(f" ✓ Saved to: {output_file}")
print("\n" + "=" * 80)
print("✓ Async video generation test completed successfully!")
print("=" * 80)
return True
except Exception as e:
print(f"\n❌ Error: {e}")
import traceback
traceback.print_exc()
return False
if __name__ == "__main__":
parser = argparse.ArgumentParser(
description="Test async video generation API with T2V and TI2V modes",
formatter_class=argparse.RawDescriptionHelpFormatter,
epilog="""
Examples:
# Text-to-Video (T2V)
python async_video_gen.py --mode t2v --prompt "A cool cat on a motorcycle"
# Text+Image-to-Video (TI2V)
python async_video_gen.py --mode ti2v \\
--prompt "She turns around and smiles, then slowly walks out of the frame" \\
--image ./media/woman_skyline_original_720p.jpeg
# Custom parameters
python async_video_gen.py --mode t2v \\
--prompt "A serene sunset over the ocean" \\
--duration 5.0 --fps 30 --size 512x512 \\
--output my_video.mp4
""",
)
# Mode selection
parser.add_argument(
"--mode",
choices=["t2v", "ti2v"],
default="t2v",
help="Generation mode: t2v (Text-to-Video) or ti2v (Text+Image-to-Video)",
)
# Required parameters
parser.add_argument(
"--prompt",
type=str,
default="A video of a cool cat on a motorcycle in the night",
help="Text prompt for video generation",
)
# TI2V mode parameters
parser.add_argument(
"--image",
"--input-reference",
type=str,
default=None,
help="Path to reference image (required for ti2v mode)",
)
# Optional parameters
parser.add_argument(
"--base-url",
type=str,
default="http://localhost:8000/v1",
help="Base URL of the API server",
)
parser.add_argument("--model", type=str, default="wan", help="Model name to use")
parser.add_argument(
"--duration", "--seconds", type=float, default=4.0, help="Video duration in seconds"
)
parser.add_argument("--fps", type=int, default=24, help="Frames per second")
parser.add_argument(
"--size",
type=str,
default="256x256",
help="Video resolution in WxH format (e.g., 1280x720)",
)
parser.add_argument(
"--output", type=str, default="output_async.mp4", help="Output video file path"
)
args = parser.parse_args()
# Validate ti2v mode requirements
if args.mode == "ti2v" and not args.image:
parser.error("--image is required when using --mode ti2v")
# Display configuration
print("\n" + "=" * 80)
print("OpenAI SDK - Async Video Generation Test")
print("=" * 80)
print(f"Base URL: {args.base_url}")
print(f"Mode: {args.mode.upper()}")
print()
# Test async video generation
success = test_async_video_generation(
base_url=args.base_url,
model=args.model,
prompt=args.prompt,
input_reference=args.image,
duration=args.duration,
fps=args.fps,
size=args.size,
output_file=args.output,
)
sys.exit(0 if success else 1)

View File

@ -0,0 +1,8 @@
linear:
type: default
teacache:
enable_teacache: true
teacache_thresh: 0.2
parallel:
dit_cfg_size: 1
dit_ulysses_size: 1

View File

@ -0,0 +1,151 @@
#!/usr/bin/env python
"""Test script for DELETE /v1/videos/{video_id} endpoint.
Tests the video deletion functionality by:
1. Creating a video generation job
2. Waiting for completion
3. Deleting the video
4. Verifying the deletion
"""
import sys
import time
import openai
def test_delete_video(
base_url: str = "http://localhost:8000/v1",
model: str = "wan",
prompt: str = "A simple test video for deletion",
duration: float = 2.0,
fps: int = 8,
size: str = "256x256",
):
"""Test video deletion endpoint using OpenAI SDK."""
print("=" * 80)
print("Testing DELETE /v1/videos/{video_id} Endpoint")
print("=" * 80)
# Initialize OpenAI client
client = openai.OpenAI(base_url=base_url, api_key="tensorrt_llm")
video_id = None
try:
# Step 1: Create a video generation job
print("\n1. Creating video generation job...")
print(f" Prompt: {prompt}")
print(f" Duration: {duration}s")
print(f" FPS: {fps}")
print(f" Size: {size}")
job = client.videos.create(
model=model,
prompt=prompt,
size=size,
seconds=duration,
extra_body={
"fps": fps,
},
)
video_id = job.id
print(f" ✓ Video job created with ID: {video_id}")
print(f" Status: {job.status}")
# Step 2: Wait for video completion
print("\n2. Waiting for video generation to complete...")
max_attempts = 60 # attempts with 1s intervals
attempt = 0
while attempt < max_attempts:
attempt += 1
# Get job status using SDK's retrieve method
job = client.videos.retrieve(video_id)
status = job.status
print(f" [{attempt:3d}] Status: {status}", end="\r")
if status == "completed":
print(" ✓ Video generation completed!")
break
elif status == "failed":
print(" ❌ Video generation failed!")
return False
time.sleep(1)
else:
print(" ⚠ Timeout waiting for video completion")
# Continue with deletion anyway
# Step 3: Delete the video
print(f"\n3. Deleting video {video_id}...")
delete_result = client.videos.delete(video_id)
print(f" Response: {delete_result.model_dump_json(indent=2)}")
if delete_result.deleted:
print(" ✓ Video deleted successfully!")
else:
print(" ❌ Video deletion returned False")
return False
# Step 4: Verify the video is gone
print("\n4. Verifying video deletion...")
try:
verify_job = client.videos.retrieve(video_id)
print(f" ⚠ Video still exists after deletion: {verify_job.status}")
return False
except openai.NotFoundError as e:
print(" ✓ Video correctly returns NotFoundError")
print(f" Error message: {e.message}")
except Exception as e:
print(f" ⚠ Unexpected error: {type(e).__name__}: {e}")
# Step 5: Test deleting non-existent video
print("\n5. Testing deletion of non-existent video...")
fake_id = "nonexistent_video_id"
try:
fake_delete_result = client.videos.delete(fake_id)
print(" ⚠ Deletion of non-existent video did not raise error")
print(f" Response: {fake_delete_result.model_dump_json(indent=2)}")
except openai.NotFoundError as e:
print(" ✓ Correctly raises NotFoundError for non-existent video")
print(f" Error message: {e.message}")
except Exception as e:
print(f" ⚠ Unexpected error: {type(e).__name__}: {e}")
print("\n" + "=" * 80)
print("✓ Video deletion test completed successfully!")
print("=" * 80)
return True
except Exception as e:
print(f"\n❌ Error: {e}")
import traceback
traceback.print_exc()
return False
if __name__ == "__main__":
# Parse command line arguments
base_url = sys.argv[1] if len(sys.argv) > 1 else "http://localhost:8000/v1"
print("\n" + "=" * 80)
print("OpenAI SDK - Video Deletion Test")
print("=" * 80)
print(f"Base URL: {base_url}")
print()
# Test video deletion
success = test_delete_video(base_url=base_url)
# Exit with appropriate code
sys.exit(0 if success else 1)

Binary file not shown.

After

Width:  |  Height:  |  Size: 174 KiB

View File

@ -0,0 +1,91 @@
#!/usr/bin/env python
"""Test script for image generation endpoints.
Tests:
- POST /v1/images/generations - Generate images from text
- POST /v1/images/edits - Edit images with text prompts
"""
import base64
import sys
import openai
def test_image_generation(
base_url: str = "http://localhost:8000/v1",
model: str = "flux2",
prompt: str = "A lovely cat lying on a sofa",
n: int = 1,
size: str = "512x512",
quality: str = "standard",
response_format: str = "b64_json",
output_file: str = "output_generation.png",
):
"""Test image generation endpoint."""
print("=" * 80)
print("Testing Image Generation API (POST /v1/images/generations)")
print("=" * 80)
# Initialize client
client = openai.OpenAI(base_url=base_url, api_key="tensorrt_llm")
print("\n1. Generating image...")
print(f" Prompt: {prompt}")
print(f" Size: {size}")
print(f" Quality: {quality}")
print(f" Number of images: {n}")
try:
# Use OpenAI SDK's images.generate() method
response = client.images.generate(
model=model,
prompt=prompt,
n=n,
size=size,
quality=quality,
response_format=response_format,
)
print("\n✓ Image generated successfully!")
print(f" Number of images: {len(response.data)}")
# Save images
for i, image in enumerate(response.data):
if response_format == "b64_json":
# Decode base64 image
image_data = base64.b64decode(image.b64_json)
output = f"{output_file.rsplit('.', 1)[0]}_{i}.png" if n > 1 else output_file
with open(output, "wb") as f:
f.write(image_data)
print(f" ✓ Saved image {i + 1} to: {output} ({len(image_data)} bytes)")
else:
print(f" Image {i + 1} URL: {image.url}")
print("\n" + "=" * 80)
print("✓ Image generation test completed successfully!")
print("=" * 80)
return True
except Exception as e:
print(f"\n❌ Error: {e}")
import traceback
traceback.print_exc()
return False
if __name__ == "__main__":
# Parse command line arguments
base_url = sys.argv[1] if len(sys.argv) > 1 else "http://localhost:8000/v1"
print("\n" + "=" * 80)
print("OpenAI SDK - Image Generation Tests")
print("=" * 80)
print(f"Base URL: {base_url}")
print()
# Test image generation
test_image_generation(base_url=base_url)

View File

@ -0,0 +1,224 @@
#!/usr/bin/env python
"""Test script for synchronous video generation endpoint.
Tests POST /v1/videos/generations endpoint which waits for completion and returns video data.
The video is generated synchronously and the response contains the video file.
Supports two modes:
- Text-to-Video (T2V): Generate video from text prompt only
- Text+Image-to-Video (TI2V): Generate video from text prompt + reference image
Examples:
# Text-to-Video (T2V)
python sync_video_gen.py --mode t2v --prompt "A cool cat on a motorcycle"
# Text+Image-to-Video (TI2V)
python sync_video_gen.py --mode ti2v --prompt "She turns and smiles" --image ./media/woman.jpg
"""
import argparse
import sys
from pathlib import Path
import requests
def test_sync_video_generation(
base_url: str = "http://localhost:8000/v1",
model: str = "wan",
prompt: str = "A video of a cute cat playing with a ball in the park",
input_reference: str = None,
duration: float = 4.0,
fps: int = 24,
size: str = "256x256",
output_file: str = "output_sync.mp4",
):
"""Test synchronous video generation with direct HTTP requests.
Args:
base_url: Base URL of the API server
model: Model name to use
prompt: Text prompt for generation
input_reference: Path to reference image (optional, for TI2V mode)
duration: Video duration in seconds
fps: Frames per second
size: Video resolution (WxH format)
output_file: Output video file path
"""
mode = "TI2V" if input_reference else "T2V"
print("=" * 80)
print(f"Testing Sync Video Generation API - {mode} Mode")
print("=" * 80)
print("\n1. Generating video (waiting for completion)...")
print(f" Mode: {mode}")
print(f" Prompt: {prompt}")
if input_reference:
print(f" Input Reference: {input_reference}")
print(f" Duration: {duration}s")
print(f" FPS: {fps}")
print(f" Size: {size}")
try:
endpoint = f"{base_url}/videos/generations"
if input_reference:
# TI2V mode - Use multipart/form-data with file upload
if not Path(input_reference).exists():
print(f"\n❌ Error: Input reference image not found: {input_reference}")
return False
# Prepare form data (all values as strings for multipart)
form_data = {
"model": model,
"prompt": prompt,
"size": size,
"seconds": str(duration),
"fps": str(fps),
}
# Add the file
## Note: The content-type must be multipart/form-data.
files = {
"input_reference": (
Path(input_reference).name,
open(input_reference, "rb"),
"multipart/form-data",
)
}
print("\n Uploading reference image and generating video...")
response_video = requests.post(endpoint, data=form_data, files=files)
else:
# T2V mode - Use JSON
response_video = requests.post(
endpoint,
json={
"model": model,
"prompt": prompt,
"size": size,
"seconds": duration,
"fps": fps,
},
)
print(f"\nStatus code: {response_video.status_code}")
if response_video.status_code == 200:
with open(output_file, "wb") as f:
f.write(response_video.content)
print(f"✓ Video saved to: {output_file}")
print("\n" + "=" * 80)
print("✓ Sync video generation test completed successfully!")
print("=" * 80)
return True
else:
print(f"\n❌ Error: Server returned status {response_video.status_code}")
print(f"Response: {response_video.text}")
return False
except Exception as e:
print(f"\n❌ Error: {e}")
import traceback
traceback.print_exc()
return False
if __name__ == "__main__":
parser = argparse.ArgumentParser(
description="Test synchronous video generation API with T2V and TI2V modes",
formatter_class=argparse.RawDescriptionHelpFormatter,
epilog="""
Examples:
# Text-to-Video (T2V)
python sync_video_gen.py --mode t2v --prompt "A cool cat on a motorcycle"
# Text+Image-to-Video (TI2V)
python sync_video_gen.py --mode ti2v \\
--prompt "She turns around and smiles, then slowly walks out of the frame" \\
--image ./media/woman_skyline_original_720p.jpeg
# Custom parameters
python sync_video_gen.py --mode t2v \\
--prompt "A serene sunset over the ocean" \\
--duration 5.0 --fps 30 --size 512x512 \\
--output my_video.mp4
""",
)
# Mode selection
parser.add_argument(
"--mode",
choices=["t2v", "ti2v"],
default="t2v",
help="Generation mode: t2v (Text-to-Video) or ti2v (Text+Image-to-Video)",
)
# Required parameters
parser.add_argument(
"--prompt",
type=str,
default="A video of a cute cat playing with a ball in the park",
help="Text prompt for video generation",
)
# TI2V mode parameters
parser.add_argument(
"--image",
"--input-reference",
type=str,
default=None,
help="Path to reference image (required for ti2v mode)",
)
# Optional parameters
parser.add_argument(
"--base-url",
type=str,
default="http://localhost:8000/v1",
help="Base URL of the API server",
)
parser.add_argument("--model", type=str, default="wan", help="Model name to use")
parser.add_argument(
"--duration", "--seconds", type=float, default=4.0, help="Video duration in seconds"
)
parser.add_argument("--fps", type=int, default=24, help="Frames per second")
parser.add_argument(
"--size",
type=str,
default="256x256",
help="Video resolution in WxH format (e.g., 1280x720)",
)
parser.add_argument(
"--output", type=str, default="output_sync.mp4", help="Output video file path"
)
args = parser.parse_args()
# Validate ti2v mode requirements
if args.mode == "ti2v" and not args.image:
parser.error("--image is required when using --mode ti2v")
# Display configuration
print("\n" + "=" * 80)
print("Synchronous Video Generation Test")
print("=" * 80)
print(f"Base URL: {args.base_url}")
print(f"Mode: {args.mode.upper()}")
print()
# Test sync video generation
success = test_sync_video_generation(
base_url=args.base_url,
model=args.model,
prompt=args.prompt,
input_reference=args.image,
duration=args.duration,
fps=args.fps,
size=args.size,
output_file=args.output,
)
sys.exit(0 if success else 1)

View File

@ -0,0 +1,238 @@
#!/bin/bash
# Visual Generation Examples - Test different models and configurations
#
# This script runs a comprehensive suite of visual generation examples including:
# - WAN T2V: Baseline, TeaCache, CFG parallelism, Ulysses parallelism, and combinations
# - WAN I2V: Baseline, TeaCache, CFG parallelism, Ulysses parallelism, and combinations
#
# The script automatically detects GPU count and runs appropriate examples:
# - 1 GPU: Single-GPU examples only
# - 2 GPUs: + CFG parallelism, Ulysses parallelism
# - 4 GPUs: + CFG + Ulysses combined
# - 8 GPUs: + Large-scale high-resolution examples
#
# Usage:
# export MODEL_ROOT=/path/to/models # required
# # Optional: PROJECT_ROOT auto-detected when run from examples/visual_gen
# cd examples/visual_gen && ./visual_gen_examples.sh
#
# Or inline:
# MODEL_ROOT=/llm-models ./visual_gen_examples.sh
set -e # Exit on error
# Environment variables with defaults
# PROJECT_ROOT: auto-detect repo root when run from examples/visual_gen
SCRIPT_DIR="$(cd "$(dirname "$0")" && pwd)"
PROJECT_ROOT=${PROJECT_ROOT:-"$(cd "${SCRIPT_DIR}/../.." && pwd)"}
MODEL_ROOT=${MODEL_ROOT:-"/llm-models"}
# Log configuration
export TLLM_LOG_LEVEL=${TLLM_LOG_LEVEL:-"INFO"}
echo "============================================"
echo "Visual Generation Examples"
echo "============================================"
echo "PROJECT_ROOT: $PROJECT_ROOT"
echo "MODEL_ROOT: $MODEL_ROOT"
echo "LOG_LEVEL: $TLLM_LOG_LEVEL"
echo "============================================"
echo ""
# Detect GPU count
if command -v nvidia-smi &> /dev/null; then
GPU_COUNT=$(nvidia-smi --query-gpu=name --format=csv,noheader | wc -l)
echo "Detected $GPU_COUNT GPU(s)"
if [ "$GPU_COUNT" -lt 2 ]; then
echo "Note: Multi-GPU examples will be skipped"
SKIP_MULTI_GPU=1
elif [ "$GPU_COUNT" -ge 8 ]; then
echo "Note: Will run all examples including 8-GPU configurations"
elif [ "$GPU_COUNT" -ge 4 ]; then
echo "Note: Will run examples up to 4-GPU configurations"
else
echo "Note: Will run 2-GPU examples only"
fi
else
echo "WARNING: nvidia-smi not found. Assuming single GPU."
GPU_COUNT=1
SKIP_MULTI_GPU=1
fi
echo ""
#############################################
# WAN (Wan2.1) Text-to-Video Examples
#############################################
# Demonstrates:
# - Single GPU: Baseline and TeaCache
# - 2 GPUs: CFG only, Ulysses only
# - 4 GPUs: CFG + Ulysses combined
# - 8 GPUs: Large-scale parallelism
#############################################
echo "=== WAN Example 1: Baseline (no optimization) ==="
python ${PROJECT_ROOT}/examples/visual_gen/visual_gen_wan_t2v.py \
--height 480 \
--width 832 \
--num_frames 33 \
--model_path ${MODEL_ROOT}/Wan2.1-T2V-1.3B-Diffusers/ \
--prompt "A cute cat playing piano" \
--output_path wan_cat_piano.png
echo ""
echo "=== WAN Example 2: With TeaCache ==="
python ${PROJECT_ROOT}/examples/visual_gen/visual_gen_wan_t2v.py \
--height 480 \
--width 832 \
--num_frames 33 \
--model_path ${MODEL_ROOT}/Wan2.1-T2V-1.3B-Diffusers \
--prompt "A cute cat playing piano" \
--output_path wan_cat_piano_teacache.png \
--enable_teacache
if [ -z "$SKIP_MULTI_GPU" ]; then
echo ""
echo "=== WAN Example 3: CFG Only (2 GPUs) ==="
python ${PROJECT_ROOT}/examples/visual_gen/visual_gen_wan_t2v.py \
--height 480 \
--width 832 \
--num_frames 33 \
--model_path ${MODEL_ROOT}/Wan2.1-T2V-1.3B-Diffusers/ \
--prompt "A cute cat playing piano" \
--output_path wan_cfg_2gpu.mp4 \
--attention_backend TRTLLM \
--cfg_size 2 \
--ulysses_size 1
else
echo ""
echo "=== WAN Example 3: Skipped (requires 2 GPUs) ==="
fi
if [ -z "$SKIP_MULTI_GPU" ]; then
echo ""
echo "=== WAN Example 4: Ulysses Only (2 GPUs) ==="
python ${PROJECT_ROOT}/examples/visual_gen/visual_gen_wan_t2v.py \
--height 480 \
--width 832 \
--num_frames 33 \
--model_path ${MODEL_ROOT}/Wan2.1-T2V-1.3B-Diffusers/ \
--prompt "A cute cat playing piano" \
--output_path wan_ulysses_2gpu.mp4 \
--attention_backend TRTLLM \
--cfg_size 1 \
--ulysses_size 2
else
echo ""
echo "=== WAN Example 4: Skipped (requires 2 GPUs) ==="
fi
if [ "$GPU_COUNT" -ge 4 ]; then
echo ""
echo "=== WAN Example 5: CFG + Ulysses (4 GPUs) ==="
python ${PROJECT_ROOT}/examples/visual_gen/visual_gen_wan_t2v.py \
--height 480 \
--width 832 \
--num_frames 33 \
--model_path ${MODEL_ROOT}/Wan2.1-T2V-1.3B-Diffusers/ \
--prompt "A cute cat playing piano" \
--output_path wan_cfg_ulysses_4gpu.mp4 \
--attention_backend TRTLLM \
--cfg_size 2 \
--ulysses_size 2
else
echo ""
echo "=== WAN Example 5: Skipped (requires 4 GPUs) ==="
fi
if [ "$GPU_COUNT" -ge 8 ]; then
echo ""
echo "=== WAN Example 6: Large-Scale (8 GPUs) ==="
python ${PROJECT_ROOT}/examples/visual_gen/visual_gen_wan_t2v.py \
--height 480 \
--width 832 \
--num_frames 33 \
--model_path ${MODEL_ROOT}/Wan2.1-T2V-1.3B-Diffusers/ \
--prompt "A cute cat playing piano" \
--output_path wan_cfg_ulysses_8gpu.mp4 \
--attention_backend TRTLLM \
--cfg_size 2 \
--ulysses_size 4
else
echo ""
echo "=== WAN Example 6: Skipped (requires 8 GPUs) ==="
fi
#############################################
# WAN 2.2 (Two-Stage) Text-to-Video Examples
#############################################
echo ""
echo "=== WAN 2.2 T2V Example: Two-stage with optimizations (FP8 + TRT-LLM + TeaCache) ==="
python ${PROJECT_ROOT}/examples/visual_gen/visual_gen_wan_t2v.py \
--height 720 \
--width 1280 \
--num_frames 81 \
--model_path ${MODEL_ROOT}/Wan2.2-T2V-A14B-Diffusers \
--prompt "A cute cat playing piano" \
--output_path wan22_t2v_cat_piano_optimized.gif \
--linear_type trtllm-fp8-blockwise \
--attention_backend TRTLLM \
--enable_teacache \
--teacache_thresh 0.2 \
--guidance_scale 3.0 \
--guidance_scale_2 2.5 \
--boundary_ratio 0.85
#############################################
# WAN 2.1 Image-to-Video Examples
#############################################
echo ""
echo "=== WAN 2.1 I2V Example: Single-stage with optimizations (FP8 + TRT-LLM + TeaCache) ==="
python ${PROJECT_ROOT}/examples/visual_gen/visual_gen_wan_i2v.py \
--height 480 \
--width 832 \
--num_frames 33 \
--model_path ${MODEL_ROOT}/Wan2.1-I2V-14B-480P-Diffusers \
--image_path ${PROJECT_ROOT}/examples/visual_gen/cat_piano.png \
--prompt "It snows as the cat plays piano, lots of snow \
appearing all over the screen, snowflakes, blizzard,
gradually more snow" \
--negative_prompt "blurry, low quality" \
--output_path wan21_i2v_cat_piano_optimized.gif \
--linear_type trtllm-fp8-per-tensor \
--attention_backend TRTLLM \
--enable_teacache \
--teacache_thresh 0.2 \
--guidance_scale 6.0
#############################################
# WAN 2.2 (Two-Stage) Image-to-Video Examples
#############################################
echo ""
echo "=== WAN 2.2 I2V Example: Two-stage with optimizations (FP8 + TRT-LLM + TeaCache) ==="
python ${PROJECT_ROOT}/examples/visual_gen/visual_gen_wan_i2v.py \
--height 480 \
--width 832 \
--num_frames 81 \
--model_path ${MODEL_ROOT}/Wan2.2-I2V-A14B-Diffusers \
--image_path ${PROJECT_ROOT}/examples/visual_gen/cat_piano.png \
--prompt "It snows as the cat plays piano, lots of snow \
appearing all over the screen, snowflakes, blizzard,
gradually more snow" \
--negative_prompt "blurry, low quality" \
--output_path wan22_i2v_cat_piano_optimized.gif \
--linear_type trtllm-fp8-blockwise \
--attention_backend TRTLLM \
--enable_teacache \
--teacache_thresh 0.2 \
--guidance_scale 6.0 \
--guidance_scale_2 5.0 \
--boundary_ratio 0.85
echo ""
echo "============================================"
echo "All examples completed successfully!"
echo "============================================"

View File

@ -0,0 +1,226 @@
#!/usr/bin/env python3
"""WAN Image-to-Video generation using TensorRT-LLM Visual Generation."""
import argparse
import time
from output_handler import OutputHandler
from tensorrt_llm import logger
from tensorrt_llm.llmapi.visual_gen import VisualGen, VisualGenParams
# Set logger level to ensure timing logs are printed
logger.set_level("info")
def parse_args():
parser = argparse.ArgumentParser(
description="TRTLLM VisualGen - Wan Image-to-Video Inference Example (supports Wan 2.1 and Wan 2.2)"
)
# Model & Input
parser.add_argument(
"--model_path",
type=str,
required=True,
help="Path to Wan I2V Diffusers model directory (2.1 or 2.2)",
)
parser.add_argument(
"--image_path",
type=str,
required=True,
help="Path to input image for I2V conditioning",
)
parser.add_argument(
"--last_image_path",
type=str,
default=None,
help="Optional path to last frame image for interpolation (Wan 2.1 only)",
)
parser.add_argument("--prompt", type=str, required=True, help="Text prompt for generation")
parser.add_argument(
"--negative_prompt",
type=str,
default=None,
help="Negative prompt. Default is model-specific.",
)
parser.add_argument(
"--output_path",
type=str,
default="output.png",
help="Path to save the output image/video frame",
)
# Generation Params
parser.add_argument("--height", type=int, default=720, help="Video height")
parser.add_argument("--width", type=int, default=1280, help="Video width")
parser.add_argument("--num_frames", type=int, default=81, help="Number of frames to generate")
parser.add_argument(
"--steps",
type=int,
default=None,
help="Number of denoising steps (default: auto-detect, 50 for Wan2.1, 40 for Wan2.2)",
)
parser.add_argument(
"--guidance_scale",
type=float,
default=None,
help="Guidance scale (default: auto-detect, 5.0 for Wan2.1, 4.0 for Wan2.2)",
)
parser.add_argument(
"--guidance_scale_2",
type=float,
default=None,
help="Second-stage guidance scale for Wan2.2 two-stage denoising (default: 3.0)",
)
parser.add_argument(
"--boundary_ratio",
type=float,
default=None,
help="Custom boundary ratio for two-stage denoising (default: auto-detect)",
)
parser.add_argument("--seed", type=int, default=42, help="Random seed")
# TeaCache Arguments
parser.add_argument(
"--enable_teacache", action="store_true", help="Enable TeaCache acceleration"
)
parser.add_argument(
"--teacache_thresh",
type=float,
default=0.2,
help="TeaCache similarity threshold (rel_l1_thresh)",
)
# Quantization
parser.add_argument(
"--linear_type",
type=str,
default="default",
choices=["default", "trtllm-fp8-per-tensor", "trtllm-fp8-blockwise", "svd-nvfp4"],
help="Linear layer quantization type",
)
# Attention Backend
parser.add_argument(
"--attention_backend",
type=str,
default="VANILLA",
choices=["VANILLA", "TRTLLM"],
help="Attention backend (VANILLA: PyTorch SDPA, TRTLLM: optimized kernels). "
"Note: TRTLLM automatically falls back to VANILLA for cross-attention.",
)
# Parallelism
parser.add_argument(
"--cfg_size",
type=int,
default=1,
choices=[1, 2],
help="CFG parallel size (1 or 2). Set to 2 for CFG Parallelism.",
)
parser.add_argument(
"--ulysses_size",
type=int,
default=1,
help="Ulysses (sequence) parallel size within each CFG group.",
)
return parser.parse_args()
def main():
args = parse_args()
# world_size = cfg_size * ulysses_size
# Example: cfg_size=2, ulysses_size=4 -> 8 GPUs
# GPU 0-3: CFG group 0 (positive prompt), internal Ulysses parallel
# GPU 4-7: CFG group 1 (negative prompt), internal Ulysses parallel
n_workers = args.cfg_size * args.ulysses_size
# Convert linear_type to quant_config
quant_config = None
if args.linear_type == "trtllm-fp8-per-tensor":
quant_config = {"quant_algo": "FP8", "dynamic": True}
elif args.linear_type == "trtllm-fp8-blockwise":
quant_config = {"quant_algo": "FP8_BLOCK_SCALES", "dynamic": True}
elif args.linear_type == "svd-nvfp4":
quant_config = {"quant_algo": "NVFP4", "dynamic": True}
# 1. Setup Configuration
diffusion_config = {
"model_type": "wan2",
"attention": {
"backend": args.attention_backend,
},
"teacache": {
"enable_teacache": args.enable_teacache,
"teacache_thresh": args.teacache_thresh,
},
"parallel": {
"dit_cfg_size": args.cfg_size,
"dit_ulysses_size": args.ulysses_size,
},
}
# Add quant_config if specified
if quant_config is not None:
diffusion_config["quant_config"] = quant_config
# 2. Initialize VisualGen
logger.info(
f"Initializing VisualGen: world_size={n_workers} "
f"(cfg_size={args.cfg_size}, ulysses_size={args.ulysses_size})"
)
visual_gen = VisualGen(
model_path=args.model_path,
n_workers=n_workers,
diffusion_config=diffusion_config,
)
try:
# 2. Run Inference
logger.info(f"Generating video for prompt: '{args.prompt}'")
logger.info(f"Negative prompt: '{args.negative_prompt}'")
logger.info(f"Input image: {args.image_path}")
if args.last_image_path:
logger.info(f"Last frame image: {args.last_image_path}")
logger.info(
f"Resolution: {args.height}x{args.width}, Frames: {args.num_frames}, Steps: {args.steps}"
)
start_time = time.time()
# Build parameters with explicit I2V and Wan 2.2 support
output = visual_gen.generate(
inputs={
"prompt": args.prompt,
"negative_prompt": args.negative_prompt,
},
params=VisualGenParams(
height=args.height,
width=args.width,
num_inference_steps=args.steps,
guidance_scale=args.guidance_scale,
seed=args.seed,
num_frames=args.num_frames,
input_reference=args.image_path,
last_image=args.last_image_path if args.last_image_path else None,
guidance_scale_2=args.guidance_scale_2,
boundary_ratio=args.boundary_ratio,
),
)
end_time = time.time()
logger.info(f"Generation completed in {end_time - start_time:.2f}s")
# 3. Save Output
OutputHandler.save(output, args.output_path, frame_rate=16.0)
finally:
# 4. Shutdown
visual_gen.shutdown()
if __name__ == "__main__":
main()

View File

@ -0,0 +1,228 @@
#!/usr/bin/env python3
"""WAN Text-to-Video generation using TensorRT-LLM Visual Generation."""
import argparse
import time
from output_handler import OutputHandler
from tensorrt_llm import logger
from tensorrt_llm.llmapi.visual_gen import VisualGen, VisualGenParams
# Set logger level to ensure timing logs are printed
logger.set_level("info")
def parse_args():
parser = argparse.ArgumentParser(
description="TRTLLM VisualGen - Wan Text-to-Video Inference Example (supports Wan 2.1 and Wan 2.2)"
)
# Model & Input
parser.add_argument(
"--model_path",
type=str,
required=True,
help="Local path or HuggingFace Hub model ID (e.g., Wan-AI/Wan2.1-T2V-1.3B-Diffusers)",
)
parser.add_argument(
"--revision",
type=str,
default=None,
help="HuggingFace Hub revision (branch, tag, or commit SHA)",
)
parser.add_argument("--prompt", type=str, required=True, help="Text prompt for generation")
parser.add_argument(
"--negative_prompt",
type=str,
default=None,
help="Negative prompt. Default is model-specific.",
)
parser.add_argument(
"--output_path",
type=str,
default="output.png",
help="Path to save the output image/video frame",
)
# Generation Params
parser.add_argument("--height", type=int, default=720, help="Video height")
parser.add_argument("--width", type=int, default=1280, help="Video width")
parser.add_argument("--num_frames", type=int, default=81, help="Number of frames to generate")
parser.add_argument(
"--steps",
type=int,
default=None,
help="Number of denoising steps (default: auto-detect, 50 for Wan2.1, 40 for Wan2.2)",
)
parser.add_argument(
"--guidance_scale",
type=float,
default=None,
help="Guidance scale (default: auto-detect, 5.0 for Wan2.1, 4.0 for Wan2.2)",
)
parser.add_argument(
"--guidance_scale_2",
type=float,
default=None,
help="Second-stage guidance scale for Wan2.2 two-stage denoising (default: 3.0)",
)
parser.add_argument(
"--boundary_ratio",
type=float,
default=None,
help="Custom boundary ratio for two-stage denoising (default: auto-detect)",
)
parser.add_argument("--seed", type=int, default=42, help="Random seed")
# TeaCache Arguments
parser.add_argument(
"--enable_teacache", action="store_true", help="Enable TeaCache acceleration"
)
parser.add_argument(
"--teacache_thresh",
type=float,
default=0.2,
help="TeaCache similarity threshold (rel_l1_thresh)",
)
# Quantization
parser.add_argument(
"--linear_type",
type=str,
default="default",
choices=["default", "trtllm-fp8-per-tensor", "trtllm-fp8-blockwise", "svd-nvfp4"],
help="Linear layer quantization type",
)
# Attention Backend
parser.add_argument(
"--attention_backend",
type=str,
default="VANILLA",
choices=["VANILLA", "TRTLLM"],
help="Attention backend (VANILLA: PyTorch SDPA, TRTLLM: optimized kernels). "
"Note: TRTLLM automatically falls back to VANILLA for cross-attention.",
)
# Parallelism
parser.add_argument(
"--cfg_size",
type=int,
default=1,
choices=[1, 2],
help="CFG parallel size (1 or 2). "
"Distributes positive/negative prompts across GPUs. "
"Example: cfg_size=2 on 4 GPUs -> 2 GPUs per prompt.",
)
parser.add_argument(
"--ulysses_size",
type=int,
default=1,
help="Ulysses sequence parallel size within each CFG group. "
"Distributes sequence across GPUs for longer sequences. "
"Requirements: num_heads (12) and sequence length must both be divisible by ulysses_size. "
"Example: ulysses_size=2 on 4 GPUs with cfg_size=2 -> "
"2 CFG groups × 2 Ulysses ranks = 4 GPUs total.",
)
return parser.parse_args()
def main():
args = parse_args()
# Total workers: cfg_size × ulysses_size
# See ParallelConfig in config.py for detailed parallelism strategy and examples
n_workers = args.cfg_size * args.ulysses_size
# Log Ulysses configuration (validation happens in setup_sequence_parallelism)
if args.ulysses_size > 1:
num_heads = 12 # WAN has 12 attention heads
logger.info(
f"Using Ulysses sequence parallelism: "
f"{num_heads} heads / {args.ulysses_size} ranks = "
f"{num_heads // args.ulysses_size} heads per GPU"
)
# Convert linear_type to quant_config
quant_config = None
if args.linear_type == "trtllm-fp8-per-tensor":
quant_config = {"quant_algo": "FP8", "dynamic": True}
elif args.linear_type == "trtllm-fp8-blockwise":
quant_config = {"quant_algo": "FP8_BLOCK_SCALES", "dynamic": True}
elif args.linear_type == "svd-nvfp4":
quant_config = {"quant_algo": "NVFP4", "dynamic": True}
# 1. Setup Configuration
diffusion_config = {
"model_type": "wan2",
"revision": args.revision,
"attention": {
"backend": args.attention_backend,
},
"teacache": {
"enable_teacache": args.enable_teacache,
"teacache_thresh": args.teacache_thresh,
},
"parallel": {
"dit_cfg_size": args.cfg_size,
"dit_ulysses_size": args.ulysses_size,
},
}
# Add quant_config if specified
if quant_config is not None:
diffusion_config["quant_config"] = quant_config
# 2. Initialize VisualGen
logger.info(
f"Initializing VisualGen: world_size={n_workers} "
f"(cfg_size={args.cfg_size}, ulysses_size={args.ulysses_size})"
)
visual_gen = VisualGen(
model_path=args.model_path,
n_workers=n_workers,
diffusion_config=diffusion_config,
)
try:
# 2. Run Inference
logger.info(f"Generating video for prompt: '{args.prompt}'")
logger.info(f"Negative prompt: '{args.negative_prompt}'")
logger.info(
f"Resolution: {args.height}x{args.width}, Frames: {args.num_frames}, Steps: {args.steps}"
)
start_time = time.time()
output = visual_gen.generate(
inputs={
"prompt": args.prompt,
"negative_prompt": args.negative_prompt,
},
params=VisualGenParams(
height=args.height,
width=args.width,
num_inference_steps=args.steps,
guidance_scale=args.guidance_scale,
seed=args.seed,
num_frames=args.num_frames,
guidance_scale_2=args.guidance_scale_2,
boundary_ratio=args.boundary_ratio,
),
)
end_time = time.time()
logger.info(f"Generation completed in {end_time - start_time:.2f}s")
# 3. Save Output
OutputHandler.save(output, args.output_path, frame_rate=16.0)
finally:
# 4. Shutdown
visual_gen.shutdown()
if __name__ == "__main__":
main()

View File

@ -83,3 +83,4 @@ llist
cuda-tile>=1.0.1
nvidia-cuda-tileiras>=13.1
etcd-sdk-python==0.0.7
python-multipart

View File

@ -4,10 +4,11 @@ from .communicator import Distributed, MPIDist, TorchDist
from .moe_alltoall import MoeAlltoAll
from .ops import (AllReduce, AllReduceParams, AllReduceStrategy,
HelixAllToAllNative, MoEAllReduce, MoEAllReduceParams,
allgather, alltoall_helix, cp_allgather, reducescatter,
userbuffers_allreduce_finalize)
all_to_all_4d, allgather, alltoall_helix, cp_allgather,
reducescatter, userbuffers_allreduce_finalize)
__all__ = [
"all_to_all_4d",
"allgather",
"alltoall_helix",
"cp_allgather",

View File

@ -959,3 +959,126 @@ class MoEAllReduce(nn.Module):
nranks=self.mapping.tp_size,
eps=all_reduce_params.eps,
)
def all_to_all_4d(
input: torch.Tensor,
scatter_dim: int,
gather_dim: int,
process_group: Optional[torch.distributed.ProcessGroup] = None,
) -> torch.Tensor:
"""
All-to-all for 4D tensors (batch, seq, heads, head_dim).
Redistributes a 4D tensor along two dimensions using all-to-all communication.
This is used for Ulysses-style sequence parallelism to transform between:
- Sequence sharding [B, S/P, H, D] Head sharding [B, S, H/P, D]
- Head sharding [B, S, H/P, D] Sequence sharding [B, S/P, H, D]
Args:
input: Input tensor with shape [batch, seq, heads, head_dim]
scatter_dim: Dimension to split and scatter (1 for seq, 2 for heads)
gather_dim: Dimension to gather (1 for seq, 2 for heads)
process_group: PyTorch distributed process group. If None, uses default process group.
Returns:
Redistributed tensor with same shape as input
Example:
# Transform from sequence sharding to head sharding
# Input: [B, S/P, H, D] (each rank has S/P sequence)
output = all_to_all_4d(input, scatter_dim=2, gather_dim=1, process_group=pg)
# Output: [B, S, H/P, D] (each rank has H/P heads)
# Transform back from head sharding to sequence sharding
output = all_to_all_4d(input, scatter_dim=1, gather_dim=2, process_group=pg)
"""
# Only support PyTorch distributed mode (not MPI mode)
if not mpi_disabled():
raise NotImplementedError(
"all_to_all_4d currently only supports PyTorch distributed mode. "
"MPI mode is not supported.")
# Get world size from process group
world_size = torch.distributed.get_world_size(group=process_group)
# If world_size is 1, no communication needed
if world_size == 1:
return input
# Validate dimensions
assert scatter_dim in [1, 2], "scatter_dim must be 1 (seq) or 2 (heads)"
assert gather_dim in [1, 2], "gather_dim must be 1 (seq) or 2 (heads)"
assert scatter_dim != gather_dim, "scatter_dim and gather_dim must be different"
batch, seq, heads, head_dim = input.shape
# Validate that the scatter dimension is divisible by world_size
scatter_size = input.shape[scatter_dim]
assert scatter_size % world_size == 0, \
f"Dimension {scatter_dim} size {scatter_size} must be divisible by world_size {world_size}"
# For all-to-all, we need to:
# 1. Split input along scatter_dim into world_size chunks
# 2. Send chunk i to rank i
# 3. Receive chunk from each rank and concatenate along gather_dim
# Reshape for all-to-all: move scatter_dim chunks to a new dimension
if scatter_dim == 1: # Scatter along seq dimension
# [B, S, H, D] -> [B, P, S/P, H, D] where P = world_size
input_reshaped = input.view(batch, world_size, seq // world_size, heads,
head_dim)
# Transpose to group by destination rank: [B, P, S/P, H, D] -> [P, B, S/P, H, D]
input_transposed = input_reshaped.permute(1, 0, 2, 3, 4).contiguous()
else: # scatter_dim == 2, scatter along heads dimension
# [B, S, H, D] -> [B, S, P, H/P, D] where P = world_size
input_reshaped = input.view(batch, seq, world_size, heads // world_size,
head_dim)
# Transpose to group by destination rank: [B, S, P, H/P, D] -> [P, B, S, H/P, D]
input_transposed = input_reshaped.permute(2, 0, 1, 3, 4).contiguous()
# Flatten to [P * ...] for all-to-all communication
# Shape: [P, B, ...] -> [P * B * ...]
input_flat = input_transposed.flatten()
output_flat = torch.empty_like(input_flat)
# Perform all-to-all communication using PyTorch distributed
# all_to_all_single splits input into world_size chunks and exchanges them
torch.distributed.all_to_all_single(output_flat,
input_flat,
group=process_group)
# Reshape output back to [P, B, ...] form
output_transposed = output_flat.view_as(input_transposed)
# Transpose back and reshape to final form
if gather_dim == 1: # Gather along seq dimension
# [P, B, S/P, H, D] -> [B, P, S/P, H, D]
output_reshaped = output_transposed.permute(1, 0, 2, 3, 4).contiguous()
# [B, P, S/P, H, D] -> [B, S, H, D] where S = P * (S/P)
# When scattering heads and gathering seq: seq needs to be multiplied, heads needs to be divided
if scatter_dim == 2:
# Scattered heads, so we have H/P heads and need to gather S/P -> S sequence
gathered_seq = seq * world_size
sharded_heads = heads // world_size
output = output_reshaped.view(batch, gathered_seq, sharded_heads,
head_dim)
else:
# Scattered seq (should be impossible if gather_dim == 1), keep as is
output = output_reshaped.view(batch, seq, heads, head_dim)
else: # gather_dim == 2, gather along heads dimension
# [P, B, S, H/P, D] -> [B, S, P, H/P, D]
output_reshaped = output_transposed.permute(1, 2, 0, 3, 4).contiguous()
# [B, S, P, H/P, D] -> [B, S, H, D] where H = P * (H/P)
# When scattering seq and gathering heads: heads needs to be multiplied, seq needs to be divided
if scatter_dim == 1:
# Scattered seq, so we have S/P seq and need to gather H/P -> H heads
gathered_heads = heads * world_size
sharded_seq = seq // world_size
output = output_reshaped.view(batch, sharded_seq, gathered_heads,
head_dim)
else:
# Scattered heads (should be impossible if gather_dim == 2), keep as is
output = output_reshaped.view(batch, seq, heads, head_dim)
return output

View File

@ -556,6 +556,13 @@ class FP8QDQLinearMethod(UnquantizedLinearMethod):
def apply(self, module: Linear, input: torch.Tensor,
bias: Optional[torch.Tensor]):
# Handle multi-dimensional inputs (e.g., 3D: batch, seq, hidden)
# GEMM ops require 2D matrices
original_shape = input.shape
if input.dim() > 2:
input = input.reshape(-1, input.shape[-1])
cur_input_scale = module.input_scale
if input.dtype != torch.float8_e4m3fn:
if module.input_scale is not None and not module.force_dynamic_quantization:
@ -591,6 +598,11 @@ class FP8QDQLinearMethod(UnquantizedLinearMethod):
bias=None,
out_dtype=module.dtype or input.dtype,
)
# Reshape output back to original shape (with out_features as last dim)
if len(original_shape) > 2:
output = output.reshape(*original_shape[:-1], output.shape[-1])
if bias is not None:
output = output + bias
return output
@ -975,6 +987,12 @@ class FP8BlockScalesLinearMethod(UnquantizedLinearMethod):
def apply(self, module: Linear, input: torch.Tensor,
bias: Optional[torch.Tensor]):
# Handle multi-dimensional inputs (e.g., 3D: batch, seq, hidden)
# GEMM ops require 2D matrices
original_shape = input.shape
if input.dim() > 2:
input = input.reshape(-1, input.shape[-1])
if input.dtype == torch.float8_e4m3fn:
input = input.to(torch.bfloat16) * module.input_scale
assert input.dtype == torch.bfloat16
@ -1003,6 +1021,10 @@ class FP8BlockScalesLinearMethod(UnquantizedLinearMethod):
output = torch.ops.trtllm.fp8_block_scaling_gemm(
act_input_fp8, module.weight, act_input_sf, module.weight_scale)
# Reshape output back to original shape (with out_features as last dim)
if len(original_shape) > 2:
output = output.reshape(*original_shape[:-1], output.shape[-1])
if bias is not None:
output = output + bias
return output
@ -1212,6 +1234,15 @@ class NVFP4LinearMethod(LinearMethodBase):
def apply(self, module: Linear, input: torch.Tensor,
bias: Optional[torch.Tensor]):
# Handle multi-dimensional inputs (e.g., 3D: batch, seq, hidden).
# GEMM requires 2D. Only plain tensors support for now, skip for
# tuple and Fp4QuantizedTensor.
original_shape = None
if not isinstance(input,
(tuple, Fp4QuantizedTensor)) and input.dim() > 2:
original_shape = input.shape
input = input.reshape(-1, input.shape[-1])
act_fp4, act_sf = self._input_prepare(module, input)
# Use unified interface - supports CUTLASS, cuBLASLt, CuteDSL
# Convert list to comma-separated string for torch.compile compatibility
@ -1229,6 +1260,9 @@ class NVFP4LinearMethod(LinearMethodBase):
if output.shape[-1] > module.out_features:
output = output[..., :module.out_features].contiguous()
if original_shape is not None:
output = output.reshape(*original_shape[:-1], output.shape[-1])
if bias is not None:
output = output + bias
return output

View File

@ -0,0 +1,45 @@
"""Visual generation module for diffusion models."""
from tensorrt_llm._torch.visual_gen.executor import (
DiffusionExecutor,
DiffusionRequest,
DiffusionResponse,
)
from tensorrt_llm._torch.visual_gen.output import MediaOutput
# Checkpoint loading
from .checkpoints import WeightLoader
from .config import (
AttentionConfig,
DiffusionArgs,
DiffusionModelConfig,
ParallelConfig,
PipelineComponent,
PipelineConfig,
TeaCacheConfig,
discover_pipeline_components,
)
from .models import AutoPipeline, BasePipeline, WanPipeline
from .pipeline_loader import PipelineLoader
__all__ = [
# Config classes
"DiffusionArgs",
"DiffusionModelConfig",
"ParallelConfig",
"PipelineComponent",
"TeaCacheConfig",
# Checkpoint loading
"WeightLoader",
# Model loading
"PipelineLoader",
# Execution
"DiffusionExecutor",
"DiffusionRequest",
"DiffusionResponse",
"MediaOutput",
# Pipelines
"AutoPipeline",
"BasePipeline",
"WanPipeline",
]

View File

@ -0,0 +1,37 @@
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Visual Generation Attention Backend
This module provides attention backend infrastructure for visual generation (diffusion) models.
It reuses existing TRT-LLM attention backends (TrtllmAttention, VanillaAttention) with
simplified metadata that doesn't require KV caching.
"""
from .interface import AttentionTensorLayout
from .parallel import UlyssesAttention
from .trtllm import TrtllmAttention, TrtllmAttentionMetadata
from .utils import create_attention, get_visual_gen_attention_backend
from .vanilla import VanillaAttention
__all__ = [
"AttentionTensorLayout",
"get_visual_gen_attention_backend",
"create_attention",
"TrtllmAttention",
"TrtllmAttentionMetadata",
"UlyssesAttention",
"VanillaAttention",
]

View File

@ -0,0 +1,33 @@
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Visual Generation Attention Backend Interface
Defines shared types and enums for attention backends.
"""
from enum import Enum
class AttentionTensorLayout(str, Enum):
"""
Tensor layout for attention backend input/output.
Backends declare their preferred layout so the attention module
can reshape tensors optimally before calling the backend.
"""
NHD = "NHD" # [B, S, H, D] - batch, seq, heads, dim
HND = "HND" # [B, H, S, D] - batch, heads, seq, dim

View File

@ -0,0 +1,162 @@
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Ulysses Sequence Parallelism Wrapper
Wraps any attention backend with sequence parallelism via all-to-all
communication. Not a standalone backend compose around a real backend
(VANILLA/TRTLLM).
Architecture:
Input: [B, S/P, H, D] (sequence sharded across P processes)
Step 1: All-to-All [B, S, H/P, D] (gather sequence, shard heads)
Step 2: Compute attention with wrapped backend (VANILLA or TRTLLM)
Step 3: All-to-All [B, S/P, H, D] (restore sequence sharding)
Output: [B, S/P, H, D] (sequence sharded)
"""
from typing import Optional
import torch
import torch.nn as nn
from tensorrt_llm._torch.distributed import all_to_all_4d
from .interface import AttentionTensorLayout
class UlyssesAttention(nn.Module):
"""
Ulysses Sequence Parallelism wrapper.
Wraps any attention backend with sequence parallelism via all-to-all.
Not a standalone backend compose around a real backend (VANILLA/TRTLLM).
"""
def __init__(
self,
inner_backend: nn.Module,
process_group: Optional[torch.distributed.ProcessGroup] = None,
):
super().__init__()
self.inner_backend = inner_backend
self.process_group = process_group
self._preferred_layout = AttentionTensorLayout.NHD
# Derive head info from inner backend
self.head_dim = inner_backend.head_dim
self.sharded_num_heads = inner_backend.num_heads
self.sharded_num_kv_heads = getattr(inner_backend, "num_kv_heads", self.sharded_num_heads)
# Get world size from process group
try:
self.world_size = torch.distributed.get_world_size(group=process_group)
except (RuntimeError, ValueError):
self.world_size = 1
# Full (unsharded) head counts for external interface
self.num_heads = self.sharded_num_heads * self.world_size
self.num_kv_heads = self.sharded_num_kv_heads * self.world_size
def forward(
self,
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
batch_size: int,
attention_mask: Optional[torch.Tensor] = None,
**kwargs,
) -> torch.Tensor:
"""
Forward pass with Ulysses sequence parallelism.
Input/Output: [B, S/P, H, D] (sequence sharded)
Args:
q: Query tensor [B, S/P, H, D]
k: Key tensor [B, S/P, H, D]
v: Value tensor [B, S/P, H, D]
batch_size: Batch size
attention_mask: Optional attention mask
Returns:
Output tensor [B, S/P, H, D] (sequence sharded)
Note:
seq_len is computed from tensor shape after all-to-all, not passed as parameter.
"""
# Step 1: All-to-All to gather full sequence, shard heads
# [B, S/P, H, D] -> [B, S, H/P, D]
if self.world_size > 1:
q = all_to_all_4d(q, scatter_dim=2, gather_dim=1, process_group=self.process_group)
k = all_to_all_4d(k, scatter_dim=2, gather_dim=1, process_group=self.process_group)
v = all_to_all_4d(v, scatter_dim=2, gather_dim=1, process_group=self.process_group)
seq_len_full = q.shape[1]
inner_layout = self.inner_backend.preferred_layout
# Step 2: Call wrapped backend for attention
# Transpose only if inner backend expects HND layout
if inner_layout == AttentionTensorLayout.HND:
# VANILLA expects [B, H/P, S, D]
q = q.transpose(1, 2)
k = k.transpose(1, 2)
v = v.transpose(1, 2)
# NHD backends (TRTLLM) keep [B, S, H/P, D] as-is
inner_kwargs = dict(
q=q,
k=k,
v=v,
batch_size=batch_size,
seq_len=seq_len_full,
)
if attention_mask is not None:
inner_kwargs["attention_mask"] = attention_mask
output = self.inner_backend.forward(**inner_kwargs)
# Convert output back to [B, S, H/P, D] for the reverse all-to-all
if inner_layout == AttentionTensorLayout.HND:
# VANILLA returns [B, H/P, S, D] -> transpose to [B, S, H/P, D]
output = output.transpose(1, 2).contiguous()
else:
# TRTLLM returns [B, S, (H/P)*D] (3D) -> reshape to [B, S, H/P, D]
if output.dim() == 3:
output = output.view(
batch_size, seq_len_full, self.sharded_num_heads, self.head_dim
)
output = output.contiguous()
# Step 3: All-to-All to restore sequence sharding
# [B, S, H/P, D] -> [B, S/P, H, D]
if self.world_size > 1:
output = all_to_all_4d(
output,
scatter_dim=1,
gather_dim=2,
process_group=self.process_group,
)
return output
@property
def preferred_layout(self) -> AttentionTensorLayout:
"""Preferred tensor layout: [B, S, H, D]"""
return self._preferred_layout
@classmethod
def support_fused_qkv(cls) -> bool:
"""This backend does not support fused QKV."""
return False

View File

@ -0,0 +1,244 @@
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Diffusion TRTLLM Attention Backend
Wraps TrtllmAttention with simplified metadata for visual generation (diffusion) models.
Handles the specifics of no-KV-cache operation and fused QKV requirements.
"""
from typing import Optional, Union
import torch
from tensorrt_llm.mapping import Mapping
from tensorrt_llm.models.modeling_utils import QuantConfig
from ...attention_backend.interface import AttentionRuntimeFeatures, PredefinedAttentionMask
from ...attention_backend.trtllm import TrtllmAttention as BaseTrtllmAttention
from ...attention_backend.trtllm import TrtllmAttentionMetadata as BaseTrtllmAttentionMetadata
from .interface import AttentionTensorLayout
class TrtllmAttentionMetadata:
"""
Simplified metadata adapter for diffusion models using TRTLLM backend.
Lazy initialization with auto-growing capacity:
- Metadata created only when capacity needs increase
- prepare() called only when seq_lens actually change
- Automatically reallocates when batch_size or seq_len exceeds current capacity
Args:
max_batch_size: Initial batch size hint. Will grow automatically if exceeded.
max_seq_len: Initial sequence length hint. Will grow automatically if exceeded.
device: Target device for tensors.
"""
def __init__(
self,
max_batch_size: int = 16,
max_seq_len: int = 4096,
device: Optional[torch.device] = None,
):
# These are initial hints, not hard limits - capacity grows as needed
self.max_batch_size = max_batch_size
self.max_seq_len = max_seq_len
self.device = device or torch.device("cuda")
# Lazily created BaseTrtllmAttentionMetadata
self._metadata: Optional[BaseTrtllmAttentionMetadata] = None
# Track allocated capacity
self._allocated_batch_size = 0
self._allocated_max_seq_len = 0
# Track prepared state
self._cached_seq_lens: Optional[torch.Tensor] = None
self._prepared = False
def _needs_new_metadata(self, batch_size: int, max_seq_len: int) -> bool:
"""Check if we need to create new metadata (capacity change)."""
return (
self._metadata is None
or batch_size > self._allocated_batch_size
or max_seq_len > self._allocated_max_seq_len
)
def _needs_prepare(self, batch_size: int, seq_lens: torch.Tensor) -> bool:
"""Check if we need to call prepare() (seq_lens changed)."""
if not self._prepared:
return True
if self._cached_seq_lens is None:
return True
if self._cached_seq_lens.shape[0] != batch_size:
return True
return not torch.equal(self._cached_seq_lens[:batch_size], seq_lens)
def _create_metadata(self, batch_size: int, max_seq_len: int) -> None:
"""Create new metadata with given capacity."""
# Allocate with some headroom to avoid frequent reallocation
alloc_batch = max(batch_size, self._allocated_batch_size)
alloc_seq_len = max(max_seq_len, self._allocated_max_seq_len)
self._metadata = BaseTrtllmAttentionMetadata(
max_num_requests=alloc_batch,
max_num_tokens=alloc_batch * alloc_seq_len,
max_num_sequences=alloc_batch,
kv_cache_manager=None, # No KV cache for diffusion
mapping=Mapping(),
runtime_features=AttentionRuntimeFeatures(),
)
self._allocated_batch_size = alloc_batch
self._allocated_max_seq_len = alloc_seq_len
self._prepared = False # Reset prepare state on new metadata
def prepare(
self,
batch_size: int,
seq_lens: Union[int, torch.Tensor],
) -> BaseTrtllmAttentionMetadata:
"""
Prepare metadata for a forward pass.
Lazy behavior:
- Creates metadata only when capacity needs increase
- Calls prepare() only when seq_lens actually change
"""
if isinstance(seq_lens, int):
seq_lens_tensor = torch.full((batch_size,), seq_lens, dtype=torch.int32)
else:
seq_lens_tensor = seq_lens.to(dtype=torch.int32)
max_seq_len = seq_lens_tensor.max().item()
if self._needs_new_metadata(batch_size, max_seq_len):
self._create_metadata(batch_size, max_seq_len)
if self._needs_prepare(batch_size, seq_lens_tensor):
self._metadata.seq_lens = seq_lens_tensor
self._metadata.num_contexts = batch_size
self._metadata.max_seq_len = max_seq_len
self._metadata.request_ids = list(range(batch_size))
self._metadata.prepare()
# Cache for next comparison
if self._cached_seq_lens is None or self._cached_seq_lens.shape[0] < batch_size:
self._cached_seq_lens = seq_lens_tensor.clone()
else:
self._cached_seq_lens[:batch_size].copy_(seq_lens_tensor)
self._prepared = True
return self._metadata
class TrtllmAttention(BaseTrtllmAttention):
"""
TRTLLM Attention wrapper for diffusion models.
Handles:
- Fused QKV requirement for TRTLLM kernel
- Metadata creation and preparation
- No KV cache operation
"""
def __init__(
self,
layer_idx: int = 0,
num_heads: int = 8,
head_dim: int = 64,
num_kv_heads: Optional[int] = None,
quant_config: Optional[QuantConfig] = None,
dtype: Optional[torch.dtype] = None,
max_batch_size: int = 16,
max_seq_len: int = 4096,
):
num_kv_heads = num_kv_heads or num_heads
super().__init__(
layer_idx=layer_idx,
num_heads=num_heads,
num_kv_heads=num_kv_heads,
head_dim=head_dim,
quant_config=quant_config,
dtype=dtype,
)
# TRTLLM expects flat [B*S, H*D] format
self._preferred_layout = AttentionTensorLayout.NHD
self.metadata = TrtllmAttentionMetadata(
max_batch_size=max_batch_size,
max_seq_len=max_seq_len,
)
def forward(
self,
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
batch_size: int,
seq_len: int,
attention_mask: PredefinedAttentionMask = PredefinedAttentionMask.FULL,
seq_len_kv: Optional[int] = None,
**kwargs,
) -> torch.Tensor:
"""
Forward pass with automatic metadata handling.
For diffusion models, expects:
- Fused QKV: q contains [Q, K, V] concatenated, k and v are None
- OR separate Q, K, V which will be fused internally
Args:
q: Query tensor [num_tokens, hidden] or fused QKV [num_tokens, qkv_hidden]
k: Key tensor [num_tokens, kv_hidden] or None if fused
v: Value tensor [num_tokens, kv_hidden] or None if fused
batch_size: Batch size (required if not inferable)
seq_len: Sequence length for Q (required if not inferable)
attention_mask: Attention mask type
seq_len_kv: Sequence length for K/V (for cross-attention, defaults to seq_len)
Returns:
Output tensor [num_tokens, q_hidden]
"""
# Handle cross-attention where K/V have different sequence length than Q
kv_seq_len = seq_len_kv if seq_len_kv is not None else seq_len
# Separate Q, K, V provided - fuse them
q = q.view(batch_size * seq_len, -1)
k = k.view(batch_size * kv_seq_len, -1)
v = v.view(batch_size * kv_seq_len, -1)
qkv = torch.cat([q, k, v], dim=-1)
prepared_metadata = self.metadata.prepare(batch_size, seq_len)
output = super().forward(
q=qkv,
k=None,
v=None,
metadata=prepared_metadata,
attention_mask=attention_mask,
)
output = output.view(batch_size, seq_len, -1)
return output
@property
def preferred_layout(self) -> AttentionTensorLayout:
"""Return the preferred tensor layout for this backend."""
return self._preferred_layout
@classmethod
def support_fused_qkv(cls) -> bool:
return True

View File

@ -0,0 +1,118 @@
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Visual Generation Attention Backend Utilities
Factory functions for creating attention backends for visual generation models.
Uses diffusion-specific wrappers (TrtllmAttention, VanillaAttention)
that handle metadata preparation internally for simplified usage.
"""
from typing import TYPE_CHECKING, Optional, Type, Union
import torch
from tensorrt_llm.models.modeling_utils import QuantConfig
# Lazy imports to avoid circular dependency
if TYPE_CHECKING:
from .trtllm import TrtllmAttention
from .vanilla import VanillaAttention
# Type alias for diffusion attention backends
DiffusionAttentionBackend = Union[TrtllmAttention, VanillaAttention]
def get_visual_gen_attention_backend(
backend_name: str,
) -> Type["DiffusionAttentionBackend"]:
"""
Get diffusion attention backend class by name.
Args:
backend_name: Backend identifier ("VANILLA", "TRTLLM")
Returns:
Diffusion attention backend class
Backend Selection Guide:
- "VANILLA": Full support for cross-attention (different Q/KV seq lengths)
Uses torch SDPA backend
- "TRTLLM": Optimized for self-attention (requires same Q/KV seq lengths)
Better performance but requires fused QKV
"""
# Lazy imports to avoid circular dependency
from .trtllm import TrtllmAttention
from .vanilla import VanillaAttention
backend_name = backend_name.upper()
if backend_name == "VANILLA":
return VanillaAttention
elif backend_name == "TRTLLM":
return TrtllmAttention
else:
# Default to VANILLA for maximum compatibility
return VanillaAttention
def create_attention(
backend: str,
layer_idx: int,
num_heads: int,
head_dim: int,
num_kv_heads: Optional[int] = None,
quant_config: Optional[QuantConfig] = None,
dtype: Optional[torch.dtype] = None,
max_batch_size: int = 16,
max_seq_len: int = 4096,
**kwargs,
) -> "DiffusionAttentionBackend":
"""
Factory function to create attention backend instance for visual generation.
Creates diffusion-specific attention backends that handle metadata preparation
internally, simplifying the forward() call.
Args:
backend: Backend identifier ("VANILLA", "TRTLLM")
layer_idx: Layer index in the model
num_heads: Number of attention heads
head_dim: Dimension per head
num_kv_heads: Number of KV heads (for GQA/MQA, defaults to num_heads)
quant_config: Optional quantization configuration
dtype: Data type for the attention
max_batch_size: Initial batch size for metadata pre-allocation. The backend
will automatically reallocate if larger batches are encountered.
max_seq_len: Initial sequence length for metadata pre-allocation. The backend
will automatically reallocate if longer sequences are encountered.
**kwargs: Additional backend-specific arguments
Returns:
Diffusion attention backend instance (TrtllmAttention or VanillaAttention)
"""
attn_cls = get_visual_gen_attention_backend(backend)
return attn_cls(
layer_idx=layer_idx,
num_heads=num_heads,
head_dim=head_dim,
num_kv_heads=num_kv_heads,
quant_config=quant_config,
dtype=dtype,
max_batch_size=max_batch_size,
max_seq_len=max_seq_len,
**kwargs,
)

View File

@ -0,0 +1,126 @@
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Diffusion Vanilla Attention Backend
Simple attention implementation for visual generation (diffusion) models using
torch.nn.functional.scaled_dot_product_attention (SDPA).
Supports both self-attention and cross-attention (different Q/KV sequence lengths).
No KV cache - full recompute each diffusion step.
"""
import math
from typing import Optional
import torch
import torch.nn as nn
import torch.nn.functional as F
from ...attention_backend.interface import PredefinedAttentionMask
from .interface import AttentionTensorLayout
class VanillaAttention(nn.Module):
"""
Vanilla Attention for diffusion models using torch SDPA.
Uses torch.nn.functional.scaled_dot_product_attention which:
- Properly handles cross-attention (different Q/KV sequence lengths)
- Uses Flash Attention 2 when available (via SDPA backend selection)
- No KV cache needed for diffusion models
This is simpler than the LLM VanillaAttention which has complex
KV cache handling and uses flash_attn_varlen_func.
"""
def __init__(
self,
layer_idx: int = 0,
num_heads: int = 8,
head_dim: int = 64,
num_kv_heads: Optional[int] = None,
dtype: Optional[torch.dtype] = None,
**kwargs,
):
super().__init__()
self.layer_idx = layer_idx
self.num_heads = num_heads
self.head_dim = head_dim
self.num_kv_heads = num_kv_heads or num_heads
self.dtype = dtype
self.scale = 1.0 / math.sqrt(head_dim)
# SDPA expects [B, H, S, D] format
self._preferred_layout = AttentionTensorLayout.HND
def forward(
self,
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
batch_size: int,
seq_len: int,
seq_len_kv: Optional[int] = None,
attention_mask: PredefinedAttentionMask = PredefinedAttentionMask.FULL,
**kwargs,
) -> torch.Tensor:
"""
Forward pass using torch SDPA.
Args:
q: Query tensor [num_tokens, num_heads * head_dim]
k: Key tensor [num_kv_tokens, num_kv_heads * head_dim]
v: Value tensor [num_kv_tokens, num_kv_heads * head_dim]
batch_size: Batch size
seq_len: Query sequence length
seq_len_kv: KV sequence length (for cross-attention)
attention_mask: Attention mask type (CAUSAL or FULL)
Returns:
Output tensor [num_tokens, num_heads * head_dim]
"""
is_causal = attention_mask == PredefinedAttentionMask.CAUSAL
# Validate tensor shapes - flexible for Ulysses head sharding
# Expected: [batch_size, num_heads, seq_len, head_dim]
# Note: num_heads may be sharded (num_heads // ulysses_size) when using Ulysses
assert (
q.dim() == 4
and q.shape[0] == batch_size
and q.shape[2] == seq_len
and q.shape[3] == self.head_dim
), (
f"Invalid q shape: expected [B={batch_size}, H, S={seq_len}, D={self.head_dim}], got {q.shape}"
)
assert k.dim() == 4 and k.shape[0] == batch_size and k.shape[3] == self.head_dim, (
f"Invalid k shape: expected [B={batch_size}, H_kv, S_kv, D={self.head_dim}], got {k.shape}"
)
assert v.dim() == 4 and v.shape[0] == batch_size and v.shape[3] == self.head_dim, (
f"Invalid v shape: expected [B={batch_size}, H_kv, S_kv, D={self.head_dim}], got {v.shape}"
)
# TODO: Maybe we need to enforce cuDNN backend here
return F.scaled_dot_product_attention(q, k, v, is_causal=is_causal, scale=self.scale)
@property
def preferred_layout(self) -> AttentionTensorLayout:
"""Return the preferred tensor layout for this backend."""
return self._preferred_layout
@classmethod
def support_fused_qkv(cls) -> bool:
return False

View File

@ -0,0 +1,7 @@
"""Diffusion model checkpoint loading utilities."""
from .weight_loader import WeightLoader
__all__ = [
"WeightLoader",
]

View File

@ -0,0 +1,152 @@
"""Weight loader for diffusion models."""
import json
from pathlib import Path
from typing import Any, Dict, List, Union
import torch
import tqdm
from tensorrt_llm._torch.models.checkpoints.base_weight_loader import BaseWeightLoader
from tensorrt_llm._torch.visual_gen.config import PipelineComponent
from tensorrt_llm.logger import logger
from tensorrt_llm.mapping import Mapping
class WeightLoader(BaseWeightLoader):
"""
Weight loader for diffusion models.
Loads weights from safetensors/bin files, similar to HfWeightLoader
but simpler (no parallel loading optimization for now).
Supports loading multiple components (e.g., transformer and transformer_2):
loader = WeightLoader(components=["transformer", "transformer_2"])
weights = loader.load_weights(ckpt_dir, mapping)
# Returns: {"transformer": {...}, "transformer_2": {...}}
"""
def __init__(self, components: Union[str, List[str]] = PipelineComponent.TRANSFORMER):
"""
Args:
components: Component(s) to load weights for. Can be:
- Single string: "transformer" (returns flat dict)
- List of strings: ["transformer", "transformer_2"] (returns nested dict)
"""
if isinstance(components, str):
self.components = [components]
self.single_component = True
else:
self.components = components
self.single_component = False
def load_weights(
self,
checkpoint_dir: str,
mapping: Mapping,
**kwargs,
) -> Dict[str, Any]:
"""
Load weights from checkpoint directory.
Args:
checkpoint_dir: Path to checkpoint (pipeline root or component dir)
mapping: Distributed mapping (for future TP/PP support)
Returns:
- If single component: Dict mapping weight names to tensors
- If multiple components: Dict mapping component names to weight dicts
Example: {"transformer": {...}, "transformer_2": {...}}
"""
checkpoint_path = Path(checkpoint_dir)
# Check if this is a pipeline (has model_index.json)
model_index = checkpoint_path / "model_index.json"
is_pipeline = model_index.exists()
# Load weights for each component
all_weights = {}
for component in self.components:
if is_pipeline:
# Pipeline format: load from component subdirectory
component_dir = checkpoint_path / component
if not component_dir.exists():
raise ValueError(f"Component '{component}' not found in {checkpoint_dir}")
weight_dir = component_dir
else:
# Standalone model (only valid for single component)
if len(self.components) > 1:
raise ValueError(
f"Multiple components specified but {checkpoint_dir} is not a pipeline "
"(no model_index.json found)"
)
weight_dir = checkpoint_path
# Find weight files
weight_files = self._find_weight_files(weight_dir)
if not weight_files:
raise ValueError(f"No weight files found in {weight_dir}")
# Load all weights with progress bar
component_weights = {}
desc = f"Loading {component}" if is_pipeline else "Loading checkpoint"
for wf in tqdm.tqdm(weight_files, desc=desc):
component_weights.update(self._load_file(wf))
all_weights[component] = component_weights
# Return flat dict for single component (backward compatibility)
if self.single_component:
return all_weights[self.components[0]]
# Return nested dict for multiple components
return all_weights
def _find_weight_files(self, weight_dir) -> List[str]:
"""Find safetensors or bin weight files.
Handles:
- Single safetensors file
- Sharded safetensors with index.json
- PyTorch bin/pth files
"""
weight_dir = Path(weight_dir)
# Check for sharded safetensors index
index_file = weight_dir / "diffusion_pytorch_model.safetensors.index.json"
if not index_file.exists():
index_file = weight_dir / "model.safetensors.index.json"
if index_file.exists():
# Sharded safetensors: read index to get all shard files
with open(index_file) as f:
index = json.load(f)
shard_files = set(index.get("weight_map", {}).values())
return sorted([str(weight_dir / f) for f in shard_files])
# Single safetensors file
files = list(weight_dir.glob("*.safetensors"))
if files:
# Filter out consolidated if multiple files exist
if len(files) > 1:
files = [f for f in files if "consolidated" not in f.name]
return sorted([str(f) for f in files])
# Fallback to bin
files = list(weight_dir.glob("*.bin"))
if files:
return sorted([str(f) for f in files])
# Fallback to pth
files = list(weight_dir.glob("*.pth"))
return sorted([str(f) for f in files])
def _load_file(self, filepath: str) -> Dict[str, Any]:
"""Load weights from a single file."""
logger.debug(f"Loading {filepath}")
if filepath.endswith(".safetensors"):
from safetensors.torch import load_file
return load_file(filepath)
else:
return torch.load(filepath, map_location="cpu", weights_only=True)

View File

@ -0,0 +1,565 @@
import json
import os
from enum import Enum
from pathlib import Path
from types import SimpleNamespace
from typing import Any, Dict, List, Literal, Optional, Tuple
import torch
from pydantic import BaseModel, ConfigDict, model_validator
from pydantic import Field as PydanticField
from tensorrt_llm.functional import AllReduceStrategy
from tensorrt_llm.mapping import Mapping
from tensorrt_llm.models.modeling_utils import QuantConfig
from tensorrt_llm.quantization.mode import QuantAlgo
# =============================================================================
# Pipeline component identifiers
# =============================================================================
class PipelineComponent(str, Enum):
"""Identifiers for pipeline components that can be loaded or skipped.
Inherits from str so values compare equal to plain strings,
e.g. ``PipelineComponent.VAE == "vae"`` is ``True``.
"""
TRANSFORMER = "transformer"
VAE = "vae"
TEXT_ENCODER = "text_encoder"
TOKENIZER = "tokenizer"
SCHEDULER = "scheduler"
IMAGE_ENCODER = "image_encoder"
IMAGE_PROCESSOR = "image_processor"
# =============================================================================
# Sub-configuration classes for DiffusionArgs
# =============================================================================
class AttentionConfig(BaseModel):
"""Configuration for Attention layers."""
backend: Literal["VANILLA", "TRTLLM"] = PydanticField(
"VANILLA", description="Attention backend: VANILLA (PyTorch SDPA), TRTLLM"
)
class ParallelConfig(BaseModel):
"""Configuration for distributed parallelism.
Currently Supported:
- dit_cfg_size: CFG (Classifier-Free Guidance) parallelism
- dit_ulysses_size: Ulysses sequence parallelism
Not Yet Supported:
- dit_tp_size: Tensor parallelism (not implemented)
- dit_ring_size: Ring attention (not implemented)
- dit_cp_size, dit_dp_size, dit_fsdp_size: Other parallelism types
Total world_size = dit_cfg_size × dit_ulysses_size
Parallelism Strategy:
- CFG Parallelism: Distributes positive/negative prompts across GPUs
- Ulysses Parallelism: Distributes sequence within each CFG group
Example Configurations:
1. cfg_size=1, ulysses_size=2 -> 2 GPUs (Ulysses only)
GPU 0-1: Single prompt, sequence parallelism across 2 GPUs
2. cfg_size=2, ulysses_size=1 -> 2 GPUs (CFG only)
GPU 0: Positive prompt
GPU 1: Negative prompt
3. cfg_size=2, ulysses_size=2 -> 4 GPUs (CFG + Ulysses)
GPU 0-1: CFG group 0 (positive), Ulysses parallel
GPU 2-3: CFG group 1 (negative), Ulysses parallel
4. cfg_size=2, ulysses_size=4 -> 8 GPUs (CFG + Ulysses)
GPU 0-3: CFG group 0 (positive), Ulysses parallel
GPU 4-7: CFG group 1 (negative), Ulysses parallel
"""
disable_parallel_vae: bool = False
parallel_vae_split_dim: Literal["width", "height"] = "width"
# DiT Parallelism
dit_dp_size: int = PydanticField(1, ge=1)
dit_tp_size: int = PydanticField(1, ge=1) # Not yet supported
dit_ulysses_size: int = PydanticField(1, ge=1) # Supported
dit_ring_size: int = PydanticField(1, ge=1) # Not yet supported
dit_cp_size: int = PydanticField(1, ge=1)
dit_cfg_size: int = PydanticField(1, ge=1) # Supported
dit_fsdp_size: int = PydanticField(1, ge=1)
# Refiner Parallelism (Optional)
refiner_dit_dp_size: int = 1
refiner_dit_tp_size: int = 1
refiner_dit_ulysses_size: int = 1
refiner_dit_ring_size: int = 1
refiner_dit_cp_size: int = 1
refiner_dit_cfg_size: int = 1
refiner_dit_fsdp_size: int = 1
t5_fsdp_size: int = 1
def to_mapping(self) -> Mapping:
"""Convert to TRT-LLM Mapping."""
world_size = self.dit_tp_size * self.dit_cp_size
return Mapping(
world_size=world_size,
tp_size=self.dit_tp_size,
pp_size=1,
cp_size=self.dit_cp_size,
)
@model_validator(mode="after")
def validate_parallel_sizes(self) -> "ParallelConfig":
"""Validate configuration against current environment."""
if torch.cuda.is_available():
world_size = int(os.environ.get("WORLD_SIZE", 1))
total_parallel = (
self.dit_tp_size
* self.dit_ulysses_size
* self.dit_ring_size
* self.dit_cp_size
* self.dit_dp_size
* self.dit_cfg_size
)
if total_parallel > world_size:
raise ValueError(
f"Total DiT parallel size ({total_parallel}) exceeds WORLD_SIZE ({world_size})"
)
return self
class TeaCacheConfig(BaseModel):
"""Configuration for TeaCache runtime optimization.
TeaCache speeds up diffusion by caching transformer outputs when timestep
embeddings change slowly. It monitors embedding distances and reuses cached
residuals when changes are below a threshold.
Attributes:
enable_teacache: Enable TeaCache optimization
teacache_thresh: Distance threshold for cache decisions (lower = more caching)
use_ret_steps: Use aggressive warmup mode (5 steps) vs minimal (1 step)
coefficients: Polynomial coefficients for rescaling embedding distances
Applied as: rescaled_distance = poly(raw_distance)
ret_steps: Number of warmup steps (always compute, initialized at runtime)
cutoff_steps: Step to stop caching (always compute after, initialized at runtime)
num_steps: Total inference steps (set at runtime)
_cnt: Internal step counter (reset per generation)
"""
enable_teacache: bool = False
teacache_thresh: float = PydanticField(0.2, gt=0.0)
use_ret_steps: bool = True
coefficients: List[float] = PydanticField(default_factory=lambda: [1.0, 0.0])
# Runtime state fields (initialized by TeaCacheBackend.refresh)
ret_steps: Optional[int] = None
cutoff_steps: Optional[int] = None
num_steps: Optional[int] = None
# State tracking (reset per generation)
_cnt: int = 0
model_config = ConfigDict(arbitrary_types_allowed=True)
@model_validator(mode="after")
def validate_teacache(self) -> "TeaCacheConfig":
"""Validate TeaCache configuration."""
# Validate coefficients
if len(self.coefficients) == 0:
raise ValueError("TeaCache coefficients list cannot be empty")
# Validate ret_steps if set
if self.ret_steps is not None and self.ret_steps < 0:
raise ValueError(f"ret_steps must be non-negative, got {self.ret_steps}")
# Validate cutoff_steps vs num_steps if both set
if self.cutoff_steps is not None and self.num_steps is not None:
if self.cutoff_steps > self.num_steps:
raise ValueError(
f"cutoff_steps ({self.cutoff_steps}) cannot exceed num_steps ({self.num_steps})"
)
return self
class PipelineConfig(BaseModel):
"""General pipeline configuration."""
enable_torch_compile: bool = True
torch_compile_models: str = PipelineComponent.TRANSFORMER
torch_compile_mode: str = "default"
fuse_qkv: bool = True
# Offloading Config
enable_offloading: bool = False
offload_device: Literal["cpu", "cuda"] = "cpu"
offload_param_pin_memory: bool = True
# =============================================================================
# DiffusionArgs - User-facing configuration (CLI / YAML)
# =============================================================================
class DiffusionArgs(BaseModel):
"""User-facing configuration for diffusion model loading and inference.
This is the main config class used in CLI args and YAML config files.
PipelineLoader converts this to DiffusionModelConfig internally.
Example:
args = DiffusionArgs(
checkpoint_path="/path/to/model",
quant_config={"quant_algo": "FP8_BLOCK_SCALES", "dynamic": True},
parallel=ParallelConfig(dit_tp_size=2),
)
loader = PipelineLoader()
pipeline = loader.load(args)
"""
model_config = ConfigDict(arbitrary_types_allowed=True)
# Required: Path to checkpoint or HuggingFace Hub model ID
checkpoint_path: str = PydanticField(
"",
description=(
"Local directory path or HuggingFace Hub model ID "
"(e.g., 'Wan-AI/Wan2.1-T2V-1.3B-Diffusers'). "
"Hub models are downloaded and cached automatically."
),
)
# HuggingFace Hub options
revision: Optional[str] = PydanticField(
None,
description="HuggingFace Hub revision (branch, tag, or commit SHA) to download.",
)
# Device/dtype options
device: str = "cuda"
dtype: str = "bfloat16"
# Component loading options (use PipelineComponent enum values or plain strings)
skip_components: List[PipelineComponent] = PydanticField(
default_factory=list,
description=(
"Components to skip loading. "
"Accepts PipelineComponent enum values or equivalent strings "
"(e.g., [PipelineComponent.TEXT_ENCODER, PipelineComponent.VAE])"
),
)
# Sub-configs (dict input for quant_config is coerced to QuantConfig in model_validator)
quant_config: QuantConfig = PydanticField(default_factory=QuantConfig)
pipeline: PipelineConfig = PydanticField(default_factory=PipelineConfig)
attention: AttentionConfig = PydanticField(default_factory=AttentionConfig)
parallel: ParallelConfig = PydanticField(default_factory=ParallelConfig)
teacache: TeaCacheConfig = PydanticField(default_factory=TeaCacheConfig)
# Set by model_validator when quant_config is provided as a dict (ModelOpt format)
dynamic_weight_quant: bool = False
force_dynamic_quantization: bool = False
@model_validator(mode="before")
@classmethod
def _parse_quant_config_dict(cls, data: Any) -> Any:
"""Parse user-facing DiffusionArgs.quant_config (dict or None) into QuantConfig and dynamic flags.
User input is ModelOpt-format dict (e.g. {"quant_algo": "FP8", "dynamic": True}).
We coerce it to QuantConfig + dynamic_weight_quant + force_dynamic_quantization so that
from_pretrained() can copy them into DiffusionModelConfig (internal) without parsing again.
"""
if not isinstance(data, dict):
return data
raw = data.get("quant_config")
if raw is None:
data = {**data, "quant_config": QuantConfig()}
return data
if not isinstance(raw, dict):
return data
qc, _, dwq, daq = DiffusionModelConfig.load_diffusion_quant_config(raw)
data = {
**data,
"quant_config": qc,
"dynamic_weight_quant": dwq,
"force_dynamic_quantization": daq,
}
return data
def to_mapping(self) -> Mapping:
"""Derive Mapping from ParallelConfig."""
return self.parallel.to_mapping()
def to_dict(self) -> Dict[str, Any]:
"""Convert to dictionary."""
return self.model_dump()
@classmethod
def from_dict(cls, config_dict: Dict[str, Any]) -> "DiffusionArgs":
"""Create from dictionary with automatic nested config parsing.
Pydantic automatically handles nested configs, but we keep this method
for backward compatibility and to filter unknown fields.
"""
# Get valid field names for DiffusionArgs
valid_fields = set(cls.model_fields.keys())
# Filter to only include valid fields (ignore unknown fields)
filtered_dict = {k: v for k, v in config_dict.items() if k in valid_fields}
# Pydantic automatically converts nested dicts to their respective config classes
return cls(**filtered_dict)
# =============================================================================
# Utilities
# =============================================================================
def discover_pipeline_components(checkpoint_path: Path) -> Dict[str, Path]:
"""
Discover components from diffusers pipeline's model_index.json.
Returns dict mapping component name to config.json path.
"""
model_index_path = checkpoint_path / "model_index.json"
if not model_index_path.exists():
return {}
with open(model_index_path) as f:
model_index = json.load(f)
components = {}
for key, value in model_index.items():
if key.startswith("_") or value is None:
continue
config_path = checkpoint_path / key / "config.json"
if config_path.exists():
components[key] = config_path
return components
# =============================================================================
# DiffusionModelConfig - Internal configuration (merged/parsed)
# =============================================================================
class DiffusionModelConfig(BaseModel):
"""Internal ModelConfig for diffusion models.
This is created by PipelineLoader from DiffusionArgs + checkpoint.
Contains merged/parsed config from:
- pretrained_config: From checkpoint/config.json
- quant_config: From checkpoint or user quant config
- Sub-configs: From DiffusionArgs (pipeline, attention, parallel, teacache)
"""
model_config = ConfigDict(arbitrary_types_allowed=True)
pretrained_config: Optional[Any] = None
mapping: Mapping = PydanticField(default_factory=Mapping)
skip_create_weights_in_init: bool = False
force_dynamic_quantization: bool = False
allreduce_strategy: AllReduceStrategy = PydanticField(default=AllReduceStrategy.AUTO)
extra_attrs: Dict = PydanticField(default_factory=dict)
# Distributed process groups
ulysses_process_group: Optional[torch.distributed.ProcessGroup] = None
dynamic_weight_quant: bool = False
# Sub-configs from DiffusionArgs (merged during from_pretrained)
quant_config: QuantConfig = PydanticField(default_factory=QuantConfig)
# Per-layer quant (from load_diffusion_quant_config layer_quant_config; None until mixed-precision parsing exists)
quant_config_dict: Optional[Dict[str, QuantConfig]] = None
pipeline: PipelineConfig = PydanticField(default_factory=PipelineConfig)
attention: AttentionConfig = PydanticField(default_factory=AttentionConfig)
parallel: ParallelConfig = PydanticField(default_factory=ParallelConfig)
teacache: TeaCacheConfig = PydanticField(default_factory=TeaCacheConfig)
@property
def torch_dtype(self) -> "torch.dtype":
"""Get the torch dtype of the model (default: bfloat16)."""
return torch.bfloat16
def get_quant_config(self, name: Optional[str] = None) -> QuantConfig:
"""Get quantization config for a layer or global. Resembles LLM ModelConfig.get_quant_config."""
if name is None or self.quant_config_dict is None:
return self.quant_config
if name in self.quant_config_dict:
return self.quant_config_dict[name]
return self.quant_config
@staticmethod
def load_diffusion_quant_config(
quant_config_dict: dict,
) -> Tuple[QuantConfig, Optional[Dict], bool, bool]:
"""
Parse quantization config in ModelOpt format.
Returns: (quant_config, layer_quant_config, dynamic_weight_quant, dynamic_activation_quant)
- quant_config: Global QuantConfig
- layer_quant_config: Per-layer config dict (None if not using mixed precision)
- dynamic_weight_quant: Whether to quantize weights at load time
- dynamic_activation_quant: Whether to quantize activations dynamically
"""
quant_algo_str = quant_config_dict.get("quant_algo")
quant_algo = None
if quant_algo_str:
algo_map = {
"FP8": QuantAlgo.FP8,
"FP8_BLOCK_SCALES": QuantAlgo.FP8_BLOCK_SCALES,
"NVFP4": QuantAlgo.NVFP4,
"W4A16_AWQ": QuantAlgo.W4A16_AWQ,
"W4A8_AWQ": QuantAlgo.W4A8_AWQ,
"W8A8_SQ_PER_CHANNEL": QuantAlgo.W8A8_SQ_PER_CHANNEL,
}
quant_algo = algo_map.get(quant_algo_str)
if quant_algo is None:
raise ValueError(f"Unknown quant_algo: {quant_algo_str}")
# Parse group_size and dynamic flags from config_groups
group_size = None
dynamic_weight_quant = False
dynamic_activation_quant = False
for group_config in quant_config_dict.get("config_groups", {}).values():
weights_config = group_config.get("weights", {})
activations_config = group_config.get("input_activations", {})
dynamic_weight_quant = weights_config.get("dynamic", False)
dynamic_activation_quant = activations_config.get("dynamic", False)
# Extract group_size from weights config (e.g., NVFP4: group_size=16)
if group_size is None:
group_size = weights_config.get("group_size")
break
# Set defaults based on quant_algo if group_size not specified
if group_size is None:
if quant_algo in (QuantAlgo.NVFP4,):
group_size = 16 # NVFP4 default
elif quant_algo == QuantAlgo.FP8_BLOCK_SCALES:
group_size = 128 # FP8 blockwise default
# Auto-enable dynamic weight quantization if quant_algo is specified
# but no explicit config_groups setting is present.
# This allows simple configs like {"quant_algo": "FP8"} to work.
if quant_algo is not None and not quant_config_dict.get("config_groups"):
dynamic_weight_quant = quant_config_dict.get("dynamic", True)
quant_config = QuantConfig(
quant_algo=quant_algo,
group_size=group_size,
exclude_modules=quant_config_dict.get("ignore"),
)
# TODO: Per-layer config (None for now - future: parse mixed precision settings)
layer_quant_config = None
return quant_config, layer_quant_config, dynamic_weight_quant, dynamic_activation_quant
@classmethod
def from_pretrained(
cls,
checkpoint_dir: str,
args: Optional["DiffusionArgs"] = None,
**kwargs,
) -> "DiffusionModelConfig":
"""
Load config from pretrained checkpoint.
Called by PipelineLoader with DiffusionArgs:
config = DiffusionModelConfig.from_pretrained(
checkpoint_dir=args.checkpoint_path,
args=args,
)
Args:
checkpoint_dir: Path to checkpoint
args: DiffusionArgs containing user config (quant, pipeline, attention, parallel, teacache)
**kwargs: Additional config options (e.g., mapping)
"""
kwargs.pop("trust_remote_code", None)
# Extract sub-configs from args or use defaults
pipeline_cfg = args.pipeline if args else PipelineConfig()
attention_cfg = args.attention if args else AttentionConfig()
parallel_cfg = args.parallel if args else ParallelConfig()
teacache_cfg = args.teacache if args else TeaCacheConfig()
component = PipelineComponent.TRANSFORMER
checkpoint_path = Path(checkpoint_dir)
# Discover pipeline components
components = discover_pipeline_components(checkpoint_path)
# Determine config path
if components:
if component not in components:
raise ValueError(
f"Component '{component}' not found. Available: {list(components.keys())}"
)
config_path = components[component]
else:
config_path = checkpoint_path / "config.json"
if not config_path.exists():
raise ValueError(f"Config not found at {config_path}")
# Load pretrained_config from checkpoint
with open(config_path) as f:
config_dict = json.load(f)
pretrained_config = SimpleNamespace(**config_dict)
model_index_path = checkpoint_path / "model_index.json"
if model_index_path.exists():
with open(model_index_path) as f:
model_index = json.load(f)
if "boundary_ratio" in model_index and "transformer_2" in model_index:
transformer_2_spec = model_index.get("transformer_2")
if transformer_2_spec and transformer_2_spec[0] is not None:
pretrained_config.boundary_ratio = model_index["boundary_ratio"]
# Resolve quant config: use args if user set quant (QuantConfig from dict), else checkpoint
if args and args.quant_config.quant_algo is not None:
quant_config = args.quant_config
quant_config_dict = (
None # DiffusionArgs has no per-layer dict; only from checkpoint parse
)
dynamic_weight_quant = args.dynamic_weight_quant
dynamic_activation_quant = args.force_dynamic_quantization
else:
quant_config = QuantConfig()
quant_config_dict = None
dynamic_weight_quant = False
dynamic_activation_quant = False
quant_dict = getattr(pretrained_config, "quantization_config", None)
if isinstance(quant_dict, dict):
quant_config, quant_config_dict, dynamic_weight_quant, dynamic_activation_quant = (
cls.load_diffusion_quant_config(quant_dict)
)
return cls(
pretrained_config=pretrained_config,
quant_config=quant_config,
quant_config_dict=quant_config_dict,
dynamic_weight_quant=dynamic_weight_quant,
force_dynamic_quantization=dynamic_activation_quant,
# Sub-configs from DiffusionArgs
pipeline=pipeline_cfg,
attention=attention_cfg,
parallel=parallel_cfg,
teacache=teacache_cfg,
# Delay weight creation after apply_quant_config_exclude_modules() in __post_init__
skip_create_weights_in_init=True,
**kwargs,
)

View File

@ -0,0 +1,246 @@
import os
import queue
import threading
import traceback
from dataclasses import dataclass
from typing import List, Optional, Union
import torch
import torch.distributed as dist
import zmq
from tensorrt_llm._torch.visual_gen.config import DiffusionArgs
from tensorrt_llm._torch.visual_gen.output import MediaOutput
from tensorrt_llm._torch.visual_gen.pipeline_loader import PipelineLoader
from tensorrt_llm.executor.ipc import ZeroMqQueue
from tensorrt_llm.logger import logger
@dataclass
class DiffusionRequest:
"""Request for diffusion inference with explicit model-specific parameters."""
request_id: int
prompt: str
negative_prompt: Optional[str] = None
height: int = 720
width: int = 1280
num_inference_steps: int = 50
guidance_scale: float = 5.0
max_sequence_length: int = 512
seed: int = 42
# Video-specific parameters
num_frames: int = 81
frame_rate: float = 24.0
# Image-specific parameters
num_images_per_prompt: int = 1
# Advanced parameters
guidance_rescale: float = 0.0
output_type: str = "pt"
# Wan-specific parameters
image: Optional[Union[str, List[str]]] = None
guidance_scale_2: Optional[float] = None
boundary_ratio: Optional[float] = None
last_image: Optional[Union[str, List[str]]] = None
@dataclass
class DiffusionResponse:
"""Response with model-specific output.
Attributes:
request_id: Unique identifier for the request
output: Generated media as MediaOutput with model-specific fields populated
error_msg: Error message if generation failed
"""
request_id: int
output: Optional[MediaOutput] = None
error_msg: Optional[str] = None
class DiffusionExecutor:
"""Execution engine for diffusion models running in worker processes."""
def __init__(
self,
model_path: str,
request_queue_addr: str,
response_queue_addr: str,
device_id: int,
diffusion_config: Optional[dict] = None,
):
self.model_path = model_path
self.request_queue_addr = request_queue_addr
self.response_queue_addr = response_queue_addr
self.device_id = device_id
self.diffusion_config = diffusion_config
self.requests_ipc = None
self.rank = dist.get_rank()
self.response_queue = queue.Queue()
self.sender_thread = None
# Only rank 0 handles IPC
if self.rank == 0:
logger.info(f"Worker {device_id}: Connecting to request queue")
self.requests_ipc = ZeroMqQueue(
(request_queue_addr, None),
is_server=False,
socket_type=zmq.PULL,
use_hmac_encryption=False,
)
self.sender_thread = threading.Thread(target=self._sender_loop, daemon=True)
self.sender_thread.start()
self._load_pipeline()
def _sender_loop(self):
"""Background thread for sending responses."""
logger.info(f"Worker {self.device_id}: Connecting to response queue")
responses_ipc = ZeroMqQueue(
(self.response_queue_addr, None),
is_server=False,
socket_type=zmq.PUSH,
use_hmac_encryption=False,
)
while True:
try:
resp = self.response_queue.get()
if resp is None:
break
responses_ipc.put(resp)
except Exception as e:
logger.error(f"Worker {self.device_id}: Sender error: {e}")
if responses_ipc.socket:
responses_ipc.socket.setsockopt(zmq.LINGER, 0)
responses_ipc.close()
def _load_pipeline(self):
"""
Load pipeline using proper flow:
DiffusionArgs PipelineLoader DiffusionModelConfig AutoPipeline BasePipeline
"""
logger.info(f"Worker {self.device_id}: Loading pipeline")
try:
# Convert diffusion_config dict to DiffusionArgs
config_dict = self.diffusion_config.copy()
config_dict["checkpoint_path"] = self.model_path
config_dict["device"] = f"cuda:{self.device_id}"
# Create DiffusionArgs from dict (handles nested configs)
args = DiffusionArgs.from_dict(config_dict)
# Use PipelineLoader for proper pipeline creation flow:
# PipelineLoader.load() internally:
# 1. Creates DiffusionModelConfig.from_pretrained()
# 2. Creates pipeline via AutoPipeline.from_config()
# 3. Loads weights with quantization support
# 4. Calls post_load_weights()
loader = PipelineLoader(args)
self.pipeline = loader.load()
except Exception as e:
logger.error(f"Worker {self.device_id}: Failed to load pipeline: {e}")
raise
logger.info(f"Worker {self.device_id}: Pipeline ready")
# Sync all workers
dist.barrier()
# Send READY signal
if self.rank == 0:
logger.info(f"Worker {self.device_id}: Sending READY")
self.response_queue.put(DiffusionResponse(request_id=-1, output="READY"))
def serve_forever(self):
"""Main execution loop."""
while True:
req = None
if self.rank == 0:
req = self.requests_ipc.get()
logger.info(f"Worker {self.device_id}: Request available")
# Broadcast to all ranks
obj_list = [req]
dist.broadcast_object_list(obj_list, src=0)
req = obj_list[0]
if req is None:
logger.info(f"Worker {self.device_id}: Shutdown signal received")
if self.rank == 0 and self.sender_thread:
self.response_queue.put(None)
self.sender_thread.join()
break
logger.info(f"Worker {self.device_id}: Processing request {req.request_id}")
self.process_request(req)
def process_request(self, req: DiffusionRequest):
"""Process a single request."""
try:
output = self.pipeline.infer(req)
if self.rank == 0:
self.response_queue.put(DiffusionResponse(request_id=req.request_id, output=output))
except Exception as e:
logger.error(f"Worker {self.device_id}: Error: {e}")
logger.error(traceback.format_exc())
if self.rank == 0:
self.response_queue.put(
DiffusionResponse(request_id=req.request_id, error_msg=str(e))
)
def run_diffusion_worker(
rank: int,
world_size: int,
master_addr: str,
master_port: int,
model_path: str,
request_queue_addr: str,
response_queue_addr: str,
diffusion_config: Optional[dict] = None,
):
"""Entry point for worker process."""
try:
# Setup distributed env — use PyTorch distributed, not MPI
os.environ["TLLM_DISABLE_MPI"] = "1"
os.environ["MASTER_ADDR"] = master_addr
os.environ["MASTER_PORT"] = str(master_port)
os.environ["RANK"] = str(rank)
os.environ["WORLD_SIZE"] = str(world_size)
# Calculate device_id before init_process_group
device_id = rank % torch.cuda.device_count() if torch.cuda.is_available() else 0
if torch.cuda.is_available():
torch.cuda.set_device(device_id)
dist.init_process_group(
backend="nccl" if torch.cuda.is_available() else "gloo",
init_method="env://",
world_size=world_size,
rank=rank,
device_id=torch.device(f"cuda:{device_id}") if torch.cuda.is_available() else None,
)
executor = DiffusionExecutor(
model_path=model_path,
request_queue_addr=request_queue_addr,
response_queue_addr=response_queue_addr,
device_id=device_id,
diffusion_config=diffusion_config,
)
executor.serve_forever()
dist.destroy_process_group()
except Exception as e:
logger.error(f"Worker failed: {e}")
traceback.print_exc()

View File

@ -0,0 +1,30 @@
"""
Visual generation model pipelines.
Each model subdirectory contains:
- pipeline_*.py: Main pipeline implementation inheriting from BasePipeline
- __init__.py: Exports the pipeline class
TeaCache extractors are registered inline in each pipeline's load() method
using register_extractor_from_config().
Pipelines are registered in pipeline_registry.py's PipelineRegistry._REGISTRY dict.
Example structure:
models/
my_model/
pipeline_my_model.py # Pipeline class with inline extractor registration
__init__.py # Exports: __all__ = ["MyModelPipeline"]
"""
from ..pipeline import BasePipeline
from ..pipeline_registry import AutoPipeline, register_pipeline
from .wan import WanImageToVideoPipeline, WanPipeline
__all__ = [
"AutoPipeline",
"BasePipeline",
"WanPipeline",
"WanImageToVideoPipeline",
"register_pipeline",
]

View File

@ -0,0 +1,5 @@
from .pipeline_wan import WanPipeline
from .pipeline_wan_i2v import WanImageToVideoPipeline
from .transformer_wan import WanTransformer3DModel
__all__ = ["WanPipeline", "WanImageToVideoPipeline", "WanTransformer3DModel"]

View File

@ -0,0 +1,521 @@
import time
from typing import Optional
import torch
from diffusers import AutoencoderKLWan, FlowMatchEulerDiscreteScheduler
from diffusers.utils.torch_utils import randn_tensor
from diffusers.video_processor import VideoProcessor
from transformers import AutoTokenizer, UMT5EncoderModel
from tensorrt_llm._torch.visual_gen.config import PipelineComponent
from tensorrt_llm._torch.visual_gen.output import MediaOutput
from tensorrt_llm._torch.visual_gen.pipeline import BasePipeline
from tensorrt_llm._torch.visual_gen.pipeline_registry import register_pipeline
from tensorrt_llm._torch.visual_gen.teacache import ExtractorConfig, register_extractor_from_config
from tensorrt_llm._torch.visual_gen.utils import postprocess_video_tensor
from tensorrt_llm.logger import logger
from .transformer_wan import WanTransformer3DModel
# Supported Wan T2V models:
# - Wan2.1-T2V-14B: Single-stage text-to-video (14B parameters)
# - Wan2.1-T2V-1.3B: Single-stage text-to-video (1.3B parameters)
# - Wan2.2-T2V-A14B: Two-stage text-to-video (14B, boundary_ratio for high/low-noise stages; supports 480P & 720P)
WAN_TEACACHE_COEFFICIENTS = {
"1.3B": {
"ret_steps": [
-5.21862437e04,
9.23041404e03,
-5.28275948e02,
1.36987616e01,
-4.99875664e-02,
],
"standard": [2.39676752e03, -1.31110545e03, 2.01331979e02, -8.29855975e00, 1.37887774e-01],
},
"14B": {
"ret_steps": [
-3.03318725e05,
4.90537029e04,
-2.65530556e03,
5.87365115e01,
-3.15583525e-01,
],
"standard": [-5784.54975374, 5449.50911966, -1811.16591783, 256.27178429, -13.02252404],
},
}
# Default negative prompt for Wan T2V models
WAN_DEFAULT_NEGATIVE_PROMPT = (
"Vibrant colors, overexposed, static, blurry details, subtitles, style, artwork, painting, image, "
"still image, overall grayish tone, worst quality, low quality, JPEG compression artifacts, ugly, "
"incomplete, extra fingers, poorly drawn hands, poorly drawn face, deformed, disfigured, malformed limbs, "
"fused fingers, motionless image, cluttered background, three legs, many people in the background, walking backward"
)
@register_pipeline("WanPipeline")
class WanPipeline(BasePipeline):
def __init__(self, model_config):
# Wan2.2 two-stage denoising parameters
self.transformer_2 = None
self.boundary_ratio = getattr(model_config.pretrained_config, "boundary_ratio", None)
self.is_wan22 = self.boundary_ratio is not None
super().__init__(model_config)
@staticmethod
def _compute_wan_timestep_embedding(module, timestep, guidance=None):
"""Compute timestep embedding for WAN transformer.
WAN uses a condition_embedder with timesteps_proj and time_embedder layers.
Handles dtype casting to match the embedder's dtype.
Args:
module: WanTransformer3DModel instance
timestep: Timestep tensor (shape: [batch_size])
guidance: Unused for WAN (no guidance embedding)
Returns:
Timestep embedding tensor used by TeaCache for distance calculation
"""
ce = module.condition_embedder
t_freq = ce.timesteps_proj(timestep)
# Cast to embedder's dtype (avoid int8 quantized layers)
te_dtype = next(iter(ce.time_embedder.parameters())).dtype
if t_freq.dtype != te_dtype and te_dtype != torch.int8:
t_freq = t_freq.to(te_dtype)
return ce.time_embedder(t_freq)
@property
def dtype(self):
return self.model_config.torch_dtype
@property
def device(self):
return self.transformer.device
@property
def transformer_components(self) -> list:
"""Return list of transformer components this pipeline needs."""
if self.transformer_2 is not None:
return ["transformer", "transformer_2"]
return ["transformer"]
def _init_transformer(self) -> None:
logger.info("Creating WAN transformer with quantization support...")
self.transformer = WanTransformer3DModel(model_config=self.model_config)
# Wan2.2: create second transformer for two-stage denoising
if self.boundary_ratio is not None:
logger.info("Creating second transformer for Wan2.2 two-stage denoising...")
self.transformer_2 = WanTransformer3DModel(model_config=self.model_config)
def load_standard_components(
self,
checkpoint_dir: str,
device: torch.device,
skip_components: Optional[list] = None,
) -> None:
"""Load VAE, text encoder, tokenizer, and scheduler from checkpoint."""
skip_components = skip_components or []
if self.transformer_2 is not None and self.boundary_ratio is None:
raise RuntimeError(
"transformer_2 exists but boundary_ratio is not set. "
"This indicates an inconsistent pipeline configuration."
)
# Detect model version
if self.is_wan22:
logger.info("Detected Wan 2.2 T2V (two-stage denoising)")
else:
logger.info("Detected Wan 2.1 T2V (single-stage denoising)")
# Set default VAE scale factors (will be overridden if VAE is loaded)
self.vae_scale_factor_temporal = 4
self.vae_scale_factor_spatial = 8
if PipelineComponent.TOKENIZER not in skip_components:
logger.info("Loading tokenizer...")
self.tokenizer = AutoTokenizer.from_pretrained(
checkpoint_dir,
subfolder=PipelineComponent.TOKENIZER,
)
if PipelineComponent.TEXT_ENCODER not in skip_components:
logger.info("Loading text encoder...")
self.text_encoder = UMT5EncoderModel.from_pretrained(
checkpoint_dir,
subfolder=PipelineComponent.TEXT_ENCODER,
torch_dtype=self.model_config.torch_dtype,
).to(device)
if PipelineComponent.VAE not in skip_components:
logger.info("Loading VAE...")
self.vae = AutoencoderKLWan.from_pretrained(
checkpoint_dir,
subfolder=PipelineComponent.VAE,
torch_dtype=torch.bfloat16, # load VAE in BF16 for memory saving
).to(device)
self.vae_scale_factor_temporal = getattr(self.vae.config, "scale_factor_temporal", 4)
self.vae_scale_factor_spatial = getattr(self.vae.config, "scale_factor_spatial", 8)
if PipelineComponent.SCHEDULER not in skip_components:
logger.info("Loading scheduler...")
self.scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained(
checkpoint_dir,
subfolder=PipelineComponent.SCHEDULER,
)
if not hasattr(self.scheduler.config, "shift") or self.scheduler.config.shift == 1.0:
self.scheduler = FlowMatchEulerDiscreteScheduler.from_config(
self.scheduler.config,
shift=5.0,
)
self.video_processor = VideoProcessor(vae_scale_factor=self.vae_scale_factor_spatial)
def load_weights(self, weights: dict) -> None:
# Store weights for later use (in case transformer_2 is created after this call)
self._weights_dict = weights
has_separate_weights = "transformer" in weights and "transformer_2" in weights
if self.transformer is not None and hasattr(self.transformer, "load_weights"):
logger.info("Loading transformer weights...")
transformer_weights = weights.get("transformer", weights)
self.transformer.load_weights(transformer_weights)
logger.info("Transformer weights loaded successfully.")
# Wan2.2: Load weights for second transformer if it exists
if self.transformer_2 is not None and hasattr(self.transformer_2, "load_weights"):
logger.info("Loading transformer_2 weights for Wan2.2...")
if not has_separate_weights:
raise ValueError(
"Wan2.2 model requires separate 'transformer' and 'transformer_2' weights in checkpoint, "
f"but only found: {list(weights.keys())}. "
"Two-stage denoising requires distinct weights for high-noise and low-noise transformers."
)
transformer_2_weights = weights["transformer_2"]
self.transformer_2.load_weights(transformer_2_weights)
logger.info("Transformer_2 weights loaded successfully.")
# Cache the target dtype from model config (default: bfloat16)
self._target_dtype = self.model_config.torch_dtype
# Set model to eval mode
if self.transformer is not None:
self.transformer.eval()
if self.transformer_2 is not None:
self.transformer_2.eval()
def post_load_weights(self) -> None:
super().post_load_weights() # Calls transformer.post_load_weights() for FP8 scale transformations
if self.transformer is not None:
# Register TeaCache extractor for this model type
# Tells TeaCache how to compute timestep embeddings for Wan
register_extractor_from_config(
ExtractorConfig(
model_class_name="WanTransformer3DModel",
timestep_embed_fn=self._compute_wan_timestep_embedding,
return_dict_default=False, # Wan returns raw tensors, not wrapped outputs
)
)
# Enable TeaCache optimization with WAN-specific coefficients
self._setup_teacache(self.transformer, coefficients=WAN_TEACACHE_COEFFICIENTS)
# Save transformer backend before it gets overwritten
self.transformer_cache_backend = self.cache_backend
# Wan2.2: Setup TeaCache for second transformer (low-noise stage)
if self.transformer_2 is not None:
if hasattr(self.transformer_2, "post_load_weights"):
self.transformer_2.post_load_weights()
# Enable TeaCache for low-noise stage with same coefficients
self._setup_teacache(self.transformer_2, coefficients=WAN_TEACACHE_COEFFICIENTS)
# Save transformer_2 backend
self.transformer_2_cache_backend = self.cache_backend
def infer(self, req):
"""Run inference with request parameters."""
return self.forward(
prompt=req.prompt,
negative_prompt=req.negative_prompt,
height=req.height,
width=req.width,
num_frames=req.num_frames,
num_inference_steps=req.num_inference_steps,
guidance_scale=req.guidance_scale,
guidance_scale_2=req.guidance_scale_2,
boundary_ratio=req.boundary_ratio,
seed=req.seed,
max_sequence_length=req.max_sequence_length,
)
@torch.no_grad()
def forward(
self,
prompt: str,
negative_prompt: Optional[str] = None,
height: int = 720,
width: int = 1280,
num_frames: int = 81,
num_inference_steps: Optional[int] = None,
guidance_scale: Optional[float] = None,
guidance_scale_2: Optional[float] = None,
boundary_ratio: Optional[float] = None,
seed: int = 42,
max_sequence_length: int = 226,
):
pipeline_start = time.time()
generator = torch.Generator(device=self.device).manual_seed(seed)
# Use user-provided boundary_ratio if given, otherwise fall back to model config
boundary_ratio = boundary_ratio if boundary_ratio is not None else self.boundary_ratio
# Validate that Wan 2.2 models have boundary_ratio set
if self.transformer_2 is not None and boundary_ratio is None:
raise ValueError(
"Wan 2.2 models require boundary_ratio to be set. "
"boundary_ratio was not found in model config. "
"Please pass boundary_ratio as a parameter."
)
# Set default negative prompt if not provided
if negative_prompt is None:
negative_prompt = WAN_DEFAULT_NEGATIVE_PROMPT
# Set model-specific defaults based on Wan version
logger.info(
f"Running {'Wan 2.2' if self.is_wan22 else 'Wan 2.1'} T2V inference"
f"(boundary_ratio={boundary_ratio}, has_transformer_2={self.transformer_2 is not None})"
)
if num_inference_steps is None:
num_inference_steps = 40 if self.is_wan22 else 50
if guidance_scale is None:
guidance_scale = 4.0 if self.is_wan22 else 5.0
if self.is_wan22 and guidance_scale_2 is None:
guidance_scale_2 = 3.0
# Validate two-stage denoising configuration
if guidance_scale_2 is not None and boundary_ratio is None:
logger.warning(
"guidance_scale_2 is specified but boundary_ratio is None. "
"guidance_scale_2 will be ignored."
"Set boundary_ratio in config or pass as parameter to enable two-stage denoising."
)
guidance_scale_2 = None
# Encode Prompt
logger.info("Encoding prompts...")
encode_start = time.time()
prompt_embeds, neg_prompt_embeds = self._encode_prompt(
prompt, negative_prompt, max_sequence_length
)
logger.info(f"Prompt encoding completed in {time.time() - encode_start:.2f}s")
# Prepare Latents
latents = self._prepare_latents(height, width, num_frames, generator)
logger.info(f"Latents shape: {latents.shape}")
self.scheduler.set_timesteps(num_inference_steps, device=self.device)
# Wan2.2: Calculate boundary timestep for two-stage denoising
boundary_timestep = None
if boundary_ratio is not None and self.transformer_2 is not None:
boundary_timestep = boundary_ratio * self.scheduler.config.num_train_timesteps
logger.info(
f"Wan2.2 two-stage denoising: boundary_timestep={boundary_timestep:.1f}, "
f"guidance_scale={guidance_scale}, guidance_scale_2={guidance_scale_2}"
)
# Denoising with two-stage support
# Track which model was used in last step (for logging model transitions)
last_model_used = [None]
def forward_fn(
latents, extra_stream_latents, timestep, encoder_hidden_states, extra_tensors
):
"""Forward function for Wan transformer with two-stage support.
extra_stream_latents and extra_tensors are unused for WAN (single stream, no additional embeddings).
"""
# Select model based on timestep (if two-stage denoising is enabled)
if boundary_timestep is not None and self.transformer_2 is not None:
# Extract scalar timestep for comparison
current_t = timestep if timestep.dim() == 0 else timestep[0]
if current_t >= boundary_timestep:
current_model = self.transformer
model_name = "transformer (high-noise)"
else:
current_model = self.transformer_2
model_name = "transformer_2 (low-noise)"
# Log when switching models
if last_model_used[0] != model_name:
if self.rank == 0:
logger.info(f"Switched to {model_name} at timestep {current_t:.1f}")
last_model_used[0] = model_name
else:
current_model = self.transformer
return current_model(
hidden_states=latents,
timestep=timestep,
encoder_hidden_states=encoder_hidden_states,
)
# Two-stage denoising: model switching in forward_fn, guidance scale switching in denoise()
latents = self.denoise(
latents=latents,
scheduler=self.scheduler,
prompt_embeds=prompt_embeds,
neg_prompt_embeds=neg_prompt_embeds,
guidance_scale=guidance_scale,
forward_fn=forward_fn,
guidance_scale_2=guidance_scale_2,
boundary_timestep=boundary_timestep,
)
# Log TeaCache statistics - show stats for each transformer separately
if self.rank == 0 and self.model_config.teacache.enable_teacache:
logger.info("=" * 80)
logger.info("TeaCache Statistics:")
# Stats for transformer (high-noise)
if hasattr(self, "transformer_cache_backend") and self.transformer_cache_backend:
stats = self.transformer_cache_backend.get_stats()
total_steps = stats.get("total_steps", 0)
cache_hits = stats.get("cached_steps", 0)
cache_misses = stats.get("compute_steps", 0)
hit_rate = (cache_hits / total_steps * 100) if total_steps > 0 else 0.0
logger.info(" Transformer (High-Noise):")
logger.info(f" Total steps: {total_steps}")
logger.info(f" Cache hits: {cache_hits}")
logger.info(f" Cache misses: {cache_misses}")
logger.info(f" Hit rate: {hit_rate:.1f}%")
# Stats for transformer_2 (low-noise)
if hasattr(self, "transformer_2_cache_backend") and self.transformer_2_cache_backend:
stats = self.transformer_2_cache_backend.get_stats()
total_steps = stats.get("total_steps", 0)
cache_hits = stats.get("cached_steps", 0)
cache_misses = stats.get("compute_steps", 0)
hit_rate = (cache_hits / total_steps * 100) if total_steps > 0 else 0.0
logger.info(" Transformer_2 (Low-Noise):")
logger.info(f" Total steps: {total_steps}")
logger.info(f" Cache hits: {cache_hits}")
logger.info(f" Cache misses: {cache_misses}")
logger.info(f" Hit rate: {hit_rate:.1f}%")
logger.info("=" * 80)
# Decode
logger.info("Decoding video...")
decode_start = time.time()
video = self.decode_latents(latents, self._decode_latents)
if self.rank == 0:
logger.info(f"Video decoded in {time.time() - decode_start:.2f}s")
logger.info(f"Total pipeline time: {time.time() - pipeline_start:.2f}s")
return MediaOutput(video=video)
def _encode_prompt(self, prompt, negative_prompt, max_sequence_length):
prompt = [prompt] if isinstance(prompt, str) else prompt
def get_embeds(texts):
text_inputs = self.tokenizer(
texts,
padding="max_length",
max_length=max_sequence_length,
truncation=True,
return_attention_mask=True,
return_tensors="pt",
)
input_ids = text_inputs.input_ids.to(self.device)
attention_mask = text_inputs.attention_mask.to(self.device)
embeds = self.text_encoder(input_ids, attention_mask=attention_mask).last_hidden_state
embeds = embeds.to(self.dtype)
# Zero-out padded tokens based on mask
seq_lens = attention_mask.gt(0).sum(dim=1).long()
cleaned_embeds = []
for u, v in zip(embeds, seq_lens):
real_content = u[:v]
pad_len = max_sequence_length - real_content.size(0)
if pad_len > 0:
padded = torch.cat(
[real_content, real_content.new_zeros(pad_len, real_content.size(1))]
)
else:
padded = real_content
cleaned_embeds.append(padded)
return torch.stack(cleaned_embeds, dim=0)
prompt_embeds = get_embeds(prompt)
if negative_prompt is None:
negative_prompt = ""
neg_prompt = [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt
if len(neg_prompt) == 1 and len(prompt) > 1:
neg_prompt = neg_prompt * len(prompt)
neg_embeds = get_embeds(neg_prompt)
return prompt_embeds, neg_embeds
def _prepare_latents(self, height, width, num_frames, generator):
num_channels_latents = 16
num_latent_frames = (num_frames - 1) // self.vae_scale_factor_temporal + 1
shape = (
1,
num_channels_latents,
num_latent_frames,
height // self.vae_scale_factor_spatial,
width // self.vae_scale_factor_spatial,
)
return randn_tensor(shape, generator=generator, device=self.device, dtype=self.dtype)
def _decode_latents(self, latents):
latents = latents.to(self.vae.dtype)
# Denormalization
if hasattr(self.vae.config, "latents_mean") and hasattr(self.vae.config, "latents_std"):
if not hasattr(self, "_latents_mean"):
self._latents_mean = (
torch.tensor(self.vae.config.latents_mean)
.view(1, -1, 1, 1, 1)
.to(self.device, self.vae.dtype)
)
self._latents_std = (
torch.tensor(self.vae.config.latents_std)
.view(1, -1, 1, 1, 1)
.to(self.device, self.vae.dtype)
)
latents = (latents * self._latents_std) + self._latents_mean
else:
scaling_factor = self.vae.config.get("scaling_factor", 1.0)
latents = latents / scaling_factor
# VAE decode: returns (B, C, T, H, W)
video = self.vae.decode(latents, return_dict=False)[0]
# Post-process video tensor: (B, C, T, H, W) -> (T, H, W, C) uint8
video = postprocess_video_tensor(video, remove_batch_dim=True)
return video

View File

@ -0,0 +1,736 @@
import json
import os
import time
from typing import Optional, Tuple, Union
import PIL.Image
import torch
from diffusers import AutoencoderKLWan, FlowMatchEulerDiscreteScheduler
from diffusers.utils.torch_utils import randn_tensor
from diffusers.video_processor import VideoProcessor
from transformers import AutoTokenizer, CLIPImageProcessor, CLIPVisionModel, UMT5EncoderModel
from tensorrt_llm._torch.visual_gen.config import PipelineComponent
from tensorrt_llm._torch.visual_gen.output import MediaOutput
from tensorrt_llm._torch.visual_gen.pipeline import BasePipeline
from tensorrt_llm._torch.visual_gen.pipeline_registry import register_pipeline
from tensorrt_llm._torch.visual_gen.teacache import ExtractorConfig, register_extractor_from_config
from tensorrt_llm._torch.visual_gen.utils import postprocess_video_tensor
from tensorrt_llm.logger import logger
# Supported Wan I2V 14B models:
# - Wan2.1-I2V-14B-480P: Single-stage image-to-video
# - Wan2.1-I2V-14B-720P: Single-stage image-to-video
# - Wan2.2-I2V-14B: Two-stage image-to-video (no CLIP, boundary_ratio for two-stage denoising)
# Note: Wan2.2-I2V-5B (expand_timesteps mode) is NOT supported by this pipeline
# Import shared coefficients from T2V pipeline
from .pipeline_wan import WAN_TEACACHE_COEFFICIENTS
from .transformer_wan import WanTransformer3DModel
# Use same coefficients
WAN_I2V_TEACACHE_COEFFICIENTS = WAN_TEACACHE_COEFFICIENTS
# Default negative prompt for Wan I2V models
WAN_DEFAULT_NEGATIVE_PROMPT = (
"Vibrant colors, overexposed, static, blurry details, subtitles, style, artwork, painting, image, "
"still image, overall grayish tone, worst quality, low quality, JPEG compression artifacts, ugly, "
"incomplete, extra fingers, poorly drawn hands, poorly drawn face, deformed, disfigured, malformed limbs, "
"fused fingers, motionless image, cluttered background, three legs, many people in the background, walking backward"
)
def retrieve_latents(
encoder_output: torch.Tensor,
generator: Optional[torch.Generator] = None,
sample_mode: str = "argmax",
):
"""Extract latents from VAE encoder output.
For I2V, we use argmax mode to get deterministic encoding of the input image.
"""
if hasattr(encoder_output, "latent_dist") and sample_mode == "sample":
return encoder_output.latent_dist.sample(generator)
elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax":
return encoder_output.latent_dist.mode()
elif hasattr(encoder_output, "latents"):
return encoder_output.latents
else:
raise AttributeError("Could not access latents of provided encoder_output")
@register_pipeline("WanImageToVideoPipeline")
class WanImageToVideoPipeline(BasePipeline):
def __init__(self, model_config):
# Wan2.2 14B two-stage denoising parameters
self.transformer_2 = None
self.boundary_ratio = getattr(model_config.pretrained_config, "boundary_ratio", None)
self.is_wan22 = self.boundary_ratio is not None
super().__init__(model_config)
@staticmethod
def _compute_wan_timestep_embedding(module, timestep, guidance=None):
"""Compute timestep embedding for Wan I2V transformer."""
ce = module.condition_embedder
t_freq = ce.timesteps_proj(timestep)
# Cast to embedder's dtype (avoid int8 quantized layers)
te_dtype = next(iter(ce.time_embedder.parameters())).dtype
if t_freq.dtype != te_dtype and te_dtype != torch.int8:
t_freq = t_freq.to(te_dtype)
return ce.time_embedder(t_freq)
@property
def dtype(self):
return self.model_config.torch_dtype
@property
def device(self):
return self.transformer.device
@property
def transformer_components(self) -> list:
if self.transformer_2 is not None:
return ["transformer", "transformer_2"]
return ["transformer"]
def _init_transformer(self) -> None:
logger.info("Creating WAN I2V transformer with quantization support...")
self.transformer = WanTransformer3DModel(model_config=self.model_config)
# Wan2.2: Optionally create second transformer for two-stage denoising
if self.boundary_ratio is not None:
logger.info("Creating second transformer for Wan2.2 I2V two-stage denoising...")
self.transformer_2 = WanTransformer3DModel(model_config=self.model_config)
def load_standard_components(
self,
checkpoint_dir: str,
device: torch.device,
skip_components: Optional[list] = None,
) -> None:
"""Load VAE, text encoder, tokenizer, scheduler, and I2V-specific components from checkpoint."""
skip_components = skip_components or []
# Load boundary_ratio and transformer_2 info from model_index.json (pipeline-level config)
# Wan 2.2 has both transformer_2 and boundary_ratio, Wan 2.1 doesn't
model_index_path = os.path.join(checkpoint_dir, "model_index.json")
has_transformer_2 = False
if os.path.exists(model_index_path):
with open(model_index_path) as f:
model_index = json.load(f)
# Check for boundary_ratio in model_index
if "boundary_ratio" in model_index:
self.boundary_ratio = model_index["boundary_ratio"]
logger.info(f"Found boundary_ratio in model_index.json: {self.boundary_ratio}")
else:
logger.info("No boundary_ratio found in model_index.json")
# Check for transformer_2 component
transformer_2_spec = model_index.get("transformer_2", None)
has_transformer_2 = (
transformer_2_spec is not None and transformer_2_spec[0] is not None
)
logger.info(f"transformer_2 in model_index.json: {has_transformer_2}")
# Set default VAE scale factors (will be overridden if VAE is loaded)
self.vae_scale_factor_temporal = 4
self.vae_scale_factor_spatial = 8
if PipelineComponent.TOKENIZER not in skip_components:
logger.info("Loading tokenizer...")
self.tokenizer = AutoTokenizer.from_pretrained(
checkpoint_dir,
subfolder=PipelineComponent.TOKENIZER,
)
if PipelineComponent.TEXT_ENCODER not in skip_components:
logger.info("Loading text encoder...")
self.text_encoder = UMT5EncoderModel.from_pretrained(
checkpoint_dir,
subfolder=PipelineComponent.TEXT_ENCODER,
torch_dtype=self.model_config.torch_dtype,
).to(device)
if PipelineComponent.VAE not in skip_components:
logger.info("Loading VAE...")
self.vae = AutoencoderKLWan.from_pretrained(
checkpoint_dir,
subfolder=PipelineComponent.VAE,
torch_dtype=torch.bfloat16, # load VAE in BF16 for memory saving
).to(device)
self.vae_scale_factor_temporal = getattr(self.vae.config, "scale_factor_temporal", 4)
self.vae_scale_factor_spatial = getattr(self.vae.config, "scale_factor_spatial", 8)
if PipelineComponent.SCHEDULER not in skip_components:
logger.info("Loading scheduler...")
self.scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained(
checkpoint_dir,
subfolder=PipelineComponent.SCHEDULER,
)
if not hasattr(self.scheduler.config, "shift") or self.scheduler.config.shift == 1.0:
self.scheduler = FlowMatchEulerDiscreteScheduler.from_config(
self.scheduler.config,
shift=5.0,
)
if self.transformer_2 is not None and self.boundary_ratio is None:
raise RuntimeError(
"transformer_2 exists but boundary_ratio is not set. "
"This indicates an inconsistent pipeline configuration."
)
# Load image encoder and processor (only for Wan 2.1)
# Wan 2.2: Has both transformer_2 and boundary_ratio (two-stage denoising)
if self.is_wan22:
logger.info("Detected Wan 2.2 I2V (two-stage, no CLIP)")
else:
logger.info("Detected Wan 2.1 I2V (single-stage, uses CLIP)")
if PipelineComponent.IMAGE_ENCODER not in skip_components and not self.is_wan22:
logger.info("Loading CLIP image encoder for I2V conditioning (Wan 2.1 only)...")
self.image_encoder = CLIPVisionModel.from_pretrained(
checkpoint_dir,
subfolder=PipelineComponent.IMAGE_ENCODER,
torch_dtype=torch.float32, # Keep CLIP in FP32 for stability
).to(device)
if PipelineComponent.IMAGE_PROCESSOR not in skip_components and not self.is_wan22:
logger.info("Loading CLIP image processor...")
self.image_processor = CLIPImageProcessor.from_pretrained(
checkpoint_dir,
subfolder=PipelineComponent.IMAGE_PROCESSOR,
)
self.video_processor = VideoProcessor(vae_scale_factor=self.vae_scale_factor_spatial)
def load_weights(self, weights: dict) -> None:
# Store weights for later use
self._weights_dict = weights
# Check if weights dict has separate transformer/transformer_2 keys (Wan2.2)
has_separate_weights = "transformer" in weights and "transformer_2" in weights
if self.transformer is not None and hasattr(self.transformer, "load_weights"):
logger.info("Loading transformer weights...")
transformer_weights = weights.get("transformer", weights)
self.transformer.load_weights(transformer_weights)
logger.info("Transformer weights loaded successfully.")
# Wan2.2: Load weights for second transformer if it exists
if self.transformer_2 is not None and hasattr(self.transformer_2, "load_weights"):
logger.info("Loading transformer_2 weights for Wan2.2 I2V...")
if has_separate_weights:
transformer_2_weights = weights["transformer_2"]
logger.info("Using separate transformer_2 weights from checkpoint")
else:
# For Wan 2.2, transformer_2 weights must exist
raise ValueError(
"Wan2.2 model requires separate 'transformer' and 'transformer_2' weights in checkpoint, "
f"but only found: {list(weights.keys())}"
)
self.transformer_2.load_weights(transformer_2_weights)
logger.info("Transformer_2 weights loaded successfully.")
# Cache the target dtype from model config (default: bfloat16)
self._target_dtype = self.model_config.torch_dtype
# Set model to eval mode
if self.transformer is not None:
self.transformer.eval()
if self.transformer_2 is not None:
self.transformer_2.eval()
if hasattr(self, "image_encoder") and self.image_encoder is not None:
self.image_encoder.eval()
def post_load_weights(self) -> None:
super().post_load_weights() # Calls transformer.post_load_weights() for FP8 scale transformations
if self.transformer is not None:
# Register TeaCache extractor for this model type
register_extractor_from_config(
ExtractorConfig(
model_class_name="WanTransformer3DModel",
timestep_embed_fn=self._compute_wan_timestep_embedding,
return_dict_default=False, # Wan returns raw tensors, not wrapped outputs
)
)
# Enable TeaCache optimization with Wan I2V-specific coefficients
self._setup_teacache(self.transformer, coefficients=WAN_I2V_TEACACHE_COEFFICIENTS)
# Save transformer backend before it gets overwritten
self.transformer_cache_backend = self.cache_backend
# Wan2.2: Setup TeaCache for second transformer (low-noise stage)
if self.transformer_2 is not None:
if hasattr(self.transformer_2, "post_load_weights"):
self.transformer_2.post_load_weights()
# Enable TeaCache for low-noise stage with same coefficients
self._setup_teacache(self.transformer_2, coefficients=WAN_I2V_TEACACHE_COEFFICIENTS)
# Save transformer_2 backend
self.transformer_2_cache_backend = self.cache_backend
def infer(self, req):
"""Run inference with request parameters."""
# Extract image from request (can be path, PIL Image, or torch.Tensor)
if req.image is None:
raise ValueError("I2V pipeline requires 'image' parameter")
image = req.image[0] if isinstance(req.image, list) else req.image
last_image = req.last_image
if last_image is not None and isinstance(last_image, list):
last_image = last_image[0] if last_image else None
return self.forward(
image=image,
prompt=req.prompt,
negative_prompt=req.negative_prompt,
height=req.height,
width=req.width,
num_frames=req.num_frames,
num_inference_steps=req.num_inference_steps,
guidance_scale=req.guidance_scale,
guidance_scale_2=req.guidance_scale_2,
boundary_ratio=req.boundary_ratio,
seed=req.seed,
max_sequence_length=req.max_sequence_length,
last_image=last_image,
)
@torch.no_grad()
def forward(
self,
image: Union[PIL.Image.Image, torch.Tensor, str],
prompt: str,
negative_prompt: Optional[str] = None,
height: int = 480,
width: int = 832,
num_frames: int = 81,
num_inference_steps: Optional[int] = None,
guidance_scale: Optional[float] = None,
guidance_scale_2: Optional[float] = None,
boundary_ratio: Optional[float] = None,
seed: int = 42,
max_sequence_length: int = 512,
last_image: Optional[Union[PIL.Image.Image, torch.Tensor, str]] = None,
):
pipeline_start = time.time()
generator = torch.Generator(device=self.device).manual_seed(seed)
# Use user-provided boundary_ratio if given, otherwise fall back to model config
boundary_ratio = boundary_ratio if boundary_ratio is not None else self.boundary_ratio
if self.transformer_2 is not None and boundary_ratio is None:
raise ValueError(
"Wan 2.2 models require boundary_ratio to be set. "
"boundary_ratio was not found in model config. "
"Please pass boundary_ratio as a parameter."
)
# Set default negative prompt if not provided
if negative_prompt is None:
negative_prompt = WAN_DEFAULT_NEGATIVE_PROMPT
# Set model-specific defaults based on Wan version
if num_inference_steps is None:
num_inference_steps = 40 if self.is_wan22 else 50
if guidance_scale is None:
guidance_scale = 4.0 if self.is_wan22 else 5.0
if self.is_wan22 and guidance_scale_2 is None:
guidance_scale_2 = 3.0 # Wan2.2 recommended default
# Validate two-stage denoising configuration
if guidance_scale_2 is not None and boundary_ratio is None:
logger.warning(
"guidance_scale_2 is specified but boundary_ratio is None. "
"guidance_scale_2 will be ignored."
"Set boundary_ratio in config or pass as parameter to enable two-stage denoising."
)
guidance_scale_2 = None
# Validate and adjust frame count for VAE compatibility
if num_frames % self.vae_scale_factor_temporal != 1:
logger.warning(
f"`num_frames - 1` must be divisible by {self.vae_scale_factor_temporal}. "
f"Rounding {num_frames} to nearest valid value."
)
num_frames = (
num_frames // self.vae_scale_factor_temporal * self.vae_scale_factor_temporal + 1
)
num_frames = max(num_frames, 1)
# Validate and adjust resolution for transformer patchification
patch_size = (
self.transformer.config.patch_size
if self.transformer is not None
else self.transformer_2.config.patch_size
)
h_multiple_of = self.vae_scale_factor_spatial * patch_size[1]
w_multiple_of = self.vae_scale_factor_spatial * patch_size[2]
calc_height = height // h_multiple_of * h_multiple_of
calc_width = width // w_multiple_of * w_multiple_of
if height != calc_height or width != calc_width:
logger.warning(
f"Height and width must be multiples of ({h_multiple_of}, {w_multiple_of}) for patchification. "
f"Adjusting ({height}, {width}) -> ({calc_height}, {calc_width})."
)
height, width = calc_height, calc_width
# Encode Prompt
logger.info("Encoding prompts...")
encode_start = time.time()
prompt_embeds, neg_prompt_embeds = self._encode_prompt(
prompt, negative_prompt, max_sequence_length
)
logger.info(f"Prompt encoding completed in {time.time() - encode_start:.2f}s")
# Encode Image (I2V-specific)
logger.info("Encoding input image...")
image_encode_start = time.time()
# Determine model version
model_version = "Wan 2.2" if self.is_wan22 else "Wan 2.1"
logger.info(
f"Running {model_version} I2V inference "
f"(boundary_ratio={boundary_ratio}, has_transformer_2={self.transformer_2 is not None})"
)
if not self.is_wan22:
# Wan 2.1 I2V: Compute CLIP image embeddings
image_embeds = self._encode_image(image, last_image)
image_embeds = image_embeds.to(self.dtype)
else:
# Wan 2.2 I2V: No image embeddings needed
image_embeds = None
logger.info(f"Image encoding completed in {time.time() - image_encode_start:.2f}s")
# Prepare Latents with image conditioning (I2V-specific)
latents, condition_data = self._prepare_latents(
image, height, width, num_frames, generator, last_image
)
self.scheduler.set_timesteps(num_inference_steps, device=self.device)
# Wan2.2: Calculate boundary timestep for two-stage denoising
boundary_timestep = None
if boundary_ratio is not None and self.transformer_2 is not None:
boundary_timestep = boundary_ratio * self.scheduler.config.num_train_timesteps
logger.info(
f"Wan2.2 I2V two-stage denoising: boundary_timestep={boundary_timestep:.1f}, "
f"guidance_scale={guidance_scale}, guidance_scale_2={guidance_scale_2}"
)
# Denoising with two-stage support
# Track which model was used in last step (for logging model transitions)
last_model_used = [None]
def forward_fn(
latents_input, extra_stream_latents, timestep, encoder_hidden_states, extra_tensors
):
"""Forward function for WAN I2V transformer with two-stage support.
Both Wan 2.1 and Wan 2.2 14B use concatenation approach: [latents, condition].
Difference: Wan 2.1 passes image_embeds, Wan 2.2 passes None.
"""
# Select model based on timestep (if two-stage denoising is enabled)
if boundary_timestep is not None and self.transformer_2 is not None:
# Extract scalar timestep for comparison
current_t = timestep if timestep.dim() == 0 else timestep[0]
if current_t >= boundary_timestep:
current_model = self.transformer
model_name = "transformer"
else:
current_model = self.transformer_2
model_name = "transformer_2"
# Log when switching models
if last_model_used[0] != model_name:
if self.rank == 0:
logger.info(
f"[TRTLLM] Switched to {model_name} at timestep {current_t:.1f}"
)
last_model_used[0] = model_name
else:
current_model = self.transformer
# Wan 2.1 & Wan 2.2 14B: concatenate latents and condition
# Handle CFG: duplicate condition if batch dimension doubled
if latents_input.shape[0] != condition_data.shape[0]:
condition_to_use = torch.cat([condition_data] * 2)
else:
condition_to_use = condition_data
latent_model_input = torch.cat([latents_input, condition_to_use], dim=1).to(self.dtype)
timestep_input = timestep.expand(latents_input.shape[0])
# Forward pass with I2V conditioning
# Wan 2.1: image_embeds is not None (CLIP embeddings)
# Wan 2.2 14B: image_embeds is None (no CLIP)
return current_model(
hidden_states=latent_model_input,
timestep=timestep_input,
encoder_hidden_states=encoder_hidden_states,
encoder_hidden_states_image=image_embeds,
)
# Two-stage denoising: model switching in forward_fn, guidance scale switching in denoise()
latents = self.denoise(
latents=latents,
scheduler=self.scheduler,
prompt_embeds=prompt_embeds,
neg_prompt_embeds=neg_prompt_embeds,
guidance_scale=guidance_scale,
forward_fn=forward_fn,
guidance_scale_2=guidance_scale_2,
boundary_timestep=boundary_timestep,
)
# Log TeaCache statistics - show stats for each transformer separately
if self.rank == 0 and self.model_config.teacache.enable_teacache:
logger.info("=" * 80)
logger.info("TeaCache Statistics:")
# Stats for transformer (high-noise)
if hasattr(self, "transformer_cache_backend") and self.transformer_cache_backend:
stats = self.transformer_cache_backend.get_stats()
total_steps = stats.get("total_steps", 0)
cache_hits = stats.get("cached_steps", 0)
cache_misses = stats.get("compute_steps", 0)
hit_rate = (cache_hits / total_steps * 100) if total_steps > 0 else 0.0
logger.info(" Transformer (High-Noise):")
logger.info(f" Total steps: {total_steps}")
logger.info(f" Cache hits: {cache_hits}")
logger.info(f" Cache misses: {cache_misses}")
logger.info(f" Hit rate: {hit_rate:.1f}%")
# Stats for transformer_2 (low-noise)
if hasattr(self, "transformer_2_cache_backend") and self.transformer_2_cache_backend:
stats = self.transformer_2_cache_backend.get_stats()
total_steps = stats.get("total_steps", 0)
cache_hits = stats.get("cached_steps", 0)
cache_misses = stats.get("compute_steps", 0)
hit_rate = (cache_hits / total_steps * 100) if total_steps > 0 else 0.0
logger.info(" Transformer_2 (Low-Noise):")
logger.info(f" Total steps: {total_steps}")
logger.info(f" Cache hits: {cache_hits}")
logger.info(f" Cache misses: {cache_misses}")
logger.info(f" Hit rate: {hit_rate:.1f}%")
logger.info("=" * 80)
# Decode
logger.info("Decoding video...")
decode_start = time.time()
video = self.decode_latents(latents, self._decode_latents)
if self.rank == 0:
logger.info(f"Video decoded in {time.time() - decode_start:.2f}s")
logger.info(f"Total pipeline time: {time.time() - pipeline_start:.2f}s")
return MediaOutput(video=video)
def _encode_prompt(self, prompt, negative_prompt, max_sequence_length):
"""Encode text prompts to embeddings (same as T2V)."""
prompt = [prompt] if isinstance(prompt, str) else prompt
def get_embeds(texts):
text_inputs = self.tokenizer(
texts,
padding="max_length",
max_length=max_sequence_length,
truncation=True,
return_attention_mask=True,
return_tensors="pt",
)
input_ids = text_inputs.input_ids.to(self.device)
attention_mask = text_inputs.attention_mask.to(self.device)
embeds = self.text_encoder(input_ids, attention_mask=attention_mask).last_hidden_state
embeds = embeds.to(self.dtype)
# Zero-out padded tokens based on mask
seq_lens = attention_mask.gt(0).sum(dim=1).long()
cleaned_embeds = []
for u, v in zip(embeds, seq_lens):
real_content = u[:v]
pad_len = max_sequence_length - real_content.size(0)
if pad_len > 0:
padded = torch.cat(
[real_content, real_content.new_zeros(pad_len, real_content.size(1))]
)
else:
padded = real_content
cleaned_embeds.append(padded)
return torch.stack(cleaned_embeds, dim=0)
prompt_embeds = get_embeds(prompt)
if negative_prompt is None:
negative_prompt = ""
neg_prompt = [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt
if len(neg_prompt) == 1 and len(prompt) > 1:
neg_prompt = neg_prompt * len(prompt)
neg_embeds = get_embeds(neg_prompt)
return prompt_embeds, neg_embeds
def _encode_image(
self,
image: Union[PIL.Image.Image, torch.Tensor, str],
last_image: Optional[Union[PIL.Image.Image, torch.Tensor, str]] = None,
) -> torch.Tensor:
"""Encode image(s) using CLIP image encoder (Wan 2.1 I2V only)."""
if isinstance(image, str):
image = PIL.Image.open(image).convert("RGB")
if isinstance(last_image, str):
last_image = PIL.Image.open(last_image).convert("RGB")
images_to_encode = [image] if last_image is None else [image, last_image]
image_inputs = self.image_processor(images=images_to_encode, return_tensors="pt").to(
self.device
)
image_embeds = self.image_encoder(**image_inputs, output_hidden_states=True)
return image_embeds.hidden_states[-2]
def _prepare_latents(
self,
image: Union[PIL.Image.Image, torch.Tensor, str],
height: int,
width: int,
num_frames: int,
generator: torch.Generator,
last_image: Optional[Union[PIL.Image.Image, torch.Tensor, str]] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Prepare latents with image conditioning for I2V generation."""
num_channels_latents = 16
num_latent_frames = (num_frames - 1) // self.vae_scale_factor_temporal + 1
latent_height = height // self.vae_scale_factor_spatial
latent_width = width // self.vae_scale_factor_spatial
# Create random noise latents
shape = (1, num_channels_latents, num_latent_frames, latent_height, latent_width)
latents = randn_tensor(shape, generator=generator, device=self.device, dtype=self.dtype)
# Load and preprocess image(s)
if isinstance(image, str):
image = PIL.Image.open(image).convert("RGB")
image = self.video_processor.preprocess(image, height=height, width=width).to(
self.device, dtype=torch.float32
)
if last_image is not None:
if isinstance(last_image, str):
last_image = PIL.Image.open(last_image).convert("RGB")
last_image = self.video_processor.preprocess(last_image, height=height, width=width).to(
self.device, dtype=torch.float32
)
image = image.unsqueeze(2)
# Create video conditioning tensor (same for both Wan 2.1 and Wan 2.2 14B)
if last_image is None:
# First frame + zeros
video_condition = torch.cat(
[
image,
image.new_zeros(image.shape[0], image.shape[1], num_frames - 1, height, width),
],
dim=2,
)
else:
# First frame + zeros + last frame (interpolation)
last_image = last_image.unsqueeze(2)
video_condition = torch.cat(
[
image,
image.new_zeros(image.shape[0], image.shape[1], num_frames - 2, height, width),
last_image,
],
dim=2,
)
# Encode video condition through VAE
video_condition = video_condition.to(device=self.device, dtype=self.vae.dtype)
latent_condition = retrieve_latents(self.vae.encode(video_condition), sample_mode="argmax")
latent_condition = latent_condition.to(self.dtype)
# Normalize latents to match diffusion model's latent space
latents_mean = (
torch.tensor(self.vae.config.latents_mean)
.view(1, self.vae.config.z_dim, 1, 1, 1)
.to(latents.device, latents.dtype)
)
latents_std = 1.0 / torch.tensor(self.vae.config.latents_std).view(
1, self.vae.config.z_dim, 1, 1, 1
).to(latents.device, latents.dtype)
latent_condition = (latent_condition - latents_mean) * latents_std
# Create mask in video frame space
# Reshaping is required to match the transformer's expected input format
mask_lat_size = torch.ones(1, 1, num_frames, latent_height, latent_width)
if last_image is None:
mask_lat_size[:, :, list(range(1, num_frames))] = 0
else:
mask_lat_size[:, :, list(range(1, num_frames - 1))] = 0
first_frame_mask = mask_lat_size[:, :, 0:1]
first_frame_mask = torch.repeat_interleave(
first_frame_mask, dim=2, repeats=self.vae_scale_factor_temporal
)
mask_lat_size = torch.cat([first_frame_mask, mask_lat_size[:, :, 1:, :]], dim=2)
mask_lat_size = mask_lat_size.view(
1, -1, self.vae_scale_factor_temporal, latent_height, latent_width
)
mask_lat_size = mask_lat_size.transpose(1, 2)
mask_lat_size = mask_lat_size.to(self.device, dtype=self.dtype)
# Concatenate mask and condition along channel dimension
condition = torch.cat([mask_lat_size, latent_condition], dim=1)
return latents, condition
def _decode_latents(self, latents):
"""Decode latents to video (same as T2V)."""
latents = latents.to(self.vae.dtype)
# Denormalization
if hasattr(self.vae.config, "latents_mean") and hasattr(self.vae.config, "latents_std"):
if not hasattr(self, "_latents_mean"):
self._latents_mean = (
torch.tensor(self.vae.config.latents_mean)
.view(1, -1, 1, 1, 1)
.to(self.device, self.vae.dtype)
)
self._latents_std = (
torch.tensor(self.vae.config.latents_std)
.view(1, -1, 1, 1, 1)
.to(self.device, self.vae.dtype)
)
latents = (latents * self._latents_std) + self._latents_mean
else:
scaling_factor = self.vae.config.get("scaling_factor", 1.0)
latents = latents / scaling_factor
# VAE decode: returns (B, C, T, H, W)
video = self.vae.decode(latents, return_dict=False)[0]
# Post-process video tensor: (B, C, T, H, W) -> (T, H, W, C) uint8
video = postprocess_video_tensor(video, remove_batch_dim=True)
return video

View File

@ -0,0 +1,756 @@
import math
from typing import Tuple
import torch
import torch.nn as nn
import torch.nn.functional as F
from diffusers.models.embeddings import PixArtAlphaTextProjection, TimestepEmbedding, Timesteps
from tqdm import tqdm
from transformers.modeling_utils import get_parameter_device
from tensorrt_llm._torch.modules.layer_norm import LayerNorm
from tensorrt_llm._torch.modules.linear import Linear
from tensorrt_llm._torch.modules.mlp import MLP
from tensorrt_llm._torch.modules.rms_norm import RMSNorm
from tensorrt_llm._torch.visual_gen.config import DiffusionModelConfig
from tensorrt_llm._torch.visual_gen.modules.attention import Attention, QKVMode
from tensorrt_llm._torch.visual_gen.parallelism import setup_sequence_parallelism
from tensorrt_llm._torch.visual_gen.quantization.loader import DynamicLinearWeightLoader
from tensorrt_llm.logger import logger
from tensorrt_llm.models.modeling_utils import QuantConfig
# =========================================================================
# 1. Rotary Positional Embeddings
# =========================================================================
class WanRotaryPosEmbed(nn.Module):
def __init__(
self,
attention_head_dim: int,
patch_size: Tuple[int, int, int],
max_seq_len: int,
theta: float = 10000.0,
):
super().__init__()
self.patch_size = patch_size
# Split logic matches Hugging Face exactly
self.h_dim = 2 * (attention_head_dim // 6)
self.w_dim = 2 * (attention_head_dim // 6)
self.t_dim = attention_head_dim - self.h_dim - self.w_dim
freqs_cos, freqs_sin = [], []
# Order: Time, Height, Width
for dim in [self.t_dim, self.h_dim, self.w_dim]:
# High precision generation
freqs = 1.0 / (theta ** (torch.arange(0, dim, 2, dtype=torch.float64) / dim))
t = torch.arange(max_seq_len, dtype=torch.float64)
freqs = torch.outer(t, freqs)
# Interleaved Pattern [c0, c0, c1, c1]
freqs_cos.append(freqs.cos().repeat_interleave(2, dim=-1).float())
freqs_sin.append(freqs.sin().repeat_interleave(2, dim=-1).float())
self.register_buffer("freqs_cos", torch.cat(freqs_cos, dim=1), persistent=False)
self.register_buffer("freqs_sin", torch.cat(freqs_sin, dim=1), persistent=False)
def forward(self, hidden_states: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
# Robust shape unpacking
b, c, f, h, w = hidden_states.shape
p_t, p_h, p_w = self.patch_size
ppf, pph, ppw = f // p_t, h // p_h, w // p_w
split_sizes = [self.t_dim, self.h_dim, self.w_dim]
freqs_cos = self.freqs_cos.split(split_sizes, dim=1)
freqs_sin = self.freqs_sin.split(split_sizes, dim=1)
# Broadcast frequencies to 3D grid: [Time, Height, Width]
f_cos_t = freqs_cos[0][:ppf].view(ppf, 1, 1, -1).expand(ppf, pph, ppw, -1)
f_sin_t = freqs_sin[0][:ppf].view(ppf, 1, 1, -1).expand(ppf, pph, ppw, -1)
f_cos_h = freqs_cos[1][:pph].view(1, pph, 1, -1).expand(ppf, pph, ppw, -1)
f_sin_h = freqs_sin[1][:pph].view(1, pph, 1, -1).expand(ppf, pph, ppw, -1)
f_cos_w = freqs_cos[2][:ppw].view(1, 1, ppw, -1).expand(ppf, pph, ppw, -1)
f_sin_w = freqs_sin[2][:ppw].view(1, 1, ppw, -1).expand(ppf, pph, ppw, -1)
# Concatenate and flatten for Attention [1, SeqLen, 1, Dim] (SHD format)
# New Attention module applies RoPE in [B, S, H, D] layout before reshaping to [B, H, S, D]
return (
torch.cat([f_cos_t, f_cos_h, f_cos_w], dim=-1).flatten(0, 2).unsqueeze(0).unsqueeze(2),
torch.cat([f_sin_t, f_sin_h, f_sin_w], dim=-1).flatten(0, 2).unsqueeze(0).unsqueeze(2),
)
# =========================================================================
# 2. Embeddings & Attention
# =========================================================================
class WanImageEmbedding(nn.Module):
"""Image embedding for I2V models (Wan 2.1/2.2)."""
def __init__(
self,
in_features: int,
out_features: int,
pos_embed_seq_len: int = None,
model_config: DiffusionModelConfig = None,
):
super().__init__()
dtype = model_config.torch_dtype if model_config else None
# LayerNorm weights in fp32 (matches internal float32 normalization; avoids bf16/fp32 mismatch).
self.norm1 = LayerNorm(
hidden_size=in_features, eps=1e-6, dtype=torch.float32, has_weights=True, has_bias=True
)
# Match HF FeedForward structure: Linear(in, in) → GELU → Linear(in, out)
self.ff_in = Linear(
in_features,
in_features,
bias=True,
dtype=dtype,
mapping=model_config.mapping if model_config else None,
quant_config=model_config.quant_config if model_config else None,
skip_create_weights_in_init=model_config.skip_create_weights_in_init
if model_config
else False,
force_dynamic_quantization=model_config.force_dynamic_quantization
if model_config
else False,
)
self.ff_out = Linear(
in_features,
out_features,
bias=True,
dtype=dtype,
mapping=model_config.mapping if model_config else None,
quant_config=model_config.quant_config if model_config else None,
skip_create_weights_in_init=model_config.skip_create_weights_in_init
if model_config
else False,
force_dynamic_quantization=model_config.force_dynamic_quantization
if model_config
else False,
)
self.norm2 = LayerNorm(
hidden_size=out_features, eps=1e-6, dtype=torch.float32, has_weights=True, has_bias=True
)
if pos_embed_seq_len is not None:
self.pos_embed = nn.Parameter(torch.zeros(1, pos_embed_seq_len, in_features))
else:
self.pos_embed = None
def forward(self, encoder_hidden_states_image: torch.Tensor) -> torch.Tensor:
if self.pos_embed is not None:
batch_size, seq_len, embed_dim = encoder_hidden_states_image.shape
encoder_hidden_states_image = encoder_hidden_states_image.view(
-1, 2 * seq_len, embed_dim
)
encoder_hidden_states_image = encoder_hidden_states_image + self.pos_embed
hidden_states = self.norm1(encoder_hidden_states_image)
hidden_states = self.ff_in(hidden_states)
hidden_states = F.gelu(hidden_states)
hidden_states = self.ff_out(hidden_states)
hidden_states = self.norm2(hidden_states)
return hidden_states
class WanTimeTextImageEmbedding(nn.Module):
def __init__(
self,
dim,
time_freq_dim,
time_proj_dim,
text_embed_dim,
model_config: DiffusionModelConfig,
image_embed_dim: int = None,
pos_embed_seq_len: int = None,
):
super().__init__()
dtype = model_config.torch_dtype
quant_config = model_config.quant_config
skip_create_weights = model_config.skip_create_weights_in_init
force_dynamic_quant = model_config.force_dynamic_quantization
self.timesteps_proj = Timesteps(
num_channels=time_freq_dim, flip_sin_to_cos=True, downscale_freq_shift=0
)
self.time_embedder = TimestepEmbedding(in_channels=time_freq_dim, time_embed_dim=dim)
self.act_fn = nn.SiLU()
self.time_proj = Linear(
dim,
time_proj_dim,
dtype=dtype,
mapping=model_config.mapping,
quant_config=quant_config,
skip_create_weights_in_init=skip_create_weights,
force_dynamic_quantization=force_dynamic_quant,
)
self.text_embedder = PixArtAlphaTextProjection(text_embed_dim, dim, act_fn="gelu_tanh")
self.image_embedder = None
if image_embed_dim is not None:
self.image_embedder = WanImageEmbedding(
image_embed_dim, dim, pos_embed_seq_len=pos_embed_seq_len, model_config=model_config
)
def forward(self, timestep, encoder_hidden_states, encoder_hidden_states_image=None):
timestep = self.timesteps_proj(timestep)
# Get time_embedder dtype
time_embedder_dtype = next(iter(self.time_embedder.parameters())).dtype
if timestep.dtype != time_embedder_dtype and time_embedder_dtype not in [
torch.float8_e4m3fn,
torch.float8_e5m2,
]:
timestep = timestep.to(time_embedder_dtype)
temb = self.time_embedder(timestep).type_as(encoder_hidden_states)
temb_proj = self.time_proj(self.act_fn(temb))
encoder_hidden_states = self.text_embedder(encoder_hidden_states)
if encoder_hidden_states_image is not None and self.image_embedder is not None:
encoder_hidden_states_image = self.image_embedder(encoder_hidden_states_image)
return temb, temb_proj, encoder_hidden_states, encoder_hidden_states_image
class WanBlock(nn.Module):
def __init__(
self,
model_config: DiffusionModelConfig,
_layer_idx: int,
added_kv_proj_dim: int = None,
):
super().__init__()
config = model_config.pretrained_config
if hasattr(config, "hidden_size"):
hidden_size = config.hidden_size
elif hasattr(config, "attention_head_dim") and hasattr(config, "num_attention_heads"):
hidden_size = config.attention_head_dim * config.num_attention_heads
else:
hidden_size = 1536
# Wan 2.1 1.3B defaults
num_heads = getattr(config, "num_attention_heads", 12)
head_dim = getattr(config, "attention_head_dim", 128)
ffn_dim = getattr(config, "ffn_dim", 8960)
eps = getattr(config, "eps", 1e-6)
dtype = model_config.torch_dtype
quant_config = model_config.quant_config
skip_create_weights = model_config.skip_create_weights_in_init
force_dynamic_quant = model_config.force_dynamic_quantization
# Store for I2V reshaping logic
self.num_heads = num_heads
self.head_dim = head_dim
# LayerNorm weights in fp32 (matches internal float32 normalization; avoids bf16/fp32 mismatch).
self.norm1 = LayerNorm(
hidden_size=hidden_size, eps=eps, dtype=torch.float32, has_weights=False, has_bias=False
)
# Self-attention with fused QKV
self.attn1 = Attention(
hidden_size=hidden_size,
num_attention_heads=num_heads,
head_dim=head_dim,
qkv_mode=QKVMode.FUSE_QKV,
qk_norm=True,
eps=eps,
config=model_config,
layer_idx=_layer_idx,
)
# Cross-attention with separate Q, K, V
self.attn2 = Attention(
hidden_size=hidden_size,
num_attention_heads=num_heads,
head_dim=head_dim,
qkv_mode=QKVMode.SEPARATE_QKV,
qk_norm=True,
eps=eps,
config=model_config,
layer_idx=_layer_idx,
)
self.norm2 = LayerNorm(
hidden_size=hidden_size, eps=eps, dtype=torch.float32, has_weights=True, has_bias=True
)
self.norm3 = LayerNorm(
hidden_size=hidden_size, eps=eps, dtype=torch.float32, has_weights=False, has_bias=False
)
self.ffn = MLP(
hidden_size=hidden_size,
intermediate_size=ffn_dim,
bias=True,
activation=lambda x: F.gelu(x, approximate="tanh"),
dtype=dtype,
config=model_config,
layer_idx=_layer_idx,
reduce_output=False,
)
# I2V: Additional K/V projections for image embeddings
self.add_k_proj = self.add_v_proj = None
self.norm_added_k = None
if added_kv_proj_dim is not None:
self.add_k_proj = Linear(
added_kv_proj_dim,
hidden_size,
dtype=dtype,
mapping=model_config.mapping,
quant_config=quant_config,
skip_create_weights_in_init=skip_create_weights,
force_dynamic_quantization=force_dynamic_quant,
)
self.add_v_proj = Linear(
added_kv_proj_dim,
hidden_size,
dtype=dtype,
mapping=model_config.mapping,
quant_config=quant_config,
skip_create_weights_in_init=skip_create_weights,
force_dynamic_quantization=force_dynamic_quant,
)
self.norm_added_k = RMSNorm(
hidden_size=hidden_size, eps=eps, dtype=dtype, has_weights=True
)
# Use torch.empty().normal_(std=...) instead of torch.randn()/scale for MetaInitMode compatibility
self.scale_shift_table = nn.Parameter(
torch.empty(1, 6, hidden_size).normal_(std=hidden_size**-0.5)
)
def forward(
self,
x,
encoder_hidden_states,
temb,
freqs_cos,
freqs_sin,
):
shift_msa, scale_msa, gate_msa, c_shift_msa, c_scale_msa, c_gate_msa = (
self.scale_shift_table.float() + temb.float()
).chunk(6, dim=1)
normed = self.norm1(x.float()) * (1 + scale_msa) + shift_msa
normed = normed.to(x.dtype)
# Prepare frequencies for Attention
freqs = (freqs_cos, freqs_sin) if freqs_cos is not None and freqs_sin is not None else None
# Self-attention with RoPE
x = (
x.float()
+ self.attn1(
normed,
freqs=freqs,
).float()
* gate_msa
).to(x.dtype)
norm_x = self.norm2(x.float()).to(x.dtype)
# I2V: Split encoder_hidden_states into image and text parts if needed
encoder_hidden_states_img = None
encoder_hidden_states_text = encoder_hidden_states
if self.add_k_proj is not None:
image_context_length = encoder_hidden_states.shape[1] - 512
encoder_hidden_states_img = encoder_hidden_states[:, :image_context_length]
encoder_hidden_states_text = encoder_hidden_states[:, image_context_length:]
# Text cross-attention
attn2_output = self.attn2(norm_x, encoder_hidden_states=encoder_hidden_states_text)
# I2V: Additional image cross-attention if image embeddings are present
if encoder_hidden_states_img is not None:
batch_size, seq_len = norm_x.shape[:2]
query = self.attn2.get_qkv(norm_x, None)[0] # Q only
query, _ = self.attn2.apply_qk_norm(query, query)
key_img = self.add_k_proj(encoder_hidden_states_img)
value_img = self.add_v_proj(encoder_hidden_states_img)
key_img = self.norm_added_k(key_img)
query = query.view(batch_size, seq_len, self.num_heads, self.head_dim)
key_img = key_img.view(
batch_size, encoder_hidden_states_img.shape[1], self.num_heads, self.head_dim
)
value_img = value_img.view(
batch_size, encoder_hidden_states_img.shape[1], self.num_heads, self.head_dim
)
attn_img_output = self.attn2._attn_impl(
query,
key_img,
value_img,
batch_size=batch_size,
seq_len=seq_len,
kv_seq_len=encoder_hidden_states_img.shape[1],
)
attn2_output = attn2_output + attn_img_output
x = x + attn2_output
# 3. Feed-forward
normed = self.norm3(x.float()) * (1 + c_scale_msa) + c_shift_msa
normed = normed.to(x.dtype)
x = (x.float() + self.ffn(normed).float() * c_gate_msa).to(x.dtype)
return x
class WanTransformer3DModel(nn.Module):
_supports_gradient_checkpointing = True
def __init__(
self,
model_config: DiffusionModelConfig,
):
super().__init__()
self.model_config = model_config
# Validate no tensor parallelism
if model_config.parallel.dit_tp_size > 1:
raise ValueError(
f"WAN does not support tensor parallelism. "
f"Got dit_tp_size={model_config.parallel.dit_tp_size}"
)
# Setup sequence parallelism (Ulysses)
num_heads = getattr(model_config.pretrained_config, "num_attention_heads", 12)
self.use_ulysses, self.ulysses_size, self.ulysses_pg, self.ulysses_rank = (
setup_sequence_parallelism(
model_config=model_config,
num_attention_heads=num_heads,
)
)
config = model_config.pretrained_config
dtype = model_config.torch_dtype
quant_config = model_config.quant_config
skip_create_weights = model_config.skip_create_weights_in_init
force_dynamic_quant = model_config.force_dynamic_quantization
if hasattr(config, "hidden_size"):
hidden_size = config.hidden_size
elif hasattr(config, "attention_head_dim") and hasattr(config, "num_attention_heads"):
hidden_size = config.attention_head_dim * config.num_attention_heads
else:
hidden_size = 1536 # Wan 1.3B default
num_layers = getattr(config, "num_layers", 30)
attention_head_dim = getattr(config, "attention_head_dim", 128)
in_channels = getattr(config, "in_channels", 16)
out_channels = getattr(config, "out_channels", 16)
text_dim = getattr(config, "text_dim", 4096)
freq_dim = getattr(config, "freq_dim", 256)
patch_size = getattr(config, "patch_size", [1, 2, 2])
image_embed_dim = getattr(config, "image_dim", None) # e.g., 1280 for I2V
added_kv_proj_dim = getattr(config, "added_kv_proj_dim", None)
pos_embed_seq_len = getattr(config, "pos_embed_seq_len", None)
# Calculate FFN Dim
ffn_dim = getattr(config, "ffn_dim", None)
if ffn_dim is None:
ffn_dim = (
13824
if hidden_size == 5120
else (8960 if hidden_size == 1536 else int(hidden_size * 4))
)
# Store config for unpatchify and pipeline compatibility
self.config = type(
"Config",
(),
{
"patch_size": patch_size,
"hidden_size": hidden_size,
"image_dim": image_embed_dim,
"in_channels": in_channels,
"out_channels": out_channels,
"num_layers": num_layers,
},
)()
self.patch_embedding = nn.Conv3d(
in_channels,
hidden_size,
kernel_size=patch_size,
stride=patch_size,
dtype=dtype, # use model's target dtype (bf16)
)
self.condition_embedder = WanTimeTextImageEmbedding(
dim=hidden_size,
time_freq_dim=freq_dim,
time_proj_dim=hidden_size * 6,
text_embed_dim=text_dim,
model_config=model_config,
image_embed_dim=image_embed_dim,
pos_embed_seq_len=pos_embed_seq_len,
)
self.blocks = nn.ModuleList(
[
WanBlock(
model_config=model_config,
_layer_idx=i,
added_kv_proj_dim=added_kv_proj_dim,
)
for i in range(num_layers)
]
)
self.rope = WanRotaryPosEmbed(attention_head_dim, patch_size, max_seq_len=1024)
# LayerNorm weights in fp32 (matches internal float32 normalization; avoids bf16/fp32 mismatch).
self.norm_out = LayerNorm(
hidden_size=hidden_size,
eps=1e-6,
dtype=torch.float32,
has_weights=False,
has_bias=False,
)
self.proj_out = Linear(
hidden_size,
out_channels * math.prod(patch_size),
dtype=dtype,
mapping=model_config.mapping,
quant_config=quant_config,
skip_create_weights_in_init=skip_create_weights,
force_dynamic_quantization=force_dynamic_quant,
)
# Use torch.empty().normal_(std=...) instead of torch.randn()/scale for MetaInitMode compatibility
self.scale_shift_table = nn.Parameter(
torch.empty(1, 2, hidden_size).normal_(std=hidden_size**-0.5)
)
self.__post_init__()
@property
def device(self):
return get_parameter_device(self)
def __post_init__(self):
self.apply_quant_config_exclude_modules()
for _, module in self.named_modules():
if callable(getattr(module, "create_weights", None)):
module.create_weights()
def apply_quant_config_exclude_modules(self):
quant_config = self.model_config.quant_config
if quant_config is None or quant_config.exclude_modules is None:
return
kv_cache_quant_algo = quant_config.kv_cache_quant_algo if quant_config else None
no_quant_config = QuantConfig(kv_cache_quant_algo=kv_cache_quant_algo)
for name, module in self.named_modules():
if isinstance(module, Linear):
is_excluded = quant_config.is_module_excluded_from_quantization(name)
if is_excluded and getattr(module, "quant_config", None) is not None:
module.quant_config = no_quant_config
def unpatchify(self, x, original_shape):
N, C, T, H, W = original_shape
pt, ph, pw = self.config.patch_size
gt, gh, gw = T // pt, H // ph, W // pw
# Use output channels instead of input channels for unpatchifying
out_channels = self.proj_out.out_features // (pt * ph * pw)
return (
x.view(N, gt, gh, gw, pt, ph, pw, out_channels)
.permute(0, 7, 1, 4, 2, 5, 3, 6)
.reshape(N, out_channels, T, H, W)
)
def forward(
self,
hidden_states,
timestep,
encoder_hidden_states,
encoder_hidden_states_image=None,
**kwargs,
):
"""
Forward pass with optional Ulysses sequence parallelism.
With Ulysses enabled (ulysses_size > 1):
1. Shard input sequence across ranks: [B, S] -> [B, S/P]
2. Each block's attention does internal all-to-all for full sequence
3. Gather output sequence: [B, S/P] -> [B, S]
When TeaCache is enabled, TeaCacheHook intercepts and replaces this call.
"""
original_shape = hidden_states.shape
B, C, T, H, W = original_shape
pt, ph, pw = self.config.patch_size
# Generate WAN RoPE frequencies
freqs_cos, freqs_sin = self.rope(hidden_states)
# Patchify and flatten: [B, C, T, H, W] -> [B, S, hidden_size]
x = self.patch_embedding(hidden_states).flatten(2).transpose(1, 2)
# Shard sequence for Ulysses parallelism: [B, S] -> [B, S/P]
if self.use_ulysses:
seq_len = x.shape[1]
if seq_len % self.ulysses_size != 0:
raise ValueError(
f"Sequence length ({seq_len}) is not divisible by ulysses_size ({self.ulysses_size}). "
f"Adjust video dimensions or use a different ulysses_size."
)
chunk_size = seq_len // self.ulysses_size
x = x[:, self.ulysses_rank * chunk_size : (self.ulysses_rank + 1) * chunk_size, :]
# Shard RoPE frequencies to match sequence sharding
# RoPE freqs shape: [B, S, ...], so shard along dim 1 (sequence dimension)
if freqs_cos is not None and freqs_sin is not None:
freqs_cos = freqs_cos[
:, self.ulysses_rank * chunk_size : (self.ulysses_rank + 1) * chunk_size
]
freqs_sin = freqs_sin[
:, self.ulysses_rank * chunk_size : (self.ulysses_rank + 1) * chunk_size
]
# Time and text/image embeddings
temb, temb_proj, encoder_hidden_states, encoder_hidden_states_image = (
self.condition_embedder(timestep, encoder_hidden_states, encoder_hidden_states_image)
)
temb_proj = temb_proj.view(-1, 6, self.config.hidden_size)
# I2V: Concatenate image and text embeddings if image embeddings are provided
if encoder_hidden_states_image is not None:
# Handle CFG: duplicate image embeddings if batch dimension is doubled
if encoder_hidden_states_image.shape[0] != encoder_hidden_states.shape[0]:
batch_multiplier = (
encoder_hidden_states.shape[0] // encoder_hidden_states_image.shape[0]
)
encoder_hidden_states_image = encoder_hidden_states_image.repeat(
batch_multiplier, 1, 1
)
encoder_hidden_states = torch.cat(
[encoder_hidden_states_image, encoder_hidden_states], dim=1
)
# Transformer blocks (attention handles all-to-all internally for Ulysses)
for block in self.blocks:
x = block(
x,
encoder_hidden_states,
temb_proj,
freqs_cos,
freqs_sin,
)
# Gather sequence from all ranks: [B, S/P] -> [B, S]
if self.use_ulysses:
# Ensure tensor is contiguous before all_gather
x = x.contiguous()
x_list = [torch.zeros_like(x) for _ in range(self.ulysses_size)]
torch.distributed.all_gather(x_list, x, group=self.ulysses_pg)
x = torch.cat(x_list, dim=1)
# Output projection and unpatchify
shift, scale = (self.scale_shift_table + temb.unsqueeze(1)).chunk(2, dim=1)
x = self.norm_out(x) * (1 + scale) + shift
x = x.to(hidden_states.dtype)
return self.unpatchify(self.proj_out(x), original_shape)
def load_weights(self, weights: dict) -> None:
# Remap checkpoint keys to match model structure
remapped_weights = {}
for key, value in weights.items():
# Remap transformer block FFN keys
if ".ffn.net.0.proj." in key:
new_key = key.replace(".ffn.net.0.proj.", ".ffn.up_proj.")
remapped_weights[new_key] = value
elif ".ffn.net.2." in key:
new_key = key.replace(".ffn.net.2.", ".ffn.down_proj.")
remapped_weights[new_key] = value
# Remap image embedder FF keys
elif ".image_embedder.ff.net.0.proj." in key:
new_key = key.replace(".image_embedder.ff.net.0.proj.", ".image_embedder.ff_in.")
remapped_weights[new_key] = value
elif ".image_embedder.ff.net.2." in key:
new_key = key.replace(".image_embedder.ff.net.2.", ".image_embedder.ff_out.")
remapped_weights[new_key] = value
# Remap I2V attention keys
elif ".attn2.add_k_proj." in key:
new_key = key.replace(".attn2.add_k_proj.", ".add_k_proj.")
remapped_weights[new_key] = value
elif ".attn2.add_v_proj." in key:
new_key = key.replace(".attn2.add_v_proj.", ".add_v_proj.")
remapped_weights[new_key] = value
elif ".attn2.norm_added_k." in key:
new_key = key.replace(".attn2.norm_added_k.", ".norm_added_k.")
remapped_weights[new_key] = value
else:
remapped_weights[key] = value
weights = remapped_weights
# Handle root-level parameters (filter_weights doesn't work for empty prefix)
for param_name, param in self._parameters.items():
if param is not None and param_name in weights:
param.data.copy_(weights[param_name].to(self.model_config.torch_dtype))
params_map = {
"qkv_proj": ["to_q", "to_k", "to_v"],
}
loader = DynamicLinearWeightLoader(self.model_config, params_map=params_map)
for name, module in tqdm(self.named_modules(), desc="Loading weights"):
if len(module._parameters) == 0:
continue
if isinstance(module, Linear):
weight_dicts = loader.get_linear_weights(module, name, weights)
if weight_dicts:
loader.load_linear_weights(module, name, weight_dicts)
elif "add_k_proj" in name or "add_v_proj" in name:
logger.info(f"[Weight Loading] No weights found for I2V module: {name}")
else:
module_weights = loader.filter_weights(name, weights)
for param_name, param in module._parameters.items():
if param is not None and param_name in module_weights:
param.data.copy_(
module_weights[param_name].to(self.model_config.torch_dtype)
)
def post_load_weights(self) -> None:
"""Call post_load_weights on all Linear modules and convert embedders to target dtype."""
# Convert condition_embedder components to target dtype
target_dtype = self.model_config.torch_dtype
if hasattr(self.condition_embedder, "time_embedder"):
self.condition_embedder.time_embedder.to(target_dtype)
if hasattr(self.condition_embedder, "text_embedder"):
self.condition_embedder.text_embedder.to(target_dtype)
# Call post_load_weights on all Linear modules
for _, module in self.named_modules():
if isinstance(module, Linear):
module.post_load_weights()

View File

@ -0,0 +1,26 @@
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Visual Generation Modules
This module provides modular neural network components for visual generation models.
"""
from .attention import Attention, QKVMode
__all__ = [
"Attention",
"QKVMode",
]

View File

@ -0,0 +1,284 @@
from enum import Enum
from typing import TYPE_CHECKING, Optional, Tuple
import torch
import torch.nn as nn
from ...modules.linear import Linear, WeightMode, WeightsLoadingConfig
from ...modules.rms_norm import RMSNorm
from ..attention_backend.interface import AttentionTensorLayout
from ..attention_backend.utils import create_attention
if TYPE_CHECKING:
from ..config import DiffusionModelConfig
class QKVMode(str, Enum):
FUSE_QKV = "fuse_qkv"
FUSE_KV = "fuse_kv"
SEPARATE_QKV = "separate"
# TODO: torch compile
def apply_rotary_emb(
x: torch.Tensor, freqs_cos: torch.Tensor, freqs_sin: torch.Tensor
) -> torch.Tensor:
freqs_cos = freqs_cos.to(x.dtype)
freqs_sin = freqs_sin.to(x.dtype)
x1, x2 = x.unflatten(-1, (-1, 2)).unbind(-1) # [B, S, H, D/2]
cos = freqs_cos[..., 0::2]
sin = freqs_sin[..., 1::2]
return torch.stack([x1 * cos - x2 * sin, x1 * sin + x2 * cos], dim=-1).flatten(-2)
class Attention(nn.Module):
"""Attention module for visual generation models."""
def __init__(
self,
hidden_size: int,
num_attention_heads: int,
num_key_value_heads: Optional[int] = None,
head_dim: Optional[int] = None,
qkv_mode: QKVMode = QKVMode.FUSE_QKV,
qk_norm: bool = True,
eps: float = 1e-6, # TODO: remove this, we should add this to the config
config: Optional["DiffusionModelConfig"] = None,
layer_idx: Optional[int] = None,
):
super().__init__()
config = config or DiffusionModelConfig()
self.dtype = config.torch_dtype
self.quant_config = config.quant_config
self.skip_create_weights_in_init = config.skip_create_weights_in_init
self.force_dynamic_quantization = config.force_dynamic_quantization
self.mapping = getattr(config, "mapping", None)
self.hidden_size = hidden_size
self.num_attention_heads = num_attention_heads
self.num_key_value_heads = num_key_value_heads or num_attention_heads
self.head_dim = head_dim or (hidden_size // num_attention_heads)
self.qkv_mode = QKVMode(qkv_mode) if isinstance(qkv_mode, str) else qkv_mode
# Select compute backend (orthogonal to parallelism)
ulysses_size = config.parallel.dit_ulysses_size
base_backend = config.attention.backend
if self.qkv_mode == QKVMode.SEPARATE_QKV:
backend_name = "VANILLA" # Cross-attention requires VANILLA
else:
backend_name = base_backend
self.attn_backend = backend_name
self.qk_norm = qk_norm
self.layer_idx = layer_idx if layer_idx is not None else 0
self.eps = eps
self.q_dim = self.num_attention_heads * self.head_dim
self.kv_dim = self.num_key_value_heads * self.head_dim
self._init_qkv_proj()
if self.qk_norm:
self.norm_q = RMSNorm(
hidden_size=self.q_dim, eps=self.eps, dtype=self.dtype, has_weights=True
)
self.norm_k = RMSNorm(
hidden_size=self.kv_dim, eps=self.eps, dtype=self.dtype, has_weights=True
)
# TODO: Use weight mapper to create just a Linear module
self.to_out = nn.ModuleList(
[
Linear(
self.q_dim,
self.hidden_size,
dtype=self.dtype,
mapping=self.mapping,
quant_config=self.quant_config,
skip_create_weights_in_init=self.skip_create_weights_in_init,
force_dynamic_quantization=self.force_dynamic_quantization,
)
]
)
# Compute head counts for the backend
# Ulysses shards heads across workers; inner backend sees sharded count
if ulysses_size > 1 and self.qkv_mode != QKVMode.SEPARATE_QKV:
backend_num_heads = self.num_attention_heads // ulysses_size
backend_num_kv_heads = self.num_key_value_heads // ulysses_size
else:
backend_num_heads = self.num_attention_heads
backend_num_kv_heads = self.num_key_value_heads
# Create compute backend
self.attn = create_attention(
backend=backend_name,
layer_idx=self.layer_idx,
num_heads=backend_num_heads,
head_dim=self.head_dim,
num_kv_heads=backend_num_kv_heads,
quant_config=self.quant_config,
dtype=self.dtype,
)
# Wrap with parallelism strategy (orthogonal to backend choice)
if ulysses_size > 1 and self.qkv_mode != QKVMode.SEPARATE_QKV:
from ..attention_backend.parallel import UlyssesAttention
process_group = getattr(config, "ulysses_process_group", None)
self.attn = UlyssesAttention(
inner_backend=self.attn,
process_group=process_group,
)
def _init_qkv_proj(self) -> None:
if self.qkv_mode == QKVMode.FUSE_QKV:
qkv_out_dim = self.q_dim + 2 * self.kv_dim
self.qkv_proj = Linear(
self.hidden_size,
qkv_out_dim,
dtype=self.dtype,
mapping=self.mapping,
quant_config=self.quant_config,
skip_create_weights_in_init=self.skip_create_weights_in_init,
force_dynamic_quantization=self.force_dynamic_quantization,
weights_loading_config=WeightsLoadingConfig(
weight_mode=WeightMode.FUSED_QKV_LINEAR
),
fused_weight_shard_indices_mapping={
"q": (0, self.q_dim),
"k": (self.q_dim, self.kv_dim),
"v": (self.q_dim + self.kv_dim, self.kv_dim),
},
)
else:
self.to_q = Linear(
self.hidden_size,
self.q_dim,
dtype=self.dtype,
mapping=self.mapping,
quant_config=self.quant_config,
skip_create_weights_in_init=self.skip_create_weights_in_init,
force_dynamic_quantization=self.force_dynamic_quantization,
)
self.to_k = Linear(
self.hidden_size,
self.kv_dim,
dtype=self.dtype,
mapping=self.mapping,
quant_config=self.quant_config,
skip_create_weights_in_init=self.skip_create_weights_in_init,
force_dynamic_quantization=self.force_dynamic_quantization,
)
self.to_v = Linear(
self.hidden_size,
self.kv_dim,
dtype=self.dtype,
mapping=self.mapping,
quant_config=self.quant_config,
skip_create_weights_in_init=self.skip_create_weights_in_init,
force_dynamic_quantization=self.force_dynamic_quantization,
)
def get_qkv(
self,
hidden_states: torch.Tensor,
encoder_hidden_states: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
if self.qkv_mode == QKVMode.FUSE_QKV:
qkv = self.qkv_proj(hidden_states)
q, k, v = qkv.split([self.q_dim, self.kv_dim, self.kv_dim], dim=-1)
else:
kv_source = (
encoder_hidden_states if encoder_hidden_states is not None else hidden_states
)
q = self.to_q(hidden_states)
k = self.to_k(kv_source)
v = self.to_v(kv_source)
return q, k, v
def apply_qk_norm(self, q: torch.Tensor, k: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
if self.qk_norm:
q = self.norm_q(q)
k = self.norm_k(k)
return q, k
def _attn_impl(
self,
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
batch_size: Optional[int] = None,
seq_len: Optional[int] = None,
kv_seq_len: Optional[int] = None,
) -> torch.Tensor:
"""
Call attention backend with appropriate tensor layout.
Two layout paths:
1. HND backends (VANILLA): [B, S, H*D] -> [B, H, S, D]
2. NHD backends (TRTLLM, UlyssesAttention): [B, S, H*D] -> [B, S, H, D]
"""
backend_layout = getattr(self.attn, "preferred_layout", AttentionTensorLayout.NHD)
batch_size = batch_size or q.shape[0]
seq_len = seq_len or q.shape[1]
kv_seq_len = kv_seq_len or k.shape[1]
# Reshape inputs: [B, S, H*D] -> backend's preferred 4D layout
if backend_layout == AttentionTensorLayout.HND:
q = q.view(batch_size, -1, self.num_attention_heads, self.head_dim).transpose(1, 2)
k = k.view(batch_size, -1, self.num_key_value_heads, self.head_dim).transpose(1, 2)
v = v.view(batch_size, -1, self.num_key_value_heads, self.head_dim).transpose(1, 2)
else:
q = q.view(batch_size, -1, self.num_attention_heads, self.head_dim)
k = k.view(batch_size, -1, self.num_key_value_heads, self.head_dim)
v = v.view(batch_size, -1, self.num_key_value_heads, self.head_dim)
# Call backend
out = self.attn.forward(
q=q,
k=k,
v=v,
batch_size=batch_size,
seq_len=seq_len,
seq_len_kv=kv_seq_len if kv_seq_len != seq_len else None,
)
# Flatten back to [B, S, H*D]
if backend_layout == AttentionTensorLayout.HND:
return out.transpose(1, 2).flatten(2)
else:
return out.flatten(2)
def forward(
self,
hidden_states: torch.Tensor,
encoder_hidden_states: Optional[torch.Tensor] = None,
freqs: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
) -> torch.Tensor:
assert hidden_states.ndim == 3, "hidden_states must be a 3D tensor"
batch_size, seq_len = hidden_states.shape[:2]
kv_seq_len = (
encoder_hidden_states.shape[1] if encoder_hidden_states is not None else seq_len
)
q, k, v = self.get_qkv(hidden_states, encoder_hidden_states)
q, k = self.apply_qk_norm(q, k)
# Apply RoPE if provided (model handles RoPE, not attention backend)
if freqs is not None:
freqs_cos, freqs_sin = freqs
q = q.view(batch_size, seq_len, self.num_attention_heads, self.head_dim) # [B, S, H, D]
k = k.view(batch_size, kv_seq_len, self.num_key_value_heads, self.head_dim)
q = apply_rotary_emb(q, freqs_cos, freqs_sin)
k = apply_rotary_emb(k, freqs_cos, freqs_sin)
q = q.flatten(2)
k = k.flatten(2)
out = self._attn_impl(q, k, v, batch_size, seq_len, kv_seq_len)
out = self.to_out[0](out)
return out

View File

@ -0,0 +1,29 @@
"""Output dataclass for visual generation models."""
from dataclasses import dataclass
from typing import Optional
import torch
@dataclass
class MediaOutput:
"""Unified output for all visual generation models.
Different models populate different fields:
- FLUX2: image only
- WAN: video only
- LTX2: video + audio
Attributes:
image: Generated image as torch tensor with shape (height, width, channels) and dtype uint8.
Populated by FLUX2 for text-to-image generation.
video: Generated video frames as torch tensor with shape (num_frames, height, width, channels) and dtype uint8.
Populated by WAN and LTX2 for text-to-video generation.
audio: Generated audio as torch tensor with dtype float32.
Populated by LTX2 for text-to-video-with-audio generation.
"""
image: Optional[torch.Tensor] = None
video: Optional[torch.Tensor] = None
audio: Optional[torch.Tensor] = None

View File

@ -0,0 +1,100 @@
"""Utilities for distributed parallelism setup in diffusion models."""
from typing import Optional, Tuple
import torch.distributed as dist
from tensorrt_llm._torch.visual_gen.config import DiffusionModelConfig
def setup_sequence_parallelism(
model_config: DiffusionModelConfig,
num_attention_heads: int,
) -> Tuple[bool, int, Optional[dist.ProcessGroup], int]:
"""
Setup sequence parallelism (currently Ulysses only) with CFG support.
Creates nested process groups where each CFG group has its own Ulysses group.
Example with cfg_size=2, ulysses_size=2, world_size=4:
GPU 0-1: CFG group 0, Ulysses group 0
GPU 2-3: CFG group 1, Ulysses group 1
Args:
model_config: Model configuration containing parallel settings
num_attention_heads: Number of attention heads in the model
Returns:
Tuple of (use_parallelism, parallelism_size, parallelism_pg, parallelism_rank):
- use_parallelism: Whether sequence parallelism is enabled
- parallelism_size: The sequence parallelism degree
- parallelism_pg: The process group for this rank (or None)
- parallelism_rank: This rank's position within its parallelism group
Raises:
RuntimeError: If torch.distributed is not initialized
ValueError: If configuration is invalid (incompatible sizes, head count not divisible, etc.)
NotImplementedError: If Ring attention is requested (not yet implemented)
Side Effects:
- Sets model_config.ulysses_process_group to the created process group
Note:
Both num_attention_heads and sequence length must be divisible by ulysses_size.
Head count is validated here; sequence length is validated at runtime during forward pass.
"""
ulysses_size = model_config.parallel.dit_ulysses_size
ring_size = model_config.parallel.dit_ring_size
cfg_size = model_config.parallel.dit_cfg_size
# Check for ring attention (not yet implemented)
if ring_size > 1:
raise NotImplementedError("Ring attention parallelism is not yet implemented")
# Early exit if not using sequence parallelism
if ulysses_size <= 1:
model_config.ulysses_process_group = None
return False, 1, None, 0
# Validate distributed initialization
if not dist.is_initialized():
raise RuntimeError(
"torch.distributed.init_process_group() must be called before "
"setting up sequence parallelism"
)
rank = dist.get_rank()
world_size = dist.get_world_size()
# Validate total parallelism capacity
total_parallel = cfg_size * ulysses_size
if total_parallel > world_size:
raise ValueError(
f"cfg_size ({cfg_size}) * ulysses_size ({ulysses_size}) = "
f"{total_parallel} exceeds world_size ({world_size})"
)
# Validate head count divisibility
if num_attention_heads % ulysses_size != 0:
raise ValueError(
f"num_attention_heads ({num_attention_heads}) must be divisible by "
f"ulysses_size ({ulysses_size})"
)
# Create nested process groups
# Each CFG group has its own Ulysses group
ulysses_pg = None
ulysses_rank = 0
for cfg_id in range(cfg_size):
ulysses_ranks = list(range(cfg_id * ulysses_size, (cfg_id + 1) * ulysses_size))
pg = dist.new_group(ulysses_ranks, use_local_synchronization=True)
# Store if this rank belongs to this group
if rank in ulysses_ranks:
ulysses_pg = pg
ulysses_rank = rank - cfg_id * ulysses_size
# Store in config for Attention modules
model_config.ulysses_process_group = ulysses_pg
return True, ulysses_size, ulysses_pg, ulysses_rank

View File

@ -0,0 +1,544 @@
import time
from typing import TYPE_CHECKING, Any, Callable, Dict, Optional, Tuple
import torch
import torch.distributed as dist
import torch.nn as nn
from tensorrt_llm.logger import logger
from tensorrt_llm.mapping import Mapping
from .teacache import TeaCacheBackend
if TYPE_CHECKING:
from .config import DiffusionModelConfig
class BasePipeline(nn.Module):
"""
Base class for diffusion pipelines.
"""
def __init__(self, model_config: "DiffusionModelConfig"):
super().__init__()
self.model_config = model_config
self.config = model_config.pretrained_config
self.mapping: Mapping = getattr(model_config, "mapping", None) or Mapping()
# Components
self.transformer: Optional[nn.Module] = None
self.vae: Optional[nn.Module] = None
self.text_encoder: Optional[nn.Module] = None
self.tokenizer: Optional[Any] = None
self.scheduler: Optional[Any] = None
# Initialize transformer
self._init_transformer()
@property
def rank(self):
return dist.get_rank() if dist.is_initialized() else 0
@property
def world_size(self):
return dist.get_world_size() if dist.is_initialized() else 1
@property
def dtype(self):
if hasattr(self, "transformer"):
return next(self.transformer.parameters()).dtype
return torch.float32
@property
def device(self):
return self.transformer.device
def infer(self, req: Any):
raise NotImplementedError
def _init_transformer(self) -> None:
raise NotImplementedError
def forward(self, *args, **kwargs):
raise NotImplementedError
def load_standard_components(
self,
checkpoint_dir: str,
device: torch.device,
skip_components: Optional[list] = None,
) -> None:
raise NotImplementedError
def load_weights(self, weights: Dict[str, torch.Tensor]) -> None:
if self.transformer is not None and hasattr(self.transformer, "load_weights"):
self.transformer.load_weights(weights)
def post_load_weights(self) -> None:
if self.transformer is not None and hasattr(self.transformer, "post_load_weights"):
self.transformer.post_load_weights()
def _setup_teacache(self, model, coefficients: Optional[Dict] = None):
"""Setup TeaCache optimization for the transformer model.
TeaCache caches transformer block outputs when timestep embeddings change slowly,
reducing computation during the denoising loop.
Args:
model: The transformer model to optimize
coefficients: Optional dict of model-specific polynomial coefficients for cache decisions
Format: {model_size: {"ret_steps": [...], "standard": [...]}}
"""
self.cache_backend = None
# Get teacache config from model_config (always present now)
teacache_cfg = self.model_config.teacache
if not teacache_cfg.enable_teacache:
return
# Apply model-specific polynomial coefficients
# Coefficients are used to rescale embedding distances for cache decisions
if coefficients:
checkpoint_path = (
getattr(self.model_config.pretrained_config, "_name_or_path", "") or ""
)
for model_size, coeff_data in coefficients.items():
# Match model size in path (case-insensitive, e.g., "1.3B", "14B", "dev")
if model_size.lower() in checkpoint_path.lower():
if isinstance(coeff_data, dict):
# Select coefficient set based on warmup mode
mode = "ret_steps" if teacache_cfg.use_ret_steps else "standard"
if mode in coeff_data:
teacache_cfg.coefficients = coeff_data[mode]
logger.info(f"TeaCache: Using {model_size} coefficients ({mode} mode)")
else:
# Single coefficient list (no mode distinction)
teacache_cfg.coefficients = coeff_data
logger.info(f"TeaCache: Using {model_size} coefficients")
break
# Initialize and enable TeaCache backend
logger.info("TeaCache: Initializing...")
self.cache_backend = TeaCacheBackend(teacache_cfg)
self.cache_backend.enable(model)
def decode_latents(
self,
latents: torch.Tensor,
decode_fn: Callable[[torch.Tensor], Any],
extra_latents: Optional[Dict[str, Tuple[torch.Tensor, Callable]]] = None,
):
"""Execute VAE decoding. Only rank 0 performs decoding.
Args:
latents: Primary latents to decode (e.g., video)
decode_fn: Decoder function for primary latents
extra_latents: Optional dict of additional latents to decode.
Format: {name: (latents_tensor, decode_fn)}
Example: {"audio": (audio_latents, audio_decode_fn)}
Returns:
Single result if no extra_latents, tuple of results if extra_latents provided.
Non-rank-0 processes return None placeholders.
"""
if self.rank == 0:
primary_result = decode_fn(latents)
if extra_latents:
extra_results = []
for name, (extra_latent, extra_decode_fn) in extra_latents.items():
extra_results.append(extra_decode_fn(extra_latent))
return (primary_result,) + tuple(extra_results)
return primary_result
# Return None placeholders for non-rank-0 processes
n_results = 1 + (len(extra_latents) if extra_latents else 0)
return (None,) * n_results if n_results > 1 else None
@staticmethod
def _rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0):
"""Rescale noise to fix overexposure (https://huggingface.co/papers/2305.08891)."""
std_text = noise_pred_text.std(dim=list(range(1, noise_pred_text.ndim)), keepdim=True)
std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True)
noise_pred_rescaled = noise_cfg * (std_text / std_cfg)
return guidance_rescale * noise_pred_rescaled + (1 - guidance_rescale) * noise_cfg
def _setup_cfg_config(
self, guidance_scale, prompt_embeds, neg_prompt_embeds, extra_cfg_tensors=None
):
"""Setup CFG parallel configuration.
Args:
guidance_scale: CFG guidance scale
prompt_embeds: Positive prompt embeddings
neg_prompt_embeds: Negative prompt embeddings (None if already concatenated)
extra_cfg_tensors: Optional dict of additional tensors to split for CFG parallel.
Format: {name: (positive_tensor, negative_tensor)}
Example: {"audio_embeds": (pos_audio, neg_audio),
"attention_mask": (pos_mask, neg_mask)}
Returns:
Dict with CFG configuration including split tensors
"""
# Access parallel config directly (always present now)
cfg_size = self.model_config.parallel.dit_cfg_size
ulysses_size = self.model_config.parallel.dit_ulysses_size
cfg_group = self.rank // ulysses_size
is_split_embeds = neg_prompt_embeds is not None
do_cfg_parallel = cfg_size >= 2 and guidance_scale > 1.0
local_extras = {}
if do_cfg_parallel:
if self.rank == 0:
logger.info(f"CFG Parallel: cfg_size={cfg_size}, ulysses_size={ulysses_size}")
# Split main embeddings
if is_split_embeds:
pos_embeds, neg_embeds = prompt_embeds, neg_prompt_embeds
else:
neg_embeds, pos_embeds = prompt_embeds.chunk(2)
local_embeds = pos_embeds if cfg_group == 0 else neg_embeds
# Split extra tensors if provided
if extra_cfg_tensors:
for name, (pos_tensor, neg_tensor) in extra_cfg_tensors.items():
if pos_tensor is not None and neg_tensor is not None:
local_extras[name] = pos_tensor if cfg_group == 0 else neg_tensor
elif pos_tensor is not None:
# Only positive provided, use it for both
local_extras[name] = pos_tensor
else:
local_embeds = None
if is_split_embeds and guidance_scale > 1.0:
prompt_embeds = torch.cat([neg_prompt_embeds, prompt_embeds])
# For standard CFG, concatenate extra tensors
if extra_cfg_tensors:
for name, (pos_tensor, neg_tensor) in extra_cfg_tensors.items():
if pos_tensor is not None and neg_tensor is not None and guidance_scale > 1.0:
local_extras[name] = torch.cat([neg_tensor, pos_tensor], dim=0)
elif pos_tensor is not None:
local_extras[name] = pos_tensor
return {
"enabled": do_cfg_parallel,
"cfg_size": cfg_size,
"ulysses_size": ulysses_size,
"cfg_group": cfg_group,
"local_embeds": local_embeds,
"prompt_embeds": prompt_embeds,
"local_extras": local_extras,
}
def _denoise_step_cfg_parallel(
self,
latents,
extra_stream_latents,
timestep,
local_embeds,
forward_fn,
guidance_scale,
guidance_rescale,
ulysses_size,
local_extras,
):
"""Execute single denoising step with CFG parallel."""
t_start = time.time()
result = forward_fn(latents, extra_stream_latents, timestep, local_embeds, local_extras)
# Handle return format: (primary_noise, extra_noises_dict) or just primary_noise
if isinstance(result, tuple) and len(result) == 2 and isinstance(result[1], dict):
noise_pred_local, extra_noise_locals = result
else:
noise_pred_local = result
extra_noise_locals = {}
t_transformer = time.time() - t_start
c_start = time.time()
# All-gather primary noise
gather_list = [torch.empty_like(noise_pred_local) for _ in range(self.world_size)]
dist.all_gather(gather_list, noise_pred_local)
noise_cond = gather_list[0]
noise_uncond = gather_list[ulysses_size]
noise_pred = noise_uncond + guidance_scale * (noise_cond - noise_uncond)
# All-gather extra stream noises
extra_noise_preds = {}
for name, noise_local in extra_noise_locals.items():
gather_list_extra = [torch.empty_like(noise_local) for _ in range(self.world_size)]
dist.all_gather(gather_list_extra, noise_local)
noise_cond_extra = gather_list_extra[0]
noise_uncond_extra = gather_list_extra[ulysses_size]
extra_noise_preds[name] = noise_uncond_extra + guidance_scale * (
noise_cond_extra - noise_uncond_extra
)
if guidance_rescale > 0.0:
extra_noise_preds[name] = self._rescale_noise_cfg(
extra_noise_preds[name], noise_cond_extra, guidance_rescale
)
if guidance_rescale > 0.0:
noise_pred = self._rescale_noise_cfg(noise_pred, noise_cond, guidance_rescale)
t_cfg = time.time() - c_start
return noise_pred, extra_noise_preds, t_transformer, t_cfg
def _denoise_step_standard(
self,
latents,
extra_stream_latents,
timestep,
prompt_embeds,
forward_fn,
guidance_scale,
guidance_rescale,
local_extras,
):
"""Execute single denoising step without CFG parallel."""
if guidance_scale > 1.0:
latent_input = torch.cat([latents] * 2)
# Duplicate extra stream latents for CFG
extra_stream_input = {
name: torch.cat([stream_latents] * 2)
for name, stream_latents in extra_stream_latents.items()
}
else:
latent_input = latents
extra_stream_input = extra_stream_latents
timestep_expanded = timestep.expand(latent_input.shape[0])
t_start = time.time()
result = forward_fn(
latent_input, extra_stream_input, timestep_expanded, prompt_embeds, local_extras
)
# Handle return format: (primary_noise, extra_noises_dict) or just primary_noise
if isinstance(result, tuple) and len(result) == 2 and isinstance(result[1], dict):
noise_pred, extra_noise_preds = result
else:
noise_pred = result
extra_noise_preds = {}
t_transformer = time.time() - t_start
c_start = time.time()
if guidance_scale > 1.0:
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
# Apply CFG to extra streams
for name, noise_extra in extra_noise_preds.items():
noise_uncond_extra, noise_text_extra = noise_extra.chunk(2)
extra_noise_preds[name] = noise_uncond_extra + guidance_scale * (
noise_text_extra - noise_uncond_extra
)
if guidance_rescale > 0.0:
extra_noise_preds[name] = self._rescale_noise_cfg(
extra_noise_preds[name], noise_text_extra, guidance_rescale
)
if guidance_rescale > 0.0:
noise_pred = self._rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale)
t_cfg = time.time() - c_start
else:
t_cfg = 0.0
return noise_pred, extra_noise_preds, t_transformer, t_cfg
def _scheduler_step(
self,
latents,
extra_stream_latents,
noise_pred,
extra_noise_preds,
timestep,
scheduler,
extra_stream_schedulers,
):
"""Execute scheduler step for all streams."""
t_start = time.time()
latents = scheduler.step(noise_pred, timestep, latents, return_dict=False)[0]
# Step schedulers for extra streams
for name, noise_extra in extra_noise_preds.items():
if name in extra_stream_schedulers:
extra_stream_latents[name] = extra_stream_schedulers[name].step(
noise_extra, timestep, extra_stream_latents[name], return_dict=False
)[0]
t_sched = time.time() - t_start
return latents, extra_stream_latents, t_sched
def denoise(
self,
latents: torch.Tensor,
scheduler: Any,
prompt_embeds: torch.Tensor,
guidance_scale: float,
forward_fn: Callable,
timesteps: Optional[torch.Tensor] = None,
neg_prompt_embeds: Optional[torch.Tensor] = None,
guidance_rescale: float = 0.0,
extra_cfg_tensors: Optional[Dict[str, Tuple[torch.Tensor, Optional[torch.Tensor]]]] = None,
extra_streams: Optional[Dict[str, Tuple[torch.Tensor, Any]]] = None,
guidance_scale_2: Optional[float] = None,
boundary_timestep: Optional[float] = None,
):
"""Execute denoising loop with optional CFG parallel and TeaCache support.
Args:
latents: Initial noise latents (primary stream, e.g., video)
scheduler: Diffusion scheduler for primary stream
prompt_embeds: Text embeddings (positive)
guidance_scale: CFG strength (1.0 = no guidance)
forward_fn: Transformer forward function
Signature: forward_fn(latents, extra_stream_latents, timestep,
encoder_hidden_states, extra_tensors_dict)
Returns: (primary_noise, extra_stream_noises_dict) or just primary_noise
timesteps: Optional custom timesteps (defaults to scheduler.timesteps)
neg_prompt_embeds: Optional negative text embeddings for CFG
guidance_rescale: CFG rescale factor to prevent overexposure
extra_cfg_tensors: Optional dict of additional tensors to split for CFG parallel
Format: {name: (positive_tensor, negative_tensor)}
Example: {"audio_embeds": (pos_audio, neg_audio)}
extra_streams: Optional dict of additional streams to denoise in parallel
Format: {name: (stream_latents, stream_scheduler)}
Example: {"audio": (audio_latents, audio_scheduler)}
guidance_scale_2: Optional guidance scale for two-stage denoising.
When provided with boundary_timestep, switches from guidance_scale
to guidance_scale_2 when timestep < boundary_timestep.
boundary_timestep: Optional timestep boundary for two-stage denoising.
Switches guidance scale when crossing this threshold.
Returns:
Single latents if no extra_streams
Tuple (primary_latents, extra_streams_dict) if extra_streams provided
"""
if timesteps is None:
timesteps = scheduler.timesteps
total_steps = len(timesteps)
has_extra_streams = extra_streams is not None and len(extra_streams) > 0
# Reset TeaCache state for new generation
# Sets warmup/cutoff steps based on total_steps
if (
hasattr(self, "cache_backend")
and self.cache_backend
and self.cache_backend.is_enabled()
):
self.cache_backend.refresh(total_steps)
if self.rank == 0:
if has_extra_streams:
stream_names = ", ".join(["primary"] + list(extra_streams.keys()))
logger.info(
f"Denoising [{stream_names}]: {total_steps} steps, guidance={guidance_scale}"
)
else:
logger.info(f"Denoising: {total_steps} steps, guidance={guidance_scale}")
cfg_config = self._setup_cfg_config(
guidance_scale, prompt_embeds, neg_prompt_embeds, extra_cfg_tensors
)
do_cfg_parallel = cfg_config["enabled"]
prompt_embeds = cfg_config["prompt_embeds"]
local_extras = cfg_config["local_extras"]
# Extract extra stream latents and schedulers
extra_stream_latents = {}
extra_stream_schedulers = {}
if extra_streams:
for name, (stream_latents, stream_scheduler) in extra_streams.items():
extra_stream_latents[name] = stream_latents
extra_stream_schedulers[name] = stream_scheduler
start_time = time.time()
for i, t in enumerate(timesteps):
step_start = time.time()
# Two-stage denoising: switch guidance scale at boundary
current_guidance_scale = guidance_scale
if guidance_scale_2 is not None and boundary_timestep is not None:
t_scalar = t.item() if t.dim() == 0 else t[0].item()
if t_scalar < boundary_timestep:
current_guidance_scale = guidance_scale_2
# Denoise
if do_cfg_parallel:
timestep = t.expand(latents.shape[0])
noise_pred, extra_noise_preds, t_trans, t_cfg = self._denoise_step_cfg_parallel(
latents,
extra_stream_latents,
timestep,
cfg_config["local_embeds"],
forward_fn,
current_guidance_scale,
guidance_rescale,
cfg_config["ulysses_size"],
local_extras,
)
else:
noise_pred, extra_noise_preds, t_trans, t_cfg = self._denoise_step_standard(
latents,
extra_stream_latents,
t,
prompt_embeds,
forward_fn,
current_guidance_scale,
guidance_rescale,
local_extras,
)
# Scheduler step for all streams
latents, extra_stream_latents, t_sched = self._scheduler_step(
latents,
extra_stream_latents,
noise_pred,
extra_noise_preds,
t,
scheduler,
extra_stream_schedulers,
)
# Logging
if self.rank == 0:
step_time = time.time() - step_start
avg_time = (time.time() - start_time) / (i + 1)
eta = avg_time * (total_steps - i - 1)
logger.info(
f"Step {i + 1}/{total_steps} | {step_time:.2f}s "
f"(trans={t_trans:.2f}s cfg={t_cfg:.3f}s sched={t_sched:.3f}s) | "
f"Avg={avg_time:.2f}s/step ETA={eta:.1f}s"
)
if self.rank == 0:
total_time = time.time() - start_time
logger.info("=" * 80)
logger.info(f"Denoising done: {total_time:.2f}s ({total_time / total_steps:.2f}s/step)")
# Log TeaCache performance statistics
# Shows how many transformer steps were skipped (cache hits) vs computed
if (
hasattr(self, "cache_backend")
and self.cache_backend
and self.cache_backend.is_enabled()
):
stats = self.cache_backend.get_stats()
if stats:
logger.info(
f"TeaCache: {stats['hit_rate']:.1%} hit rate ({stats['cached']}/{stats['total']} steps)"
)
return (latents, extra_stream_latents) if has_extra_streams else latents

View File

@ -0,0 +1,228 @@
"""
Model loader for diffusion pipelines.
Flow:
1. Load config via DiffusionModelConfig.from_pretrained()
2. Create pipeline via AutoPipeline.from_config() with MetaInit
3. Load weights with on-the-fly quantization if dynamic_weight_quant=True
4. Call pipeline.post_load_weights()
Dynamic Quantization:
- If quant_config specifies FP8/NVFP4 and dynamic_weight_quant=True:
- Model Linear layers are created with FP8/NVFP4 buffers
- BF16 checkpoint weights are quantized on-the-fly during loading
- Quantized weights are copied into model buffers
"""
import os
from typing import TYPE_CHECKING, Optional
import torch
from tensorrt_llm._torch.models.modeling_utils import MetaInitMode
from tensorrt_llm.llmapi.utils import download_hf_model
from tensorrt_llm.logger import logger
from tensorrt_llm.mapping import Mapping
from .checkpoints import WeightLoader
from .config import DiffusionArgs, DiffusionModelConfig, PipelineComponent
from .models import AutoPipeline
if TYPE_CHECKING:
from .models import BasePipeline
class PipelineLoader:
"""
Loader for diffusion pipelines.
Supports dynamic quantization: when quant_config specifies FP8/NVFP4,
model is built with quantized buffers and BF16 weights are quantized
on-the-fly during loading.
Example:
args = DiffusionArgs(
checkpoint_path="/path/to/model",
linear=LinearConfig(type="trtllm-fp8-blockwise"),
parallel=ParallelConfig(dit_tp_size=2),
)
pipeline = PipelineLoader(args).load()
"""
def __init__(
self,
args: Optional[DiffusionArgs] = None,
*,
mapping: Optional[Mapping] = None,
device: str = "cuda",
):
"""
Initialize model loader.
Args:
args: DiffusionArgs containing all configuration (preferred)
mapping: Tensor parallel mapping (fallback if args is None)
device: Device to load model on (fallback if args is None)
"""
self.args = args
if args is not None:
self.mapping = args.to_mapping()
self.device = torch.device(args.device)
else:
self.mapping = mapping or Mapping()
self.device = torch.device(device)
def _resolve_checkpoint_dir(self, checkpoint_dir: str) -> str:
"""Resolve checkpoint_dir to a local directory path.
If checkpoint_dir is an existing local path, returns it unchanged.
Otherwise, attempts to download from HuggingFace Hub using the
file-lock-protected ``download_hf_model`` utility (safe for
concurrent multi-process access).
Args:
checkpoint_dir: Local path or HuggingFace Hub model ID.
Returns:
Path to local directory containing the model.
Raises:
ValueError: If the path cannot be resolved (invalid repo ID,
authentication failure, offline with no cache, etc.)
"""
if os.path.exists(checkpoint_dir):
return checkpoint_dir
revision = self.args.revision if self.args else None
logger.info(
f"'{checkpoint_dir}' not found locally; "
f"attempting HuggingFace Hub download (revision={revision})"
)
try:
local_dir = download_hf_model(checkpoint_dir, revision=revision)
except Exception as e:
raise ValueError(
f"Could not resolve '{checkpoint_dir}' as a local path or "
f"HuggingFace Hub model ID: {e}"
) from e
return str(local_dir)
def load(
self,
checkpoint_dir: Optional[str] = None,
) -> "BasePipeline":
"""
Load a diffusion pipeline with optional dynamic quantization.
Flow:
1. Resolve checkpoint_dir (local path or HuggingFace Hub model ID)
2. Load config via DiffusionModelConfig.from_pretrained()
3. Create pipeline via AutoPipeline.from_config() with MetaInit
4. Load transformer weights via pipeline.load_weights()
5. Load auxiliary components (VAE, text_encoder) via diffusers
6. Call pipeline.post_load_weights()
Args:
checkpoint_dir: Local path or HF Hub model ID (uses args.checkpoint_path if not provided)
Returns:
Loaded pipeline (WanPipeline, FluxPipeline, etc.) - type auto-detected
"""
# Resolve checkpoint_dir
checkpoint_dir = checkpoint_dir or (self.args.checkpoint_path if self.args else None)
if not checkpoint_dir:
raise ValueError("checkpoint_dir must be provided or set in DiffusionArgs")
checkpoint_dir = self._resolve_checkpoint_dir(str(checkpoint_dir))
# Get loading options from args
skip_components = self.args.skip_components if self.args else []
# =====================================================================
# STEP 1: Load Config (includes quant config parsing)
# Merge pretrained checkpoint config with user-provided DiffusionArgs
# =====================================================================
logger.info(f"Loading config from {checkpoint_dir}")
config = DiffusionModelConfig.from_pretrained(
checkpoint_dir,
args=self.args,
mapping=self.mapping,
)
# Log quantization settings
if config.quant_config and config.quant_config.quant_algo:
logger.info(f"Quantization: {config.quant_config.quant_algo.name}")
logger.info(f"Dynamic weight quant: {config.dynamic_weight_quant}")
# =====================================================================
# STEP 2: Create Pipeline with MetaInit
# Pipeline type is auto-detected from model_index.json
# - Meta tensors (no GPU memory until materialization)
# - If quant_config specifies FP8, Linear layers have FP8 weight buffers
# =====================================================================
logger.info("Creating pipeline with MetaInitMode")
with MetaInitMode():
pipeline = AutoPipeline.from_config(config, checkpoint_dir)
# Convert meta tensors to CUDA tensors
self._materialize_meta_tensors(pipeline)
pipeline.to(self.device)
# =====================================================================
# STEP 3: Load Transformer Weights
# If dynamic_weight_quant=True:
# - BF16 checkpoint weights are loaded
# - Quantized on-the-fly to FP8/NVFP4 by DynamicLinearWeightLoader
# - Copied into model's quantized buffers
# =====================================================================
if pipeline.transformer is None:
raise ValueError("Pipeline has no transformer component")
transformer_components = getattr(pipeline, "transformer_components", ["transformer"])
logger.info(f"Transformer components: {transformer_components}")
transformer_path = os.path.join(checkpoint_dir, PipelineComponent.TRANSFORMER)
if not os.path.exists(transformer_path):
raise FileNotFoundError(
f"Transformer path does not exist: {transformer_path}. "
f"Checkpoint directory must contain a 'transformer' subdirectory."
)
weight_loader = WeightLoader(components=transformer_components)
# TODO: accelerate the cpu loading w/ multiprocessing
weights = weight_loader.load_weights(checkpoint_dir, self.mapping)
# Load weights into pipeline
pipeline.load_weights(weights)
# =====================================================================
# STEP 4: Load Standard Components (VAE, TextEncoder via diffusers)
# These are NOT quantized - loaded as-is from checkpoint
# =====================================================================
pipeline.load_standard_components(checkpoint_dir, self.device, skip_components)
# =====================================================================
# STEP 5: Post-load Hooks (TeaCache setup, etc.)
# =====================================================================
if hasattr(pipeline, "post_load_weights"):
pipeline.post_load_weights()
logger.info(f"Pipeline loaded: {pipeline.__class__.__name__}")
return pipeline
def _materialize_meta_tensors(self, module: torch.nn.Module) -> None:
"""
Convert meta tensors to CUDA tensors.
Meta tensors are placeholders that don't allocate GPU memory.
After model structure is defined, we materialize them to real tensors.
"""
memo = {}
def init_meta_tensor(t: torch.Tensor) -> torch.Tensor:
if t.device != torch.device("meta"):
return t
if t not in memo:
memo[t] = torch.empty_like(t, device="cuda")
return memo[t]
module._apply(init_meta_tensor)

View File

@ -0,0 +1,94 @@
"""Pipeline registry for unified config flow.
Follows: DiffusionArgs PipelineLoader DiffusionModelConfig AutoPipeline BasePipeline
All pipelines (Wan, Flux2, LTX2) register via @register_pipeline decorator.
"""
import json
import os
from typing import TYPE_CHECKING, Dict, Type
from tensorrt_llm.logger import logger
if TYPE_CHECKING:
from .config import DiffusionModelConfig
from .pipeline import BasePipeline
# Global registry: pipeline_name -> pipeline_class
PIPELINE_REGISTRY: Dict[str, Type["BasePipeline"]] = {}
def register_pipeline(name: str):
"""Register a pipeline class for AutoPipeline.
Usage:
@register_pipeline("WanPipeline")
class WanPipeline(BasePipeline):
...
"""
def decorator(cls: Type["BasePipeline"]) -> Type["BasePipeline"]:
PIPELINE_REGISTRY[name] = cls
logger.debug(f"Registered pipeline: {name} -> {cls.__name__}")
return cls
return decorator
class AutoPipeline:
"""Factory for creating pipelines from config."""
@staticmethod
def from_config(
config: "DiffusionModelConfig",
checkpoint_dir: str,
) -> "BasePipeline":
"""
Create pipeline instance from DiffusionModelConfig.
"""
# Detect pipeline type from model_index.json
pipeline_type = AutoPipeline._detect_from_checkpoint(checkpoint_dir)
if pipeline_type not in PIPELINE_REGISTRY:
raise ValueError(
f"Unknown pipeline: '{pipeline_type}'. "
f"Available: {list(PIPELINE_REGISTRY.keys())}\n"
f"Checkpoint: {checkpoint_dir}"
)
pipeline_class = PIPELINE_REGISTRY[pipeline_type]
logger.info(f"AutoPipeline: Creating {pipeline_class.__name__} from {checkpoint_dir}")
# Instantiate pipeline with DiffusionModelConfig
return pipeline_class(config)
@staticmethod
def _detect_from_checkpoint(checkpoint_dir: str) -> str:
"""Detect pipeline type."""
index_path = os.path.join(checkpoint_dir, "model_index.json")
if os.path.exists(index_path):
with open(index_path) as f:
index = json.load(f)
class_name = index.get("_class_name", "")
if class_name in PIPELINE_REGISTRY:
return class_name
if "ImageToVideo" in class_name or "I2V" in class_name:
if "Wan" in class_name:
return "WanImageToVideoPipeline"
# Generic Wan (T2V)
if "Wan" in class_name:
return "WanPipeline"
if "Flux" in class_name:
return "FluxPipeline"
if "LTX" in class_name or "Ltx" in class_name:
return "LTX2Pipeline"
raise ValueError(
f"Cannot detect pipeline type for {checkpoint_dir}\n"
f"Expected model_index.json with '_class_name' field at: {index_path}"
)

View File

@ -0,0 +1,15 @@
# SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
"""
Quantization support for diffusion models.
"""
from .loader import DynamicLinearWeightLoader
from .ops import quantize_fp8_blockwise, quantize_fp8_per_tensor
__all__ = [
"DynamicLinearWeightLoader",
"quantize_fp8_per_tensor",
"quantize_fp8_blockwise",
]

View File

@ -0,0 +1,197 @@
"""
Dynamic weight quantization loader for Linear modules.
Wraps Linear.load_weights() to perform dynamic quantization before loading to device.
"""
from typing import Dict, List, Optional
import torch
from tensorrt_llm._torch.modules.linear import Linear, WeightMode
from tensorrt_llm._torch.visual_gen.config import DiffusionModelConfig
from tensorrt_llm._torch.visual_gen.quantization.ops import (
quantize_fp8_blockwise,
quantize_fp8_per_tensor,
)
from tensorrt_llm.quantization.mode import QuantAlgo
class DynamicLinearWeightLoader:
"""
Dynamic weight quantization loader for Linear modules.
Wraps Linear.load_weights() to perform dynamic (load-time) quantization
from BF16/FP16 to FP8 before loading weights to device.
Example:
params_map = {'qkv_proj': ['to_q', 'to_k', 'to_v']}
loader = DynamicLinearWeightLoader(model_config, params_map=params_map)
for name, module in model.named_modules():
if isinstance(module, Linear):
weight_dicts = loader.get_linear_weights(module, name, weights)
loader.load_linear_weights(module, name, weight_dicts)
"""
def __init__(
self,
model_config: DiffusionModelConfig,
params_map: Optional[Dict[str, List[str]]] = None,
):
self.model_config = model_config
self.quant_config = model_config.quant_config
self.quant_config_dict = model_config.quant_config_dict
self.dynamic_weight_quant = model_config.dynamic_weight_quant
self.params_map = params_map or {}
# =========================================================================
# Weight gathering methods
# =========================================================================
def get_linear_weights(
self,
module: Linear,
full_name: str,
weights: Dict[str, torch.Tensor],
) -> List[Dict[str, torch.Tensor]]:
"""Get weights for a Linear module, auto-detecting fused weights."""
weights_config = getattr(module, "weights_loading_config", None)
if weights_config is not None:
weight_mode = getattr(weights_config, "weight_mode", None)
if weight_mode == WeightMode.FUSED_QKV_LINEAR:
fused_names = self._get_fused_names(full_name)
return self._get_fused_weights(full_name, weights, fused_names)
return self._get_vanilla_weights(full_name, weights)
def filter_weights(
self, prefix: str, weights: Dict[str, torch.Tensor]
) -> Dict[str, torch.Tensor]:
"""
Filter weights by prefix and strip the prefix.
Example:
prefix = 'blocks.0.attn1.to_q'
weights = {'blocks.0.attn1.to_q.weight': ..., 'blocks.0.attn1.to_q.bias': ...}
Returns: {'weight': ..., 'bias': ...}
"""
result = {}
prefix_dot = prefix + "."
for k, v in weights.items():
if k.startswith(prefix_dot):
result[k[len(prefix_dot) :]] = v
return result
def _get_fused_names(self, full_name: str) -> List[str]:
"""Get checkpoint names for a fused module from params_map."""
for suffix, names in self.params_map.items():
if full_name.endswith(suffix):
return names
raise ValueError(
f"No params_map entry for fused module '{full_name}'. "
f"Add mapping like {{'qkv_proj': ['to_q', 'to_k', 'to_v']}} to params_map."
)
def _get_fused_weights(
self,
full_name: str,
weights: Dict[str, torch.Tensor],
fused_names: List[str],
) -> List[Dict[str, torch.Tensor]]:
"""Get weights for a fused module from checkpoint."""
parent_path = ".".join(full_name.split(".")[:-1])
module_weights = []
for ckpt_name in fused_names:
ckpt_path = f"{parent_path}.{ckpt_name}" if parent_path else ckpt_name
filtered = self.filter_weights(ckpt_path, weights)
module_weights.append(filtered)
return module_weights
def _get_vanilla_weights(
self,
full_name: str,
weights: Dict[str, torch.Tensor],
) -> List[Dict[str, torch.Tensor]]:
"""Get weights for a standard (non-fused) Linear module."""
fw = self.filter_weights(full_name, weights)
return [fw] if fw else []
# =========================================================================
# Quantization methods
# =========================================================================
def _get_quant_algo_for_layer(self, name: str) -> Optional[QuantAlgo]:
"""Get quantization algorithm for a specific layer."""
if self.quant_config_dict is not None:
layer_config = self.quant_config_dict.get(name)
if layer_config is not None:
return layer_config.quant_algo
if self.quant_config is not None:
return self.quant_config.quant_algo
return None
def _should_dynamic_quantize(
self, weight_dict: Dict[str, torch.Tensor], quant_algo: Optional[QuantAlgo], name: str
) -> bool:
"""Decide if weight should be dynamically quantized at load time."""
if not self.dynamic_weight_quant or quant_algo is None:
return False
# Check if module is excluded
if self.quant_config is not None:
if self.quant_config.is_module_excluded_from_quantization(name):
return False
weight = weight_dict.get("weight")
if weight is None:
return False
# For FP8 algorithms: quantize if weight is high precision
if quant_algo in (QuantAlgo.FP8, QuantAlgo.FP8_BLOCK_SCALES):
if weight.dtype == torch.float8_e4m3fn and "weight_scale" in weight_dict:
return False # Already quantized
return weight.dtype in (torch.bfloat16, torch.float16, torch.float32)
return False
def _maybe_dynamic_quantize(
self, weight_dict: Dict[str, torch.Tensor], quant_algo: Optional[QuantAlgo], name: str
) -> Dict[str, torch.Tensor]:
"""Conditionally quantize weight at load time on GPU."""
if not self._should_dynamic_quantize(weight_dict, quant_algo, name):
return weight_dict
weight = weight_dict["weight"]
# Move to GPU only if needed
if weight.device.type != "cuda":
weight = weight.cuda()
if quant_algo == QuantAlgo.FP8:
qweight, scale = quantize_fp8_per_tensor(weight)
elif quant_algo == QuantAlgo.FP8_BLOCK_SCALES:
block_size = self.quant_config.group_size if self.quant_config else 128
qweight, scale = quantize_fp8_blockwise(weight, block_size=block_size)
else:
return weight_dict
return {**weight_dict, "weight": qweight, "weight_scale": scale}
def load_linear_weights(
self, module: Linear, name: str, weight_dicts: List[Dict[str, torch.Tensor]]
) -> None:
"""Load weights into Linear module with optional quantization."""
module_quant_config = getattr(module, "quant_config", None)
if module_quant_config is not None:
quant_algo = module_quant_config.quant_algo
else:
quant_algo = self._get_quant_algo_for_layer(name)
quantized_weight_dicts = [
self._maybe_dynamic_quantize(wd, quant_algo, name) for wd in weight_dicts
]
module.load_weights(quantized_weight_dicts)

View File

@ -0,0 +1,98 @@
"""
Quantization operations for diffusion models.
Provides on-the-fly quantization functions for dynamic (load-time) quantization.
"""
from typing import Tuple
import torch
# FP8 E4M3 max value
FP8_E4M3_MAX = torch.finfo(torch.float8_e4m3fn).max
def quantize_fp8_per_tensor(weight: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Quantize weight to FP8 E4M3 with per-tensor scale.
Uses torch.ops.tensorrt_llm.quantize_e4m3_per_tensor CUDA kernel.
Args:
weight: Input weight tensor (BF16/FP16/FP32), shape (out_features, in_features)
Returns:
Tuple of:
- qweight: Quantized weight (FP8 E4M3), same shape as input
- weight_scale: Dequantization scale (FP32), shape (1, 1)
"""
qweight, scale = torch.ops.tensorrt_llm.quantize_e4m3_per_tensor(weight)
# Ensure scale is float32 and has shape (1, 1) for consistency
return qweight, scale.to(torch.float32)
def quantize_fp8_blockwise(
weight: torch.Tensor, block_size: int = 128
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Quantize weight to FP8 E4M3 with 128x128 blockwise scales.
This function converts BF16/FP16/FP32 weights to FP8 E4M3 format using
per-block scale factors. The weight is divided into blocks of size
(block_size, block_size) and each block has its own scale.
Args:
weight: Input weight tensor (BF16/FP16/FP32), shape (out_features, in_features)
block_size: Block size for blockwise quantization (default: 128)
Returns:
Tuple of:
- qweight: Quantized weight (FP8 E4M3), shape (out_features, in_features)
- block_scales: Block-wise dequantization scales (FP32),
shape (num_blocks_out, num_blocks_in)
Note:
- If dimensions are not divisible by block_size, the last block may be smaller
- block_scales are dequantization scales (multiply to get back original scale)
- This uses 128x128 block scaling compatible with Linear module's FP8_BLOCK_SCALES
"""
out_features, in_features = weight.shape
weight_fp32 = weight.float()
# Calculate number of blocks
num_blocks_out = (out_features + block_size - 1) // block_size
num_blocks_in = (in_features + block_size - 1) // block_size
# Initialize outputs
qweight = torch.empty_like(weight, dtype=torch.float8_e4m3fn)
block_scales = torch.empty(
(num_blocks_out, num_blocks_in), dtype=torch.float32, device=weight.device
)
# Quantize each block
for i in range(num_blocks_out):
row_start = i * block_size
row_end = min((i + 1) * block_size, out_features)
for j in range(num_blocks_in):
col_start = j * block_size
col_end = min((j + 1) * block_size, in_features)
# Extract block
block = weight_fp32[row_start:row_end, col_start:col_end]
# Compute block scale
max_val = block.abs().max()
scale = (
max_val / FP8_E4M3_MAX if max_val > 0 else torch.tensor(1.0, device=weight.device)
)
# Quantize block
inv_scale = scale.reciprocal() if scale > 0 else torch.tensor(1.0, device=weight.device)
qblock = (block * inv_scale).clamp(-FP8_E4M3_MAX, FP8_E4M3_MAX).to(torch.float8_e4m3fn)
# Store results
qweight[row_start:row_end, col_start:col_end] = qblock
block_scales[i, j] = scale.to(torch.float32)
return qweight, block_scales

View File

@ -0,0 +1,409 @@
import inspect
from dataclasses import dataclass
from typing import Any, Callable, Dict, List, Optional
import numpy as np
import torch
from diffusers.models.modeling_outputs import Transformer2DModelOutput
from tensorrt_llm.logger import logger
# =============================================================================
# Core Data Structures
# =============================================================================
@dataclass
class CacheContext:
"""Context returned by model extractors for TeaCache.
Attributes:
modulated_input: Timestep embedding used for cache distance calculation
hidden_states: Input hidden states for the transformer
encoder_hidden_states: Text/prompt embeddings
run_transformer_blocks: Callable that executes the transformer forward pass
postprocess: Callable that formats the output to the expected return type
"""
modulated_input: torch.Tensor
hidden_states: torch.Tensor
encoder_hidden_states: Any = None
run_transformer_blocks: Callable = None
postprocess: Callable = None
# =============================================================================
# Extractor Registry
# =============================================================================
_EXTRACTORS = {}
def register_extractor(model_name, extractor_fn):
"""Register an extractor function for a model class."""
_EXTRACTORS[model_name] = extractor_fn
def get_extractor(model_type):
"""Get the registered extractor for a model type."""
if model_type not in _EXTRACTORS:
raise ValueError(
f"TeaCache: Unknown model '{model_type}'. Available: {list(_EXTRACTORS.keys())}"
)
return _EXTRACTORS[model_type]
# =============================================================================
# Config-Based Extractor System
# =============================================================================
@dataclass
class ExtractorConfig:
"""Configuration for model-specific TeaCache extractors.
Only the timestep embedding logic is model-specific; all other logic is handled generically.
Attributes:
model_class_name: Model class name (e.g., "LTX2VideoTransformer3DModel")
timestep_embed_fn: Callable(module, timestep, guidance=None) -> Tensor
timestep_param_name: Parameter name for timestep in forward() (default: "timestep")
guidance_param_name: Parameter name for guidance if used (default: None)
forward_params: List of parameter names (None = auto-introspect from forward signature)
return_dict_default: Default value for return_dict parameter (default: True)
output_model_class: Output class name for return type (default: "Transformer2DModelOutput")
"""
model_class_name: str
timestep_embed_fn: Callable
timestep_param_name: str = "timestep"
guidance_param_name: Optional[str] = None
forward_params: Optional[List[str]] = None
return_dict_default: bool = True
output_model_class: str = "Transformer2DModelOutput"
class GenericExtractor:
"""Handles common TeaCache logic for all diffusion models.
Extracts forward() arguments, creates run_blocks and postprocess callbacks,
and delegates only timestep embedding computation to model-specific logic.
"""
def __init__(self, config: ExtractorConfig):
self.config = config
def _extract_forward_args(self, module: torch.nn.Module, *args, **kwargs) -> Dict:
"""Extract and normalize forward() arguments from *args and **kwargs."""
# Get parameter names (auto-introspect or use config)
if self.config.forward_params is not None:
param_names = self.config.forward_params
else:
# Auto-introspect forward signature
try:
sig = inspect.signature(module._original_forward)
param_names = [p for p in sig.parameters if p not in ("self", "args", "kwargs")]
except Exception as e:
logger.warning(f"Could not introspect forward signature: {e}")
param_names = []
# Map positional args to parameter names
extracted = {param_names[i]: arg for i, arg in enumerate(args) if i < len(param_names)}
# Merge kwargs (kwargs take precedence)
extracted.update(kwargs)
return extracted
def _compute_timestep_embedding(self, module: torch.nn.Module, params: Dict) -> torch.Tensor:
"""Compute timestep embedding using configured callable."""
timestep = params.get(self.config.timestep_param_name)
if timestep is None:
raise ValueError(f"Missing required parameter: {self.config.timestep_param_name}")
# Flatten timestep if needed (common pattern)
timestep_flat = timestep.flatten() if timestep.ndim == 2 else timestep
guidance = (
params.get(self.config.guidance_param_name) if self.config.guidance_param_name else None
)
# Call configured timestep embedding function
try:
return self.config.timestep_embed_fn(module, timestep_flat, guidance)
except Exception as e:
logger.error(f"Timestep embedder failed: {e}")
# Last resort: use timestep as-is
logger.warning("Using timestep fallback")
return timestep_flat.unsqueeze(-1) if timestep_flat.ndim == 1 else timestep_flat
def __call__(self, module: torch.nn.Module, *args, **kwargs) -> CacheContext:
"""Main extractor logic - called by TeaCacheHook.
Extracts forward arguments, computes timestep embedding, and creates callbacks
for running the transformer and post-processing the output.
"""
# Extract forward arguments from positional and keyword args
params = self._extract_forward_args(module, *args, **kwargs)
# Compute timestep embedding (used for cache distance calculation)
t_emb = self._compute_timestep_embedding(module, params)
return_dict = params.get("return_dict", self.config.return_dict_default)
def run_blocks():
"""Execute the full transformer forward pass with original parameters."""
ret = module._original_forward(**params)
# Normalize output to tuple format
if return_dict and not isinstance(ret, tuple):
sample = ret.sample if hasattr(ret, "sample") else ret
return (sample,) if not isinstance(sample, tuple) else sample
return ret if isinstance(ret, tuple) else (ret,)
def postprocess(output):
"""Convert cached/computed output back to expected return format."""
if return_dict:
if isinstance(output, tuple):
return output
return Transformer2DModelOutput(sample=output)
# For return_dict=False, unwrap single-element tuple to raw tensor
if isinstance(output, tuple) and len(output) == 1:
return output[0]
# Return raw tensor as-is (TeaCacheHook always passes tensors to postprocess)
return output
return CacheContext(
modulated_input=t_emb,
hidden_states=params.get("hidden_states"),
encoder_hidden_states=params.get("encoder_hidden_states"),
run_transformer_blocks=run_blocks,
postprocess=postprocess,
)
def register_extractor_from_config(config: ExtractorConfig):
"""Register a TeaCache extractor for a model. Call this in pipeline's load() method.
Example:
register_extractor_from_config(ExtractorConfig(
model_class_name="LTX2VideoTransformer3DModel",
timestep_embed_fn=self._compute_ltx2_timestep_embedding,
))
"""
extractor = GenericExtractor(config)
register_extractor(config.model_class_name, extractor)
logger.debug(f"Registered TeaCache extractor for {config.model_class_name}")
# =============================================================================
# TeaCache Runtime (caching hook and lifecycle management)
# =============================================================================
class TeaCacheHook:
"""Caches transformer blocks when timestep embeddings change slowly.
The hook monitors the relative change in timestep embeddings between steps.
When the change is small (below threshold), it reuses the cached residual
from the previous step instead of running the full transformer.
Separate cache states are maintained for conditional and unconditional branches
when using Classifier-Free Guidance (CFG).
"""
def __init__(self, config):
self.config = config
# Polynomial function to rescale embedding distances
self.rescale_func = np.poly1d(config.coefficients)
self.extractor_fn = None
# Separate cache state for conditional (pos) and unconditional (neg) branches
self.state_pos = self._new_state()
self.state_neg = self._new_state()
self.stats = {"total": 0, "cached": 0}
def _new_state(self):
return {"cnt": 0, "acc_dist": 0.0, "prev_input": None, "prev_residual": None}
def initialize(self, module):
self.extractor_fn = get_extractor(module.__class__.__name__)
def reset_state(self):
self.state_pos = self._new_state()
self.state_neg = self._new_state()
self.stats = {"total": 0, "cached": 0}
def get_stats(self):
total = max(self.stats["total"], 1)
cached = self.stats["cached"]
return {
"hit_rate": cached / total,
"total": total,
"cached": cached,
# Backward compatibility
"total_steps": total,
"cached_steps": cached,
"compute_steps": total - cached,
}
def __call__(self, module, *args, **kwargs):
"""Main hook called during transformer forward pass.
Decides whether to run the full transformer or reuse cached residual
based on timestep embedding distance.
"""
# Extract context (timestep embedding, hidden states, callbacks)
ctx = self.extractor_fn(module, *args, **kwargs)
# Select cache state (for CFG: separate tracking for conditional/unconditional)
cache_branch = getattr(module, "_cache_branch", None)
state = self.state_neg if cache_branch == "uncond" else self.state_pos
# Decide: compute transformer or use cache?
should_compute = self._should_compute(state, ctx.modulated_input)
self.stats["total"] += 1
if not should_compute and state["prev_residual"] is not None:
# Cache hit: Add cached residual to skip transformer computation
logger.debug(f"TeaCache: SKIP step {state['cnt']}")
# For I2V: output might have fewer channels than input
# Apply residual only to the latent channels
if ctx.hidden_states.shape[1] != state["prev_residual"].shape[1]:
# Extract latent channels (match output channels)
num_output_channels = state["prev_residual"].shape[1]
latent_channels = ctx.hidden_states[:, :num_output_channels]
output = latent_channels + state["prev_residual"]
else:
output = ctx.hidden_states + state["prev_residual"]
self.stats["cached"] += 1
else:
# Cache miss: Run full transformer and cache the residual
outputs = ctx.run_transformer_blocks()
output = outputs[0] if isinstance(outputs, tuple) else outputs
# Store residual (output - input) for next potential cache hit
# For I2V: output may have fewer channels than input
# Compute residual only on the latent channels
if ctx.hidden_states.shape[1] != output.shape[1]:
# Extract latent channels (match output channels)
num_output_channels = output.shape[1]
latent_channels = ctx.hidden_states[:, :num_output_channels]
state["prev_residual"] = (output - latent_channels).detach()
else:
original = ctx.hidden_states.clone()
state["prev_residual"] = (output - original).detach()
# Update state for next iteration
state["prev_input"] = ctx.modulated_input.detach()
state["cnt"] += 1
return ctx.postprocess(output)
def _should_compute(self, state, modulated_inp):
"""Decide whether to compute transformer or use cached result.
Returns True to compute, False to use cache.
"""
# Warmup: Always compute first few steps to build stable cache
if self.config.ret_steps and state["cnt"] < self.config.ret_steps:
state["acc_dist"] = 0.0
return True
# Cooldown: Always compute last few steps for quality
if self.config.cutoff_steps and state["cnt"] >= self.config.cutoff_steps:
return True
# First step: no previous input to compare
if state["prev_input"] is None:
return True
# Compute relative change in timestep embedding
curr, prev = modulated_inp, state["prev_input"]
# For CFG (batch_size > 1), only compare conditional branch
# Both branches move similarly, so one comparison is sufficient
if modulated_inp.shape[0] > 1:
curr, prev = modulated_inp.chunk(2)[1], prev.chunk(2)[1]
# Calculate relative L1 distance (normalized by magnitude)
rel_dist = ((curr - prev).abs().mean() / (prev.abs().mean() + 1e-8)).cpu().item()
# Apply polynomial rescaling to adjust sensitivity
# Accumulate distance (capped at 2x threshold to prevent overflow)
rescaled = float(self.rescale_func(rel_dist))
state["acc_dist"] = min(
state["acc_dist"] + abs(rescaled), self.config.teacache_thresh * 2.0
)
logger.debug(
f"TeaCache: step {state['cnt']} | dist {rel_dist:.2e} | acc {state['acc_dist']:.4f}"
)
# Cache decision based on accumulated distance
if state["acc_dist"] < self.config.teacache_thresh:
# Below threshold: use cache, apply decay to distance
state["acc_dist"] *= 0.95
return False
else:
# Above threshold: compute, reset accumulated distance
state["acc_dist"] = 0.0
return True
class TeaCacheBackend:
"""Manages TeaCache lifecycle."""
def __init__(self, config):
self.config = config
self.hook = None
def enable(self, module):
if self.hook is None:
logger.info(f"TeaCache: Enabling for {module.__class__.__name__}")
self.hook = TeaCacheHook(self.config)
self.hook.initialize(module)
module._original_forward = module.forward
module.forward = lambda *args, **kwargs: self.hook(module, *args, **kwargs)
def disable(self, module):
if self.hook and hasattr(module, "_original_forward"):
module.forward = module._original_forward
self.hook = None
def refresh(self, num_inference_steps):
"""Reset TeaCache state for a new generation.
Sets warmup/cutoff steps based on total inference steps:
- Warmup steps: Always compute to build stable cache
- Cutoff steps: Always compute for quality at the end
- Middle steps: Use caching based on distance threshold
Args:
num_inference_steps: Total number of denoising steps
"""
if not self.hook:
return
# Reset cache state (clears previous residuals and counters)
self.hook.reset_state()
# Configure warmup and cutoff based on mode
if self.config.use_ret_steps:
# Aggressive warmup: 5 steps to stabilize cache
self.config.ret_steps = 5
self.config.cutoff_steps = num_inference_steps # No cutoff (cache until end)
else:
# Minimal warmup: 1 step
self.config.ret_steps = 1
self.config.cutoff_steps = num_inference_steps - 2 # Compute last 2 steps
self.config.num_steps = num_inference_steps
logger.info(
f"TeaCache: {num_inference_steps} steps | "
f"warmup: {self.config.ret_steps}, cutoff: {self.config.cutoff_steps}, "
f"thresh: {self.config.teacache_thresh}"
)
def is_enabled(self):
return self.hook is not None
def get_stats(self):
return self.hook.get_stats() if self.hook else {}

View File

@ -0,0 +1,39 @@
"""Utility functions for visual generation pipelines."""
import torch
@torch.compile
def postprocess_video_tensor(video: torch.Tensor, remove_batch_dim: bool = True) -> torch.Tensor:
"""Post-process video tensor from VAE decoder output to final format.
This is a more efficient implementation than using VideoProcessor for single-batch cases,
as it avoids loop overhead and processes the entire batch with vectorized operations.
Args:
video: Video tensor in (B, C, T, H, W) format from VAE decoder
remove_batch_dim: Whether to remove batch dimension. Default True for typical
single-batch video generation.
Returns:
Post-processed video tensor:
- If remove_batch_dim=True: (T, H, W, C) uint8 tensor
- If remove_batch_dim=False: (B, T, H, W, C) uint8 tensor
Note:
Assumes video values are in [-1, 1] range (standard VAE decoder output).
"""
# Convert to (B, T, H, W, C) format
video = video.permute(0, 2, 3, 4, 1) # (B, C, T, H, W) -> (B, T, H, W, C)
# Normalize to [0, 1] range
video = (video / 2 + 0.5).clamp(0, 1)
# Convert to uint8
video = (video * 255).round().to(torch.uint8)
# Remove batch dimension if requested
if remove_batch_dim:
video = video[0] # (B, T, H, W, C) -> (T, H, W, C)
return video

View File

@ -19,11 +19,13 @@ from tensorrt_llm import LLM as PyTorchLLM
from tensorrt_llm import MultimodalEncoder
from tensorrt_llm._tensorrt_engine import LLM
from tensorrt_llm._utils import mpi_rank
from tensorrt_llm.commands.utils import (get_is_diffusion_model,
get_visual_gen_model_type)
from tensorrt_llm.executor.utils import LlmLauncherEnvs
from tensorrt_llm.inputs.multimodal import MultimodalServerConfig
from tensorrt_llm.llmapi import (BuildConfig, CapacitySchedulerPolicy,
DynamicBatchConfig, KvCacheConfig,
SchedulerConfig)
SchedulerConfig, VisualGen)
from tensorrt_llm.llmapi.disagg_utils import (DisaggClusterConfig,
MetadataServerConfig, ServerRole,
extract_disagg_cluster_config,
@ -217,7 +219,7 @@ def launch_server(
f"{backend} is not a known backend, check help for available options.",
param_hint="backend")
server = OpenAIServer(llm=llm,
server = OpenAIServer(generator=llm,
model=model,
tool_parser=tool_parser,
server_role=server_role,
@ -361,7 +363,7 @@ def launch_mm_encoder_server(
encoder_args.pop("build_config")
mm_encoder = MultimodalEncoder(**encoder_args)
server = OpenAIServer(llm=mm_encoder,
server = OpenAIServer(generator=mm_encoder,
model=model,
server_role=ServerRole.MM_ENCODER,
metadata_server_cfg=metadata_server_cfg,
@ -369,6 +371,45 @@ def launch_mm_encoder_server(
asyncio.run(server(host, port))
def launch_visual_gen_server(
host: str,
port: int,
visual_gen_config: dict,
metadata_server_cfg: Optional[MetadataServerConfig] = None,
):
"""Launch a VISUAL_GEN model server for image/video generation.
Args:
host: Server hostname.
port: Server port.
visual_gen_config: Arguments for VISUAL_GEN model initialization.
metadata_server_cfg: Optional metadata server configuration.
"""
model = visual_gen_config["model"]
logger.info(f"Initializing VisualGen ({model})")
n_workers = 1
parallel_config = visual_gen_config.get("parallel", {})
if parallel_config:
n_workers = parallel_config.get(
"dit_cfg_size", 1) * parallel_config.get("dit_ulysses_size", 1)
logger.info(f"World size: {n_workers}")
logger.info(f"CFG size: {parallel_config.get('dit_cfg_size', 1)}")
logger.info(
f"Ulysses size: {parallel_config.get('dit_ulysses_size', 1)}")
visual_gen_model = VisualGen(model_path=model,
n_workers=n_workers,
diffusion_config=visual_gen_config)
server = OpenAIServer(generator=visual_gen_model,
model=model,
server_role=ServerRole.VISUAL_GEN,
metadata_server_cfg=metadata_server_cfg,
tool_parser=None)
asyncio.run(server(host, port))
class ChoiceWithAlias(click.Choice):
def __init__(self,
@ -600,6 +641,12 @@ class ChoiceWithAlias(click.Choice):
default=False,
help="Run gRPC server instead of OpenAI HTTP server. "
"gRPC server accepts pre-tokenized requests and returns raw token IDs.")
@click.option("--extra_visual_gen_options",
type=str,
default=None,
help=help_info_with_stability_tag(
"Path to a YAML file with extra VISUAL_GEN model options.",
"prototype"))
def serve(
model: str, tokenizer: Optional[str], custom_tokenizer: Optional[str],
host: str, port: int, log_level: str, backend: str, max_beam_width: int,
@ -616,8 +663,8 @@ def serve(
otlp_traces_endpoint: Optional[str], enable_chunked_prefill: bool,
disagg_cluster_uri: Optional[str], media_io_kwargs: Optional[str],
custom_module_dirs: list[Path], chat_template: Optional[str],
grpc: bool):
"""Running an OpenAI API compatible server (or gRPC server with --grpc flag)
grpc: bool, extra_visual_gen_options: Optional[str]):
"""Running an OpenAI API compatible server
MODEL: model name | HF checkpoint path | TensorRT engine path
"""
@ -630,93 +677,120 @@ def serve(
logger.error(
f"Failed to import custom module from {custom_module_dir}: {e}")
raise e
llm_args, _ = get_llm_args(
model=model,
tokenizer=tokenizer,
custom_tokenizer=custom_tokenizer,
backend=backend,
max_beam_width=max_beam_width,
max_batch_size=max_batch_size,
max_num_tokens=max_num_tokens,
max_seq_len=max_seq_len,
tensor_parallel_size=tensor_parallel_size,
pipeline_parallel_size=pipeline_parallel_size,
context_parallel_size=context_parallel_size,
moe_expert_parallel_size=moe_expert_parallel_size,
moe_cluster_parallel_size=moe_cluster_parallel_size,
gpus_per_node=gpus_per_node,
free_gpu_memory_fraction=free_gpu_memory_fraction,
num_postprocess_workers=num_postprocess_workers,
trust_remote_code=trust_remote_code,
revision=revision,
reasoning_parser=reasoning_parser,
fail_fast_on_attention_window_too_large=
fail_fast_on_attention_window_too_large,
otlp_traces_endpoint=otlp_traces_endpoint,
enable_chunked_prefill=enable_chunked_prefill)
llm_args_extra_dict = {}
if extra_llm_api_options is not None:
with open(extra_llm_api_options, 'r') as f:
llm_args_extra_dict = yaml.safe_load(f)
llm_args = update_llm_args_with_extra_dict(llm_args, llm_args_extra_dict)
def _serve_llm():
nonlocal server_role
llm_args, _ = get_llm_args(
model=model,
tokenizer=tokenizer,
custom_tokenizer=custom_tokenizer,
backend=backend,
max_beam_width=max_beam_width,
max_batch_size=max_batch_size,
max_num_tokens=max_num_tokens,
max_seq_len=max_seq_len,
tensor_parallel_size=tensor_parallel_size,
pipeline_parallel_size=pipeline_parallel_size,
context_parallel_size=context_parallel_size,
moe_expert_parallel_size=moe_expert_parallel_size,
moe_cluster_parallel_size=moe_cluster_parallel_size,
gpus_per_node=gpus_per_node,
free_gpu_memory_fraction=free_gpu_memory_fraction,
num_postprocess_workers=num_postprocess_workers,
trust_remote_code=trust_remote_code,
revision=revision,
reasoning_parser=reasoning_parser,
fail_fast_on_attention_window_too_large=
fail_fast_on_attention_window_too_large,
otlp_traces_endpoint=otlp_traces_endpoint,
enable_chunked_prefill=enable_chunked_prefill)
metadata_server_cfg = parse_metadata_server_config_file(
metadata_server_config_file)
llm_args_extra_dict = {}
if extra_llm_api_options is not None:
with open(extra_llm_api_options, 'r') as f:
llm_args_extra_dict = yaml.safe_load(f)
llm_args = update_llm_args_with_extra_dict(llm_args,
llm_args_extra_dict)
# Specify disagg_cluster_config in config file or through command line "--disagg_cluster_uri",
# but disagg_cluster_uri takes precedence over cluster uri in config file
disagg_cluster_config = llm_args.pop("disagg_cluster", None)
if disagg_cluster_config:
disagg_cluster_config = extract_disagg_cluster_config(
disagg_cluster_config, disagg_cluster_uri)
elif disagg_cluster_uri:
disagg_cluster_config = DisaggClusterConfig(
cluster_uri=disagg_cluster_uri)
metadata_server_cfg = parse_metadata_server_config_file(
metadata_server_config_file)
if metadata_server_cfg is not None or disagg_cluster_config is not None:
assert (
server_role is not None
), "server_role is required when metadata_server_cfg or disagg_cluster_config is provided"
try:
server_role = ServerRole[server_role.upper()]
except ValueError:
raise ValueError(f"Invalid server role: {server_role}. " \
f"Must be one of: {', '.join([role.name for role in ServerRole])}")
# Specify disagg_cluster_config in config file or through command line "--disagg_cluster_uri",
# but disagg_cluster_uri takes precedence over cluster uri in config file
disagg_cluster_config = llm_args.pop("disagg_cluster", None)
if disagg_cluster_config:
disagg_cluster_config = extract_disagg_cluster_config(
disagg_cluster_config, disagg_cluster_uri)
elif disagg_cluster_uri:
disagg_cluster_config = DisaggClusterConfig(
cluster_uri=disagg_cluster_uri)
# Parse media_io_kwargs from JSON string to dict if provided
parsed_media_io_kwargs = None
if media_io_kwargs is not None:
try:
parsed_media_io_kwargs = json.loads(media_io_kwargs)
except json.JSONDecodeError as e:
raise ValueError(f"Invalid JSON for media_io_kwargs: {e}")
if metadata_server_cfg is not None or disagg_cluster_config is not None:
assert (
server_role is not None
), "server_role is required when metadata_server_cfg or disagg_cluster_config is provided"
try:
server_role = ServerRole[server_role.upper()]
except ValueError:
raise ValueError(f"Invalid server role: {server_role}. " \
f"Must be one of: {', '.join([role.name for role in ServerRole])}")
# Parse media_io_kwargs from JSON string to dict if provided
parsed_media_io_kwargs = None
if media_io_kwargs is not None:
try:
parsed_media_io_kwargs = json.loads(media_io_kwargs)
except json.JSONDecodeError as e:
raise ValueError(f"Invalid JSON for media_io_kwargs: {e}")
multimodal_server_config = MultimodalServerConfig(
media_io_kwargs=parsed_media_io_kwargs)
multimodal_server_config = MultimodalServerConfig(
media_io_kwargs=parsed_media_io_kwargs)
if grpc:
# gRPC mode: launch gRPC server instead of OpenAI HTTP server
# Check for unsupported arguments that are silently ignored in gRPC mode
unsupported_args = {
"tool_parser": tool_parser,
"chat_template": chat_template,
"metadata_server_config_file": metadata_server_config_file,
"server_role": server_role,
"disagg_cluster_config": disagg_cluster_config,
if grpc:
# gRPC mode: launch gRPC server instead of OpenAI HTTP server
# Check for unsupported arguments that are silently ignored in gRPC mode
unsupported_args = {
"tool_parser": tool_parser,
"chat_template": chat_template,
"metadata_server_config_file": metadata_server_config_file,
"server_role": server_role,
"disagg_cluster_config": disagg_cluster_config,
}
for name, value in unsupported_args.items():
if value is not None:
raise ValueError(
f"Argument '{name}' is not supported when running in gRPC mode. "
f"The gRPC server is designed for use with external routers that handle "
f"these features (e.g., tool parsing, chat templates).")
launch_grpc_server(host, port, llm_args)
else:
# Default: launch OpenAI HTTP server
launch_server(host, port, llm_args, tool_parser, chat_template,
metadata_server_cfg, server_role,
disagg_cluster_config, multimodal_server_config)
def _serve_visual_gen():
visual_gen_config = {
"model": model,
"model_type": get_visual_gen_model_type(model),
}
for name, value in unsupported_args.items():
if value is not None:
raise ValueError(
f"Argument '{name}' is not supported when running in gRPC mode. "
f"The gRPC server is designed for use with external routers that handle "
f"these features (e.g., tool parsing, chat templates).")
launch_grpc_server(host, port, llm_args)
visual_gen_extra_args = {}
if extra_visual_gen_options is not None:
with open(extra_visual_gen_options, 'r') as f:
visual_gen_extra_args = yaml.safe_load(f)
visual_gen_config.update(visual_gen_extra_args)
metadata_server_cfg = parse_metadata_server_config_file(
metadata_server_config_file)
launch_visual_gen_server(host, port, visual_gen_config,
metadata_server_cfg)
if get_is_diffusion_model(model):
_serve_visual_gen()
else:
# Default: launch OpenAI HTTP server
launch_server(host, port, llm_args, tool_parser, chat_template,
metadata_server_cfg, server_role, disagg_cluster_config,
multimodal_server_config)
_serve_llm()
@click.command("mm_embedding_serve")

View File

@ -0,0 +1,132 @@
# Adapted from https://github.com/sgl-project/sglang/blob/030496eb06472f76fcb11de53d93f10cefb4604f/python/sglang/cli/utils.py#L27
import json
import logging
import os
from tensorrt_llm.llmapi.utils import download_hf_partial
logger = logging.getLogger(__name__)
def _maybe_download_model(
model_name_or_path: str, local_dir: str | None = None, download: bool = True
) -> str:
"""Resolve a model path. If it's a local directory, return it.
If it's a Hugging Face Hub ID, download only the config file
(`model_index.json` or `config.json`) and return its directory.
Args:
model_name_or_path: Local path or Hugging Face Hub model ID
local_dir: Local directory to save the downloaded file (if any)
download: Whether to download from Hugging Face Hub when needed
Returns:
Local directory path that contains the downloaded config file, or the original local directory.
"""
if os.path.exists(model_name_or_path):
logger.info("Model already exists locally")
return model_name_or_path
if not download:
return model_name_or_path
try:
logger.info(
"Downloading model_index.json from HF Hub for %s...",
model_name_or_path,
)
file_path = download_hf_partial(
model=model_name_or_path,
allow_patterns=["model_index.json", "config.json"],
)
logger.info("Downloaded to %s", file_path)
return str(file_path)
except Exception as e:
raise ValueError(
(
"Could not find model locally at %s and failed to download "
"model_index.json/config.json from HF Hub: %s"
)
% (model_name_or_path, e)
) from e
# Copied and adapted from hf_diffusers_utils.py
def is_diffusers_model_path(model_path: str) -> bool:
"""Verify if the model directory contains a valid diffusers configuration.
Args:
model_path: Path to the model directory
Returns:
The loaded model configuration as a dictionary if the model is a diffusers model
None if the model is not a diffusers model
"""
# Prefer model_index.json which indicates a diffusers pipeline
config_path = os.path.join(model_path, "model_index.json")
if not os.path.exists(config_path):
return False
# Load the config
with open(config_path) as f:
config = json.load(f)
# Verify diffusers version exists
if "_diffusers_version" not in config:
return False
return True
def get_is_diffusion_model(model_path: str):
model_path = _maybe_download_model(model_path)
is_diffusion_model = is_diffusers_model_path(model_path)
if is_diffusion_model:
logger.info("Diffusion model detected")
return is_diffusion_model
def get_model_path(extra_argv):
# Find the model_path argument
model_path = None
for i, arg in enumerate(extra_argv):
if arg == "--model-path":
if i + 1 < len(extra_argv):
model_path = extra_argv[i + 1]
break
elif arg.startswith("--model-path="):
model_path = arg.split("=", 1)[1]
break
if model_path is None:
# Fallback for --help or other cases where model-path is not provided
if any(h in extra_argv for h in ["-h", "--help"]):
raise Exception(
"Usage: sglang serve --model-path <model-name-or-path> [additional-arguments]\n\n"
"This command can launch either a standard language model server or a diffusion model server.\n"
"The server type is determined by the model path.\n"
"For specific arguments, please provide a model_path."
)
else:
raise Exception(
"Error: --model-path is required. Please provide the path to the model."
)
return model_path
VISUAL_GEN_PARTIAL_MODEL_NAME_TO_MODEL_TYPE = {
"FLUX.2": "flux2",
"LTX-2": "ltx2",
"Wan2": "wan2",
}
def get_visual_gen_model_type(model_path: str):
for partial_model_name, model_type in VISUAL_GEN_PARTIAL_MODEL_NAME_TO_MODEL_TYPE.items():
if partial_model_name.lower() in model_path.lower():
return model_type
raise ValueError(
f"Unknown VISUAL_GEN model type for model path: {model_path},"
f"available models: {VISUAL_GEN_PARTIAL_MODEL_NAME_TO_MODEL_TYPE.keys()}"
)

View File

@ -83,8 +83,7 @@ class ZeroMqQueue:
"Server and client should not receive HMAC key when encryption is disabled"
)
if (socket_type == zmq.PAIR and self.is_server
) or socket_type == zmq.PULL or socket_type == zmq.ROUTER:
if self.should_bind_socket():
self.socket.bind(
self.address_endpoint
) # Binds to the address and occupy a port immediately
@ -101,6 +100,31 @@ class ZeroMqQueue:
self.address = (self.address_endpoint, self.hmac_key)
def should_bind_socket(self) -> bool:
"""
Determine if socket should bind vs connect based on type and role.
ZMQ binding conventions:
- PAIR: server binds, client connects (1-to-1 bidirectional)
- PULL: server binds to receive from multiple PUSH sockets
- PUSH: server binds when acting as message source
- ROUTER: always binds to handle multiple clients
Returns:
True if socket should bind, False if it should connect
"""
# Server binds for PAIR, PULL, PUSH patterns
if self.is_server and self.socket_type in (zmq.PAIR, zmq.PULL,
zmq.PUSH):
return True
# ROUTER always binds (multi-client pattern)
if self.socket_type == zmq.ROUTER:
return True
# Client connects for all other cases
return False
def setup_lazily(self):
# Early return if setup is already done
if self._setup_done:

View File

@ -1,6 +1,6 @@
# Adapt from
# https://github.com/vllm-project/vllm/blob/2e33fe419186c65a18da6668972d61d7bbc31564/vllm/inputs/data.py
from typing import Any, Dict, List, Union
from typing import Any, Dict, List, Sequence, Union
from typing_extensions import NotRequired, TypedDict
@ -85,3 +85,80 @@ def prompt_inputs(inputs: PromptInputs, ) -> Union[TextPrompt, TokensPrompt]:
f"Invalid type of inputs for llm.generate: {type(inputs)}")
return prompt_inputs
class VisualGenTextPrompt(TypedDict):
prompt: str
negative_prompt: NotRequired[str]
class VisualGenTokensPrompt(TypedDict):
prompt_token_ids: List[int]
negative_prompt_token_ids: NotRequired[List[int]]
VisualGenPromptInputs = Union[
str,
List[int],
VisualGenTextPrompt,
VisualGenTokensPrompt,
]
VisualGenInputs = Union[
VisualGenPromptInputs,
Sequence[VisualGenPromptInputs],
]
def visual_gen_inputs(
inputs: "VisualGenPromptInputs",
) -> Union["VisualGenTextPrompt", "VisualGenTokensPrompt"]:
# str -> text prompt
if isinstance(inputs, str):
return VisualGenTextPrompt(prompt=inputs)
# list[int] -> token prompt
if isinstance(inputs, list):
if len(inputs) == 0:
raise ValueError("`inputs` token list cannot be empty.")
if not all(isinstance(t, int) for t in inputs):
raise TypeError(
"`inputs` list must contain only ints when used as token IDs.")
return VisualGenTokensPrompt(prompt_token_ids=inputs)
# dict form
if isinstance(inputs, dict):
has_prompt = "prompt" in inputs
has_prompt_token_ids = "prompt_token_ids" in inputs
if has_prompt == has_prompt_token_ids:
raise ValueError(
"VisualGen prompt dict must contain exactly one of "
"`prompt` or `prompt_token_ids`.")
if has_prompt:
prompt = inputs.get("prompt")
if not isinstance(prompt, str) or prompt == "":
raise TypeError("`prompt` must be a non-empty string.")
if "negative_prompt" in inputs and not isinstance(
inputs["negative_prompt"], str):
raise TypeError("`negative_prompt` must be a string.")
return inputs # VisualGenTextPrompt
token_ids = inputs.get("prompt_token_ids")
if not isinstance(token_ids, list) or len(token_ids) == 0:
raise TypeError("`prompt_token_ids` must be a non-empty list[int].")
if not all(isinstance(t, int) for t in token_ids):
raise TypeError("`prompt_token_ids` must contain only ints.")
if "negative_prompt_token_ids" in inputs:
neg_ids = inputs["negative_prompt_token_ids"]
if not isinstance(neg_ids, list) or not all(
isinstance(t, int) for t in neg_ids):
raise TypeError(
"`negative_prompt_token_ids` must be a list[int].")
return inputs # VisualGenTokensPrompt
raise TypeError(
"Invalid `inputs` for VisualGen.generate. "
"Expected one of: str, list[int], VisualGenTextPrompt, VisualGenTokensPrompt."
)

View File

@ -22,10 +22,13 @@ from .llm_utils import (BuildConfig, KvCacheRetentionConfig, QuantAlgo,
QuantConfig)
from .mm_encoder import MultimodalEncoder
from .mpi_session import MpiCommSession
from .visual_gen import VisualGen, VisualGenParams
__all__ = [
'LLM',
'AsyncLLM',
'VisualGen',
'VisualGenParams',
'MultimodalEncoder',
'CompletionOutput',
'RequestOutput',

View File

@ -24,6 +24,7 @@ class ServerRole(IntEnum):
CONTEXT = 0
GENERATION = 1
MM_ENCODER = 2
VISUAL_GEN = 3
@dataclass

View File

@ -236,18 +236,34 @@ def download_hf_model(model: str, revision: Optional[str] = None) -> Path:
return Path(hf_folder)
def download_hf_pretrained_config(model: str,
revision: Optional[str] = None) -> Path:
def download_hf_partial(model: str,
allow_patterns: List[str],
revision: Optional[str] = None) -> Path:
"""Download a partial model from HuggingFace.
Args:
model: The model name or path.
revision: The revision to use for the model.
allow_patterns: The patterns to allow for the model.
Returns:
The path to the downloaded model.
"""
with get_file_lock(model):
hf_folder = snapshot_download(
model,
local_files_only=huggingface_hub.constants.HF_HUB_OFFLINE,
revision=revision,
allow_patterns=["config.json"],
allow_patterns=allow_patterns,
tqdm_class=DisabledTqdm)
return Path(hf_folder)
def download_hf_pretrained_config(model: str,
revision: Optional[str] = None) -> Path:
return download_hf_partial(model, ["config.json"], revision)
def append_docstring(docstring: str):
''' A decorator to append a docstring to a function. '''

View File

@ -0,0 +1,544 @@
import asyncio
import queue
import socket
import threading
import time
import traceback
from dataclasses import dataclass
from pathlib import Path
from typing import Any, Dict, List, Optional, Union
import torch.multiprocessing as mp
import zmq
from tensorrt_llm._torch.visual_gen import DiffusionRequest, DiffusionResponse
from tensorrt_llm._torch.visual_gen.executor import run_diffusion_worker
from tensorrt_llm._torch.visual_gen.output import MediaOutput
__all__ = ["VisualGen", "VisualGenParams", "MediaOutput"]
from tensorrt_llm.executor.ipc import ZeroMqQueue
from tensorrt_llm.inputs.data import VisualGenInputs
from tensorrt_llm.logger import logger
# Timeouts (seconds)
POLL_TIMEOUT = 0.01
AWAIT_TIMEOUT = 0.05
THREAD_TIMEOUT = 5.0
WORKER_TIMEOUT = 2.0
READY_TIMEOUT = 1200 # 20 minutes for large models (Wan 2.2 with transformer_2)
def find_free_port() -> int:
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
s.bind(("", 0))
return s.getsockname()[1]
def get_ip_address() -> str:
"""Get local IP address."""
s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
try:
s.connect(("10.255.255.255", 1))
return s.getsockname()[0]
except Exception:
return "127.0.0.1"
finally:
s.close()
class DiffusionRemoteClient:
"""Client proxy for remote DiffusionExecutor in worker processes."""
def __init__(
self,
model_path: Union[str, Path],
n_workers: int = 1,
diffusion_config: Optional[dict] = None,
):
self.model_path = str(model_path)
self.n_workers = n_workers
self.diffusion_config = diffusion_config
# Setup distributed env
self.master_addr = "127.0.0.1"
self.master_port = find_free_port()
# Setup IPC addresses
self.host_ip = get_ip_address()
req_port, resp_port = find_free_port(), find_free_port()
self.request_queue_addr = f"tcp://0.0.0.0:{req_port}"
self.response_queue_addr = f"tcp://0.0.0.0:{resp_port}"
self.req_addr_connect = f"tcp://{self.host_ip}:{req_port}"
self.resp_addr_connect = f"tcp://{self.host_ip}:{resp_port}"
# IPC setup
self.requests_ipc = None
self.responses_ipc = None
self.pending_requests = queue.Queue()
self.completed_responses: Dict[int, DiffusionResponse] = {}
# We'll create asyncio primitives in the background thread's event loop
self._event_loop = None
self.response_event = None
self.lock = None
self.shutdown_event = threading.Event()
self.event_loop_ready = threading.Event()
# Start background thread (it will create its own event loop)
self.background_thread = threading.Thread(target=self._serve_forever_thread, daemon=True)
self.background_thread.start()
# Wait for the background thread to initialize the event loop
self.event_loop_ready.wait()
# Launch workers
logger.info(f"DiffusionClient: Launching {n_workers} workers")
ctx = mp.get_context("spawn")
self.worker_processes = []
for rank in range(n_workers):
p = ctx.Process(
target=run_diffusion_worker,
kwargs={
"rank": rank,
"world_size": n_workers,
"master_addr": self.master_addr,
"master_port": self.master_port,
"model_path": self.model_path,
"request_queue_addr": self.req_addr_connect,
"response_queue_addr": self.resp_addr_connect,
"diffusion_config": self.diffusion_config,
},
)
p.start()
self.worker_processes.append(p)
self._wait_ready()
@staticmethod
def _close_socket(ipc_queue):
if ipc_queue and ipc_queue.socket:
ipc_queue.socket.setsockopt(zmq.LINGER, 0)
ipc_queue.close()
def enqueue_requests(self, requests: List[DiffusionRequest]) -> List[int]:
"""Enqueue requests and return their IDs."""
req_ids = []
for req in requests:
self.pending_requests.put(req)
req_ids.append(req.request_id)
return req_ids
async def await_responses(
self, request_ids: Union[int, List[int]], timeout: Optional[float] = None
) -> Union[DiffusionResponse, List[DiffusionResponse]]:
"""Wait for responses by request IDs.
Args:
request_ids: Single request ID or list of request IDs to wait for
timeout: Maximum total wait time in seconds (None = wait indefinitely)
Returns:
Single response or list of responses (None if request timed out)
"""
is_single = isinstance(request_ids, int)
ids = [request_ids] if is_single else request_ids
start_time = time.time()
results = {}
while len(results) < len(ids):
async with self.lock:
for req_id in ids:
if req_id in self.completed_responses:
results[req_id] = self.completed_responses.pop(req_id)
# All responses collected
if len(results) == len(ids):
break
# Check if overall timeout exceeded
if timeout is not None:
elapsed = time.time() - start_time
if elapsed >= timeout:
break
# Wait for remaining time or AWAIT_TIMEOUT, whichever is shorter
wait_time = min(timeout - elapsed, AWAIT_TIMEOUT)
else:
wait_time = AWAIT_TIMEOUT
try:
await asyncio.wait_for(self.response_event.wait(), timeout=wait_time)
except asyncio.TimeoutError:
pass
self.response_event.clear()
out = [results.get(rid) for rid in ids]
return out[0] if is_single else out
def await_responses_sync(
self, request_ids: Union[int, List[int]], timeout: Optional[float] = None
) -> Union[DiffusionResponse, List[DiffusionResponse]]:
"""Sync wrapper to await responses from the main thread."""
future = asyncio.run_coroutine_threadsafe(
self.await_responses(request_ids, timeout), self._event_loop
)
return future.result(timeout=timeout if timeout else None)
def _init_ipc(self) -> bool:
"""Initialize IPC queues."""
try:
logger.info("DiffusionClient: Initializing IPC")
self.requests_ipc = ZeroMqQueue(
(self.request_queue_addr, None),
is_server=True,
socket_type=zmq.PUSH,
use_hmac_encryption=False,
)
self.responses_ipc = ZeroMqQueue(
(self.response_queue_addr, None),
is_server=True,
socket_type=zmq.PULL,
use_hmac_encryption=False,
)
logger.info("DiffusionClient: IPC ready")
return True
except Exception as e:
logger.error(f"DiffusionClient: IPC init failed: {e}")
return False
def _send_shutdown(self):
"""Send shutdown signal."""
logger.info("DiffusionClient: Sending shutdown signal")
if self.requests_ipc:
self.requests_ipc.put(None)
self._close_socket(self.requests_ipc)
def _process_requests(self):
"""Process pending requests."""
try:
req = self.pending_requests.get(timeout=POLL_TIMEOUT)
if req is None:
self._send_shutdown()
self.shutdown_event.set()
return
logger.info(f"DiffusionClient: Sending request {req.request_id}")
self.requests_ipc.put(req)
except queue.Empty:
pass
except Exception as e:
logger.error(f"DiffusionClient: Error sending request: {e}")
logger.error(traceback.format_exc())
def _process_responses(self):
"""Poll and process responses."""
try:
if self.responses_ipc.poll(timeout=POLL_TIMEOUT):
response = self.responses_ipc.get()
if isinstance(response, DiffusionResponse):
if response.request_id == -1:
logger.info("DiffusionClient: Received READY signal")
# Schedule the lock acquisition and event setting in the event loop
asyncio.run_coroutine_threadsafe(
self._store_response(response), self._event_loop
)
except Exception as e:
logger.error(f"DiffusionClient: Error processing response: {e}")
async def _store_response(self, response: DiffusionResponse):
"""Store response in the completed_responses dict (async helper)."""
async with self.lock:
self.completed_responses[response.request_id] = response
self.response_event.set()
def _cleanup_ipc(self):
"""Cleanup IPC."""
logger.info("DiffusionClient: Cleaning up IPC")
self._close_socket(self.requests_ipc)
self._close_socket(self.responses_ipc)
def _serve_forever_thread(self):
"""Background thread wrapper that creates and runs an event loop."""
logger.info("DiffusionClient: Background thread started")
# Create a new event loop for this thread
self._event_loop = asyncio.new_event_loop()
asyncio.set_event_loop(self._event_loop)
# Create async primitives in this thread's event loop
self.response_event = asyncio.Event()
self.lock = asyncio.Lock()
# Signal that the event loop is ready
self.event_loop_ready.set()
# Run the async serve_forever
try:
self._event_loop.run_until_complete(self._serve_forever())
finally:
self._event_loop.close()
logger.info("DiffusionClient: Background thread stopped")
async def _serve_forever(self):
"""Background thread main loop (async version)."""
if not self._init_ipc():
return
while not self.shutdown_event.is_set():
self._process_requests()
self._process_responses()
await asyncio.sleep(0.001) # Yield control to allow other coroutines to run
self._cleanup_ipc()
def shutdown(self):
"""Shutdown client and workers."""
logger.info("DiffusionClient: Shutting down")
self.pending_requests.put(None)
self.background_thread.join(timeout=THREAD_TIMEOUT)
if self.background_thread.is_alive():
logger.warning("DiffusionClient: Force stopping background thread")
self.shutdown_event.set()
self.background_thread.join(timeout=1.0)
# Shutdown workers
logger.info("DiffusionClient: Stopping workers")
for p in self.worker_processes:
p.join(timeout=WORKER_TIMEOUT)
if p.is_alive():
logger.warning(f"DiffusionClient: Terminating worker {p.pid} with SIGTERM")
p.terminate()
p.join(timeout=WORKER_TIMEOUT)
if p.is_alive():
logger.warning(f"DiffusionClient: Force killing worker {p.pid} with SIGKILL")
p.kill()
p.join(timeout=WORKER_TIMEOUT)
def _wait_ready(self, timeout: float = READY_TIMEOUT):
"""Wait for workers to be ready (sync wrapper for async operation)."""
logger.info("DiffusionClient: Waiting for workers")
# Run the async wait in the background thread's event loop
future = asyncio.run_coroutine_threadsafe(self._wait_ready_async(timeout), self._event_loop)
return future.result(timeout=timeout)
async def _wait_ready_async(self, timeout: float = READY_TIMEOUT):
"""Wait for workers to be ready (async version)."""
start_time = time.time()
while True:
async with self.lock:
if -1 in self.completed_responses:
self.completed_responses.pop(-1)
logger.info("DiffusionClient: Workers ready")
return
if time.time() - start_time > timeout:
raise RuntimeError("DiffusionClient: Timeout waiting for workers")
try:
await asyncio.wait_for(self.response_event.wait(), timeout=AWAIT_TIMEOUT)
except asyncio.TimeoutError:
pass
self.response_event.clear()
class DiffusionGenerationResult:
"""Future-like object for async generation."""
def __init__(self, request_id: int, executor: DiffusionRemoteClient):
self.request_id = request_id
self.executor = executor
self._result = None
self._finished = False
self._error = None
async def result(self, timeout: Optional[float] = None) -> Any:
"""Wait for and return result (async version).
Can be awaited from any async context (e.g., FastAPI background tasks).
"""
if self._finished:
if self._error:
raise RuntimeError(self._error)
return self._result
# Use run_coroutine_threadsafe to execute in the background thread's event loop
future = asyncio.run_coroutine_threadsafe(
self.executor.await_responses(self.request_id, timeout=timeout),
self.executor._event_loop,
)
# Await the future in the current event loop
response = await asyncio.wrap_future(future)
if response.error_msg:
self._error = response.error_msg
self._finished = True
raise RuntimeError(f"Generation failed: {response.error_msg}")
self._result = response.output
self._finished = True
return self._result
def cancel(self):
raise NotImplementedError("Cancel request (not yet implemented).")
@dataclass
class VisualGenParams:
"""Parameters for visual generation.
Attributes:
height: Output height in pixels
width: Output width in pixels
num_inference_steps: Number of denoising steps
guidance_scale: Classifier-free guidance scale
max_sequence_length: Maximum sequence length for text encoding
seed: Random seed for reproducibility
# Video-specific parameters
num_frames: Number of video frames to generate
frame_rate: Frame rate for video output in fps
# Image-specific parameters
num_images_per_prompt: Number of images to generate per prompt (for image models)
# Advanced parameters
guidance_rescale: Guidance rescale factor (for some models)
output_type: Output type ("pt" for PyTorch tensors, "pil" for PIL images)
"""
height: int = 720
width: int = 1280
num_inference_steps: int = 50
guidance_scale: float = 5.0
max_sequence_length: int = 512
seed: int = 42
# Video-specific parameters
num_frames: int = 81
frame_rate: float = 24.0
input_reference: Optional[str] = None
# Image-specific parameters
num_images_per_prompt: int = 1
## Image edit parameters
image: Optional[List[str]] = None
mask: Optional[str] = None
# Advanced parameters
guidance_rescale: float = 0.0
output_type: str = "pt"
# Wan-specific parameters
guidance_scale_2: Optional[float] = None
boundary_ratio: Optional[float] = None
last_image: Optional[str] = None
class VisualGen:
"""High-level API for visual generation."""
def __init__(
self,
model_path: Union[str, Path],
n_workers: int = 1,
diffusion_config: Optional[dict] = None,
):
self.model_path = str(model_path)
self.n_workers = n_workers
self.diffusion_config = diffusion_config
self.executor = DiffusionRemoteClient(
model_path=self.model_path,
n_workers=self.n_workers,
diffusion_config=self.diffusion_config,
)
self.req_counter = 0
def generate(
self,
inputs: VisualGenInputs,
params: VisualGenParams,
) -> MediaOutput:
"""Synchronous generation. Blocks until complete.
Args:
params: Generation parameters.
Returns:
MediaOutput: Generated media with model-specific fields populated:
- FLUX2: MediaOutput(image=torch.Tensor)
- WAN: MediaOutput(video=torch.Tensor)
- LTX2: MediaOutput(video=torch.Tensor, audio=torch.Tensor)
"""
future = self.generate_async(
inputs=inputs,
params=params,
)
# Use the sync wrapper to get result
response = self.executor.await_responses_sync(future.request_id, timeout=None)
if response.error_msg:
raise RuntimeError(f"Generation failed: {response.error_msg}")
return response.output
def generate_async(
self,
inputs: VisualGenInputs,
params: VisualGenParams,
) -> DiffusionGenerationResult:
"""Async generation. Returns immediately with future-like object.
Args:
params: Generation parameters.
Returns:
DiffusionGenerationResult: Call result() to get output dict.
"""
req_id = self.req_counter
self.req_counter += 1
if isinstance(inputs, dict):
prompt = inputs.get("prompt")
negative_prompt = inputs.get("negative_prompt", None)
elif isinstance(inputs, str):
prompt = inputs
negative_prompt = None
else:
# TODO: Support batch generation
raise ValueError(f"Invalid inputs type: {type(inputs)}")
request = DiffusionRequest(
request_id=req_id,
prompt=prompt,
negative_prompt=negative_prompt,
height=params.height,
width=params.width,
num_inference_steps=params.num_inference_steps,
guidance_scale=params.guidance_scale,
max_sequence_length=params.max_sequence_length,
seed=params.seed,
num_frames=params.num_frames,
frame_rate=params.frame_rate,
num_images_per_prompt=params.num_images_per_prompt,
guidance_rescale=params.guidance_rescale,
output_type=params.output_type,
image=params.input_reference,
guidance_scale_2=params.guidance_scale_2,
boundary_ratio=params.boundary_ratio,
last_image=params.last_image,
)
self.executor.enqueue_requests([request])
return DiffusionGenerationResult(req_id, self.executor)
def shutdown(self):
"""Shutdown executor and cleanup."""
logger.info("VisualGen: Shutting down")
self.executor.shutdown()

View File

@ -16,10 +16,8 @@ from functools import wraps as _wraps
from tensorrt_llm._utils import mpi_disabled as _mpi_disabled
if _mpi_disabled():
raise RuntimeError(
"Ray requested (TLLM_DISABLE_MPI=1), but not installed. Please install Ray."
)
# Don't raise error on import - only when Ray functionality is actually used
_RAY_NOT_INSTALLED_MSG = "Ray requested (TLLM_DISABLE_MPI=1), but not installed. Please install Ray."
def remote(*args, **kwargs):
@ -42,6 +40,7 @@ def remote(*args, **kwargs):
def __getattr__(name):
raise RuntimeError(
f'Ray not installed, so "ray.{name}" is unavailable. Please install Ray.'
)
msg = f'Ray not installed, so "ray.{name}" is unavailable.'
if _mpi_disabled():
msg = _RAY_NOT_INSTALLED_MSG
raise RuntimeError(msg)

View File

@ -0,0 +1,426 @@
#!/usr/bin/env python
"""Media Storage for generated images and videos.
This module provides storage handlers for persisting generated media assets
(videos, images) and their associated metadata.
"""
import os
from io import BytesIO
from pathlib import Path
from typing import Any, Optional
import torch
from PIL import Image
from tensorrt_llm.logger import logger
class MediaStorage:
"""Handler for storing images and videos in various formats."""
@staticmethod
def save_image(
image: Any, output_path: str, format: Optional[str] = None, quality: int = 95
) -> str:
"""Save image to file.
Args:
image: torch.Tensor (H, W, C) uint8
output_path: Path to save the image
format: Image format (png, jpg, webp). If None, infer from extension
quality: Quality for lossy formats (1-100, higher is better)
Returns:
Path where the image was saved
"""
# Ensure output directory exists
output_dir = os.path.dirname(output_path)
if output_dir:
os.makedirs(output_dir, exist_ok=True)
# Convert to PIL Image if needed
pil_image = MediaStorage._to_pil_image(image)
# Determine format
if format is None:
ext = os.path.splitext(output_path)[1].lower()
if ext in [".png"]:
format = "PNG"
elif ext in [".jpg", ".jpeg"]:
format = "JPEG"
elif ext in [".webp"]:
format = "WEBP"
else:
logger.warning(f"Unknown image extension {ext}, defaulting to PNG")
format = "PNG"
output_path = output_path.rsplit(".", 1)[0] + ".png"
# Save image with format-specific handling
MediaStorage._save_pil_image(pil_image, output_path, format, quality)
logger.info(f"Saved image to {output_path} (format={format})")
return output_path
@staticmethod
def convert_image_to_bytes(image: Any, format: str = "PNG", quality: int = 95) -> bytes:
"""Convert image to bytes buffer.
Args:
image: torch.Tensor (H, W, C) uint8
format: Image format (PNG, JPEG, WEBP)
quality: Quality for lossy formats (1-100)
Returns:
Image bytes
"""
pil_image = MediaStorage._to_pil_image(image)
# Save to bytes buffer
buffer = BytesIO()
MediaStorage._save_pil_image(pil_image, buffer, format, quality)
return buffer.getvalue()
@staticmethod
def _to_pil_image(image: torch.Tensor) -> Image.Image:
"""Convert torch.Tensor to PIL Image.
Args:
image: torch.Tensor (H, W, C) uint8
Returns:
PIL Image
"""
if not isinstance(image, torch.Tensor):
raise ValueError(f"Expected torch.Tensor, got {type(image)}")
# Convert to numpy for PIL
image_np = image.cpu().numpy()
return Image.fromarray(image_np)
@staticmethod
def _save_pil_image(
pil_image: Image.Image,
output: Any, # Can be path string or BytesIO
format: str,
quality: int,
):
"""Save PIL Image to file or buffer.
Args:
pil_image: PIL Image to save
output: Output path (str) or BytesIO buffer
format: Image format (PNG, JPEG, WEBP)
quality: Quality for lossy formats (1-100)
"""
format_upper = format.upper()
if format_upper in ["JPEG", "JPG"]:
# Convert RGBA to RGB for JPEG
if pil_image.mode in ("RGBA", "LA", "P"):
background = Image.new("RGB", pil_image.size, (255, 255, 255))
if pil_image.mode == "P":
pil_image = pil_image.convert("RGBA")
background.paste(
pil_image, mask=pil_image.split()[-1] if pil_image.mode == "RGBA" else None
)
pil_image = background
pil_image.save(output, format="JPEG", quality=quality, optimize=True)
elif format_upper == "WEBP":
pil_image.save(output, format="WEBP", quality=quality)
else: # PNG or default
pil_image.save(output, format="PNG", optimize=True)
@staticmethod
def save_video(
video: Any,
output_path: str,
audio: Optional[Any] = None,
frame_rate: float = 24.0,
format: Optional[str] = None,
) -> str:
"""Save video to file with optional audio.
Args:
video: Video frames as torch.Tensor (T, H, W, C) uint8
output_path: Path to save the video
audio: Optional audio as torch.Tensor
frame_rate: Frames per second (default: 24.0)
format: Video format (mp4, gif, png). If None, infer from extension
Returns:
Path where the video was saved
"""
# Ensure output directory exists
if isinstance(output_path, Path):
output_path = str(output_path)
output_dir = os.path.dirname(output_path)
if output_dir:
os.makedirs(output_dir, exist_ok=True)
# Determine format
if format is None:
ext = os.path.splitext(output_path)[1].lower()
format = ext[1:] if ext else "mp4"
format = format.lower()
# Save based on format
if format == "mp4":
MediaStorage._save_mp4(video, audio, output_path, frame_rate)
elif format == "gif":
MediaStorage._save_gif(video, output_path, frame_rate)
elif format == "png":
MediaStorage._save_middle_frame(video, output_path)
else:
logger.warning(f"Unsupported video format: {format}, defaulting to mp4")
output_path = output_path.rsplit(".", 1)[0] + ".mp4"
MediaStorage._save_mp4(video, audio, output_path, frame_rate)
return output_path
@staticmethod
def convert_video_to_bytes(
video: Any, audio: Optional[Any] = None, frame_rate: float = 24.0, format: str = "mp4"
) -> bytes:
"""Convert video to bytes buffer.
Args:
video: Video frames as torch.Tensor (T, H, W, C) uint8
audio: Optional audio as torch.Tensor
frame_rate: Frames per second
format: Video format (mp4, gif)
Returns:
Video bytes
"""
import tempfile
# Create temporary file
with tempfile.NamedTemporaryFile(suffix=f".{format}", delete=False) as tmp_file:
tmp_path = tmp_file.name
try:
# Save to temporary file
MediaStorage.save_video(video, tmp_path, audio, frame_rate, format)
# Read bytes
with open(tmp_path, "rb") as f:
video_bytes = f.read()
return video_bytes
finally:
# Clean up temporary file
if os.path.exists(tmp_path):
os.unlink(tmp_path)
@staticmethod
def _save_mp4(
video: torch.Tensor, audio: Optional[torch.Tensor], output_path: str, frame_rate: float
) -> str:
"""Save video with optional audio as MP4.
Args:
video: Video frames as torch.Tensor (T, H, W, C) uint8
audio: Optional audio as torch.Tensor
output_path: Output path for MP4
frame_rate: Frames per second
Returns:
Path where the video was saved
"""
try:
from fractions import Fraction
import av
if not isinstance(video, torch.Tensor):
raise ValueError(f"Expected torch.Tensor for video, got {type(video)}")
# Convert video tensor to numpy: (T, H, W, C) uint8
video_np = video.cpu().numpy()
num_frames, height, width, channels = video_np.shape
# Ensure RGB format (3 channels)
if channels != 3:
raise ValueError(f"Expected 3-channel RGB video, got {channels} channels")
# Open output container
container = av.open(output_path, mode="w")
# Add video stream (H.264 codec)
video_stream = container.add_stream("libx264", rate=int(frame_rate))
video_stream.width = width
video_stream.height = height
video_stream.pix_fmt = "yuv420p"
video_stream.options = {"preset": "medium", "crf": "23"}
# Pre-process audio and add audio stream BEFORE any muxing.
# All streams must be registered before the first mux() call
# (which triggers container header writing).
audio_stream = None
audio_tensor = None
audio_sample_rate = 24000 # Default sample rate
if audio is not None:
if not isinstance(audio, torch.Tensor):
raise ValueError(f"Expected torch.Tensor for audio, got {type(audio)}")
# Prepare audio tensor: convert to (samples, channels) format
audio_tensor = audio
# Handle different audio tensor dimensions
if audio_tensor.ndim == 1:
# Mono audio: (samples,) -> (samples, 1)
audio_tensor = audio_tensor[:, None]
elif audio_tensor.ndim == 2:
# If shape[1] != 2 and shape[0] == 2, transpose to (samples, channels)
if audio_tensor.shape[1] != 2 and audio_tensor.shape[0] == 2:
audio_tensor = audio_tensor.T
if audio_tensor.shape[1] > 2:
audio_tensor = audio_tensor[:, :2]
elif audio_tensor.ndim == 3:
if audio_tensor.shape[0] == 1:
audio_tensor = audio_tensor.squeeze(0)
else:
audio_tensor = audio_tensor[0]
if audio_tensor.shape[1] != 2 and audio_tensor.shape[0] == 2:
audio_tensor = audio_tensor.T
if audio_tensor.shape[1] > 2:
audio_tensor = audio_tensor[:, :2]
else:
raise ValueError(
f"Unsupported audio tensor shape: {audio_tensor.shape}. "
f"Expected 1D, 2D, or 3D tensor."
)
if audio_tensor.shape[1] > 2:
audio_tensor = audio_tensor[:, :2]
# Convert to int16 if needed
if audio_tensor.dtype != torch.int16:
audio_tensor = torch.clip(audio_tensor, -1.0, 1.0)
audio_tensor = (audio_tensor * 32767.0).to(torch.int16)
# Add audio stream now (before any muxing)
audio_stream = container.add_stream("aac", rate=audio_sample_rate)
audio_stream.codec_context.sample_rate = audio_sample_rate
audio_stream.codec_context.layout = "stereo"
audio_stream.codec_context.time_base = Fraction(1, audio_sample_rate)
# --- Encode video frames ---
for frame_array in video_np:
frame = av.VideoFrame.from_ndarray(frame_array, format="rgb24")
for packet in video_stream.encode(frame):
container.mux(packet)
# Flush video encoder
for packet in video_stream.encode():
container.mux(packet)
# --- Encode audio (after video is done) ---
if audio_stream is not None and audio_tensor is not None:
# Build packed int16 frame: (1, samples*channels)
audio_np = audio_tensor.contiguous().reshape(1, -1).cpu().numpy()
frame_in = av.AudioFrame.from_ndarray(audio_np, format="s16", layout="stereo")
frame_in.sample_rate = audio_sample_rate
# Use AudioResampler to convert s16→fltp (AAC's native format)
cc = audio_stream.codec_context
audio_resampler = av.audio.resampler.AudioResampler(
format=cc.format or "fltp",
layout=cc.layout or "stereo",
rate=cc.sample_rate or audio_sample_rate,
)
audio_next_pts = 0
for rframe in audio_resampler.resample(frame_in):
if rframe.pts is None:
rframe.pts = audio_next_pts
audio_next_pts += rframe.samples
rframe.sample_rate = audio_sample_rate
container.mux(audio_stream.encode(rframe))
# Flush audio encoder
for packet in audio_stream.encode():
container.mux(packet)
# Close container
container.close()
logger.info(f"Saved video{' with audio' if audio is not None else ''} to {output_path}")
return output_path
except ImportError:
logger.warning(
"PyAV (av) library not available. "
"Falling back to saving middle frame as PNG. "
"Install with: pip install av"
)
png_path = output_path.replace(".mp4", ".png")
return MediaStorage._save_middle_frame(video, png_path)
except Exception as e:
logger.error(f"Error encoding video with PyAV: {e}")
import traceback
logger.error(traceback.format_exc())
logger.warning("Falling back to saving middle frame as PNG.")
png_path = output_path.replace(".mp4", ".png")
return MediaStorage._save_middle_frame(video, png_path)
@staticmethod
def _save_gif(video: torch.Tensor, output_path: str, frame_rate: float) -> str:
"""Save video as animated GIF.
Args:
video: Video frames as torch.Tensor (T, H, W, C) uint8
output_path: Output path for GIF
frame_rate: Frames per second
Returns:
Path where the GIF was saved
"""
if not isinstance(video, torch.Tensor):
raise ValueError(f"Expected torch.Tensor for video, got {type(video)}")
# Convert to numpy and then to list of PIL Images
video_np = video.cpu().numpy()
frames = [Image.fromarray(video_np[i]) for i in range(video_np.shape[0])]
# Save as GIF
duration_ms = int(1000 / frame_rate)
frames[0].save(
output_path,
save_all=True,
append_images=frames[1:],
optimize=False,
duration=duration_ms,
loop=0,
)
logger.info(f"Saved video as GIF to {output_path} ({len(frames)} frames)")
return output_path
@staticmethod
def _save_middle_frame(video: torch.Tensor, output_path: str) -> str:
"""Save middle frame of video as PNG.
Args:
video: Video frames as torch.Tensor (T, H, W, C) uint8
output_path: Output path for PNG
Returns:
Path where the frame was saved
"""
if not isinstance(video, torch.Tensor):
raise ValueError(f"Expected torch.Tensor for video, got {type(video)}")
# Extract middle frame
video_np = video.cpu().numpy()
frame_idx = video_np.shape[0] // 2
image = Image.fromarray(video_np[frame_idx])
image.save(output_path)
logger.info(f"Saved frame {frame_idx} to {output_path}")
return output_path

View File

@ -1,12 +1,14 @@
# Adapted from
# https://github.com/vllm-project/vllm/blob/4db5176d9758b720b05460c50ace3c01026eb158/vllm/entrypoints/openai/protocol.py
import base64
import re
import time
import uuid
from typing import Any, Dict, List, Literal, Optional, Union
import torch
import xgrammar
from fastapi import UploadFile
from openai.types.chat import ChatCompletionAssistantMessageParam
from openai.types.chat import \
ChatCompletionContentPartParam as OpenAIChatCompletionContentPartParam
@ -1120,5 +1122,218 @@ def to_llm_disaggregated_params(
)
# ============================================================================
# Diffusion API Protocol Classes
# ============================================================================
class ImageGenerationRequest(OpenAIBaseModel):
"""OpenAI-compatible image generation request.
Follows the OpenAI Images API specification:
https://platform.openai.com/docs/api-reference/images/create
"""
prompt: str
model: Optional[str] = None
n: int = Field(default=1, ge=1, le=10)
output_format: Literal["png", "webp", "jpeg"] = "png"
size: Optional[str] = Field(
default="auto",
description=(
"The size of the generated images. Must be in 'WxH' format like "
"1024x1024, 1536x1024 (landscape), 1024x1536 (portrait), etc. "
"Use 'auto' for model default size."))
quality: Literal["standard", "hd"] = "standard"
response_format: Literal["url", "b64_json"] = "url"
style: Optional[Literal["vivid", "natural"]] = "vivid"
user: Optional[str] = None
# Extended parameters for diffusion control
num_inference_steps: Optional[int] = Field(
default=None,
description=
"Number of denoising steps. More steps = higher quality but slower.")
guidance_scale: Optional[float] = Field(
default=None,
description=
"Classifier-free guidance scale. Higher values follow prompt more closely."
)
guidance_rescale: Optional[float] = Field(
default=None, description="Classifier-free guidance rescale.")
negative_prompt: Optional[str] = Field(
default=None,
description="Text describing what to avoid in the generated image.")
seed: Optional[int] = Field(default=None,
description="Random seed for reproducibility.")
@field_validator("size")
@classmethod
def validate_size(cls, v):
"""Validate size format is 'WxH' or 'auto'."""
if v is None or v == "auto":
return v
if not isinstance(v, str):
raise ValueError("size must be a string in 'WxH' format or 'auto'")
# Check format: should be like "1024x1024"
import re
if not re.match(r'^\d+x\d+$', v):
raise ValueError(
f"Invalid size format '{v}'. Must be in 'WxH' format "
"(e.g., '1024x1024', '1536x1024') or 'auto'.")
return v
class ImageObject(OpenAIBaseModel):
"""Generated image object in the response."""
b64_json: Optional[str] = None
url: Optional[str] = None
revised_prompt: Optional[str] = None
class ImageGenerationResponse(OpenAIBaseModel):
"""Response from image generation endpoint."""
created: int = Field(default_factory=lambda: int(time.time()))
data: List[ImageObject]
output_format: Literal["png", "webp", "jpeg"] = "png"
quality: Literal["low", "medium", "high"] = "medium"
size: Optional[str] = None
class ImageEditRequest(OpenAIBaseModel):
"""Request for image editing endpoint.
Follows the OpenAI Images API specification:
https://platform.openai.com/docs/api-reference/images/createEdit
"""
image: Union[List[str], str] = Field(
description="Base64-encoded source image(s) to edit")
prompt: str = Field(description="Text description of desired edits")
model: Optional[str] = None
mask: Optional[str] = Field(
default=None,
description=
"Base64-encoded mask image (optional, black areas will be edited)")
n: int = Field(default=1, ge=1, le=10)
size: Optional[str] = Field(
default="auto",
description=(
"The size of the edited images. Must be in 'WxH' format like "
"1024x1024, 1536x1024 (landscape), 1024x1536 (portrait), etc. "
"Use 'auto' to match source image size."))
response_format: Literal["url", "b64_json"] = "url"
user: Optional[str] = None
# Extended parameters for diffusion control
num_inference_steps: Optional[int] = Field(
default=None, description="Number of denoising steps.")
guidance_scale: Optional[float] = Field(
default=None, description="Classifier-free guidance scale.")
guidance_rescale: Optional[float] = Field(
default=None, description="Classifier-free guidance rescale.")
negative_prompt: Optional[str] = Field(
default=None,
description="Text describing what to avoid in the edited image.")
seed: Optional[int] = Field(default=None,
description="Random seed for reproducibility.")
@field_validator("size")
@classmethod
def validate_size(cls, v):
"""Validate size format is 'WxH' or 'auto'."""
if v != "auto" and not re.match(r"^\d+x\d+$", v):
raise ValueError(
"Size must be 'auto' or in 'WxH' format (e.g., '1024x1024')")
return v
class VideoGenerationRequest(OpenAIBaseModel):
"""Video generation request (extended API).
This is an extension to the OpenAI API for video generation support.
"""
prompt: str
input_reference: Optional[Union[str, UploadFile]] = Field(
default=None,
description="Optional image reference that guides generation.")
model: Optional[str] = None
size: Optional[str] = Field(
default="auto",
description=
("The size of the generated video frames. Must be in 'WxH' format like "
"512x512, 1024x576 (landscape), 576x1024 (portrait), etc. "
"Use 'auto' for model default size."))
seconds: float = Field(default=2.0,
ge=1.0,
le=16.0,
description="Video duration in seconds.")
# Extended parameters for diffusion control
n: int = Field(default=1, ge=1, le=4)
fps: int = Field(default=24, ge=8, le=60, description="Frames per second.")
num_inference_steps: Optional[int] = Field(
default=None, description="Number of denoising steps.")
guidance_scale: Optional[float] = Field(
default=None, description="Classifier-free guidance scale.")
guidance_rescale: Optional[float] = Field(
default=None, description="Classifier-free guidance rescale.")
negative_prompt: Optional[str] = Field(
default=None,
description="Text describing what to avoid in the generated video.")
seed: Optional[int] = Field(default=None,
description="Random seed for reproducibility.")
@field_validator("size")
@classmethod
def validate_size(cls, v):
"""Validate size format is 'WxH' or 'auto'."""
if v is None or v == "auto":
return v
if not isinstance(v, str):
raise ValueError("size must be a string in 'WxH' format or 'auto'")
import re
if not re.match(r'^\d+x\d+$', v):
raise ValueError(
f"Invalid size format '{v}'. Must be in 'WxH' format "
"(e.g., '512x512', '1024x576') or 'auto'.")
return v
class VideoJob(OpenAIBaseModel):
"""Metadata for an asynchronous video generation job.
Follows the OpenAI Videos API specification:
https://platform.openai.com/docs/api-reference/videos
"""
completed_at: Optional[int] = Field(
default=None, description="Unix timestamp of completion")
created_at: int = Field(description="Unix timestamp of creation")
error: Optional[str] = Field(default=None,
description="Error message if failed")
expires_at: Optional[int] = Field(
default=None, description="Unix timestamp of expiration")
id: str = Field(description="Unique identifier for the video")
model: str = Field(description="The model used for generation")
object: str = Field(default="video", description="Object type")
progress: Optional[int] = Field(
default=None,
description="Progress of the video generation job (0-100)")
prompt: str = Field(description="The prompt used to generate the video")
status: Literal["queued", "in_progress", "completed", "failed"] = Field(
description="Current status of the video generation job")
# Video properties
duration: Optional[float] = Field(default=None,
description="Video duration in seconds")
fps: Optional[int] = Field(default=None, description="Frames per second")
size: Optional[str] = Field(default=None,
description="Video dimensions in 'WxH' format")
class VideoJobList(OpenAIBaseModel):
"""Response from listing video jobs endpoint."""
data: List[VideoJob] = Field(description="List of video jobs")
object: str = Field(default="list", description="Object type")
UCompletionRequest = Union[CompletionRequest, ChatCompletionRequest]
UCompletionResponse = Union[CompletionResponse, ChatCompletionResponse]

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,112 @@
import asyncio
import base64
import os
import shutil
from typing import Any, Dict, List, Optional
from tensorrt_llm.llmapi.visual_gen import VisualGenParams
from tensorrt_llm.serve.openai_protocol import (
ImageEditRequest,
ImageGenerationRequest,
VideoGenerationRequest,
)
def parse_visual_gen_params(
request: ImageGenerationRequest | VideoGenerationRequest | ImageEditRequest,
id: str,
media_storage_path: Optional[str] = None,
) -> VisualGenParams:
params = VisualGenParams()
params.prompt = request.prompt
if request.negative_prompt is not None:
params.negative_prompt = request.negative_prompt
if request.size is not None and request.size != "auto":
params.width, params.height = map(int, request.size.split("x"))
if request.guidance_scale is not None:
params.guidance_scale = request.guidance_scale
if request.guidance_rescale is not None:
params.guidance_rescale = request.guidance_rescale
if isinstance(request, ImageGenerationRequest) or isinstance(request, ImageEditRequest):
if request.num_inference_steps is not None:
params.num_inference_steps = request.num_inference_steps
elif isinstance(request, ImageGenerationRequest) and request.quality == "hd":
params.num_inference_steps = 30
if request.n is not None:
params.num_images_per_prompt = request.n
if isinstance(request, ImageEditRequest):
if request.image is not None:
if isinstance(request.image, list):
params.image = [base64.b64decode(image) for image in request.image]
else:
params.image = [base64.b64decode(request.image)]
if request.mask is not None:
if isinstance(request.mask, list):
params.mask = [base64.b64decode(mask) for mask in request.mask]
else:
params.mask = base64.b64decode(request.mask)
elif isinstance(request, VideoGenerationRequest):
if request.num_inference_steps is not None:
params.num_inference_steps = request.num_inference_steps
if request.input_reference is not None:
if media_storage_path is None:
raise ValueError("media_storage_path is required when input_reference is provided")
params.input_reference = os.path.join(media_storage_path, f"{id}_reference.png")
if isinstance(request.input_reference, str):
with open(params.input_reference, "wb") as f:
f.write(base64.b64decode(request.input_reference))
else:
with open(params.input_reference, "wb") as f:
shutil.copyfileobj(request.input_reference.file, f)
params.frame_rate = request.fps
params.num_frames = int(request.seconds * request.fps)
if request.seed is not None:
params.seed = int(request.seed)
return params
class AsyncDictStore:
"""A small async-safe in-memory key-value store for dict items.
This encapsulates the usual pattern of a module-level dict guarded by
an asyncio.Lock and provides simple CRUD methods that are safe to call
concurrently from FastAPI request handlers and background tasks.
"""
def __init__(self) -> None:
self._items: Dict[str, Dict[str, Any]] = {}
self._lock = asyncio.Lock()
async def upsert(self, key: str, value: Dict[str, Any]) -> None:
async with self._lock:
self._items[key] = value
async def update_fields(self, key: str, updates: Dict[str, Any]) -> Optional[Dict[str, Any]]:
async with self._lock:
item = self._items.get(key)
if item is None:
return None
item.update(updates)
return item
async def get(self, key: str) -> Optional[Dict[str, Any]]:
async with self._lock:
return self._items.get(key)
async def pop(self, key: str) -> Optional[Dict[str, Any]]:
async with self._lock:
return self._items.pop(key, None)
async def list_values(self) -> List[Dict[str, Any]]:
async with self._lock:
return list(self._items.values())
# Global stores shared by OpenAI entrypoints
# [request_id, dict]
VIDEO_STORE = AsyncDictStore()

View File

@ -0,0 +1,288 @@
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Integration tests: VBench dimension scores for WAN and LTX-2 (TRT-LLM vs diffusers reference)."""
import glob
import json
import os
import pytest
from defs.common import venv_check_call
from defs.conftest import llm_models_root
from defs.trt_test_alternative import check_call
WAN_T2V_MODEL_SUBPATH = "Wan2.1-T2V-1.3B-Diffusers"
VISUAL_GEN_OUTPUT_VIDEO = "trtllm_output.mp4"
DIFFUSERS_REFERENCE_VIDEO = "diffusers_reference.mp4"
WAN_T2V_PROMPT = "A cute cat playing piano"
WAN_T2V_HEIGHT = 480
WAN_T2V_WIDTH = 832
WAN_T2V_NUM_FRAMES = 165
# Dimensions to evaluate
VBENCH_DIMENSIONS = [
"subject_consistency",
"background_consistency",
"motion_smoothness",
"dynamic_degree",
"aesthetic_quality",
"imaging_quality",
]
# Golden VBench scores from HF reference video (WAN); TRT-LLM is compared against these.
VBENCH_WAN_GOLDEN_SCORES = {
"subject_consistency": 0.9381,
"background_consistency": 0.9535,
"motion_smoothness": 0.9923,
"dynamic_degree": 1.0000,
"aesthetic_quality": 0.5033,
"imaging_quality": 0.3033,
}
VBENCH_REPO = "https://github.com/Vchitect/VBench.git"
VBENCH_BRANCH = "master"
# Pin to a fixed commit for reproducible runs
VBENCH_COMMIT = "98b19513678e99c80d8377fda25ba53b81a491a6"
@pytest.fixture(scope="session")
def vbench_repo_root(llm_venv):
"""Clone VBench repo into workspace and install; return repo root path."""
workspace = llm_venv.get_working_directory()
repo_path = os.path.join(workspace, "VBench_repo")
if os.path.exists(repo_path):
return repo_path
# Clone without --depth=1 so we can checkout a specific commit
check_call(
["git", "clone", "--single-branch", "--branch", VBENCH_BRANCH, VBENCH_REPO, repo_path],
shell=False,
)
check_call(["git", "-C", repo_path, "checkout", VBENCH_COMMIT], shell=False)
# # Install VBench dependencies explicitly
# llm_venv.run_cmd([
# "-m", "pip", "install",
# "tqdm>=4.60.0",
# "openai-clip>=1.0",
# "pyiqa>=0.1.0", # install this might also install transformers=4.37.2, which is incompatible
# "easydict",
# "decord>=0.6.0",
# ])
return repo_path
@pytest.fixture(scope="session")
def wan_trtllm_video_path(llm_venv, llm_root):
"""Generate input video via visual_gen_wan_t2v.py and return path to trtllm_output.mp4."""
scratch_space = llm_models_root()
model_path = os.path.join(scratch_space, WAN_T2V_MODEL_SUBPATH)
if not os.path.isdir(model_path):
pytest.skip(
f"Wan T2V model not found: {model_path} "
f"(set LLM_MODELS_ROOT or place {WAN_T2V_MODEL_SUBPATH} under scratch)"
)
out_dir = os.path.join(llm_venv.get_working_directory(), "visual_gen_output")
os.makedirs(out_dir, exist_ok=True)
output_path = os.path.join(out_dir, VISUAL_GEN_OUTPUT_VIDEO)
if os.path.isfile(output_path):
return output_path
# Install av and diffusers from main branch
llm_venv.run_cmd(["-m", "pip", "install", "av"])
llm_venv.run_cmd(
[
"-m",
"pip",
"install",
"git+https://github.com/huggingface/diffusers.git",
]
)
script_path = os.path.join(llm_root, "examples", "visual_gen", "visual_gen_wan_t2v.py")
assert os.path.isfile(script_path), f"Visual gen script not found: {script_path}"
venv_check_call(
llm_venv,
[
script_path,
"--height",
str(WAN_T2V_HEIGHT),
"--width",
str(WAN_T2V_WIDTH),
"--num_frames",
str(WAN_T2V_NUM_FRAMES),
"--model_path",
model_path,
"--prompt",
WAN_T2V_PROMPT,
"--output_path",
output_path,
],
)
assert os.path.isfile(output_path), f"Visual gen did not produce {output_path}"
return output_path
@pytest.fixture(scope="session")
def wan_reference_video_path(llm_venv, llm_root):
"""Generate reference video via diffusers (hf_wan.py) using the same model checkpoint."""
scratch_space = llm_models_root()
model_path = os.path.join(scratch_space, WAN_T2V_MODEL_SUBPATH)
if not os.path.isdir(model_path):
pytest.skip(
f"Wan T2V model not found: {model_path} "
f"(set LLM_MODELS_ROOT or place {WAN_T2V_MODEL_SUBPATH} under scratch)"
)
out_dir = os.path.join(llm_venv.get_working_directory(), "visual_gen_output")
os.makedirs(out_dir, exist_ok=True)
reference_path = os.path.join(out_dir, DIFFUSERS_REFERENCE_VIDEO)
if os.path.isfile(reference_path):
return reference_path
hf_script = os.path.join(llm_root, "examples", "visual_gen", "hf_wan.py")
assert os.path.isfile(hf_script), f"Diffusers script not found: {hf_script}"
venv_check_call(
llm_venv,
[
hf_script,
"--model_path",
model_path,
"--prompt",
WAN_T2V_PROMPT,
"--output_path",
reference_path,
"--height",
str(WAN_T2V_HEIGHT),
"--width",
str(WAN_T2V_WIDTH),
"--num_frames",
str(WAN_T2V_NUM_FRAMES),
],
)
assert os.path.isfile(reference_path), f"Diffusers did not produce {reference_path}"
return reference_path
def _visual_gen_out_dir(llm_venv, subdir=""):
"""Output directory for generated media; subdir e.g. 'ltx2' for model-specific outputs."""
base = os.path.join(llm_venv.get_working_directory(), "visual_gen_output")
return os.path.join(base, subdir) if subdir else base
def _normalize_score(val):
"""Normalize to 0-1 scale (e.g. imaging_quality can be 0-100)."""
if isinstance(val, bool):
return float(val)
if isinstance(val, (int, float)) and val > 1.5:
return val / 100.0
return float(val)
def _get_per_video_scores(results, video_path_substr):
"""From VBench results, get per-dimension score for the video whose path contains video_path_substr."""
scores = {}
for dim in VBENCH_DIMENSIONS:
dim_result = results[dim]
assert isinstance(dim_result, list) and len(dim_result) >= 2, (
f"Dimension '{dim}' result must be [overall_score, video_results]; got {type(dim_result)}"
)
video_results = dim_result[1]
for entry in video_results:
if video_path_substr in entry.get("video_path", ""):
raw = entry.get("video_results")
scores[dim] = _normalize_score(raw)
break
else:
raise AssertionError(
f"No video matching '{video_path_substr}' in dimension '{dim}'; "
f"paths: {[e.get('video_path') for e in video_results]}"
)
return scores
def _run_vbench_and_compare_to_golden(
vbench_repo_root,
videos_dir,
trtllm_filename,
golden_scores,
llm_venv,
title,
max_score_diff=0.1,
):
"""Run VBench on videos_dir (TRT-LLM output only), compare to golden HF reference scores."""
output_path = os.path.join(
llm_venv.get_working_directory(), "vbench_eval_output", title.replace(" ", "_").lower()
)
os.makedirs(output_path, exist_ok=True)
evaluate_script = os.path.join(vbench_repo_root, "evaluate.py")
cmd = [
evaluate_script,
"--videos_path",
videos_dir,
"--output_path",
output_path,
"--mode",
"custom_input",
]
cmd.extend(["--dimension"] + VBENCH_DIMENSIONS)
venv_check_call(llm_venv, cmd)
pattern = os.path.join(output_path, "*_eval_results.json")
result_files = glob.glob(pattern)
assert result_files, (
f"No eval results found matching {pattern}; output dir: {os.listdir(output_path)}"
)
with open(result_files[0], "r") as f:
results = json.load(f)
for dim in VBENCH_DIMENSIONS:
assert dim in results, (
f"Expected dimension '{dim}' in results; keys: {list(results.keys())}"
)
scores_trtllm = _get_per_video_scores(results, trtllm_filename)
scores_ref = golden_scores
max_len = max(len(d) for d in VBENCH_DIMENSIONS)
header = f"{'Dimension':<{max_len}} | {'TRT-LLM':>10} | {'HF Ref':>10} | {'Diff':>8}"
sep = "-" * len(header)
print("\n" + "=" * len(header))
print(f"VBench dimension scores ({title}): TRT-LLM vs golden HF reference scores")
print("=" * len(header))
print(header)
print(sep)
max_diff_val = 0.0
for dim in VBENCH_DIMENSIONS:
t, r = scores_trtllm[dim], scores_ref[dim]
diff = abs(t - r)
max_diff_val = max(max_diff_val, diff)
print(f"{dim:<{max_len}} | {t:>10.4f} | {r:>10.4f} | {diff:>8.4f}")
print(sep)
print(
f"{' (all dimensions)':<{max_len}} | (TRT-LLM) | (golden) | max_diff={max_diff_val:.4f}"
)
print("=" * len(header) + "\n")
for dim in VBENCH_DIMENSIONS:
diff = abs(scores_trtllm[dim] - scores_ref[dim])
assert diff < max_score_diff or scores_trtllm[dim] >= scores_ref[dim], (
f"Dimension '{dim}' score difference {diff:.4f} >= {max_score_diff} "
f"(TRT-LLM={scores_trtllm[dim]:.4f}, golden={scores_ref[dim]:.4f})"
)
def test_vbench_dimension_score_wan(vbench_repo_root, wan_trtllm_video_path, llm_venv):
"""Run VBench on WAN TRT-LLM video; compare to golden HF reference scores (diff < 0.05 or TRT-LLM >= golden)."""
videos_dir = os.path.dirname(wan_trtllm_video_path)
assert os.path.isfile(wan_trtllm_video_path), "TRT-LLM video must exist"
_run_vbench_and_compare_to_golden(
vbench_repo_root,
videos_dir,
VISUAL_GEN_OUTPUT_VIDEO,
VBENCH_WAN_GOLDEN_SCORES,
llm_venv,
title="WAN",
max_score_diff=0.05,
)

View File

@ -91,6 +91,17 @@ l0_b200:
- unittest/tools/test_layer_wise_benchmarks.py::test_performance_alignment[1]
- unittest/_torch/modeling/test_modeling_exaone4.py::TestEXAONE4::test_llm_load_1_FP8
- unittest/kv_cache_manager_v2_tests/
# ------------- Visual Gen tests ---------------
- unittest/_torch/visual_gen/test_fused_qkv.py
- unittest/_torch/visual_gen/test_quant_ops.py
- unittest/_torch/visual_gen/test_attention_integration.py
- unittest/_torch/visual_gen/test_attention_perf.py
- unittest/_torch/visual_gen/test_trtllm_serve_endpoints.py
- unittest/_torch/visual_gen/test_trtllm_serve_e2e.py
- unittest/_torch/visual_gen/test_wan.py -k "not TestWanTwoStageTransformer"
- unittest/_torch/visual_gen/test_wan_i2v.py
- unittest/_torch/visual_gen/test_model_loader.py
# - examples/test_visual_gen.py
- condition:
ranges:
system_gpu_count:
@ -161,6 +172,7 @@ l0_b200:
- accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_nvfp4[moe_backend=CUTEDSL-mtp_nextn=0-fp8kv=True-attention_dp=True-cuda_graph=True-overlap_scheduler=True-torch_compile=False]
- accuracy/test_llm_api_pytorch.py::TestSeedOss_36B::test_auto_dtype
- accuracy/test_llm_api_pytorch.py::TestLlama3_1_8B_Instruct_RocketKV::test_auto_dtype
- unittest/_torch/visual_gen/test_wan.py::TestWanTwoStageTransformer
# ------------- AutoDeploy Backend Stages ---------------
- condition:
ranges:

View File

@ -30,6 +30,12 @@ l0_dgx_b200:
- disaggregated/test_disaggregated.py::test_disaggregated_gpt_oss_120b_harmony[gpt_oss/gpt-oss-120b]
- accuracy/test_llm_api_pytorch.py::TestDeepSeekR1::test_nvfp4_multi_gpus[latency_adp_lmtp_tp4]
- accuracy/test_llm_api_pytorch.py::TestMiniMaxM2::test_4gpus[attention_dp=False-cuda_graph=True-overlap_scheduler=True-tp_size=4-ep_size=4] TIMEOUT (60)
# ------------- VisualGen multi-GPU tests ---------------
- unittest/_torch/visual_gen/multi_gpu
- unittest/_torch/visual_gen/test_wan.py::TestWanParallelism::test_cfg_2gpu_correctness
- unittest/_torch/visual_gen/test_wan.py::TestWanCombinedOptimizations::test_all_optimizations_combined
- unittest/_torch/visual_gen/test_wan_i2v.py::TestWanI2VParallelism::test_cfg_2gpu_correctness
- unittest/_torch/visual_gen/test_wan_i2v.py::TestWanI2VCombinedOptimizations::test_all_optimizations_combined
- condition:
ranges:
system_gpu_count:

View File

@ -0,0 +1 @@
"""Multi-GPU tests for visual generation modules."""

View File

@ -0,0 +1,505 @@
"""Multi-GPU tests for Ulysses Attention.
These tests use torch.multiprocessing.spawn to launch multiple processes internally.
Run with:
pytest tests/visual_gen/multi_gpu/test_ulysses_attention.py -v
"""
import os
os.environ["TLLM_DISABLE_MPI"] = "1"
import math
from typing import Callable
import pytest
import torch
import torch.distributed as dist
import torch.multiprocessing as mp
import torch.nn.functional as F
# Try to import the modules - skip tests if not available
try:
from tensorrt_llm._torch.attention_backend.interface import PredefinedAttentionMask
from tensorrt_llm._torch.distributed import all_to_all_4d
from tensorrt_llm._torch.visual_gen.attention_backend import UlyssesAttention, VanillaAttention
from tensorrt_llm._utils import get_free_port
MODULES_AVAILABLE = True
except ImportError:
MODULES_AVAILABLE = False
@pytest.fixture(autouse=True, scope="module")
def _cleanup_mpi_env():
"""Clean up TLLM_DISABLE_MPI env var after tests complete."""
yield
os.environ.pop("TLLM_DISABLE_MPI", None)
def init_distributed_worker(rank: int, world_size: int, backend: str = "gloo", port: int = 29500):
"""Initialize distributed environment for a worker process."""
os.environ["MASTER_ADDR"] = "localhost"
os.environ["MASTER_PORT"] = str(port)
os.environ["RANK"] = str(rank)
os.environ["WORLD_SIZE"] = str(world_size)
# Use gloo backend for CPU, nccl for GPU
if backend == "nccl" and torch.cuda.is_available():
torch.cuda.set_device(rank % torch.cuda.device_count())
dist.init_process_group(backend="nccl", rank=rank, world_size=world_size)
else:
dist.init_process_group(backend="gloo", rank=rank, world_size=world_size)
def cleanup_distributed():
"""Clean up distributed environment."""
if dist.is_initialized():
dist.destroy_process_group()
def _distributed_worker(rank, world_size, backend, test_fn, port):
"""Worker function that runs in each process. Module-level for pickling."""
try:
init_distributed_worker(rank, world_size, backend, port)
test_fn(rank, world_size)
except Exception as e:
print(f"Rank {rank} failed with error: {e}")
raise
finally:
cleanup_distributed()
def run_test_in_distributed(world_size: int, test_fn: Callable, use_cuda: bool = True):
"""Run a test function in a distributed environment with multiple processes.
Args:
world_size: Number of processes to spawn
test_fn: Test function to run (must be module-level for pickling).
Should accept (rank, world_size) as arguments.
use_cuda: Whether to use CUDA (requires sufficient GPUs)
"""
if not MODULES_AVAILABLE:
pytest.skip("Required modules not available")
if use_cuda and torch.cuda.device_count() < world_size:
pytest.skip(f"Test requires {world_size} GPUs, only {torch.cuda.device_count()} available")
backend = "nccl" if use_cuda else "gloo"
port = get_free_port()
# Spawn processes
mp.spawn(
_distributed_worker, args=(world_size, backend, test_fn, port), nprocs=world_size, join=True
)
# =============================================================================
# Test logic functions (module-level so they can be pickled by mp.spawn)
# =============================================================================
def _logic_a2a_seq_to_head(rank, world_size):
"""all_to_all_4d: sequence sharding to head sharding."""
batch = 2
seq_per_rank = 4
heads = 8
head_dim = 64
if heads % world_size != 0:
heads = world_size * 2
device = torch.device(f"cuda:{rank}") if torch.cuda.is_available() else torch.device("cpu")
input_tensor = (
torch.randn(batch, seq_per_rank, heads, head_dim, device=device, dtype=torch.float32)
+ rank * 100
)
output = all_to_all_4d(
input_tensor,
scatter_dim=2,
gather_dim=1,
process_group=None,
)
expected_shape = (batch, seq_per_rank * world_size, heads // world_size, head_dim)
assert output.shape == expected_shape, (
f"Rank {rank}: Expected shape {expected_shape}, got {output.shape}"
)
assert output.device == device
def _logic_a2a_head_to_seq(rank, world_size):
"""all_to_all_4d: head sharding to sequence sharding."""
batch = 2
seq = 16
heads_per_rank = 2
head_dim = 64
device = torch.device(f"cuda:{rank}") if torch.cuda.is_available() else torch.device("cpu")
input_tensor = torch.randn(
batch, seq, heads_per_rank, head_dim, device=device, dtype=torch.float32
)
output = all_to_all_4d(
input_tensor,
scatter_dim=1,
gather_dim=2,
process_group=None,
)
expected_shape = (batch, seq // world_size, heads_per_rank * world_size, head_dim)
assert output.shape == expected_shape, (
f"Rank {rank}: Expected shape {expected_shape}, got {output.shape}"
)
def _logic_a2a_roundtrip(rank, world_size):
"""all_to_all_4d: forward and backward are inverses."""
batch = 2
seq_per_rank = 4
heads = world_size * 4
head_dim = 64
device = torch.device(f"cuda:{rank}") if torch.cuda.is_available() else torch.device("cpu")
original = torch.randn(batch, seq_per_rank, heads, head_dim, device=device, dtype=torch.float32)
intermediate = all_to_all_4d(original, scatter_dim=2, gather_dim=1, process_group=None)
reconstructed = all_to_all_4d(intermediate, scatter_dim=1, gather_dim=2, process_group=None)
assert reconstructed.shape == original.shape
torch.testing.assert_close(reconstructed, original, rtol=1e-5, atol=1e-5)
def _logic_a2a_single_process(rank, world_size):
"""all_to_all_4d: single process returns input unchanged."""
batch, seq, heads, head_dim = 2, 8, 4, 64
device = torch.device(f"cuda:{rank}" if torch.cuda.is_available() else "cpu")
input_tensor = torch.randn(batch, seq, heads, head_dim, device=device)
output = all_to_all_4d(input_tensor, scatter_dim=2, gather_dim=1, process_group=None)
torch.testing.assert_close(output, input_tensor)
def _logic_ulysses_init(rank, world_size):
"""UlyssesAttention initialization."""
num_heads = world_size * 4
head_dim = 64
inner = VanillaAttention(num_heads=num_heads // world_size, head_dim=head_dim)
attention = UlyssesAttention(
inner_backend=inner,
process_group=None,
)
assert attention.num_heads == num_heads
assert attention.head_dim == head_dim
assert attention.world_size == world_size
assert rank >= 0 and rank < world_size
def _logic_ulysses_forward(rank, world_size):
"""UlyssesAttention forward pass."""
batch = 2
seq_per_rank = 8
num_heads = world_size * 4
head_dim = 64
device = torch.device(f"cuda:{rank}") if torch.cuda.is_available() else torch.device("cpu")
inner = VanillaAttention(num_heads=num_heads // world_size, head_dim=head_dim)
attention = UlyssesAttention(
inner_backend=inner,
process_group=None,
).to(device)
q = torch.randn(batch, seq_per_rank, num_heads, head_dim, device=device)
k = torch.randn(batch, seq_per_rank, num_heads, head_dim, device=device)
v = torch.randn(batch, seq_per_rank, num_heads, head_dim, device=device)
output = attention(q, k, v, batch_size=batch, seq_len=seq_per_rank * world_size)
assert output.shape == q.shape, f"Rank {rank}: Expected shape {q.shape}, got {output.shape}"
assert output.device == device
def _logic_ulysses_with_mask(rank, world_size):
"""UlyssesAttention with attention mask."""
batch = 2
seq_per_rank = 8
seq_full = seq_per_rank * world_size
num_heads = world_size * 4
head_dim = 64
device = torch.device(f"cuda:{rank}") if torch.cuda.is_available() else torch.device("cpu")
inner = VanillaAttention(num_heads=num_heads // world_size, head_dim=head_dim)
attention = UlyssesAttention(
inner_backend=inner,
process_group=None,
).to(device)
q = torch.randn(batch, seq_per_rank, num_heads, head_dim, device=device)
k = torch.randn(batch, seq_per_rank, num_heads, head_dim, device=device)
v = torch.randn(batch, seq_per_rank, num_heads, head_dim, device=device)
mask = PredefinedAttentionMask.CAUSAL
output = attention(q, k, v, batch_size=batch, seq_len=seq_full, attention_mask=mask)
assert output.shape == q.shape
def _logic_ulysses_vs_standard_multi_gpu(rank, world_size):
"""UlyssesAttention across multiple GPUs matches standard attention on the full sequence."""
batch = 2
seq_per_rank = 8
seq_full = seq_per_rank * world_size
num_heads = world_size * 4
head_dim = 64
device = torch.device(f"cuda:{rank}" if torch.cuda.is_available() else "cpu")
# Every rank generates identical full tensors using the same seed.
torch.manual_seed(42)
q_full = torch.randn(batch, seq_full, num_heads, head_dim, device=device)
k_full = torch.randn(batch, seq_full, num_heads, head_dim, device=device)
v_full = torch.randn(batch, seq_full, num_heads, head_dim, device=device)
# Each rank takes its sequence shard.
q_shard = q_full[:, rank * seq_per_rank : (rank + 1) * seq_per_rank].contiguous()
k_shard = k_full[:, rank * seq_per_rank : (rank + 1) * seq_per_rank].contiguous()
v_shard = v_full[:, rank * seq_per_rank : (rank + 1) * seq_per_rank].contiguous()
# Ulysses attention on shards.
inner = VanillaAttention(num_heads=num_heads // world_size, head_dim=head_dim)
attention = UlyssesAttention(
inner_backend=inner,
process_group=None,
).to(device)
ulysses_output = attention(q_shard, k_shard, v_shard, batch_size=batch, seq_len=seq_full)
# Standard attention on the full tensors.
q_std = q_full.transpose(1, 2) # [B, H, S, D]
k_std = k_full.transpose(1, 2)
v_std = v_full.transpose(1, 2)
std_output = F.scaled_dot_product_attention(
q_std, k_std, v_std, scale=1.0 / math.sqrt(head_dim), dropout_p=0.0
)
std_output = std_output.transpose(1, 2).contiguous() # [B, S, H, D]
# Compare the shard slice.
expected_shard = std_output[:, rank * seq_per_rank : (rank + 1) * seq_per_rank]
torch.testing.assert_close(
ulysses_output,
expected_shard,
rtol=1e-4,
atol=1e-4,
msg=f"Rank {rank}: Ulysses multi-GPU output differs from standard attention",
)
def _logic_ulysses_invalid_heads(rank, world_size):
"""Invalid head count (not divisible by world_size) cannot be sharded."""
assert rank >= 0 and rank < world_size
num_heads = world_size * 4 + 1 # Not divisible
head_dim = 64
# With the decorator pattern, the caller is responsible for sharding heads.
# num_heads // world_size truncates, so the wrapper's computed full head
# count won't match the original.
sharded_heads = num_heads // world_size
inner = VanillaAttention(num_heads=sharded_heads, head_dim=head_dim)
attention = UlyssesAttention(inner_backend=inner, process_group=None)
assert attention.num_heads != num_heads # Truncation means mismatch
def _logic_different_batch_sizes(rank, world_size):
"""Various batch sizes."""
num_heads = world_size * 4
head_dim = 64
seq_per_rank = 8
device = torch.device(f"cuda:{rank}") if torch.cuda.is_available() else torch.device("cpu")
inner = VanillaAttention(num_heads=num_heads // world_size, head_dim=head_dim)
attention = UlyssesAttention(
inner_backend=inner,
process_group=None,
).to(device)
for batch_size in [1, 2, 4, 8]:
q = torch.randn(batch_size, seq_per_rank, num_heads, head_dim, device=device)
k = torch.randn(batch_size, seq_per_rank, num_heads, head_dim, device=device)
v = torch.randn(batch_size, seq_per_rank, num_heads, head_dim, device=device)
output = attention(q, k, v, batch_size=batch_size, seq_len=seq_per_rank * world_size)
assert output.shape == q.shape
def _logic_different_head_dims(rank, world_size):
"""Various head dims."""
batch = 2
seq_per_rank = 8
num_heads = world_size * 4
device = torch.device(f"cuda:{rank}") if torch.cuda.is_available() else torch.device("cpu")
for head_dim in [32, 64, 128]:
inner = VanillaAttention(num_heads=num_heads // world_size, head_dim=head_dim)
attention = UlyssesAttention(
inner_backend=inner,
process_group=None,
).to(device)
q = torch.randn(batch, seq_per_rank, num_heads, head_dim, device=device)
k = torch.randn(batch, seq_per_rank, num_heads, head_dim, device=device)
v = torch.randn(batch, seq_per_rank, num_heads, head_dim, device=device)
output = attention(q, k, v, batch_size=batch, seq_len=seq_per_rank * world_size)
assert output.shape == q.shape
def _logic_world_size_4(rank, world_size):
"""4-GPU test."""
batch = 2
seq_per_rank = 16
num_heads = world_size * 8 # 32 heads total
head_dim = 64
device = torch.device(f"cuda:{rank}") if torch.cuda.is_available() else torch.device("cpu")
inner = VanillaAttention(num_heads=num_heads // world_size, head_dim=head_dim)
attention = UlyssesAttention(
inner_backend=inner,
process_group=None,
).to(device)
q = torch.randn(batch, seq_per_rank, num_heads, head_dim, device=device)
k = torch.randn(batch, seq_per_rank, num_heads, head_dim, device=device)
v = torch.randn(batch, seq_per_rank, num_heads, head_dim, device=device)
output = attention(q, k, v, batch_size=batch, seq_len=seq_per_rank * world_size)
assert output.shape == q.shape
# =============================================================================
# Test classes
# =============================================================================
class TestAllToAll4D:
"""Tests for all_to_all_4d function."""
def test_all_to_all_4d_sequence_to_head(self):
"""Test sequence sharding to head sharding transformation."""
run_test_in_distributed(world_size=2, test_fn=_logic_a2a_seq_to_head, use_cuda=True)
def test_all_to_all_4d_head_to_sequence(self):
"""Test head sharding to sequence sharding transformation."""
run_test_in_distributed(world_size=2, test_fn=_logic_a2a_head_to_seq, use_cuda=True)
def test_all_to_all_4d_roundtrip(self):
"""Test that forward and backward all-to-all are inverses."""
run_test_in_distributed(world_size=2, test_fn=_logic_a2a_roundtrip, use_cuda=True)
def test_all_to_all_4d_single_process(self):
"""Test that single process returns input unchanged."""
run_test_in_distributed(world_size=1, test_fn=_logic_a2a_single_process, use_cuda=True)
class TestUlyssesAttention:
"""Tests for UlyssesAttention module."""
def test_ulysses_attention_initialization(self):
"""Test UlyssesAttention initialization."""
run_test_in_distributed(world_size=2, test_fn=_logic_ulysses_init, use_cuda=True)
def test_ulysses_attention_forward(self):
"""Test UlyssesAttention forward pass."""
run_test_in_distributed(world_size=2, test_fn=_logic_ulysses_forward, use_cuda=True)
def test_ulysses_attention_with_mask(self):
"""Test UlyssesAttention with attention mask."""
run_test_in_distributed(world_size=2, test_fn=_logic_ulysses_with_mask, use_cuda=True)
def test_ulysses_vs_standard_attention_single_gpu(self):
"""Compare UlyssesAttention with standard attention on single GPU."""
if not MODULES_AVAILABLE:
pytest.skip("Required modules not available")
if not torch.cuda.is_available():
pytest.skip("Test requires CUDA")
batch = 2
seq = 16
num_heads = 8
head_dim = 64
device = torch.device("cuda:0")
inner = VanillaAttention(num_heads=num_heads, head_dim=head_dim)
ulysses_attn = UlyssesAttention(
inner_backend=inner,
process_group=None,
).to(device)
torch.manual_seed(42)
q = torch.randn(batch, seq, num_heads, head_dim, device=device)
k = torch.randn(batch, seq, num_heads, head_dim, device=device)
v = torch.randn(batch, seq, num_heads, head_dim, device=device)
ulysses_output = ulysses_attn(q, k, v, batch_size=batch, seq_len=seq)
q_std = q.transpose(1, 2) # [B, H, S, D]
k_std = k.transpose(1, 2)
v_std = v.transpose(1, 2)
std_output = F.scaled_dot_product_attention(
q_std, k_std, v_std, scale=1.0 / math.sqrt(head_dim), dropout_p=0.0
)
std_output = std_output.transpose(1, 2).contiguous() # [B, S, H, D]
torch.testing.assert_close(
ulysses_output,
std_output,
rtol=1e-4,
atol=1e-4,
msg="Ulysses attention output differs from standard attention",
)
def test_ulysses_vs_standard_attention_multi_gpu(self):
"""Compare UlyssesAttention across GPUs with standard attention on full sequence."""
run_test_in_distributed(
world_size=2, test_fn=_logic_ulysses_vs_standard_multi_gpu, use_cuda=True
)
def test_ulysses_attention_invalid_heads(self):
"""Test that invalid head count raises error."""
run_test_in_distributed(world_size=2, test_fn=_logic_ulysses_invalid_heads, use_cuda=False)
class TestUlyssesAttentionEdgeCases:
"""Edge case tests for UlyssesAttention."""
def test_different_batch_sizes(self):
"""Test with various batch sizes."""
run_test_in_distributed(world_size=2, test_fn=_logic_different_batch_sizes, use_cuda=True)
def test_different_head_dims(self):
"""Test with various head dims."""
run_test_in_distributed(world_size=2, test_fn=_logic_different_head_dims, use_cuda=True)
def test_world_size_4(self):
"""Test with 4 GPUs."""
run_test_in_distributed(world_size=4, test_fn=_logic_world_size_4, use_cuda=True)
if __name__ == "__main__":
pytest.main([__file__, "-v"])

View File

@ -0,0 +1,540 @@
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
"""Test WAN Attention Integration.
Compares the new integrated attention (using TRT-LLM backend) with the original
naive implementation to ensure numerical equivalence.
"""
from types import SimpleNamespace
import pytest
import torch
import torch.nn as nn
import torch.nn.functional as F
from tensorrt_llm._torch.modules.rms_norm import RMSNorm
from tensorrt_llm._torch.visual_gen.config import AttentionConfig, DiffusionModelConfig
# Import new integrated versions
from tensorrt_llm._torch.visual_gen.modules.attention import Attention, QKVMode, apply_rotary_emb
# ============================================================================
# Original naive implementations for comparison
# ============================================================================
class NaiveWanSelfAttention(nn.Module):
"""Original naive self-attention implementation (for comparison)."""
def __init__(
self, hidden_size: int, num_heads: int, head_dim: int, eps: float = 1e-6, dtype=None
):
super().__init__()
self.num_heads = num_heads
self.head_dim = head_dim
self.hidden_size = hidden_size
# fused QKV projection
self.to_qkv = nn.Linear(hidden_size, 3 * hidden_size, dtype=dtype)
self.norm_q = RMSNorm(hidden_size=hidden_size, eps=eps, dtype=dtype, has_weights=True)
self.norm_k = RMSNorm(hidden_size=hidden_size, eps=eps, dtype=dtype, has_weights=True)
self.to_out = nn.ModuleList([nn.Linear(hidden_size, hidden_size, dtype=dtype)])
def forward(self, hidden_states, freqs_cos, freqs_sin):
B, S = hidden_states.shape[:2]
q, k, v = self.to_qkv(hidden_states).chunk(3, dim=-1)
q = self.norm_q(q).view(B, S, self.num_heads, self.head_dim).transpose(1, 2)
k = self.norm_k(k).view(B, S, self.num_heads, self.head_dim).transpose(1, 2)
v = v.view(B, S, self.num_heads, self.head_dim).transpose(1, 2)
if freqs_cos is not None and freqs_sin is not None:
q = apply_rotary_emb(q, freqs_cos, freqs_sin)
k = apply_rotary_emb(k, freqs_cos, freqs_sin)
out = F.scaled_dot_product_attention(q, k, v, is_causal=False)
out = out.transpose(1, 2).flatten(2)
out = self.to_out[0](out)
return out
class NaiveWanCrossAttention(nn.Module):
"""Original naive cross-attention implementation (for comparison)."""
def __init__(
self, hidden_size: int, num_heads: int, head_dim: int, eps: float = 1e-6, dtype=None
):
super().__init__()
self.num_heads = num_heads
self.head_dim = head_dim
self.hidden_size = hidden_size
self.to_q = nn.Linear(hidden_size, hidden_size, dtype=dtype)
self.to_k = nn.Linear(hidden_size, hidden_size, dtype=dtype)
self.to_v = nn.Linear(hidden_size, hidden_size, dtype=dtype)
self.norm_q = RMSNorm(hidden_size=hidden_size, eps=eps, dtype=dtype, has_weights=True)
self.norm_k = RMSNorm(hidden_size=hidden_size, eps=eps, dtype=dtype, has_weights=True)
self.to_out = nn.ModuleList([nn.Linear(hidden_size, hidden_size, dtype=dtype)])
def forward(self, hidden_states, encoder_hidden_states):
B, S = hidden_states.shape[:2]
q = self.norm_q(self.to_q(hidden_states))
k = self.norm_k(self.to_k(encoder_hidden_states))
v = self.to_v(encoder_hidden_states)
q = q.view(B, S, self.num_heads, self.head_dim).transpose(1, 2)
k = k.view(B, -1, self.num_heads, self.head_dim).transpose(1, 2)
v = v.view(B, -1, self.num_heads, self.head_dim).transpose(1, 2)
out = F.scaled_dot_product_attention(q, k, v, is_causal=False)
out = out.transpose(1, 2).flatten(2)
out = self.to_out[0](out)
return out
# ============================================================================
# Test utilities
# ============================================================================
def create_model_config(
hidden_size: int,
num_heads: int,
head_dim: int,
eps: float = 1e-6,
attn_backend: str = "VANILLA",
):
"""Create a mock DiffusionModelConfig for testing."""
pretrained_config = SimpleNamespace(
hidden_size=hidden_size,
num_attention_heads=num_heads,
attention_head_dim=head_dim,
eps=eps,
)
# Create a minimal config without quantization
config = DiffusionModelConfig(
pretrained_config=pretrained_config,
attention=AttentionConfig(backend=attn_backend),
skip_create_weights_in_init=False,
)
return config
def copy_weights_self_attention(naive: NaiveWanSelfAttention, integrated: Attention):
"""Copy weights from naive to integrated self-attention."""
# QKV projection: naive has to_qkv, integrated has qkv_proj
integrated.qkv_proj.weight.data.copy_(naive.to_qkv.weight.data)
if naive.to_qkv.bias is not None and integrated.qkv_proj.bias is not None:
integrated.qkv_proj.bias.data.copy_(naive.to_qkv.bias.data)
# QK norms
integrated.norm_q.weight.data.copy_(naive.norm_q.weight.data)
integrated.norm_k.weight.data.copy_(naive.norm_k.weight.data)
# Output projection
integrated.to_out[0].weight.data.copy_(naive.to_out[0].weight.data)
if naive.to_out[0].bias is not None and integrated.to_out[0].bias is not None:
integrated.to_out[0].bias.data.copy_(naive.to_out[0].bias.data)
def copy_weights_cross_attention(naive: NaiveWanCrossAttention, integrated: Attention):
"""Copy weights from naive to integrated cross-attention."""
# Q, K, V projections
integrated.to_q.weight.data.copy_(naive.to_q.weight.data)
integrated.to_k.weight.data.copy_(naive.to_k.weight.data)
integrated.to_v.weight.data.copy_(naive.to_v.weight.data)
if naive.to_q.bias is not None and integrated.to_q.bias is not None:
integrated.to_q.bias.data.copy_(naive.to_q.bias.data)
if naive.to_k.bias is not None and integrated.to_k.bias is not None:
integrated.to_k.bias.data.copy_(naive.to_k.bias.data)
if naive.to_v.bias is not None and integrated.to_v.bias is not None:
integrated.to_v.bias.data.copy_(naive.to_v.bias.data)
# QK norms
integrated.norm_q.weight.data.copy_(naive.norm_q.weight.data)
integrated.norm_k.weight.data.copy_(naive.norm_k.weight.data)
# Output projection
integrated.to_out[0].weight.data.copy_(naive.to_out[0].weight.data)
if naive.to_out[0].bias is not None and integrated.to_out[0].bias is not None:
integrated.to_out[0].bias.data.copy_(naive.to_out[0].bias.data)
def generate_rope_embeddings(
seq_len: int, head_dim: int, device: torch.device, is_HSD: bool = False
):
"""Generate RoPE embeddings with full head_dim.
apply_rotary_emb expects freqs with full head_dim, then slices with [..., 0::2] and [..., 1::2].
Args:
is_HSD: If True, returns [1, 1, S, D] for broadcasting with [B, H, S, D] (naive)
If False, returns [1, S, 1, D] for broadcasting with [B, S, H, D] (integrated)
"""
position = torch.arange(seq_len, device=device).unsqueeze(1)
# Use full head_dim - apply_rotary_emb will slice with 0::2 and 1::2
div_term = torch.exp(
torch.arange(0, head_dim, device=device) * (-torch.log(torch.tensor(10000.0)) / head_dim)
)
if is_HSD:
freqs_cos = torch.cos(position * div_term).unsqueeze(0).unsqueeze(0) # [1, 1, S, D]
freqs_sin = torch.sin(position * div_term).unsqueeze(0).unsqueeze(0) # [1, 1, S, D]
else:
freqs_cos = torch.cos(position * div_term).unsqueeze(0).unsqueeze(2) # [1, S, 1, D]
freqs_sin = torch.sin(position * div_term).unsqueeze(0).unsqueeze(2) # [1, S, 1, D]
return freqs_cos, freqs_sin
# ============================================================================
# Test functions
# ============================================================================
@pytest.mark.parametrize("attn_backend", ["VANILLA", "TRTLLM"])
def test_self_attention_equivalence(attn_backend: str):
"""Test that integrated self-attention produces same output as naive."""
print("\n" + "=" * 60)
print("Testing Self-Attention Equivalence")
print("=" * 60)
# Config
batch_size = 2
seq_len = 16
hidden_size = 128
num_heads = 4
head_dim = hidden_size // num_heads
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
dtype = torch.bfloat16 # Use bf16 since flashinfer doesn't support fp32
print(f"Config: B={batch_size}, S={seq_len}, H={hidden_size}, heads={num_heads}")
print(f"Device: {device}, dtype: {dtype}")
# Create models
naive = NaiveWanSelfAttention(hidden_size, num_heads, head_dim, dtype=dtype).to(device)
model_config = create_model_config(hidden_size, num_heads, head_dim, attn_backend=attn_backend)
integrated = Attention(
hidden_size, num_heads, qkv_mode=QKVMode.FUSE_QKV, config=model_config
).to(device) # self attention
# Copy weights
copy_weights_self_attention(naive, integrated)
# Set to eval mode
naive.eval()
integrated.eval()
# Create inputs
torch.manual_seed(42)
hidden_states = torch.randn(batch_size, seq_len, hidden_size, device=device, dtype=dtype)
# Naive uses [1, 1, S, D] (HSD format) - broadcasts with [B, H, S, D]
freqs_cos_HSD, freqs_sin_HSD = generate_rope_embeddings(seq_len, head_dim, device, is_HSD=True)
# Integrated uses [1, S, 1, D] (SHD format) - broadcasts with [B, S, H, D]
freqs_cos_SHD, freqs_sin_SHD = generate_rope_embeddings(seq_len, head_dim, device, is_HSD=False)
# Forward pass
with torch.no_grad():
out_naive = naive(hidden_states, freqs_cos_HSD, freqs_sin_HSD)
out_integrated = integrated(hidden_states, freqs=(freqs_cos_SHD, freqs_sin_SHD))
# Compare (using looser tolerance for bf16)
max_diff = (out_naive - out_integrated).abs().max().item()
mean_diff = (out_naive - out_integrated).abs().mean().item()
is_close = torch.allclose(out_naive, out_integrated, rtol=1e-2, atol=1e-2)
print("\nResults:")
print(f" Output shape: naive={out_naive.shape}, integrated={out_integrated.shape}")
print(f" Max absolute difference: {max_diff:.2e}")
print(f" Mean absolute difference: {mean_diff:.2e}")
print(f" Outputs match (rtol=1e-2, atol=1e-2): {is_close}")
if is_close:
print(" ✅ PASS: Self-attention outputs match!")
else:
print(" ❌ FAIL: Self-attention outputs differ!")
assert is_close, (
f"Self-attention outputs differ: max_diff={max_diff:.2e}, mean_diff={mean_diff:.2e}"
)
return is_close
@pytest.mark.parametrize("attn_backend", ["VANILLA"])
def test_cross_attention_equivalence(attn_backend: str):
"""Test that integrated cross-attention produces same output as naive."""
print("\n" + "=" * 60)
print("Testing Cross-Attention Equivalence")
print("=" * 60)
# Config
batch_size = 2
seq_len = 16
encoder_seq_len = 24 # Different from query seq_len
hidden_size = 128
num_heads = 4
head_dim = hidden_size // num_heads
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
dtype = torch.bfloat16 # Use bf16 since flashinfer doesn't support fp32
print(
f"Config: B={batch_size}, S_q={seq_len}, S_kv={encoder_seq_len}, H={hidden_size}, heads={num_heads}"
)
print(f"Device: {device}, dtype: {dtype}")
# Create models
naive = NaiveWanCrossAttention(hidden_size, num_heads, head_dim, dtype=dtype).to(device)
model_config = create_model_config(hidden_size, num_heads, head_dim, attn_backend=attn_backend)
integrated = Attention(
hidden_size, num_heads, qkv_mode=QKVMode.SEPARATE_QKV, config=model_config
).to(device) # cross attention
# Copy weights
copy_weights_cross_attention(naive, integrated)
# Set to eval mode
naive.eval()
integrated.eval()
# Create inputs
torch.manual_seed(42)
hidden_states = torch.randn(batch_size, seq_len, hidden_size, device=device, dtype=dtype)
encoder_hidden_states = torch.randn(
batch_size, encoder_seq_len, hidden_size, device=device, dtype=dtype
)
# Forward pass
with torch.no_grad():
out_naive = naive(hidden_states, encoder_hidden_states)
out_integrated = integrated(hidden_states, encoder_hidden_states)
# Compare (using looser tolerance for bf16)
max_diff = (out_naive - out_integrated).abs().max().item()
mean_diff = (out_naive - out_integrated).abs().mean().item()
is_close = torch.allclose(out_naive, out_integrated, rtol=1e-2, atol=1e-2)
print("\nResults:")
print(f" Output shape: naive={out_naive.shape}, integrated={out_integrated.shape}")
print(f" Max absolute difference: {max_diff:.2e}")
print(f" Mean absolute difference: {mean_diff:.2e}")
print(f" Outputs match (rtol=1e-2, atol=1e-2): {is_close}")
if is_close:
print(" ✅ PASS: Cross-attention outputs match!")
else:
print(" ❌ FAIL: Cross-attention outputs differ!")
assert is_close, (
f"Cross-attention outputs differ: max_diff={max_diff:.2e}, mean_diff={mean_diff:.2e}"
)
return is_close
def test_trtllm_cached_prepare():
"""Test that TRTLLM attention cached prepare works correctly.
This test verifies that when running multiple forward passes with same B/S
but different q/k/v values, the cached prepare phase doesn't cause incorrect
results (i.e., outputs should differ when inputs differ).
"""
print("\n" + "=" * 60)
print("Testing TRTLLM Cached Prepare Phase")
print("=" * 60)
# Config - same B, S for all iterations
batch_size = 2
seq_len = 16
hidden_size = 128
num_heads = 4
head_dim = hidden_size // num_heads
num_iterations = 5
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
dtype = torch.bfloat16
print(f"Config: B={batch_size}, S={seq_len}, H={hidden_size}, heads={num_heads}")
print(f"Running {num_iterations} iterations with same B/S but different inputs")
# Create models - single instance to test caching
naive = NaiveWanSelfAttention(hidden_size, num_heads, head_dim, dtype=dtype).to(device)
model_config = create_model_config(hidden_size, num_heads, head_dim, attn_backend="TRTLLM")
integrated = Attention(
hidden_size, num_heads, qkv_mode=QKVMode.FUSE_QKV, config=model_config
).to(device) # self attention
# Copy weights
copy_weights_self_attention(naive, integrated)
naive.eval()
integrated.eval()
# Generate freqs (same for all iterations since S is same)
freqs_cos_HSD, freqs_sin_HSD = generate_rope_embeddings(seq_len, head_dim, device, is_HSD=True)
freqs_cos_SHD, freqs_sin_SHD = generate_rope_embeddings(seq_len, head_dim, device, is_HSD=False)
all_passed = True
outputs_integrated = []
with torch.no_grad():
for i in range(num_iterations):
# Different random inputs for each iteration
torch.manual_seed(42 + i) # Different seed each time
hidden_states = torch.randn(
batch_size, seq_len, hidden_size, device=device, dtype=dtype
)
out_naive = naive(hidden_states, freqs_cos_HSD, freqs_sin_HSD)
out_integrated = integrated(hidden_states, freqs=(freqs_cos_SHD, freqs_sin_SHD))
# Check this iteration matches naive
max_diff = (out_naive - out_integrated).abs().max().item()
is_close = torch.allclose(out_naive, out_integrated, rtol=1e-2, atol=1e-2)
status = "" if is_close else ""
print(f" Iteration {i + 1}: max_diff={max_diff:.2e} {status}")
if not is_close:
all_passed = False
outputs_integrated.append(out_integrated.clone())
# Additional check: outputs should be DIFFERENT across iterations
# (since inputs were different)
print("\n Checking outputs differ across iterations (inputs were different):")
outputs_differ = True
for i in range(1, num_iterations):
diff = (outputs_integrated[i] - outputs_integrated[0]).abs().max().item()
if diff < 1e-6:
print(
f" ⚠️ Iteration {i + 1} output same as iteration 1 (diff={diff:.2e}) - possible caching bug!"
)
outputs_differ = False
else:
print(f" Iteration {i + 1} vs 1: diff={diff:.2e}")
if all_passed and outputs_differ:
print("\n ✅ PASS: Cached prepare works correctly!")
else:
print("\n ❌ FAIL: Cached prepare may have issues!")
all_passed = False
assert all_passed, "Cached prepare: outputs did not match naive reference"
assert outputs_differ, (
"Cached prepare: outputs should differ across iterations with different inputs"
)
return all_passed
def test_trtllm_varying_seq_len():
"""Test TRTLLM attention with varying sequence lengths.
This tests that the prepare phase correctly handles different seq_lens
and doesn't incorrectly reuse cached metadata.
"""
print("\n" + "=" * 60)
print("Testing TRTLLM with Varying Sequence Lengths")
print("=" * 60)
batch_size = 2
hidden_size = 128
num_heads = 4
head_dim = hidden_size // num_heads
seq_lens = [8, 16, 32, 16, 8] # Vary seq_len, including repeats
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
dtype = torch.bfloat16
print(f"Config: B={batch_size}, H={hidden_size}, heads={num_heads}")
print(f"Testing seq_lens: {seq_lens}")
# Create models - single instance to test caching across different seq_lens
naive = NaiveWanSelfAttention(hidden_size, num_heads, head_dim, dtype=dtype).to(device)
model_config = create_model_config(hidden_size, num_heads, head_dim, attn_backend="TRTLLM")
integrated = Attention(
hidden_size, num_heads, qkv_mode=QKVMode.FUSE_QKV, config=model_config
).to(device) # self attention
copy_weights_self_attention(naive, integrated)
naive.eval()
integrated.eval()
all_passed = True
with torch.no_grad():
for i, seq_len in enumerate(seq_lens):
torch.manual_seed(42 + i)
hidden_states = torch.randn(
batch_size, seq_len, hidden_size, device=device, dtype=dtype
)
freqs_cos_HSD, freqs_sin_HSD = generate_rope_embeddings(
seq_len, head_dim, device, is_HSD=True
)
freqs_cos_SHD, freqs_sin_SHD = generate_rope_embeddings(
seq_len, head_dim, device, is_HSD=False
)
out_naive = naive(hidden_states, freqs_cos_HSD, freqs_sin_HSD)
out_integrated = integrated(hidden_states, freqs=(freqs_cos_SHD, freqs_sin_SHD))
max_diff = (out_naive - out_integrated).abs().max().item()
is_close = torch.allclose(out_naive, out_integrated, rtol=1e-2, atol=1e-2)
status = "" if is_close else ""
print(f" seq_len={seq_len:3d}: max_diff={max_diff:.2e} {status}")
if not is_close:
all_passed = False
if all_passed:
print("\n ✅ PASS: Varying seq_len handled correctly!")
else:
print("\n ❌ FAIL: Issues with varying seq_len!")
assert all_passed, "Varying seq_len: outputs did not match naive reference"
return all_passed
def run_all_tests():
"""Run all tests and report results."""
print("\n" + "=" * 60)
print("WAN Attention Integration Tests")
print("=" * 60)
results = {}
# Run self-attention tests with different backends
for backend in ["VANILLA", "TRTLLM"]:
results[f"self_attention_{backend}"] = test_self_attention_equivalence(backend)
# Run cross-attention test (VANILLA only)
results["cross_attention_VANILLA"] = test_cross_attention_equivalence("VANILLA")
# Run TRTLLM-specific caching tests
results["trtllm_cached_prepare"] = test_trtllm_cached_prepare()
results["trtllm_varying_seq_len"] = test_trtllm_varying_seq_len()
print("\n" + "=" * 60)
print("Summary")
print("=" * 60)
all_passed = all(results.values())
for name, passed in results.items():
status = "✅ PASS" if passed else "❌ FAIL"
print(f" {name}: {status}")
print()
if all_passed:
print("All tests passed! ✅")
else:
print("Some tests failed! ❌")
return all_passed
if __name__ == "__main__":
run_all_tests()

View File

@ -0,0 +1,622 @@
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""WAN Attention Performance Benchmark.
Compares VANILLA vs TRTLLM attention backends for visual generation models.
Uses CUDA events for precise GPU timing and supports NVTX profiling.
Usage:
# Run all tests
python test_attention_perf.py
# With Nsight Systems profiling
nsys profile -t cuda,nvtx --nvtx-capture=range -o wan_attn_perf python test_attention_perf.py
# Run specific tests with pytest
pytest test_attention_perf.py -v -k "test_self_attention_perf"
"""
import time
from contextlib import contextmanager
from types import SimpleNamespace
from typing import Dict, Optional, Tuple
import pytest
import torch
from tensorrt_llm._torch.visual_gen.config import AttentionConfig, DiffusionModelConfig
from tensorrt_llm._torch.visual_gen.modules.attention import Attention, QKVMode
# NVTX support for profiling
try:
import nvtx
NVTX_AVAILABLE = True
if hasattr(nvtx, "annotate"):
NVTX_METHOD = "annotate"
elif hasattr(nvtx, "range_start") and hasattr(nvtx, "range_end"):
NVTX_METHOD = "range"
else:
NVTX_METHOD = None
NVTX_AVAILABLE = False
except ImportError:
NVTX_AVAILABLE = False
NVTX_METHOD = None
# Torch profiler support
try:
from torch.profiler import record_function
TORCH_PROFILER_AVAILABLE = True
except ImportError:
TORCH_PROFILER_AVAILABLE = False
# ============================================================================
# Timing utilities
# ============================================================================
@contextmanager
def cuda_timer(device: torch.device):
"""Context manager for precise GPU timing using CUDA events."""
if device.type == "cuda":
start_event = torch.cuda.Event(enable_timing=True)
end_event = torch.cuda.Event(enable_timing=True)
start_event.record()
def get_elapsed_time():
end_event.record()
torch.cuda.synchronize()
return start_event.elapsed_time(end_event)
yield get_elapsed_time
else:
start_time = time.perf_counter()
def get_elapsed_time():
return (time.perf_counter() - start_time) * 1000
yield get_elapsed_time
@contextmanager
def nvtx_range(name: str):
"""Context manager for NVTX range profiling."""
if NVTX_AVAILABLE and NVTX_METHOD:
if NVTX_METHOD == "annotate":
with nvtx.annotate(name):
yield
elif NVTX_METHOD == "range":
range_id = nvtx.range_start(name)
try:
yield
finally:
nvtx.range_end(range_id)
else:
yield
else:
yield
@contextmanager
def torch_profiler_range(name: str):
"""Context manager for torch profiler range."""
if TORCH_PROFILER_AVAILABLE:
with record_function(name):
yield
else:
yield
# ============================================================================
# Test utilities
# ============================================================================
def create_model_config(
hidden_size: int,
num_heads: int,
head_dim: int,
eps: float = 1e-6,
attn_backend: str = "VANILLA",
) -> DiffusionModelConfig:
"""Create a mock DiffusionModelConfig for testing."""
pretrained_config = SimpleNamespace(
hidden_size=hidden_size,
num_attention_heads=num_heads,
attention_head_dim=head_dim,
eps=eps,
)
config = DiffusionModelConfig(
pretrained_config=pretrained_config,
attention=AttentionConfig(backend=attn_backend),
skip_create_weights_in_init=False,
)
return config
def generate_rope_embeddings(
seq_len: int, head_dim: int, device: torch.device, is_HSD: bool = False
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Generate RoPE embeddings.
Args:
seq_len: Sequence length
head_dim: Head dimension
device: Target device
is_HSD: If True, returns [1, 1, S, D] for HSD format, else [1, S, 1, D] for SHD
Returns:
Tuple of (freqs_cos, freqs_sin)
"""
position = torch.arange(seq_len, device=device).unsqueeze(1)
div_term = torch.exp(
torch.arange(0, head_dim, device=device) * (-torch.log(torch.tensor(10000.0)) / head_dim)
)
if is_HSD:
freqs_cos = torch.cos(position * div_term).unsqueeze(0).unsqueeze(0)
freqs_sin = torch.sin(position * div_term).unsqueeze(0).unsqueeze(0)
else:
freqs_cos = torch.cos(position * div_term).unsqueeze(0).unsqueeze(2)
freqs_sin = torch.sin(position * div_term).unsqueeze(0).unsqueeze(2)
return freqs_cos, freqs_sin
# ============================================================================
# Performance benchmark class
# ============================================================================
class WanAttentionPerformanceBenchmark:
"""Performance benchmark for WAN attention backends."""
# WAN model configurations: (batch_size, num_heads, seq_len, head_dim, description)
TEST_SIZES = [
# Wan2.1-T2V-1.3B configurations
(1, 24, 14040, 64, "Wan-1.3B 480p 2s"),
(1, 24, 3510, 64, "Wan-1.3B 480p 2s ring4"),
(1, 24, 7020, 64, "Wan-1.3B 480p 2s ring2"),
# Wan2.1-T2V-14B configurations
(1, 40, 75600, 128, "Wan-14B 720p 5s"),
(1, 40, 37800, 128, "Wan-14B 720p 5s ring2"),
(1, 40, 18900, 128, "Wan-14B 720p 5s ring4"),
(1, 40, 9450, 128, "Wan-14B 720p 5s ring8"),
# Ulysses parallelism configurations
(1, 20, 75600, 128, "Wan-14B 720p ulysses2"),
(1, 10, 75600, 128, "Wan-14B 720p ulysses4"),
(1, 5, 75600, 128, "Wan-14B 720p ulysses8"),
# Smaller test cases for quick validation
(2, 24, 1024, 64, "Small batch2"),
(1, 24, 4096, 64, "Medium 4k"),
(1, 40, 8192, 128, "Large 8k"),
]
# Quick test sizes for CI/pytest
QUICK_TEST_SIZES = [
(1, 24, 1024, 64, "Quick 1k"),
(1, 24, 2048, 64, "Quick 2k"),
(2, 24, 1024, 64, "Quick batch2"),
]
def __init__(
self,
device: Optional[torch.device] = None,
dtype: torch.dtype = torch.bfloat16,
warmup_iterations: int = 10,
benchmark_iterations: int = 50,
):
self.device = device or torch.device("cuda" if torch.cuda.is_available() else "cpu")
self.dtype = dtype
self.warmup_iterations = warmup_iterations
self.benchmark_iterations = benchmark_iterations
self.backends = ["VANILLA", "TRTLLM"]
def create_attention_model(
self, hidden_size: int, num_heads: int, head_dim: int, backend: str
) -> Attention:
"""Create a WAN self-attention model with specified backend."""
config = create_model_config(hidden_size, num_heads, head_dim, attn_backend=backend)
model = Attention(hidden_size, num_heads, qkv_mode=QKVMode.FUSE_QKV, config=config).to(
self.device
)
model.eval()
return model
def create_test_data(
self, batch_size: int, seq_len: int, hidden_size: int, head_dim: int
) -> Tuple[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
"""Create test input data and RoPE embeddings."""
hidden_states = torch.randn(
batch_size, seq_len, hidden_size, device=self.device, dtype=self.dtype
)
freqs = generate_rope_embeddings(seq_len, head_dim, self.device, is_HSD=False)
return hidden_states, freqs
def estimate_memory_gb(
self, batch_size: int, num_heads: int, seq_len: int, head_dim: int
) -> float:
"""Estimate tensor memory usage in GB."""
hidden_size = num_heads * head_dim
# Input: [B, S, H] + Q, K, V: [B, S, num_heads, head_dim] each
bytes_per_element = 2 # bf16
input_bytes = batch_size * seq_len * hidden_size * bytes_per_element
qkv_bytes = 3 * batch_size * seq_len * num_heads * head_dim * bytes_per_element
output_bytes = batch_size * seq_len * hidden_size * bytes_per_element
# Attention matrix can be O(S^2) but flash attention avoids materializing it
return (input_bytes + qkv_bytes + output_bytes) / (1024**3)
def benchmark_single(
self,
batch_size: int,
num_heads: int,
seq_len: int,
head_dim: int,
backend: str,
verbose: bool = True,
) -> Optional[Dict]:
"""Benchmark a single configuration.
Returns:
Dict with timing statistics or None if test failed/skipped
"""
hidden_size = num_heads * head_dim
# Memory check
est_memory = self.estimate_memory_gb(batch_size, num_heads, seq_len, head_dim)
if est_memory > 8.0:
if verbose:
print(f" Skipping - estimated memory {est_memory:.2f}GB > 8GB limit")
return None
try:
# Create model and data
model = self.create_attention_model(hidden_size, num_heads, head_dim, backend)
hidden_states, freqs = self.create_test_data(batch_size, seq_len, hidden_size, head_dim)
# Warmup
with nvtx_range(f"warmup_{backend}"):
with torch_profiler_range(f"warmup_{backend}"):
with torch.no_grad():
for _ in range(self.warmup_iterations):
_ = model(hidden_states, freqs=freqs)
if self.device.type == "cuda":
torch.cuda.synchronize()
# Benchmark
times = []
with nvtx_range(f"benchmark_{backend}"):
with torch_profiler_range(f"benchmark_{backend}"):
with torch.no_grad():
for i in range(self.benchmark_iterations):
with nvtx_range(f"iter_{backend}_{i}"):
with cuda_timer(self.device) as get_time:
_ = model(hidden_states, freqs=freqs)
times.append(get_time())
# Statistics
times_tensor = torch.tensor(times)
stats = {
"avg_ms": times_tensor.mean().item(),
"min_ms": times_tensor.min().item(),
"max_ms": times_tensor.max().item(),
"std_ms": times_tensor.std().item(),
"median_ms": times_tensor.median().item(),
"p95_ms": torch.quantile(times_tensor, 0.95).item(),
"p99_ms": torch.quantile(times_tensor, 0.99).item(),
}
# Calculate throughput (approximate TOPS)
total_ops = batch_size * num_heads * seq_len * seq_len * head_dim
stats["throughput_tops"] = (total_ops / 1e12) / (stats["avg_ms"] / 1000)
if verbose:
print(
f" {backend}: avg={stats['avg_ms']:.3f}ms, "
f"median={stats['median_ms']:.3f}ms, "
f"throughput={stats['throughput_tops']:.2f} TOPS"
)
return stats
except Exception as e:
if verbose:
print(f" {backend}: ERROR - {e}")
return None
def benchmark_comparison(
self,
batch_size: int,
num_heads: int,
seq_len: int,
head_dim: int,
description: str = "",
verbose: bool = True,
) -> Dict[str, Optional[Dict]]:
"""Benchmark and compare all backends for a given configuration."""
if verbose:
print(
f"\nBenchmarking: ({batch_size}, {num_heads}, {seq_len}, {head_dim}) {description}"
)
print(f" Device: {self.device}, dtype: {self.dtype}")
print(f" Warmup: {self.warmup_iterations}, Iterations: {self.benchmark_iterations}")
results = {}
for backend in self.backends:
results[backend] = self.benchmark_single(
batch_size, num_heads, seq_len, head_dim, backend, verbose
)
# Print comparison
if verbose and results.get("VANILLA") and results.get("TRTLLM"):
vanilla_avg = results["VANILLA"]["avg_ms"]
trtllm_avg = results["TRTLLM"]["avg_ms"]
speedup = vanilla_avg / trtllm_avg
print(f" TRTLLM vs VANILLA: {speedup:.2f}x {'faster' if speedup > 1 else 'slower'}")
return results
def run_full_benchmark(self, use_quick_sizes: bool = False) -> Dict:
"""Run benchmark on all configured sizes."""
test_sizes = self.QUICK_TEST_SIZES if use_quick_sizes else self.TEST_SIZES
print("\n" + "=" * 70)
print("WAN ATTENTION PERFORMANCE BENCHMARK")
print("=" * 70)
print(f"Device: {self.device}")
print(f"dtype: {self.dtype}")
print(f"Backends: {self.backends}")
print(f"NVTX: {'Enabled' if NVTX_AVAILABLE else 'Disabled'}")
print(f"Torch Profiler: {'Enabled' if TORCH_PROFILER_AVAILABLE else 'Disabled'}")
all_results = {}
with nvtx_range("wan_attention_benchmark"):
with torch_profiler_range("wan_attention_benchmark"):
for batch_size, num_heads, seq_len, head_dim, desc in test_sizes:
key = f"{desc}_{batch_size}x{num_heads}x{seq_len}x{head_dim}"
results = self.benchmark_comparison(
batch_size, num_heads, seq_len, head_dim, desc
)
all_results[key] = {
"config": {
"batch_size": batch_size,
"num_heads": num_heads,
"seq_len": seq_len,
"head_dim": head_dim,
"description": desc,
},
"results": results,
}
# Print summary
self._print_summary(all_results)
return all_results
def _print_summary(self, all_results: Dict) -> None:
"""Print benchmark summary table."""
print("\n" + "=" * 70)
print("BENCHMARK SUMMARY")
print("=" * 70)
print(f"{'Configuration':<40} {'VANILLA (ms)':<15} {'TRTLLM (ms)':<15} {'Speedup':<10}")
print("-" * 70)
for key, data in all_results.items():
desc = data["config"]["description"]
results = data["results"]
vanilla = results.get("VANILLA")
trtllm = results.get("TRTLLM")
vanilla_str = f"{vanilla['avg_ms']:.2f}" if vanilla else "N/A"
trtllm_str = f"{trtllm['avg_ms']:.2f}" if trtllm else "N/A"
if vanilla and trtllm:
speedup = vanilla["avg_ms"] / trtllm["avg_ms"]
speedup_str = f"{speedup:.2f}x"
else:
speedup_str = "N/A"
print(f"{desc:<40} {vanilla_str:<15} {trtllm_str:<15} {speedup_str:<10}")
def test_memory_usage(
self,
batch_size: int = 1,
num_heads: int = 24,
seq_len: int = 4096,
head_dim: int = 64,
) -> Dict[str, Dict]:
"""Test memory usage of different backends."""
if self.device.type != "cuda":
print("Memory test requires CUDA device")
return {}
print("\n" + "=" * 70)
print("MEMORY USAGE TEST")
print("=" * 70)
print(f"Config: ({batch_size}, {num_heads}, {seq_len}, {head_dim})")
hidden_size = num_heads * head_dim
memory_results = {}
for backend in self.backends:
print(f"\nTesting {backend}...")
try:
# Clear cache
torch.cuda.empty_cache()
torch.cuda.reset_peak_memory_stats()
# Create model and data
model = self.create_attention_model(hidden_size, num_heads, head_dim, backend)
hidden_states, freqs = self.create_test_data(
batch_size, seq_len, hidden_size, head_dim
)
# Warmup
with torch.no_grad():
_ = model(hidden_states, freqs=freqs)
torch.cuda.synchronize()
torch.cuda.reset_peak_memory_stats()
# Forward pass
with nvtx_range(f"memory_test_{backend}"):
with torch.no_grad():
_ = model(hidden_states, freqs=freqs)
torch.cuda.synchronize()
peak_memory_gb = torch.cuda.max_memory_allocated() / (1024**3)
current_memory_gb = torch.cuda.memory_allocated() / (1024**3)
memory_results[backend] = {
"peak_memory_gb": peak_memory_gb,
"current_memory_gb": current_memory_gb,
}
print(f" Peak memory: {peak_memory_gb:.3f} GB")
print(f" Current memory: {current_memory_gb:.3f} GB")
except Exception as e:
print(f" ERROR: {e}")
memory_results[backend] = None
return memory_results
# ============================================================================
# Pytest test functions
# ============================================================================
class TestWanAttentionPerformance:
"""Pytest test class for WAN attention performance."""
@pytest.fixture(autouse=True)
def setup(self):
"""Setup test environment."""
self.benchmark = WanAttentionPerformanceBenchmark(
warmup_iterations=5,
benchmark_iterations=20,
)
@pytest.mark.parametrize("backend", ["VANILLA", "TRTLLM"])
def test_self_attention_perf(self, backend: str):
"""Test that attention backend runs without errors."""
batch_size, num_heads, seq_len, head_dim = 1, 24, 1024, 64
result = self.benchmark.benchmark_single(
batch_size, num_heads, seq_len, head_dim, backend, verbose=True
)
if result is not None:
assert result["avg_ms"] > 0, "Average time should be positive"
assert result["min_ms"] <= result["avg_ms"], "Min should be <= avg"
assert result["max_ms"] >= result["avg_ms"], "Max should be >= avg"
print(f" {backend}: avg={result['avg_ms']:.3f}ms OK")
@pytest.mark.parametrize(
"batch_size,num_heads,seq_len,head_dim",
[
(1, 24, 1024, 64),
(1, 24, 2048, 64),
(2, 24, 1024, 64),
],
)
def test_backend_comparison(self, batch_size: int, num_heads: int, seq_len: int, head_dim: int):
"""Test VANILLA vs TRTLLM comparison."""
results = self.benchmark.benchmark_comparison(
batch_size, num_heads, seq_len, head_dim, verbose=True
)
# At least one backend should work
assert any(r is not None for r in results.values()), "All backends failed"
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA required")
def test_memory_usage(self):
"""Test memory usage tracking."""
memory_results = self.benchmark.test_memory_usage(
batch_size=1, num_heads=24, seq_len=2048, head_dim=64
)
for backend, result in memory_results.items():
if result is not None:
assert result["peak_memory_gb"] > 0, f"{backend} peak memory should be positive"
def test_quick_benchmark(self):
"""Run quick benchmark for CI validation."""
results = self.benchmark.run_full_benchmark(use_quick_sizes=True)
assert len(results) > 0, "Should have benchmark results"
# ============================================================================
# Main entry point
# ============================================================================
def main():
"""Run full benchmark suite."""
print("\n" + "=" * 70)
print("WAN ATTENTION PERFORMANCE BENCHMARK SUITE")
print("=" * 70)
if not torch.cuda.is_available():
print("WARNING: CUDA not available, results will not be meaningful")
# Print profiling instructions
if torch.cuda.is_available():
print("\nPROFILING INSTRUCTIONS:")
print("-" * 50)
if NVTX_AVAILABLE:
print("NVTX Profiling (Nsight Systems):")
print(" nsys profile -t cuda,nvtx --nvtx-capture=range \\")
print(" -o wan_attn_perf python test_attention_perf.py")
else:
print("NVTX not available. Install with: pip install nvtx")
print("\nPyTorch Profiler:")
print(" The benchmark includes record_function() calls for profiling")
print("-" * 50)
# Create benchmark instance
benchmark = WanAttentionPerformanceBenchmark(
warmup_iterations=10,
benchmark_iterations=50,
)
# Run full benchmark
print("\n" + "=" * 70)
print("FULL BENCHMARK")
print("=" * 70)
all_results = benchmark.run_full_benchmark(use_quick_sizes=False)
# Memory test
if torch.cuda.is_available():
benchmark.test_memory_usage(batch_size=1, num_heads=24, seq_len=4096, head_dim=64)
print("\n" + "=" * 70)
print("BENCHMARK COMPLETE")
print("=" * 70)
return all_results
if __name__ == "__main__":
main()

View File

@ -0,0 +1,126 @@
"""Tests for fused QKV support in diffusion models.
Tests:
1. Model structure with fuse_qkv=True (default) vs fuse_qkv=False
2. Weight loading works for fused QKV layers
"""
import unittest
from types import SimpleNamespace
from typing import Dict
import torch
from tensorrt_llm._torch.visual_gen.config import DiffusionModelConfig
def _create_test_config(hidden_size: int = 64) -> DiffusionModelConfig:
"""Create a test DiffusionModelConfig."""
num_heads = hidden_size // 8 # e.g., 64 // 8 = 8 heads
head_dim = 8
return DiffusionModelConfig(
pretrained_config=SimpleNamespace(
hidden_size=hidden_size,
num_attention_heads=num_heads,
attention_head_dim=head_dim,
num_layers=2,
ffn_dim=256,
out_channels=16,
patch_size=[1, 2, 2],
in_channels=16,
text_dim=64,
freq_dim=32,
),
)
class TestFusedQKVWeightLoading(unittest.TestCase):
"""Test weight loading for fused QKV layers."""
def setUp(self):
"""Set up test fixtures."""
torch.manual_seed(42)
self.hidden_size = 64
def _create_mock_checkpoint_weights(self) -> Dict[str, torch.Tensor]:
"""Create mock checkpoint weights with separate to_q, to_k, to_v."""
dtype = torch.bfloat16 # Match model dtype
weights = {}
for block_idx in range(2):
for attn_name in ["attn1", "attn2"]:
prefix = f"blocks.{block_idx}.{attn_name}"
# Separate QKV weights (as in checkpoint)
weights[f"{prefix}.to_q.weight"] = torch.randn(
self.hidden_size, self.hidden_size, dtype=dtype
)
weights[f"{prefix}.to_q.bias"] = torch.randn(self.hidden_size, dtype=dtype)
weights[f"{prefix}.to_k.weight"] = torch.randn(
self.hidden_size, self.hidden_size, dtype=dtype
)
weights[f"{prefix}.to_k.bias"] = torch.randn(self.hidden_size, dtype=dtype)
weights[f"{prefix}.to_v.weight"] = torch.randn(
self.hidden_size, self.hidden_size, dtype=dtype
)
weights[f"{prefix}.to_v.bias"] = torch.randn(self.hidden_size, dtype=dtype)
# Output projection
weights[f"{prefix}.to_out.0.weight"] = torch.randn(
self.hidden_size, self.hidden_size, dtype=dtype
)
weights[f"{prefix}.to_out.0.bias"] = torch.randn(self.hidden_size, dtype=dtype)
# FFN weights
ffn_dim = 256
prefix = f"blocks.{block_idx}.ffn"
weights[f"{prefix}.net.0.proj.weight"] = torch.randn(
ffn_dim, self.hidden_size, dtype=dtype
)
weights[f"{prefix}.net.0.proj.bias"] = torch.randn(ffn_dim, dtype=dtype)
weights[f"{prefix}.net.2.weight"] = torch.randn(self.hidden_size, ffn_dim, dtype=dtype)
weights[f"{prefix}.net.2.bias"] = torch.randn(self.hidden_size, dtype=dtype)
# proj_out
weights["proj_out.weight"] = torch.randn(64, self.hidden_size, dtype=dtype)
weights["proj_out.bias"] = torch.randn(64, dtype=dtype)
return weights
def test_load_weights_fused(self):
"""Test loading weights with fused QKV (default for self-attention)."""
from tensorrt_llm._torch.visual_gen.models.wan.transformer_wan import WanTransformer3DModel
config = _create_test_config(self.hidden_size)
# Create model - self-attention (attn1) uses fused QKV by default
model = WanTransformer3DModel(model_config=config)
weights = self._create_mock_checkpoint_weights()
# Load weights (model handles fused QKV internally via DynamicLinearWeightLoader)
model.load_weights(weights)
# Verify fused weights were loaded correctly for self-attention
attn1 = model.blocks[0].attn1
qkv_weight = attn1.qkv_proj.weight.data
# Expected: concatenation of to_q, to_k, to_v weights
expected_weight = torch.cat(
[
weights["blocks.0.attn1.to_q.weight"],
weights["blocks.0.attn1.to_k.weight"],
weights["blocks.0.attn1.to_v.weight"],
],
dim=0,
)
self.assertEqual(qkv_weight.shape, expected_weight.shape)
self.assertTrue(torch.allclose(qkv_weight, expected_weight))
# Also verify cross-attention (attn2) uses separate Q/K/V
attn2 = model.blocks[0].attn2
self.assertTrue(hasattr(attn2, "to_q"), "Cross-attention should have separate to_q")
self.assertTrue(
torch.allclose(attn2.to_q.weight.data, weights["blocks.0.attn2.to_q.weight"])
)
if __name__ == "__main__":
unittest.main()

View File

@ -0,0 +1,494 @@
"""Test PipelineLoader with DiffusionArgs API."""
import os
from pathlib import Path
import pytest
import torch
from tensorrt_llm._torch.visual_gen.config import PipelineComponent
def _llm_models_root() -> str:
"""Return LLM_MODELS_ROOT path if it is set in env, assert when it's set but not a valid path."""
root = Path("/home/scratch.trt_llm_data_ci/llm-models/")
if "LLM_MODELS_ROOT" in os.environ:
root = Path(os.environ["LLM_MODELS_ROOT"])
if not root.exists():
root = Path("/scratch.trt_llm_data/llm-models/")
assert root.exists(), (
"You shall set LLM_MODELS_ROOT env or be able to access scratch.trt_llm_data to run this test"
)
return str(root)
# Skip if checkpoint not available
# Set DIFFUSION_MODEL_PATH env var to run integration tests
CHECKPOINT_PATH = os.environ.get(
"DIFFUSION_MODEL_PATH",
os.path.join(_llm_models_root(), "Wan2.1-T2V-1.3B-Diffusers"),
)
# Skip heavy components (text_encoder ~44GB, vae ~300MB) to speed up tests
# These components are loaded via diffusers and don't need quantization testing
SKIP_HEAVY_COMPONENTS = [
PipelineComponent.TEXT_ENCODER,
PipelineComponent.VAE,
PipelineComponent.TOKENIZER,
PipelineComponent.SCHEDULER,
]
@pytest.fixture
def checkpoint_exists():
return CHECKPOINT_PATH and os.path.exists(CHECKPOINT_PATH)
def test_meta_init_mode_creates_meta_tensors(checkpoint_exists):
"""Test that MetaInitMode creates tensors on meta device (no GPU memory)."""
if not checkpoint_exists:
pytest.skip("Checkpoint not available")
from tensorrt_llm._torch.models.modeling_utils import MetaInitMode
from tensorrt_llm._torch.visual_gen import DiffusionArgs
from tensorrt_llm._torch.visual_gen.config import DiffusionModelConfig
from tensorrt_llm._torch.visual_gen.models import AutoPipeline
# Load config directly
args = DiffusionArgs(checkpoint_path=CHECKPOINT_PATH)
config = DiffusionModelConfig.from_pretrained(
CHECKPOINT_PATH,
args=args,
)
# Create pipeline WITH MetaInitMode
with MetaInitMode():
pipeline = AutoPipeline.from_config(config, CHECKPOINT_PATH)
# Verify tensors are on meta device (no GPU memory allocated)
param = next(pipeline.transformer.parameters())
assert param.device.type == "meta", f"Expected meta device, got {param.device}"
def test_load_wan_pipeline_basic(checkpoint_exists):
"""Test basic loading without quantization using DiffusionArgs."""
if not checkpoint_exists:
pytest.skip("Checkpoint not available")
from tensorrt_llm._torch.visual_gen import DiffusionArgs, PipelineLoader
# Simple one-liner with DiffusionArgs
# Skip text_encoder/vae to speed up test (focus on transformer)
args = DiffusionArgs(
checkpoint_path=CHECKPOINT_PATH,
skip_components=SKIP_HEAVY_COMPONENTS,
)
pipeline = PipelineLoader(args).load()
# Verify pipeline type
assert pipeline.__class__.__name__ == "WanPipeline"
assert pipeline.transformer is not None
# Verify text_encoder/vae were skipped
assert pipeline.text_encoder is None, "text_encoder should be skipped"
assert pipeline.vae is None, "vae should be skipped"
# Verify weights are loaded (not meta tensors)
param = next(pipeline.transformer.parameters())
assert param.device.type == "cuda"
assert param.dtype in [torch.float32, torch.bfloat16, torch.float16]
def test_load_wan_pipeline_with_fp8_dynamic_quant(checkpoint_exists):
"""Test loading with FP8 dynamic quantization using DiffusionArgs.
Verifies the dynamic quantization flow:
1. Config has dynamic_weight_quant=True when linear.type="trtllm-fp8-per-tensor"
2. Model Linear layers have FP8 weight buffers
3. BF16 checkpoint weights are quantized on-the-fly
4. Quantized weights are in FP8 format
"""
if not checkpoint_exists:
pytest.skip("Checkpoint not available")
from tensorrt_llm._torch.modules.linear import Linear
from tensorrt_llm._torch.visual_gen import DiffusionArgs, PipelineLoader
# Use DiffusionArgs with FP8 quantization
# Skip text_encoder/vae to speed up test (focus on transformer quantization)
args = DiffusionArgs(
checkpoint_path=CHECKPOINT_PATH,
quant_config={"quant_algo": "FP8", "dynamic": True},
skip_components=SKIP_HEAVY_COMPONENTS,
)
pipeline = PipelineLoader(args).load()
# Verify model config has dynamic_weight_quant enabled
assert pipeline.model_config.dynamic_weight_quant is True, (
"dynamic_weight_quant should be True when linear.type specifies FP8"
)
# Verify FP8 weights in transformer Linear layers
found_fp8_linear = False
for name, module in pipeline.transformer.named_modules():
if isinstance(module, Linear):
if hasattr(module, "weight") and module.weight is not None:
assert module.weight.dtype == torch.float8_e4m3fn, (
f"Linear {name} weight dtype is {module.weight.dtype}, expected float8_e4m3fn"
)
assert hasattr(module, "weight_scale") and module.weight_scale is not None, (
f"Linear {name} missing weight_scale buffer"
)
found_fp8_linear = True
break
assert found_fp8_linear, "No FP8 Linear modules found in transformer"
def test_load_wan_pipeline_with_fp8_blockwise(checkpoint_exists):
"""Test loading with FP8 blockwise quantization using DiffusionArgs."""
if not checkpoint_exists:
pytest.skip("Checkpoint not available")
from tensorrt_llm._torch.modules.linear import Linear
from tensorrt_llm._torch.visual_gen import DiffusionArgs, PipelineLoader
# Skip text_encoder/vae to speed up test (focus on transformer quantization)
args = DiffusionArgs(
checkpoint_path=CHECKPOINT_PATH,
quant_config={"quant_algo": "FP8_BLOCK_SCALES", "dynamic": True},
skip_components=SKIP_HEAVY_COMPONENTS,
)
pipeline = PipelineLoader(args).load()
# Verify FP8 weights
for name, module in pipeline.transformer.named_modules():
if isinstance(module, Linear):
if hasattr(module, "weight") and module.weight is not None:
assert module.weight.dtype == torch.float8_e4m3fn, (
f"Linear {name} should have FP8 weight"
)
break
def test_diffusion_args_to_quant_config():
"""Test that DiffusionArgs correctly parses quant_config dict to QuantConfig."""
from tensorrt_llm._torch.visual_gen import DiffusionArgs
from tensorrt_llm.quantization.mode import QuantAlgo
# Default - no quantization
args = DiffusionArgs(checkpoint_path="/fake/path")
assert args.quant_config.quant_algo is None
# FP8 per-tensor (dict is coerced to QuantConfig by model_validator)
args = DiffusionArgs(
checkpoint_path="/fake/path",
quant_config={"quant_algo": "FP8", "dynamic": True},
)
qc = args.quant_config
assert qc is not None
assert qc.quant_algo == QuantAlgo.FP8
assert args.dynamic_weight_quant is True
# FP8 blockwise
args = DiffusionArgs(
checkpoint_path="/fake/path",
quant_config={"quant_algo": "FP8_BLOCK_SCALES", "dynamic": True},
)
qc = args.quant_config
assert qc.quant_algo == QuantAlgo.FP8_BLOCK_SCALES
# NVFP4
args = DiffusionArgs(
checkpoint_path="/fake/path",
quant_config={"quant_algo": "NVFP4", "dynamic": True},
)
qc = args.quant_config
assert qc.quant_algo == QuantAlgo.NVFP4
# With ignore patterns (exclude_modules)
args = DiffusionArgs(
checkpoint_path="/fake/path",
quant_config={
"quant_algo": "FP8",
"ignore": ["blocks.0.attn1.*", "proj_out"],
"config_groups": {
"group_0": {
"weights": {"dynamic": True, "num_bits": 8, "type": "float"},
"targets": ["Linear"],
}
},
},
)
qc = args.quant_config
assert qc is not None
assert qc.quant_algo == QuantAlgo.FP8
assert qc.exclude_modules == ["blocks.0.attn1.*", "proj_out"]
assert args.dynamic_weight_quant is True
def test_diffusion_args_to_mapping():
"""Test that DiffusionArgs correctly generates Mapping from ParallelConfig."""
from tensorrt_llm._torch.visual_gen import DiffusionArgs, ParallelConfig
# ParallelConfig validator requires WORLD_SIZE >= total parallel (tp*cp = 4)
old_world = os.environ.get("WORLD_SIZE")
try:
os.environ["WORLD_SIZE"] = "4"
args = DiffusionArgs(
checkpoint_path="/fake/path",
parallel=ParallelConfig(dit_tp_size=2, dit_cp_size=2),
)
mapping = args.to_mapping()
assert mapping.tp_size == 2
assert mapping.cp_size == 2
# world_size = tp_size * pp_size * cp_size (DP is handled separately)
assert mapping.world_size == 4
finally:
if old_world is not None:
os.environ["WORLD_SIZE"] = old_world
elif "WORLD_SIZE" in os.environ:
del os.environ["WORLD_SIZE"]
def test_load_without_quant_config_no_fp8(checkpoint_exists):
"""Test that loading without quant_config does NOT produce FP8 weights."""
if not checkpoint_exists:
pytest.skip("Checkpoint not available")
from tensorrt_llm._torch.modules.linear import Linear
from tensorrt_llm._torch.visual_gen import DiffusionArgs, PipelineLoader
# No quantization specified
# Skip text_encoder/vae to speed up test (focus on transformer)
args = DiffusionArgs(
checkpoint_path=CHECKPOINT_PATH,
skip_components=SKIP_HEAVY_COMPONENTS,
)
pipeline = PipelineLoader(args).load()
# Verify dynamic_weight_quant is False
assert pipeline.model_config.dynamic_weight_quant is False, (
"dynamic_weight_quant should be False when no quant_config"
)
# Verify NO FP8 weights
for name, module in pipeline.transformer.named_modules():
if isinstance(module, Linear):
if hasattr(module, "weight") and module.weight is not None:
assert module.weight.dtype != torch.float8_e4m3fn, (
f"Linear {name} should NOT be FP8 without quant_config"
)
break
def test_diffusion_args_from_dict():
"""Test DiffusionArgs can be created from a dictionary."""
from tensorrt_llm._torch.visual_gen import DiffusionArgs
from tensorrt_llm.quantization.mode import QuantAlgo
config_dict = {
"checkpoint_path": "/path/to/model",
"quant_config": {"quant_algo": "FP8", "dynamic": True},
"parallel": {"dit_tp_size": 2},
"pipeline": {"fuse_qkv": True},
}
# ParallelConfig validator requires WORLD_SIZE >= total parallel (dit_tp_size=2)
old_world = os.environ.get("WORLD_SIZE")
try:
os.environ["WORLD_SIZE"] = "2"
args = DiffusionArgs.from_dict(config_dict)
assert args.checkpoint_path == "/path/to/model"
assert args.quant_config.quant_algo == QuantAlgo.FP8
assert args.dynamic_weight_quant is True
assert args.parallel.dit_tp_size == 2
assert args.pipeline.fuse_qkv is True
finally:
if old_world is not None:
os.environ["WORLD_SIZE"] = old_world
elif "WORLD_SIZE" in os.environ:
del os.environ["WORLD_SIZE"]
# =============================================================================
# Memory and Performance Tests
# =============================================================================
def _get_module_memory_gb(module):
"""Get GPU memory usage of a module in GB."""
return sum(p.numel() * p.element_size() for p in module.parameters()) / 1024**3
def _get_cuda_memory_gb():
"""Get current CUDA memory allocated in GB."""
return torch.cuda.memory_allocated() / 1024**3
def _get_cuda_peak_memory_gb():
"""Get peak CUDA memory allocated in GB."""
return torch.cuda.max_memory_allocated() / 1024**3
def test_fp8_vs_bf16_memory_comparison(checkpoint_exists):
"""Test FP8 dynamic quant uses ~2x less memory than BF16, including peak memory.
This test verifies that dynamic quantization doesn't create unnecessary
intermediate buffers that would negate the memory savings.
Expected for Wan 1.3B transformer:
- BF16: ~2.6 GB model memory, similar peak during loading
- FP8: ~1.3 GB model memory, peak should be < 2x BF16 peak
"""
if not checkpoint_exists:
pytest.skip("Checkpoint not available")
from tensorrt_llm._torch.visual_gen import DiffusionArgs, PipelineLoader
# =========================================================================
# Test 1: Load BF16 (no quantization)
# =========================================================================
torch.cuda.empty_cache()
torch.cuda.reset_peak_memory_stats()
args_bf16 = DiffusionArgs(
checkpoint_path=CHECKPOINT_PATH,
skip_components=SKIP_HEAVY_COMPONENTS,
)
pipeline_bf16 = PipelineLoader(args_bf16).load()
bf16_model_mem = _get_module_memory_gb(pipeline_bf16.transformer)
bf16_total_mem = _get_cuda_memory_gb()
bf16_peak_mem = _get_cuda_peak_memory_gb()
print(f"\n[BF16] Transformer model memory: {bf16_model_mem:.2f} GB")
print(f"[BF16] Total CUDA memory: {bf16_total_mem:.2f} GB")
print(f"[BF16] Peak CUDA memory: {bf16_peak_mem:.2f} GB")
# Cleanup BF16
del pipeline_bf16
torch.cuda.empty_cache()
# =========================================================================
# Test 2: Load FP8 (dynamic quantization)
# =========================================================================
torch.cuda.reset_peak_memory_stats()
args_fp8 = DiffusionArgs(
checkpoint_path=CHECKPOINT_PATH,
quant_config={"quant_algo": "FP8", "dynamic": True},
skip_components=SKIP_HEAVY_COMPONENTS,
)
pipeline_fp8 = PipelineLoader(args_fp8).load()
fp8_model_mem = _get_module_memory_gb(pipeline_fp8.transformer)
fp8_total_mem = _get_cuda_memory_gb()
fp8_peak_mem = _get_cuda_peak_memory_gb()
print(f"\n[FP8] Transformer model memory: {fp8_model_mem:.2f} GB")
print(f"[FP8] Total CUDA memory: {fp8_total_mem:.2f} GB")
print(f"[FP8] Peak CUDA memory: {fp8_peak_mem:.2f} GB")
# =========================================================================
# Verify memory savings
# =========================================================================
model_mem_ratio = bf16_model_mem / fp8_model_mem
peak_mem_ratio = bf16_peak_mem / fp8_peak_mem
print(f"\n[Comparison] Model memory ratio (BF16/FP8): {model_mem_ratio:.2f}x")
print(f"[Comparison] Peak memory ratio (BF16/FP8): {peak_mem_ratio:.2f}x")
# Model memory should be ~2x smaller for FP8
assert model_mem_ratio > 1.8, (
f"FP8 model memory should be ~2x smaller than BF16, got {model_mem_ratio:.2f}x"
)
# Peak memory during loading should also show savings
# Allow some overhead for dynamic quant, but should still be significantly better
assert peak_mem_ratio > 1.5, (
f"FP8 peak memory should be significantly smaller than BF16, got {peak_mem_ratio:.2f}x. "
f"This may indicate unnecessary intermediate buffers during dynamic quantization."
)
# FP8 peak should not be much higher than FP8 final (no large temp buffers)
fp8_peak_overhead = fp8_peak_mem / fp8_total_mem
print(f"[FP8 Per-Tensor] Peak/Final memory ratio: {fp8_peak_overhead:.2f}x")
# Peak should be close to final (< 1.5x overhead during loading)
assert fp8_peak_overhead < 2.0, (
f"FP8 peak memory ({fp8_peak_mem:.2f} GB) is too high compared to final "
f"({fp8_total_mem:.2f} GB). Ratio: {fp8_peak_overhead:.2f}x. "
f"This suggests unnecessary buffer allocation during dynamic quantization."
)
# Cleanup
del pipeline_fp8
torch.cuda.empty_cache()
# =========================================================================
# Test 3: Load FP8 Blockwise (dynamic quantization with block scales)
# =========================================================================
torch.cuda.reset_peak_memory_stats()
args_fp8_block = DiffusionArgs(
checkpoint_path=CHECKPOINT_PATH,
quant_config={"quant_algo": "FP8_BLOCK_SCALES", "dynamic": True},
skip_components=SKIP_HEAVY_COMPONENTS,
)
pipeline_fp8_block = PipelineLoader(args_fp8_block).load()
fp8_block_model_mem = _get_module_memory_gb(pipeline_fp8_block.transformer)
fp8_block_total_mem = _get_cuda_memory_gb()
fp8_block_peak_mem = _get_cuda_peak_memory_gb()
print(f"\n[FP8 Blockwise] Transformer model memory: {fp8_block_model_mem:.2f} GB")
print(f"[FP8 Blockwise] Total CUDA memory: {fp8_block_total_mem:.2f} GB")
print(f"[FP8 Blockwise] Peak CUDA memory: {fp8_block_peak_mem:.2f} GB")
# =========================================================================
# Verify FP8 Blockwise memory savings
# =========================================================================
block_model_mem_ratio = bf16_model_mem / fp8_block_model_mem
block_peak_mem_ratio = bf16_peak_mem / fp8_block_peak_mem
print(f"\n[Comparison] Model memory ratio (BF16/FP8-Block): {block_model_mem_ratio:.2f}x")
print(f"[Comparison] Peak memory ratio (BF16/FP8-Block): {block_peak_mem_ratio:.2f}x")
# FP8 Blockwise has additional scale tensors, so slightly less than 2x savings
# But should still be significantly better than BF16
assert block_model_mem_ratio > 1.5, (
f"FP8 Blockwise model memory should be significantly smaller than BF16, got {block_model_mem_ratio:.2f}x"
)
# Peak memory check
assert block_peak_mem_ratio > 1.3, (
f"FP8 Blockwise peak memory should be smaller than BF16, got {block_peak_mem_ratio:.2f}x"
)
fp8_block_peak_overhead = fp8_block_peak_mem / fp8_block_total_mem
print(f"[FP8 Blockwise] Peak/Final memory ratio: {fp8_block_peak_overhead:.2f}x")
assert fp8_block_peak_overhead < 2.0, (
f"FP8 Blockwise peak memory ({fp8_block_peak_mem:.2f} GB) is too high compared to final "
f"({fp8_block_total_mem:.2f} GB). Ratio: {fp8_block_peak_overhead:.2f}x."
)
# Cleanup
del pipeline_fp8_block
torch.cuda.empty_cache()
# =========================================================================
# Summary
# =========================================================================
print("\n" + "=" * 60)
print("SUMMARY")
print("=" * 60)
print(f"BF16: {bf16_model_mem:.2f} GB model, {bf16_peak_mem:.2f} GB peak")
print(
f"FP8 Per-Tensor: {fp8_model_mem:.2f} GB model, {fp8_peak_mem:.2f} GB peak "
f"({model_mem_ratio:.2f}x savings)"
)
print(
f"FP8 Blockwise: {fp8_block_model_mem:.2f} GB model, {fp8_block_peak_mem:.2f} GB peak "
f"({block_model_mem_ratio:.2f}x savings)"
)

View File

@ -0,0 +1,120 @@
"""Unit tests for diffusion quantization operations."""
import unittest
import torch
from tensorrt_llm._torch.visual_gen.quantization.ops import (
quantize_fp8_blockwise,
quantize_fp8_per_tensor,
)
def _dequant_fp8_per_tensor(qweight, scale):
"""Dequantize per-tensor FP8 weight."""
return qweight.to(torch.float32) * scale
class TestQuantOps(unittest.TestCase):
"""Test quantization operations."""
def setUp(self):
"""Set random seed for reproducibility."""
torch.manual_seed(42)
if not torch.cuda.is_available():
self.skipTest("CUDA not available")
def test_fp8_per_tensor(self):
"""Test FP8 per-tensor quantization using CUDA kernel."""
weight = torch.randn(256, 512, dtype=torch.bfloat16, device="cuda")
qweight, scale = quantize_fp8_per_tensor(weight)
# Check output types
self.assertEqual(qweight.dtype, torch.float8_e4m3fn)
self.assertEqual(qweight.shape, weight.shape)
self.assertEqual(scale.dtype, torch.float32)
self.assertEqual(scale.shape, (1, 1))
# Verify dequantization (approximate)
dequant = _dequant_fp8_per_tensor(qweight, scale)
error = (dequant - weight.to(torch.float32)).abs().mean()
self.assertLess(error, 0.15) # Reasonable quantization error
def test_fp8_per_tensor_different_shapes(self):
"""Test FP8 per-tensor quantization with various shapes."""
shapes = [(128, 256), (256, 512), (512, 1024), (1024, 2048)]
for shape in shapes:
with self.subTest(shape=shape):
weight = torch.randn(shape, dtype=torch.bfloat16, device="cuda")
qweight, scale = quantize_fp8_per_tensor(weight)
self.assertEqual(qweight.dtype, torch.float8_e4m3fn)
self.assertEqual(qweight.shape, weight.shape)
self.assertEqual(scale.dtype, torch.float32)
def test_fp8_blockwise(self):
"""Test FP8 128x128 blockwise quantization."""
weight = torch.randn(512, 512, dtype=torch.bfloat16, device="cuda")
block_size = 128
qweight, scales = quantize_fp8_blockwise(weight, block_size=block_size)
# Check output types
self.assertEqual(qweight.dtype, torch.float8_e4m3fn)
self.assertEqual(qweight.shape, weight.shape)
self.assertEqual(scales.dtype, torch.float32)
# Check scales shape: (num_blocks_out, num_blocks_in) for 128x128 blocks
num_blocks_out = (512 + block_size - 1) // block_size # 4
num_blocks_in = (512 + block_size - 1) // block_size # 4
self.assertEqual(scales.shape, (num_blocks_out, num_blocks_in))
def test_fp8_blockwise_non_divisible(self):
"""Test FP8 blockwise quantization with non-divisible dimensions."""
weight = torch.randn(300, 500, dtype=torch.bfloat16, device="cuda")
block_size = 128
qweight, scales = quantize_fp8_blockwise(weight, block_size=block_size)
# Check output types
self.assertEqual(qweight.dtype, torch.float8_e4m3fn)
self.assertEqual(qweight.shape, weight.shape)
# Check scales shape (should handle non-divisible dimensions)
num_blocks_out = (300 + block_size - 1) // block_size # 3
num_blocks_in = (500 + block_size - 1) // block_size # 4
self.assertEqual(scales.shape, (num_blocks_out, num_blocks_in))
def test_fp8_blockwise_different_block_sizes(self):
"""Test FP8 blockwise quantization with different block sizes."""
weight = torch.randn(256, 256, dtype=torch.bfloat16, device="cuda")
for block_size in [64, 128, 256]:
with self.subTest(block_size=block_size):
qweight, scales = quantize_fp8_blockwise(weight, block_size=block_size)
self.assertEqual(qweight.dtype, torch.float8_e4m3fn)
self.assertEqual(qweight.shape, weight.shape)
num_blocks = (256 + block_size - 1) // block_size
self.assertEqual(scales.shape, (num_blocks, num_blocks))
def test_fp8_per_tensor_zero_weight(self):
"""Test FP8 per-tensor quantization with zero weight."""
weight = torch.zeros(128, 256, dtype=torch.bfloat16, device="cuda")
qweight, scale = quantize_fp8_per_tensor(weight)
# Should handle zero weights gracefully
self.assertEqual(qweight.dtype, torch.float8_e4m3fn)
self.assertTrue(torch.all(qweight.to(torch.float32) == 0))
def test_fp8_blockwise_zero_weight(self):
"""Test FP8 blockwise quantization with zero weight."""
weight = torch.zeros(256, 256, dtype=torch.bfloat16, device="cuda")
qweight, scales = quantize_fp8_blockwise(weight, block_size=128)
# Should handle zero weights gracefully
self.assertEqual(qweight.dtype, torch.float8_e4m3fn)
self.assertTrue(torch.all(qweight.to(torch.float32) == 0))
if __name__ == "__main__":
unittest.main()

View File

@ -0,0 +1,398 @@
"""End-to-end tests for trtllm-serve visual_gen with real models.
Tests text-to-video (t2v) and text+image-to-video (ti2v) generation through
the full ``trtllm-serve`` stack backed by real VisualGen models.
The server is launched as a subprocess (same pattern as
``tests/unittest/llmapi/apps/openai_server.py``), so each test class gets an
isolated ``trtllm-serve`` process.
Usage::
# Run all real-model tests (requires GPU + models in $HOME/llm-models-ci)
pytest tests/visual_gen/test_trtllm_serve_e2e.py -v
# Run only t2v tests
pytest tests/visual_gen/test_trtllm_serve_e2e.py -v -k TestWanT2V
# Run only ti2v tests
pytest tests/visual_gen/test_trtllm_serve_e2e.py -v -k TestWanI2V
"""
import os
import subprocess
import sys
import tempfile
import time
from pathlib import Path
from typing import List, Optional
import pytest
import requests
import yaml
from tensorrt_llm._utils import get_free_port
# ---------------------------------------------------------------------------
# Model paths
# ---------------------------------------------------------------------------
def _llm_models_root() -> str:
"""Return LLM_MODELS_ROOT path if it is set in env, assert when it's set but not a valid path."""
root = Path("/home/scratch.trt_llm_data_ci/llm-models/")
if "LLM_MODELS_ROOT" in os.environ:
root = Path(os.environ["LLM_MODELS_ROOT"])
if not root.exists():
root = Path("/scratch.trt_llm_data/llm-models/")
assert root.exists(), (
"You shall set LLM_MODELS_ROOT env or be able to access scratch.trt_llm_data to run this test"
)
return str(root)
_WAN_T2V_PATH = Path(_llm_models_root()) / "Wan2.1-T2V-1.3B-Diffusers"
_WAN_I2V_PATH = Path(_llm_models_root()) / "Wan2.2-I2V-A14B-Diffusers"
# Reference image used for image-to-video (ti2v) tests
_PROJECT_ROOT = Path(__file__).resolve().parents[4] # repo root
_REF_IMAGE_PATH = _PROJECT_ROOT / "examples" / "visual_gen" / "cat_piano.png"
# ---------------------------------------------------------------------------
# Remote server helper (follows RemoteOpenAIServer pattern)
# ---------------------------------------------------------------------------
class RemoteVisualGenServer:
"""Launch ``trtllm-serve`` for a visual-gen model as a subprocess.
Mirrors the interface of ``tests.unittest.llmapi.apps.openai_server.RemoteOpenAIServer``
adapted for diffusion / visual-gen models.
"""
MAX_SERVER_START_WAIT_S = 1200 # 20 min large models need time to load
def __init__(
self,
model: str,
extra_visual_gen_options: Optional[dict] = None,
cli_args: Optional[List[str]] = None,
host: str = "localhost",
port: Optional[int] = None,
env: Optional[dict] = None,
) -> None:
self.host = host
self.port = port if port is not None else get_free_port()
self._config_file: Optional[str] = None
self.proc: Optional[subprocess.Popen] = None
args = ["--host", self.host, "--port", str(self.port)]
if cli_args:
args += cli_args
# Write the visual-gen YAML config to a temp file
if extra_visual_gen_options:
fd, self._config_file = tempfile.mkstemp(suffix=".yml", prefix="vg_cfg_")
with os.fdopen(fd, "w") as f:
yaml.dump(extra_visual_gen_options, f)
args += ["--extra_visual_gen_options", self._config_file]
launch_cmd = ["trtllm-serve", model] + args
if env is None:
env = os.environ.copy()
self.proc = subprocess.Popen(
launch_cmd,
env=env,
stdout=sys.stdout,
stderr=sys.stderr,
)
self._wait_for_server(timeout=self.MAX_SERVER_START_WAIT_S)
# -- lifecycle ---------------------------------------------------------
def __enter__(self):
return self
def __exit__(self, exc_type, exc_value, traceback):
self.terminate()
def terminate(self):
if self.proc is None:
return
self.proc.terminate()
try:
self.proc.wait(timeout=30)
except subprocess.TimeoutExpired:
self.proc.kill()
self.proc.wait(timeout=30)
self.proc = None
if self._config_file:
try:
os.remove(self._config_file)
except OSError:
pass
self._config_file = None
# -- readiness ---------------------------------------------------------
def _wait_for_server(self, timeout: float):
url = self.url_for("health")
start = time.time()
while True:
try:
if requests.get(url).status_code == 200:
return
except Exception as err:
result = self.proc.poll()
if result is not None and result != 0:
raise RuntimeError("Visual-gen server exited unexpectedly.") from err
time.sleep(1)
if time.time() - start > timeout:
self.terminate()
raise RuntimeError(f"Visual-gen server failed to start within {timeout}s.")
# -- URL helpers -------------------------------------------------------
@property
def url_root(self) -> str:
return f"http://{self.host}:{self.port}"
def url_for(self, *parts: str) -> str:
return self.url_root + "/" + "/".join(parts)
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
def _model_available(path: Path) -> bool:
return path.is_dir()
def _av_available() -> bool:
"""Check if PyAV is installed (required for video encoding in E2E tests)."""
try:
import av # noqa: F401
return True
except ImportError:
return False
def _make_visual_gen_options(**extra) -> dict:
"""Build the YAML dict passed via ``--extra_visual_gen_options``."""
config = {
"linear": {"type": "default"},
"parallel": {"dit_cfg_size": 1, "dit_ulysses_size": 1},
}
config.update(extra)
return config
# =========================================================================
# WAN 2.1 Text-to-Video (t2v)
# =========================================================================
@pytest.mark.skipif(
not _model_available(_WAN_T2V_PATH), reason=f"Wan2.1-T2V model not found at {_WAN_T2V_PATH}"
)
@pytest.mark.skipif(
not _av_available(), reason="PyAV (av) not installed — required for video encoding in E2E tests"
)
class TestWanTextToVideo:
"""Test Wan2.1-T2V-1.3B-Diffusers text-to-video generation via serve API."""
@pytest.fixture(scope="class")
def server(self):
with RemoteVisualGenServer(
model=str(_WAN_T2V_PATH),
extra_visual_gen_options=_make_visual_gen_options(),
) as srv:
yield srv
# ------------------------------------------------------------------
def test_health(self, server):
resp = requests.get(server.url_for("health"))
assert resp.status_code == 200
def test_t2v_sync(self, server):
"""Synchronous text-to-video via POST /v1/videos/generations."""
resp = requests.post(
server.url_for("v1", "videos", "generations"),
json={
"prompt": "A cute cat playing piano",
"size": "480x320",
"seconds": 1.0,
"fps": 8,
"num_inference_steps": 4,
"seed": 42,
},
)
assert resp.status_code == 200, resp.text
assert resp.headers["content-type"] == "video/mp4"
assert len(resp.content) > 1000, "Video file too small"
def test_t2v_async_lifecycle(self, server):
"""Async video generation: create job → poll → download → delete."""
base = server.url_for("v1", "videos")
# 1. Create job
create_resp = requests.post(
base,
json={
"prompt": "A rocket launching into a starry sky",
"size": "480x320",
"seconds": 1.0,
"fps": 8,
"num_inference_steps": 4,
"seed": 42,
},
)
assert create_resp.status_code == 202, create_resp.text
job = create_resp.json()
video_id = job["id"]
assert video_id.startswith("video_")
assert job["status"] == "queued"
# 2. Poll until completed (or timeout)
deadline = time.time() + 600 # 10 min
status = "queued"
while status not in ("completed", "failed") and time.time() < deadline:
time.sleep(2)
meta_resp = requests.get(f"{base}/{video_id}")
assert meta_resp.status_code == 200
status = meta_resp.json()["status"]
assert status == "completed", f"Video generation did not complete: {status}"
# 3. Download video content
content_resp = requests.get(f"{base}/{video_id}/content")
assert content_resp.status_code == 200
assert "video/mp4" in content_resp.headers.get("content-type", "")
assert len(content_resp.content) > 1000
# 4. Verify it appears in list
list_resp = requests.get(base)
assert list_resp.status_code == 200
ids = [v["id"] for v in list_resp.json()["data"]]
assert video_id in ids
# 5. Delete
del_resp = requests.delete(f"{base}/{video_id}")
assert del_resp.status_code == 200
assert del_resp.json()["deleted"] is True
# 6. Confirm gone
gone_resp = requests.get(f"{base}/{video_id}")
assert gone_resp.status_code == 404
# =========================================================================
# WAN 2.2 Image-to-Video (ti2v)
# =========================================================================
@pytest.mark.skipif(
not _model_available(_WAN_I2V_PATH), reason=f"Wan2.2-I2V model not found at {_WAN_I2V_PATH}"
)
@pytest.mark.skipif(
not _REF_IMAGE_PATH.is_file(), reason=f"Reference image not found at {_REF_IMAGE_PATH}"
)
@pytest.mark.skipif(
not _av_available(), reason="PyAV (av) not installed — required for video encoding in E2E tests"
)
class TestWanImageToVideo:
"""Test Wan2.2-I2V-A14B-Diffusers image-to-video generation via serve API."""
@pytest.fixture(scope="class")
def server(self):
with RemoteVisualGenServer(
model=str(_WAN_I2V_PATH),
extra_visual_gen_options=_make_visual_gen_options(),
) as srv:
yield srv
# ------------------------------------------------------------------
def test_health(self, server):
resp = requests.get(server.url_for("health"))
assert resp.status_code == 200
def test_ti2v_sync(self, server):
"""Synchronous image-to-video via multipart POST /v1/videos/generations."""
with open(_REF_IMAGE_PATH, "rb") as f:
resp = requests.post(
server.url_for("v1", "videos", "generations"),
data={
"prompt": "The cat starts playing piano, keys moving",
"size": "480x320",
"seconds": "1.0",
"fps": "8",
"num_inference_steps": "4",
"seed": "42",
},
files={
"input_reference": ("cat_piano.png", f, "image/png"),
},
)
assert resp.status_code == 200, resp.text
assert resp.headers["content-type"] == "video/mp4"
assert len(resp.content) > 1000, "Video file too small"
def test_ti2v_async_lifecycle(self, server):
"""Async i2v: create job with image → poll → download → delete."""
base = server.url_for("v1", "videos")
# 1. Create job via multipart
with open(_REF_IMAGE_PATH, "rb") as f:
create_resp = requests.post(
base,
data={
"prompt": "Snow falls on the piano and the cat",
"size": "480x320",
"seconds": "1.0",
"fps": "8",
"num_inference_steps": "4",
"seed": "42",
},
files={
"input_reference": ("cat_piano.png", f, "image/png"),
},
)
assert create_resp.status_code == 202, create_resp.text
job = create_resp.json()
video_id = job["id"]
assert job["status"] == "queued"
# 2. Poll until completed
deadline = time.time() + 600
status = "queued"
while status not in ("completed", "failed") and time.time() < deadline:
time.sleep(2)
meta_resp = requests.get(f"{base}/{video_id}")
assert meta_resp.status_code == 200
status = meta_resp.json()["status"]
assert status == "completed", f"Video generation did not complete: {status}"
# 3. Download
content_resp = requests.get(f"{base}/{video_id}/content")
assert content_resp.status_code == 200
assert "video/mp4" in content_resp.headers.get("content-type", "")
assert len(content_resp.content) > 1000
# 4. Delete
del_resp = requests.delete(f"{base}/{video_id}")
assert del_resp.status_code == 200
assert del_resp.json()["deleted"] is True
# 5. Confirm gone
gone_resp = requests.get(f"{base}/{video_id}")
assert gone_resp.status_code == 404

View File

@ -0,0 +1,876 @@
"""trtllm-serve visual_gen endpoints tests.
Tests all endpoints registered for the VISUAL_GEN server role
in OpenAIServer.register_visual_gen_routes():
POST /v1/images/generations
POST /v1/images/edits
POST /v1/videos/generations (sync)
POST /v1/videos (async)
GET /v1/videos (list)
GET /v1/videos/{video_id} (metadata)
GET /v1/videos/{video_id}/content (download)
DELETE /v1/videos/{video_id} (delete)
"""
import asyncio
import base64
import os
from io import BytesIO
from typing import Optional
from unittest.mock import patch
import pytest
import torch
from fastapi.testclient import TestClient
from PIL import Image
from tensorrt_llm._torch.visual_gen.output import MediaOutput
from tensorrt_llm.serve.media_storage import MediaStorage
from tensorrt_llm.serve.openai_protocol import VideoJob
from tensorrt_llm.serve.visual_gen_utils import VIDEO_STORE
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
def _make_dummy_image_tensor(height: int = 64, width: int = 64) -> torch.Tensor:
"""Create a small dummy uint8 image tensor (H, W, C)."""
return torch.randint(0, 256, (height, width, 3), dtype=torch.uint8)
def _make_dummy_video_tensor(
num_frames: int = 4, height: int = 64, width: int = 64
) -> torch.Tensor:
"""Create a small dummy uint8 video tensor (T, H, W, C)."""
return torch.randint(0, 256, (num_frames, height, width, 3), dtype=torch.uint8)
def _make_dummy_audio_tensor(length: int = 16000) -> torch.Tensor:
"""Create a small dummy float32 audio tensor."""
return torch.randn(1, length, dtype=torch.float32)
def _b64_white_png_1x1() -> str:
"""Return a base64-encoded 1x1 white PNG for image edit tests."""
buf = BytesIO()
Image.new("RGB", (1, 1), (255, 255, 255)).save(buf, format="PNG")
return base64.b64encode(buf.getvalue()).decode("utf-8")
def _run_async(coro):
"""Run an async coroutine in a new event loop (for test helpers)."""
loop = asyncio.new_event_loop()
try:
return loop.run_until_complete(coro)
finally:
loop.close()
# ---------------------------------------------------------------------------
# Mock VisualGen
# ---------------------------------------------------------------------------
class MockVisualGen:
"""Lightweight stand-in for VisualGen that avoids GPU / model loading."""
def __init__(
self,
image_output: Optional[torch.Tensor] = None,
video_output: Optional[torch.Tensor] = None,
audio_output: Optional[torch.Tensor] = None,
should_fail: bool = False,
):
self._image = image_output
self._video = video_output
self._audio = audio_output
self._should_fail = should_fail
self._healthy = True
self.req_counter = 0
# --- VisualGen interface ---
def generate(self, inputs=None, params=None) -> MediaOutput:
if self._should_fail:
raise RuntimeError("Generation intentionally failed")
return MediaOutput(
image=self._image,
video=self._video,
audio=self._audio,
)
def generate_async(self, inputs=None, params=None) -> "MockDiffusionGenerationResult":
return MockDiffusionGenerationResult(
image=self._image,
video=self._video,
audio=self._audio,
should_fail=self._should_fail,
)
def _check_health(self) -> bool:
return self._healthy
async def get_stats_async(self, timeout: int):
return
def shutdown(self):
pass
class MockDiffusionGenerationResult:
"""Mock future-like result for generate_async."""
def __init__(
self,
image: Optional[torch.Tensor] = None,
video: Optional[torch.Tensor] = None,
audio: Optional[torch.Tensor] = None,
should_fail: bool = False,
):
self._image = image
self._video = video
self._audio = audio
self._should_fail = should_fail
async def result(self, timeout=None):
if self._should_fail:
raise RuntimeError("Async generation intentionally failed")
return MediaOutput(
image=self._image,
video=self._video,
audio=self._audio,
)
# ---------------------------------------------------------------------------
# Server factory
# ---------------------------------------------------------------------------
def _create_server(generator: MockVisualGen, model_name: str = "test-model") -> TestClient:
"""Instantiate an OpenAIServer for VISUAL_GEN with a mocked generator.
We patch the ``VisualGen`` name inside the ``openai_server`` module so that
``isinstance(generator, VisualGen)`` returns True for our mock.
"""
from tensorrt_llm.llmapi.disagg_utils import ServerRole
from tensorrt_llm.serve.openai_server import OpenAIServer
with patch("tensorrt_llm.serve.openai_server.VisualGen", MockVisualGen):
server = OpenAIServer(
generator=generator,
model=model_name,
tool_parser=None,
server_role=ServerRole.VISUAL_GEN,
metadata_server_cfg=None,
)
return TestClient(server.app)
# ---------------------------------------------------------------------------
# Fixtures
# ---------------------------------------------------------------------------
@pytest.fixture()
def image_client(tmp_path):
"""TestClient backed by a MockVisualGen that produces images."""
gen = MockVisualGen(image_output=_make_dummy_image_tensor())
os.environ["TRTLLM_MEDIA_STORAGE_PATH"] = str(tmp_path)
client = _create_server(gen)
yield client
os.environ.pop("TRTLLM_MEDIA_STORAGE_PATH", None)
@pytest.fixture()
def video_client(tmp_path):
"""TestClient backed by a MockVisualGen that produces videos."""
gen = MockVisualGen(video_output=_make_dummy_video_tensor())
os.environ["TRTLLM_MEDIA_STORAGE_PATH"] = str(tmp_path)
client = _create_server(gen)
yield client
os.environ.pop("TRTLLM_MEDIA_STORAGE_PATH", None)
@pytest.fixture()
def video_audio_client(tmp_path):
"""TestClient backed by a MockVisualGen that produces videos with audio."""
gen = MockVisualGen(
video_output=_make_dummy_video_tensor(),
audio_output=_make_dummy_audio_tensor(),
)
os.environ["TRTLLM_MEDIA_STORAGE_PATH"] = str(tmp_path)
client = _create_server(gen)
yield client
os.environ.pop("TRTLLM_MEDIA_STORAGE_PATH", None)
@pytest.fixture()
def failing_client(tmp_path):
"""TestClient backed by a MockVisualGen that always fails."""
gen = MockVisualGen(should_fail=True)
os.environ["TRTLLM_MEDIA_STORAGE_PATH"] = str(tmp_path)
client = _create_server(gen)
yield client
os.environ.pop("TRTLLM_MEDIA_STORAGE_PATH", None)
@pytest.fixture(autouse=True)
def _clear_video_store():
"""Reset the global VIDEO_STORE before each test."""
VIDEO_STORE._items.clear()
yield
VIDEO_STORE._items.clear()
@pytest.fixture(autouse=True)
def _mock_video_encoding():
"""Mock MP4 encoding to avoid PyAV dependency in unit tests.
Replaces MediaStorage._save_mp4 with a stub that writes a small
dummy file so FileResponse can serve it.
"""
def _dummy_save_mp4(video, audio, output_path, frame_rate):
os.makedirs(os.path.dirname(str(output_path)) or ".", exist_ok=True)
with open(str(output_path), "wb") as f:
f.write(b"\x00\x00\x00\x1cftypisom" + b"\x00" * 32)
return str(output_path)
with patch.object(MediaStorage, "_save_mp4", staticmethod(_dummy_save_mp4)):
yield
# =========================================================================
# POST /v1/images/generations
# =========================================================================
class TestImageGeneration:
def test_basic_image_generation_b64(self, image_client):
resp = image_client.post(
"/v1/images/generations",
json={
"prompt": "A cat sitting on a mat",
"response_format": "b64_json",
"size": "64x64",
},
)
assert resp.status_code == 200
data = resp.json()
assert "data" in data
assert len(data["data"]) >= 1
img_obj = data["data"][0]
assert img_obj["b64_json"] is not None
# Verify it decodes to valid bytes
decoded = base64.b64decode(img_obj["b64_json"])
assert len(decoded) > 0
assert img_obj["revised_prompt"] == "A cat sitting on a mat"
def test_image_generation_with_optional_params(self, image_client):
resp = image_client.post(
"/v1/images/generations",
json={
"prompt": "Sunset over ocean",
"response_format": "b64_json",
"size": "128x64",
"num_inference_steps": 20,
"guidance_scale": 7.5,
"seed": 123,
"negative_prompt": "blurry",
},
)
assert resp.status_code == 200
data = resp.json()
assert data["size"] == "128x64"
def test_image_generation_url_format_not_supported(self, image_client):
resp = image_client.post(
"/v1/images/generations",
json={
"prompt": "A dog",
"response_format": "url",
},
)
assert resp.status_code == 501
def test_image_generation_auto_size(self, image_client):
resp = image_client.post(
"/v1/images/generations",
json={
"prompt": "A tree",
"response_format": "b64_json",
"size": "auto",
},
)
assert resp.status_code == 200
def test_image_generation_failure(self, failing_client):
resp = failing_client.post(
"/v1/images/generations",
json={
"prompt": "A bird",
"response_format": "b64_json",
},
)
assert resp.status_code == 400
def test_image_generation_invalid_size(self, image_client):
"""Invalid size triggers RequestValidationError → custom handler → 400."""
resp = image_client.post(
"/v1/images/generations",
json={
"prompt": "A mountain",
"response_format": "b64_json",
"size": "invalid",
},
)
assert resp.status_code == 400
def test_image_generation_null_output(self, tmp_path):
"""Generator returns MediaOutput with image=None."""
gen = MockVisualGen(image_output=None)
os.environ["TRTLLM_MEDIA_STORAGE_PATH"] = str(tmp_path)
client = _create_server(gen)
resp = client.post(
"/v1/images/generations",
json={
"prompt": "null image",
"response_format": "b64_json",
},
)
assert resp.status_code == 500
os.environ.pop("TRTLLM_MEDIA_STORAGE_PATH", None)
def test_image_generation_multiple_n(self, image_client):
"""Request n=2 images in one call."""
resp = image_client.post(
"/v1/images/generations",
json={
"prompt": "Flowers",
"response_format": "b64_json",
"size": "64x64",
"n": 2,
},
)
assert resp.status_code == 200
def test_image_generation_hd_quality(self, image_client):
resp = image_client.post(
"/v1/images/generations",
json={
"prompt": "HD landscape",
"response_format": "b64_json",
"quality": "hd",
},
)
assert resp.status_code == 200
def test_missing_prompt_image_generation(self, image_client):
"""Missing required field → RequestValidationError → custom handler → 400."""
resp = image_client.post(
"/v1/images/generations",
json={},
)
assert resp.status_code == 400
# =========================================================================
# POST /v1/images/edits
# =========================================================================
class TestImageEdit:
def test_basic_image_edit(self, image_client):
b64_img = _b64_white_png_1x1()
resp = image_client.post(
"/v1/images/edits",
json={
"image": b64_img,
"prompt": "Make it blue",
"num_inference_steps": 10,
},
)
assert resp.status_code == 200
data = resp.json()
assert "data" in data
assert len(data["data"]) >= 1
assert data["data"][0]["b64_json"] is not None
def test_image_edit_with_list_images(self, image_client):
b64_img = _b64_white_png_1x1()
resp = image_client.post(
"/v1/images/edits",
json={
"image": [b64_img, b64_img],
"prompt": "Merge them",
"num_inference_steps": 10,
},
)
assert resp.status_code == 200
def test_image_edit_with_mask(self, image_client):
b64_img = _b64_white_png_1x1()
b64_mask = _b64_white_png_1x1()
resp = image_client.post(
"/v1/images/edits",
json={
"image": b64_img,
"prompt": "Remove object",
"mask": b64_mask,
"num_inference_steps": 10,
},
)
assert resp.status_code == 200
def test_image_edit_with_optional_params(self, image_client):
b64_img = _b64_white_png_1x1()
resp = image_client.post(
"/v1/images/edits",
json={
"image": b64_img,
"prompt": "Enhance colors",
"size": "128x128",
"guidance_scale": 8.0,
"num_inference_steps": 15,
"seed": 42,
"negative_prompt": "dark",
},
)
assert resp.status_code == 200
data = resp.json()
assert data["size"] == "128x128"
def test_image_edit_failure(self, failing_client):
b64_img = _b64_white_png_1x1()
resp = failing_client.post(
"/v1/images/edits",
json={
"image": b64_img,
"prompt": "Edit this",
"num_inference_steps": 10,
},
)
assert resp.status_code == 500
def test_missing_image_for_edit(self, image_client):
"""Missing required field → RequestValidationError → custom handler → 400."""
resp = image_client.post(
"/v1/images/edits",
json={
"prompt": "Edit without image",
},
)
assert resp.status_code == 400
# =========================================================================
# POST /v1/videos/generations (synchronous)
# =========================================================================
@pytest.mark.threadleak(enabled=False) # FileResponse spawns AnyIO worker threads
class TestVideoGenerationSync:
def test_basic_sync_video_generation(self, video_client):
resp = video_client.post(
"/v1/videos/generations",
json={
"prompt": "A rocket launching",
"size": "64x64",
"seconds": 1.0,
"fps": 8,
},
headers={"content-type": "application/json"},
)
assert resp.status_code == 200
assert resp.headers["content-type"] == "video/mp4"
assert len(resp.content) > 0
def test_sync_video_generation_with_params(self, video_client):
resp = video_client.post(
"/v1/videos/generations",
json={
"prompt": "Ocean waves",
"size": "64x64",
"seconds": 2.0,
"fps": 8,
"num_inference_steps": 10,
"guidance_scale": 5.0,
"seed": 42,
"negative_prompt": "blurry",
},
headers={"content-type": "application/json"},
)
assert resp.status_code == 200
assert len(resp.content) > 0
def test_sync_video_generation_multipart(self, video_client):
# Use files={} with a dummy file to ensure multipart/form-data
dummy_file = BytesIO(b"")
resp = video_client.post(
"/v1/videos/generations",
data={
"prompt": "Mountain sunrise",
"size": "64x64",
"seconds": "1.0",
"fps": "8",
},
files={"_dummy": ("dummy", dummy_file, "application/octet-stream")},
)
# The server will parse fields; _dummy is ignored since it's not "input_reference"
assert resp.status_code == 200
assert len(resp.content) > 0
def test_sync_video_generation_multipart_with_reference(self, video_client, tmp_path):
# Create a dummy reference image file
ref_path = tmp_path / "ref.png"
Image.new("RGB", (4, 4), (128, 128, 128)).save(str(ref_path))
with open(ref_path, "rb") as f:
resp = video_client.post(
"/v1/videos/generations",
data={
"prompt": "Animate this image",
"size": "64x64",
"seconds": "1.0",
"fps": "8",
},
files={"input_reference": ("ref.png", f, "image/png")},
)
assert resp.status_code == 200
assert len(resp.content) > 0
def test_sync_video_failure(self, failing_client):
resp = failing_client.post(
"/v1/videos/generations",
json={
"prompt": "Should fail",
"size": "64x64",
"seconds": 1.0,
"fps": 8,
},
headers={"content-type": "application/json"},
)
assert resp.status_code == 400
def test_sync_video_null_output(self, tmp_path):
"""Generator returns MediaOutput with video=None."""
gen = MockVisualGen(video_output=None)
os.environ["TRTLLM_MEDIA_STORAGE_PATH"] = str(tmp_path)
client = _create_server(gen)
resp = client.post(
"/v1/videos/generations",
json={"prompt": "null video", "size": "64x64", "seconds": 1.0, "fps": 8},
headers={"content-type": "application/json"},
)
assert resp.status_code == 500
os.environ.pop("TRTLLM_MEDIA_STORAGE_PATH", None)
def test_sync_video_unsupported_content_type(self, video_client):
resp = video_client.post(
"/v1/videos/generations",
content=b"some raw bytes",
headers={"content-type": "text/plain"},
)
assert resp.status_code == 400
def test_sync_video_missing_prompt_json(self, video_client):
"""Missing required prompt → Pydantic ValidationError → 400."""
resp = video_client.post(
"/v1/videos/generations",
json={"size": "64x64"},
headers={"content-type": "application/json"},
)
assert resp.status_code == 400
def test_sync_video_missing_prompt_multipart(self, video_client):
"""Missing prompt in multipart form → ValueError → 400."""
dummy_file = BytesIO(b"")
resp = video_client.post(
"/v1/videos/generations",
data={"size": "64x64"},
files={"_dummy": ("dummy", dummy_file, "application/octet-stream")},
)
assert resp.status_code == 400
# =========================================================================
# POST /v1/videos (asynchronous)
# =========================================================================
class TestVideoGenerationAsync:
def test_async_video_returns_202(self, video_client):
resp = video_client.post(
"/v1/videos",
json={
"prompt": "A dancing robot",
"size": "64x64",
"seconds": 1.0,
"fps": 8,
},
headers={"content-type": "application/json"},
)
assert resp.status_code == 202
data = resp.json()
assert data["status"] == "queued"
assert data["object"] == "video"
assert data["prompt"] == "A dancing robot"
assert data["id"].startswith("video_")
def test_async_video_job_metadata_fields(self, video_client):
resp = video_client.post(
"/v1/videos",
json={
"prompt": "Starry night",
"size": "64x64",
"seconds": 2.0,
"fps": 12,
},
headers={"content-type": "application/json"},
)
data = resp.json()
assert "created_at" in data
assert data["duration"] == 2.0
assert data["fps"] == 12
assert data["size"] == "64x64"
def test_async_video_multipart(self, video_client):
"""Multipart encoding requires a file field to trigger the correct content-type."""
dummy_file = BytesIO(b"")
resp = video_client.post(
"/v1/videos",
data={
"prompt": "A sunset",
"size": "64x64",
"seconds": "1.0",
"fps": "8",
},
files={"_dummy": ("dummy", dummy_file, "application/octet-stream")},
)
assert resp.status_code == 202
def test_async_video_invalid_seconds(self, video_client):
"""Seconds must be between 1.0 and 16.0. Validation error → 400."""
resp = video_client.post(
"/v1/videos",
json={
"prompt": "Too short",
"seconds": 0.1,
"size": "64x64",
"fps": 8,
},
headers={"content-type": "application/json"},
)
assert resp.status_code == 400
def test_async_video_invalid_fps(self, video_client):
"""Fps must be between 8 and 60. Validation error → 400."""
resp = video_client.post(
"/v1/videos",
json={
"prompt": "Bad fps",
"seconds": 1.0,
"fps": 2,
"size": "64x64",
},
headers={"content-type": "application/json"},
)
assert resp.status_code == 400
# =========================================================================
# GET /v1/videos (list)
# =========================================================================
class TestListVideos:
def test_list_videos_empty(self, video_client):
resp = video_client.get("/v1/videos")
assert resp.status_code == 200
data = resp.json()
assert data["object"] == "list"
assert data["data"] == []
def test_list_videos_after_creation(self, video_client):
# Create two video jobs
video_client.post(
"/v1/videos",
json={"prompt": "First video", "size": "64x64", "seconds": 1.0, "fps": 8},
headers={"content-type": "application/json"},
)
video_client.post(
"/v1/videos",
json={"prompt": "Second video", "size": "64x64", "seconds": 1.0, "fps": 8},
headers={"content-type": "application/json"},
)
resp = video_client.get("/v1/videos")
assert resp.status_code == 200
data = resp.json()
assert len(data["data"]) == 2
# =========================================================================
# GET /v1/videos/{video_id} (metadata)
# =========================================================================
class TestGetVideoMetadata:
def test_get_video_metadata_success(self, video_client):
create_resp = video_client.post(
"/v1/videos",
json={"prompt": "Space walk", "size": "64x64", "seconds": 1.0, "fps": 8},
headers={"content-type": "application/json"},
)
video_id = create_resp.json()["id"]
resp = video_client.get(f"/v1/videos/{video_id}")
assert resp.status_code == 200
data = resp.json()
assert data["id"] == video_id
assert data["object"] == "video"
assert data["prompt"] == "Space walk"
def test_get_video_metadata_not_found(self, video_client):
resp = video_client.get("/v1/videos/video_nonexistent")
assert resp.status_code == 404
# =========================================================================
# GET /v1/videos/{video_id}/content (download)
# =========================================================================
@pytest.mark.threadleak(enabled=False) # FileResponse spawns AnyIO worker threads
class TestGetVideoContent:
def _insert_video_job(self, video_id: str, status: str = "queued"):
import time as _time
job = VideoJob(
created_at=int(_time.time()),
id=video_id,
model="test-model",
prompt="test prompt",
status=status,
)
_run_async(VIDEO_STORE.upsert(video_id, job))
def test_get_video_content_success(self, tmp_path):
gen = MockVisualGen(video_output=_make_dummy_video_tensor())
os.environ["TRTLLM_MEDIA_STORAGE_PATH"] = str(tmp_path)
client = _create_server(gen)
video_id = "video_testcontent"
self._insert_video_job(video_id, status="completed")
# Write a dummy mp4 file so FileResponse can serve it
video_path = tmp_path / f"{video_id}.mp4"
video_path.write_bytes(b"\x00\x00\x00\x1cftyp" + b"\x00" * 16)
resp = client.get(f"/v1/videos/{video_id}/content")
assert resp.status_code == 200
assert "video/mp4" in resp.headers.get("content-type", "")
assert len(resp.content) > 0
os.environ.pop("TRTLLM_MEDIA_STORAGE_PATH", None)
def test_get_video_content_not_found(self, video_client):
resp = video_client.get("/v1/videos/video_nonexistent/content")
assert resp.status_code == 404
def test_get_video_content_not_ready(self, tmp_path):
"""A queued video should return 400 when its content is requested."""
gen = MockVisualGen(video_output=_make_dummy_video_tensor())
os.environ["TRTLLM_MEDIA_STORAGE_PATH"] = str(tmp_path)
client = _create_server(gen)
video_id = "video_notready"
self._insert_video_job(video_id, status="queued")
resp = client.get(f"/v1/videos/{video_id}/content")
assert resp.status_code == 400
os.environ.pop("TRTLLM_MEDIA_STORAGE_PATH", None)
def test_get_video_content_completed_but_file_missing(self, tmp_path):
"""Video marked completed but file deleted from disk → 404."""
gen = MockVisualGen(video_output=_make_dummy_video_tensor())
os.environ["TRTLLM_MEDIA_STORAGE_PATH"] = str(tmp_path)
client = _create_server(gen)
video_id = "video_nofile"
self._insert_video_job(video_id, status="completed")
# Do NOT write a file
resp = client.get(f"/v1/videos/{video_id}/content")
assert resp.status_code == 404
os.environ.pop("TRTLLM_MEDIA_STORAGE_PATH", None)
# =========================================================================
# DELETE /v1/videos/{video_id}
# =========================================================================
class TestDeleteVideo:
def test_delete_video_success(self, tmp_path):
gen = MockVisualGen(video_output=_make_dummy_video_tensor())
os.environ["TRTLLM_MEDIA_STORAGE_PATH"] = str(tmp_path)
client = _create_server(gen)
create_resp = client.post(
"/v1/videos",
json={"prompt": "Delete me", "size": "64x64", "seconds": 1.0, "fps": 8},
headers={"content-type": "application/json"},
)
video_id = create_resp.json()["id"]
# Write a dummy video file
(tmp_path / f"{video_id}.mp4").write_bytes(b"\x00" * 32)
resp = client.delete(f"/v1/videos/{video_id}")
assert resp.status_code == 200
data = resp.json()
assert data["deleted"] is True
# Verify it's gone from the store
resp = client.get(f"/v1/videos/{video_id}")
assert resp.status_code == 404
# Verify file is deleted
assert not (tmp_path / f"{video_id}.mp4").exists()
os.environ.pop("TRTLLM_MEDIA_STORAGE_PATH", None)
def test_delete_video_not_found(self, video_client):
resp = video_client.delete("/v1/videos/video_nonexistent")
assert resp.status_code == 404
def test_delete_video_without_file_on_disk(self, video_client):
"""Delete a video job that exists in the store but has no file on disk."""
create_resp = video_client.post(
"/v1/videos",
json={"prompt": "No file", "size": "64x64", "seconds": 1.0, "fps": 8},
headers={"content-type": "application/json"},
)
video_id = create_resp.json()["id"]
resp = video_client.delete(f"/v1/videos/{video_id}")
assert resp.status_code == 200
data = resp.json()
assert data["deleted"] is True
def test_delete_video_then_list_empty(self, video_client):
"""After deleting the only video, the list should be empty."""
create_resp = video_client.post(
"/v1/videos",
json={"prompt": "Ephemeral", "size": "64x64", "seconds": 1.0, "fps": 8},
headers={"content-type": "application/json"},
)
video_id = create_resp.json()["id"]
video_client.delete(f"/v1/videos/{video_id}")
resp = video_client.get("/v1/videos")
assert resp.status_code == 200
assert resp.json()["data"] == []

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff