mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-02-15 23:44:02 +08:00
[TRTLLM-10612][feat] Initial support of AIGV models in TRTLLM (#11462)
Signed-off-by: Chang Liu (Enterprise Products) <liuc@nvidia.com> Signed-off-by: Chang Liu <9713593+chang-l@users.noreply.github.com> Signed-off-by: Zhenhua Wang <zhenhuaw@nvidia.com> Co-authored-by: Freddy Qi <junq@nvidia.com> Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com> Co-authored-by: Zhenhua Wang <zhenhuaw@nvidia.com>
This commit is contained in:
parent
19a3031ecb
commit
26901e4aa0
@ -5,9 +5,6 @@ TensorRT LLM
|
||||
<h4>TensorRT LLM provides users with an easy-to-use Python API to define Large Language Models (LLMs) and supports
|
||||
state-of-the-art optimizations to perform inference efficiently on NVIDIA GPUs.</h4>
|
||||
|
||||
🌟 TensorRT LLM is experimenting with Image&Video Generation models in [TensorRT-LLM/feat/visual_gen](https://github.com/NVIDIA/TensorRT-LLM/tree/feat/visual_gen/tensorrt_llm/visual_gen) branch.
|
||||
This branch is a prototype and not stable for production use. PRs are not accepted.
|
||||
|
||||
[](https://nvidia.github.io/TensorRT-LLM/)
|
||||
[](https://deepwiki.com/NVIDIA/TensorRT-LLM)
|
||||
[](https://www.python.org/downloads/release/python-3123/)
|
||||
|
||||
172
examples/visual_gen/README.md
Normal file
172
examples/visual_gen/README.md
Normal file
@ -0,0 +1,172 @@
|
||||
# Visual Generation Examples
|
||||
|
||||
Quick reference for running visual generation models (WAN).
|
||||
|
||||
## Prerequisites
|
||||
|
||||
```bash
|
||||
# Install dependencies (from repository root)
|
||||
pip install -r requirements-dev.txt
|
||||
pip install git+https://github.com/huggingface/diffusers.git
|
||||
pip install av
|
||||
```
|
||||
|
||||
## Quick Start
|
||||
|
||||
```bash
|
||||
# Set MODEL_ROOT to your model directory (required for examples)
|
||||
export MODEL_ROOT=/llm-models
|
||||
# Optional: PROJECT_ROOT defaults to repo root when run from examples/visual_gen
|
||||
|
||||
# Run all examples (auto-detects GPUs)
|
||||
cd examples/visual_gen
|
||||
./visual_gen_examples.sh
|
||||
```
|
||||
|
||||
|
||||
## Environment Variables
|
||||
|
||||
| Variable | Default | Description |
|
||||
|----------|---------|-------------|
|
||||
| `PROJECT_ROOT` | Auto-detected | Path to repository root (set when running from `examples/visual_gen`) |
|
||||
| `MODEL_ROOT` | `/llm-models` | Path to model directory |
|
||||
| `TLLM_LOG_LEVEL` | `INFO` | Logging level |
|
||||
|
||||
---
|
||||
|
||||
## WAN (Text-to-Video)
|
||||
|
||||
### Basic Usage
|
||||
|
||||
**Single GPU:**
|
||||
```bash
|
||||
python visual_gen_wan_t2v.py \
|
||||
--model_path ${MODEL_ROOT}/Wan2.1-T2V-1.3B-Diffusers \
|
||||
--prompt "A cute cat playing piano" \
|
||||
--height 480 --width 832 --num_frames 33 \
|
||||
--output_path output.mp4
|
||||
```
|
||||
|
||||
**With TeaCache:**
|
||||
```bash
|
||||
python visual_gen_wan_t2v.py \
|
||||
--model_path ${MODEL_ROOT}/Wan2.1-T2V-1.3B-Diffusers \
|
||||
--prompt "A cute cat playing piano" \
|
||||
--height 480 --width 832 --num_frames 33 \
|
||||
--enable_teacache \
|
||||
--output_path output.mp4
|
||||
```
|
||||
|
||||
### Multi-GPU Parallelism
|
||||
|
||||
WAN supports two parallelism modes that can be combined:
|
||||
- **CFG Parallelism**: Split positive/negative prompts across GPUs
|
||||
- **Ulysses Parallelism**: Split sequence across GPUs for longer sequences
|
||||
|
||||
|
||||
**Ulysses Only (2 GPUs):**
|
||||
```bash
|
||||
python visual_gen_wan_t2v.py \
|
||||
--model_path ${MODEL_ROOT}/Wan2.1-T2V-1.3B-Diffusers \
|
||||
--prompt "A cute cat playing piano" \
|
||||
--height 480 --width 832 --num_frames 33 \
|
||||
--attention_backend TRTLLM \
|
||||
--cfg_size 1 --ulysses_size 2 \
|
||||
--output_path output.mp4
|
||||
```
|
||||
GPU Layout: GPU 0-1 share sequence (6 heads each)
|
||||
|
||||
**CFG Only (2 GPUs):**
|
||||
```bash
|
||||
python visual_gen_wan_t2v.py \
|
||||
--model_path ${MODEL_ROOT}/Wan2.1-T2V-1.3B-Diffusers \
|
||||
--prompt "A cute cat playing piano" \
|
||||
--height 480 --width 832 --num_frames 33 \
|
||||
--attention_backend TRTLLM \
|
||||
--cfg_size 2 --ulysses_size 1 \
|
||||
--output_path output.mp4
|
||||
```
|
||||
GPU Layout: GPU 0 (positive) | GPU 1 (negative)
|
||||
|
||||
**CFG + Ulysses (4 GPUs):**
|
||||
```bash
|
||||
python visual_gen_wan_t2v.py \
|
||||
--model_path ${MODEL_ROOT}/Wan2.1-T2V-1.3B-Diffusers \
|
||||
--prompt "A cute cat playing piano" \
|
||||
--height 480 --width 832 --num_frames 33 \
|
||||
--attention_backend TRTLLM \
|
||||
--cfg_size 2 --ulysses_size 2 \
|
||||
--output_path output.mp4
|
||||
```
|
||||
GPU Layout: GPU 0-1 (positive, Ulysses) | GPU 2-3 (negative, Ulysses)
|
||||
|
||||
**Large-Scale (8 GPUs):**
|
||||
```bash
|
||||
python visual_gen_wan_t2v.py \
|
||||
--model_path ${MODEL_ROOT}/Wan2.1-T2V-1.3B-Diffusers \
|
||||
--prompt "A cute cat playing piano" \
|
||||
--height 480 --width 832 --num_frames 33 \
|
||||
--attention_backend TRTLLM \
|
||||
--cfg_size 2 --ulysses_size 4 \
|
||||
--output_path output.mp4
|
||||
```
|
||||
GPU Layout: GPU 0-3 (positive) | GPU 4-7 (negative)
|
||||
|
||||
---
|
||||
|
||||
## Common Arguments
|
||||
|
||||
| Argument | WAN | Default | Description |
|
||||
|----------|-----|---------|-------------|
|
||||
| `--height` | ✓ | 720 | Output height |
|
||||
| `--width` | ✓ | 1280 | Output width |
|
||||
| `--num_frames` | ✓ | 81 | Number of frames |
|
||||
| `--steps` | ✓ | 50 | Denoising steps |
|
||||
| `--guidance_scale` | ✓ | 5.0 | CFG guidance strength |
|
||||
| `--seed` | ✓ | 42 | Random seed |
|
||||
| `--enable_teacache` | ✓ | False | Cache optimization |
|
||||
| `--teacache_thresh` | ✓ | 0.2 | TeaCache similarity threshold |
|
||||
| `--attention_backend` | ✓ | VANILLA | VANILLA or TRTLLM |
|
||||
| `--cfg_size` | ✓ | 1 | CFG parallelism |
|
||||
| `--ulysses_size` | ✓ | 1 | Sequence parallelism |
|
||||
| `--linear_type` | ✓ | default | Quantization type |
|
||||
|
||||
## Troubleshooting
|
||||
|
||||
**Out of Memory:**
|
||||
- Use quantization: `--linear_type trtllm-fp8-blockwise`
|
||||
- Reduce resolution or frames
|
||||
- Enable TeaCache: `--enable_teacache`
|
||||
- Use Ulysses parallelism with more GPUs
|
||||
|
||||
**Slow Inference:**
|
||||
- Enable TeaCache: `--enable_teacache`
|
||||
- Use TRTLLM backend: `--attention_backend TRTLLM`
|
||||
- Use multi-GPU: `--cfg_size 2` or `--ulysses_size 2`
|
||||
|
||||
**Import Errors:**
|
||||
- Run from repository root
|
||||
- Install necessary dependencies, e.g., `pip install -r requirements-dev.txt`
|
||||
|
||||
**Ulysses Errors:**
|
||||
- `ulysses_size` must divide 12 (WAN heads)
|
||||
- Total GPUs = `cfg_size × ulysses_size`
|
||||
- Sequence length must be divisible by `ulysses_size`
|
||||
|
||||
## Output Formats
|
||||
|
||||
- **WAN**: `.mp4` (video), `.gif` (animated), `.png` (single frame)
|
||||
|
||||
## Baseline Validation
|
||||
|
||||
Compare with official HuggingFace Diffusers implementation:
|
||||
|
||||
```bash
|
||||
# Run HuggingFace baselines
|
||||
./hf_examples.sh
|
||||
|
||||
# Or run individual models
|
||||
python hf_wan.py --model_path ${MODEL_ROOT}/Wan2.1-T2V-1.3B-Diffusers
|
||||
```
|
||||
|
||||
Compare outputs with same seed for correctness verification.
|
||||
BIN
examples/visual_gen/cat_piano.png
Normal file
BIN
examples/visual_gen/cat_piano.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 445 KiB |
128
examples/visual_gen/hf_examples.sh
Executable file
128
examples/visual_gen/hf_examples.sh
Executable file
@ -0,0 +1,128 @@
|
||||
#!/bin/bash
|
||||
# HuggingFace Baseline Tests - Official Diffusers Implementation
|
||||
#
|
||||
# Usage:
|
||||
# export PROJECT_ROOT=/path/to/tekit
|
||||
# export MODEL_ROOT=/path/to/models
|
||||
# ./hf_examples.sh
|
||||
#
|
||||
# Or inline:
|
||||
# PROJECT_ROOT=/workspace/gitlab/tekit-b200 MODEL_ROOT=/llm-models ./hf_examples.sh
|
||||
|
||||
set -e # Exit on error
|
||||
|
||||
# Environment variables with defaults
|
||||
PROJECT_ROOT=${PROJECT_ROOT:-"/workspace/gitlab/tekit-b200"}
|
||||
MODEL_ROOT=${MODEL_ROOT:-"/llm-models"}
|
||||
|
||||
# Log configuration
|
||||
export TLLM_LOG_LEVEL=${TLLM_LOG_LEVEL:-"INFO"}
|
||||
|
||||
echo "============================================"
|
||||
echo "HuggingFace Diffusers Baseline Tests"
|
||||
echo "============================================"
|
||||
echo "PROJECT_ROOT: $PROJECT_ROOT"
|
||||
echo "MODEL_ROOT: $MODEL_ROOT"
|
||||
echo "LOG_LEVEL: $TLLM_LOG_LEVEL"
|
||||
echo ""
|
||||
echo "Purpose: Establish baseline results using"
|
||||
echo " official diffusers implementations"
|
||||
echo "============================================"
|
||||
echo ""
|
||||
|
||||
# Check Python dependencies
|
||||
echo "Checking dependencies..."
|
||||
MISSING_DEPS=""
|
||||
|
||||
if ! python -c "import diffusers" 2>/dev/null; then
|
||||
echo "❌ ERROR: diffusers not found"
|
||||
MISSING_DEPS="$MISSING_DEPS diffusers"
|
||||
fi
|
||||
|
||||
if ! python -c "import torch" 2>/dev/null; then
|
||||
echo "❌ ERROR: torch not found"
|
||||
MISSING_DEPS="$MISSING_DEPS torch"
|
||||
fi
|
||||
|
||||
if [ -n "$MISSING_DEPS" ]; then
|
||||
echo ""
|
||||
echo "❌ Missing required dependencies:$MISSING_DEPS"
|
||||
echo "Install with: pip install$MISSING_DEPS"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
echo "✅ All required dependencies found"
|
||||
echo ""
|
||||
|
||||
# Detect GPU
|
||||
if command -v nvidia-smi &> /dev/null; then
|
||||
GPU_COUNT=$(nvidia-smi --query-gpu=name --format=csv,noheader | wc -l)
|
||||
echo "Detected $GPU_COUNT GPU(s)"
|
||||
GPU_NAME=$(nvidia-smi --query-gpu=name --format=csv,noheader | head -1)
|
||||
echo "GPU: $GPU_NAME"
|
||||
else
|
||||
echo "⚠️ WARNING: nvidia-smi not found"
|
||||
echo " Continuing with CPU (very slow!)"
|
||||
GPU_COUNT=0
|
||||
fi
|
||||
echo ""
|
||||
|
||||
# Create output directory (in current directory)
|
||||
OUTPUT_DIR="./baseline_outputs"
|
||||
mkdir -p "$OUTPUT_DIR"
|
||||
echo "Output directory: $OUTPUT_DIR ($(pwd)/baseline_outputs)"
|
||||
echo ""
|
||||
|
||||
#############################################
|
||||
# WAN (Wan2.1) Baseline Test
|
||||
#############################################
|
||||
|
||||
echo "============================================"
|
||||
echo "1/1: WAN Baseline Test"
|
||||
echo "============================================"
|
||||
echo ""
|
||||
|
||||
WAN_MODEL="${MODEL_ROOT}/Wan2.1-T2V-1.3B-Diffusers/"
|
||||
WAN_OUTPUT="${OUTPUT_DIR}/wan_baseline.gif"
|
||||
|
||||
if [ -d "$WAN_MODEL" ]; then
|
||||
echo "Testing WAN with official diffusers..."
|
||||
python ${PROJECT_ROOT}/examples/visual_gen/hf_wan.py \
|
||||
--model_path "$WAN_MODEL" \
|
||||
--output_path "$WAN_OUTPUT" \
|
||||
--prompt "A cute cat playing piano" \
|
||||
--height 480 \
|
||||
--width 832 \
|
||||
--num_frames 33 \
|
||||
--steps 50 \
|
||||
--guidance_scale 7.0 \
|
||||
--seed 42
|
||||
echo ""
|
||||
echo "✅ WAN baseline test completed"
|
||||
echo " Output: $WAN_OUTPUT"
|
||||
else
|
||||
echo "⚠️ SKIPPED: WAN model not found at $WAN_MODEL"
|
||||
fi
|
||||
|
||||
echo ""
|
||||
|
||||
#############################################
|
||||
# Summary
|
||||
#############################################
|
||||
|
||||
echo "============================================"
|
||||
echo "Baseline Tests Complete!"
|
||||
echo "============================================"
|
||||
echo ""
|
||||
echo "Output files saved to: $OUTPUT_DIR"
|
||||
echo ""
|
||||
ls -lh "$OUTPUT_DIR" 2>/dev/null || echo "No outputs generated"
|
||||
echo ""
|
||||
echo "Next Steps:"
|
||||
echo " 1. Verify outputs are correct (images/videos generated)"
|
||||
echo " 2. Compare with custom implementation outputs"
|
||||
echo " 3. Use these as reference/baseline for debugging"
|
||||
echo ""
|
||||
echo "Comparison command:"
|
||||
echo " diff -r $OUTPUT_DIR <custom_implementation_outputs>"
|
||||
echo "============================================"
|
||||
141
examples/visual_gen/hf_wan.py
Executable file
141
examples/visual_gen/hf_wan.py
Executable file
@ -0,0 +1,141 @@
|
||||
#!/usr/bin/env python3
|
||||
"""Baseline test for WAN using official diffusers library."""
|
||||
|
||||
import sys
|
||||
|
||||
import torch
|
||||
from output_handler import OutputHandler, postprocess_hf_video_tensor
|
||||
|
||||
from tensorrt_llm._torch.visual_gen import MediaOutput
|
||||
|
||||
|
||||
def test_wan_baseline(
|
||||
model_path: str,
|
||||
output_path: str,
|
||||
prompt: str = "A cute cat playing piano",
|
||||
height: int = 480,
|
||||
width: int = 832,
|
||||
num_frames: int = 33,
|
||||
num_inference_steps: int = 50,
|
||||
guidance_scale: float = 7.0,
|
||||
seed: int = 42,
|
||||
):
|
||||
"""Test WAN video generation with official diffusers."""
|
||||
from diffusers import WanPipeline
|
||||
|
||||
print("=" * 80)
|
||||
print("WAN Baseline Test (Official Diffusers)")
|
||||
print("=" * 80)
|
||||
print()
|
||||
|
||||
# Load pipeline
|
||||
print(f"Loading WAN pipeline from {model_path}...")
|
||||
pipe = WanPipeline.from_pretrained(model_path, torch_dtype=torch.bfloat16)
|
||||
pipe.to("cuda")
|
||||
print("✅ Pipeline loaded")
|
||||
print()
|
||||
|
||||
# Check model states
|
||||
print("Model Training States:")
|
||||
print(f" text_encoder.training: {pipe.text_encoder.training}")
|
||||
print(f" transformer.training: {pipe.transformer.training}")
|
||||
print(f" vae.training: {pipe.vae.training}")
|
||||
print()
|
||||
|
||||
# Generate video
|
||||
print(f"Generating video: '{prompt}'")
|
||||
print(
|
||||
f"Parameters: {height}x{width}, {num_frames} frames, {num_inference_steps} steps, guidance={guidance_scale}"
|
||||
)
|
||||
print()
|
||||
|
||||
# Set random seed
|
||||
generator = torch.Generator(device="cuda").manual_seed(seed)
|
||||
|
||||
result = pipe(
|
||||
prompt=prompt,
|
||||
height=height,
|
||||
width=width,
|
||||
num_frames=num_frames,
|
||||
num_inference_steps=num_inference_steps,
|
||||
guidance_scale=guidance_scale,
|
||||
generator=generator,
|
||||
output_type="pt",
|
||||
return_dict=False,
|
||||
)
|
||||
|
||||
video = result[0]
|
||||
|
||||
# Post-process video tensor: (B, T, C, H, W) -> (T, H, W, C) uint8
|
||||
video = postprocess_hf_video_tensor(video, remove_batch_dim=True)
|
||||
|
||||
print("=" * 80)
|
||||
print("Generation Complete!")
|
||||
print("=" * 80)
|
||||
print(f"Video shape: {video.shape}")
|
||||
print(f"Video dtype: {video.dtype}")
|
||||
print()
|
||||
|
||||
# Save output
|
||||
print(f"Saving output to {output_path}...")
|
||||
OutputHandler.save(output=MediaOutput(video=video), output_path=output_path, frame_rate=24.0)
|
||||
print(f"✅ Saved to {output_path}")
|
||||
print()
|
||||
|
||||
print("=" * 80)
|
||||
print("WAN BASELINE TEST PASSED ✅")
|
||||
print("=" * 80)
|
||||
return video
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import argparse
|
||||
|
||||
parser = argparse.ArgumentParser(
|
||||
description="HuggingFace Baseline - WAN Text-to-Video Generation"
|
||||
)
|
||||
|
||||
# Model & Input
|
||||
parser.add_argument(
|
||||
"--model_path",
|
||||
type=str,
|
||||
default="/llm-models/Wan2.1-T2V-1.3B-Diffusers/",
|
||||
help="Path to WAN model",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--prompt", type=str, default="A cute cat playing piano", help="Text prompt for generation"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--output_path", type=str, default="wan_baseline.gif", help="Output file path"
|
||||
)
|
||||
|
||||
# Generation parameters
|
||||
parser.add_argument("--height", type=int, default=480, help="Video height")
|
||||
parser.add_argument("--width", type=int, default=832, help="Video width")
|
||||
parser.add_argument("--num_frames", type=int, default=33, help="Number of frames to generate")
|
||||
parser.add_argument("--steps", type=int, default=50, help="Number of denoising steps")
|
||||
parser.add_argument(
|
||||
"--guidance_scale", type=float, default=7.0, help="Classifier-free guidance scale"
|
||||
)
|
||||
parser.add_argument("--seed", type=int, default=42, help="Random seed")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
try:
|
||||
test_wan_baseline(
|
||||
args.model_path,
|
||||
args.output_path,
|
||||
prompt=args.prompt,
|
||||
height=args.height,
|
||||
width=args.width,
|
||||
num_frames=args.num_frames,
|
||||
num_inference_steps=args.steps,
|
||||
guidance_scale=args.guidance_scale,
|
||||
seed=args.seed,
|
||||
)
|
||||
except Exception as e:
|
||||
print(f"\n❌ ERROR: {e}")
|
||||
import traceback
|
||||
|
||||
traceback.print_exc()
|
||||
sys.exit(1)
|
||||
237
examples/visual_gen/output_handler.py
Normal file
237
examples/visual_gen/output_handler.py
Normal file
@ -0,0 +1,237 @@
|
||||
"""Unified output handler for diffusion model outputs."""
|
||||
|
||||
import os
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
from PIL import Image
|
||||
|
||||
from tensorrt_llm import logger
|
||||
from tensorrt_llm.llmapi.visual_gen import MediaOutput
|
||||
|
||||
|
||||
def postprocess_hf_video_tensor(video: torch.Tensor, remove_batch_dim: bool = True) -> torch.Tensor:
|
||||
"""Post-process video tensor from HuggingFace pipeline output to final format.
|
||||
|
||||
HuggingFace pipelines with output_type="pt" return videos in (B, T, C, H, W) format,
|
||||
which is different from VAE decoder output format.
|
||||
|
||||
Args:
|
||||
video: Video tensor in (B, T, C, H, W) format from HuggingFace pipeline
|
||||
remove_batch_dim: Whether to remove batch dimension. Default True for typical
|
||||
single-batch video generation.
|
||||
|
||||
Returns:
|
||||
Post-processed video tensor:
|
||||
- If remove_batch_dim=True: (T, H, W, C) uint8 tensor
|
||||
- If remove_batch_dim=False: (B, T, H, W, C) uint8 tensor
|
||||
|
||||
Note:
|
||||
Assumes video values are in [-1, 1] range (standard pipeline output).
|
||||
"""
|
||||
# Remove batch dimension first if requested
|
||||
if remove_batch_dim:
|
||||
video = video[0] # (B, T, C, H, W) -> (T, C, H, W)
|
||||
video = video.permute(0, 2, 3, 1) # (T, C, H, W) -> (T, H, W, C)
|
||||
else:
|
||||
video = video.permute(0, 1, 3, 4, 2) # (B, T, C, H, W) -> (B, T, H, W, C)
|
||||
|
||||
# Normalize to [0, 1] range
|
||||
video = (video / 2 + 0.5).clamp(0, 1)
|
||||
|
||||
# Convert to uint8
|
||||
video = (video * 255).round().to(torch.uint8)
|
||||
|
||||
return video
|
||||
|
||||
|
||||
def postprocess_hf_image_tensor(image: torch.Tensor) -> torch.Tensor:
|
||||
"""Post-process image tensor from HuggingFace pipeline output to final format.
|
||||
|
||||
HuggingFace pipelines with output_type="pt" return images in (B, C, H, W) format.
|
||||
|
||||
Args:
|
||||
image: Image tensor in (B, C, H, W) or (C, H, W) format from HuggingFace pipeline
|
||||
|
||||
Returns:
|
||||
Post-processed image tensor in (H, W, C) uint8 format
|
||||
|
||||
Note:
|
||||
Assumes image values are in [-1, 1] range (standard pipeline output).
|
||||
"""
|
||||
# Remove batch dimension if present
|
||||
if image.ndim == 4:
|
||||
image = image[0] # (B, C, H, W) -> (C, H, W)
|
||||
|
||||
# Convert to (H, W, C) format
|
||||
image = image.permute(1, 2, 0) # (C, H, W) -> (H, W, C)
|
||||
|
||||
# Normalize to [0, 1] range
|
||||
image = (image / 2 + 0.5).clamp(0, 1)
|
||||
|
||||
# Convert to uint8
|
||||
image = (image * 255).round().to(torch.uint8)
|
||||
|
||||
return image
|
||||
|
||||
|
||||
class OutputHandler:
|
||||
"""Handle saving of generated outputs in various formats.
|
||||
|
||||
Supports MediaOutput from all models:
|
||||
- Video models (WAN): MediaOutput(video=torch.Tensor)
|
||||
- Image models: MediaOutput(image=torch.Tensor)
|
||||
- Video+Audio models: MediaOutput(video=torch.Tensor, audio=torch.Tensor)
|
||||
|
||||
Supported output formats:
|
||||
- .png: Save single image or middle frame
|
||||
- .gif: Save video as animated GIF (no audio)
|
||||
- .mp4: Save video with audio (requires diffusers export_utils)
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def save(output: MediaOutput, output_path: str, frame_rate: float = 24.0):
|
||||
"""Save output based on content type and file extension.
|
||||
|
||||
Args:
|
||||
output: MediaOutput containing model outputs (image/video/audio)
|
||||
output_path: Path to save the output file
|
||||
frame_rate: Frames per second for video output (default: 24.0)
|
||||
"""
|
||||
if not isinstance(output, MediaOutput):
|
||||
raise ValueError(f"Expected output to be MediaOutput, got {type(output)}")
|
||||
|
||||
file_ext = os.path.splitext(output_path)[1].lower()
|
||||
|
||||
# Determine content type
|
||||
if output.image is not None:
|
||||
OutputHandler._save_image(output.image, output_path, file_ext)
|
||||
elif output.video is not None:
|
||||
OutputHandler._save_video(output.video, output.audio, output_path, file_ext, frame_rate)
|
||||
else:
|
||||
raise ValueError("Unknown output format. MediaOutput has no image or video data.")
|
||||
|
||||
@staticmethod
|
||||
def _save_image(image: torch.Tensor, output_path: str, file_ext: str):
|
||||
"""Save single image output.
|
||||
|
||||
Args:
|
||||
image: Image as torch tensor (H, W, C) uint8
|
||||
output_path: Path to save the image
|
||||
file_ext: File extension (.png, .jpg, etc.)
|
||||
"""
|
||||
if file_ext not in [".png", ".jpg", ".jpeg"]:
|
||||
logger.warning(f"Image output requested with {file_ext}, defaulting to .png")
|
||||
output_path = output_path.replace(file_ext, ".png")
|
||||
|
||||
# Convert torch.Tensor to PIL Image and save
|
||||
image_np = image.cpu().numpy()
|
||||
Image.fromarray(image_np).save(output_path)
|
||||
logger.info(f"Saved image to {output_path}")
|
||||
|
||||
@staticmethod
|
||||
def _save_video(
|
||||
video: torch.Tensor,
|
||||
audio: Optional[torch.Tensor],
|
||||
output_path: str,
|
||||
file_ext: str,
|
||||
frame_rate: float,
|
||||
):
|
||||
"""Save video output with optional audio.
|
||||
|
||||
Args:
|
||||
video: Video frames as torch tensor (T, H, W, C) with dtype uint8
|
||||
audio: Optional audio as torch tensor
|
||||
output_path: Path to save the video
|
||||
file_ext: File extension (.mp4, .gif, .png)
|
||||
frame_rate: Frames per second
|
||||
"""
|
||||
if file_ext == ".mp4":
|
||||
OutputHandler._save_mp4(video, audio, output_path, frame_rate)
|
||||
elif file_ext == ".gif":
|
||||
OutputHandler._save_gif(video, output_path, frame_rate)
|
||||
elif file_ext == ".png":
|
||||
OutputHandler._save_middle_frame(video, output_path)
|
||||
else:
|
||||
logger.warning(f"Unsupported video output format: {file_ext}, defaulting to .png")
|
||||
output_path = output_path.replace(file_ext, ".png")
|
||||
OutputHandler._save_middle_frame(video, output_path)
|
||||
|
||||
@staticmethod
|
||||
def _save_mp4(
|
||||
video: torch.Tensor, audio: Optional[torch.Tensor], output_path: str, frame_rate: float
|
||||
):
|
||||
"""Save video with optional audio as MP4.
|
||||
|
||||
Args:
|
||||
video: Video frames as torch tensor (T, H, W, C) uint8
|
||||
audio: Optional audio as torch tensor (float32)
|
||||
output_path: Output path for MP4
|
||||
frame_rate: Frames per second
|
||||
"""
|
||||
try:
|
||||
from diffusers.pipelines.ltx2.export_utils import encode_video
|
||||
|
||||
# Prepare audio if present
|
||||
audio_prepared = audio.float() if audio is not None else None
|
||||
|
||||
# encode_video expects (T, H, W, C) uint8 video and float32 audio
|
||||
encode_video(
|
||||
video,
|
||||
fps=frame_rate,
|
||||
audio=audio_prepared,
|
||||
audio_sample_rate=24000 if audio_prepared is not None else None,
|
||||
output_path=output_path,
|
||||
)
|
||||
logger.info(f"Saved video{' with audio' if audio is not None else ''} to {output_path}")
|
||||
|
||||
except ImportError:
|
||||
logger.warning(
|
||||
"diffusers export_utils (encode_video) not available. "
|
||||
"Falling back to saving middle frame as PNG."
|
||||
)
|
||||
png_path = output_path.replace(".mp4", ".png")
|
||||
OutputHandler._save_middle_frame(video, png_path)
|
||||
|
||||
@staticmethod
|
||||
def _save_gif(video: torch.Tensor, output_path: str, frame_rate: float):
|
||||
"""Save video as animated GIF.
|
||||
|
||||
Args:
|
||||
video: Video frames as torch tensor (T, H, W, C) uint8
|
||||
output_path: Output path for GIF
|
||||
frame_rate: Frames per second
|
||||
"""
|
||||
# Convert torch.Tensor to numpy for PIL
|
||||
video_np = video.cpu().numpy()
|
||||
|
||||
# Convert to list of PIL Images
|
||||
frames = [Image.fromarray(video_np[i]) for i in range(video_np.shape[0])]
|
||||
|
||||
# Save as animated GIF
|
||||
duration_ms = int(1000 / frame_rate)
|
||||
frames[0].save(
|
||||
output_path,
|
||||
save_all=True,
|
||||
append_images=frames[1:],
|
||||
optimize=False,
|
||||
duration=duration_ms,
|
||||
loop=0,
|
||||
)
|
||||
logger.info(f"Saved video as GIF to {output_path} ({len(frames)} frames)")
|
||||
|
||||
@staticmethod
|
||||
def _save_middle_frame(video: torch.Tensor, output_path: str):
|
||||
"""Save middle frame of video as PNG.
|
||||
|
||||
Args:
|
||||
video: Video frames as torch tensor (T, H, W, C) uint8
|
||||
output_path: Output path for PNG
|
||||
"""
|
||||
# Convert torch.Tensor to numpy for PIL
|
||||
video_np = video.cpu().numpy()
|
||||
|
||||
# Extract middle frame
|
||||
frame_idx = video_np.shape[0] // 2
|
||||
Image.fromarray(video_np[frame_idx]).save(output_path)
|
||||
logger.info(f"Saved frame {frame_idx} to {output_path}")
|
||||
322
examples/visual_gen/serve/README.md
Normal file
322
examples/visual_gen/serve/README.md
Normal file
@ -0,0 +1,322 @@
|
||||
# Visual Generation API Examples
|
||||
|
||||
This directory contains example scripts that demonstrate how to use the TensorRT-LLM Visual Generation API endpoints for image and video generation.
|
||||
|
||||
## Overview
|
||||
|
||||
These examples show how to interact with the visual generation server using both the OpenAI Python SDK and standard HTTP requests. The API provides endpoints for:
|
||||
|
||||
- **Image Generation**: Text-to-image generation (T2I)
|
||||
- **Video Generation**:
|
||||
- Text-to-video generation (T2V) - generate videos from text prompts only
|
||||
- Text+Image-to-video generation (TI2V) - generate videos from text + reference image
|
||||
- Both synchronous and asynchronous modes supported
|
||||
- Multipart/form-data support for file uploads
|
||||
- **Video Management**: Retrieving and deleting generated videos
|
||||
|
||||
## Prerequisites
|
||||
|
||||
Before running these examples, ensure you have:
|
||||
|
||||
1. **Install modules**: Install required dependencies before running examples:
|
||||
|
||||
```bash
|
||||
pip install git+https://github.com/huggingface/diffusers.git
|
||||
pip install av
|
||||
```
|
||||
|
||||
2. **Server Running**: The TensorRT-LLM visual generation server must be running
|
||||
```bash
|
||||
trtllm-serve <path to your model> --extra_visual_gen_options <path to config yaml>
|
||||
```
|
||||
|
||||
e.g.
|
||||
|
||||
```bash
|
||||
trtllm-serve $LLM_MODEL_DIR/Wan2.1-T2V-1.3B-Diffusers --extra_visual_gen_options ./configs/wan.yml
|
||||
|
||||
# Run server on background:
|
||||
trtllm-serve $LLM_MODEL_DIR/Wan2.1-T2V-1.3B-Diffusers --extra_visual_gen_options ./configs/wan.yml > /tmp/serve.log 2>&1 &
|
||||
|
||||
## Check if the server is setup
|
||||
tail -f /tmp/serve.log
|
||||
|
||||
```
|
||||
|
||||
## Examples
|
||||
|
||||
Current supported & tested models:
|
||||
|
||||
1. WAN T2V/I2V for video generation (t2v, ti2v, delete_video)
|
||||
|
||||
### 1. Synchronous Image Generation (`sync_t2i.py`)
|
||||
|
||||
Demonstrates synchronous text-to-image generation using the OpenAI SDK.
|
||||
|
||||
**Features:**
|
||||
- Generates images from text prompts
|
||||
- Supports configurable image size and quality
|
||||
- Returns base64-encoded images or URLs
|
||||
- Saves generated images to disk
|
||||
|
||||
**Usage:**
|
||||
```bash
|
||||
# Use default localhost server
|
||||
python sync_image_gen.py
|
||||
|
||||
# Specify custom server URL
|
||||
python sync_image_gen.py http://your-server:8000/v1
|
||||
```
|
||||
|
||||
**API Endpoint:** `POST /v1/images/generations`
|
||||
|
||||
**Output:** Saves generated image to `output_generation.png` (or numbered files for multiple images)
|
||||
|
||||
---
|
||||
|
||||
### 2. Synchronous Video Generation with T2V and TI2V Modes (`sync_video_gen.py`)
|
||||
|
||||
Demonstrates synchronous video generation using direct HTTP requests. Waits for completion and returns the video file directly.
|
||||
|
||||
**Features:**
|
||||
- **T2V Mode**: Generate videos from text prompts only
|
||||
- **TI2V Mode**: Generate videos from text + reference image (multipart/form-data)
|
||||
- Waits for video generation to complete before returning
|
||||
- Returns video file directly in response
|
||||
- Command-line interface for easy testing
|
||||
|
||||
**Usage:**
|
||||
|
||||
```bash
|
||||
# Text-to-Video (T2V) - No reference image
|
||||
python sync_video_gen.py --mode t2v \
|
||||
--prompt "A cute cat playing with a ball in the park" \
|
||||
--duration 4.0 --fps 24 --size 256x256
|
||||
|
||||
# Text+Image-to-Video (TI2V) - With reference image
|
||||
## Note: longer duration and higher size will lead to much longer waiting time
|
||||
python sync_video_gen.py --mode ti2v \
|
||||
--prompt "She turns around and smiles, then slowly walks out of the frame" \
|
||||
--image ./media/woman_skyline_original_720p.jpeg \
|
||||
--duration 4.0 --fps 24 --size 512x512
|
||||
|
||||
# Custom parameters
|
||||
python sync_video_gen.py --mode t2v \
|
||||
--prompt "A serene sunset over the ocean" \
|
||||
--duration 5.0 --fps 30 --size 512x512 \
|
||||
--output my_video.mp4
|
||||
```
|
||||
|
||||
**Command-Line Arguments:**
|
||||
- `--mode` - Generation mode: `t2v` or `ti2v` (default: t2v)
|
||||
- `--prompt` - Text prompt for video generation (required)
|
||||
- `--image` - Path to reference image (required for ti2v mode)
|
||||
- `--base-url` - API server URL (default: http://localhost:8000/v1)
|
||||
- `--model` - Model name (default: wan)
|
||||
- `--duration` - Video duration in seconds (default: 4.0)
|
||||
- `--fps` - Frames per second (default: 24)
|
||||
- `--size` - Video resolution in WxH format (default: 256x256)
|
||||
- `--output` - Output video file path (default: output_sync.mp4)
|
||||
|
||||
**API Endpoint:** `POST /v1/videos/generations`
|
||||
|
||||
**API Details:**
|
||||
- T2V uses JSON `Content-Type: application/json`
|
||||
- TI2V uses multipart/form-data `Content-Type: multipart/form-data` with file upload
|
||||
|
||||
**Output:** Saves generated video to specified output file
|
||||
|
||||
---
|
||||
|
||||
### 3. Async Video Generation with T2V and TI2V Modes (`async_video_gen.py`)
|
||||
|
||||
**NEW**: Enhanced async video generation supporting both Text-to-Video (T2V) and Text+Image-to-Video (TI2V) modes.
|
||||
|
||||
**Features:**
|
||||
- **T2V Mode**: Generate videos from text prompts only (JSON request)
|
||||
- **TI2V Mode**: Generate videos from text + reference image (multipart/form-data with file upload)
|
||||
- Command-line interface for easy testing
|
||||
- Automatic mode detection
|
||||
- Comprehensive parameter control
|
||||
|
||||
**Usage:**
|
||||
|
||||
```bash
|
||||
# Text-to-Video (T2V) - No reference image
|
||||
python async_video_gen.py --mode t2v \
|
||||
--prompt "A cool cat on a motorcycle in the night" \
|
||||
--duration 4.0 --fps 24 --size 256x256
|
||||
|
||||
# Text+Image-to-Video (TI2V) - With reference image
|
||||
python async_video_gen.py --mode ti2v \
|
||||
--prompt "She turns around and smiles, then slowly walks out of the frame" \
|
||||
--image ./media/woman_skyline_original_720p.jpeg \
|
||||
--duration 4.0 --fps 24 --size 512x512
|
||||
|
||||
# Custom parameters
|
||||
python async_video_gen.py --mode t2v \
|
||||
--prompt "A serene sunset over the ocean" \
|
||||
--duration 5.0 --fps 30 --size 512x512 \
|
||||
--output my_video.mp4
|
||||
```
|
||||
|
||||
**Command-Line Arguments:**
|
||||
- `--mode` - Generation mode: `t2v` or `ti2v` (default: t2v)
|
||||
- `--prompt` - Text prompt for video generation (required)
|
||||
- `--image` - Path to reference image (required for ti2v mode)
|
||||
- `--base-url` - API server URL (default: http://localhost:8000/v1)
|
||||
- `--model` - Model name (default: wan)
|
||||
- `--duration` - Video duration in seconds (default: 4.0)
|
||||
- `--fps` - Frames per second (default: 24)
|
||||
- `--size` - Video resolution in WxH format (default: 256x256)
|
||||
- `--output` - Output video file path (default: output_async.mp4)
|
||||
|
||||
**API Details:**
|
||||
- T2V uses JSON `Content-Type: application/json`
|
||||
- TI2V uses multipart/form-data `Content-Type: multipart/form-data` with file upload
|
||||
|
||||
**Output:** Saves generated video to specified output file
|
||||
|
||||
---
|
||||
|
||||
### 4. Video Deletion (`delete_video.py`)
|
||||
|
||||
Demonstrates the complete lifecycle of video generation and deletion.
|
||||
|
||||
**Features:**
|
||||
- Creates a test video generation job
|
||||
- Waits for completion
|
||||
- Deletes the generated video
|
||||
- Verifies deletion by attempting to retrieve the deleted video
|
||||
- Tests error handling for non-existent videos
|
||||
|
||||
**Usage:**
|
||||
```bash
|
||||
# Use default localhost server
|
||||
python delete_video.py
|
||||
|
||||
# Specify custom server URL
|
||||
python delete_video.py http://your-server:8000/v1
|
||||
```
|
||||
|
||||
**API Endpoints:**
|
||||
- `POST /v1/videos` - Create video job
|
||||
- `GET /v1/videos/{video_id}` - Check video status
|
||||
- `DELETE /v1/videos/{video_id}` - Delete video
|
||||
|
||||
**Test Flow:**
|
||||
1. Create video generation job
|
||||
2. Wait for completion
|
||||
3. Delete the video
|
||||
4. Verify video returns `NotFoundError`
|
||||
5. Test deletion of non-existent video
|
||||
|
||||
---
|
||||
|
||||
## API Configuration
|
||||
|
||||
All examples use the following default configuration:
|
||||
|
||||
- **Base URL**: `http://localhost:8000/v1`
|
||||
- **API Key**: `"tensorrt_llm"` (authentication token)
|
||||
- **Timeout**: 300 seconds for async operations
|
||||
|
||||
You can customize these by:
|
||||
1. Passing the base URL as a command-line argument
|
||||
2. Modifying the default parameters in each script's function
|
||||
|
||||
## Common Parameters
|
||||
|
||||
### Image Generation
|
||||
- `model`: Model identifier (e.g., "wan")
|
||||
- `prompt`: Text description
|
||||
- `n`: Number of images to generate
|
||||
- `size`: Image dimensions (e.g., "512x512", "1024x1024")
|
||||
- `quality`: "standard" or "hd"
|
||||
- `response_format`: "b64_json" or "url"
|
||||
|
||||
### Video Generation
|
||||
- `model`: Model identifier (e.g., "wan")
|
||||
- `prompt`: Text description
|
||||
- `size`: Video resolution (e.g., "256x256", "512x512")
|
||||
- `seconds`: Duration in seconds
|
||||
- `fps`: Frames per second
|
||||
- `input_reference`: Reference image file (for TI2V mode)
|
||||
|
||||
## Quick Reference - curl Examples
|
||||
|
||||
### Text-to-Video (JSON)
|
||||
```bash
|
||||
curl -X POST "http://localhost:8000/v1/videos" \
|
||||
-H "Content-Type: application/json" \
|
||||
-d '{
|
||||
"prompt": "A cool cat on a motorcycle",
|
||||
"seconds": 4.0,
|
||||
"fps": 24,
|
||||
"size": "256x256"
|
||||
}'
|
||||
```
|
||||
|
||||
### Text+Image-to-Video (Multipart with File Upload)
|
||||
```bash
|
||||
curl -X POST "http://localhost:8000/v1/videos" \
|
||||
-F "prompt=She turns around and smiles" \
|
||||
-F "input_reference=@./media/woman_skyline_original_720p.jpeg" \
|
||||
-F "seconds=4.0" \
|
||||
-F "fps=24" \
|
||||
-F "size=256x256" \
|
||||
-F "guidance_scale=5.0"
|
||||
```
|
||||
|
||||
### Check Video Status
|
||||
```bash
|
||||
curl -X GET "http://localhost:8000/v1/videos/{video_id}"
|
||||
```
|
||||
|
||||
### Download Video
|
||||
```bash
|
||||
curl -X GET "http://localhost:8000/v1/videos/{video_id}/content" -o output.mp4
|
||||
```
|
||||
|
||||
### Delete Video
|
||||
```bash
|
||||
curl -X DELETE "http://localhost:8000/v1/videos/{video_id}"
|
||||
```
|
||||
|
||||
## API Endpoints Summary
|
||||
|
||||
| Endpoint | Method | Mode | Content-Type | Purpose |
|
||||
|----------|--------|------|--------------|---------|
|
||||
| `/v1/videos` | POST | Async | JSON or Multipart | Create video job (T2V/TI2V) |
|
||||
| `/v1/videos/generations` | POST | Sync | JSON or Multipart | Generate video sync (T2V/TI2V) |
|
||||
| `/v1/videos/{id}` | GET | - | - | Get video status/metadata |
|
||||
| `/v1/videos/{id}/content` | GET | - | - | Download video file |
|
||||
| `/v1/videos/{id}` | DELETE | - | - | Delete video |
|
||||
| `/v1/videos` | GET | - | - | List all videos |
|
||||
| `/v1/images/generations` | POST | - | JSON | Generate images (T2I) |
|
||||
|
||||
**Note:** Both `/v1/videos` (async) and `/v1/videos/generations` (sync) support:
|
||||
- **JSON**: Standard text-to-video (T2V)
|
||||
- **Multipart/Form-Data**: Text+image-to-video (TI2V) with file upload
|
||||
|
||||
## Error Handling
|
||||
|
||||
All examples include comprehensive error handling:
|
||||
|
||||
- Connection errors (server not running)
|
||||
- API errors (invalid parameters, model not found)
|
||||
- Timeout errors (generation taking too long)
|
||||
- Resource errors (video not found for deletion)
|
||||
|
||||
Errors are displayed with full stack traces for debugging.
|
||||
|
||||
## Output Files
|
||||
|
||||
Generated files are saved to the current working directory:
|
||||
|
||||
- `output_generation.png` - Synchronous image generation (`sync_image_gen.py`)
|
||||
- `output_sync.mp4` - Synchronous video generation (`sync_video_gen.py`)
|
||||
- `output_async.mp4` - Asynchronous video generation (`async_video_gen.py`)
|
||||
- `output_multipart.mp4` - Multipart example output (`multipart_example.py`)
|
||||
|
||||
**Note:** You can customize output filenames using the `--output` parameter in all scripts.
|
||||
238
examples/visual_gen/serve/async_video_gen.py
Executable file
238
examples/visual_gen/serve/async_video_gen.py
Executable file
@ -0,0 +1,238 @@
|
||||
#!/usr/bin/env python
|
||||
"""Test script for asynchronous video generation endpoint.
|
||||
|
||||
Tests POST /v1/videos endpoint which returns immediately with a job ID.
|
||||
The video is generated in the background and can be retrieved later.
|
||||
|
||||
Supports two modes:
|
||||
- Text-to-Video (T2V): Generate video from text prompt only
|
||||
- Text+Image-to-Video (TI2V): Generate video from text prompt + reference image
|
||||
|
||||
Examples:
|
||||
# Text-to-Video (T2V)
|
||||
python async_video_gen.py --mode t2v --prompt "A cool cat on a motorcycle"
|
||||
|
||||
# Text+Image-to-Video (TI2V)
|
||||
python async_video_gen.py --mode ti2v --prompt "She turns and smiles" --image ./media/woman.jpg
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import sys
|
||||
import time
|
||||
from pathlib import Path
|
||||
|
||||
import openai
|
||||
|
||||
|
||||
def test_async_video_generation(
|
||||
base_url: str = "http://localhost:8000/v1",
|
||||
model: str = "wan",
|
||||
prompt: str = "A video of a cool cat on a motorcycle in the night",
|
||||
input_reference: str = None,
|
||||
duration: float = 4.0,
|
||||
fps: int = 24,
|
||||
size: str = "256x256",
|
||||
output_file: str = "output_async.mp4",
|
||||
):
|
||||
"""Test asynchronous video generation with OpenAI SDK.
|
||||
|
||||
Args:
|
||||
base_url: Base URL of the API server
|
||||
model: Model name to use
|
||||
prompt: Text prompt for generation
|
||||
input_reference: Path to reference image (optional, for TI2V mode)
|
||||
duration: Video duration in seconds
|
||||
fps: Frames per second
|
||||
size: Video resolution (WxH format)
|
||||
output_file: Output video file path
|
||||
"""
|
||||
mode = "TI2V" if input_reference else "T2V"
|
||||
print("=" * 80)
|
||||
print(f"Testing Async Video Generation API - {mode} Mode")
|
||||
print("=" * 80)
|
||||
|
||||
# Initialize client
|
||||
client = openai.OpenAI(base_url=base_url, api_key="tensorrt_llm")
|
||||
|
||||
print("\n1. Creating video generation job...")
|
||||
print(f" Mode: {mode}")
|
||||
print(f" Prompt: {prompt}")
|
||||
if input_reference:
|
||||
print(f" Input Reference: {input_reference}")
|
||||
print(f" Duration: {duration}s")
|
||||
print(f" FPS: {fps}")
|
||||
print(f" Size: {size}")
|
||||
|
||||
try:
|
||||
# Prepare request parameters
|
||||
create_params = {
|
||||
"model": model,
|
||||
"prompt": prompt,
|
||||
"size": size,
|
||||
"seconds": duration,
|
||||
"extra_body": {
|
||||
"fps": fps,
|
||||
},
|
||||
}
|
||||
|
||||
# Add input reference if provided (TI2V mode)
|
||||
if input_reference:
|
||||
if not Path(input_reference).exists():
|
||||
print(f"\n❌ Error: Input reference image not found: {input_reference}")
|
||||
return False
|
||||
create_params["input_reference"] = open(input_reference, "rb")
|
||||
|
||||
# Create video generation job
|
||||
job = client.videos.create(**create_params)
|
||||
|
||||
print("Video generation started: \n", job.model_dump_json(indent=2))
|
||||
|
||||
video_id = job.id
|
||||
print("\n✓ Job created successfully!")
|
||||
print(f" Video ID: {video_id}")
|
||||
print(f" Status: {job.status}")
|
||||
|
||||
# Poll for completion
|
||||
print("\n2. Polling for completion...")
|
||||
max_attempts = 300 # 5 minutes with 1s intervals
|
||||
attempt = 0
|
||||
|
||||
while attempt < max_attempts:
|
||||
attempt += 1
|
||||
|
||||
# Get job status using SDK's get method
|
||||
job = client.videos.retrieve(video_id)
|
||||
status = job.status
|
||||
|
||||
print(f" [{attempt:3d}] Status: {status}", end="\r")
|
||||
|
||||
if status == "completed":
|
||||
print("\n\n✓ Video generation completed!")
|
||||
print(f" Completion time: {job.completed_at}")
|
||||
break
|
||||
elif status == "failed":
|
||||
print("\n\n❌ Video generation failed!")
|
||||
print(f" Error: {job.error}")
|
||||
return False
|
||||
|
||||
time.sleep(1)
|
||||
else:
|
||||
print(f"\n\n❌ Timeout waiting for completion (>{max_attempts}s)")
|
||||
return False
|
||||
|
||||
# Download video
|
||||
print("\n3. Downloading video...")
|
||||
# For binary content, use the underlying HTTP client
|
||||
content = client.videos.download_content(video_id, variant="video")
|
||||
content.write_to_file(output_file)
|
||||
print(f" ✓ Saved to: {output_file}")
|
||||
|
||||
print("\n" + "=" * 80)
|
||||
print("✓ Async video generation test completed successfully!")
|
||||
print("=" * 80)
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
print(f"\n❌ Error: {e}")
|
||||
import traceback
|
||||
|
||||
traceback.print_exc()
|
||||
return False
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Test async video generation API with T2V and TI2V modes",
|
||||
formatter_class=argparse.RawDescriptionHelpFormatter,
|
||||
epilog="""
|
||||
Examples:
|
||||
# Text-to-Video (T2V)
|
||||
python async_video_gen.py --mode t2v --prompt "A cool cat on a motorcycle"
|
||||
|
||||
# Text+Image-to-Video (TI2V)
|
||||
python async_video_gen.py --mode ti2v \\
|
||||
--prompt "She turns around and smiles, then slowly walks out of the frame" \\
|
||||
--image ./media/woman_skyline_original_720p.jpeg
|
||||
|
||||
# Custom parameters
|
||||
python async_video_gen.py --mode t2v \\
|
||||
--prompt "A serene sunset over the ocean" \\
|
||||
--duration 5.0 --fps 30 --size 512x512 \\
|
||||
--output my_video.mp4
|
||||
""",
|
||||
)
|
||||
|
||||
# Mode selection
|
||||
parser.add_argument(
|
||||
"--mode",
|
||||
choices=["t2v", "ti2v"],
|
||||
default="t2v",
|
||||
help="Generation mode: t2v (Text-to-Video) or ti2v (Text+Image-to-Video)",
|
||||
)
|
||||
|
||||
# Required parameters
|
||||
parser.add_argument(
|
||||
"--prompt",
|
||||
type=str,
|
||||
default="A video of a cool cat on a motorcycle in the night",
|
||||
help="Text prompt for video generation",
|
||||
)
|
||||
|
||||
# TI2V mode parameters
|
||||
parser.add_argument(
|
||||
"--image",
|
||||
"--input-reference",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Path to reference image (required for ti2v mode)",
|
||||
)
|
||||
|
||||
# Optional parameters
|
||||
parser.add_argument(
|
||||
"--base-url",
|
||||
type=str,
|
||||
default="http://localhost:8000/v1",
|
||||
help="Base URL of the API server",
|
||||
)
|
||||
parser.add_argument("--model", type=str, default="wan", help="Model name to use")
|
||||
parser.add_argument(
|
||||
"--duration", "--seconds", type=float, default=4.0, help="Video duration in seconds"
|
||||
)
|
||||
parser.add_argument("--fps", type=int, default=24, help="Frames per second")
|
||||
parser.add_argument(
|
||||
"--size",
|
||||
type=str,
|
||||
default="256x256",
|
||||
help="Video resolution in WxH format (e.g., 1280x720)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--output", type=str, default="output_async.mp4", help="Output video file path"
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
# Validate ti2v mode requirements
|
||||
if args.mode == "ti2v" and not args.image:
|
||||
parser.error("--image is required when using --mode ti2v")
|
||||
|
||||
# Display configuration
|
||||
print("\n" + "=" * 80)
|
||||
print("OpenAI SDK - Async Video Generation Test")
|
||||
print("=" * 80)
|
||||
print(f"Base URL: {args.base_url}")
|
||||
print(f"Mode: {args.mode.upper()}")
|
||||
print()
|
||||
|
||||
# Test async video generation
|
||||
success = test_async_video_generation(
|
||||
base_url=args.base_url,
|
||||
model=args.model,
|
||||
prompt=args.prompt,
|
||||
input_reference=args.image,
|
||||
duration=args.duration,
|
||||
fps=args.fps,
|
||||
size=args.size,
|
||||
output_file=args.output,
|
||||
)
|
||||
|
||||
sys.exit(0 if success else 1)
|
||||
8
examples/visual_gen/serve/configs/wan.yml
Normal file
8
examples/visual_gen/serve/configs/wan.yml
Normal file
@ -0,0 +1,8 @@
|
||||
linear:
|
||||
type: default
|
||||
teacache:
|
||||
enable_teacache: true
|
||||
teacache_thresh: 0.2
|
||||
parallel:
|
||||
dit_cfg_size: 1
|
||||
dit_ulysses_size: 1
|
||||
151
examples/visual_gen/serve/delete_video.py
Executable file
151
examples/visual_gen/serve/delete_video.py
Executable file
@ -0,0 +1,151 @@
|
||||
#!/usr/bin/env python
|
||||
"""Test script for DELETE /v1/videos/{video_id} endpoint.
|
||||
|
||||
Tests the video deletion functionality by:
|
||||
1. Creating a video generation job
|
||||
2. Waiting for completion
|
||||
3. Deleting the video
|
||||
4. Verifying the deletion
|
||||
"""
|
||||
|
||||
import sys
|
||||
import time
|
||||
|
||||
import openai
|
||||
|
||||
|
||||
def test_delete_video(
|
||||
base_url: str = "http://localhost:8000/v1",
|
||||
model: str = "wan",
|
||||
prompt: str = "A simple test video for deletion",
|
||||
duration: float = 2.0,
|
||||
fps: int = 8,
|
||||
size: str = "256x256",
|
||||
):
|
||||
"""Test video deletion endpoint using OpenAI SDK."""
|
||||
print("=" * 80)
|
||||
print("Testing DELETE /v1/videos/{video_id} Endpoint")
|
||||
print("=" * 80)
|
||||
|
||||
# Initialize OpenAI client
|
||||
client = openai.OpenAI(base_url=base_url, api_key="tensorrt_llm")
|
||||
|
||||
video_id = None
|
||||
|
||||
try:
|
||||
# Step 1: Create a video generation job
|
||||
print("\n1. Creating video generation job...")
|
||||
print(f" Prompt: {prompt}")
|
||||
print(f" Duration: {duration}s")
|
||||
print(f" FPS: {fps}")
|
||||
print(f" Size: {size}")
|
||||
|
||||
job = client.videos.create(
|
||||
model=model,
|
||||
prompt=prompt,
|
||||
size=size,
|
||||
seconds=duration,
|
||||
extra_body={
|
||||
"fps": fps,
|
||||
},
|
||||
)
|
||||
|
||||
video_id = job.id
|
||||
print(f" ✓ Video job created with ID: {video_id}")
|
||||
print(f" Status: {job.status}")
|
||||
|
||||
# Step 2: Wait for video completion
|
||||
print("\n2. Waiting for video generation to complete...")
|
||||
max_attempts = 60 # attempts with 1s intervals
|
||||
attempt = 0
|
||||
|
||||
while attempt < max_attempts:
|
||||
attempt += 1
|
||||
|
||||
# Get job status using SDK's retrieve method
|
||||
job = client.videos.retrieve(video_id)
|
||||
status = job.status
|
||||
|
||||
print(f" [{attempt:3d}] Status: {status}", end="\r")
|
||||
|
||||
if status == "completed":
|
||||
print(" ✓ Video generation completed!")
|
||||
break
|
||||
elif status == "failed":
|
||||
print(" ❌ Video generation failed!")
|
||||
return False
|
||||
|
||||
time.sleep(1)
|
||||
else:
|
||||
print(" ⚠ Timeout waiting for video completion")
|
||||
# Continue with deletion anyway
|
||||
|
||||
# Step 3: Delete the video
|
||||
print(f"\n3. Deleting video {video_id}...")
|
||||
|
||||
delete_result = client.videos.delete(video_id)
|
||||
|
||||
print(f" Response: {delete_result.model_dump_json(indent=2)}")
|
||||
|
||||
if delete_result.deleted:
|
||||
print(" ✓ Video deleted successfully!")
|
||||
else:
|
||||
print(" ❌ Video deletion returned False")
|
||||
return False
|
||||
|
||||
# Step 4: Verify the video is gone
|
||||
print("\n4. Verifying video deletion...")
|
||||
|
||||
try:
|
||||
verify_job = client.videos.retrieve(video_id)
|
||||
print(f" ⚠ Video still exists after deletion: {verify_job.status}")
|
||||
return False
|
||||
except openai.NotFoundError as e:
|
||||
print(" ✓ Video correctly returns NotFoundError")
|
||||
print(f" Error message: {e.message}")
|
||||
except Exception as e:
|
||||
print(f" ⚠ Unexpected error: {type(e).__name__}: {e}")
|
||||
|
||||
# Step 5: Test deleting non-existent video
|
||||
print("\n5. Testing deletion of non-existent video...")
|
||||
|
||||
fake_id = "nonexistent_video_id"
|
||||
|
||||
try:
|
||||
fake_delete_result = client.videos.delete(fake_id)
|
||||
print(" ⚠ Deletion of non-existent video did not raise error")
|
||||
print(f" Response: {fake_delete_result.model_dump_json(indent=2)}")
|
||||
except openai.NotFoundError as e:
|
||||
print(" ✓ Correctly raises NotFoundError for non-existent video")
|
||||
print(f" Error message: {e.message}")
|
||||
except Exception as e:
|
||||
print(f" ⚠ Unexpected error: {type(e).__name__}: {e}")
|
||||
|
||||
print("\n" + "=" * 80)
|
||||
print("✓ Video deletion test completed successfully!")
|
||||
print("=" * 80)
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
print(f"\n❌ Error: {e}")
|
||||
import traceback
|
||||
|
||||
traceback.print_exc()
|
||||
return False
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# Parse command line arguments
|
||||
base_url = sys.argv[1] if len(sys.argv) > 1 else "http://localhost:8000/v1"
|
||||
|
||||
print("\n" + "=" * 80)
|
||||
print("OpenAI SDK - Video Deletion Test")
|
||||
print("=" * 80)
|
||||
print(f"Base URL: {base_url}")
|
||||
print()
|
||||
|
||||
# Test video deletion
|
||||
success = test_delete_video(base_url=base_url)
|
||||
|
||||
# Exit with appropriate code
|
||||
sys.exit(0 if success else 1)
|
||||
BIN
examples/visual_gen/serve/media/woman_skyline_original_720p.jpeg
Normal file
BIN
examples/visual_gen/serve/media/woman_skyline_original_720p.jpeg
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 174 KiB |
91
examples/visual_gen/serve/sync_image_gen.py
Executable file
91
examples/visual_gen/serve/sync_image_gen.py
Executable file
@ -0,0 +1,91 @@
|
||||
#!/usr/bin/env python
|
||||
"""Test script for image generation endpoints.
|
||||
|
||||
Tests:
|
||||
- POST /v1/images/generations - Generate images from text
|
||||
- POST /v1/images/edits - Edit images with text prompts
|
||||
"""
|
||||
|
||||
import base64
|
||||
import sys
|
||||
|
||||
import openai
|
||||
|
||||
|
||||
def test_image_generation(
|
||||
base_url: str = "http://localhost:8000/v1",
|
||||
model: str = "flux2",
|
||||
prompt: str = "A lovely cat lying on a sofa",
|
||||
n: int = 1,
|
||||
size: str = "512x512",
|
||||
quality: str = "standard",
|
||||
response_format: str = "b64_json",
|
||||
output_file: str = "output_generation.png",
|
||||
):
|
||||
"""Test image generation endpoint."""
|
||||
print("=" * 80)
|
||||
print("Testing Image Generation API (POST /v1/images/generations)")
|
||||
print("=" * 80)
|
||||
|
||||
# Initialize client
|
||||
client = openai.OpenAI(base_url=base_url, api_key="tensorrt_llm")
|
||||
|
||||
print("\n1. Generating image...")
|
||||
print(f" Prompt: {prompt}")
|
||||
print(f" Size: {size}")
|
||||
print(f" Quality: {quality}")
|
||||
print(f" Number of images: {n}")
|
||||
|
||||
try:
|
||||
# Use OpenAI SDK's images.generate() method
|
||||
response = client.images.generate(
|
||||
model=model,
|
||||
prompt=prompt,
|
||||
n=n,
|
||||
size=size,
|
||||
quality=quality,
|
||||
response_format=response_format,
|
||||
)
|
||||
|
||||
print("\n✓ Image generated successfully!")
|
||||
print(f" Number of images: {len(response.data)}")
|
||||
|
||||
# Save images
|
||||
for i, image in enumerate(response.data):
|
||||
if response_format == "b64_json":
|
||||
# Decode base64 image
|
||||
image_data = base64.b64decode(image.b64_json)
|
||||
output = f"{output_file.rsplit('.', 1)[0]}_{i}.png" if n > 1 else output_file
|
||||
|
||||
with open(output, "wb") as f:
|
||||
f.write(image_data)
|
||||
|
||||
print(f" ✓ Saved image {i + 1} to: {output} ({len(image_data)} bytes)")
|
||||
else:
|
||||
print(f" Image {i + 1} URL: {image.url}")
|
||||
|
||||
print("\n" + "=" * 80)
|
||||
print("✓ Image generation test completed successfully!")
|
||||
print("=" * 80)
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
print(f"\n❌ Error: {e}")
|
||||
import traceback
|
||||
|
||||
traceback.print_exc()
|
||||
return False
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# Parse command line arguments
|
||||
base_url = sys.argv[1] if len(sys.argv) > 1 else "http://localhost:8000/v1"
|
||||
|
||||
print("\n" + "=" * 80)
|
||||
print("OpenAI SDK - Image Generation Tests")
|
||||
print("=" * 80)
|
||||
print(f"Base URL: {base_url}")
|
||||
print()
|
||||
|
||||
# Test image generation
|
||||
test_image_generation(base_url=base_url)
|
||||
224
examples/visual_gen/serve/sync_video_gen.py
Executable file
224
examples/visual_gen/serve/sync_video_gen.py
Executable file
@ -0,0 +1,224 @@
|
||||
#!/usr/bin/env python
|
||||
"""Test script for synchronous video generation endpoint.
|
||||
|
||||
Tests POST /v1/videos/generations endpoint which waits for completion and returns video data.
|
||||
The video is generated synchronously and the response contains the video file.
|
||||
|
||||
Supports two modes:
|
||||
- Text-to-Video (T2V): Generate video from text prompt only
|
||||
- Text+Image-to-Video (TI2V): Generate video from text prompt + reference image
|
||||
|
||||
Examples:
|
||||
# Text-to-Video (T2V)
|
||||
python sync_video_gen.py --mode t2v --prompt "A cool cat on a motorcycle"
|
||||
|
||||
# Text+Image-to-Video (TI2V)
|
||||
python sync_video_gen.py --mode ti2v --prompt "She turns and smiles" --image ./media/woman.jpg
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
import requests
|
||||
|
||||
|
||||
def test_sync_video_generation(
|
||||
base_url: str = "http://localhost:8000/v1",
|
||||
model: str = "wan",
|
||||
prompt: str = "A video of a cute cat playing with a ball in the park",
|
||||
input_reference: str = None,
|
||||
duration: float = 4.0,
|
||||
fps: int = 24,
|
||||
size: str = "256x256",
|
||||
output_file: str = "output_sync.mp4",
|
||||
):
|
||||
"""Test synchronous video generation with direct HTTP requests.
|
||||
|
||||
Args:
|
||||
base_url: Base URL of the API server
|
||||
model: Model name to use
|
||||
prompt: Text prompt for generation
|
||||
input_reference: Path to reference image (optional, for TI2V mode)
|
||||
duration: Video duration in seconds
|
||||
fps: Frames per second
|
||||
size: Video resolution (WxH format)
|
||||
output_file: Output video file path
|
||||
"""
|
||||
mode = "TI2V" if input_reference else "T2V"
|
||||
print("=" * 80)
|
||||
print(f"Testing Sync Video Generation API - {mode} Mode")
|
||||
print("=" * 80)
|
||||
|
||||
print("\n1. Generating video (waiting for completion)...")
|
||||
print(f" Mode: {mode}")
|
||||
print(f" Prompt: {prompt}")
|
||||
if input_reference:
|
||||
print(f" Input Reference: {input_reference}")
|
||||
print(f" Duration: {duration}s")
|
||||
print(f" FPS: {fps}")
|
||||
print(f" Size: {size}")
|
||||
|
||||
try:
|
||||
endpoint = f"{base_url}/videos/generations"
|
||||
|
||||
if input_reference:
|
||||
# TI2V mode - Use multipart/form-data with file upload
|
||||
if not Path(input_reference).exists():
|
||||
print(f"\n❌ Error: Input reference image not found: {input_reference}")
|
||||
return False
|
||||
|
||||
# Prepare form data (all values as strings for multipart)
|
||||
form_data = {
|
||||
"model": model,
|
||||
"prompt": prompt,
|
||||
"size": size,
|
||||
"seconds": str(duration),
|
||||
"fps": str(fps),
|
||||
}
|
||||
|
||||
# Add the file
|
||||
## Note: The content-type must be multipart/form-data.
|
||||
files = {
|
||||
"input_reference": (
|
||||
Path(input_reference).name,
|
||||
open(input_reference, "rb"),
|
||||
"multipart/form-data",
|
||||
)
|
||||
}
|
||||
|
||||
print("\n Uploading reference image and generating video...")
|
||||
response_video = requests.post(endpoint, data=form_data, files=files)
|
||||
else:
|
||||
# T2V mode - Use JSON
|
||||
response_video = requests.post(
|
||||
endpoint,
|
||||
json={
|
||||
"model": model,
|
||||
"prompt": prompt,
|
||||
"size": size,
|
||||
"seconds": duration,
|
||||
"fps": fps,
|
||||
},
|
||||
)
|
||||
|
||||
print(f"\nStatus code: {response_video.status_code}")
|
||||
|
||||
if response_video.status_code == 200:
|
||||
with open(output_file, "wb") as f:
|
||||
f.write(response_video.content)
|
||||
print(f"✓ Video saved to: {output_file}")
|
||||
|
||||
print("\n" + "=" * 80)
|
||||
print("✓ Sync video generation test completed successfully!")
|
||||
print("=" * 80)
|
||||
return True
|
||||
else:
|
||||
print(f"\n❌ Error: Server returned status {response_video.status_code}")
|
||||
print(f"Response: {response_video.text}")
|
||||
return False
|
||||
|
||||
except Exception as e:
|
||||
print(f"\n❌ Error: {e}")
|
||||
import traceback
|
||||
|
||||
traceback.print_exc()
|
||||
return False
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Test synchronous video generation API with T2V and TI2V modes",
|
||||
formatter_class=argparse.RawDescriptionHelpFormatter,
|
||||
epilog="""
|
||||
Examples:
|
||||
# Text-to-Video (T2V)
|
||||
python sync_video_gen.py --mode t2v --prompt "A cool cat on a motorcycle"
|
||||
|
||||
# Text+Image-to-Video (TI2V)
|
||||
python sync_video_gen.py --mode ti2v \\
|
||||
--prompt "She turns around and smiles, then slowly walks out of the frame" \\
|
||||
--image ./media/woman_skyline_original_720p.jpeg
|
||||
|
||||
# Custom parameters
|
||||
python sync_video_gen.py --mode t2v \\
|
||||
--prompt "A serene sunset over the ocean" \\
|
||||
--duration 5.0 --fps 30 --size 512x512 \\
|
||||
--output my_video.mp4
|
||||
""",
|
||||
)
|
||||
|
||||
# Mode selection
|
||||
parser.add_argument(
|
||||
"--mode",
|
||||
choices=["t2v", "ti2v"],
|
||||
default="t2v",
|
||||
help="Generation mode: t2v (Text-to-Video) or ti2v (Text+Image-to-Video)",
|
||||
)
|
||||
|
||||
# Required parameters
|
||||
parser.add_argument(
|
||||
"--prompt",
|
||||
type=str,
|
||||
default="A video of a cute cat playing with a ball in the park",
|
||||
help="Text prompt for video generation",
|
||||
)
|
||||
|
||||
# TI2V mode parameters
|
||||
parser.add_argument(
|
||||
"--image",
|
||||
"--input-reference",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Path to reference image (required for ti2v mode)",
|
||||
)
|
||||
|
||||
# Optional parameters
|
||||
parser.add_argument(
|
||||
"--base-url",
|
||||
type=str,
|
||||
default="http://localhost:8000/v1",
|
||||
help="Base URL of the API server",
|
||||
)
|
||||
parser.add_argument("--model", type=str, default="wan", help="Model name to use")
|
||||
parser.add_argument(
|
||||
"--duration", "--seconds", type=float, default=4.0, help="Video duration in seconds"
|
||||
)
|
||||
parser.add_argument("--fps", type=int, default=24, help="Frames per second")
|
||||
parser.add_argument(
|
||||
"--size",
|
||||
type=str,
|
||||
default="256x256",
|
||||
help="Video resolution in WxH format (e.g., 1280x720)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--output", type=str, default="output_sync.mp4", help="Output video file path"
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
# Validate ti2v mode requirements
|
||||
if args.mode == "ti2v" and not args.image:
|
||||
parser.error("--image is required when using --mode ti2v")
|
||||
|
||||
# Display configuration
|
||||
print("\n" + "=" * 80)
|
||||
print("Synchronous Video Generation Test")
|
||||
print("=" * 80)
|
||||
print(f"Base URL: {args.base_url}")
|
||||
print(f"Mode: {args.mode.upper()}")
|
||||
print()
|
||||
|
||||
# Test sync video generation
|
||||
success = test_sync_video_generation(
|
||||
base_url=args.base_url,
|
||||
model=args.model,
|
||||
prompt=args.prompt,
|
||||
input_reference=args.image,
|
||||
duration=args.duration,
|
||||
fps=args.fps,
|
||||
size=args.size,
|
||||
output_file=args.output,
|
||||
)
|
||||
|
||||
sys.exit(0 if success else 1)
|
||||
238
examples/visual_gen/visual_gen_examples.sh
Executable file
238
examples/visual_gen/visual_gen_examples.sh
Executable file
@ -0,0 +1,238 @@
|
||||
#!/bin/bash
|
||||
# Visual Generation Examples - Test different models and configurations
|
||||
#
|
||||
# This script runs a comprehensive suite of visual generation examples including:
|
||||
# - WAN T2V: Baseline, TeaCache, CFG parallelism, Ulysses parallelism, and combinations
|
||||
# - WAN I2V: Baseline, TeaCache, CFG parallelism, Ulysses parallelism, and combinations
|
||||
#
|
||||
# The script automatically detects GPU count and runs appropriate examples:
|
||||
# - 1 GPU: Single-GPU examples only
|
||||
# - 2 GPUs: + CFG parallelism, Ulysses parallelism
|
||||
# - 4 GPUs: + CFG + Ulysses combined
|
||||
# - 8 GPUs: + Large-scale high-resolution examples
|
||||
#
|
||||
# Usage:
|
||||
# export MODEL_ROOT=/path/to/models # required
|
||||
# # Optional: PROJECT_ROOT auto-detected when run from examples/visual_gen
|
||||
# cd examples/visual_gen && ./visual_gen_examples.sh
|
||||
#
|
||||
# Or inline:
|
||||
# MODEL_ROOT=/llm-models ./visual_gen_examples.sh
|
||||
|
||||
set -e # Exit on error
|
||||
|
||||
# Environment variables with defaults
|
||||
# PROJECT_ROOT: auto-detect repo root when run from examples/visual_gen
|
||||
SCRIPT_DIR="$(cd "$(dirname "$0")" && pwd)"
|
||||
PROJECT_ROOT=${PROJECT_ROOT:-"$(cd "${SCRIPT_DIR}/../.." && pwd)"}
|
||||
MODEL_ROOT=${MODEL_ROOT:-"/llm-models"}
|
||||
|
||||
# Log configuration
|
||||
export TLLM_LOG_LEVEL=${TLLM_LOG_LEVEL:-"INFO"}
|
||||
|
||||
echo "============================================"
|
||||
echo "Visual Generation Examples"
|
||||
echo "============================================"
|
||||
echo "PROJECT_ROOT: $PROJECT_ROOT"
|
||||
echo "MODEL_ROOT: $MODEL_ROOT"
|
||||
echo "LOG_LEVEL: $TLLM_LOG_LEVEL"
|
||||
echo "============================================"
|
||||
echo ""
|
||||
|
||||
|
||||
# Detect GPU count
|
||||
if command -v nvidia-smi &> /dev/null; then
|
||||
GPU_COUNT=$(nvidia-smi --query-gpu=name --format=csv,noheader | wc -l)
|
||||
echo "Detected $GPU_COUNT GPU(s)"
|
||||
if [ "$GPU_COUNT" -lt 2 ]; then
|
||||
echo "Note: Multi-GPU examples will be skipped"
|
||||
SKIP_MULTI_GPU=1
|
||||
elif [ "$GPU_COUNT" -ge 8 ]; then
|
||||
echo "Note: Will run all examples including 8-GPU configurations"
|
||||
elif [ "$GPU_COUNT" -ge 4 ]; then
|
||||
echo "Note: Will run examples up to 4-GPU configurations"
|
||||
else
|
||||
echo "Note: Will run 2-GPU examples only"
|
||||
fi
|
||||
else
|
||||
echo "WARNING: nvidia-smi not found. Assuming single GPU."
|
||||
GPU_COUNT=1
|
||||
SKIP_MULTI_GPU=1
|
||||
fi
|
||||
echo ""
|
||||
|
||||
#############################################
|
||||
# WAN (Wan2.1) Text-to-Video Examples
|
||||
#############################################
|
||||
# Demonstrates:
|
||||
# - Single GPU: Baseline and TeaCache
|
||||
# - 2 GPUs: CFG only, Ulysses only
|
||||
# - 4 GPUs: CFG + Ulysses combined
|
||||
# - 8 GPUs: Large-scale parallelism
|
||||
#############################################
|
||||
|
||||
echo "=== WAN Example 1: Baseline (no optimization) ==="
|
||||
python ${PROJECT_ROOT}/examples/visual_gen/visual_gen_wan_t2v.py \
|
||||
--height 480 \
|
||||
--width 832 \
|
||||
--num_frames 33 \
|
||||
--model_path ${MODEL_ROOT}/Wan2.1-T2V-1.3B-Diffusers/ \
|
||||
--prompt "A cute cat playing piano" \
|
||||
--output_path wan_cat_piano.png
|
||||
|
||||
echo ""
|
||||
echo "=== WAN Example 2: With TeaCache ==="
|
||||
python ${PROJECT_ROOT}/examples/visual_gen/visual_gen_wan_t2v.py \
|
||||
--height 480 \
|
||||
--width 832 \
|
||||
--num_frames 33 \
|
||||
--model_path ${MODEL_ROOT}/Wan2.1-T2V-1.3B-Diffusers \
|
||||
--prompt "A cute cat playing piano" \
|
||||
--output_path wan_cat_piano_teacache.png \
|
||||
--enable_teacache
|
||||
|
||||
if [ -z "$SKIP_MULTI_GPU" ]; then
|
||||
echo ""
|
||||
echo "=== WAN Example 3: CFG Only (2 GPUs) ==="
|
||||
python ${PROJECT_ROOT}/examples/visual_gen/visual_gen_wan_t2v.py \
|
||||
--height 480 \
|
||||
--width 832 \
|
||||
--num_frames 33 \
|
||||
--model_path ${MODEL_ROOT}/Wan2.1-T2V-1.3B-Diffusers/ \
|
||||
--prompt "A cute cat playing piano" \
|
||||
--output_path wan_cfg_2gpu.mp4 \
|
||||
--attention_backend TRTLLM \
|
||||
--cfg_size 2 \
|
||||
--ulysses_size 1
|
||||
else
|
||||
echo ""
|
||||
echo "=== WAN Example 3: Skipped (requires 2 GPUs) ==="
|
||||
fi
|
||||
|
||||
if [ -z "$SKIP_MULTI_GPU" ]; then
|
||||
echo ""
|
||||
echo "=== WAN Example 4: Ulysses Only (2 GPUs) ==="
|
||||
python ${PROJECT_ROOT}/examples/visual_gen/visual_gen_wan_t2v.py \
|
||||
--height 480 \
|
||||
--width 832 \
|
||||
--num_frames 33 \
|
||||
--model_path ${MODEL_ROOT}/Wan2.1-T2V-1.3B-Diffusers/ \
|
||||
--prompt "A cute cat playing piano" \
|
||||
--output_path wan_ulysses_2gpu.mp4 \
|
||||
--attention_backend TRTLLM \
|
||||
--cfg_size 1 \
|
||||
--ulysses_size 2
|
||||
else
|
||||
echo ""
|
||||
echo "=== WAN Example 4: Skipped (requires 2 GPUs) ==="
|
||||
fi
|
||||
|
||||
if [ "$GPU_COUNT" -ge 4 ]; then
|
||||
echo ""
|
||||
echo "=== WAN Example 5: CFG + Ulysses (4 GPUs) ==="
|
||||
python ${PROJECT_ROOT}/examples/visual_gen/visual_gen_wan_t2v.py \
|
||||
--height 480 \
|
||||
--width 832 \
|
||||
--num_frames 33 \
|
||||
--model_path ${MODEL_ROOT}/Wan2.1-T2V-1.3B-Diffusers/ \
|
||||
--prompt "A cute cat playing piano" \
|
||||
--output_path wan_cfg_ulysses_4gpu.mp4 \
|
||||
--attention_backend TRTLLM \
|
||||
--cfg_size 2 \
|
||||
--ulysses_size 2
|
||||
else
|
||||
echo ""
|
||||
echo "=== WAN Example 5: Skipped (requires 4 GPUs) ==="
|
||||
fi
|
||||
|
||||
if [ "$GPU_COUNT" -ge 8 ]; then
|
||||
echo ""
|
||||
echo "=== WAN Example 6: Large-Scale (8 GPUs) ==="
|
||||
python ${PROJECT_ROOT}/examples/visual_gen/visual_gen_wan_t2v.py \
|
||||
--height 480 \
|
||||
--width 832 \
|
||||
--num_frames 33 \
|
||||
--model_path ${MODEL_ROOT}/Wan2.1-T2V-1.3B-Diffusers/ \
|
||||
--prompt "A cute cat playing piano" \
|
||||
--output_path wan_cfg_ulysses_8gpu.mp4 \
|
||||
--attention_backend TRTLLM \
|
||||
--cfg_size 2 \
|
||||
--ulysses_size 4
|
||||
else
|
||||
echo ""
|
||||
echo "=== WAN Example 6: Skipped (requires 8 GPUs) ==="
|
||||
fi
|
||||
|
||||
#############################################
|
||||
# WAN 2.2 (Two-Stage) Text-to-Video Examples
|
||||
#############################################
|
||||
|
||||
echo ""
|
||||
echo "=== WAN 2.2 T2V Example: Two-stage with optimizations (FP8 + TRT-LLM + TeaCache) ==="
|
||||
python ${PROJECT_ROOT}/examples/visual_gen/visual_gen_wan_t2v.py \
|
||||
--height 720 \
|
||||
--width 1280 \
|
||||
--num_frames 81 \
|
||||
--model_path ${MODEL_ROOT}/Wan2.2-T2V-A14B-Diffusers \
|
||||
--prompt "A cute cat playing piano" \
|
||||
--output_path wan22_t2v_cat_piano_optimized.gif \
|
||||
--linear_type trtllm-fp8-blockwise \
|
||||
--attention_backend TRTLLM \
|
||||
--enable_teacache \
|
||||
--teacache_thresh 0.2 \
|
||||
--guidance_scale 3.0 \
|
||||
--guidance_scale_2 2.5 \
|
||||
--boundary_ratio 0.85
|
||||
|
||||
#############################################
|
||||
# WAN 2.1 Image-to-Video Examples
|
||||
#############################################
|
||||
|
||||
echo ""
|
||||
echo "=== WAN 2.1 I2V Example: Single-stage with optimizations (FP8 + TRT-LLM + TeaCache) ==="
|
||||
python ${PROJECT_ROOT}/examples/visual_gen/visual_gen_wan_i2v.py \
|
||||
--height 480 \
|
||||
--width 832 \
|
||||
--num_frames 33 \
|
||||
--model_path ${MODEL_ROOT}/Wan2.1-I2V-14B-480P-Diffusers \
|
||||
--image_path ${PROJECT_ROOT}/examples/visual_gen/cat_piano.png \
|
||||
--prompt "It snows as the cat plays piano, lots of snow \
|
||||
appearing all over the screen, snowflakes, blizzard,
|
||||
gradually more snow" \
|
||||
--negative_prompt "blurry, low quality" \
|
||||
--output_path wan21_i2v_cat_piano_optimized.gif \
|
||||
--linear_type trtllm-fp8-per-tensor \
|
||||
--attention_backend TRTLLM \
|
||||
--enable_teacache \
|
||||
--teacache_thresh 0.2 \
|
||||
--guidance_scale 6.0
|
||||
|
||||
#############################################
|
||||
# WAN 2.2 (Two-Stage) Image-to-Video Examples
|
||||
#############################################
|
||||
|
||||
echo ""
|
||||
echo "=== WAN 2.2 I2V Example: Two-stage with optimizations (FP8 + TRT-LLM + TeaCache) ==="
|
||||
python ${PROJECT_ROOT}/examples/visual_gen/visual_gen_wan_i2v.py \
|
||||
--height 480 \
|
||||
--width 832 \
|
||||
--num_frames 81 \
|
||||
--model_path ${MODEL_ROOT}/Wan2.2-I2V-A14B-Diffusers \
|
||||
--image_path ${PROJECT_ROOT}/examples/visual_gen/cat_piano.png \
|
||||
--prompt "It snows as the cat plays piano, lots of snow \
|
||||
appearing all over the screen, snowflakes, blizzard,
|
||||
gradually more snow" \
|
||||
--negative_prompt "blurry, low quality" \
|
||||
--output_path wan22_i2v_cat_piano_optimized.gif \
|
||||
--linear_type trtllm-fp8-blockwise \
|
||||
--attention_backend TRTLLM \
|
||||
--enable_teacache \
|
||||
--teacache_thresh 0.2 \
|
||||
--guidance_scale 6.0 \
|
||||
--guidance_scale_2 5.0 \
|
||||
--boundary_ratio 0.85
|
||||
|
||||
echo ""
|
||||
echo "============================================"
|
||||
echo "All examples completed successfully!"
|
||||
echo "============================================"
|
||||
226
examples/visual_gen/visual_gen_wan_i2v.py
Normal file
226
examples/visual_gen/visual_gen_wan_i2v.py
Normal file
@ -0,0 +1,226 @@
|
||||
#!/usr/bin/env python3
|
||||
"""WAN Image-to-Video generation using TensorRT-LLM Visual Generation."""
|
||||
|
||||
import argparse
|
||||
import time
|
||||
|
||||
from output_handler import OutputHandler
|
||||
|
||||
from tensorrt_llm import logger
|
||||
from tensorrt_llm.llmapi.visual_gen import VisualGen, VisualGenParams
|
||||
|
||||
# Set logger level to ensure timing logs are printed
|
||||
logger.set_level("info")
|
||||
|
||||
|
||||
def parse_args():
|
||||
parser = argparse.ArgumentParser(
|
||||
description="TRTLLM VisualGen - Wan Image-to-Video Inference Example (supports Wan 2.1 and Wan 2.2)"
|
||||
)
|
||||
|
||||
# Model & Input
|
||||
parser.add_argument(
|
||||
"--model_path",
|
||||
type=str,
|
||||
required=True,
|
||||
help="Path to Wan I2V Diffusers model directory (2.1 or 2.2)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--image_path",
|
||||
type=str,
|
||||
required=True,
|
||||
help="Path to input image for I2V conditioning",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--last_image_path",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Optional path to last frame image for interpolation (Wan 2.1 only)",
|
||||
)
|
||||
parser.add_argument("--prompt", type=str, required=True, help="Text prompt for generation")
|
||||
parser.add_argument(
|
||||
"--negative_prompt",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Negative prompt. Default is model-specific.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--output_path",
|
||||
type=str,
|
||||
default="output.png",
|
||||
help="Path to save the output image/video frame",
|
||||
)
|
||||
|
||||
# Generation Params
|
||||
parser.add_argument("--height", type=int, default=720, help="Video height")
|
||||
parser.add_argument("--width", type=int, default=1280, help="Video width")
|
||||
parser.add_argument("--num_frames", type=int, default=81, help="Number of frames to generate")
|
||||
parser.add_argument(
|
||||
"--steps",
|
||||
type=int,
|
||||
default=None,
|
||||
help="Number of denoising steps (default: auto-detect, 50 for Wan2.1, 40 for Wan2.2)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--guidance_scale",
|
||||
type=float,
|
||||
default=None,
|
||||
help="Guidance scale (default: auto-detect, 5.0 for Wan2.1, 4.0 for Wan2.2)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--guidance_scale_2",
|
||||
type=float,
|
||||
default=None,
|
||||
help="Second-stage guidance scale for Wan2.2 two-stage denoising (default: 3.0)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--boundary_ratio",
|
||||
type=float,
|
||||
default=None,
|
||||
help="Custom boundary ratio for two-stage denoising (default: auto-detect)",
|
||||
)
|
||||
parser.add_argument("--seed", type=int, default=42, help="Random seed")
|
||||
|
||||
# TeaCache Arguments
|
||||
parser.add_argument(
|
||||
"--enable_teacache", action="store_true", help="Enable TeaCache acceleration"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--teacache_thresh",
|
||||
type=float,
|
||||
default=0.2,
|
||||
help="TeaCache similarity threshold (rel_l1_thresh)",
|
||||
)
|
||||
|
||||
# Quantization
|
||||
parser.add_argument(
|
||||
"--linear_type",
|
||||
type=str,
|
||||
default="default",
|
||||
choices=["default", "trtllm-fp8-per-tensor", "trtllm-fp8-blockwise", "svd-nvfp4"],
|
||||
help="Linear layer quantization type",
|
||||
)
|
||||
|
||||
# Attention Backend
|
||||
parser.add_argument(
|
||||
"--attention_backend",
|
||||
type=str,
|
||||
default="VANILLA",
|
||||
choices=["VANILLA", "TRTLLM"],
|
||||
help="Attention backend (VANILLA: PyTorch SDPA, TRTLLM: optimized kernels). "
|
||||
"Note: TRTLLM automatically falls back to VANILLA for cross-attention.",
|
||||
)
|
||||
|
||||
# Parallelism
|
||||
parser.add_argument(
|
||||
"--cfg_size",
|
||||
type=int,
|
||||
default=1,
|
||||
choices=[1, 2],
|
||||
help="CFG parallel size (1 or 2). Set to 2 for CFG Parallelism.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--ulysses_size",
|
||||
type=int,
|
||||
default=1,
|
||||
help="Ulysses (sequence) parallel size within each CFG group.",
|
||||
)
|
||||
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def main():
|
||||
args = parse_args()
|
||||
|
||||
# world_size = cfg_size * ulysses_size
|
||||
# Example: cfg_size=2, ulysses_size=4 -> 8 GPUs
|
||||
# GPU 0-3: CFG group 0 (positive prompt), internal Ulysses parallel
|
||||
# GPU 4-7: CFG group 1 (negative prompt), internal Ulysses parallel
|
||||
n_workers = args.cfg_size * args.ulysses_size
|
||||
|
||||
# Convert linear_type to quant_config
|
||||
quant_config = None
|
||||
if args.linear_type == "trtllm-fp8-per-tensor":
|
||||
quant_config = {"quant_algo": "FP8", "dynamic": True}
|
||||
elif args.linear_type == "trtllm-fp8-blockwise":
|
||||
quant_config = {"quant_algo": "FP8_BLOCK_SCALES", "dynamic": True}
|
||||
elif args.linear_type == "svd-nvfp4":
|
||||
quant_config = {"quant_algo": "NVFP4", "dynamic": True}
|
||||
|
||||
# 1. Setup Configuration
|
||||
diffusion_config = {
|
||||
"model_type": "wan2",
|
||||
"attention": {
|
||||
"backend": args.attention_backend,
|
||||
},
|
||||
"teacache": {
|
||||
"enable_teacache": args.enable_teacache,
|
||||
"teacache_thresh": args.teacache_thresh,
|
||||
},
|
||||
"parallel": {
|
||||
"dit_cfg_size": args.cfg_size,
|
||||
"dit_ulysses_size": args.ulysses_size,
|
||||
},
|
||||
}
|
||||
|
||||
# Add quant_config if specified
|
||||
if quant_config is not None:
|
||||
diffusion_config["quant_config"] = quant_config
|
||||
|
||||
# 2. Initialize VisualGen
|
||||
logger.info(
|
||||
f"Initializing VisualGen: world_size={n_workers} "
|
||||
f"(cfg_size={args.cfg_size}, ulysses_size={args.ulysses_size})"
|
||||
)
|
||||
visual_gen = VisualGen(
|
||||
model_path=args.model_path,
|
||||
n_workers=n_workers,
|
||||
diffusion_config=diffusion_config,
|
||||
)
|
||||
|
||||
try:
|
||||
# 2. Run Inference
|
||||
logger.info(f"Generating video for prompt: '{args.prompt}'")
|
||||
logger.info(f"Negative prompt: '{args.negative_prompt}'")
|
||||
logger.info(f"Input image: {args.image_path}")
|
||||
if args.last_image_path:
|
||||
logger.info(f"Last frame image: {args.last_image_path}")
|
||||
logger.info(
|
||||
f"Resolution: {args.height}x{args.width}, Frames: {args.num_frames}, Steps: {args.steps}"
|
||||
)
|
||||
|
||||
start_time = time.time()
|
||||
|
||||
# Build parameters with explicit I2V and Wan 2.2 support
|
||||
output = visual_gen.generate(
|
||||
inputs={
|
||||
"prompt": args.prompt,
|
||||
"negative_prompt": args.negative_prompt,
|
||||
},
|
||||
params=VisualGenParams(
|
||||
height=args.height,
|
||||
width=args.width,
|
||||
num_inference_steps=args.steps,
|
||||
guidance_scale=args.guidance_scale,
|
||||
seed=args.seed,
|
||||
num_frames=args.num_frames,
|
||||
input_reference=args.image_path,
|
||||
last_image=args.last_image_path if args.last_image_path else None,
|
||||
guidance_scale_2=args.guidance_scale_2,
|
||||
boundary_ratio=args.boundary_ratio,
|
||||
),
|
||||
)
|
||||
|
||||
end_time = time.time()
|
||||
logger.info(f"Generation completed in {end_time - start_time:.2f}s")
|
||||
|
||||
# 3. Save Output
|
||||
OutputHandler.save(output, args.output_path, frame_rate=16.0)
|
||||
|
||||
finally:
|
||||
# 4. Shutdown
|
||||
visual_gen.shutdown()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
228
examples/visual_gen/visual_gen_wan_t2v.py
Executable file
228
examples/visual_gen/visual_gen_wan_t2v.py
Executable file
@ -0,0 +1,228 @@
|
||||
#!/usr/bin/env python3
|
||||
"""WAN Text-to-Video generation using TensorRT-LLM Visual Generation."""
|
||||
|
||||
import argparse
|
||||
import time
|
||||
|
||||
from output_handler import OutputHandler
|
||||
|
||||
from tensorrt_llm import logger
|
||||
from tensorrt_llm.llmapi.visual_gen import VisualGen, VisualGenParams
|
||||
|
||||
# Set logger level to ensure timing logs are printed
|
||||
logger.set_level("info")
|
||||
|
||||
|
||||
def parse_args():
|
||||
parser = argparse.ArgumentParser(
|
||||
description="TRTLLM VisualGen - Wan Text-to-Video Inference Example (supports Wan 2.1 and Wan 2.2)"
|
||||
)
|
||||
|
||||
# Model & Input
|
||||
parser.add_argument(
|
||||
"--model_path",
|
||||
type=str,
|
||||
required=True,
|
||||
help="Local path or HuggingFace Hub model ID (e.g., Wan-AI/Wan2.1-T2V-1.3B-Diffusers)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--revision",
|
||||
type=str,
|
||||
default=None,
|
||||
help="HuggingFace Hub revision (branch, tag, or commit SHA)",
|
||||
)
|
||||
parser.add_argument("--prompt", type=str, required=True, help="Text prompt for generation")
|
||||
parser.add_argument(
|
||||
"--negative_prompt",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Negative prompt. Default is model-specific.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--output_path",
|
||||
type=str,
|
||||
default="output.png",
|
||||
help="Path to save the output image/video frame",
|
||||
)
|
||||
|
||||
# Generation Params
|
||||
parser.add_argument("--height", type=int, default=720, help="Video height")
|
||||
parser.add_argument("--width", type=int, default=1280, help="Video width")
|
||||
parser.add_argument("--num_frames", type=int, default=81, help="Number of frames to generate")
|
||||
parser.add_argument(
|
||||
"--steps",
|
||||
type=int,
|
||||
default=None,
|
||||
help="Number of denoising steps (default: auto-detect, 50 for Wan2.1, 40 for Wan2.2)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--guidance_scale",
|
||||
type=float,
|
||||
default=None,
|
||||
help="Guidance scale (default: auto-detect, 5.0 for Wan2.1, 4.0 for Wan2.2)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--guidance_scale_2",
|
||||
type=float,
|
||||
default=None,
|
||||
help="Second-stage guidance scale for Wan2.2 two-stage denoising (default: 3.0)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--boundary_ratio",
|
||||
type=float,
|
||||
default=None,
|
||||
help="Custom boundary ratio for two-stage denoising (default: auto-detect)",
|
||||
)
|
||||
parser.add_argument("--seed", type=int, default=42, help="Random seed")
|
||||
|
||||
# TeaCache Arguments
|
||||
parser.add_argument(
|
||||
"--enable_teacache", action="store_true", help="Enable TeaCache acceleration"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--teacache_thresh",
|
||||
type=float,
|
||||
default=0.2,
|
||||
help="TeaCache similarity threshold (rel_l1_thresh)",
|
||||
)
|
||||
|
||||
# Quantization
|
||||
parser.add_argument(
|
||||
"--linear_type",
|
||||
type=str,
|
||||
default="default",
|
||||
choices=["default", "trtllm-fp8-per-tensor", "trtllm-fp8-blockwise", "svd-nvfp4"],
|
||||
help="Linear layer quantization type",
|
||||
)
|
||||
|
||||
# Attention Backend
|
||||
parser.add_argument(
|
||||
"--attention_backend",
|
||||
type=str,
|
||||
default="VANILLA",
|
||||
choices=["VANILLA", "TRTLLM"],
|
||||
help="Attention backend (VANILLA: PyTorch SDPA, TRTLLM: optimized kernels). "
|
||||
"Note: TRTLLM automatically falls back to VANILLA for cross-attention.",
|
||||
)
|
||||
|
||||
# Parallelism
|
||||
parser.add_argument(
|
||||
"--cfg_size",
|
||||
type=int,
|
||||
default=1,
|
||||
choices=[1, 2],
|
||||
help="CFG parallel size (1 or 2). "
|
||||
"Distributes positive/negative prompts across GPUs. "
|
||||
"Example: cfg_size=2 on 4 GPUs -> 2 GPUs per prompt.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--ulysses_size",
|
||||
type=int,
|
||||
default=1,
|
||||
help="Ulysses sequence parallel size within each CFG group. "
|
||||
"Distributes sequence across GPUs for longer sequences. "
|
||||
"Requirements: num_heads (12) and sequence length must both be divisible by ulysses_size. "
|
||||
"Example: ulysses_size=2 on 4 GPUs with cfg_size=2 -> "
|
||||
"2 CFG groups × 2 Ulysses ranks = 4 GPUs total.",
|
||||
)
|
||||
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def main():
|
||||
args = parse_args()
|
||||
|
||||
# Total workers: cfg_size × ulysses_size
|
||||
# See ParallelConfig in config.py for detailed parallelism strategy and examples
|
||||
n_workers = args.cfg_size * args.ulysses_size
|
||||
|
||||
# Log Ulysses configuration (validation happens in setup_sequence_parallelism)
|
||||
if args.ulysses_size > 1:
|
||||
num_heads = 12 # WAN has 12 attention heads
|
||||
logger.info(
|
||||
f"Using Ulysses sequence parallelism: "
|
||||
f"{num_heads} heads / {args.ulysses_size} ranks = "
|
||||
f"{num_heads // args.ulysses_size} heads per GPU"
|
||||
)
|
||||
|
||||
# Convert linear_type to quant_config
|
||||
quant_config = None
|
||||
if args.linear_type == "trtllm-fp8-per-tensor":
|
||||
quant_config = {"quant_algo": "FP8", "dynamic": True}
|
||||
elif args.linear_type == "trtllm-fp8-blockwise":
|
||||
quant_config = {"quant_algo": "FP8_BLOCK_SCALES", "dynamic": True}
|
||||
elif args.linear_type == "svd-nvfp4":
|
||||
quant_config = {"quant_algo": "NVFP4", "dynamic": True}
|
||||
|
||||
# 1. Setup Configuration
|
||||
diffusion_config = {
|
||||
"model_type": "wan2",
|
||||
"revision": args.revision,
|
||||
"attention": {
|
||||
"backend": args.attention_backend,
|
||||
},
|
||||
"teacache": {
|
||||
"enable_teacache": args.enable_teacache,
|
||||
"teacache_thresh": args.teacache_thresh,
|
||||
},
|
||||
"parallel": {
|
||||
"dit_cfg_size": args.cfg_size,
|
||||
"dit_ulysses_size": args.ulysses_size,
|
||||
},
|
||||
}
|
||||
|
||||
# Add quant_config if specified
|
||||
if quant_config is not None:
|
||||
diffusion_config["quant_config"] = quant_config
|
||||
|
||||
# 2. Initialize VisualGen
|
||||
logger.info(
|
||||
f"Initializing VisualGen: world_size={n_workers} "
|
||||
f"(cfg_size={args.cfg_size}, ulysses_size={args.ulysses_size})"
|
||||
)
|
||||
visual_gen = VisualGen(
|
||||
model_path=args.model_path,
|
||||
n_workers=n_workers,
|
||||
diffusion_config=diffusion_config,
|
||||
)
|
||||
|
||||
try:
|
||||
# 2. Run Inference
|
||||
logger.info(f"Generating video for prompt: '{args.prompt}'")
|
||||
logger.info(f"Negative prompt: '{args.negative_prompt}'")
|
||||
logger.info(
|
||||
f"Resolution: {args.height}x{args.width}, Frames: {args.num_frames}, Steps: {args.steps}"
|
||||
)
|
||||
|
||||
start_time = time.time()
|
||||
|
||||
output = visual_gen.generate(
|
||||
inputs={
|
||||
"prompt": args.prompt,
|
||||
"negative_prompt": args.negative_prompt,
|
||||
},
|
||||
params=VisualGenParams(
|
||||
height=args.height,
|
||||
width=args.width,
|
||||
num_inference_steps=args.steps,
|
||||
guidance_scale=args.guidance_scale,
|
||||
seed=args.seed,
|
||||
num_frames=args.num_frames,
|
||||
guidance_scale_2=args.guidance_scale_2,
|
||||
boundary_ratio=args.boundary_ratio,
|
||||
),
|
||||
)
|
||||
|
||||
end_time = time.time()
|
||||
logger.info(f"Generation completed in {end_time - start_time:.2f}s")
|
||||
|
||||
# 3. Save Output
|
||||
OutputHandler.save(output, args.output_path, frame_rate=16.0)
|
||||
|
||||
finally:
|
||||
# 4. Shutdown
|
||||
visual_gen.shutdown()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@ -83,3 +83,4 @@ llist
|
||||
cuda-tile>=1.0.1
|
||||
nvidia-cuda-tileiras>=13.1
|
||||
etcd-sdk-python==0.0.7
|
||||
python-multipart
|
||||
|
||||
@ -4,10 +4,11 @@ from .communicator import Distributed, MPIDist, TorchDist
|
||||
from .moe_alltoall import MoeAlltoAll
|
||||
from .ops import (AllReduce, AllReduceParams, AllReduceStrategy,
|
||||
HelixAllToAllNative, MoEAllReduce, MoEAllReduceParams,
|
||||
allgather, alltoall_helix, cp_allgather, reducescatter,
|
||||
userbuffers_allreduce_finalize)
|
||||
all_to_all_4d, allgather, alltoall_helix, cp_allgather,
|
||||
reducescatter, userbuffers_allreduce_finalize)
|
||||
|
||||
__all__ = [
|
||||
"all_to_all_4d",
|
||||
"allgather",
|
||||
"alltoall_helix",
|
||||
"cp_allgather",
|
||||
|
||||
@ -959,3 +959,126 @@ class MoEAllReduce(nn.Module):
|
||||
nranks=self.mapping.tp_size,
|
||||
eps=all_reduce_params.eps,
|
||||
)
|
||||
|
||||
|
||||
def all_to_all_4d(
|
||||
input: torch.Tensor,
|
||||
scatter_dim: int,
|
||||
gather_dim: int,
|
||||
process_group: Optional[torch.distributed.ProcessGroup] = None,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
All-to-all for 4D tensors (batch, seq, heads, head_dim).
|
||||
|
||||
Redistributes a 4D tensor along two dimensions using all-to-all communication.
|
||||
This is used for Ulysses-style sequence parallelism to transform between:
|
||||
- Sequence sharding [B, S/P, H, D] → Head sharding [B, S, H/P, D]
|
||||
- Head sharding [B, S, H/P, D] → Sequence sharding [B, S/P, H, D]
|
||||
|
||||
Args:
|
||||
input: Input tensor with shape [batch, seq, heads, head_dim]
|
||||
scatter_dim: Dimension to split and scatter (1 for seq, 2 for heads)
|
||||
gather_dim: Dimension to gather (1 for seq, 2 for heads)
|
||||
process_group: PyTorch distributed process group. If None, uses default process group.
|
||||
|
||||
Returns:
|
||||
Redistributed tensor with same shape as input
|
||||
|
||||
Example:
|
||||
# Transform from sequence sharding to head sharding
|
||||
# Input: [B, S/P, H, D] (each rank has S/P sequence)
|
||||
output = all_to_all_4d(input, scatter_dim=2, gather_dim=1, process_group=pg)
|
||||
# Output: [B, S, H/P, D] (each rank has H/P heads)
|
||||
|
||||
# Transform back from head sharding to sequence sharding
|
||||
output = all_to_all_4d(input, scatter_dim=1, gather_dim=2, process_group=pg)
|
||||
"""
|
||||
# Only support PyTorch distributed mode (not MPI mode)
|
||||
if not mpi_disabled():
|
||||
raise NotImplementedError(
|
||||
"all_to_all_4d currently only supports PyTorch distributed mode. "
|
||||
"MPI mode is not supported.")
|
||||
|
||||
# Get world size from process group
|
||||
world_size = torch.distributed.get_world_size(group=process_group)
|
||||
|
||||
# If world_size is 1, no communication needed
|
||||
if world_size == 1:
|
||||
return input
|
||||
|
||||
# Validate dimensions
|
||||
assert scatter_dim in [1, 2], "scatter_dim must be 1 (seq) or 2 (heads)"
|
||||
assert gather_dim in [1, 2], "gather_dim must be 1 (seq) or 2 (heads)"
|
||||
assert scatter_dim != gather_dim, "scatter_dim and gather_dim must be different"
|
||||
|
||||
batch, seq, heads, head_dim = input.shape
|
||||
|
||||
# Validate that the scatter dimension is divisible by world_size
|
||||
scatter_size = input.shape[scatter_dim]
|
||||
assert scatter_size % world_size == 0, \
|
||||
f"Dimension {scatter_dim} size {scatter_size} must be divisible by world_size {world_size}"
|
||||
|
||||
# For all-to-all, we need to:
|
||||
# 1. Split input along scatter_dim into world_size chunks
|
||||
# 2. Send chunk i to rank i
|
||||
# 3. Receive chunk from each rank and concatenate along gather_dim
|
||||
|
||||
# Reshape for all-to-all: move scatter_dim chunks to a new dimension
|
||||
if scatter_dim == 1: # Scatter along seq dimension
|
||||
# [B, S, H, D] -> [B, P, S/P, H, D] where P = world_size
|
||||
input_reshaped = input.view(batch, world_size, seq // world_size, heads,
|
||||
head_dim)
|
||||
# Transpose to group by destination rank: [B, P, S/P, H, D] -> [P, B, S/P, H, D]
|
||||
input_transposed = input_reshaped.permute(1, 0, 2, 3, 4).contiguous()
|
||||
else: # scatter_dim == 2, scatter along heads dimension
|
||||
# [B, S, H, D] -> [B, S, P, H/P, D] where P = world_size
|
||||
input_reshaped = input.view(batch, seq, world_size, heads // world_size,
|
||||
head_dim)
|
||||
# Transpose to group by destination rank: [B, S, P, H/P, D] -> [P, B, S, H/P, D]
|
||||
input_transposed = input_reshaped.permute(2, 0, 1, 3, 4).contiguous()
|
||||
|
||||
# Flatten to [P * ...] for all-to-all communication
|
||||
# Shape: [P, B, ...] -> [P * B * ...]
|
||||
input_flat = input_transposed.flatten()
|
||||
output_flat = torch.empty_like(input_flat)
|
||||
|
||||
# Perform all-to-all communication using PyTorch distributed
|
||||
# all_to_all_single splits input into world_size chunks and exchanges them
|
||||
torch.distributed.all_to_all_single(output_flat,
|
||||
input_flat,
|
||||
group=process_group)
|
||||
|
||||
# Reshape output back to [P, B, ...] form
|
||||
output_transposed = output_flat.view_as(input_transposed)
|
||||
|
||||
# Transpose back and reshape to final form
|
||||
if gather_dim == 1: # Gather along seq dimension
|
||||
# [P, B, S/P, H, D] -> [B, P, S/P, H, D]
|
||||
output_reshaped = output_transposed.permute(1, 0, 2, 3, 4).contiguous()
|
||||
# [B, P, S/P, H, D] -> [B, S, H, D] where S = P * (S/P)
|
||||
# When scattering heads and gathering seq: seq needs to be multiplied, heads needs to be divided
|
||||
if scatter_dim == 2:
|
||||
# Scattered heads, so we have H/P heads and need to gather S/P -> S sequence
|
||||
gathered_seq = seq * world_size
|
||||
sharded_heads = heads // world_size
|
||||
output = output_reshaped.view(batch, gathered_seq, sharded_heads,
|
||||
head_dim)
|
||||
else:
|
||||
# Scattered seq (should be impossible if gather_dim == 1), keep as is
|
||||
output = output_reshaped.view(batch, seq, heads, head_dim)
|
||||
else: # gather_dim == 2, gather along heads dimension
|
||||
# [P, B, S, H/P, D] -> [B, S, P, H/P, D]
|
||||
output_reshaped = output_transposed.permute(1, 2, 0, 3, 4).contiguous()
|
||||
# [B, S, P, H/P, D] -> [B, S, H, D] where H = P * (H/P)
|
||||
# When scattering seq and gathering heads: heads needs to be multiplied, seq needs to be divided
|
||||
if scatter_dim == 1:
|
||||
# Scattered seq, so we have S/P seq and need to gather H/P -> H heads
|
||||
gathered_heads = heads * world_size
|
||||
sharded_seq = seq // world_size
|
||||
output = output_reshaped.view(batch, sharded_seq, gathered_heads,
|
||||
head_dim)
|
||||
else:
|
||||
# Scattered heads (should be impossible if gather_dim == 2), keep as is
|
||||
output = output_reshaped.view(batch, seq, heads, head_dim)
|
||||
|
||||
return output
|
||||
|
||||
@ -556,6 +556,13 @@ class FP8QDQLinearMethod(UnquantizedLinearMethod):
|
||||
|
||||
def apply(self, module: Linear, input: torch.Tensor,
|
||||
bias: Optional[torch.Tensor]):
|
||||
|
||||
# Handle multi-dimensional inputs (e.g., 3D: batch, seq, hidden)
|
||||
# GEMM ops require 2D matrices
|
||||
original_shape = input.shape
|
||||
if input.dim() > 2:
|
||||
input = input.reshape(-1, input.shape[-1])
|
||||
|
||||
cur_input_scale = module.input_scale
|
||||
if input.dtype != torch.float8_e4m3fn:
|
||||
if module.input_scale is not None and not module.force_dynamic_quantization:
|
||||
@ -591,6 +598,11 @@ class FP8QDQLinearMethod(UnquantizedLinearMethod):
|
||||
bias=None,
|
||||
out_dtype=module.dtype or input.dtype,
|
||||
)
|
||||
|
||||
# Reshape output back to original shape (with out_features as last dim)
|
||||
if len(original_shape) > 2:
|
||||
output = output.reshape(*original_shape[:-1], output.shape[-1])
|
||||
|
||||
if bias is not None:
|
||||
output = output + bias
|
||||
return output
|
||||
@ -975,6 +987,12 @@ class FP8BlockScalesLinearMethod(UnquantizedLinearMethod):
|
||||
|
||||
def apply(self, module: Linear, input: torch.Tensor,
|
||||
bias: Optional[torch.Tensor]):
|
||||
# Handle multi-dimensional inputs (e.g., 3D: batch, seq, hidden)
|
||||
# GEMM ops require 2D matrices
|
||||
original_shape = input.shape
|
||||
if input.dim() > 2:
|
||||
input = input.reshape(-1, input.shape[-1])
|
||||
|
||||
if input.dtype == torch.float8_e4m3fn:
|
||||
input = input.to(torch.bfloat16) * module.input_scale
|
||||
assert input.dtype == torch.bfloat16
|
||||
@ -1003,6 +1021,10 @@ class FP8BlockScalesLinearMethod(UnquantizedLinearMethod):
|
||||
output = torch.ops.trtllm.fp8_block_scaling_gemm(
|
||||
act_input_fp8, module.weight, act_input_sf, module.weight_scale)
|
||||
|
||||
# Reshape output back to original shape (with out_features as last dim)
|
||||
if len(original_shape) > 2:
|
||||
output = output.reshape(*original_shape[:-1], output.shape[-1])
|
||||
|
||||
if bias is not None:
|
||||
output = output + bias
|
||||
return output
|
||||
@ -1212,6 +1234,15 @@ class NVFP4LinearMethod(LinearMethodBase):
|
||||
|
||||
def apply(self, module: Linear, input: torch.Tensor,
|
||||
bias: Optional[torch.Tensor]):
|
||||
# Handle multi-dimensional inputs (e.g., 3D: batch, seq, hidden).
|
||||
# GEMM requires 2D. Only plain tensors support for now, skip for
|
||||
# tuple and Fp4QuantizedTensor.
|
||||
original_shape = None
|
||||
if not isinstance(input,
|
||||
(tuple, Fp4QuantizedTensor)) and input.dim() > 2:
|
||||
original_shape = input.shape
|
||||
input = input.reshape(-1, input.shape[-1])
|
||||
|
||||
act_fp4, act_sf = self._input_prepare(module, input)
|
||||
# Use unified interface - supports CUTLASS, cuBLASLt, CuteDSL
|
||||
# Convert list to comma-separated string for torch.compile compatibility
|
||||
@ -1229,6 +1260,9 @@ class NVFP4LinearMethod(LinearMethodBase):
|
||||
if output.shape[-1] > module.out_features:
|
||||
output = output[..., :module.out_features].contiguous()
|
||||
|
||||
if original_shape is not None:
|
||||
output = output.reshape(*original_shape[:-1], output.shape[-1])
|
||||
|
||||
if bias is not None:
|
||||
output = output + bias
|
||||
return output
|
||||
|
||||
45
tensorrt_llm/_torch/visual_gen/__init__.py
Normal file
45
tensorrt_llm/_torch/visual_gen/__init__.py
Normal file
@ -0,0 +1,45 @@
|
||||
"""Visual generation module for diffusion models."""
|
||||
|
||||
from tensorrt_llm._torch.visual_gen.executor import (
|
||||
DiffusionExecutor,
|
||||
DiffusionRequest,
|
||||
DiffusionResponse,
|
||||
)
|
||||
from tensorrt_llm._torch.visual_gen.output import MediaOutput
|
||||
|
||||
# Checkpoint loading
|
||||
from .checkpoints import WeightLoader
|
||||
from .config import (
|
||||
AttentionConfig,
|
||||
DiffusionArgs,
|
||||
DiffusionModelConfig,
|
||||
ParallelConfig,
|
||||
PipelineComponent,
|
||||
PipelineConfig,
|
||||
TeaCacheConfig,
|
||||
discover_pipeline_components,
|
||||
)
|
||||
from .models import AutoPipeline, BasePipeline, WanPipeline
|
||||
from .pipeline_loader import PipelineLoader
|
||||
|
||||
__all__ = [
|
||||
# Config classes
|
||||
"DiffusionArgs",
|
||||
"DiffusionModelConfig",
|
||||
"ParallelConfig",
|
||||
"PipelineComponent",
|
||||
"TeaCacheConfig",
|
||||
# Checkpoint loading
|
||||
"WeightLoader",
|
||||
# Model loading
|
||||
"PipelineLoader",
|
||||
# Execution
|
||||
"DiffusionExecutor",
|
||||
"DiffusionRequest",
|
||||
"DiffusionResponse",
|
||||
"MediaOutput",
|
||||
# Pipelines
|
||||
"AutoPipeline",
|
||||
"BasePipeline",
|
||||
"WanPipeline",
|
||||
]
|
||||
37
tensorrt_llm/_torch/visual_gen/attention_backend/__init__.py
Normal file
37
tensorrt_llm/_torch/visual_gen/attention_backend/__init__.py
Normal file
@ -0,0 +1,37 @@
|
||||
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
Visual Generation Attention Backend
|
||||
|
||||
This module provides attention backend infrastructure for visual generation (diffusion) models.
|
||||
It reuses existing TRT-LLM attention backends (TrtllmAttention, VanillaAttention) with
|
||||
simplified metadata that doesn't require KV caching.
|
||||
"""
|
||||
|
||||
from .interface import AttentionTensorLayout
|
||||
from .parallel import UlyssesAttention
|
||||
from .trtllm import TrtllmAttention, TrtllmAttentionMetadata
|
||||
from .utils import create_attention, get_visual_gen_attention_backend
|
||||
from .vanilla import VanillaAttention
|
||||
|
||||
__all__ = [
|
||||
"AttentionTensorLayout",
|
||||
"get_visual_gen_attention_backend",
|
||||
"create_attention",
|
||||
"TrtllmAttention",
|
||||
"TrtllmAttentionMetadata",
|
||||
"UlyssesAttention",
|
||||
"VanillaAttention",
|
||||
]
|
||||
@ -0,0 +1,33 @@
|
||||
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
Visual Generation Attention Backend Interface
|
||||
|
||||
Defines shared types and enums for attention backends.
|
||||
"""
|
||||
|
||||
from enum import Enum
|
||||
|
||||
|
||||
class AttentionTensorLayout(str, Enum):
|
||||
"""
|
||||
Tensor layout for attention backend input/output.
|
||||
|
||||
Backends declare their preferred layout so the attention module
|
||||
can reshape tensors optimally before calling the backend.
|
||||
"""
|
||||
|
||||
NHD = "NHD" # [B, S, H, D] - batch, seq, heads, dim
|
||||
HND = "HND" # [B, H, S, D] - batch, heads, seq, dim
|
||||
162
tensorrt_llm/_torch/visual_gen/attention_backend/parallel.py
Normal file
162
tensorrt_llm/_torch/visual_gen/attention_backend/parallel.py
Normal file
@ -0,0 +1,162 @@
|
||||
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
Ulysses Sequence Parallelism Wrapper
|
||||
|
||||
Wraps any attention backend with sequence parallelism via all-to-all
|
||||
communication. Not a standalone backend — compose around a real backend
|
||||
(VANILLA/TRTLLM).
|
||||
|
||||
Architecture:
|
||||
Input: [B, S/P, H, D] (sequence sharded across P processes)
|
||||
Step 1: All-to-All → [B, S, H/P, D] (gather sequence, shard heads)
|
||||
Step 2: Compute attention with wrapped backend (VANILLA or TRTLLM)
|
||||
Step 3: All-to-All → [B, S/P, H, D] (restore sequence sharding)
|
||||
Output: [B, S/P, H, D] (sequence sharded)
|
||||
"""
|
||||
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from tensorrt_llm._torch.distributed import all_to_all_4d
|
||||
|
||||
from .interface import AttentionTensorLayout
|
||||
|
||||
|
||||
class UlyssesAttention(nn.Module):
|
||||
"""
|
||||
Ulysses Sequence Parallelism wrapper.
|
||||
|
||||
Wraps any attention backend with sequence parallelism via all-to-all.
|
||||
Not a standalone backend — compose around a real backend (VANILLA/TRTLLM).
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
inner_backend: nn.Module,
|
||||
process_group: Optional[torch.distributed.ProcessGroup] = None,
|
||||
):
|
||||
super().__init__()
|
||||
self.inner_backend = inner_backend
|
||||
self.process_group = process_group
|
||||
self._preferred_layout = AttentionTensorLayout.NHD
|
||||
|
||||
# Derive head info from inner backend
|
||||
self.head_dim = inner_backend.head_dim
|
||||
self.sharded_num_heads = inner_backend.num_heads
|
||||
self.sharded_num_kv_heads = getattr(inner_backend, "num_kv_heads", self.sharded_num_heads)
|
||||
|
||||
# Get world size from process group
|
||||
try:
|
||||
self.world_size = torch.distributed.get_world_size(group=process_group)
|
||||
except (RuntimeError, ValueError):
|
||||
self.world_size = 1
|
||||
|
||||
# Full (unsharded) head counts for external interface
|
||||
self.num_heads = self.sharded_num_heads * self.world_size
|
||||
self.num_kv_heads = self.sharded_num_kv_heads * self.world_size
|
||||
|
||||
def forward(
|
||||
self,
|
||||
q: torch.Tensor,
|
||||
k: torch.Tensor,
|
||||
v: torch.Tensor,
|
||||
batch_size: int,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
**kwargs,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Forward pass with Ulysses sequence parallelism.
|
||||
|
||||
Input/Output: [B, S/P, H, D] (sequence sharded)
|
||||
|
||||
Args:
|
||||
q: Query tensor [B, S/P, H, D]
|
||||
k: Key tensor [B, S/P, H, D]
|
||||
v: Value tensor [B, S/P, H, D]
|
||||
batch_size: Batch size
|
||||
attention_mask: Optional attention mask
|
||||
|
||||
Returns:
|
||||
Output tensor [B, S/P, H, D] (sequence sharded)
|
||||
|
||||
Note:
|
||||
seq_len is computed from tensor shape after all-to-all, not passed as parameter.
|
||||
"""
|
||||
# Step 1: All-to-All to gather full sequence, shard heads
|
||||
# [B, S/P, H, D] -> [B, S, H/P, D]
|
||||
if self.world_size > 1:
|
||||
q = all_to_all_4d(q, scatter_dim=2, gather_dim=1, process_group=self.process_group)
|
||||
k = all_to_all_4d(k, scatter_dim=2, gather_dim=1, process_group=self.process_group)
|
||||
v = all_to_all_4d(v, scatter_dim=2, gather_dim=1, process_group=self.process_group)
|
||||
|
||||
seq_len_full = q.shape[1]
|
||||
inner_layout = self.inner_backend.preferred_layout
|
||||
|
||||
# Step 2: Call wrapped backend for attention
|
||||
# Transpose only if inner backend expects HND layout
|
||||
if inner_layout == AttentionTensorLayout.HND:
|
||||
# VANILLA expects [B, H/P, S, D]
|
||||
q = q.transpose(1, 2)
|
||||
k = k.transpose(1, 2)
|
||||
v = v.transpose(1, 2)
|
||||
# NHD backends (TRTLLM) keep [B, S, H/P, D] as-is
|
||||
|
||||
inner_kwargs = dict(
|
||||
q=q,
|
||||
k=k,
|
||||
v=v,
|
||||
batch_size=batch_size,
|
||||
seq_len=seq_len_full,
|
||||
)
|
||||
if attention_mask is not None:
|
||||
inner_kwargs["attention_mask"] = attention_mask
|
||||
output = self.inner_backend.forward(**inner_kwargs)
|
||||
|
||||
# Convert output back to [B, S, H/P, D] for the reverse all-to-all
|
||||
if inner_layout == AttentionTensorLayout.HND:
|
||||
# VANILLA returns [B, H/P, S, D] -> transpose to [B, S, H/P, D]
|
||||
output = output.transpose(1, 2).contiguous()
|
||||
else:
|
||||
# TRTLLM returns [B, S, (H/P)*D] (3D) -> reshape to [B, S, H/P, D]
|
||||
if output.dim() == 3:
|
||||
output = output.view(
|
||||
batch_size, seq_len_full, self.sharded_num_heads, self.head_dim
|
||||
)
|
||||
output = output.contiguous()
|
||||
|
||||
# Step 3: All-to-All to restore sequence sharding
|
||||
# [B, S, H/P, D] -> [B, S/P, H, D]
|
||||
if self.world_size > 1:
|
||||
output = all_to_all_4d(
|
||||
output,
|
||||
scatter_dim=1,
|
||||
gather_dim=2,
|
||||
process_group=self.process_group,
|
||||
)
|
||||
|
||||
return output
|
||||
|
||||
@property
|
||||
def preferred_layout(self) -> AttentionTensorLayout:
|
||||
"""Preferred tensor layout: [B, S, H, D]"""
|
||||
return self._preferred_layout
|
||||
|
||||
@classmethod
|
||||
def support_fused_qkv(cls) -> bool:
|
||||
"""This backend does not support fused QKV."""
|
||||
return False
|
||||
244
tensorrt_llm/_torch/visual_gen/attention_backend/trtllm.py
Normal file
244
tensorrt_llm/_torch/visual_gen/attention_backend/trtllm.py
Normal file
@ -0,0 +1,244 @@
|
||||
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
Diffusion TRTLLM Attention Backend
|
||||
|
||||
Wraps TrtllmAttention with simplified metadata for visual generation (diffusion) models.
|
||||
Handles the specifics of no-KV-cache operation and fused QKV requirements.
|
||||
"""
|
||||
|
||||
from typing import Optional, Union
|
||||
|
||||
import torch
|
||||
|
||||
from tensorrt_llm.mapping import Mapping
|
||||
from tensorrt_llm.models.modeling_utils import QuantConfig
|
||||
|
||||
from ...attention_backend.interface import AttentionRuntimeFeatures, PredefinedAttentionMask
|
||||
from ...attention_backend.trtllm import TrtllmAttention as BaseTrtllmAttention
|
||||
from ...attention_backend.trtllm import TrtllmAttentionMetadata as BaseTrtllmAttentionMetadata
|
||||
from .interface import AttentionTensorLayout
|
||||
|
||||
|
||||
class TrtllmAttentionMetadata:
|
||||
"""
|
||||
Simplified metadata adapter for diffusion models using TRTLLM backend.
|
||||
|
||||
Lazy initialization with auto-growing capacity:
|
||||
- Metadata created only when capacity needs increase
|
||||
- prepare() called only when seq_lens actually change
|
||||
- Automatically reallocates when batch_size or seq_len exceeds current capacity
|
||||
|
||||
Args:
|
||||
max_batch_size: Initial batch size hint. Will grow automatically if exceeded.
|
||||
max_seq_len: Initial sequence length hint. Will grow automatically if exceeded.
|
||||
device: Target device for tensors.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
max_batch_size: int = 16,
|
||||
max_seq_len: int = 4096,
|
||||
device: Optional[torch.device] = None,
|
||||
):
|
||||
# These are initial hints, not hard limits - capacity grows as needed
|
||||
self.max_batch_size = max_batch_size
|
||||
self.max_seq_len = max_seq_len
|
||||
self.device = device or torch.device("cuda")
|
||||
|
||||
# Lazily created BaseTrtllmAttentionMetadata
|
||||
self._metadata: Optional[BaseTrtllmAttentionMetadata] = None
|
||||
|
||||
# Track allocated capacity
|
||||
self._allocated_batch_size = 0
|
||||
self._allocated_max_seq_len = 0
|
||||
|
||||
# Track prepared state
|
||||
self._cached_seq_lens: Optional[torch.Tensor] = None
|
||||
self._prepared = False
|
||||
|
||||
def _needs_new_metadata(self, batch_size: int, max_seq_len: int) -> bool:
|
||||
"""Check if we need to create new metadata (capacity change)."""
|
||||
return (
|
||||
self._metadata is None
|
||||
or batch_size > self._allocated_batch_size
|
||||
or max_seq_len > self._allocated_max_seq_len
|
||||
)
|
||||
|
||||
def _needs_prepare(self, batch_size: int, seq_lens: torch.Tensor) -> bool:
|
||||
"""Check if we need to call prepare() (seq_lens changed)."""
|
||||
if not self._prepared:
|
||||
return True
|
||||
if self._cached_seq_lens is None:
|
||||
return True
|
||||
if self._cached_seq_lens.shape[0] != batch_size:
|
||||
return True
|
||||
return not torch.equal(self._cached_seq_lens[:batch_size], seq_lens)
|
||||
|
||||
def _create_metadata(self, batch_size: int, max_seq_len: int) -> None:
|
||||
"""Create new metadata with given capacity."""
|
||||
# Allocate with some headroom to avoid frequent reallocation
|
||||
alloc_batch = max(batch_size, self._allocated_batch_size)
|
||||
alloc_seq_len = max(max_seq_len, self._allocated_max_seq_len)
|
||||
|
||||
self._metadata = BaseTrtllmAttentionMetadata(
|
||||
max_num_requests=alloc_batch,
|
||||
max_num_tokens=alloc_batch * alloc_seq_len,
|
||||
max_num_sequences=alloc_batch,
|
||||
kv_cache_manager=None, # No KV cache for diffusion
|
||||
mapping=Mapping(),
|
||||
runtime_features=AttentionRuntimeFeatures(),
|
||||
)
|
||||
|
||||
self._allocated_batch_size = alloc_batch
|
||||
self._allocated_max_seq_len = alloc_seq_len
|
||||
self._prepared = False # Reset prepare state on new metadata
|
||||
|
||||
def prepare(
|
||||
self,
|
||||
batch_size: int,
|
||||
seq_lens: Union[int, torch.Tensor],
|
||||
) -> BaseTrtllmAttentionMetadata:
|
||||
"""
|
||||
Prepare metadata for a forward pass.
|
||||
|
||||
Lazy behavior:
|
||||
- Creates metadata only when capacity needs increase
|
||||
- Calls prepare() only when seq_lens actually change
|
||||
"""
|
||||
if isinstance(seq_lens, int):
|
||||
seq_lens_tensor = torch.full((batch_size,), seq_lens, dtype=torch.int32)
|
||||
else:
|
||||
seq_lens_tensor = seq_lens.to(dtype=torch.int32)
|
||||
|
||||
max_seq_len = seq_lens_tensor.max().item()
|
||||
|
||||
if self._needs_new_metadata(batch_size, max_seq_len):
|
||||
self._create_metadata(batch_size, max_seq_len)
|
||||
|
||||
if self._needs_prepare(batch_size, seq_lens_tensor):
|
||||
self._metadata.seq_lens = seq_lens_tensor
|
||||
self._metadata.num_contexts = batch_size
|
||||
self._metadata.max_seq_len = max_seq_len
|
||||
self._metadata.request_ids = list(range(batch_size))
|
||||
self._metadata.prepare()
|
||||
|
||||
# Cache for next comparison
|
||||
if self._cached_seq_lens is None or self._cached_seq_lens.shape[0] < batch_size:
|
||||
self._cached_seq_lens = seq_lens_tensor.clone()
|
||||
else:
|
||||
self._cached_seq_lens[:batch_size].copy_(seq_lens_tensor)
|
||||
self._prepared = True
|
||||
|
||||
return self._metadata
|
||||
|
||||
|
||||
class TrtllmAttention(BaseTrtllmAttention):
|
||||
"""
|
||||
TRTLLM Attention wrapper for diffusion models.
|
||||
|
||||
Handles:
|
||||
- Fused QKV requirement for TRTLLM kernel
|
||||
- Metadata creation and preparation
|
||||
- No KV cache operation
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
layer_idx: int = 0,
|
||||
num_heads: int = 8,
|
||||
head_dim: int = 64,
|
||||
num_kv_heads: Optional[int] = None,
|
||||
quant_config: Optional[QuantConfig] = None,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
max_batch_size: int = 16,
|
||||
max_seq_len: int = 4096,
|
||||
):
|
||||
num_kv_heads = num_kv_heads or num_heads
|
||||
|
||||
super().__init__(
|
||||
layer_idx=layer_idx,
|
||||
num_heads=num_heads,
|
||||
num_kv_heads=num_kv_heads,
|
||||
head_dim=head_dim,
|
||||
quant_config=quant_config,
|
||||
dtype=dtype,
|
||||
)
|
||||
|
||||
# TRTLLM expects flat [B*S, H*D] format
|
||||
self._preferred_layout = AttentionTensorLayout.NHD
|
||||
|
||||
self.metadata = TrtllmAttentionMetadata(
|
||||
max_batch_size=max_batch_size,
|
||||
max_seq_len=max_seq_len,
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
q: torch.Tensor,
|
||||
k: torch.Tensor,
|
||||
v: torch.Tensor,
|
||||
batch_size: int,
|
||||
seq_len: int,
|
||||
attention_mask: PredefinedAttentionMask = PredefinedAttentionMask.FULL,
|
||||
seq_len_kv: Optional[int] = None,
|
||||
**kwargs,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Forward pass with automatic metadata handling.
|
||||
|
||||
For diffusion models, expects:
|
||||
- Fused QKV: q contains [Q, K, V] concatenated, k and v are None
|
||||
- OR separate Q, K, V which will be fused internally
|
||||
|
||||
Args:
|
||||
q: Query tensor [num_tokens, hidden] or fused QKV [num_tokens, qkv_hidden]
|
||||
k: Key tensor [num_tokens, kv_hidden] or None if fused
|
||||
v: Value tensor [num_tokens, kv_hidden] or None if fused
|
||||
batch_size: Batch size (required if not inferable)
|
||||
seq_len: Sequence length for Q (required if not inferable)
|
||||
attention_mask: Attention mask type
|
||||
seq_len_kv: Sequence length for K/V (for cross-attention, defaults to seq_len)
|
||||
|
||||
Returns:
|
||||
Output tensor [num_tokens, q_hidden]
|
||||
"""
|
||||
# Handle cross-attention where K/V have different sequence length than Q
|
||||
kv_seq_len = seq_len_kv if seq_len_kv is not None else seq_len
|
||||
|
||||
# Separate Q, K, V provided - fuse them
|
||||
q = q.view(batch_size * seq_len, -1)
|
||||
k = k.view(batch_size * kv_seq_len, -1)
|
||||
v = v.view(batch_size * kv_seq_len, -1)
|
||||
qkv = torch.cat([q, k, v], dim=-1)
|
||||
prepared_metadata = self.metadata.prepare(batch_size, seq_len)
|
||||
output = super().forward(
|
||||
q=qkv,
|
||||
k=None,
|
||||
v=None,
|
||||
metadata=prepared_metadata,
|
||||
attention_mask=attention_mask,
|
||||
)
|
||||
output = output.view(batch_size, seq_len, -1)
|
||||
return output
|
||||
|
||||
@property
|
||||
def preferred_layout(self) -> AttentionTensorLayout:
|
||||
"""Return the preferred tensor layout for this backend."""
|
||||
return self._preferred_layout
|
||||
|
||||
@classmethod
|
||||
def support_fused_qkv(cls) -> bool:
|
||||
return True
|
||||
118
tensorrt_llm/_torch/visual_gen/attention_backend/utils.py
Normal file
118
tensorrt_llm/_torch/visual_gen/attention_backend/utils.py
Normal file
@ -0,0 +1,118 @@
|
||||
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
Visual Generation Attention Backend Utilities
|
||||
|
||||
Factory functions for creating attention backends for visual generation models.
|
||||
Uses diffusion-specific wrappers (TrtllmAttention, VanillaAttention)
|
||||
that handle metadata preparation internally for simplified usage.
|
||||
"""
|
||||
|
||||
from typing import TYPE_CHECKING, Optional, Type, Union
|
||||
|
||||
import torch
|
||||
|
||||
from tensorrt_llm.models.modeling_utils import QuantConfig
|
||||
|
||||
# Lazy imports to avoid circular dependency
|
||||
if TYPE_CHECKING:
|
||||
from .trtllm import TrtllmAttention
|
||||
from .vanilla import VanillaAttention
|
||||
|
||||
# Type alias for diffusion attention backends
|
||||
DiffusionAttentionBackend = Union[TrtllmAttention, VanillaAttention]
|
||||
|
||||
|
||||
def get_visual_gen_attention_backend(
|
||||
backend_name: str,
|
||||
) -> Type["DiffusionAttentionBackend"]:
|
||||
"""
|
||||
Get diffusion attention backend class by name.
|
||||
|
||||
Args:
|
||||
backend_name: Backend identifier ("VANILLA", "TRTLLM")
|
||||
|
||||
Returns:
|
||||
Diffusion attention backend class
|
||||
|
||||
Backend Selection Guide:
|
||||
- "VANILLA": Full support for cross-attention (different Q/KV seq lengths)
|
||||
Uses torch SDPA backend
|
||||
- "TRTLLM": Optimized for self-attention (requires same Q/KV seq lengths)
|
||||
Better performance but requires fused QKV
|
||||
"""
|
||||
# Lazy imports to avoid circular dependency
|
||||
from .trtllm import TrtllmAttention
|
||||
from .vanilla import VanillaAttention
|
||||
|
||||
backend_name = backend_name.upper()
|
||||
|
||||
if backend_name == "VANILLA":
|
||||
return VanillaAttention
|
||||
elif backend_name == "TRTLLM":
|
||||
return TrtllmAttention
|
||||
else:
|
||||
# Default to VANILLA for maximum compatibility
|
||||
return VanillaAttention
|
||||
|
||||
|
||||
def create_attention(
|
||||
backend: str,
|
||||
layer_idx: int,
|
||||
num_heads: int,
|
||||
head_dim: int,
|
||||
num_kv_heads: Optional[int] = None,
|
||||
quant_config: Optional[QuantConfig] = None,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
max_batch_size: int = 16,
|
||||
max_seq_len: int = 4096,
|
||||
**kwargs,
|
||||
) -> "DiffusionAttentionBackend":
|
||||
"""
|
||||
Factory function to create attention backend instance for visual generation.
|
||||
|
||||
Creates diffusion-specific attention backends that handle metadata preparation
|
||||
internally, simplifying the forward() call.
|
||||
|
||||
Args:
|
||||
backend: Backend identifier ("VANILLA", "TRTLLM")
|
||||
layer_idx: Layer index in the model
|
||||
num_heads: Number of attention heads
|
||||
head_dim: Dimension per head
|
||||
num_kv_heads: Number of KV heads (for GQA/MQA, defaults to num_heads)
|
||||
quant_config: Optional quantization configuration
|
||||
dtype: Data type for the attention
|
||||
max_batch_size: Initial batch size for metadata pre-allocation. The backend
|
||||
will automatically reallocate if larger batches are encountered.
|
||||
max_seq_len: Initial sequence length for metadata pre-allocation. The backend
|
||||
will automatically reallocate if longer sequences are encountered.
|
||||
**kwargs: Additional backend-specific arguments
|
||||
|
||||
Returns:
|
||||
Diffusion attention backend instance (TrtllmAttention or VanillaAttention)
|
||||
"""
|
||||
attn_cls = get_visual_gen_attention_backend(backend)
|
||||
|
||||
return attn_cls(
|
||||
layer_idx=layer_idx,
|
||||
num_heads=num_heads,
|
||||
head_dim=head_dim,
|
||||
num_kv_heads=num_kv_heads,
|
||||
quant_config=quant_config,
|
||||
dtype=dtype,
|
||||
max_batch_size=max_batch_size,
|
||||
max_seq_len=max_seq_len,
|
||||
**kwargs,
|
||||
)
|
||||
126
tensorrt_llm/_torch/visual_gen/attention_backend/vanilla.py
Normal file
126
tensorrt_llm/_torch/visual_gen/attention_backend/vanilla.py
Normal file
@ -0,0 +1,126 @@
|
||||
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
Diffusion Vanilla Attention Backend
|
||||
|
||||
Simple attention implementation for visual generation (diffusion) models using
|
||||
torch.nn.functional.scaled_dot_product_attention (SDPA).
|
||||
|
||||
Supports both self-attention and cross-attention (different Q/KV sequence lengths).
|
||||
No KV cache - full recompute each diffusion step.
|
||||
"""
|
||||
|
||||
import math
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from ...attention_backend.interface import PredefinedAttentionMask
|
||||
from .interface import AttentionTensorLayout
|
||||
|
||||
|
||||
class VanillaAttention(nn.Module):
|
||||
"""
|
||||
Vanilla Attention for diffusion models using torch SDPA.
|
||||
|
||||
Uses torch.nn.functional.scaled_dot_product_attention which:
|
||||
- Properly handles cross-attention (different Q/KV sequence lengths)
|
||||
- Uses Flash Attention 2 when available (via SDPA backend selection)
|
||||
- No KV cache needed for diffusion models
|
||||
|
||||
This is simpler than the LLM VanillaAttention which has complex
|
||||
KV cache handling and uses flash_attn_varlen_func.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
layer_idx: int = 0,
|
||||
num_heads: int = 8,
|
||||
head_dim: int = 64,
|
||||
num_kv_heads: Optional[int] = None,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.layer_idx = layer_idx
|
||||
self.num_heads = num_heads
|
||||
self.head_dim = head_dim
|
||||
self.num_kv_heads = num_kv_heads or num_heads
|
||||
self.dtype = dtype
|
||||
self.scale = 1.0 / math.sqrt(head_dim)
|
||||
|
||||
# SDPA expects [B, H, S, D] format
|
||||
self._preferred_layout = AttentionTensorLayout.HND
|
||||
|
||||
def forward(
|
||||
self,
|
||||
q: torch.Tensor,
|
||||
k: torch.Tensor,
|
||||
v: torch.Tensor,
|
||||
batch_size: int,
|
||||
seq_len: int,
|
||||
seq_len_kv: Optional[int] = None,
|
||||
attention_mask: PredefinedAttentionMask = PredefinedAttentionMask.FULL,
|
||||
**kwargs,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Forward pass using torch SDPA.
|
||||
|
||||
Args:
|
||||
q: Query tensor [num_tokens, num_heads * head_dim]
|
||||
k: Key tensor [num_kv_tokens, num_kv_heads * head_dim]
|
||||
v: Value tensor [num_kv_tokens, num_kv_heads * head_dim]
|
||||
batch_size: Batch size
|
||||
seq_len: Query sequence length
|
||||
seq_len_kv: KV sequence length (for cross-attention)
|
||||
attention_mask: Attention mask type (CAUSAL or FULL)
|
||||
|
||||
Returns:
|
||||
Output tensor [num_tokens, num_heads * head_dim]
|
||||
"""
|
||||
is_causal = attention_mask == PredefinedAttentionMask.CAUSAL
|
||||
|
||||
# Validate tensor shapes - flexible for Ulysses head sharding
|
||||
# Expected: [batch_size, num_heads, seq_len, head_dim]
|
||||
# Note: num_heads may be sharded (num_heads // ulysses_size) when using Ulysses
|
||||
assert (
|
||||
q.dim() == 4
|
||||
and q.shape[0] == batch_size
|
||||
and q.shape[2] == seq_len
|
||||
and q.shape[3] == self.head_dim
|
||||
), (
|
||||
f"Invalid q shape: expected [B={batch_size}, H, S={seq_len}, D={self.head_dim}], got {q.shape}"
|
||||
)
|
||||
assert k.dim() == 4 and k.shape[0] == batch_size and k.shape[3] == self.head_dim, (
|
||||
f"Invalid k shape: expected [B={batch_size}, H_kv, S_kv, D={self.head_dim}], got {k.shape}"
|
||||
)
|
||||
assert v.dim() == 4 and v.shape[0] == batch_size and v.shape[3] == self.head_dim, (
|
||||
f"Invalid v shape: expected [B={batch_size}, H_kv, S_kv, D={self.head_dim}], got {v.shape}"
|
||||
)
|
||||
|
||||
# TODO: Maybe we need to enforce cuDNN backend here
|
||||
return F.scaled_dot_product_attention(q, k, v, is_causal=is_causal, scale=self.scale)
|
||||
|
||||
@property
|
||||
def preferred_layout(self) -> AttentionTensorLayout:
|
||||
"""Return the preferred tensor layout for this backend."""
|
||||
return self._preferred_layout
|
||||
|
||||
@classmethod
|
||||
def support_fused_qkv(cls) -> bool:
|
||||
return False
|
||||
7
tensorrt_llm/_torch/visual_gen/checkpoints/__init__.py
Normal file
7
tensorrt_llm/_torch/visual_gen/checkpoints/__init__.py
Normal file
@ -0,0 +1,7 @@
|
||||
"""Diffusion model checkpoint loading utilities."""
|
||||
|
||||
from .weight_loader import WeightLoader
|
||||
|
||||
__all__ = [
|
||||
"WeightLoader",
|
||||
]
|
||||
152
tensorrt_llm/_torch/visual_gen/checkpoints/weight_loader.py
Normal file
152
tensorrt_llm/_torch/visual_gen/checkpoints/weight_loader.py
Normal file
@ -0,0 +1,152 @@
|
||||
"""Weight loader for diffusion models."""
|
||||
|
||||
import json
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Union
|
||||
|
||||
import torch
|
||||
import tqdm
|
||||
|
||||
from tensorrt_llm._torch.models.checkpoints.base_weight_loader import BaseWeightLoader
|
||||
from tensorrt_llm._torch.visual_gen.config import PipelineComponent
|
||||
from tensorrt_llm.logger import logger
|
||||
from tensorrt_llm.mapping import Mapping
|
||||
|
||||
|
||||
class WeightLoader(BaseWeightLoader):
|
||||
"""
|
||||
Weight loader for diffusion models.
|
||||
|
||||
Loads weights from safetensors/bin files, similar to HfWeightLoader
|
||||
but simpler (no parallel loading optimization for now).
|
||||
|
||||
Supports loading multiple components (e.g., transformer and transformer_2):
|
||||
loader = WeightLoader(components=["transformer", "transformer_2"])
|
||||
weights = loader.load_weights(ckpt_dir, mapping)
|
||||
# Returns: {"transformer": {...}, "transformer_2": {...}}
|
||||
"""
|
||||
|
||||
def __init__(self, components: Union[str, List[str]] = PipelineComponent.TRANSFORMER):
|
||||
"""
|
||||
Args:
|
||||
components: Component(s) to load weights for. Can be:
|
||||
- Single string: "transformer" (returns flat dict)
|
||||
- List of strings: ["transformer", "transformer_2"] (returns nested dict)
|
||||
"""
|
||||
if isinstance(components, str):
|
||||
self.components = [components]
|
||||
self.single_component = True
|
||||
else:
|
||||
self.components = components
|
||||
self.single_component = False
|
||||
|
||||
def load_weights(
|
||||
self,
|
||||
checkpoint_dir: str,
|
||||
mapping: Mapping,
|
||||
**kwargs,
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Load weights from checkpoint directory.
|
||||
|
||||
Args:
|
||||
checkpoint_dir: Path to checkpoint (pipeline root or component dir)
|
||||
mapping: Distributed mapping (for future TP/PP support)
|
||||
|
||||
Returns:
|
||||
- If single component: Dict mapping weight names to tensors
|
||||
- If multiple components: Dict mapping component names to weight dicts
|
||||
Example: {"transformer": {...}, "transformer_2": {...}}
|
||||
"""
|
||||
checkpoint_path = Path(checkpoint_dir)
|
||||
|
||||
# Check if this is a pipeline (has model_index.json)
|
||||
model_index = checkpoint_path / "model_index.json"
|
||||
is_pipeline = model_index.exists()
|
||||
|
||||
# Load weights for each component
|
||||
all_weights = {}
|
||||
for component in self.components:
|
||||
if is_pipeline:
|
||||
# Pipeline format: load from component subdirectory
|
||||
component_dir = checkpoint_path / component
|
||||
if not component_dir.exists():
|
||||
raise ValueError(f"Component '{component}' not found in {checkpoint_dir}")
|
||||
weight_dir = component_dir
|
||||
else:
|
||||
# Standalone model (only valid for single component)
|
||||
if len(self.components) > 1:
|
||||
raise ValueError(
|
||||
f"Multiple components specified but {checkpoint_dir} is not a pipeline "
|
||||
"(no model_index.json found)"
|
||||
)
|
||||
weight_dir = checkpoint_path
|
||||
|
||||
# Find weight files
|
||||
weight_files = self._find_weight_files(weight_dir)
|
||||
if not weight_files:
|
||||
raise ValueError(f"No weight files found in {weight_dir}")
|
||||
|
||||
# Load all weights with progress bar
|
||||
component_weights = {}
|
||||
desc = f"Loading {component}" if is_pipeline else "Loading checkpoint"
|
||||
for wf in tqdm.tqdm(weight_files, desc=desc):
|
||||
component_weights.update(self._load_file(wf))
|
||||
|
||||
all_weights[component] = component_weights
|
||||
|
||||
# Return flat dict for single component (backward compatibility)
|
||||
if self.single_component:
|
||||
return all_weights[self.components[0]]
|
||||
|
||||
# Return nested dict for multiple components
|
||||
return all_weights
|
||||
|
||||
def _find_weight_files(self, weight_dir) -> List[str]:
|
||||
"""Find safetensors or bin weight files.
|
||||
|
||||
Handles:
|
||||
- Single safetensors file
|
||||
- Sharded safetensors with index.json
|
||||
- PyTorch bin/pth files
|
||||
"""
|
||||
weight_dir = Path(weight_dir)
|
||||
|
||||
# Check for sharded safetensors index
|
||||
index_file = weight_dir / "diffusion_pytorch_model.safetensors.index.json"
|
||||
if not index_file.exists():
|
||||
index_file = weight_dir / "model.safetensors.index.json"
|
||||
|
||||
if index_file.exists():
|
||||
# Sharded safetensors: read index to get all shard files
|
||||
with open(index_file) as f:
|
||||
index = json.load(f)
|
||||
shard_files = set(index.get("weight_map", {}).values())
|
||||
return sorted([str(weight_dir / f) for f in shard_files])
|
||||
|
||||
# Single safetensors file
|
||||
files = list(weight_dir.glob("*.safetensors"))
|
||||
if files:
|
||||
# Filter out consolidated if multiple files exist
|
||||
if len(files) > 1:
|
||||
files = [f for f in files if "consolidated" not in f.name]
|
||||
return sorted([str(f) for f in files])
|
||||
|
||||
# Fallback to bin
|
||||
files = list(weight_dir.glob("*.bin"))
|
||||
if files:
|
||||
return sorted([str(f) for f in files])
|
||||
|
||||
# Fallback to pth
|
||||
files = list(weight_dir.glob("*.pth"))
|
||||
return sorted([str(f) for f in files])
|
||||
|
||||
def _load_file(self, filepath: str) -> Dict[str, Any]:
|
||||
"""Load weights from a single file."""
|
||||
logger.debug(f"Loading {filepath}")
|
||||
if filepath.endswith(".safetensors"):
|
||||
from safetensors.torch import load_file
|
||||
|
||||
return load_file(filepath)
|
||||
else:
|
||||
return torch.load(filepath, map_location="cpu", weights_only=True)
|
||||
565
tensorrt_llm/_torch/visual_gen/config.py
Normal file
565
tensorrt_llm/_torch/visual_gen/config.py
Normal file
@ -0,0 +1,565 @@
|
||||
import json
|
||||
import os
|
||||
from enum import Enum
|
||||
from pathlib import Path
|
||||
from types import SimpleNamespace
|
||||
from typing import Any, Dict, List, Literal, Optional, Tuple
|
||||
|
||||
import torch
|
||||
from pydantic import BaseModel, ConfigDict, model_validator
|
||||
from pydantic import Field as PydanticField
|
||||
|
||||
from tensorrt_llm.functional import AllReduceStrategy
|
||||
from tensorrt_llm.mapping import Mapping
|
||||
from tensorrt_llm.models.modeling_utils import QuantConfig
|
||||
from tensorrt_llm.quantization.mode import QuantAlgo
|
||||
|
||||
# =============================================================================
|
||||
# Pipeline component identifiers
|
||||
# =============================================================================
|
||||
|
||||
|
||||
class PipelineComponent(str, Enum):
|
||||
"""Identifiers for pipeline components that can be loaded or skipped.
|
||||
|
||||
Inherits from str so values compare equal to plain strings,
|
||||
e.g. ``PipelineComponent.VAE == "vae"`` is ``True``.
|
||||
"""
|
||||
|
||||
TRANSFORMER = "transformer"
|
||||
VAE = "vae"
|
||||
TEXT_ENCODER = "text_encoder"
|
||||
TOKENIZER = "tokenizer"
|
||||
SCHEDULER = "scheduler"
|
||||
IMAGE_ENCODER = "image_encoder"
|
||||
IMAGE_PROCESSOR = "image_processor"
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Sub-configuration classes for DiffusionArgs
|
||||
# =============================================================================
|
||||
|
||||
|
||||
class AttentionConfig(BaseModel):
|
||||
"""Configuration for Attention layers."""
|
||||
|
||||
backend: Literal["VANILLA", "TRTLLM"] = PydanticField(
|
||||
"VANILLA", description="Attention backend: VANILLA (PyTorch SDPA), TRTLLM"
|
||||
)
|
||||
|
||||
|
||||
class ParallelConfig(BaseModel):
|
||||
"""Configuration for distributed parallelism.
|
||||
|
||||
Currently Supported:
|
||||
- dit_cfg_size: CFG (Classifier-Free Guidance) parallelism
|
||||
- dit_ulysses_size: Ulysses sequence parallelism
|
||||
|
||||
Not Yet Supported:
|
||||
- dit_tp_size: Tensor parallelism (not implemented)
|
||||
- dit_ring_size: Ring attention (not implemented)
|
||||
- dit_cp_size, dit_dp_size, dit_fsdp_size: Other parallelism types
|
||||
|
||||
Total world_size = dit_cfg_size × dit_ulysses_size
|
||||
|
||||
Parallelism Strategy:
|
||||
- CFG Parallelism: Distributes positive/negative prompts across GPUs
|
||||
- Ulysses Parallelism: Distributes sequence within each CFG group
|
||||
|
||||
Example Configurations:
|
||||
1. cfg_size=1, ulysses_size=2 -> 2 GPUs (Ulysses only)
|
||||
GPU 0-1: Single prompt, sequence parallelism across 2 GPUs
|
||||
|
||||
2. cfg_size=2, ulysses_size=1 -> 2 GPUs (CFG only)
|
||||
GPU 0: Positive prompt
|
||||
GPU 1: Negative prompt
|
||||
|
||||
3. cfg_size=2, ulysses_size=2 -> 4 GPUs (CFG + Ulysses)
|
||||
GPU 0-1: CFG group 0 (positive), Ulysses parallel
|
||||
GPU 2-3: CFG group 1 (negative), Ulysses parallel
|
||||
|
||||
4. cfg_size=2, ulysses_size=4 -> 8 GPUs (CFG + Ulysses)
|
||||
GPU 0-3: CFG group 0 (positive), Ulysses parallel
|
||||
GPU 4-7: CFG group 1 (negative), Ulysses parallel
|
||||
"""
|
||||
|
||||
disable_parallel_vae: bool = False
|
||||
parallel_vae_split_dim: Literal["width", "height"] = "width"
|
||||
|
||||
# DiT Parallelism
|
||||
dit_dp_size: int = PydanticField(1, ge=1)
|
||||
dit_tp_size: int = PydanticField(1, ge=1) # Not yet supported
|
||||
dit_ulysses_size: int = PydanticField(1, ge=1) # Supported
|
||||
dit_ring_size: int = PydanticField(1, ge=1) # Not yet supported
|
||||
dit_cp_size: int = PydanticField(1, ge=1)
|
||||
dit_cfg_size: int = PydanticField(1, ge=1) # Supported
|
||||
dit_fsdp_size: int = PydanticField(1, ge=1)
|
||||
|
||||
# Refiner Parallelism (Optional)
|
||||
refiner_dit_dp_size: int = 1
|
||||
refiner_dit_tp_size: int = 1
|
||||
refiner_dit_ulysses_size: int = 1
|
||||
refiner_dit_ring_size: int = 1
|
||||
refiner_dit_cp_size: int = 1
|
||||
refiner_dit_cfg_size: int = 1
|
||||
refiner_dit_fsdp_size: int = 1
|
||||
|
||||
t5_fsdp_size: int = 1
|
||||
|
||||
def to_mapping(self) -> Mapping:
|
||||
"""Convert to TRT-LLM Mapping."""
|
||||
world_size = self.dit_tp_size * self.dit_cp_size
|
||||
return Mapping(
|
||||
world_size=world_size,
|
||||
tp_size=self.dit_tp_size,
|
||||
pp_size=1,
|
||||
cp_size=self.dit_cp_size,
|
||||
)
|
||||
|
||||
@model_validator(mode="after")
|
||||
def validate_parallel_sizes(self) -> "ParallelConfig":
|
||||
"""Validate configuration against current environment."""
|
||||
if torch.cuda.is_available():
|
||||
world_size = int(os.environ.get("WORLD_SIZE", 1))
|
||||
total_parallel = (
|
||||
self.dit_tp_size
|
||||
* self.dit_ulysses_size
|
||||
* self.dit_ring_size
|
||||
* self.dit_cp_size
|
||||
* self.dit_dp_size
|
||||
* self.dit_cfg_size
|
||||
)
|
||||
if total_parallel > world_size:
|
||||
raise ValueError(
|
||||
f"Total DiT parallel size ({total_parallel}) exceeds WORLD_SIZE ({world_size})"
|
||||
)
|
||||
return self
|
||||
|
||||
|
||||
class TeaCacheConfig(BaseModel):
|
||||
"""Configuration for TeaCache runtime optimization.
|
||||
|
||||
TeaCache speeds up diffusion by caching transformer outputs when timestep
|
||||
embeddings change slowly. It monitors embedding distances and reuses cached
|
||||
residuals when changes are below a threshold.
|
||||
|
||||
Attributes:
|
||||
enable_teacache: Enable TeaCache optimization
|
||||
teacache_thresh: Distance threshold for cache decisions (lower = more caching)
|
||||
use_ret_steps: Use aggressive warmup mode (5 steps) vs minimal (1 step)
|
||||
coefficients: Polynomial coefficients for rescaling embedding distances
|
||||
Applied as: rescaled_distance = poly(raw_distance)
|
||||
ret_steps: Number of warmup steps (always compute, initialized at runtime)
|
||||
cutoff_steps: Step to stop caching (always compute after, initialized at runtime)
|
||||
num_steps: Total inference steps (set at runtime)
|
||||
_cnt: Internal step counter (reset per generation)
|
||||
"""
|
||||
|
||||
enable_teacache: bool = False
|
||||
teacache_thresh: float = PydanticField(0.2, gt=0.0)
|
||||
use_ret_steps: bool = True
|
||||
|
||||
coefficients: List[float] = PydanticField(default_factory=lambda: [1.0, 0.0])
|
||||
|
||||
# Runtime state fields (initialized by TeaCacheBackend.refresh)
|
||||
ret_steps: Optional[int] = None
|
||||
cutoff_steps: Optional[int] = None
|
||||
num_steps: Optional[int] = None
|
||||
|
||||
# State tracking (reset per generation)
|
||||
_cnt: int = 0
|
||||
|
||||
model_config = ConfigDict(arbitrary_types_allowed=True)
|
||||
|
||||
@model_validator(mode="after")
|
||||
def validate_teacache(self) -> "TeaCacheConfig":
|
||||
"""Validate TeaCache configuration."""
|
||||
# Validate coefficients
|
||||
if len(self.coefficients) == 0:
|
||||
raise ValueError("TeaCache coefficients list cannot be empty")
|
||||
|
||||
# Validate ret_steps if set
|
||||
if self.ret_steps is not None and self.ret_steps < 0:
|
||||
raise ValueError(f"ret_steps must be non-negative, got {self.ret_steps}")
|
||||
|
||||
# Validate cutoff_steps vs num_steps if both set
|
||||
if self.cutoff_steps is not None and self.num_steps is not None:
|
||||
if self.cutoff_steps > self.num_steps:
|
||||
raise ValueError(
|
||||
f"cutoff_steps ({self.cutoff_steps}) cannot exceed num_steps ({self.num_steps})"
|
||||
)
|
||||
|
||||
return self
|
||||
|
||||
|
||||
class PipelineConfig(BaseModel):
|
||||
"""General pipeline configuration."""
|
||||
|
||||
enable_torch_compile: bool = True
|
||||
torch_compile_models: str = PipelineComponent.TRANSFORMER
|
||||
torch_compile_mode: str = "default"
|
||||
fuse_qkv: bool = True
|
||||
|
||||
# Offloading Config
|
||||
enable_offloading: bool = False
|
||||
offload_device: Literal["cpu", "cuda"] = "cpu"
|
||||
offload_param_pin_memory: bool = True
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# DiffusionArgs - User-facing configuration (CLI / YAML)
|
||||
# =============================================================================
|
||||
|
||||
|
||||
class DiffusionArgs(BaseModel):
|
||||
"""User-facing configuration for diffusion model loading and inference.
|
||||
|
||||
This is the main config class used in CLI args and YAML config files.
|
||||
PipelineLoader converts this to DiffusionModelConfig internally.
|
||||
|
||||
Example:
|
||||
args = DiffusionArgs(
|
||||
checkpoint_path="/path/to/model",
|
||||
quant_config={"quant_algo": "FP8_BLOCK_SCALES", "dynamic": True},
|
||||
parallel=ParallelConfig(dit_tp_size=2),
|
||||
)
|
||||
loader = PipelineLoader()
|
||||
pipeline = loader.load(args)
|
||||
"""
|
||||
|
||||
model_config = ConfigDict(arbitrary_types_allowed=True)
|
||||
|
||||
# Required: Path to checkpoint or HuggingFace Hub model ID
|
||||
checkpoint_path: str = PydanticField(
|
||||
"",
|
||||
description=(
|
||||
"Local directory path or HuggingFace Hub model ID "
|
||||
"(e.g., 'Wan-AI/Wan2.1-T2V-1.3B-Diffusers'). "
|
||||
"Hub models are downloaded and cached automatically."
|
||||
),
|
||||
)
|
||||
|
||||
# HuggingFace Hub options
|
||||
revision: Optional[str] = PydanticField(
|
||||
None,
|
||||
description="HuggingFace Hub revision (branch, tag, or commit SHA) to download.",
|
||||
)
|
||||
|
||||
# Device/dtype options
|
||||
device: str = "cuda"
|
||||
dtype: str = "bfloat16"
|
||||
|
||||
# Component loading options (use PipelineComponent enum values or plain strings)
|
||||
skip_components: List[PipelineComponent] = PydanticField(
|
||||
default_factory=list,
|
||||
description=(
|
||||
"Components to skip loading. "
|
||||
"Accepts PipelineComponent enum values or equivalent strings "
|
||||
"(e.g., [PipelineComponent.TEXT_ENCODER, PipelineComponent.VAE])"
|
||||
),
|
||||
)
|
||||
|
||||
# Sub-configs (dict input for quant_config is coerced to QuantConfig in model_validator)
|
||||
quant_config: QuantConfig = PydanticField(default_factory=QuantConfig)
|
||||
pipeline: PipelineConfig = PydanticField(default_factory=PipelineConfig)
|
||||
attention: AttentionConfig = PydanticField(default_factory=AttentionConfig)
|
||||
parallel: ParallelConfig = PydanticField(default_factory=ParallelConfig)
|
||||
teacache: TeaCacheConfig = PydanticField(default_factory=TeaCacheConfig)
|
||||
|
||||
# Set by model_validator when quant_config is provided as a dict (ModelOpt format)
|
||||
dynamic_weight_quant: bool = False
|
||||
force_dynamic_quantization: bool = False
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def _parse_quant_config_dict(cls, data: Any) -> Any:
|
||||
"""Parse user-facing DiffusionArgs.quant_config (dict or None) into QuantConfig and dynamic flags.
|
||||
|
||||
User input is ModelOpt-format dict (e.g. {"quant_algo": "FP8", "dynamic": True}).
|
||||
We coerce it to QuantConfig + dynamic_weight_quant + force_dynamic_quantization so that
|
||||
from_pretrained() can copy them into DiffusionModelConfig (internal) without parsing again.
|
||||
"""
|
||||
if not isinstance(data, dict):
|
||||
return data
|
||||
raw = data.get("quant_config")
|
||||
if raw is None:
|
||||
data = {**data, "quant_config": QuantConfig()}
|
||||
return data
|
||||
if not isinstance(raw, dict):
|
||||
return data
|
||||
qc, _, dwq, daq = DiffusionModelConfig.load_diffusion_quant_config(raw)
|
||||
data = {
|
||||
**data,
|
||||
"quant_config": qc,
|
||||
"dynamic_weight_quant": dwq,
|
||||
"force_dynamic_quantization": daq,
|
||||
}
|
||||
return data
|
||||
|
||||
def to_mapping(self) -> Mapping:
|
||||
"""Derive Mapping from ParallelConfig."""
|
||||
return self.parallel.to_mapping()
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
"""Convert to dictionary."""
|
||||
return self.model_dump()
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, config_dict: Dict[str, Any]) -> "DiffusionArgs":
|
||||
"""Create from dictionary with automatic nested config parsing.
|
||||
|
||||
Pydantic automatically handles nested configs, but we keep this method
|
||||
for backward compatibility and to filter unknown fields.
|
||||
"""
|
||||
# Get valid field names for DiffusionArgs
|
||||
valid_fields = set(cls.model_fields.keys())
|
||||
|
||||
# Filter to only include valid fields (ignore unknown fields)
|
||||
filtered_dict = {k: v for k, v in config_dict.items() if k in valid_fields}
|
||||
|
||||
# Pydantic automatically converts nested dicts to their respective config classes
|
||||
return cls(**filtered_dict)
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Utilities
|
||||
# =============================================================================
|
||||
|
||||
|
||||
def discover_pipeline_components(checkpoint_path: Path) -> Dict[str, Path]:
|
||||
"""
|
||||
Discover components from diffusers pipeline's model_index.json.
|
||||
|
||||
Returns dict mapping component name to config.json path.
|
||||
"""
|
||||
model_index_path = checkpoint_path / "model_index.json"
|
||||
if not model_index_path.exists():
|
||||
return {}
|
||||
|
||||
with open(model_index_path) as f:
|
||||
model_index = json.load(f)
|
||||
|
||||
components = {}
|
||||
for key, value in model_index.items():
|
||||
if key.startswith("_") or value is None:
|
||||
continue
|
||||
config_path = checkpoint_path / key / "config.json"
|
||||
if config_path.exists():
|
||||
components[key] = config_path
|
||||
|
||||
return components
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# DiffusionModelConfig - Internal configuration (merged/parsed)
|
||||
# =============================================================================
|
||||
|
||||
|
||||
class DiffusionModelConfig(BaseModel):
|
||||
"""Internal ModelConfig for diffusion models.
|
||||
|
||||
This is created by PipelineLoader from DiffusionArgs + checkpoint.
|
||||
Contains merged/parsed config from:
|
||||
- pretrained_config: From checkpoint/config.json
|
||||
- quant_config: From checkpoint or user quant config
|
||||
- Sub-configs: From DiffusionArgs (pipeline, attention, parallel, teacache)
|
||||
"""
|
||||
|
||||
model_config = ConfigDict(arbitrary_types_allowed=True)
|
||||
|
||||
pretrained_config: Optional[Any] = None
|
||||
mapping: Mapping = PydanticField(default_factory=Mapping)
|
||||
skip_create_weights_in_init: bool = False
|
||||
force_dynamic_quantization: bool = False
|
||||
allreduce_strategy: AllReduceStrategy = PydanticField(default=AllReduceStrategy.AUTO)
|
||||
extra_attrs: Dict = PydanticField(default_factory=dict)
|
||||
|
||||
# Distributed process groups
|
||||
ulysses_process_group: Optional[torch.distributed.ProcessGroup] = None
|
||||
|
||||
dynamic_weight_quant: bool = False
|
||||
|
||||
# Sub-configs from DiffusionArgs (merged during from_pretrained)
|
||||
quant_config: QuantConfig = PydanticField(default_factory=QuantConfig)
|
||||
# Per-layer quant (from load_diffusion_quant_config layer_quant_config; None until mixed-precision parsing exists)
|
||||
quant_config_dict: Optional[Dict[str, QuantConfig]] = None
|
||||
pipeline: PipelineConfig = PydanticField(default_factory=PipelineConfig)
|
||||
attention: AttentionConfig = PydanticField(default_factory=AttentionConfig)
|
||||
parallel: ParallelConfig = PydanticField(default_factory=ParallelConfig)
|
||||
teacache: TeaCacheConfig = PydanticField(default_factory=TeaCacheConfig)
|
||||
|
||||
@property
|
||||
def torch_dtype(self) -> "torch.dtype":
|
||||
"""Get the torch dtype of the model (default: bfloat16)."""
|
||||
return torch.bfloat16
|
||||
|
||||
def get_quant_config(self, name: Optional[str] = None) -> QuantConfig:
|
||||
"""Get quantization config for a layer or global. Resembles LLM ModelConfig.get_quant_config."""
|
||||
if name is None or self.quant_config_dict is None:
|
||||
return self.quant_config
|
||||
if name in self.quant_config_dict:
|
||||
return self.quant_config_dict[name]
|
||||
return self.quant_config
|
||||
|
||||
@staticmethod
|
||||
def load_diffusion_quant_config(
|
||||
quant_config_dict: dict,
|
||||
) -> Tuple[QuantConfig, Optional[Dict], bool, bool]:
|
||||
"""
|
||||
Parse quantization config in ModelOpt format.
|
||||
|
||||
Returns: (quant_config, layer_quant_config, dynamic_weight_quant, dynamic_activation_quant)
|
||||
- quant_config: Global QuantConfig
|
||||
- layer_quant_config: Per-layer config dict (None if not using mixed precision)
|
||||
- dynamic_weight_quant: Whether to quantize weights at load time
|
||||
- dynamic_activation_quant: Whether to quantize activations dynamically
|
||||
"""
|
||||
quant_algo_str = quant_config_dict.get("quant_algo")
|
||||
quant_algo = None
|
||||
if quant_algo_str:
|
||||
algo_map = {
|
||||
"FP8": QuantAlgo.FP8,
|
||||
"FP8_BLOCK_SCALES": QuantAlgo.FP8_BLOCK_SCALES,
|
||||
"NVFP4": QuantAlgo.NVFP4,
|
||||
"W4A16_AWQ": QuantAlgo.W4A16_AWQ,
|
||||
"W4A8_AWQ": QuantAlgo.W4A8_AWQ,
|
||||
"W8A8_SQ_PER_CHANNEL": QuantAlgo.W8A8_SQ_PER_CHANNEL,
|
||||
}
|
||||
quant_algo = algo_map.get(quant_algo_str)
|
||||
if quant_algo is None:
|
||||
raise ValueError(f"Unknown quant_algo: {quant_algo_str}")
|
||||
|
||||
# Parse group_size and dynamic flags from config_groups
|
||||
group_size = None
|
||||
dynamic_weight_quant = False
|
||||
dynamic_activation_quant = False
|
||||
for group_config in quant_config_dict.get("config_groups", {}).values():
|
||||
weights_config = group_config.get("weights", {})
|
||||
activations_config = group_config.get("input_activations", {})
|
||||
dynamic_weight_quant = weights_config.get("dynamic", False)
|
||||
dynamic_activation_quant = activations_config.get("dynamic", False)
|
||||
# Extract group_size from weights config (e.g., NVFP4: group_size=16)
|
||||
if group_size is None:
|
||||
group_size = weights_config.get("group_size")
|
||||
break
|
||||
|
||||
# Set defaults based on quant_algo if group_size not specified
|
||||
if group_size is None:
|
||||
if quant_algo in (QuantAlgo.NVFP4,):
|
||||
group_size = 16 # NVFP4 default
|
||||
elif quant_algo == QuantAlgo.FP8_BLOCK_SCALES:
|
||||
group_size = 128 # FP8 blockwise default
|
||||
|
||||
# Auto-enable dynamic weight quantization if quant_algo is specified
|
||||
# but no explicit config_groups setting is present.
|
||||
# This allows simple configs like {"quant_algo": "FP8"} to work.
|
||||
if quant_algo is not None and not quant_config_dict.get("config_groups"):
|
||||
dynamic_weight_quant = quant_config_dict.get("dynamic", True)
|
||||
|
||||
quant_config = QuantConfig(
|
||||
quant_algo=quant_algo,
|
||||
group_size=group_size,
|
||||
exclude_modules=quant_config_dict.get("ignore"),
|
||||
)
|
||||
|
||||
# TODO: Per-layer config (None for now - future: parse mixed precision settings)
|
||||
layer_quant_config = None
|
||||
|
||||
return quant_config, layer_quant_config, dynamic_weight_quant, dynamic_activation_quant
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(
|
||||
cls,
|
||||
checkpoint_dir: str,
|
||||
args: Optional["DiffusionArgs"] = None,
|
||||
**kwargs,
|
||||
) -> "DiffusionModelConfig":
|
||||
"""
|
||||
Load config from pretrained checkpoint.
|
||||
|
||||
Called by PipelineLoader with DiffusionArgs:
|
||||
config = DiffusionModelConfig.from_pretrained(
|
||||
checkpoint_dir=args.checkpoint_path,
|
||||
args=args,
|
||||
)
|
||||
|
||||
Args:
|
||||
checkpoint_dir: Path to checkpoint
|
||||
args: DiffusionArgs containing user config (quant, pipeline, attention, parallel, teacache)
|
||||
**kwargs: Additional config options (e.g., mapping)
|
||||
"""
|
||||
kwargs.pop("trust_remote_code", None)
|
||||
|
||||
# Extract sub-configs from args or use defaults
|
||||
pipeline_cfg = args.pipeline if args else PipelineConfig()
|
||||
attention_cfg = args.attention if args else AttentionConfig()
|
||||
parallel_cfg = args.parallel if args else ParallelConfig()
|
||||
teacache_cfg = args.teacache if args else TeaCacheConfig()
|
||||
|
||||
component = PipelineComponent.TRANSFORMER
|
||||
checkpoint_path = Path(checkpoint_dir)
|
||||
|
||||
# Discover pipeline components
|
||||
components = discover_pipeline_components(checkpoint_path)
|
||||
|
||||
# Determine config path
|
||||
if components:
|
||||
if component not in components:
|
||||
raise ValueError(
|
||||
f"Component '{component}' not found. Available: {list(components.keys())}"
|
||||
)
|
||||
config_path = components[component]
|
||||
else:
|
||||
config_path = checkpoint_path / "config.json"
|
||||
|
||||
if not config_path.exists():
|
||||
raise ValueError(f"Config not found at {config_path}")
|
||||
|
||||
# Load pretrained_config from checkpoint
|
||||
with open(config_path) as f:
|
||||
config_dict = json.load(f)
|
||||
pretrained_config = SimpleNamespace(**config_dict)
|
||||
|
||||
model_index_path = checkpoint_path / "model_index.json"
|
||||
if model_index_path.exists():
|
||||
with open(model_index_path) as f:
|
||||
model_index = json.load(f)
|
||||
if "boundary_ratio" in model_index and "transformer_2" in model_index:
|
||||
transformer_2_spec = model_index.get("transformer_2")
|
||||
if transformer_2_spec and transformer_2_spec[0] is not None:
|
||||
pretrained_config.boundary_ratio = model_index["boundary_ratio"]
|
||||
|
||||
# Resolve quant config: use args if user set quant (QuantConfig from dict), else checkpoint
|
||||
if args and args.quant_config.quant_algo is not None:
|
||||
quant_config = args.quant_config
|
||||
quant_config_dict = (
|
||||
None # DiffusionArgs has no per-layer dict; only from checkpoint parse
|
||||
)
|
||||
dynamic_weight_quant = args.dynamic_weight_quant
|
||||
dynamic_activation_quant = args.force_dynamic_quantization
|
||||
else:
|
||||
quant_config = QuantConfig()
|
||||
quant_config_dict = None
|
||||
dynamic_weight_quant = False
|
||||
dynamic_activation_quant = False
|
||||
quant_dict = getattr(pretrained_config, "quantization_config", None)
|
||||
if isinstance(quant_dict, dict):
|
||||
quant_config, quant_config_dict, dynamic_weight_quant, dynamic_activation_quant = (
|
||||
cls.load_diffusion_quant_config(quant_dict)
|
||||
)
|
||||
|
||||
return cls(
|
||||
pretrained_config=pretrained_config,
|
||||
quant_config=quant_config,
|
||||
quant_config_dict=quant_config_dict,
|
||||
dynamic_weight_quant=dynamic_weight_quant,
|
||||
force_dynamic_quantization=dynamic_activation_quant,
|
||||
# Sub-configs from DiffusionArgs
|
||||
pipeline=pipeline_cfg,
|
||||
attention=attention_cfg,
|
||||
parallel=parallel_cfg,
|
||||
teacache=teacache_cfg,
|
||||
# Delay weight creation after apply_quant_config_exclude_modules() in __post_init__
|
||||
skip_create_weights_in_init=True,
|
||||
**kwargs,
|
||||
)
|
||||
246
tensorrt_llm/_torch/visual_gen/executor.py
Normal file
246
tensorrt_llm/_torch/visual_gen/executor.py
Normal file
@ -0,0 +1,246 @@
|
||||
import os
|
||||
import queue
|
||||
import threading
|
||||
import traceback
|
||||
from dataclasses import dataclass
|
||||
from typing import List, Optional, Union
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
import zmq
|
||||
|
||||
from tensorrt_llm._torch.visual_gen.config import DiffusionArgs
|
||||
from tensorrt_llm._torch.visual_gen.output import MediaOutput
|
||||
from tensorrt_llm._torch.visual_gen.pipeline_loader import PipelineLoader
|
||||
from tensorrt_llm.executor.ipc import ZeroMqQueue
|
||||
from tensorrt_llm.logger import logger
|
||||
|
||||
|
||||
@dataclass
|
||||
class DiffusionRequest:
|
||||
"""Request for diffusion inference with explicit model-specific parameters."""
|
||||
|
||||
request_id: int
|
||||
prompt: str
|
||||
negative_prompt: Optional[str] = None
|
||||
height: int = 720
|
||||
width: int = 1280
|
||||
num_inference_steps: int = 50
|
||||
guidance_scale: float = 5.0
|
||||
max_sequence_length: int = 512
|
||||
seed: int = 42
|
||||
|
||||
# Video-specific parameters
|
||||
num_frames: int = 81
|
||||
frame_rate: float = 24.0
|
||||
|
||||
# Image-specific parameters
|
||||
num_images_per_prompt: int = 1
|
||||
|
||||
# Advanced parameters
|
||||
guidance_rescale: float = 0.0
|
||||
output_type: str = "pt"
|
||||
|
||||
# Wan-specific parameters
|
||||
image: Optional[Union[str, List[str]]] = None
|
||||
guidance_scale_2: Optional[float] = None
|
||||
boundary_ratio: Optional[float] = None
|
||||
last_image: Optional[Union[str, List[str]]] = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class DiffusionResponse:
|
||||
"""Response with model-specific output.
|
||||
|
||||
Attributes:
|
||||
request_id: Unique identifier for the request
|
||||
output: Generated media as MediaOutput with model-specific fields populated
|
||||
error_msg: Error message if generation failed
|
||||
"""
|
||||
|
||||
request_id: int
|
||||
output: Optional[MediaOutput] = None
|
||||
error_msg: Optional[str] = None
|
||||
|
||||
|
||||
class DiffusionExecutor:
|
||||
"""Execution engine for diffusion models running in worker processes."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model_path: str,
|
||||
request_queue_addr: str,
|
||||
response_queue_addr: str,
|
||||
device_id: int,
|
||||
diffusion_config: Optional[dict] = None,
|
||||
):
|
||||
self.model_path = model_path
|
||||
self.request_queue_addr = request_queue_addr
|
||||
self.response_queue_addr = response_queue_addr
|
||||
self.device_id = device_id
|
||||
self.diffusion_config = diffusion_config
|
||||
|
||||
self.requests_ipc = None
|
||||
self.rank = dist.get_rank()
|
||||
self.response_queue = queue.Queue()
|
||||
self.sender_thread = None
|
||||
|
||||
# Only rank 0 handles IPC
|
||||
if self.rank == 0:
|
||||
logger.info(f"Worker {device_id}: Connecting to request queue")
|
||||
self.requests_ipc = ZeroMqQueue(
|
||||
(request_queue_addr, None),
|
||||
is_server=False,
|
||||
socket_type=zmq.PULL,
|
||||
use_hmac_encryption=False,
|
||||
)
|
||||
self.sender_thread = threading.Thread(target=self._sender_loop, daemon=True)
|
||||
self.sender_thread.start()
|
||||
|
||||
self._load_pipeline()
|
||||
|
||||
def _sender_loop(self):
|
||||
"""Background thread for sending responses."""
|
||||
logger.info(f"Worker {self.device_id}: Connecting to response queue")
|
||||
responses_ipc = ZeroMqQueue(
|
||||
(self.response_queue_addr, None),
|
||||
is_server=False,
|
||||
socket_type=zmq.PUSH,
|
||||
use_hmac_encryption=False,
|
||||
)
|
||||
|
||||
while True:
|
||||
try:
|
||||
resp = self.response_queue.get()
|
||||
if resp is None:
|
||||
break
|
||||
responses_ipc.put(resp)
|
||||
except Exception as e:
|
||||
logger.error(f"Worker {self.device_id}: Sender error: {e}")
|
||||
|
||||
if responses_ipc.socket:
|
||||
responses_ipc.socket.setsockopt(zmq.LINGER, 0)
|
||||
responses_ipc.close()
|
||||
|
||||
def _load_pipeline(self):
|
||||
"""
|
||||
Load pipeline using proper flow:
|
||||
DiffusionArgs → PipelineLoader → DiffusionModelConfig → AutoPipeline → BasePipeline
|
||||
"""
|
||||
logger.info(f"Worker {self.device_id}: Loading pipeline")
|
||||
|
||||
try:
|
||||
# Convert diffusion_config dict to DiffusionArgs
|
||||
config_dict = self.diffusion_config.copy()
|
||||
config_dict["checkpoint_path"] = self.model_path
|
||||
config_dict["device"] = f"cuda:{self.device_id}"
|
||||
|
||||
# Create DiffusionArgs from dict (handles nested configs)
|
||||
args = DiffusionArgs.from_dict(config_dict)
|
||||
|
||||
# Use PipelineLoader for proper pipeline creation flow:
|
||||
# PipelineLoader.load() internally:
|
||||
# 1. Creates DiffusionModelConfig.from_pretrained()
|
||||
# 2. Creates pipeline via AutoPipeline.from_config()
|
||||
# 3. Loads weights with quantization support
|
||||
# 4. Calls post_load_weights()
|
||||
loader = PipelineLoader(args)
|
||||
self.pipeline = loader.load()
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Worker {self.device_id}: Failed to load pipeline: {e}")
|
||||
raise
|
||||
|
||||
logger.info(f"Worker {self.device_id}: Pipeline ready")
|
||||
|
||||
# Sync all workers
|
||||
dist.barrier()
|
||||
|
||||
# Send READY signal
|
||||
if self.rank == 0:
|
||||
logger.info(f"Worker {self.device_id}: Sending READY")
|
||||
self.response_queue.put(DiffusionResponse(request_id=-1, output="READY"))
|
||||
|
||||
def serve_forever(self):
|
||||
"""Main execution loop."""
|
||||
while True:
|
||||
req = None
|
||||
if self.rank == 0:
|
||||
req = self.requests_ipc.get()
|
||||
logger.info(f"Worker {self.device_id}: Request available")
|
||||
|
||||
# Broadcast to all ranks
|
||||
obj_list = [req]
|
||||
dist.broadcast_object_list(obj_list, src=0)
|
||||
req = obj_list[0]
|
||||
|
||||
if req is None:
|
||||
logger.info(f"Worker {self.device_id}: Shutdown signal received")
|
||||
if self.rank == 0 and self.sender_thread:
|
||||
self.response_queue.put(None)
|
||||
self.sender_thread.join()
|
||||
break
|
||||
|
||||
logger.info(f"Worker {self.device_id}: Processing request {req.request_id}")
|
||||
self.process_request(req)
|
||||
|
||||
def process_request(self, req: DiffusionRequest):
|
||||
"""Process a single request."""
|
||||
try:
|
||||
output = self.pipeline.infer(req)
|
||||
if self.rank == 0:
|
||||
self.response_queue.put(DiffusionResponse(request_id=req.request_id, output=output))
|
||||
except Exception as e:
|
||||
logger.error(f"Worker {self.device_id}: Error: {e}")
|
||||
logger.error(traceback.format_exc())
|
||||
if self.rank == 0:
|
||||
self.response_queue.put(
|
||||
DiffusionResponse(request_id=req.request_id, error_msg=str(e))
|
||||
)
|
||||
|
||||
|
||||
def run_diffusion_worker(
|
||||
rank: int,
|
||||
world_size: int,
|
||||
master_addr: str,
|
||||
master_port: int,
|
||||
model_path: str,
|
||||
request_queue_addr: str,
|
||||
response_queue_addr: str,
|
||||
diffusion_config: Optional[dict] = None,
|
||||
):
|
||||
"""Entry point for worker process."""
|
||||
try:
|
||||
# Setup distributed env — use PyTorch distributed, not MPI
|
||||
os.environ["TLLM_DISABLE_MPI"] = "1"
|
||||
os.environ["MASTER_ADDR"] = master_addr
|
||||
os.environ["MASTER_PORT"] = str(master_port)
|
||||
os.environ["RANK"] = str(rank)
|
||||
os.environ["WORLD_SIZE"] = str(world_size)
|
||||
|
||||
# Calculate device_id before init_process_group
|
||||
device_id = rank % torch.cuda.device_count() if torch.cuda.is_available() else 0
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.set_device(device_id)
|
||||
|
||||
dist.init_process_group(
|
||||
backend="nccl" if torch.cuda.is_available() else "gloo",
|
||||
init_method="env://",
|
||||
world_size=world_size,
|
||||
rank=rank,
|
||||
device_id=torch.device(f"cuda:{device_id}") if torch.cuda.is_available() else None,
|
||||
)
|
||||
|
||||
executor = DiffusionExecutor(
|
||||
model_path=model_path,
|
||||
request_queue_addr=request_queue_addr,
|
||||
response_queue_addr=response_queue_addr,
|
||||
device_id=device_id,
|
||||
diffusion_config=diffusion_config,
|
||||
)
|
||||
executor.serve_forever()
|
||||
dist.destroy_process_group()
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Worker failed: {e}")
|
||||
traceback.print_exc()
|
||||
30
tensorrt_llm/_torch/visual_gen/models/__init__.py
Normal file
30
tensorrt_llm/_torch/visual_gen/models/__init__.py
Normal file
@ -0,0 +1,30 @@
|
||||
"""
|
||||
Visual generation model pipelines.
|
||||
|
||||
Each model subdirectory contains:
|
||||
- pipeline_*.py: Main pipeline implementation inheriting from BasePipeline
|
||||
- __init__.py: Exports the pipeline class
|
||||
|
||||
TeaCache extractors are registered inline in each pipeline's load() method
|
||||
using register_extractor_from_config().
|
||||
|
||||
Pipelines are registered in pipeline_registry.py's PipelineRegistry._REGISTRY dict.
|
||||
|
||||
Example structure:
|
||||
models/
|
||||
my_model/
|
||||
pipeline_my_model.py # Pipeline class with inline extractor registration
|
||||
__init__.py # Exports: __all__ = ["MyModelPipeline"]
|
||||
"""
|
||||
|
||||
from ..pipeline import BasePipeline
|
||||
from ..pipeline_registry import AutoPipeline, register_pipeline
|
||||
from .wan import WanImageToVideoPipeline, WanPipeline
|
||||
|
||||
__all__ = [
|
||||
"AutoPipeline",
|
||||
"BasePipeline",
|
||||
"WanPipeline",
|
||||
"WanImageToVideoPipeline",
|
||||
"register_pipeline",
|
||||
]
|
||||
5
tensorrt_llm/_torch/visual_gen/models/wan/__init__.py
Normal file
5
tensorrt_llm/_torch/visual_gen/models/wan/__init__.py
Normal file
@ -0,0 +1,5 @@
|
||||
from .pipeline_wan import WanPipeline
|
||||
from .pipeline_wan_i2v import WanImageToVideoPipeline
|
||||
from .transformer_wan import WanTransformer3DModel
|
||||
|
||||
__all__ = ["WanPipeline", "WanImageToVideoPipeline", "WanTransformer3DModel"]
|
||||
521
tensorrt_llm/_torch/visual_gen/models/wan/pipeline_wan.py
Normal file
521
tensorrt_llm/_torch/visual_gen/models/wan/pipeline_wan.py
Normal file
@ -0,0 +1,521 @@
|
||||
import time
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
from diffusers import AutoencoderKLWan, FlowMatchEulerDiscreteScheduler
|
||||
from diffusers.utils.torch_utils import randn_tensor
|
||||
from diffusers.video_processor import VideoProcessor
|
||||
from transformers import AutoTokenizer, UMT5EncoderModel
|
||||
|
||||
from tensorrt_llm._torch.visual_gen.config import PipelineComponent
|
||||
from tensorrt_llm._torch.visual_gen.output import MediaOutput
|
||||
from tensorrt_llm._torch.visual_gen.pipeline import BasePipeline
|
||||
from tensorrt_llm._torch.visual_gen.pipeline_registry import register_pipeline
|
||||
from tensorrt_llm._torch.visual_gen.teacache import ExtractorConfig, register_extractor_from_config
|
||||
from tensorrt_llm._torch.visual_gen.utils import postprocess_video_tensor
|
||||
from tensorrt_llm.logger import logger
|
||||
|
||||
from .transformer_wan import WanTransformer3DModel
|
||||
|
||||
# Supported Wan T2V models:
|
||||
# - Wan2.1-T2V-14B: Single-stage text-to-video (14B parameters)
|
||||
# - Wan2.1-T2V-1.3B: Single-stage text-to-video (1.3B parameters)
|
||||
# - Wan2.2-T2V-A14B: Two-stage text-to-video (14B, boundary_ratio for high/low-noise stages; supports 480P & 720P)
|
||||
|
||||
WAN_TEACACHE_COEFFICIENTS = {
|
||||
"1.3B": {
|
||||
"ret_steps": [
|
||||
-5.21862437e04,
|
||||
9.23041404e03,
|
||||
-5.28275948e02,
|
||||
1.36987616e01,
|
||||
-4.99875664e-02,
|
||||
],
|
||||
"standard": [2.39676752e03, -1.31110545e03, 2.01331979e02, -8.29855975e00, 1.37887774e-01],
|
||||
},
|
||||
"14B": {
|
||||
"ret_steps": [
|
||||
-3.03318725e05,
|
||||
4.90537029e04,
|
||||
-2.65530556e03,
|
||||
5.87365115e01,
|
||||
-3.15583525e-01,
|
||||
],
|
||||
"standard": [-5784.54975374, 5449.50911966, -1811.16591783, 256.27178429, -13.02252404],
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
# Default negative prompt for Wan T2V models
|
||||
WAN_DEFAULT_NEGATIVE_PROMPT = (
|
||||
"Vibrant colors, overexposed, static, blurry details, subtitles, style, artwork, painting, image, "
|
||||
"still image, overall grayish tone, worst quality, low quality, JPEG compression artifacts, ugly, "
|
||||
"incomplete, extra fingers, poorly drawn hands, poorly drawn face, deformed, disfigured, malformed limbs, "
|
||||
"fused fingers, motionless image, cluttered background, three legs, many people in the background, walking backward"
|
||||
)
|
||||
|
||||
|
||||
@register_pipeline("WanPipeline")
|
||||
class WanPipeline(BasePipeline):
|
||||
def __init__(self, model_config):
|
||||
# Wan2.2 two-stage denoising parameters
|
||||
self.transformer_2 = None
|
||||
self.boundary_ratio = getattr(model_config.pretrained_config, "boundary_ratio", None)
|
||||
self.is_wan22 = self.boundary_ratio is not None
|
||||
|
||||
super().__init__(model_config)
|
||||
|
||||
@staticmethod
|
||||
def _compute_wan_timestep_embedding(module, timestep, guidance=None):
|
||||
"""Compute timestep embedding for WAN transformer.
|
||||
|
||||
WAN uses a condition_embedder with timesteps_proj and time_embedder layers.
|
||||
Handles dtype casting to match the embedder's dtype.
|
||||
|
||||
Args:
|
||||
module: WanTransformer3DModel instance
|
||||
timestep: Timestep tensor (shape: [batch_size])
|
||||
guidance: Unused for WAN (no guidance embedding)
|
||||
|
||||
Returns:
|
||||
Timestep embedding tensor used by TeaCache for distance calculation
|
||||
"""
|
||||
ce = module.condition_embedder
|
||||
t_freq = ce.timesteps_proj(timestep)
|
||||
|
||||
# Cast to embedder's dtype (avoid int8 quantized layers)
|
||||
te_dtype = next(iter(ce.time_embedder.parameters())).dtype
|
||||
if t_freq.dtype != te_dtype and te_dtype != torch.int8:
|
||||
t_freq = t_freq.to(te_dtype)
|
||||
|
||||
return ce.time_embedder(t_freq)
|
||||
|
||||
@property
|
||||
def dtype(self):
|
||||
return self.model_config.torch_dtype
|
||||
|
||||
@property
|
||||
def device(self):
|
||||
return self.transformer.device
|
||||
|
||||
@property
|
||||
def transformer_components(self) -> list:
|
||||
"""Return list of transformer components this pipeline needs."""
|
||||
if self.transformer_2 is not None:
|
||||
return ["transformer", "transformer_2"]
|
||||
return ["transformer"]
|
||||
|
||||
def _init_transformer(self) -> None:
|
||||
logger.info("Creating WAN transformer with quantization support...")
|
||||
self.transformer = WanTransformer3DModel(model_config=self.model_config)
|
||||
|
||||
# Wan2.2: create second transformer for two-stage denoising
|
||||
if self.boundary_ratio is not None:
|
||||
logger.info("Creating second transformer for Wan2.2 two-stage denoising...")
|
||||
self.transformer_2 = WanTransformer3DModel(model_config=self.model_config)
|
||||
|
||||
def load_standard_components(
|
||||
self,
|
||||
checkpoint_dir: str,
|
||||
device: torch.device,
|
||||
skip_components: Optional[list] = None,
|
||||
) -> None:
|
||||
"""Load VAE, text encoder, tokenizer, and scheduler from checkpoint."""
|
||||
skip_components = skip_components or []
|
||||
|
||||
if self.transformer_2 is not None and self.boundary_ratio is None:
|
||||
raise RuntimeError(
|
||||
"transformer_2 exists but boundary_ratio is not set. "
|
||||
"This indicates an inconsistent pipeline configuration."
|
||||
)
|
||||
|
||||
# Detect model version
|
||||
if self.is_wan22:
|
||||
logger.info("Detected Wan 2.2 T2V (two-stage denoising)")
|
||||
else:
|
||||
logger.info("Detected Wan 2.1 T2V (single-stage denoising)")
|
||||
|
||||
# Set default VAE scale factors (will be overridden if VAE is loaded)
|
||||
self.vae_scale_factor_temporal = 4
|
||||
self.vae_scale_factor_spatial = 8
|
||||
|
||||
if PipelineComponent.TOKENIZER not in skip_components:
|
||||
logger.info("Loading tokenizer...")
|
||||
self.tokenizer = AutoTokenizer.from_pretrained(
|
||||
checkpoint_dir,
|
||||
subfolder=PipelineComponent.TOKENIZER,
|
||||
)
|
||||
|
||||
if PipelineComponent.TEXT_ENCODER not in skip_components:
|
||||
logger.info("Loading text encoder...")
|
||||
self.text_encoder = UMT5EncoderModel.from_pretrained(
|
||||
checkpoint_dir,
|
||||
subfolder=PipelineComponent.TEXT_ENCODER,
|
||||
torch_dtype=self.model_config.torch_dtype,
|
||||
).to(device)
|
||||
|
||||
if PipelineComponent.VAE not in skip_components:
|
||||
logger.info("Loading VAE...")
|
||||
self.vae = AutoencoderKLWan.from_pretrained(
|
||||
checkpoint_dir,
|
||||
subfolder=PipelineComponent.VAE,
|
||||
torch_dtype=torch.bfloat16, # load VAE in BF16 for memory saving
|
||||
).to(device)
|
||||
|
||||
self.vae_scale_factor_temporal = getattr(self.vae.config, "scale_factor_temporal", 4)
|
||||
self.vae_scale_factor_spatial = getattr(self.vae.config, "scale_factor_spatial", 8)
|
||||
|
||||
if PipelineComponent.SCHEDULER not in skip_components:
|
||||
logger.info("Loading scheduler...")
|
||||
self.scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained(
|
||||
checkpoint_dir,
|
||||
subfolder=PipelineComponent.SCHEDULER,
|
||||
)
|
||||
if not hasattr(self.scheduler.config, "shift") or self.scheduler.config.shift == 1.0:
|
||||
self.scheduler = FlowMatchEulerDiscreteScheduler.from_config(
|
||||
self.scheduler.config,
|
||||
shift=5.0,
|
||||
)
|
||||
|
||||
self.video_processor = VideoProcessor(vae_scale_factor=self.vae_scale_factor_spatial)
|
||||
|
||||
def load_weights(self, weights: dict) -> None:
|
||||
# Store weights for later use (in case transformer_2 is created after this call)
|
||||
self._weights_dict = weights
|
||||
|
||||
has_separate_weights = "transformer" in weights and "transformer_2" in weights
|
||||
|
||||
if self.transformer is not None and hasattr(self.transformer, "load_weights"):
|
||||
logger.info("Loading transformer weights...")
|
||||
transformer_weights = weights.get("transformer", weights)
|
||||
self.transformer.load_weights(transformer_weights)
|
||||
logger.info("Transformer weights loaded successfully.")
|
||||
|
||||
# Wan2.2: Load weights for second transformer if it exists
|
||||
if self.transformer_2 is not None and hasattr(self.transformer_2, "load_weights"):
|
||||
logger.info("Loading transformer_2 weights for Wan2.2...")
|
||||
if not has_separate_weights:
|
||||
raise ValueError(
|
||||
"Wan2.2 model requires separate 'transformer' and 'transformer_2' weights in checkpoint, "
|
||||
f"but only found: {list(weights.keys())}. "
|
||||
"Two-stage denoising requires distinct weights for high-noise and low-noise transformers."
|
||||
)
|
||||
transformer_2_weights = weights["transformer_2"]
|
||||
self.transformer_2.load_weights(transformer_2_weights)
|
||||
logger.info("Transformer_2 weights loaded successfully.")
|
||||
|
||||
# Cache the target dtype from model config (default: bfloat16)
|
||||
self._target_dtype = self.model_config.torch_dtype
|
||||
|
||||
# Set model to eval mode
|
||||
if self.transformer is not None:
|
||||
self.transformer.eval()
|
||||
if self.transformer_2 is not None:
|
||||
self.transformer_2.eval()
|
||||
|
||||
def post_load_weights(self) -> None:
|
||||
super().post_load_weights() # Calls transformer.post_load_weights() for FP8 scale transformations
|
||||
if self.transformer is not None:
|
||||
# Register TeaCache extractor for this model type
|
||||
# Tells TeaCache how to compute timestep embeddings for Wan
|
||||
register_extractor_from_config(
|
||||
ExtractorConfig(
|
||||
model_class_name="WanTransformer3DModel",
|
||||
timestep_embed_fn=self._compute_wan_timestep_embedding,
|
||||
return_dict_default=False, # Wan returns raw tensors, not wrapped outputs
|
||||
)
|
||||
)
|
||||
|
||||
# Enable TeaCache optimization with WAN-specific coefficients
|
||||
self._setup_teacache(self.transformer, coefficients=WAN_TEACACHE_COEFFICIENTS)
|
||||
# Save transformer backend before it gets overwritten
|
||||
self.transformer_cache_backend = self.cache_backend
|
||||
|
||||
# Wan2.2: Setup TeaCache for second transformer (low-noise stage)
|
||||
if self.transformer_2 is not None:
|
||||
if hasattr(self.transformer_2, "post_load_weights"):
|
||||
self.transformer_2.post_load_weights()
|
||||
|
||||
# Enable TeaCache for low-noise stage with same coefficients
|
||||
self._setup_teacache(self.transformer_2, coefficients=WAN_TEACACHE_COEFFICIENTS)
|
||||
# Save transformer_2 backend
|
||||
self.transformer_2_cache_backend = self.cache_backend
|
||||
|
||||
def infer(self, req):
|
||||
"""Run inference with request parameters."""
|
||||
return self.forward(
|
||||
prompt=req.prompt,
|
||||
negative_prompt=req.negative_prompt,
|
||||
height=req.height,
|
||||
width=req.width,
|
||||
num_frames=req.num_frames,
|
||||
num_inference_steps=req.num_inference_steps,
|
||||
guidance_scale=req.guidance_scale,
|
||||
guidance_scale_2=req.guidance_scale_2,
|
||||
boundary_ratio=req.boundary_ratio,
|
||||
seed=req.seed,
|
||||
max_sequence_length=req.max_sequence_length,
|
||||
)
|
||||
|
||||
@torch.no_grad()
|
||||
def forward(
|
||||
self,
|
||||
prompt: str,
|
||||
negative_prompt: Optional[str] = None,
|
||||
height: int = 720,
|
||||
width: int = 1280,
|
||||
num_frames: int = 81,
|
||||
num_inference_steps: Optional[int] = None,
|
||||
guidance_scale: Optional[float] = None,
|
||||
guidance_scale_2: Optional[float] = None,
|
||||
boundary_ratio: Optional[float] = None,
|
||||
seed: int = 42,
|
||||
max_sequence_length: int = 226,
|
||||
):
|
||||
pipeline_start = time.time()
|
||||
generator = torch.Generator(device=self.device).manual_seed(seed)
|
||||
|
||||
# Use user-provided boundary_ratio if given, otherwise fall back to model config
|
||||
boundary_ratio = boundary_ratio if boundary_ratio is not None else self.boundary_ratio
|
||||
|
||||
# Validate that Wan 2.2 models have boundary_ratio set
|
||||
if self.transformer_2 is not None and boundary_ratio is None:
|
||||
raise ValueError(
|
||||
"Wan 2.2 models require boundary_ratio to be set. "
|
||||
"boundary_ratio was not found in model config. "
|
||||
"Please pass boundary_ratio as a parameter."
|
||||
)
|
||||
|
||||
# Set default negative prompt if not provided
|
||||
if negative_prompt is None:
|
||||
negative_prompt = WAN_DEFAULT_NEGATIVE_PROMPT
|
||||
|
||||
# Set model-specific defaults based on Wan version
|
||||
logger.info(
|
||||
f"Running {'Wan 2.2' if self.is_wan22 else 'Wan 2.1'} T2V inference"
|
||||
f"(boundary_ratio={boundary_ratio}, has_transformer_2={self.transformer_2 is not None})"
|
||||
)
|
||||
|
||||
if num_inference_steps is None:
|
||||
num_inference_steps = 40 if self.is_wan22 else 50
|
||||
|
||||
if guidance_scale is None:
|
||||
guidance_scale = 4.0 if self.is_wan22 else 5.0
|
||||
|
||||
if self.is_wan22 and guidance_scale_2 is None:
|
||||
guidance_scale_2 = 3.0
|
||||
|
||||
# Validate two-stage denoising configuration
|
||||
if guidance_scale_2 is not None and boundary_ratio is None:
|
||||
logger.warning(
|
||||
"guidance_scale_2 is specified but boundary_ratio is None. "
|
||||
"guidance_scale_2 will be ignored."
|
||||
"Set boundary_ratio in config or pass as parameter to enable two-stage denoising."
|
||||
)
|
||||
guidance_scale_2 = None
|
||||
|
||||
# Encode Prompt
|
||||
logger.info("Encoding prompts...")
|
||||
encode_start = time.time()
|
||||
prompt_embeds, neg_prompt_embeds = self._encode_prompt(
|
||||
prompt, negative_prompt, max_sequence_length
|
||||
)
|
||||
logger.info(f"Prompt encoding completed in {time.time() - encode_start:.2f}s")
|
||||
|
||||
# Prepare Latents
|
||||
latents = self._prepare_latents(height, width, num_frames, generator)
|
||||
logger.info(f"Latents shape: {latents.shape}")
|
||||
|
||||
self.scheduler.set_timesteps(num_inference_steps, device=self.device)
|
||||
|
||||
# Wan2.2: Calculate boundary timestep for two-stage denoising
|
||||
boundary_timestep = None
|
||||
if boundary_ratio is not None and self.transformer_2 is not None:
|
||||
boundary_timestep = boundary_ratio * self.scheduler.config.num_train_timesteps
|
||||
logger.info(
|
||||
f"Wan2.2 two-stage denoising: boundary_timestep={boundary_timestep:.1f}, "
|
||||
f"guidance_scale={guidance_scale}, guidance_scale_2={guidance_scale_2}"
|
||||
)
|
||||
|
||||
# Denoising with two-stage support
|
||||
# Track which model was used in last step (for logging model transitions)
|
||||
last_model_used = [None]
|
||||
|
||||
def forward_fn(
|
||||
latents, extra_stream_latents, timestep, encoder_hidden_states, extra_tensors
|
||||
):
|
||||
"""Forward function for Wan transformer with two-stage support.
|
||||
|
||||
extra_stream_latents and extra_tensors are unused for WAN (single stream, no additional embeddings).
|
||||
"""
|
||||
# Select model based on timestep (if two-stage denoising is enabled)
|
||||
if boundary_timestep is not None and self.transformer_2 is not None:
|
||||
# Extract scalar timestep for comparison
|
||||
current_t = timestep if timestep.dim() == 0 else timestep[0]
|
||||
if current_t >= boundary_timestep:
|
||||
current_model = self.transformer
|
||||
model_name = "transformer (high-noise)"
|
||||
else:
|
||||
current_model = self.transformer_2
|
||||
model_name = "transformer_2 (low-noise)"
|
||||
|
||||
# Log when switching models
|
||||
if last_model_used[0] != model_name:
|
||||
if self.rank == 0:
|
||||
logger.info(f"Switched to {model_name} at timestep {current_t:.1f}")
|
||||
last_model_used[0] = model_name
|
||||
else:
|
||||
current_model = self.transformer
|
||||
|
||||
return current_model(
|
||||
hidden_states=latents,
|
||||
timestep=timestep,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
)
|
||||
|
||||
# Two-stage denoising: model switching in forward_fn, guidance scale switching in denoise()
|
||||
latents = self.denoise(
|
||||
latents=latents,
|
||||
scheduler=self.scheduler,
|
||||
prompt_embeds=prompt_embeds,
|
||||
neg_prompt_embeds=neg_prompt_embeds,
|
||||
guidance_scale=guidance_scale,
|
||||
forward_fn=forward_fn,
|
||||
guidance_scale_2=guidance_scale_2,
|
||||
boundary_timestep=boundary_timestep,
|
||||
)
|
||||
|
||||
# Log TeaCache statistics - show stats for each transformer separately
|
||||
if self.rank == 0 and self.model_config.teacache.enable_teacache:
|
||||
logger.info("=" * 80)
|
||||
logger.info("TeaCache Statistics:")
|
||||
|
||||
# Stats for transformer (high-noise)
|
||||
if hasattr(self, "transformer_cache_backend") and self.transformer_cache_backend:
|
||||
stats = self.transformer_cache_backend.get_stats()
|
||||
total_steps = stats.get("total_steps", 0)
|
||||
cache_hits = stats.get("cached_steps", 0)
|
||||
cache_misses = stats.get("compute_steps", 0)
|
||||
hit_rate = (cache_hits / total_steps * 100) if total_steps > 0 else 0.0
|
||||
|
||||
logger.info(" Transformer (High-Noise):")
|
||||
logger.info(f" Total steps: {total_steps}")
|
||||
logger.info(f" Cache hits: {cache_hits}")
|
||||
logger.info(f" Cache misses: {cache_misses}")
|
||||
logger.info(f" Hit rate: {hit_rate:.1f}%")
|
||||
|
||||
# Stats for transformer_2 (low-noise)
|
||||
if hasattr(self, "transformer_2_cache_backend") and self.transformer_2_cache_backend:
|
||||
stats = self.transformer_2_cache_backend.get_stats()
|
||||
total_steps = stats.get("total_steps", 0)
|
||||
cache_hits = stats.get("cached_steps", 0)
|
||||
cache_misses = stats.get("compute_steps", 0)
|
||||
hit_rate = (cache_hits / total_steps * 100) if total_steps > 0 else 0.0
|
||||
|
||||
logger.info(" Transformer_2 (Low-Noise):")
|
||||
logger.info(f" Total steps: {total_steps}")
|
||||
logger.info(f" Cache hits: {cache_hits}")
|
||||
logger.info(f" Cache misses: {cache_misses}")
|
||||
logger.info(f" Hit rate: {hit_rate:.1f}%")
|
||||
|
||||
logger.info("=" * 80)
|
||||
|
||||
# Decode
|
||||
logger.info("Decoding video...")
|
||||
decode_start = time.time()
|
||||
video = self.decode_latents(latents, self._decode_latents)
|
||||
|
||||
if self.rank == 0:
|
||||
logger.info(f"Video decoded in {time.time() - decode_start:.2f}s")
|
||||
logger.info(f"Total pipeline time: {time.time() - pipeline_start:.2f}s")
|
||||
|
||||
return MediaOutput(video=video)
|
||||
|
||||
def _encode_prompt(self, prompt, negative_prompt, max_sequence_length):
|
||||
prompt = [prompt] if isinstance(prompt, str) else prompt
|
||||
|
||||
def get_embeds(texts):
|
||||
text_inputs = self.tokenizer(
|
||||
texts,
|
||||
padding="max_length",
|
||||
max_length=max_sequence_length,
|
||||
truncation=True,
|
||||
return_attention_mask=True,
|
||||
return_tensors="pt",
|
||||
)
|
||||
input_ids = text_inputs.input_ids.to(self.device)
|
||||
attention_mask = text_inputs.attention_mask.to(self.device)
|
||||
|
||||
embeds = self.text_encoder(input_ids, attention_mask=attention_mask).last_hidden_state
|
||||
embeds = embeds.to(self.dtype)
|
||||
|
||||
# Zero-out padded tokens based on mask
|
||||
seq_lens = attention_mask.gt(0).sum(dim=1).long()
|
||||
cleaned_embeds = []
|
||||
for u, v in zip(embeds, seq_lens):
|
||||
real_content = u[:v]
|
||||
pad_len = max_sequence_length - real_content.size(0)
|
||||
if pad_len > 0:
|
||||
padded = torch.cat(
|
||||
[real_content, real_content.new_zeros(pad_len, real_content.size(1))]
|
||||
)
|
||||
else:
|
||||
padded = real_content
|
||||
cleaned_embeds.append(padded)
|
||||
|
||||
return torch.stack(cleaned_embeds, dim=0)
|
||||
|
||||
prompt_embeds = get_embeds(prompt)
|
||||
|
||||
if negative_prompt is None:
|
||||
negative_prompt = ""
|
||||
|
||||
neg_prompt = [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt
|
||||
if len(neg_prompt) == 1 and len(prompt) > 1:
|
||||
neg_prompt = neg_prompt * len(prompt)
|
||||
|
||||
neg_embeds = get_embeds(neg_prompt)
|
||||
|
||||
return prompt_embeds, neg_embeds
|
||||
|
||||
def _prepare_latents(self, height, width, num_frames, generator):
|
||||
num_channels_latents = 16
|
||||
num_latent_frames = (num_frames - 1) // self.vae_scale_factor_temporal + 1
|
||||
|
||||
shape = (
|
||||
1,
|
||||
num_channels_latents,
|
||||
num_latent_frames,
|
||||
height // self.vae_scale_factor_spatial,
|
||||
width // self.vae_scale_factor_spatial,
|
||||
)
|
||||
return randn_tensor(shape, generator=generator, device=self.device, dtype=self.dtype)
|
||||
|
||||
def _decode_latents(self, latents):
|
||||
latents = latents.to(self.vae.dtype)
|
||||
|
||||
# Denormalization
|
||||
if hasattr(self.vae.config, "latents_mean") and hasattr(self.vae.config, "latents_std"):
|
||||
if not hasattr(self, "_latents_mean"):
|
||||
self._latents_mean = (
|
||||
torch.tensor(self.vae.config.latents_mean)
|
||||
.view(1, -1, 1, 1, 1)
|
||||
.to(self.device, self.vae.dtype)
|
||||
)
|
||||
self._latents_std = (
|
||||
torch.tensor(self.vae.config.latents_std)
|
||||
.view(1, -1, 1, 1, 1)
|
||||
.to(self.device, self.vae.dtype)
|
||||
)
|
||||
latents = (latents * self._latents_std) + self._latents_mean
|
||||
else:
|
||||
scaling_factor = self.vae.config.get("scaling_factor", 1.0)
|
||||
latents = latents / scaling_factor
|
||||
|
||||
# VAE decode: returns (B, C, T, H, W)
|
||||
video = self.vae.decode(latents, return_dict=False)[0]
|
||||
|
||||
# Post-process video tensor: (B, C, T, H, W) -> (T, H, W, C) uint8
|
||||
video = postprocess_video_tensor(video, remove_batch_dim=True)
|
||||
|
||||
return video
|
||||
736
tensorrt_llm/_torch/visual_gen/models/wan/pipeline_wan_i2v.py
Normal file
736
tensorrt_llm/_torch/visual_gen/models/wan/pipeline_wan_i2v.py
Normal file
@ -0,0 +1,736 @@
|
||||
import json
|
||||
import os
|
||||
import time
|
||||
from typing import Optional, Tuple, Union
|
||||
|
||||
import PIL.Image
|
||||
import torch
|
||||
from diffusers import AutoencoderKLWan, FlowMatchEulerDiscreteScheduler
|
||||
from diffusers.utils.torch_utils import randn_tensor
|
||||
from diffusers.video_processor import VideoProcessor
|
||||
from transformers import AutoTokenizer, CLIPImageProcessor, CLIPVisionModel, UMT5EncoderModel
|
||||
|
||||
from tensorrt_llm._torch.visual_gen.config import PipelineComponent
|
||||
from tensorrt_llm._torch.visual_gen.output import MediaOutput
|
||||
from tensorrt_llm._torch.visual_gen.pipeline import BasePipeline
|
||||
from tensorrt_llm._torch.visual_gen.pipeline_registry import register_pipeline
|
||||
from tensorrt_llm._torch.visual_gen.teacache import ExtractorConfig, register_extractor_from_config
|
||||
from tensorrt_llm._torch.visual_gen.utils import postprocess_video_tensor
|
||||
from tensorrt_llm.logger import logger
|
||||
|
||||
# Supported Wan I2V 14B models:
|
||||
# - Wan2.1-I2V-14B-480P: Single-stage image-to-video
|
||||
# - Wan2.1-I2V-14B-720P: Single-stage image-to-video
|
||||
# - Wan2.2-I2V-14B: Two-stage image-to-video (no CLIP, boundary_ratio for two-stage denoising)
|
||||
# Note: Wan2.2-I2V-5B (expand_timesteps mode) is NOT supported by this pipeline
|
||||
# Import shared coefficients from T2V pipeline
|
||||
from .pipeline_wan import WAN_TEACACHE_COEFFICIENTS
|
||||
from .transformer_wan import WanTransformer3DModel
|
||||
|
||||
# Use same coefficients
|
||||
WAN_I2V_TEACACHE_COEFFICIENTS = WAN_TEACACHE_COEFFICIENTS
|
||||
|
||||
# Default negative prompt for Wan I2V models
|
||||
WAN_DEFAULT_NEGATIVE_PROMPT = (
|
||||
"Vibrant colors, overexposed, static, blurry details, subtitles, style, artwork, painting, image, "
|
||||
"still image, overall grayish tone, worst quality, low quality, JPEG compression artifacts, ugly, "
|
||||
"incomplete, extra fingers, poorly drawn hands, poorly drawn face, deformed, disfigured, malformed limbs, "
|
||||
"fused fingers, motionless image, cluttered background, three legs, many people in the background, walking backward"
|
||||
)
|
||||
|
||||
|
||||
def retrieve_latents(
|
||||
encoder_output: torch.Tensor,
|
||||
generator: Optional[torch.Generator] = None,
|
||||
sample_mode: str = "argmax",
|
||||
):
|
||||
"""Extract latents from VAE encoder output.
|
||||
|
||||
For I2V, we use argmax mode to get deterministic encoding of the input image.
|
||||
"""
|
||||
if hasattr(encoder_output, "latent_dist") and sample_mode == "sample":
|
||||
return encoder_output.latent_dist.sample(generator)
|
||||
elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax":
|
||||
return encoder_output.latent_dist.mode()
|
||||
elif hasattr(encoder_output, "latents"):
|
||||
return encoder_output.latents
|
||||
else:
|
||||
raise AttributeError("Could not access latents of provided encoder_output")
|
||||
|
||||
|
||||
@register_pipeline("WanImageToVideoPipeline")
|
||||
class WanImageToVideoPipeline(BasePipeline):
|
||||
def __init__(self, model_config):
|
||||
# Wan2.2 14B two-stage denoising parameters
|
||||
self.transformer_2 = None
|
||||
self.boundary_ratio = getattr(model_config.pretrained_config, "boundary_ratio", None)
|
||||
self.is_wan22 = self.boundary_ratio is not None
|
||||
|
||||
super().__init__(model_config)
|
||||
|
||||
@staticmethod
|
||||
def _compute_wan_timestep_embedding(module, timestep, guidance=None):
|
||||
"""Compute timestep embedding for Wan I2V transformer."""
|
||||
ce = module.condition_embedder
|
||||
t_freq = ce.timesteps_proj(timestep)
|
||||
|
||||
# Cast to embedder's dtype (avoid int8 quantized layers)
|
||||
te_dtype = next(iter(ce.time_embedder.parameters())).dtype
|
||||
if t_freq.dtype != te_dtype and te_dtype != torch.int8:
|
||||
t_freq = t_freq.to(te_dtype)
|
||||
|
||||
return ce.time_embedder(t_freq)
|
||||
|
||||
@property
|
||||
def dtype(self):
|
||||
return self.model_config.torch_dtype
|
||||
|
||||
@property
|
||||
def device(self):
|
||||
return self.transformer.device
|
||||
|
||||
@property
|
||||
def transformer_components(self) -> list:
|
||||
if self.transformer_2 is not None:
|
||||
return ["transformer", "transformer_2"]
|
||||
return ["transformer"]
|
||||
|
||||
def _init_transformer(self) -> None:
|
||||
logger.info("Creating WAN I2V transformer with quantization support...")
|
||||
self.transformer = WanTransformer3DModel(model_config=self.model_config)
|
||||
|
||||
# Wan2.2: Optionally create second transformer for two-stage denoising
|
||||
if self.boundary_ratio is not None:
|
||||
logger.info("Creating second transformer for Wan2.2 I2V two-stage denoising...")
|
||||
self.transformer_2 = WanTransformer3DModel(model_config=self.model_config)
|
||||
|
||||
def load_standard_components(
|
||||
self,
|
||||
checkpoint_dir: str,
|
||||
device: torch.device,
|
||||
skip_components: Optional[list] = None,
|
||||
) -> None:
|
||||
"""Load VAE, text encoder, tokenizer, scheduler, and I2V-specific components from checkpoint."""
|
||||
skip_components = skip_components or []
|
||||
|
||||
# Load boundary_ratio and transformer_2 info from model_index.json (pipeline-level config)
|
||||
# Wan 2.2 has both transformer_2 and boundary_ratio, Wan 2.1 doesn't
|
||||
model_index_path = os.path.join(checkpoint_dir, "model_index.json")
|
||||
has_transformer_2 = False
|
||||
if os.path.exists(model_index_path):
|
||||
with open(model_index_path) as f:
|
||||
model_index = json.load(f)
|
||||
# Check for boundary_ratio in model_index
|
||||
if "boundary_ratio" in model_index:
|
||||
self.boundary_ratio = model_index["boundary_ratio"]
|
||||
logger.info(f"Found boundary_ratio in model_index.json: {self.boundary_ratio}")
|
||||
else:
|
||||
logger.info("No boundary_ratio found in model_index.json")
|
||||
# Check for transformer_2 component
|
||||
transformer_2_spec = model_index.get("transformer_2", None)
|
||||
has_transformer_2 = (
|
||||
transformer_2_spec is not None and transformer_2_spec[0] is not None
|
||||
)
|
||||
logger.info(f"transformer_2 in model_index.json: {has_transformer_2}")
|
||||
|
||||
# Set default VAE scale factors (will be overridden if VAE is loaded)
|
||||
self.vae_scale_factor_temporal = 4
|
||||
self.vae_scale_factor_spatial = 8
|
||||
|
||||
if PipelineComponent.TOKENIZER not in skip_components:
|
||||
logger.info("Loading tokenizer...")
|
||||
self.tokenizer = AutoTokenizer.from_pretrained(
|
||||
checkpoint_dir,
|
||||
subfolder=PipelineComponent.TOKENIZER,
|
||||
)
|
||||
|
||||
if PipelineComponent.TEXT_ENCODER not in skip_components:
|
||||
logger.info("Loading text encoder...")
|
||||
self.text_encoder = UMT5EncoderModel.from_pretrained(
|
||||
checkpoint_dir,
|
||||
subfolder=PipelineComponent.TEXT_ENCODER,
|
||||
torch_dtype=self.model_config.torch_dtype,
|
||||
).to(device)
|
||||
|
||||
if PipelineComponent.VAE not in skip_components:
|
||||
logger.info("Loading VAE...")
|
||||
self.vae = AutoencoderKLWan.from_pretrained(
|
||||
checkpoint_dir,
|
||||
subfolder=PipelineComponent.VAE,
|
||||
torch_dtype=torch.bfloat16, # load VAE in BF16 for memory saving
|
||||
).to(device)
|
||||
|
||||
self.vae_scale_factor_temporal = getattr(self.vae.config, "scale_factor_temporal", 4)
|
||||
self.vae_scale_factor_spatial = getattr(self.vae.config, "scale_factor_spatial", 8)
|
||||
|
||||
if PipelineComponent.SCHEDULER not in skip_components:
|
||||
logger.info("Loading scheduler...")
|
||||
self.scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained(
|
||||
checkpoint_dir,
|
||||
subfolder=PipelineComponent.SCHEDULER,
|
||||
)
|
||||
if not hasattr(self.scheduler.config, "shift") or self.scheduler.config.shift == 1.0:
|
||||
self.scheduler = FlowMatchEulerDiscreteScheduler.from_config(
|
||||
self.scheduler.config,
|
||||
shift=5.0,
|
||||
)
|
||||
|
||||
if self.transformer_2 is not None and self.boundary_ratio is None:
|
||||
raise RuntimeError(
|
||||
"transformer_2 exists but boundary_ratio is not set. "
|
||||
"This indicates an inconsistent pipeline configuration."
|
||||
)
|
||||
|
||||
# Load image encoder and processor (only for Wan 2.1)
|
||||
# Wan 2.2: Has both transformer_2 and boundary_ratio (two-stage denoising)
|
||||
if self.is_wan22:
|
||||
logger.info("Detected Wan 2.2 I2V (two-stage, no CLIP)")
|
||||
else:
|
||||
logger.info("Detected Wan 2.1 I2V (single-stage, uses CLIP)")
|
||||
|
||||
if PipelineComponent.IMAGE_ENCODER not in skip_components and not self.is_wan22:
|
||||
logger.info("Loading CLIP image encoder for I2V conditioning (Wan 2.1 only)...")
|
||||
self.image_encoder = CLIPVisionModel.from_pretrained(
|
||||
checkpoint_dir,
|
||||
subfolder=PipelineComponent.IMAGE_ENCODER,
|
||||
torch_dtype=torch.float32, # Keep CLIP in FP32 for stability
|
||||
).to(device)
|
||||
|
||||
if PipelineComponent.IMAGE_PROCESSOR not in skip_components and not self.is_wan22:
|
||||
logger.info("Loading CLIP image processor...")
|
||||
self.image_processor = CLIPImageProcessor.from_pretrained(
|
||||
checkpoint_dir,
|
||||
subfolder=PipelineComponent.IMAGE_PROCESSOR,
|
||||
)
|
||||
|
||||
self.video_processor = VideoProcessor(vae_scale_factor=self.vae_scale_factor_spatial)
|
||||
|
||||
def load_weights(self, weights: dict) -> None:
|
||||
# Store weights for later use
|
||||
self._weights_dict = weights
|
||||
|
||||
# Check if weights dict has separate transformer/transformer_2 keys (Wan2.2)
|
||||
has_separate_weights = "transformer" in weights and "transformer_2" in weights
|
||||
|
||||
if self.transformer is not None and hasattr(self.transformer, "load_weights"):
|
||||
logger.info("Loading transformer weights...")
|
||||
transformer_weights = weights.get("transformer", weights)
|
||||
self.transformer.load_weights(transformer_weights)
|
||||
logger.info("Transformer weights loaded successfully.")
|
||||
|
||||
# Wan2.2: Load weights for second transformer if it exists
|
||||
if self.transformer_2 is not None and hasattr(self.transformer_2, "load_weights"):
|
||||
logger.info("Loading transformer_2 weights for Wan2.2 I2V...")
|
||||
if has_separate_weights:
|
||||
transformer_2_weights = weights["transformer_2"]
|
||||
logger.info("Using separate transformer_2 weights from checkpoint")
|
||||
else:
|
||||
# For Wan 2.2, transformer_2 weights must exist
|
||||
raise ValueError(
|
||||
"Wan2.2 model requires separate 'transformer' and 'transformer_2' weights in checkpoint, "
|
||||
f"but only found: {list(weights.keys())}"
|
||||
)
|
||||
self.transformer_2.load_weights(transformer_2_weights)
|
||||
logger.info("Transformer_2 weights loaded successfully.")
|
||||
|
||||
# Cache the target dtype from model config (default: bfloat16)
|
||||
self._target_dtype = self.model_config.torch_dtype
|
||||
|
||||
# Set model to eval mode
|
||||
if self.transformer is not None:
|
||||
self.transformer.eval()
|
||||
if self.transformer_2 is not None:
|
||||
self.transformer_2.eval()
|
||||
if hasattr(self, "image_encoder") and self.image_encoder is not None:
|
||||
self.image_encoder.eval()
|
||||
|
||||
def post_load_weights(self) -> None:
|
||||
super().post_load_weights() # Calls transformer.post_load_weights() for FP8 scale transformations
|
||||
if self.transformer is not None:
|
||||
# Register TeaCache extractor for this model type
|
||||
register_extractor_from_config(
|
||||
ExtractorConfig(
|
||||
model_class_name="WanTransformer3DModel",
|
||||
timestep_embed_fn=self._compute_wan_timestep_embedding,
|
||||
return_dict_default=False, # Wan returns raw tensors, not wrapped outputs
|
||||
)
|
||||
)
|
||||
|
||||
# Enable TeaCache optimization with Wan I2V-specific coefficients
|
||||
self._setup_teacache(self.transformer, coefficients=WAN_I2V_TEACACHE_COEFFICIENTS)
|
||||
# Save transformer backend before it gets overwritten
|
||||
self.transformer_cache_backend = self.cache_backend
|
||||
|
||||
# Wan2.2: Setup TeaCache for second transformer (low-noise stage)
|
||||
if self.transformer_2 is not None:
|
||||
if hasattr(self.transformer_2, "post_load_weights"):
|
||||
self.transformer_2.post_load_weights()
|
||||
|
||||
# Enable TeaCache for low-noise stage with same coefficients
|
||||
self._setup_teacache(self.transformer_2, coefficients=WAN_I2V_TEACACHE_COEFFICIENTS)
|
||||
# Save transformer_2 backend
|
||||
self.transformer_2_cache_backend = self.cache_backend
|
||||
|
||||
def infer(self, req):
|
||||
"""Run inference with request parameters."""
|
||||
# Extract image from request (can be path, PIL Image, or torch.Tensor)
|
||||
if req.image is None:
|
||||
raise ValueError("I2V pipeline requires 'image' parameter")
|
||||
|
||||
image = req.image[0] if isinstance(req.image, list) else req.image
|
||||
last_image = req.last_image
|
||||
|
||||
if last_image is not None and isinstance(last_image, list):
|
||||
last_image = last_image[0] if last_image else None
|
||||
|
||||
return self.forward(
|
||||
image=image,
|
||||
prompt=req.prompt,
|
||||
negative_prompt=req.negative_prompt,
|
||||
height=req.height,
|
||||
width=req.width,
|
||||
num_frames=req.num_frames,
|
||||
num_inference_steps=req.num_inference_steps,
|
||||
guidance_scale=req.guidance_scale,
|
||||
guidance_scale_2=req.guidance_scale_2,
|
||||
boundary_ratio=req.boundary_ratio,
|
||||
seed=req.seed,
|
||||
max_sequence_length=req.max_sequence_length,
|
||||
last_image=last_image,
|
||||
)
|
||||
|
||||
@torch.no_grad()
|
||||
def forward(
|
||||
self,
|
||||
image: Union[PIL.Image.Image, torch.Tensor, str],
|
||||
prompt: str,
|
||||
negative_prompt: Optional[str] = None,
|
||||
height: int = 480,
|
||||
width: int = 832,
|
||||
num_frames: int = 81,
|
||||
num_inference_steps: Optional[int] = None,
|
||||
guidance_scale: Optional[float] = None,
|
||||
guidance_scale_2: Optional[float] = None,
|
||||
boundary_ratio: Optional[float] = None,
|
||||
seed: int = 42,
|
||||
max_sequence_length: int = 512,
|
||||
last_image: Optional[Union[PIL.Image.Image, torch.Tensor, str]] = None,
|
||||
):
|
||||
pipeline_start = time.time()
|
||||
generator = torch.Generator(device=self.device).manual_seed(seed)
|
||||
|
||||
# Use user-provided boundary_ratio if given, otherwise fall back to model config
|
||||
boundary_ratio = boundary_ratio if boundary_ratio is not None else self.boundary_ratio
|
||||
|
||||
if self.transformer_2 is not None and boundary_ratio is None:
|
||||
raise ValueError(
|
||||
"Wan 2.2 models require boundary_ratio to be set. "
|
||||
"boundary_ratio was not found in model config. "
|
||||
"Please pass boundary_ratio as a parameter."
|
||||
)
|
||||
|
||||
# Set default negative prompt if not provided
|
||||
if negative_prompt is None:
|
||||
negative_prompt = WAN_DEFAULT_NEGATIVE_PROMPT
|
||||
|
||||
# Set model-specific defaults based on Wan version
|
||||
if num_inference_steps is None:
|
||||
num_inference_steps = 40 if self.is_wan22 else 50
|
||||
|
||||
if guidance_scale is None:
|
||||
guidance_scale = 4.0 if self.is_wan22 else 5.0
|
||||
|
||||
if self.is_wan22 and guidance_scale_2 is None:
|
||||
guidance_scale_2 = 3.0 # Wan2.2 recommended default
|
||||
|
||||
# Validate two-stage denoising configuration
|
||||
if guidance_scale_2 is not None and boundary_ratio is None:
|
||||
logger.warning(
|
||||
"guidance_scale_2 is specified but boundary_ratio is None. "
|
||||
"guidance_scale_2 will be ignored."
|
||||
"Set boundary_ratio in config or pass as parameter to enable two-stage denoising."
|
||||
)
|
||||
guidance_scale_2 = None
|
||||
|
||||
# Validate and adjust frame count for VAE compatibility
|
||||
if num_frames % self.vae_scale_factor_temporal != 1:
|
||||
logger.warning(
|
||||
f"`num_frames - 1` must be divisible by {self.vae_scale_factor_temporal}. "
|
||||
f"Rounding {num_frames} to nearest valid value."
|
||||
)
|
||||
num_frames = (
|
||||
num_frames // self.vae_scale_factor_temporal * self.vae_scale_factor_temporal + 1
|
||||
)
|
||||
num_frames = max(num_frames, 1)
|
||||
|
||||
# Validate and adjust resolution for transformer patchification
|
||||
patch_size = (
|
||||
self.transformer.config.patch_size
|
||||
if self.transformer is not None
|
||||
else self.transformer_2.config.patch_size
|
||||
)
|
||||
h_multiple_of = self.vae_scale_factor_spatial * patch_size[1]
|
||||
w_multiple_of = self.vae_scale_factor_spatial * patch_size[2]
|
||||
calc_height = height // h_multiple_of * h_multiple_of
|
||||
calc_width = width // w_multiple_of * w_multiple_of
|
||||
if height != calc_height or width != calc_width:
|
||||
logger.warning(
|
||||
f"Height and width must be multiples of ({h_multiple_of}, {w_multiple_of}) for patchification. "
|
||||
f"Adjusting ({height}, {width}) -> ({calc_height}, {calc_width})."
|
||||
)
|
||||
height, width = calc_height, calc_width
|
||||
|
||||
# Encode Prompt
|
||||
logger.info("Encoding prompts...")
|
||||
encode_start = time.time()
|
||||
prompt_embeds, neg_prompt_embeds = self._encode_prompt(
|
||||
prompt, negative_prompt, max_sequence_length
|
||||
)
|
||||
logger.info(f"Prompt encoding completed in {time.time() - encode_start:.2f}s")
|
||||
|
||||
# Encode Image (I2V-specific)
|
||||
logger.info("Encoding input image...")
|
||||
image_encode_start = time.time()
|
||||
|
||||
# Determine model version
|
||||
model_version = "Wan 2.2" if self.is_wan22 else "Wan 2.1"
|
||||
logger.info(
|
||||
f"Running {model_version} I2V inference "
|
||||
f"(boundary_ratio={boundary_ratio}, has_transformer_2={self.transformer_2 is not None})"
|
||||
)
|
||||
|
||||
if not self.is_wan22:
|
||||
# Wan 2.1 I2V: Compute CLIP image embeddings
|
||||
image_embeds = self._encode_image(image, last_image)
|
||||
image_embeds = image_embeds.to(self.dtype)
|
||||
else:
|
||||
# Wan 2.2 I2V: No image embeddings needed
|
||||
image_embeds = None
|
||||
|
||||
logger.info(f"Image encoding completed in {time.time() - image_encode_start:.2f}s")
|
||||
|
||||
# Prepare Latents with image conditioning (I2V-specific)
|
||||
latents, condition_data = self._prepare_latents(
|
||||
image, height, width, num_frames, generator, last_image
|
||||
)
|
||||
|
||||
self.scheduler.set_timesteps(num_inference_steps, device=self.device)
|
||||
|
||||
# Wan2.2: Calculate boundary timestep for two-stage denoising
|
||||
boundary_timestep = None
|
||||
if boundary_ratio is not None and self.transformer_2 is not None:
|
||||
boundary_timestep = boundary_ratio * self.scheduler.config.num_train_timesteps
|
||||
logger.info(
|
||||
f"Wan2.2 I2V two-stage denoising: boundary_timestep={boundary_timestep:.1f}, "
|
||||
f"guidance_scale={guidance_scale}, guidance_scale_2={guidance_scale_2}"
|
||||
)
|
||||
|
||||
# Denoising with two-stage support
|
||||
# Track which model was used in last step (for logging model transitions)
|
||||
last_model_used = [None]
|
||||
|
||||
def forward_fn(
|
||||
latents_input, extra_stream_latents, timestep, encoder_hidden_states, extra_tensors
|
||||
):
|
||||
"""Forward function for WAN I2V transformer with two-stage support.
|
||||
|
||||
Both Wan 2.1 and Wan 2.2 14B use concatenation approach: [latents, condition].
|
||||
Difference: Wan 2.1 passes image_embeds, Wan 2.2 passes None.
|
||||
"""
|
||||
# Select model based on timestep (if two-stage denoising is enabled)
|
||||
if boundary_timestep is not None and self.transformer_2 is not None:
|
||||
# Extract scalar timestep for comparison
|
||||
current_t = timestep if timestep.dim() == 0 else timestep[0]
|
||||
if current_t >= boundary_timestep:
|
||||
current_model = self.transformer
|
||||
model_name = "transformer"
|
||||
else:
|
||||
current_model = self.transformer_2
|
||||
model_name = "transformer_2"
|
||||
|
||||
# Log when switching models
|
||||
if last_model_used[0] != model_name:
|
||||
if self.rank == 0:
|
||||
logger.info(
|
||||
f"[TRTLLM] Switched to {model_name} at timestep {current_t:.1f}"
|
||||
)
|
||||
last_model_used[0] = model_name
|
||||
else:
|
||||
current_model = self.transformer
|
||||
|
||||
# Wan 2.1 & Wan 2.2 14B: concatenate latents and condition
|
||||
# Handle CFG: duplicate condition if batch dimension doubled
|
||||
if latents_input.shape[0] != condition_data.shape[0]:
|
||||
condition_to_use = torch.cat([condition_data] * 2)
|
||||
else:
|
||||
condition_to_use = condition_data
|
||||
|
||||
latent_model_input = torch.cat([latents_input, condition_to_use], dim=1).to(self.dtype)
|
||||
timestep_input = timestep.expand(latents_input.shape[0])
|
||||
|
||||
# Forward pass with I2V conditioning
|
||||
# Wan 2.1: image_embeds is not None (CLIP embeddings)
|
||||
# Wan 2.2 14B: image_embeds is None (no CLIP)
|
||||
return current_model(
|
||||
hidden_states=latent_model_input,
|
||||
timestep=timestep_input,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
encoder_hidden_states_image=image_embeds,
|
||||
)
|
||||
|
||||
# Two-stage denoising: model switching in forward_fn, guidance scale switching in denoise()
|
||||
latents = self.denoise(
|
||||
latents=latents,
|
||||
scheduler=self.scheduler,
|
||||
prompt_embeds=prompt_embeds,
|
||||
neg_prompt_embeds=neg_prompt_embeds,
|
||||
guidance_scale=guidance_scale,
|
||||
forward_fn=forward_fn,
|
||||
guidance_scale_2=guidance_scale_2,
|
||||
boundary_timestep=boundary_timestep,
|
||||
)
|
||||
|
||||
# Log TeaCache statistics - show stats for each transformer separately
|
||||
if self.rank == 0 and self.model_config.teacache.enable_teacache:
|
||||
logger.info("=" * 80)
|
||||
logger.info("TeaCache Statistics:")
|
||||
|
||||
# Stats for transformer (high-noise)
|
||||
if hasattr(self, "transformer_cache_backend") and self.transformer_cache_backend:
|
||||
stats = self.transformer_cache_backend.get_stats()
|
||||
total_steps = stats.get("total_steps", 0)
|
||||
cache_hits = stats.get("cached_steps", 0)
|
||||
cache_misses = stats.get("compute_steps", 0)
|
||||
hit_rate = (cache_hits / total_steps * 100) if total_steps > 0 else 0.0
|
||||
|
||||
logger.info(" Transformer (High-Noise):")
|
||||
logger.info(f" Total steps: {total_steps}")
|
||||
logger.info(f" Cache hits: {cache_hits}")
|
||||
logger.info(f" Cache misses: {cache_misses}")
|
||||
logger.info(f" Hit rate: {hit_rate:.1f}%")
|
||||
|
||||
# Stats for transformer_2 (low-noise)
|
||||
if hasattr(self, "transformer_2_cache_backend") and self.transformer_2_cache_backend:
|
||||
stats = self.transformer_2_cache_backend.get_stats()
|
||||
total_steps = stats.get("total_steps", 0)
|
||||
cache_hits = stats.get("cached_steps", 0)
|
||||
cache_misses = stats.get("compute_steps", 0)
|
||||
hit_rate = (cache_hits / total_steps * 100) if total_steps > 0 else 0.0
|
||||
|
||||
logger.info(" Transformer_2 (Low-Noise):")
|
||||
logger.info(f" Total steps: {total_steps}")
|
||||
logger.info(f" Cache hits: {cache_hits}")
|
||||
logger.info(f" Cache misses: {cache_misses}")
|
||||
logger.info(f" Hit rate: {hit_rate:.1f}%")
|
||||
|
||||
logger.info("=" * 80)
|
||||
|
||||
# Decode
|
||||
logger.info("Decoding video...")
|
||||
decode_start = time.time()
|
||||
video = self.decode_latents(latents, self._decode_latents)
|
||||
|
||||
if self.rank == 0:
|
||||
logger.info(f"Video decoded in {time.time() - decode_start:.2f}s")
|
||||
logger.info(f"Total pipeline time: {time.time() - pipeline_start:.2f}s")
|
||||
|
||||
return MediaOutput(video=video)
|
||||
|
||||
def _encode_prompt(self, prompt, negative_prompt, max_sequence_length):
|
||||
"""Encode text prompts to embeddings (same as T2V)."""
|
||||
prompt = [prompt] if isinstance(prompt, str) else prompt
|
||||
|
||||
def get_embeds(texts):
|
||||
text_inputs = self.tokenizer(
|
||||
texts,
|
||||
padding="max_length",
|
||||
max_length=max_sequence_length,
|
||||
truncation=True,
|
||||
return_attention_mask=True,
|
||||
return_tensors="pt",
|
||||
)
|
||||
input_ids = text_inputs.input_ids.to(self.device)
|
||||
attention_mask = text_inputs.attention_mask.to(self.device)
|
||||
|
||||
embeds = self.text_encoder(input_ids, attention_mask=attention_mask).last_hidden_state
|
||||
embeds = embeds.to(self.dtype)
|
||||
|
||||
# Zero-out padded tokens based on mask
|
||||
seq_lens = attention_mask.gt(0).sum(dim=1).long()
|
||||
cleaned_embeds = []
|
||||
for u, v in zip(embeds, seq_lens):
|
||||
real_content = u[:v]
|
||||
pad_len = max_sequence_length - real_content.size(0)
|
||||
if pad_len > 0:
|
||||
padded = torch.cat(
|
||||
[real_content, real_content.new_zeros(pad_len, real_content.size(1))]
|
||||
)
|
||||
else:
|
||||
padded = real_content
|
||||
cleaned_embeds.append(padded)
|
||||
|
||||
return torch.stack(cleaned_embeds, dim=0)
|
||||
|
||||
prompt_embeds = get_embeds(prompt)
|
||||
|
||||
if negative_prompt is None:
|
||||
negative_prompt = ""
|
||||
|
||||
neg_prompt = [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt
|
||||
if len(neg_prompt) == 1 and len(prompt) > 1:
|
||||
neg_prompt = neg_prompt * len(prompt)
|
||||
|
||||
neg_embeds = get_embeds(neg_prompt)
|
||||
|
||||
return prompt_embeds, neg_embeds
|
||||
|
||||
def _encode_image(
|
||||
self,
|
||||
image: Union[PIL.Image.Image, torch.Tensor, str],
|
||||
last_image: Optional[Union[PIL.Image.Image, torch.Tensor, str]] = None,
|
||||
) -> torch.Tensor:
|
||||
"""Encode image(s) using CLIP image encoder (Wan 2.1 I2V only)."""
|
||||
if isinstance(image, str):
|
||||
image = PIL.Image.open(image).convert("RGB")
|
||||
if isinstance(last_image, str):
|
||||
last_image = PIL.Image.open(last_image).convert("RGB")
|
||||
|
||||
images_to_encode = [image] if last_image is None else [image, last_image]
|
||||
|
||||
image_inputs = self.image_processor(images=images_to_encode, return_tensors="pt").to(
|
||||
self.device
|
||||
)
|
||||
image_embeds = self.image_encoder(**image_inputs, output_hidden_states=True)
|
||||
|
||||
return image_embeds.hidden_states[-2]
|
||||
|
||||
def _prepare_latents(
|
||||
self,
|
||||
image: Union[PIL.Image.Image, torch.Tensor, str],
|
||||
height: int,
|
||||
width: int,
|
||||
num_frames: int,
|
||||
generator: torch.Generator,
|
||||
last_image: Optional[Union[PIL.Image.Image, torch.Tensor, str]] = None,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""Prepare latents with image conditioning for I2V generation."""
|
||||
num_channels_latents = 16
|
||||
num_latent_frames = (num_frames - 1) // self.vae_scale_factor_temporal + 1
|
||||
latent_height = height // self.vae_scale_factor_spatial
|
||||
latent_width = width // self.vae_scale_factor_spatial
|
||||
|
||||
# Create random noise latents
|
||||
shape = (1, num_channels_latents, num_latent_frames, latent_height, latent_width)
|
||||
latents = randn_tensor(shape, generator=generator, device=self.device, dtype=self.dtype)
|
||||
|
||||
# Load and preprocess image(s)
|
||||
if isinstance(image, str):
|
||||
image = PIL.Image.open(image).convert("RGB")
|
||||
image = self.video_processor.preprocess(image, height=height, width=width).to(
|
||||
self.device, dtype=torch.float32
|
||||
)
|
||||
|
||||
if last_image is not None:
|
||||
if isinstance(last_image, str):
|
||||
last_image = PIL.Image.open(last_image).convert("RGB")
|
||||
last_image = self.video_processor.preprocess(last_image, height=height, width=width).to(
|
||||
self.device, dtype=torch.float32
|
||||
)
|
||||
|
||||
image = image.unsqueeze(2)
|
||||
|
||||
# Create video conditioning tensor (same for both Wan 2.1 and Wan 2.2 14B)
|
||||
if last_image is None:
|
||||
# First frame + zeros
|
||||
video_condition = torch.cat(
|
||||
[
|
||||
image,
|
||||
image.new_zeros(image.shape[0], image.shape[1], num_frames - 1, height, width),
|
||||
],
|
||||
dim=2,
|
||||
)
|
||||
else:
|
||||
# First frame + zeros + last frame (interpolation)
|
||||
last_image = last_image.unsqueeze(2)
|
||||
video_condition = torch.cat(
|
||||
[
|
||||
image,
|
||||
image.new_zeros(image.shape[0], image.shape[1], num_frames - 2, height, width),
|
||||
last_image,
|
||||
],
|
||||
dim=2,
|
||||
)
|
||||
|
||||
# Encode video condition through VAE
|
||||
video_condition = video_condition.to(device=self.device, dtype=self.vae.dtype)
|
||||
latent_condition = retrieve_latents(self.vae.encode(video_condition), sample_mode="argmax")
|
||||
latent_condition = latent_condition.to(self.dtype)
|
||||
|
||||
# Normalize latents to match diffusion model's latent space
|
||||
latents_mean = (
|
||||
torch.tensor(self.vae.config.latents_mean)
|
||||
.view(1, self.vae.config.z_dim, 1, 1, 1)
|
||||
.to(latents.device, latents.dtype)
|
||||
)
|
||||
latents_std = 1.0 / torch.tensor(self.vae.config.latents_std).view(
|
||||
1, self.vae.config.z_dim, 1, 1, 1
|
||||
).to(latents.device, latents.dtype)
|
||||
latent_condition = (latent_condition - latents_mean) * latents_std
|
||||
|
||||
# Create mask in video frame space
|
||||
# Reshaping is required to match the transformer's expected input format
|
||||
mask_lat_size = torch.ones(1, 1, num_frames, latent_height, latent_width)
|
||||
|
||||
if last_image is None:
|
||||
mask_lat_size[:, :, list(range(1, num_frames))] = 0
|
||||
else:
|
||||
mask_lat_size[:, :, list(range(1, num_frames - 1))] = 0
|
||||
|
||||
first_frame_mask = mask_lat_size[:, :, 0:1]
|
||||
first_frame_mask = torch.repeat_interleave(
|
||||
first_frame_mask, dim=2, repeats=self.vae_scale_factor_temporal
|
||||
)
|
||||
|
||||
mask_lat_size = torch.cat([first_frame_mask, mask_lat_size[:, :, 1:, :]], dim=2)
|
||||
|
||||
mask_lat_size = mask_lat_size.view(
|
||||
1, -1, self.vae_scale_factor_temporal, latent_height, latent_width
|
||||
)
|
||||
|
||||
mask_lat_size = mask_lat_size.transpose(1, 2)
|
||||
|
||||
mask_lat_size = mask_lat_size.to(self.device, dtype=self.dtype)
|
||||
|
||||
# Concatenate mask and condition along channel dimension
|
||||
condition = torch.cat([mask_lat_size, latent_condition], dim=1)
|
||||
return latents, condition
|
||||
|
||||
def _decode_latents(self, latents):
|
||||
"""Decode latents to video (same as T2V)."""
|
||||
latents = latents.to(self.vae.dtype)
|
||||
|
||||
# Denormalization
|
||||
if hasattr(self.vae.config, "latents_mean") and hasattr(self.vae.config, "latents_std"):
|
||||
if not hasattr(self, "_latents_mean"):
|
||||
self._latents_mean = (
|
||||
torch.tensor(self.vae.config.latents_mean)
|
||||
.view(1, -1, 1, 1, 1)
|
||||
.to(self.device, self.vae.dtype)
|
||||
)
|
||||
self._latents_std = (
|
||||
torch.tensor(self.vae.config.latents_std)
|
||||
.view(1, -1, 1, 1, 1)
|
||||
.to(self.device, self.vae.dtype)
|
||||
)
|
||||
latents = (latents * self._latents_std) + self._latents_mean
|
||||
else:
|
||||
scaling_factor = self.vae.config.get("scaling_factor", 1.0)
|
||||
latents = latents / scaling_factor
|
||||
|
||||
# VAE decode: returns (B, C, T, H, W)
|
||||
video = self.vae.decode(latents, return_dict=False)[0]
|
||||
|
||||
# Post-process video tensor: (B, C, T, H, W) -> (T, H, W, C) uint8
|
||||
video = postprocess_video_tensor(video, remove_batch_dim=True)
|
||||
|
||||
return video
|
||||
756
tensorrt_llm/_torch/visual_gen/models/wan/transformer_wan.py
Normal file
756
tensorrt_llm/_torch/visual_gen/models/wan/transformer_wan.py
Normal file
@ -0,0 +1,756 @@
|
||||
import math
|
||||
from typing import Tuple
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from diffusers.models.embeddings import PixArtAlphaTextProjection, TimestepEmbedding, Timesteps
|
||||
from tqdm import tqdm
|
||||
from transformers.modeling_utils import get_parameter_device
|
||||
|
||||
from tensorrt_llm._torch.modules.layer_norm import LayerNorm
|
||||
from tensorrt_llm._torch.modules.linear import Linear
|
||||
from tensorrt_llm._torch.modules.mlp import MLP
|
||||
from tensorrt_llm._torch.modules.rms_norm import RMSNorm
|
||||
from tensorrt_llm._torch.visual_gen.config import DiffusionModelConfig
|
||||
from tensorrt_llm._torch.visual_gen.modules.attention import Attention, QKVMode
|
||||
from tensorrt_llm._torch.visual_gen.parallelism import setup_sequence_parallelism
|
||||
from tensorrt_llm._torch.visual_gen.quantization.loader import DynamicLinearWeightLoader
|
||||
from tensorrt_llm.logger import logger
|
||||
from tensorrt_llm.models.modeling_utils import QuantConfig
|
||||
|
||||
# =========================================================================
|
||||
# 1. Rotary Positional Embeddings
|
||||
# =========================================================================
|
||||
|
||||
|
||||
class WanRotaryPosEmbed(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
attention_head_dim: int,
|
||||
patch_size: Tuple[int, int, int],
|
||||
max_seq_len: int,
|
||||
theta: float = 10000.0,
|
||||
):
|
||||
super().__init__()
|
||||
self.patch_size = patch_size
|
||||
|
||||
# Split logic matches Hugging Face exactly
|
||||
self.h_dim = 2 * (attention_head_dim // 6)
|
||||
self.w_dim = 2 * (attention_head_dim // 6)
|
||||
self.t_dim = attention_head_dim - self.h_dim - self.w_dim
|
||||
|
||||
freqs_cos, freqs_sin = [], []
|
||||
|
||||
# Order: Time, Height, Width
|
||||
for dim in [self.t_dim, self.h_dim, self.w_dim]:
|
||||
# High precision generation
|
||||
freqs = 1.0 / (theta ** (torch.arange(0, dim, 2, dtype=torch.float64) / dim))
|
||||
t = torch.arange(max_seq_len, dtype=torch.float64)
|
||||
freqs = torch.outer(t, freqs)
|
||||
|
||||
# Interleaved Pattern [c0, c0, c1, c1]
|
||||
freqs_cos.append(freqs.cos().repeat_interleave(2, dim=-1).float())
|
||||
freqs_sin.append(freqs.sin().repeat_interleave(2, dim=-1).float())
|
||||
|
||||
self.register_buffer("freqs_cos", torch.cat(freqs_cos, dim=1), persistent=False)
|
||||
self.register_buffer("freqs_sin", torch.cat(freqs_sin, dim=1), persistent=False)
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
# Robust shape unpacking
|
||||
b, c, f, h, w = hidden_states.shape
|
||||
p_t, p_h, p_w = self.patch_size
|
||||
ppf, pph, ppw = f // p_t, h // p_h, w // p_w
|
||||
|
||||
split_sizes = [self.t_dim, self.h_dim, self.w_dim]
|
||||
freqs_cos = self.freqs_cos.split(split_sizes, dim=1)
|
||||
freqs_sin = self.freqs_sin.split(split_sizes, dim=1)
|
||||
|
||||
# Broadcast frequencies to 3D grid: [Time, Height, Width]
|
||||
f_cos_t = freqs_cos[0][:ppf].view(ppf, 1, 1, -1).expand(ppf, pph, ppw, -1)
|
||||
f_sin_t = freqs_sin[0][:ppf].view(ppf, 1, 1, -1).expand(ppf, pph, ppw, -1)
|
||||
|
||||
f_cos_h = freqs_cos[1][:pph].view(1, pph, 1, -1).expand(ppf, pph, ppw, -1)
|
||||
f_sin_h = freqs_sin[1][:pph].view(1, pph, 1, -1).expand(ppf, pph, ppw, -1)
|
||||
|
||||
f_cos_w = freqs_cos[2][:ppw].view(1, 1, ppw, -1).expand(ppf, pph, ppw, -1)
|
||||
f_sin_w = freqs_sin[2][:ppw].view(1, 1, ppw, -1).expand(ppf, pph, ppw, -1)
|
||||
|
||||
# Concatenate and flatten for Attention [1, SeqLen, 1, Dim] (SHD format)
|
||||
# New Attention module applies RoPE in [B, S, H, D] layout before reshaping to [B, H, S, D]
|
||||
return (
|
||||
torch.cat([f_cos_t, f_cos_h, f_cos_w], dim=-1).flatten(0, 2).unsqueeze(0).unsqueeze(2),
|
||||
torch.cat([f_sin_t, f_sin_h, f_sin_w], dim=-1).flatten(0, 2).unsqueeze(0).unsqueeze(2),
|
||||
)
|
||||
|
||||
|
||||
# =========================================================================
|
||||
# 2. Embeddings & Attention
|
||||
# =========================================================================
|
||||
|
||||
|
||||
class WanImageEmbedding(nn.Module):
|
||||
"""Image embedding for I2V models (Wan 2.1/2.2)."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
in_features: int,
|
||||
out_features: int,
|
||||
pos_embed_seq_len: int = None,
|
||||
model_config: DiffusionModelConfig = None,
|
||||
):
|
||||
super().__init__()
|
||||
dtype = model_config.torch_dtype if model_config else None
|
||||
# LayerNorm weights in fp32 (matches internal float32 normalization; avoids bf16/fp32 mismatch).
|
||||
self.norm1 = LayerNorm(
|
||||
hidden_size=in_features, eps=1e-6, dtype=torch.float32, has_weights=True, has_bias=True
|
||||
)
|
||||
|
||||
# Match HF FeedForward structure: Linear(in, in) → GELU → Linear(in, out)
|
||||
self.ff_in = Linear(
|
||||
in_features,
|
||||
in_features,
|
||||
bias=True,
|
||||
dtype=dtype,
|
||||
mapping=model_config.mapping if model_config else None,
|
||||
quant_config=model_config.quant_config if model_config else None,
|
||||
skip_create_weights_in_init=model_config.skip_create_weights_in_init
|
||||
if model_config
|
||||
else False,
|
||||
force_dynamic_quantization=model_config.force_dynamic_quantization
|
||||
if model_config
|
||||
else False,
|
||||
)
|
||||
self.ff_out = Linear(
|
||||
in_features,
|
||||
out_features,
|
||||
bias=True,
|
||||
dtype=dtype,
|
||||
mapping=model_config.mapping if model_config else None,
|
||||
quant_config=model_config.quant_config if model_config else None,
|
||||
skip_create_weights_in_init=model_config.skip_create_weights_in_init
|
||||
if model_config
|
||||
else False,
|
||||
force_dynamic_quantization=model_config.force_dynamic_quantization
|
||||
if model_config
|
||||
else False,
|
||||
)
|
||||
|
||||
self.norm2 = LayerNorm(
|
||||
hidden_size=out_features, eps=1e-6, dtype=torch.float32, has_weights=True, has_bias=True
|
||||
)
|
||||
|
||||
if pos_embed_seq_len is not None:
|
||||
self.pos_embed = nn.Parameter(torch.zeros(1, pos_embed_seq_len, in_features))
|
||||
else:
|
||||
self.pos_embed = None
|
||||
|
||||
def forward(self, encoder_hidden_states_image: torch.Tensor) -> torch.Tensor:
|
||||
if self.pos_embed is not None:
|
||||
batch_size, seq_len, embed_dim = encoder_hidden_states_image.shape
|
||||
encoder_hidden_states_image = encoder_hidden_states_image.view(
|
||||
-1, 2 * seq_len, embed_dim
|
||||
)
|
||||
encoder_hidden_states_image = encoder_hidden_states_image + self.pos_embed
|
||||
|
||||
hidden_states = self.norm1(encoder_hidden_states_image)
|
||||
hidden_states = self.ff_in(hidden_states)
|
||||
hidden_states = F.gelu(hidden_states)
|
||||
hidden_states = self.ff_out(hidden_states)
|
||||
hidden_states = self.norm2(hidden_states)
|
||||
return hidden_states
|
||||
|
||||
|
||||
class WanTimeTextImageEmbedding(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
dim,
|
||||
time_freq_dim,
|
||||
time_proj_dim,
|
||||
text_embed_dim,
|
||||
model_config: DiffusionModelConfig,
|
||||
image_embed_dim: int = None,
|
||||
pos_embed_seq_len: int = None,
|
||||
):
|
||||
super().__init__()
|
||||
dtype = model_config.torch_dtype
|
||||
quant_config = model_config.quant_config
|
||||
skip_create_weights = model_config.skip_create_weights_in_init
|
||||
force_dynamic_quant = model_config.force_dynamic_quantization
|
||||
|
||||
self.timesteps_proj = Timesteps(
|
||||
num_channels=time_freq_dim, flip_sin_to_cos=True, downscale_freq_shift=0
|
||||
)
|
||||
self.time_embedder = TimestepEmbedding(in_channels=time_freq_dim, time_embed_dim=dim)
|
||||
self.act_fn = nn.SiLU()
|
||||
|
||||
self.time_proj = Linear(
|
||||
dim,
|
||||
time_proj_dim,
|
||||
dtype=dtype,
|
||||
mapping=model_config.mapping,
|
||||
quant_config=quant_config,
|
||||
skip_create_weights_in_init=skip_create_weights,
|
||||
force_dynamic_quantization=force_dynamic_quant,
|
||||
)
|
||||
self.text_embedder = PixArtAlphaTextProjection(text_embed_dim, dim, act_fn="gelu_tanh")
|
||||
|
||||
self.image_embedder = None
|
||||
if image_embed_dim is not None:
|
||||
self.image_embedder = WanImageEmbedding(
|
||||
image_embed_dim, dim, pos_embed_seq_len=pos_embed_seq_len, model_config=model_config
|
||||
)
|
||||
|
||||
def forward(self, timestep, encoder_hidden_states, encoder_hidden_states_image=None):
|
||||
timestep = self.timesteps_proj(timestep)
|
||||
|
||||
# Get time_embedder dtype
|
||||
time_embedder_dtype = next(iter(self.time_embedder.parameters())).dtype
|
||||
if timestep.dtype != time_embedder_dtype and time_embedder_dtype not in [
|
||||
torch.float8_e4m3fn,
|
||||
torch.float8_e5m2,
|
||||
]:
|
||||
timestep = timestep.to(time_embedder_dtype)
|
||||
|
||||
temb = self.time_embedder(timestep).type_as(encoder_hidden_states)
|
||||
|
||||
temb_proj = self.time_proj(self.act_fn(temb))
|
||||
|
||||
encoder_hidden_states = self.text_embedder(encoder_hidden_states)
|
||||
|
||||
if encoder_hidden_states_image is not None and self.image_embedder is not None:
|
||||
encoder_hidden_states_image = self.image_embedder(encoder_hidden_states_image)
|
||||
|
||||
return temb, temb_proj, encoder_hidden_states, encoder_hidden_states_image
|
||||
|
||||
|
||||
class WanBlock(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
model_config: DiffusionModelConfig,
|
||||
_layer_idx: int,
|
||||
added_kv_proj_dim: int = None,
|
||||
):
|
||||
super().__init__()
|
||||
config = model_config.pretrained_config
|
||||
|
||||
if hasattr(config, "hidden_size"):
|
||||
hidden_size = config.hidden_size
|
||||
elif hasattr(config, "attention_head_dim") and hasattr(config, "num_attention_heads"):
|
||||
hidden_size = config.attention_head_dim * config.num_attention_heads
|
||||
else:
|
||||
hidden_size = 1536
|
||||
|
||||
# Wan 2.1 1.3B defaults
|
||||
num_heads = getattr(config, "num_attention_heads", 12)
|
||||
head_dim = getattr(config, "attention_head_dim", 128)
|
||||
ffn_dim = getattr(config, "ffn_dim", 8960)
|
||||
eps = getattr(config, "eps", 1e-6)
|
||||
|
||||
dtype = model_config.torch_dtype
|
||||
quant_config = model_config.quant_config
|
||||
skip_create_weights = model_config.skip_create_weights_in_init
|
||||
force_dynamic_quant = model_config.force_dynamic_quantization
|
||||
|
||||
# Store for I2V reshaping logic
|
||||
self.num_heads = num_heads
|
||||
self.head_dim = head_dim
|
||||
|
||||
# LayerNorm weights in fp32 (matches internal float32 normalization; avoids bf16/fp32 mismatch).
|
||||
self.norm1 = LayerNorm(
|
||||
hidden_size=hidden_size, eps=eps, dtype=torch.float32, has_weights=False, has_bias=False
|
||||
)
|
||||
|
||||
# Self-attention with fused QKV
|
||||
self.attn1 = Attention(
|
||||
hidden_size=hidden_size,
|
||||
num_attention_heads=num_heads,
|
||||
head_dim=head_dim,
|
||||
qkv_mode=QKVMode.FUSE_QKV,
|
||||
qk_norm=True,
|
||||
eps=eps,
|
||||
config=model_config,
|
||||
layer_idx=_layer_idx,
|
||||
)
|
||||
|
||||
# Cross-attention with separate Q, K, V
|
||||
self.attn2 = Attention(
|
||||
hidden_size=hidden_size,
|
||||
num_attention_heads=num_heads,
|
||||
head_dim=head_dim,
|
||||
qkv_mode=QKVMode.SEPARATE_QKV,
|
||||
qk_norm=True,
|
||||
eps=eps,
|
||||
config=model_config,
|
||||
layer_idx=_layer_idx,
|
||||
)
|
||||
|
||||
self.norm2 = LayerNorm(
|
||||
hidden_size=hidden_size, eps=eps, dtype=torch.float32, has_weights=True, has_bias=True
|
||||
)
|
||||
self.norm3 = LayerNorm(
|
||||
hidden_size=hidden_size, eps=eps, dtype=torch.float32, has_weights=False, has_bias=False
|
||||
)
|
||||
|
||||
self.ffn = MLP(
|
||||
hidden_size=hidden_size,
|
||||
intermediate_size=ffn_dim,
|
||||
bias=True,
|
||||
activation=lambda x: F.gelu(x, approximate="tanh"),
|
||||
dtype=dtype,
|
||||
config=model_config,
|
||||
layer_idx=_layer_idx,
|
||||
reduce_output=False,
|
||||
)
|
||||
|
||||
# I2V: Additional K/V projections for image embeddings
|
||||
self.add_k_proj = self.add_v_proj = None
|
||||
self.norm_added_k = None
|
||||
if added_kv_proj_dim is not None:
|
||||
self.add_k_proj = Linear(
|
||||
added_kv_proj_dim,
|
||||
hidden_size,
|
||||
dtype=dtype,
|
||||
mapping=model_config.mapping,
|
||||
quant_config=quant_config,
|
||||
skip_create_weights_in_init=skip_create_weights,
|
||||
force_dynamic_quantization=force_dynamic_quant,
|
||||
)
|
||||
self.add_v_proj = Linear(
|
||||
added_kv_proj_dim,
|
||||
hidden_size,
|
||||
dtype=dtype,
|
||||
mapping=model_config.mapping,
|
||||
quant_config=quant_config,
|
||||
skip_create_weights_in_init=skip_create_weights,
|
||||
force_dynamic_quantization=force_dynamic_quant,
|
||||
)
|
||||
self.norm_added_k = RMSNorm(
|
||||
hidden_size=hidden_size, eps=eps, dtype=dtype, has_weights=True
|
||||
)
|
||||
|
||||
# Use torch.empty().normal_(std=...) instead of torch.randn()/scale for MetaInitMode compatibility
|
||||
self.scale_shift_table = nn.Parameter(
|
||||
torch.empty(1, 6, hidden_size).normal_(std=hidden_size**-0.5)
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x,
|
||||
encoder_hidden_states,
|
||||
temb,
|
||||
freqs_cos,
|
||||
freqs_sin,
|
||||
):
|
||||
shift_msa, scale_msa, gate_msa, c_shift_msa, c_scale_msa, c_gate_msa = (
|
||||
self.scale_shift_table.float() + temb.float()
|
||||
).chunk(6, dim=1)
|
||||
|
||||
normed = self.norm1(x.float()) * (1 + scale_msa) + shift_msa
|
||||
normed = normed.to(x.dtype)
|
||||
|
||||
# Prepare frequencies for Attention
|
||||
freqs = (freqs_cos, freqs_sin) if freqs_cos is not None and freqs_sin is not None else None
|
||||
|
||||
# Self-attention with RoPE
|
||||
x = (
|
||||
x.float()
|
||||
+ self.attn1(
|
||||
normed,
|
||||
freqs=freqs,
|
||||
).float()
|
||||
* gate_msa
|
||||
).to(x.dtype)
|
||||
|
||||
norm_x = self.norm2(x.float()).to(x.dtype)
|
||||
|
||||
# I2V: Split encoder_hidden_states into image and text parts if needed
|
||||
encoder_hidden_states_img = None
|
||||
encoder_hidden_states_text = encoder_hidden_states
|
||||
if self.add_k_proj is not None:
|
||||
image_context_length = encoder_hidden_states.shape[1] - 512
|
||||
encoder_hidden_states_img = encoder_hidden_states[:, :image_context_length]
|
||||
encoder_hidden_states_text = encoder_hidden_states[:, image_context_length:]
|
||||
|
||||
# Text cross-attention
|
||||
attn2_output = self.attn2(norm_x, encoder_hidden_states=encoder_hidden_states_text)
|
||||
|
||||
# I2V: Additional image cross-attention if image embeddings are present
|
||||
if encoder_hidden_states_img is not None:
|
||||
batch_size, seq_len = norm_x.shape[:2]
|
||||
|
||||
query = self.attn2.get_qkv(norm_x, None)[0] # Q only
|
||||
query, _ = self.attn2.apply_qk_norm(query, query)
|
||||
|
||||
key_img = self.add_k_proj(encoder_hidden_states_img)
|
||||
value_img = self.add_v_proj(encoder_hidden_states_img)
|
||||
key_img = self.norm_added_k(key_img)
|
||||
|
||||
query = query.view(batch_size, seq_len, self.num_heads, self.head_dim)
|
||||
key_img = key_img.view(
|
||||
batch_size, encoder_hidden_states_img.shape[1], self.num_heads, self.head_dim
|
||||
)
|
||||
value_img = value_img.view(
|
||||
batch_size, encoder_hidden_states_img.shape[1], self.num_heads, self.head_dim
|
||||
)
|
||||
|
||||
attn_img_output = self.attn2._attn_impl(
|
||||
query,
|
||||
key_img,
|
||||
value_img,
|
||||
batch_size=batch_size,
|
||||
seq_len=seq_len,
|
||||
kv_seq_len=encoder_hidden_states_img.shape[1],
|
||||
)
|
||||
|
||||
attn2_output = attn2_output + attn_img_output
|
||||
|
||||
x = x + attn2_output
|
||||
|
||||
# 3. Feed-forward
|
||||
normed = self.norm3(x.float()) * (1 + c_scale_msa) + c_shift_msa
|
||||
normed = normed.to(x.dtype)
|
||||
|
||||
x = (x.float() + self.ffn(normed).float() * c_gate_msa).to(x.dtype)
|
||||
|
||||
return x
|
||||
|
||||
|
||||
class WanTransformer3DModel(nn.Module):
|
||||
_supports_gradient_checkpointing = True
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model_config: DiffusionModelConfig,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.model_config = model_config
|
||||
|
||||
# Validate no tensor parallelism
|
||||
if model_config.parallel.dit_tp_size > 1:
|
||||
raise ValueError(
|
||||
f"WAN does not support tensor parallelism. "
|
||||
f"Got dit_tp_size={model_config.parallel.dit_tp_size}"
|
||||
)
|
||||
|
||||
# Setup sequence parallelism (Ulysses)
|
||||
num_heads = getattr(model_config.pretrained_config, "num_attention_heads", 12)
|
||||
self.use_ulysses, self.ulysses_size, self.ulysses_pg, self.ulysses_rank = (
|
||||
setup_sequence_parallelism(
|
||||
model_config=model_config,
|
||||
num_attention_heads=num_heads,
|
||||
)
|
||||
)
|
||||
|
||||
config = model_config.pretrained_config
|
||||
|
||||
dtype = model_config.torch_dtype
|
||||
quant_config = model_config.quant_config
|
||||
skip_create_weights = model_config.skip_create_weights_in_init
|
||||
force_dynamic_quant = model_config.force_dynamic_quantization
|
||||
|
||||
if hasattr(config, "hidden_size"):
|
||||
hidden_size = config.hidden_size
|
||||
elif hasattr(config, "attention_head_dim") and hasattr(config, "num_attention_heads"):
|
||||
hidden_size = config.attention_head_dim * config.num_attention_heads
|
||||
else:
|
||||
hidden_size = 1536 # Wan 1.3B default
|
||||
|
||||
num_layers = getattr(config, "num_layers", 30)
|
||||
attention_head_dim = getattr(config, "attention_head_dim", 128)
|
||||
in_channels = getattr(config, "in_channels", 16)
|
||||
out_channels = getattr(config, "out_channels", 16)
|
||||
text_dim = getattr(config, "text_dim", 4096)
|
||||
freq_dim = getattr(config, "freq_dim", 256)
|
||||
patch_size = getattr(config, "patch_size", [1, 2, 2])
|
||||
image_embed_dim = getattr(config, "image_dim", None) # e.g., 1280 for I2V
|
||||
added_kv_proj_dim = getattr(config, "added_kv_proj_dim", None)
|
||||
pos_embed_seq_len = getattr(config, "pos_embed_seq_len", None)
|
||||
|
||||
# Calculate FFN Dim
|
||||
ffn_dim = getattr(config, "ffn_dim", None)
|
||||
if ffn_dim is None:
|
||||
ffn_dim = (
|
||||
13824
|
||||
if hidden_size == 5120
|
||||
else (8960 if hidden_size == 1536 else int(hidden_size * 4))
|
||||
)
|
||||
|
||||
# Store config for unpatchify and pipeline compatibility
|
||||
self.config = type(
|
||||
"Config",
|
||||
(),
|
||||
{
|
||||
"patch_size": patch_size,
|
||||
"hidden_size": hidden_size,
|
||||
"image_dim": image_embed_dim,
|
||||
"in_channels": in_channels,
|
||||
"out_channels": out_channels,
|
||||
"num_layers": num_layers,
|
||||
},
|
||||
)()
|
||||
|
||||
self.patch_embedding = nn.Conv3d(
|
||||
in_channels,
|
||||
hidden_size,
|
||||
kernel_size=patch_size,
|
||||
stride=patch_size,
|
||||
dtype=dtype, # use model's target dtype (bf16)
|
||||
)
|
||||
|
||||
self.condition_embedder = WanTimeTextImageEmbedding(
|
||||
dim=hidden_size,
|
||||
time_freq_dim=freq_dim,
|
||||
time_proj_dim=hidden_size * 6,
|
||||
text_embed_dim=text_dim,
|
||||
model_config=model_config,
|
||||
image_embed_dim=image_embed_dim,
|
||||
pos_embed_seq_len=pos_embed_seq_len,
|
||||
)
|
||||
|
||||
self.blocks = nn.ModuleList(
|
||||
[
|
||||
WanBlock(
|
||||
model_config=model_config,
|
||||
_layer_idx=i,
|
||||
added_kv_proj_dim=added_kv_proj_dim,
|
||||
)
|
||||
for i in range(num_layers)
|
||||
]
|
||||
)
|
||||
|
||||
self.rope = WanRotaryPosEmbed(attention_head_dim, patch_size, max_seq_len=1024)
|
||||
|
||||
# LayerNorm weights in fp32 (matches internal float32 normalization; avoids bf16/fp32 mismatch).
|
||||
self.norm_out = LayerNorm(
|
||||
hidden_size=hidden_size,
|
||||
eps=1e-6,
|
||||
dtype=torch.float32,
|
||||
has_weights=False,
|
||||
has_bias=False,
|
||||
)
|
||||
|
||||
self.proj_out = Linear(
|
||||
hidden_size,
|
||||
out_channels * math.prod(patch_size),
|
||||
dtype=dtype,
|
||||
mapping=model_config.mapping,
|
||||
quant_config=quant_config,
|
||||
skip_create_weights_in_init=skip_create_weights,
|
||||
force_dynamic_quantization=force_dynamic_quant,
|
||||
)
|
||||
# Use torch.empty().normal_(std=...) instead of torch.randn()/scale for MetaInitMode compatibility
|
||||
self.scale_shift_table = nn.Parameter(
|
||||
torch.empty(1, 2, hidden_size).normal_(std=hidden_size**-0.5)
|
||||
)
|
||||
|
||||
self.__post_init__()
|
||||
|
||||
@property
|
||||
def device(self):
|
||||
return get_parameter_device(self)
|
||||
|
||||
def __post_init__(self):
|
||||
self.apply_quant_config_exclude_modules()
|
||||
|
||||
for _, module in self.named_modules():
|
||||
if callable(getattr(module, "create_weights", None)):
|
||||
module.create_weights()
|
||||
|
||||
def apply_quant_config_exclude_modules(self):
|
||||
quant_config = self.model_config.quant_config
|
||||
if quant_config is None or quant_config.exclude_modules is None:
|
||||
return
|
||||
|
||||
kv_cache_quant_algo = quant_config.kv_cache_quant_algo if quant_config else None
|
||||
no_quant_config = QuantConfig(kv_cache_quant_algo=kv_cache_quant_algo)
|
||||
|
||||
for name, module in self.named_modules():
|
||||
if isinstance(module, Linear):
|
||||
is_excluded = quant_config.is_module_excluded_from_quantization(name)
|
||||
if is_excluded and getattr(module, "quant_config", None) is not None:
|
||||
module.quant_config = no_quant_config
|
||||
|
||||
def unpatchify(self, x, original_shape):
|
||||
N, C, T, H, W = original_shape
|
||||
pt, ph, pw = self.config.patch_size
|
||||
gt, gh, gw = T // pt, H // ph, W // pw
|
||||
# Use output channels instead of input channels for unpatchifying
|
||||
out_channels = self.proj_out.out_features // (pt * ph * pw)
|
||||
return (
|
||||
x.view(N, gt, gh, gw, pt, ph, pw, out_channels)
|
||||
.permute(0, 7, 1, 4, 2, 5, 3, 6)
|
||||
.reshape(N, out_channels, T, H, W)
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states,
|
||||
timestep,
|
||||
encoder_hidden_states,
|
||||
encoder_hidden_states_image=None,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
Forward pass with optional Ulysses sequence parallelism.
|
||||
|
||||
With Ulysses enabled (ulysses_size > 1):
|
||||
1. Shard input sequence across ranks: [B, S] -> [B, S/P]
|
||||
2. Each block's attention does internal all-to-all for full sequence
|
||||
3. Gather output sequence: [B, S/P] -> [B, S]
|
||||
|
||||
When TeaCache is enabled, TeaCacheHook intercepts and replaces this call.
|
||||
"""
|
||||
original_shape = hidden_states.shape
|
||||
B, C, T, H, W = original_shape
|
||||
pt, ph, pw = self.config.patch_size
|
||||
|
||||
# Generate WAN RoPE frequencies
|
||||
freqs_cos, freqs_sin = self.rope(hidden_states)
|
||||
|
||||
# Patchify and flatten: [B, C, T, H, W] -> [B, S, hidden_size]
|
||||
x = self.patch_embedding(hidden_states).flatten(2).transpose(1, 2)
|
||||
|
||||
# Shard sequence for Ulysses parallelism: [B, S] -> [B, S/P]
|
||||
if self.use_ulysses:
|
||||
seq_len = x.shape[1]
|
||||
if seq_len % self.ulysses_size != 0:
|
||||
raise ValueError(
|
||||
f"Sequence length ({seq_len}) is not divisible by ulysses_size ({self.ulysses_size}). "
|
||||
f"Adjust video dimensions or use a different ulysses_size."
|
||||
)
|
||||
|
||||
chunk_size = seq_len // self.ulysses_size
|
||||
x = x[:, self.ulysses_rank * chunk_size : (self.ulysses_rank + 1) * chunk_size, :]
|
||||
|
||||
# Shard RoPE frequencies to match sequence sharding
|
||||
# RoPE freqs shape: [B, S, ...], so shard along dim 1 (sequence dimension)
|
||||
if freqs_cos is not None and freqs_sin is not None:
|
||||
freqs_cos = freqs_cos[
|
||||
:, self.ulysses_rank * chunk_size : (self.ulysses_rank + 1) * chunk_size
|
||||
]
|
||||
freqs_sin = freqs_sin[
|
||||
:, self.ulysses_rank * chunk_size : (self.ulysses_rank + 1) * chunk_size
|
||||
]
|
||||
|
||||
# Time and text/image embeddings
|
||||
temb, temb_proj, encoder_hidden_states, encoder_hidden_states_image = (
|
||||
self.condition_embedder(timestep, encoder_hidden_states, encoder_hidden_states_image)
|
||||
)
|
||||
temb_proj = temb_proj.view(-1, 6, self.config.hidden_size)
|
||||
|
||||
# I2V: Concatenate image and text embeddings if image embeddings are provided
|
||||
if encoder_hidden_states_image is not None:
|
||||
# Handle CFG: duplicate image embeddings if batch dimension is doubled
|
||||
if encoder_hidden_states_image.shape[0] != encoder_hidden_states.shape[0]:
|
||||
batch_multiplier = (
|
||||
encoder_hidden_states.shape[0] // encoder_hidden_states_image.shape[0]
|
||||
)
|
||||
encoder_hidden_states_image = encoder_hidden_states_image.repeat(
|
||||
batch_multiplier, 1, 1
|
||||
)
|
||||
encoder_hidden_states = torch.cat(
|
||||
[encoder_hidden_states_image, encoder_hidden_states], dim=1
|
||||
)
|
||||
|
||||
# Transformer blocks (attention handles all-to-all internally for Ulysses)
|
||||
for block in self.blocks:
|
||||
x = block(
|
||||
x,
|
||||
encoder_hidden_states,
|
||||
temb_proj,
|
||||
freqs_cos,
|
||||
freqs_sin,
|
||||
)
|
||||
|
||||
# Gather sequence from all ranks: [B, S/P] -> [B, S]
|
||||
if self.use_ulysses:
|
||||
# Ensure tensor is contiguous before all_gather
|
||||
x = x.contiguous()
|
||||
x_list = [torch.zeros_like(x) for _ in range(self.ulysses_size)]
|
||||
torch.distributed.all_gather(x_list, x, group=self.ulysses_pg)
|
||||
x = torch.cat(x_list, dim=1)
|
||||
|
||||
# Output projection and unpatchify
|
||||
shift, scale = (self.scale_shift_table + temb.unsqueeze(1)).chunk(2, dim=1)
|
||||
x = self.norm_out(x) * (1 + scale) + shift
|
||||
x = x.to(hidden_states.dtype)
|
||||
|
||||
return self.unpatchify(self.proj_out(x), original_shape)
|
||||
|
||||
def load_weights(self, weights: dict) -> None:
|
||||
# Remap checkpoint keys to match model structure
|
||||
remapped_weights = {}
|
||||
for key, value in weights.items():
|
||||
# Remap transformer block FFN keys
|
||||
if ".ffn.net.0.proj." in key:
|
||||
new_key = key.replace(".ffn.net.0.proj.", ".ffn.up_proj.")
|
||||
remapped_weights[new_key] = value
|
||||
elif ".ffn.net.2." in key:
|
||||
new_key = key.replace(".ffn.net.2.", ".ffn.down_proj.")
|
||||
remapped_weights[new_key] = value
|
||||
# Remap image embedder FF keys
|
||||
elif ".image_embedder.ff.net.0.proj." in key:
|
||||
new_key = key.replace(".image_embedder.ff.net.0.proj.", ".image_embedder.ff_in.")
|
||||
remapped_weights[new_key] = value
|
||||
elif ".image_embedder.ff.net.2." in key:
|
||||
new_key = key.replace(".image_embedder.ff.net.2.", ".image_embedder.ff_out.")
|
||||
remapped_weights[new_key] = value
|
||||
# Remap I2V attention keys
|
||||
elif ".attn2.add_k_proj." in key:
|
||||
new_key = key.replace(".attn2.add_k_proj.", ".add_k_proj.")
|
||||
remapped_weights[new_key] = value
|
||||
elif ".attn2.add_v_proj." in key:
|
||||
new_key = key.replace(".attn2.add_v_proj.", ".add_v_proj.")
|
||||
remapped_weights[new_key] = value
|
||||
elif ".attn2.norm_added_k." in key:
|
||||
new_key = key.replace(".attn2.norm_added_k.", ".norm_added_k.")
|
||||
remapped_weights[new_key] = value
|
||||
else:
|
||||
remapped_weights[key] = value
|
||||
|
||||
weights = remapped_weights
|
||||
|
||||
# Handle root-level parameters (filter_weights doesn't work for empty prefix)
|
||||
for param_name, param in self._parameters.items():
|
||||
if param is not None and param_name in weights:
|
||||
param.data.copy_(weights[param_name].to(self.model_config.torch_dtype))
|
||||
|
||||
params_map = {
|
||||
"qkv_proj": ["to_q", "to_k", "to_v"],
|
||||
}
|
||||
loader = DynamicLinearWeightLoader(self.model_config, params_map=params_map)
|
||||
|
||||
for name, module in tqdm(self.named_modules(), desc="Loading weights"):
|
||||
if len(module._parameters) == 0:
|
||||
continue
|
||||
|
||||
if isinstance(module, Linear):
|
||||
weight_dicts = loader.get_linear_weights(module, name, weights)
|
||||
|
||||
if weight_dicts:
|
||||
loader.load_linear_weights(module, name, weight_dicts)
|
||||
elif "add_k_proj" in name or "add_v_proj" in name:
|
||||
logger.info(f"[Weight Loading] No weights found for I2V module: {name}")
|
||||
else:
|
||||
module_weights = loader.filter_weights(name, weights)
|
||||
for param_name, param in module._parameters.items():
|
||||
if param is not None and param_name in module_weights:
|
||||
param.data.copy_(
|
||||
module_weights[param_name].to(self.model_config.torch_dtype)
|
||||
)
|
||||
|
||||
def post_load_weights(self) -> None:
|
||||
"""Call post_load_weights on all Linear modules and convert embedders to target dtype."""
|
||||
# Convert condition_embedder components to target dtype
|
||||
target_dtype = self.model_config.torch_dtype
|
||||
if hasattr(self.condition_embedder, "time_embedder"):
|
||||
self.condition_embedder.time_embedder.to(target_dtype)
|
||||
if hasattr(self.condition_embedder, "text_embedder"):
|
||||
self.condition_embedder.text_embedder.to(target_dtype)
|
||||
|
||||
# Call post_load_weights on all Linear modules
|
||||
for _, module in self.named_modules():
|
||||
if isinstance(module, Linear):
|
||||
module.post_load_weights()
|
||||
26
tensorrt_llm/_torch/visual_gen/modules/__init__.py
Normal file
26
tensorrt_llm/_torch/visual_gen/modules/__init__.py
Normal file
@ -0,0 +1,26 @@
|
||||
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
Visual Generation Modules
|
||||
|
||||
This module provides modular neural network components for visual generation models.
|
||||
"""
|
||||
|
||||
from .attention import Attention, QKVMode
|
||||
|
||||
__all__ = [
|
||||
"Attention",
|
||||
"QKVMode",
|
||||
]
|
||||
284
tensorrt_llm/_torch/visual_gen/modules/attention.py
Normal file
284
tensorrt_llm/_torch/visual_gen/modules/attention.py
Normal file
@ -0,0 +1,284 @@
|
||||
from enum import Enum
|
||||
from typing import TYPE_CHECKING, Optional, Tuple
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from ...modules.linear import Linear, WeightMode, WeightsLoadingConfig
|
||||
from ...modules.rms_norm import RMSNorm
|
||||
from ..attention_backend.interface import AttentionTensorLayout
|
||||
from ..attention_backend.utils import create_attention
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ..config import DiffusionModelConfig
|
||||
|
||||
|
||||
class QKVMode(str, Enum):
|
||||
FUSE_QKV = "fuse_qkv"
|
||||
FUSE_KV = "fuse_kv"
|
||||
SEPARATE_QKV = "separate"
|
||||
|
||||
|
||||
# TODO: torch compile
|
||||
def apply_rotary_emb(
|
||||
x: torch.Tensor, freqs_cos: torch.Tensor, freqs_sin: torch.Tensor
|
||||
) -> torch.Tensor:
|
||||
freqs_cos = freqs_cos.to(x.dtype)
|
||||
freqs_sin = freqs_sin.to(x.dtype)
|
||||
x1, x2 = x.unflatten(-1, (-1, 2)).unbind(-1) # [B, S, H, D/2]
|
||||
|
||||
cos = freqs_cos[..., 0::2]
|
||||
sin = freqs_sin[..., 1::2]
|
||||
|
||||
return torch.stack([x1 * cos - x2 * sin, x1 * sin + x2 * cos], dim=-1).flatten(-2)
|
||||
|
||||
|
||||
class Attention(nn.Module):
|
||||
"""Attention module for visual generation models."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
hidden_size: int,
|
||||
num_attention_heads: int,
|
||||
num_key_value_heads: Optional[int] = None,
|
||||
head_dim: Optional[int] = None,
|
||||
qkv_mode: QKVMode = QKVMode.FUSE_QKV,
|
||||
qk_norm: bool = True,
|
||||
eps: float = 1e-6, # TODO: remove this, we should add this to the config
|
||||
config: Optional["DiffusionModelConfig"] = None,
|
||||
layer_idx: Optional[int] = None,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
config = config or DiffusionModelConfig()
|
||||
self.dtype = config.torch_dtype
|
||||
self.quant_config = config.quant_config
|
||||
self.skip_create_weights_in_init = config.skip_create_weights_in_init
|
||||
self.force_dynamic_quantization = config.force_dynamic_quantization
|
||||
self.mapping = getattr(config, "mapping", None)
|
||||
|
||||
self.hidden_size = hidden_size
|
||||
self.num_attention_heads = num_attention_heads
|
||||
self.num_key_value_heads = num_key_value_heads or num_attention_heads
|
||||
self.head_dim = head_dim or (hidden_size // num_attention_heads)
|
||||
self.qkv_mode = QKVMode(qkv_mode) if isinstance(qkv_mode, str) else qkv_mode
|
||||
|
||||
# Select compute backend (orthogonal to parallelism)
|
||||
ulysses_size = config.parallel.dit_ulysses_size
|
||||
base_backend = config.attention.backend
|
||||
|
||||
if self.qkv_mode == QKVMode.SEPARATE_QKV:
|
||||
backend_name = "VANILLA" # Cross-attention requires VANILLA
|
||||
else:
|
||||
backend_name = base_backend
|
||||
self.attn_backend = backend_name
|
||||
self.qk_norm = qk_norm
|
||||
self.layer_idx = layer_idx if layer_idx is not None else 0
|
||||
self.eps = eps
|
||||
|
||||
self.q_dim = self.num_attention_heads * self.head_dim
|
||||
self.kv_dim = self.num_key_value_heads * self.head_dim
|
||||
|
||||
self._init_qkv_proj()
|
||||
|
||||
if self.qk_norm:
|
||||
self.norm_q = RMSNorm(
|
||||
hidden_size=self.q_dim, eps=self.eps, dtype=self.dtype, has_weights=True
|
||||
)
|
||||
self.norm_k = RMSNorm(
|
||||
hidden_size=self.kv_dim, eps=self.eps, dtype=self.dtype, has_weights=True
|
||||
)
|
||||
|
||||
# TODO: Use weight mapper to create just a Linear module
|
||||
self.to_out = nn.ModuleList(
|
||||
[
|
||||
Linear(
|
||||
self.q_dim,
|
||||
self.hidden_size,
|
||||
dtype=self.dtype,
|
||||
mapping=self.mapping,
|
||||
quant_config=self.quant_config,
|
||||
skip_create_weights_in_init=self.skip_create_weights_in_init,
|
||||
force_dynamic_quantization=self.force_dynamic_quantization,
|
||||
)
|
||||
]
|
||||
)
|
||||
|
||||
# Compute head counts for the backend
|
||||
# Ulysses shards heads across workers; inner backend sees sharded count
|
||||
if ulysses_size > 1 and self.qkv_mode != QKVMode.SEPARATE_QKV:
|
||||
backend_num_heads = self.num_attention_heads // ulysses_size
|
||||
backend_num_kv_heads = self.num_key_value_heads // ulysses_size
|
||||
else:
|
||||
backend_num_heads = self.num_attention_heads
|
||||
backend_num_kv_heads = self.num_key_value_heads
|
||||
|
||||
# Create compute backend
|
||||
self.attn = create_attention(
|
||||
backend=backend_name,
|
||||
layer_idx=self.layer_idx,
|
||||
num_heads=backend_num_heads,
|
||||
head_dim=self.head_dim,
|
||||
num_kv_heads=backend_num_kv_heads,
|
||||
quant_config=self.quant_config,
|
||||
dtype=self.dtype,
|
||||
)
|
||||
|
||||
# Wrap with parallelism strategy (orthogonal to backend choice)
|
||||
if ulysses_size > 1 and self.qkv_mode != QKVMode.SEPARATE_QKV:
|
||||
from ..attention_backend.parallel import UlyssesAttention
|
||||
|
||||
process_group = getattr(config, "ulysses_process_group", None)
|
||||
self.attn = UlyssesAttention(
|
||||
inner_backend=self.attn,
|
||||
process_group=process_group,
|
||||
)
|
||||
|
||||
def _init_qkv_proj(self) -> None:
|
||||
if self.qkv_mode == QKVMode.FUSE_QKV:
|
||||
qkv_out_dim = self.q_dim + 2 * self.kv_dim
|
||||
self.qkv_proj = Linear(
|
||||
self.hidden_size,
|
||||
qkv_out_dim,
|
||||
dtype=self.dtype,
|
||||
mapping=self.mapping,
|
||||
quant_config=self.quant_config,
|
||||
skip_create_weights_in_init=self.skip_create_weights_in_init,
|
||||
force_dynamic_quantization=self.force_dynamic_quantization,
|
||||
weights_loading_config=WeightsLoadingConfig(
|
||||
weight_mode=WeightMode.FUSED_QKV_LINEAR
|
||||
),
|
||||
fused_weight_shard_indices_mapping={
|
||||
"q": (0, self.q_dim),
|
||||
"k": (self.q_dim, self.kv_dim),
|
||||
"v": (self.q_dim + self.kv_dim, self.kv_dim),
|
||||
},
|
||||
)
|
||||
else:
|
||||
self.to_q = Linear(
|
||||
self.hidden_size,
|
||||
self.q_dim,
|
||||
dtype=self.dtype,
|
||||
mapping=self.mapping,
|
||||
quant_config=self.quant_config,
|
||||
skip_create_weights_in_init=self.skip_create_weights_in_init,
|
||||
force_dynamic_quantization=self.force_dynamic_quantization,
|
||||
)
|
||||
self.to_k = Linear(
|
||||
self.hidden_size,
|
||||
self.kv_dim,
|
||||
dtype=self.dtype,
|
||||
mapping=self.mapping,
|
||||
quant_config=self.quant_config,
|
||||
skip_create_weights_in_init=self.skip_create_weights_in_init,
|
||||
force_dynamic_quantization=self.force_dynamic_quantization,
|
||||
)
|
||||
self.to_v = Linear(
|
||||
self.hidden_size,
|
||||
self.kv_dim,
|
||||
dtype=self.dtype,
|
||||
mapping=self.mapping,
|
||||
quant_config=self.quant_config,
|
||||
skip_create_weights_in_init=self.skip_create_weights_in_init,
|
||||
force_dynamic_quantization=self.force_dynamic_quantization,
|
||||
)
|
||||
|
||||
def get_qkv(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
encoder_hidden_states: Optional[torch.Tensor] = None,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
if self.qkv_mode == QKVMode.FUSE_QKV:
|
||||
qkv = self.qkv_proj(hidden_states)
|
||||
q, k, v = qkv.split([self.q_dim, self.kv_dim, self.kv_dim], dim=-1)
|
||||
else:
|
||||
kv_source = (
|
||||
encoder_hidden_states if encoder_hidden_states is not None else hidden_states
|
||||
)
|
||||
q = self.to_q(hidden_states)
|
||||
k = self.to_k(kv_source)
|
||||
v = self.to_v(kv_source)
|
||||
return q, k, v
|
||||
|
||||
def apply_qk_norm(self, q: torch.Tensor, k: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
if self.qk_norm:
|
||||
q = self.norm_q(q)
|
||||
k = self.norm_k(k)
|
||||
return q, k
|
||||
|
||||
def _attn_impl(
|
||||
self,
|
||||
q: torch.Tensor,
|
||||
k: torch.Tensor,
|
||||
v: torch.Tensor,
|
||||
batch_size: Optional[int] = None,
|
||||
seq_len: Optional[int] = None,
|
||||
kv_seq_len: Optional[int] = None,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Call attention backend with appropriate tensor layout.
|
||||
|
||||
Two layout paths:
|
||||
1. HND backends (VANILLA): [B, S, H*D] -> [B, H, S, D]
|
||||
2. NHD backends (TRTLLM, UlyssesAttention): [B, S, H*D] -> [B, S, H, D]
|
||||
"""
|
||||
backend_layout = getattr(self.attn, "preferred_layout", AttentionTensorLayout.NHD)
|
||||
|
||||
batch_size = batch_size or q.shape[0]
|
||||
seq_len = seq_len or q.shape[1]
|
||||
kv_seq_len = kv_seq_len or k.shape[1]
|
||||
|
||||
# Reshape inputs: [B, S, H*D] -> backend's preferred 4D layout
|
||||
if backend_layout == AttentionTensorLayout.HND:
|
||||
q = q.view(batch_size, -1, self.num_attention_heads, self.head_dim).transpose(1, 2)
|
||||
k = k.view(batch_size, -1, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
||||
v = v.view(batch_size, -1, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
||||
else:
|
||||
q = q.view(batch_size, -1, self.num_attention_heads, self.head_dim)
|
||||
k = k.view(batch_size, -1, self.num_key_value_heads, self.head_dim)
|
||||
v = v.view(batch_size, -1, self.num_key_value_heads, self.head_dim)
|
||||
|
||||
# Call backend
|
||||
out = self.attn.forward(
|
||||
q=q,
|
||||
k=k,
|
||||
v=v,
|
||||
batch_size=batch_size,
|
||||
seq_len=seq_len,
|
||||
seq_len_kv=kv_seq_len if kv_seq_len != seq_len else None,
|
||||
)
|
||||
|
||||
# Flatten back to [B, S, H*D]
|
||||
if backend_layout == AttentionTensorLayout.HND:
|
||||
return out.transpose(1, 2).flatten(2)
|
||||
else:
|
||||
return out.flatten(2)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
encoder_hidden_states: Optional[torch.Tensor] = None,
|
||||
freqs: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
||||
) -> torch.Tensor:
|
||||
assert hidden_states.ndim == 3, "hidden_states must be a 3D tensor"
|
||||
batch_size, seq_len = hidden_states.shape[:2]
|
||||
kv_seq_len = (
|
||||
encoder_hidden_states.shape[1] if encoder_hidden_states is not None else seq_len
|
||||
)
|
||||
|
||||
q, k, v = self.get_qkv(hidden_states, encoder_hidden_states)
|
||||
q, k = self.apply_qk_norm(q, k)
|
||||
|
||||
# Apply RoPE if provided (model handles RoPE, not attention backend)
|
||||
if freqs is not None:
|
||||
freqs_cos, freqs_sin = freqs
|
||||
q = q.view(batch_size, seq_len, self.num_attention_heads, self.head_dim) # [B, S, H, D]
|
||||
k = k.view(batch_size, kv_seq_len, self.num_key_value_heads, self.head_dim)
|
||||
q = apply_rotary_emb(q, freqs_cos, freqs_sin)
|
||||
k = apply_rotary_emb(k, freqs_cos, freqs_sin)
|
||||
q = q.flatten(2)
|
||||
k = k.flatten(2)
|
||||
|
||||
out = self._attn_impl(q, k, v, batch_size, seq_len, kv_seq_len)
|
||||
out = self.to_out[0](out)
|
||||
return out
|
||||
29
tensorrt_llm/_torch/visual_gen/output.py
Normal file
29
tensorrt_llm/_torch/visual_gen/output.py
Normal file
@ -0,0 +1,29 @@
|
||||
"""Output dataclass for visual generation models."""
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
@dataclass
|
||||
class MediaOutput:
|
||||
"""Unified output for all visual generation models.
|
||||
|
||||
Different models populate different fields:
|
||||
- FLUX2: image only
|
||||
- WAN: video only
|
||||
- LTX2: video + audio
|
||||
|
||||
Attributes:
|
||||
image: Generated image as torch tensor with shape (height, width, channels) and dtype uint8.
|
||||
Populated by FLUX2 for text-to-image generation.
|
||||
video: Generated video frames as torch tensor with shape (num_frames, height, width, channels) and dtype uint8.
|
||||
Populated by WAN and LTX2 for text-to-video generation.
|
||||
audio: Generated audio as torch tensor with dtype float32.
|
||||
Populated by LTX2 for text-to-video-with-audio generation.
|
||||
"""
|
||||
|
||||
image: Optional[torch.Tensor] = None
|
||||
video: Optional[torch.Tensor] = None
|
||||
audio: Optional[torch.Tensor] = None
|
||||
100
tensorrt_llm/_torch/visual_gen/parallelism.py
Normal file
100
tensorrt_llm/_torch/visual_gen/parallelism.py
Normal file
@ -0,0 +1,100 @@
|
||||
"""Utilities for distributed parallelism setup in diffusion models."""
|
||||
|
||||
from typing import Optional, Tuple
|
||||
|
||||
import torch.distributed as dist
|
||||
|
||||
from tensorrt_llm._torch.visual_gen.config import DiffusionModelConfig
|
||||
|
||||
|
||||
def setup_sequence_parallelism(
|
||||
model_config: DiffusionModelConfig,
|
||||
num_attention_heads: int,
|
||||
) -> Tuple[bool, int, Optional[dist.ProcessGroup], int]:
|
||||
"""
|
||||
Setup sequence parallelism (currently Ulysses only) with CFG support.
|
||||
|
||||
Creates nested process groups where each CFG group has its own Ulysses group.
|
||||
Example with cfg_size=2, ulysses_size=2, world_size=4:
|
||||
GPU 0-1: CFG group 0, Ulysses group 0
|
||||
GPU 2-3: CFG group 1, Ulysses group 1
|
||||
|
||||
Args:
|
||||
model_config: Model configuration containing parallel settings
|
||||
num_attention_heads: Number of attention heads in the model
|
||||
|
||||
Returns:
|
||||
Tuple of (use_parallelism, parallelism_size, parallelism_pg, parallelism_rank):
|
||||
- use_parallelism: Whether sequence parallelism is enabled
|
||||
- parallelism_size: The sequence parallelism degree
|
||||
- parallelism_pg: The process group for this rank (or None)
|
||||
- parallelism_rank: This rank's position within its parallelism group
|
||||
|
||||
Raises:
|
||||
RuntimeError: If torch.distributed is not initialized
|
||||
ValueError: If configuration is invalid (incompatible sizes, head count not divisible, etc.)
|
||||
NotImplementedError: If Ring attention is requested (not yet implemented)
|
||||
|
||||
Side Effects:
|
||||
- Sets model_config.ulysses_process_group to the created process group
|
||||
|
||||
Note:
|
||||
Both num_attention_heads and sequence length must be divisible by ulysses_size.
|
||||
Head count is validated here; sequence length is validated at runtime during forward pass.
|
||||
"""
|
||||
ulysses_size = model_config.parallel.dit_ulysses_size
|
||||
ring_size = model_config.parallel.dit_ring_size
|
||||
cfg_size = model_config.parallel.dit_cfg_size
|
||||
|
||||
# Check for ring attention (not yet implemented)
|
||||
if ring_size > 1:
|
||||
raise NotImplementedError("Ring attention parallelism is not yet implemented")
|
||||
|
||||
# Early exit if not using sequence parallelism
|
||||
if ulysses_size <= 1:
|
||||
model_config.ulysses_process_group = None
|
||||
return False, 1, None, 0
|
||||
|
||||
# Validate distributed initialization
|
||||
if not dist.is_initialized():
|
||||
raise RuntimeError(
|
||||
"torch.distributed.init_process_group() must be called before "
|
||||
"setting up sequence parallelism"
|
||||
)
|
||||
|
||||
rank = dist.get_rank()
|
||||
world_size = dist.get_world_size()
|
||||
|
||||
# Validate total parallelism capacity
|
||||
total_parallel = cfg_size * ulysses_size
|
||||
if total_parallel > world_size:
|
||||
raise ValueError(
|
||||
f"cfg_size ({cfg_size}) * ulysses_size ({ulysses_size}) = "
|
||||
f"{total_parallel} exceeds world_size ({world_size})"
|
||||
)
|
||||
|
||||
# Validate head count divisibility
|
||||
if num_attention_heads % ulysses_size != 0:
|
||||
raise ValueError(
|
||||
f"num_attention_heads ({num_attention_heads}) must be divisible by "
|
||||
f"ulysses_size ({ulysses_size})"
|
||||
)
|
||||
|
||||
# Create nested process groups
|
||||
# Each CFG group has its own Ulysses group
|
||||
ulysses_pg = None
|
||||
ulysses_rank = 0
|
||||
|
||||
for cfg_id in range(cfg_size):
|
||||
ulysses_ranks = list(range(cfg_id * ulysses_size, (cfg_id + 1) * ulysses_size))
|
||||
pg = dist.new_group(ulysses_ranks, use_local_synchronization=True)
|
||||
|
||||
# Store if this rank belongs to this group
|
||||
if rank in ulysses_ranks:
|
||||
ulysses_pg = pg
|
||||
ulysses_rank = rank - cfg_id * ulysses_size
|
||||
|
||||
# Store in config for Attention modules
|
||||
model_config.ulysses_process_group = ulysses_pg
|
||||
|
||||
return True, ulysses_size, ulysses_pg, ulysses_rank
|
||||
544
tensorrt_llm/_torch/visual_gen/pipeline.py
Normal file
544
tensorrt_llm/_torch/visual_gen/pipeline.py
Normal file
@ -0,0 +1,544 @@
|
||||
import time
|
||||
from typing import TYPE_CHECKING, Any, Callable, Dict, Optional, Tuple
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
import torch.nn as nn
|
||||
|
||||
from tensorrt_llm.logger import logger
|
||||
from tensorrt_llm.mapping import Mapping
|
||||
|
||||
from .teacache import TeaCacheBackend
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .config import DiffusionModelConfig
|
||||
|
||||
|
||||
class BasePipeline(nn.Module):
|
||||
"""
|
||||
Base class for diffusion pipelines.
|
||||
"""
|
||||
|
||||
def __init__(self, model_config: "DiffusionModelConfig"):
|
||||
super().__init__()
|
||||
self.model_config = model_config
|
||||
self.config = model_config.pretrained_config
|
||||
self.mapping: Mapping = getattr(model_config, "mapping", None) or Mapping()
|
||||
|
||||
# Components
|
||||
self.transformer: Optional[nn.Module] = None
|
||||
self.vae: Optional[nn.Module] = None
|
||||
self.text_encoder: Optional[nn.Module] = None
|
||||
self.tokenizer: Optional[Any] = None
|
||||
self.scheduler: Optional[Any] = None
|
||||
|
||||
# Initialize transformer
|
||||
self._init_transformer()
|
||||
|
||||
@property
|
||||
def rank(self):
|
||||
return dist.get_rank() if dist.is_initialized() else 0
|
||||
|
||||
@property
|
||||
def world_size(self):
|
||||
return dist.get_world_size() if dist.is_initialized() else 1
|
||||
|
||||
@property
|
||||
def dtype(self):
|
||||
if hasattr(self, "transformer"):
|
||||
return next(self.transformer.parameters()).dtype
|
||||
return torch.float32
|
||||
|
||||
@property
|
||||
def device(self):
|
||||
return self.transformer.device
|
||||
|
||||
def infer(self, req: Any):
|
||||
raise NotImplementedError
|
||||
|
||||
def _init_transformer(self) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
def forward(self, *args, **kwargs):
|
||||
raise NotImplementedError
|
||||
|
||||
def load_standard_components(
|
||||
self,
|
||||
checkpoint_dir: str,
|
||||
device: torch.device,
|
||||
skip_components: Optional[list] = None,
|
||||
) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
def load_weights(self, weights: Dict[str, torch.Tensor]) -> None:
|
||||
if self.transformer is not None and hasattr(self.transformer, "load_weights"):
|
||||
self.transformer.load_weights(weights)
|
||||
|
||||
def post_load_weights(self) -> None:
|
||||
if self.transformer is not None and hasattr(self.transformer, "post_load_weights"):
|
||||
self.transformer.post_load_weights()
|
||||
|
||||
def _setup_teacache(self, model, coefficients: Optional[Dict] = None):
|
||||
"""Setup TeaCache optimization for the transformer model.
|
||||
|
||||
TeaCache caches transformer block outputs when timestep embeddings change slowly,
|
||||
reducing computation during the denoising loop.
|
||||
|
||||
Args:
|
||||
model: The transformer model to optimize
|
||||
coefficients: Optional dict of model-specific polynomial coefficients for cache decisions
|
||||
Format: {model_size: {"ret_steps": [...], "standard": [...]}}
|
||||
"""
|
||||
self.cache_backend = None
|
||||
|
||||
# Get teacache config from model_config (always present now)
|
||||
teacache_cfg = self.model_config.teacache
|
||||
if not teacache_cfg.enable_teacache:
|
||||
return
|
||||
|
||||
# Apply model-specific polynomial coefficients
|
||||
# Coefficients are used to rescale embedding distances for cache decisions
|
||||
if coefficients:
|
||||
checkpoint_path = (
|
||||
getattr(self.model_config.pretrained_config, "_name_or_path", "") or ""
|
||||
)
|
||||
for model_size, coeff_data in coefficients.items():
|
||||
# Match model size in path (case-insensitive, e.g., "1.3B", "14B", "dev")
|
||||
if model_size.lower() in checkpoint_path.lower():
|
||||
if isinstance(coeff_data, dict):
|
||||
# Select coefficient set based on warmup mode
|
||||
mode = "ret_steps" if teacache_cfg.use_ret_steps else "standard"
|
||||
if mode in coeff_data:
|
||||
teacache_cfg.coefficients = coeff_data[mode]
|
||||
logger.info(f"TeaCache: Using {model_size} coefficients ({mode} mode)")
|
||||
else:
|
||||
# Single coefficient list (no mode distinction)
|
||||
teacache_cfg.coefficients = coeff_data
|
||||
logger.info(f"TeaCache: Using {model_size} coefficients")
|
||||
break
|
||||
|
||||
# Initialize and enable TeaCache backend
|
||||
logger.info("TeaCache: Initializing...")
|
||||
self.cache_backend = TeaCacheBackend(teacache_cfg)
|
||||
self.cache_backend.enable(model)
|
||||
|
||||
def decode_latents(
|
||||
self,
|
||||
latents: torch.Tensor,
|
||||
decode_fn: Callable[[torch.Tensor], Any],
|
||||
extra_latents: Optional[Dict[str, Tuple[torch.Tensor, Callable]]] = None,
|
||||
):
|
||||
"""Execute VAE decoding. Only rank 0 performs decoding.
|
||||
|
||||
Args:
|
||||
latents: Primary latents to decode (e.g., video)
|
||||
decode_fn: Decoder function for primary latents
|
||||
extra_latents: Optional dict of additional latents to decode.
|
||||
Format: {name: (latents_tensor, decode_fn)}
|
||||
Example: {"audio": (audio_latents, audio_decode_fn)}
|
||||
|
||||
Returns:
|
||||
Single result if no extra_latents, tuple of results if extra_latents provided.
|
||||
Non-rank-0 processes return None placeholders.
|
||||
"""
|
||||
if self.rank == 0:
|
||||
primary_result = decode_fn(latents)
|
||||
|
||||
if extra_latents:
|
||||
extra_results = []
|
||||
for name, (extra_latent, extra_decode_fn) in extra_latents.items():
|
||||
extra_results.append(extra_decode_fn(extra_latent))
|
||||
return (primary_result,) + tuple(extra_results)
|
||||
|
||||
return primary_result
|
||||
|
||||
# Return None placeholders for non-rank-0 processes
|
||||
n_results = 1 + (len(extra_latents) if extra_latents else 0)
|
||||
return (None,) * n_results if n_results > 1 else None
|
||||
|
||||
@staticmethod
|
||||
def _rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0):
|
||||
"""Rescale noise to fix overexposure (https://huggingface.co/papers/2305.08891)."""
|
||||
std_text = noise_pred_text.std(dim=list(range(1, noise_pred_text.ndim)), keepdim=True)
|
||||
std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True)
|
||||
noise_pred_rescaled = noise_cfg * (std_text / std_cfg)
|
||||
return guidance_rescale * noise_pred_rescaled + (1 - guidance_rescale) * noise_cfg
|
||||
|
||||
def _setup_cfg_config(
|
||||
self, guidance_scale, prompt_embeds, neg_prompt_embeds, extra_cfg_tensors=None
|
||||
):
|
||||
"""Setup CFG parallel configuration.
|
||||
|
||||
Args:
|
||||
guidance_scale: CFG guidance scale
|
||||
prompt_embeds: Positive prompt embeddings
|
||||
neg_prompt_embeds: Negative prompt embeddings (None if already concatenated)
|
||||
extra_cfg_tensors: Optional dict of additional tensors to split for CFG parallel.
|
||||
Format: {name: (positive_tensor, negative_tensor)}
|
||||
Example: {"audio_embeds": (pos_audio, neg_audio),
|
||||
"attention_mask": (pos_mask, neg_mask)}
|
||||
|
||||
Returns:
|
||||
Dict with CFG configuration including split tensors
|
||||
"""
|
||||
# Access parallel config directly (always present now)
|
||||
cfg_size = self.model_config.parallel.dit_cfg_size
|
||||
ulysses_size = self.model_config.parallel.dit_ulysses_size
|
||||
|
||||
cfg_group = self.rank // ulysses_size
|
||||
is_split_embeds = neg_prompt_embeds is not None
|
||||
do_cfg_parallel = cfg_size >= 2 and guidance_scale > 1.0
|
||||
|
||||
local_extras = {}
|
||||
|
||||
if do_cfg_parallel:
|
||||
if self.rank == 0:
|
||||
logger.info(f"CFG Parallel: cfg_size={cfg_size}, ulysses_size={ulysses_size}")
|
||||
|
||||
# Split main embeddings
|
||||
if is_split_embeds:
|
||||
pos_embeds, neg_embeds = prompt_embeds, neg_prompt_embeds
|
||||
else:
|
||||
neg_embeds, pos_embeds = prompt_embeds.chunk(2)
|
||||
|
||||
local_embeds = pos_embeds if cfg_group == 0 else neg_embeds
|
||||
|
||||
# Split extra tensors if provided
|
||||
if extra_cfg_tensors:
|
||||
for name, (pos_tensor, neg_tensor) in extra_cfg_tensors.items():
|
||||
if pos_tensor is not None and neg_tensor is not None:
|
||||
local_extras[name] = pos_tensor if cfg_group == 0 else neg_tensor
|
||||
elif pos_tensor is not None:
|
||||
# Only positive provided, use it for both
|
||||
local_extras[name] = pos_tensor
|
||||
else:
|
||||
local_embeds = None
|
||||
if is_split_embeds and guidance_scale > 1.0:
|
||||
prompt_embeds = torch.cat([neg_prompt_embeds, prompt_embeds])
|
||||
|
||||
# For standard CFG, concatenate extra tensors
|
||||
if extra_cfg_tensors:
|
||||
for name, (pos_tensor, neg_tensor) in extra_cfg_tensors.items():
|
||||
if pos_tensor is not None and neg_tensor is not None and guidance_scale > 1.0:
|
||||
local_extras[name] = torch.cat([neg_tensor, pos_tensor], dim=0)
|
||||
elif pos_tensor is not None:
|
||||
local_extras[name] = pos_tensor
|
||||
|
||||
return {
|
||||
"enabled": do_cfg_parallel,
|
||||
"cfg_size": cfg_size,
|
||||
"ulysses_size": ulysses_size,
|
||||
"cfg_group": cfg_group,
|
||||
"local_embeds": local_embeds,
|
||||
"prompt_embeds": prompt_embeds,
|
||||
"local_extras": local_extras,
|
||||
}
|
||||
|
||||
def _denoise_step_cfg_parallel(
|
||||
self,
|
||||
latents,
|
||||
extra_stream_latents,
|
||||
timestep,
|
||||
local_embeds,
|
||||
forward_fn,
|
||||
guidance_scale,
|
||||
guidance_rescale,
|
||||
ulysses_size,
|
||||
local_extras,
|
||||
):
|
||||
"""Execute single denoising step with CFG parallel."""
|
||||
t_start = time.time()
|
||||
result = forward_fn(latents, extra_stream_latents, timestep, local_embeds, local_extras)
|
||||
|
||||
# Handle return format: (primary_noise, extra_noises_dict) or just primary_noise
|
||||
if isinstance(result, tuple) and len(result) == 2 and isinstance(result[1], dict):
|
||||
noise_pred_local, extra_noise_locals = result
|
||||
else:
|
||||
noise_pred_local = result
|
||||
extra_noise_locals = {}
|
||||
|
||||
t_transformer = time.time() - t_start
|
||||
|
||||
c_start = time.time()
|
||||
|
||||
# All-gather primary noise
|
||||
gather_list = [torch.empty_like(noise_pred_local) for _ in range(self.world_size)]
|
||||
dist.all_gather(gather_list, noise_pred_local)
|
||||
noise_cond = gather_list[0]
|
||||
noise_uncond = gather_list[ulysses_size]
|
||||
noise_pred = noise_uncond + guidance_scale * (noise_cond - noise_uncond)
|
||||
|
||||
# All-gather extra stream noises
|
||||
extra_noise_preds = {}
|
||||
for name, noise_local in extra_noise_locals.items():
|
||||
gather_list_extra = [torch.empty_like(noise_local) for _ in range(self.world_size)]
|
||||
dist.all_gather(gather_list_extra, noise_local)
|
||||
noise_cond_extra = gather_list_extra[0]
|
||||
noise_uncond_extra = gather_list_extra[ulysses_size]
|
||||
extra_noise_preds[name] = noise_uncond_extra + guidance_scale * (
|
||||
noise_cond_extra - noise_uncond_extra
|
||||
)
|
||||
|
||||
if guidance_rescale > 0.0:
|
||||
extra_noise_preds[name] = self._rescale_noise_cfg(
|
||||
extra_noise_preds[name], noise_cond_extra, guidance_rescale
|
||||
)
|
||||
|
||||
if guidance_rescale > 0.0:
|
||||
noise_pred = self._rescale_noise_cfg(noise_pred, noise_cond, guidance_rescale)
|
||||
|
||||
t_cfg = time.time() - c_start
|
||||
return noise_pred, extra_noise_preds, t_transformer, t_cfg
|
||||
|
||||
def _denoise_step_standard(
|
||||
self,
|
||||
latents,
|
||||
extra_stream_latents,
|
||||
timestep,
|
||||
prompt_embeds,
|
||||
forward_fn,
|
||||
guidance_scale,
|
||||
guidance_rescale,
|
||||
local_extras,
|
||||
):
|
||||
"""Execute single denoising step without CFG parallel."""
|
||||
if guidance_scale > 1.0:
|
||||
latent_input = torch.cat([latents] * 2)
|
||||
# Duplicate extra stream latents for CFG
|
||||
extra_stream_input = {
|
||||
name: torch.cat([stream_latents] * 2)
|
||||
for name, stream_latents in extra_stream_latents.items()
|
||||
}
|
||||
else:
|
||||
latent_input = latents
|
||||
extra_stream_input = extra_stream_latents
|
||||
|
||||
timestep_expanded = timestep.expand(latent_input.shape[0])
|
||||
|
||||
t_start = time.time()
|
||||
result = forward_fn(
|
||||
latent_input, extra_stream_input, timestep_expanded, prompt_embeds, local_extras
|
||||
)
|
||||
|
||||
# Handle return format: (primary_noise, extra_noises_dict) or just primary_noise
|
||||
if isinstance(result, tuple) and len(result) == 2 and isinstance(result[1], dict):
|
||||
noise_pred, extra_noise_preds = result
|
||||
else:
|
||||
noise_pred = result
|
||||
extra_noise_preds = {}
|
||||
|
||||
t_transformer = time.time() - t_start
|
||||
|
||||
c_start = time.time()
|
||||
if guidance_scale > 1.0:
|
||||
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
|
||||
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
|
||||
|
||||
# Apply CFG to extra streams
|
||||
for name, noise_extra in extra_noise_preds.items():
|
||||
noise_uncond_extra, noise_text_extra = noise_extra.chunk(2)
|
||||
extra_noise_preds[name] = noise_uncond_extra + guidance_scale * (
|
||||
noise_text_extra - noise_uncond_extra
|
||||
)
|
||||
|
||||
if guidance_rescale > 0.0:
|
||||
extra_noise_preds[name] = self._rescale_noise_cfg(
|
||||
extra_noise_preds[name], noise_text_extra, guidance_rescale
|
||||
)
|
||||
|
||||
if guidance_rescale > 0.0:
|
||||
noise_pred = self._rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale)
|
||||
|
||||
t_cfg = time.time() - c_start
|
||||
else:
|
||||
t_cfg = 0.0
|
||||
|
||||
return noise_pred, extra_noise_preds, t_transformer, t_cfg
|
||||
|
||||
def _scheduler_step(
|
||||
self,
|
||||
latents,
|
||||
extra_stream_latents,
|
||||
noise_pred,
|
||||
extra_noise_preds,
|
||||
timestep,
|
||||
scheduler,
|
||||
extra_stream_schedulers,
|
||||
):
|
||||
"""Execute scheduler step for all streams."""
|
||||
t_start = time.time()
|
||||
latents = scheduler.step(noise_pred, timestep, latents, return_dict=False)[0]
|
||||
|
||||
# Step schedulers for extra streams
|
||||
for name, noise_extra in extra_noise_preds.items():
|
||||
if name in extra_stream_schedulers:
|
||||
extra_stream_latents[name] = extra_stream_schedulers[name].step(
|
||||
noise_extra, timestep, extra_stream_latents[name], return_dict=False
|
||||
)[0]
|
||||
|
||||
t_sched = time.time() - t_start
|
||||
return latents, extra_stream_latents, t_sched
|
||||
|
||||
def denoise(
|
||||
self,
|
||||
latents: torch.Tensor,
|
||||
scheduler: Any,
|
||||
prompt_embeds: torch.Tensor,
|
||||
guidance_scale: float,
|
||||
forward_fn: Callable,
|
||||
timesteps: Optional[torch.Tensor] = None,
|
||||
neg_prompt_embeds: Optional[torch.Tensor] = None,
|
||||
guidance_rescale: float = 0.0,
|
||||
extra_cfg_tensors: Optional[Dict[str, Tuple[torch.Tensor, Optional[torch.Tensor]]]] = None,
|
||||
extra_streams: Optional[Dict[str, Tuple[torch.Tensor, Any]]] = None,
|
||||
guidance_scale_2: Optional[float] = None,
|
||||
boundary_timestep: Optional[float] = None,
|
||||
):
|
||||
"""Execute denoising loop with optional CFG parallel and TeaCache support.
|
||||
|
||||
Args:
|
||||
latents: Initial noise latents (primary stream, e.g., video)
|
||||
scheduler: Diffusion scheduler for primary stream
|
||||
prompt_embeds: Text embeddings (positive)
|
||||
guidance_scale: CFG strength (1.0 = no guidance)
|
||||
forward_fn: Transformer forward function
|
||||
Signature: forward_fn(latents, extra_stream_latents, timestep,
|
||||
encoder_hidden_states, extra_tensors_dict)
|
||||
Returns: (primary_noise, extra_stream_noises_dict) or just primary_noise
|
||||
timesteps: Optional custom timesteps (defaults to scheduler.timesteps)
|
||||
neg_prompt_embeds: Optional negative text embeddings for CFG
|
||||
guidance_rescale: CFG rescale factor to prevent overexposure
|
||||
extra_cfg_tensors: Optional dict of additional tensors to split for CFG parallel
|
||||
Format: {name: (positive_tensor, negative_tensor)}
|
||||
Example: {"audio_embeds": (pos_audio, neg_audio)}
|
||||
extra_streams: Optional dict of additional streams to denoise in parallel
|
||||
Format: {name: (stream_latents, stream_scheduler)}
|
||||
Example: {"audio": (audio_latents, audio_scheduler)}
|
||||
guidance_scale_2: Optional guidance scale for two-stage denoising.
|
||||
When provided with boundary_timestep, switches from guidance_scale
|
||||
to guidance_scale_2 when timestep < boundary_timestep.
|
||||
boundary_timestep: Optional timestep boundary for two-stage denoising.
|
||||
Switches guidance scale when crossing this threshold.
|
||||
|
||||
Returns:
|
||||
Single latents if no extra_streams
|
||||
Tuple (primary_latents, extra_streams_dict) if extra_streams provided
|
||||
"""
|
||||
if timesteps is None:
|
||||
timesteps = scheduler.timesteps
|
||||
|
||||
total_steps = len(timesteps)
|
||||
has_extra_streams = extra_streams is not None and len(extra_streams) > 0
|
||||
|
||||
# Reset TeaCache state for new generation
|
||||
# Sets warmup/cutoff steps based on total_steps
|
||||
if (
|
||||
hasattr(self, "cache_backend")
|
||||
and self.cache_backend
|
||||
and self.cache_backend.is_enabled()
|
||||
):
|
||||
self.cache_backend.refresh(total_steps)
|
||||
|
||||
if self.rank == 0:
|
||||
if has_extra_streams:
|
||||
stream_names = ", ".join(["primary"] + list(extra_streams.keys()))
|
||||
logger.info(
|
||||
f"Denoising [{stream_names}]: {total_steps} steps, guidance={guidance_scale}"
|
||||
)
|
||||
else:
|
||||
logger.info(f"Denoising: {total_steps} steps, guidance={guidance_scale}")
|
||||
|
||||
cfg_config = self._setup_cfg_config(
|
||||
guidance_scale, prompt_embeds, neg_prompt_embeds, extra_cfg_tensors
|
||||
)
|
||||
do_cfg_parallel = cfg_config["enabled"]
|
||||
prompt_embeds = cfg_config["prompt_embeds"]
|
||||
local_extras = cfg_config["local_extras"]
|
||||
|
||||
# Extract extra stream latents and schedulers
|
||||
extra_stream_latents = {}
|
||||
extra_stream_schedulers = {}
|
||||
if extra_streams:
|
||||
for name, (stream_latents, stream_scheduler) in extra_streams.items():
|
||||
extra_stream_latents[name] = stream_latents
|
||||
extra_stream_schedulers[name] = stream_scheduler
|
||||
|
||||
start_time = time.time()
|
||||
|
||||
for i, t in enumerate(timesteps):
|
||||
step_start = time.time()
|
||||
|
||||
# Two-stage denoising: switch guidance scale at boundary
|
||||
current_guidance_scale = guidance_scale
|
||||
if guidance_scale_2 is not None and boundary_timestep is not None:
|
||||
t_scalar = t.item() if t.dim() == 0 else t[0].item()
|
||||
if t_scalar < boundary_timestep:
|
||||
current_guidance_scale = guidance_scale_2
|
||||
|
||||
# Denoise
|
||||
if do_cfg_parallel:
|
||||
timestep = t.expand(latents.shape[0])
|
||||
noise_pred, extra_noise_preds, t_trans, t_cfg = self._denoise_step_cfg_parallel(
|
||||
latents,
|
||||
extra_stream_latents,
|
||||
timestep,
|
||||
cfg_config["local_embeds"],
|
||||
forward_fn,
|
||||
current_guidance_scale,
|
||||
guidance_rescale,
|
||||
cfg_config["ulysses_size"],
|
||||
local_extras,
|
||||
)
|
||||
else:
|
||||
noise_pred, extra_noise_preds, t_trans, t_cfg = self._denoise_step_standard(
|
||||
latents,
|
||||
extra_stream_latents,
|
||||
t,
|
||||
prompt_embeds,
|
||||
forward_fn,
|
||||
current_guidance_scale,
|
||||
guidance_rescale,
|
||||
local_extras,
|
||||
)
|
||||
|
||||
# Scheduler step for all streams
|
||||
latents, extra_stream_latents, t_sched = self._scheduler_step(
|
||||
latents,
|
||||
extra_stream_latents,
|
||||
noise_pred,
|
||||
extra_noise_preds,
|
||||
t,
|
||||
scheduler,
|
||||
extra_stream_schedulers,
|
||||
)
|
||||
|
||||
# Logging
|
||||
if self.rank == 0:
|
||||
step_time = time.time() - step_start
|
||||
avg_time = (time.time() - start_time) / (i + 1)
|
||||
eta = avg_time * (total_steps - i - 1)
|
||||
logger.info(
|
||||
f"Step {i + 1}/{total_steps} | {step_time:.2f}s "
|
||||
f"(trans={t_trans:.2f}s cfg={t_cfg:.3f}s sched={t_sched:.3f}s) | "
|
||||
f"Avg={avg_time:.2f}s/step ETA={eta:.1f}s"
|
||||
)
|
||||
|
||||
if self.rank == 0:
|
||||
total_time = time.time() - start_time
|
||||
logger.info("=" * 80)
|
||||
logger.info(f"Denoising done: {total_time:.2f}s ({total_time / total_steps:.2f}s/step)")
|
||||
|
||||
# Log TeaCache performance statistics
|
||||
# Shows how many transformer steps were skipped (cache hits) vs computed
|
||||
if (
|
||||
hasattr(self, "cache_backend")
|
||||
and self.cache_backend
|
||||
and self.cache_backend.is_enabled()
|
||||
):
|
||||
stats = self.cache_backend.get_stats()
|
||||
if stats:
|
||||
logger.info(
|
||||
f"TeaCache: {stats['hit_rate']:.1%} hit rate ({stats['cached']}/{stats['total']} steps)"
|
||||
)
|
||||
|
||||
return (latents, extra_stream_latents) if has_extra_streams else latents
|
||||
228
tensorrt_llm/_torch/visual_gen/pipeline_loader.py
Normal file
228
tensorrt_llm/_torch/visual_gen/pipeline_loader.py
Normal file
@ -0,0 +1,228 @@
|
||||
"""
|
||||
Model loader for diffusion pipelines.
|
||||
|
||||
Flow:
|
||||
1. Load config via DiffusionModelConfig.from_pretrained()
|
||||
2. Create pipeline via AutoPipeline.from_config() with MetaInit
|
||||
3. Load weights with on-the-fly quantization if dynamic_weight_quant=True
|
||||
4. Call pipeline.post_load_weights()
|
||||
|
||||
Dynamic Quantization:
|
||||
- If quant_config specifies FP8/NVFP4 and dynamic_weight_quant=True:
|
||||
- Model Linear layers are created with FP8/NVFP4 buffers
|
||||
- BF16 checkpoint weights are quantized on-the-fly during loading
|
||||
- Quantized weights are copied into model buffers
|
||||
"""
|
||||
|
||||
import os
|
||||
from typing import TYPE_CHECKING, Optional
|
||||
|
||||
import torch
|
||||
|
||||
from tensorrt_llm._torch.models.modeling_utils import MetaInitMode
|
||||
from tensorrt_llm.llmapi.utils import download_hf_model
|
||||
from tensorrt_llm.logger import logger
|
||||
from tensorrt_llm.mapping import Mapping
|
||||
|
||||
from .checkpoints import WeightLoader
|
||||
from .config import DiffusionArgs, DiffusionModelConfig, PipelineComponent
|
||||
from .models import AutoPipeline
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .models import BasePipeline
|
||||
|
||||
|
||||
class PipelineLoader:
|
||||
"""
|
||||
Loader for diffusion pipelines.
|
||||
|
||||
Supports dynamic quantization: when quant_config specifies FP8/NVFP4,
|
||||
model is built with quantized buffers and BF16 weights are quantized
|
||||
on-the-fly during loading.
|
||||
|
||||
Example:
|
||||
args = DiffusionArgs(
|
||||
checkpoint_path="/path/to/model",
|
||||
linear=LinearConfig(type="trtllm-fp8-blockwise"),
|
||||
parallel=ParallelConfig(dit_tp_size=2),
|
||||
)
|
||||
pipeline = PipelineLoader(args).load()
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
args: Optional[DiffusionArgs] = None,
|
||||
*,
|
||||
mapping: Optional[Mapping] = None,
|
||||
device: str = "cuda",
|
||||
):
|
||||
"""
|
||||
Initialize model loader.
|
||||
|
||||
Args:
|
||||
args: DiffusionArgs containing all configuration (preferred)
|
||||
mapping: Tensor parallel mapping (fallback if args is None)
|
||||
device: Device to load model on (fallback if args is None)
|
||||
"""
|
||||
self.args = args
|
||||
if args is not None:
|
||||
self.mapping = args.to_mapping()
|
||||
self.device = torch.device(args.device)
|
||||
else:
|
||||
self.mapping = mapping or Mapping()
|
||||
self.device = torch.device(device)
|
||||
|
||||
def _resolve_checkpoint_dir(self, checkpoint_dir: str) -> str:
|
||||
"""Resolve checkpoint_dir to a local directory path.
|
||||
|
||||
If checkpoint_dir is an existing local path, returns it unchanged.
|
||||
Otherwise, attempts to download from HuggingFace Hub using the
|
||||
file-lock-protected ``download_hf_model`` utility (safe for
|
||||
concurrent multi-process access).
|
||||
|
||||
Args:
|
||||
checkpoint_dir: Local path or HuggingFace Hub model ID.
|
||||
|
||||
Returns:
|
||||
Path to local directory containing the model.
|
||||
|
||||
Raises:
|
||||
ValueError: If the path cannot be resolved (invalid repo ID,
|
||||
authentication failure, offline with no cache, etc.)
|
||||
"""
|
||||
if os.path.exists(checkpoint_dir):
|
||||
return checkpoint_dir
|
||||
|
||||
revision = self.args.revision if self.args else None
|
||||
logger.info(
|
||||
f"'{checkpoint_dir}' not found locally; "
|
||||
f"attempting HuggingFace Hub download (revision={revision})"
|
||||
)
|
||||
try:
|
||||
local_dir = download_hf_model(checkpoint_dir, revision=revision)
|
||||
except Exception as e:
|
||||
raise ValueError(
|
||||
f"Could not resolve '{checkpoint_dir}' as a local path or "
|
||||
f"HuggingFace Hub model ID: {e}"
|
||||
) from e
|
||||
return str(local_dir)
|
||||
|
||||
def load(
|
||||
self,
|
||||
checkpoint_dir: Optional[str] = None,
|
||||
) -> "BasePipeline":
|
||||
"""
|
||||
Load a diffusion pipeline with optional dynamic quantization.
|
||||
|
||||
Flow:
|
||||
1. Resolve checkpoint_dir (local path or HuggingFace Hub model ID)
|
||||
2. Load config via DiffusionModelConfig.from_pretrained()
|
||||
3. Create pipeline via AutoPipeline.from_config() with MetaInit
|
||||
4. Load transformer weights via pipeline.load_weights()
|
||||
5. Load auxiliary components (VAE, text_encoder) via diffusers
|
||||
6. Call pipeline.post_load_weights()
|
||||
|
||||
Args:
|
||||
checkpoint_dir: Local path or HF Hub model ID (uses args.checkpoint_path if not provided)
|
||||
|
||||
Returns:
|
||||
Loaded pipeline (WanPipeline, FluxPipeline, etc.) - type auto-detected
|
||||
"""
|
||||
# Resolve checkpoint_dir
|
||||
checkpoint_dir = checkpoint_dir or (self.args.checkpoint_path if self.args else None)
|
||||
if not checkpoint_dir:
|
||||
raise ValueError("checkpoint_dir must be provided or set in DiffusionArgs")
|
||||
checkpoint_dir = self._resolve_checkpoint_dir(str(checkpoint_dir))
|
||||
|
||||
# Get loading options from args
|
||||
skip_components = self.args.skip_components if self.args else []
|
||||
|
||||
# =====================================================================
|
||||
# STEP 1: Load Config (includes quant config parsing)
|
||||
# Merge pretrained checkpoint config with user-provided DiffusionArgs
|
||||
# =====================================================================
|
||||
logger.info(f"Loading config from {checkpoint_dir}")
|
||||
config = DiffusionModelConfig.from_pretrained(
|
||||
checkpoint_dir,
|
||||
args=self.args,
|
||||
mapping=self.mapping,
|
||||
)
|
||||
|
||||
# Log quantization settings
|
||||
if config.quant_config and config.quant_config.quant_algo:
|
||||
logger.info(f"Quantization: {config.quant_config.quant_algo.name}")
|
||||
logger.info(f"Dynamic weight quant: {config.dynamic_weight_quant}")
|
||||
|
||||
# =====================================================================
|
||||
# STEP 2: Create Pipeline with MetaInit
|
||||
# Pipeline type is auto-detected from model_index.json
|
||||
# - Meta tensors (no GPU memory until materialization)
|
||||
# - If quant_config specifies FP8, Linear layers have FP8 weight buffers
|
||||
# =====================================================================
|
||||
logger.info("Creating pipeline with MetaInitMode")
|
||||
with MetaInitMode():
|
||||
pipeline = AutoPipeline.from_config(config, checkpoint_dir)
|
||||
|
||||
# Convert meta tensors to CUDA tensors
|
||||
self._materialize_meta_tensors(pipeline)
|
||||
pipeline.to(self.device)
|
||||
|
||||
# =====================================================================
|
||||
# STEP 3: Load Transformer Weights
|
||||
# If dynamic_weight_quant=True:
|
||||
# - BF16 checkpoint weights are loaded
|
||||
# - Quantized on-the-fly to FP8/NVFP4 by DynamicLinearWeightLoader
|
||||
# - Copied into model's quantized buffers
|
||||
# =====================================================================
|
||||
if pipeline.transformer is None:
|
||||
raise ValueError("Pipeline has no transformer component")
|
||||
|
||||
transformer_components = getattr(pipeline, "transformer_components", ["transformer"])
|
||||
logger.info(f"Transformer components: {transformer_components}")
|
||||
|
||||
transformer_path = os.path.join(checkpoint_dir, PipelineComponent.TRANSFORMER)
|
||||
if not os.path.exists(transformer_path):
|
||||
raise FileNotFoundError(
|
||||
f"Transformer path does not exist: {transformer_path}. "
|
||||
f"Checkpoint directory must contain a 'transformer' subdirectory."
|
||||
)
|
||||
|
||||
weight_loader = WeightLoader(components=transformer_components)
|
||||
# TODO: accelerate the cpu loading w/ multiprocessing
|
||||
weights = weight_loader.load_weights(checkpoint_dir, self.mapping)
|
||||
|
||||
# Load weights into pipeline
|
||||
pipeline.load_weights(weights)
|
||||
|
||||
# =====================================================================
|
||||
# STEP 4: Load Standard Components (VAE, TextEncoder via diffusers)
|
||||
# These are NOT quantized - loaded as-is from checkpoint
|
||||
# =====================================================================
|
||||
pipeline.load_standard_components(checkpoint_dir, self.device, skip_components)
|
||||
|
||||
# =====================================================================
|
||||
# STEP 5: Post-load Hooks (TeaCache setup, etc.)
|
||||
# =====================================================================
|
||||
if hasattr(pipeline, "post_load_weights"):
|
||||
pipeline.post_load_weights()
|
||||
|
||||
logger.info(f"Pipeline loaded: {pipeline.__class__.__name__}")
|
||||
return pipeline
|
||||
|
||||
def _materialize_meta_tensors(self, module: torch.nn.Module) -> None:
|
||||
"""
|
||||
Convert meta tensors to CUDA tensors.
|
||||
|
||||
Meta tensors are placeholders that don't allocate GPU memory.
|
||||
After model structure is defined, we materialize them to real tensors.
|
||||
"""
|
||||
memo = {}
|
||||
|
||||
def init_meta_tensor(t: torch.Tensor) -> torch.Tensor:
|
||||
if t.device != torch.device("meta"):
|
||||
return t
|
||||
if t not in memo:
|
||||
memo[t] = torch.empty_like(t, device="cuda")
|
||||
return memo[t]
|
||||
|
||||
module._apply(init_meta_tensor)
|
||||
94
tensorrt_llm/_torch/visual_gen/pipeline_registry.py
Normal file
94
tensorrt_llm/_torch/visual_gen/pipeline_registry.py
Normal file
@ -0,0 +1,94 @@
|
||||
"""Pipeline registry for unified config flow.
|
||||
|
||||
Follows: DiffusionArgs → PipelineLoader → DiffusionModelConfig → AutoPipeline → BasePipeline
|
||||
|
||||
All pipelines (Wan, Flux2, LTX2) register via @register_pipeline decorator.
|
||||
"""
|
||||
|
||||
import json
|
||||
import os
|
||||
from typing import TYPE_CHECKING, Dict, Type
|
||||
|
||||
from tensorrt_llm.logger import logger
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .config import DiffusionModelConfig
|
||||
from .pipeline import BasePipeline
|
||||
|
||||
# Global registry: pipeline_name -> pipeline_class
|
||||
PIPELINE_REGISTRY: Dict[str, Type["BasePipeline"]] = {}
|
||||
|
||||
|
||||
def register_pipeline(name: str):
|
||||
"""Register a pipeline class for AutoPipeline.
|
||||
|
||||
Usage:
|
||||
@register_pipeline("WanPipeline")
|
||||
class WanPipeline(BasePipeline):
|
||||
...
|
||||
"""
|
||||
|
||||
def decorator(cls: Type["BasePipeline"]) -> Type["BasePipeline"]:
|
||||
PIPELINE_REGISTRY[name] = cls
|
||||
logger.debug(f"Registered pipeline: {name} -> {cls.__name__}")
|
||||
return cls
|
||||
|
||||
return decorator
|
||||
|
||||
|
||||
class AutoPipeline:
|
||||
"""Factory for creating pipelines from config."""
|
||||
|
||||
@staticmethod
|
||||
def from_config(
|
||||
config: "DiffusionModelConfig",
|
||||
checkpoint_dir: str,
|
||||
) -> "BasePipeline":
|
||||
"""
|
||||
Create pipeline instance from DiffusionModelConfig.
|
||||
"""
|
||||
# Detect pipeline type from model_index.json
|
||||
pipeline_type = AutoPipeline._detect_from_checkpoint(checkpoint_dir)
|
||||
|
||||
if pipeline_type not in PIPELINE_REGISTRY:
|
||||
raise ValueError(
|
||||
f"Unknown pipeline: '{pipeline_type}'. "
|
||||
f"Available: {list(PIPELINE_REGISTRY.keys())}\n"
|
||||
f"Checkpoint: {checkpoint_dir}"
|
||||
)
|
||||
|
||||
pipeline_class = PIPELINE_REGISTRY[pipeline_type]
|
||||
logger.info(f"AutoPipeline: Creating {pipeline_class.__name__} from {checkpoint_dir}")
|
||||
|
||||
# Instantiate pipeline with DiffusionModelConfig
|
||||
return pipeline_class(config)
|
||||
|
||||
@staticmethod
|
||||
def _detect_from_checkpoint(checkpoint_dir: str) -> str:
|
||||
"""Detect pipeline type."""
|
||||
index_path = os.path.join(checkpoint_dir, "model_index.json")
|
||||
|
||||
if os.path.exists(index_path):
|
||||
with open(index_path) as f:
|
||||
index = json.load(f)
|
||||
|
||||
class_name = index.get("_class_name", "")
|
||||
|
||||
if class_name in PIPELINE_REGISTRY:
|
||||
return class_name
|
||||
|
||||
if "ImageToVideo" in class_name or "I2V" in class_name:
|
||||
if "Wan" in class_name:
|
||||
return "WanImageToVideoPipeline"
|
||||
# Generic Wan (T2V)
|
||||
if "Wan" in class_name:
|
||||
return "WanPipeline"
|
||||
if "Flux" in class_name:
|
||||
return "FluxPipeline"
|
||||
if "LTX" in class_name or "Ltx" in class_name:
|
||||
return "LTX2Pipeline"
|
||||
|
||||
raise ValueError(
|
||||
f"Cannot detect pipeline type for {checkpoint_dir}\n"
|
||||
f"Expected model_index.json with '_class_name' field at: {index_path}"
|
||||
)
|
||||
15
tensorrt_llm/_torch/visual_gen/quantization/__init__.py
Normal file
15
tensorrt_llm/_torch/visual_gen/quantization/__init__.py
Normal file
@ -0,0 +1,15 @@
|
||||
# SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
"""
|
||||
Quantization support for diffusion models.
|
||||
"""
|
||||
|
||||
from .loader import DynamicLinearWeightLoader
|
||||
from .ops import quantize_fp8_blockwise, quantize_fp8_per_tensor
|
||||
|
||||
__all__ = [
|
||||
"DynamicLinearWeightLoader",
|
||||
"quantize_fp8_per_tensor",
|
||||
"quantize_fp8_blockwise",
|
||||
]
|
||||
197
tensorrt_llm/_torch/visual_gen/quantization/loader.py
Normal file
197
tensorrt_llm/_torch/visual_gen/quantization/loader.py
Normal file
@ -0,0 +1,197 @@
|
||||
"""
|
||||
Dynamic weight quantization loader for Linear modules.
|
||||
|
||||
Wraps Linear.load_weights() to perform dynamic quantization before loading to device.
|
||||
"""
|
||||
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
import torch
|
||||
|
||||
from tensorrt_llm._torch.modules.linear import Linear, WeightMode
|
||||
from tensorrt_llm._torch.visual_gen.config import DiffusionModelConfig
|
||||
from tensorrt_llm._torch.visual_gen.quantization.ops import (
|
||||
quantize_fp8_blockwise,
|
||||
quantize_fp8_per_tensor,
|
||||
)
|
||||
from tensorrt_llm.quantization.mode import QuantAlgo
|
||||
|
||||
|
||||
class DynamicLinearWeightLoader:
|
||||
"""
|
||||
Dynamic weight quantization loader for Linear modules.
|
||||
|
||||
Wraps Linear.load_weights() to perform dynamic (load-time) quantization
|
||||
from BF16/FP16 to FP8 before loading weights to device.
|
||||
|
||||
Example:
|
||||
params_map = {'qkv_proj': ['to_q', 'to_k', 'to_v']}
|
||||
loader = DynamicLinearWeightLoader(model_config, params_map=params_map)
|
||||
|
||||
for name, module in model.named_modules():
|
||||
if isinstance(module, Linear):
|
||||
weight_dicts = loader.get_linear_weights(module, name, weights)
|
||||
loader.load_linear_weights(module, name, weight_dicts)
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model_config: DiffusionModelConfig,
|
||||
params_map: Optional[Dict[str, List[str]]] = None,
|
||||
):
|
||||
self.model_config = model_config
|
||||
self.quant_config = model_config.quant_config
|
||||
self.quant_config_dict = model_config.quant_config_dict
|
||||
self.dynamic_weight_quant = model_config.dynamic_weight_quant
|
||||
self.params_map = params_map or {}
|
||||
|
||||
# =========================================================================
|
||||
# Weight gathering methods
|
||||
# =========================================================================
|
||||
|
||||
def get_linear_weights(
|
||||
self,
|
||||
module: Linear,
|
||||
full_name: str,
|
||||
weights: Dict[str, torch.Tensor],
|
||||
) -> List[Dict[str, torch.Tensor]]:
|
||||
"""Get weights for a Linear module, auto-detecting fused weights."""
|
||||
weights_config = getattr(module, "weights_loading_config", None)
|
||||
if weights_config is not None:
|
||||
weight_mode = getattr(weights_config, "weight_mode", None)
|
||||
if weight_mode == WeightMode.FUSED_QKV_LINEAR:
|
||||
fused_names = self._get_fused_names(full_name)
|
||||
return self._get_fused_weights(full_name, weights, fused_names)
|
||||
|
||||
return self._get_vanilla_weights(full_name, weights)
|
||||
|
||||
def filter_weights(
|
||||
self, prefix: str, weights: Dict[str, torch.Tensor]
|
||||
) -> Dict[str, torch.Tensor]:
|
||||
"""
|
||||
Filter weights by prefix and strip the prefix.
|
||||
|
||||
Example:
|
||||
prefix = 'blocks.0.attn1.to_q'
|
||||
weights = {'blocks.0.attn1.to_q.weight': ..., 'blocks.0.attn1.to_q.bias': ...}
|
||||
Returns: {'weight': ..., 'bias': ...}
|
||||
"""
|
||||
result = {}
|
||||
prefix_dot = prefix + "."
|
||||
for k, v in weights.items():
|
||||
if k.startswith(prefix_dot):
|
||||
result[k[len(prefix_dot) :]] = v
|
||||
return result
|
||||
|
||||
def _get_fused_names(self, full_name: str) -> List[str]:
|
||||
"""Get checkpoint names for a fused module from params_map."""
|
||||
for suffix, names in self.params_map.items():
|
||||
if full_name.endswith(suffix):
|
||||
return names
|
||||
raise ValueError(
|
||||
f"No params_map entry for fused module '{full_name}'. "
|
||||
f"Add mapping like {{'qkv_proj': ['to_q', 'to_k', 'to_v']}} to params_map."
|
||||
)
|
||||
|
||||
def _get_fused_weights(
|
||||
self,
|
||||
full_name: str,
|
||||
weights: Dict[str, torch.Tensor],
|
||||
fused_names: List[str],
|
||||
) -> List[Dict[str, torch.Tensor]]:
|
||||
"""Get weights for a fused module from checkpoint."""
|
||||
parent_path = ".".join(full_name.split(".")[:-1])
|
||||
module_weights = []
|
||||
for ckpt_name in fused_names:
|
||||
ckpt_path = f"{parent_path}.{ckpt_name}" if parent_path else ckpt_name
|
||||
filtered = self.filter_weights(ckpt_path, weights)
|
||||
module_weights.append(filtered)
|
||||
return module_weights
|
||||
|
||||
def _get_vanilla_weights(
|
||||
self,
|
||||
full_name: str,
|
||||
weights: Dict[str, torch.Tensor],
|
||||
) -> List[Dict[str, torch.Tensor]]:
|
||||
"""Get weights for a standard (non-fused) Linear module."""
|
||||
fw = self.filter_weights(full_name, weights)
|
||||
return [fw] if fw else []
|
||||
|
||||
# =========================================================================
|
||||
# Quantization methods
|
||||
# =========================================================================
|
||||
|
||||
def _get_quant_algo_for_layer(self, name: str) -> Optional[QuantAlgo]:
|
||||
"""Get quantization algorithm for a specific layer."""
|
||||
if self.quant_config_dict is not None:
|
||||
layer_config = self.quant_config_dict.get(name)
|
||||
if layer_config is not None:
|
||||
return layer_config.quant_algo
|
||||
|
||||
if self.quant_config is not None:
|
||||
return self.quant_config.quant_algo
|
||||
|
||||
return None
|
||||
|
||||
def _should_dynamic_quantize(
|
||||
self, weight_dict: Dict[str, torch.Tensor], quant_algo: Optional[QuantAlgo], name: str
|
||||
) -> bool:
|
||||
"""Decide if weight should be dynamically quantized at load time."""
|
||||
if not self.dynamic_weight_quant or quant_algo is None:
|
||||
return False
|
||||
|
||||
# Check if module is excluded
|
||||
if self.quant_config is not None:
|
||||
if self.quant_config.is_module_excluded_from_quantization(name):
|
||||
return False
|
||||
|
||||
weight = weight_dict.get("weight")
|
||||
if weight is None:
|
||||
return False
|
||||
|
||||
# For FP8 algorithms: quantize if weight is high precision
|
||||
if quant_algo in (QuantAlgo.FP8, QuantAlgo.FP8_BLOCK_SCALES):
|
||||
if weight.dtype == torch.float8_e4m3fn and "weight_scale" in weight_dict:
|
||||
return False # Already quantized
|
||||
return weight.dtype in (torch.bfloat16, torch.float16, torch.float32)
|
||||
|
||||
return False
|
||||
|
||||
def _maybe_dynamic_quantize(
|
||||
self, weight_dict: Dict[str, torch.Tensor], quant_algo: Optional[QuantAlgo], name: str
|
||||
) -> Dict[str, torch.Tensor]:
|
||||
"""Conditionally quantize weight at load time on GPU."""
|
||||
if not self._should_dynamic_quantize(weight_dict, quant_algo, name):
|
||||
return weight_dict
|
||||
|
||||
weight = weight_dict["weight"]
|
||||
|
||||
# Move to GPU only if needed
|
||||
if weight.device.type != "cuda":
|
||||
weight = weight.cuda()
|
||||
|
||||
if quant_algo == QuantAlgo.FP8:
|
||||
qweight, scale = quantize_fp8_per_tensor(weight)
|
||||
elif quant_algo == QuantAlgo.FP8_BLOCK_SCALES:
|
||||
block_size = self.quant_config.group_size if self.quant_config else 128
|
||||
qweight, scale = quantize_fp8_blockwise(weight, block_size=block_size)
|
||||
else:
|
||||
return weight_dict
|
||||
|
||||
return {**weight_dict, "weight": qweight, "weight_scale": scale}
|
||||
|
||||
def load_linear_weights(
|
||||
self, module: Linear, name: str, weight_dicts: List[Dict[str, torch.Tensor]]
|
||||
) -> None:
|
||||
"""Load weights into Linear module with optional quantization."""
|
||||
module_quant_config = getattr(module, "quant_config", None)
|
||||
if module_quant_config is not None:
|
||||
quant_algo = module_quant_config.quant_algo
|
||||
else:
|
||||
quant_algo = self._get_quant_algo_for_layer(name)
|
||||
|
||||
quantized_weight_dicts = [
|
||||
self._maybe_dynamic_quantize(wd, quant_algo, name) for wd in weight_dicts
|
||||
]
|
||||
|
||||
module.load_weights(quantized_weight_dicts)
|
||||
98
tensorrt_llm/_torch/visual_gen/quantization/ops.py
Normal file
98
tensorrt_llm/_torch/visual_gen/quantization/ops.py
Normal file
@ -0,0 +1,98 @@
|
||||
"""
|
||||
Quantization operations for diffusion models.
|
||||
|
||||
Provides on-the-fly quantization functions for dynamic (load-time) quantization.
|
||||
"""
|
||||
|
||||
from typing import Tuple
|
||||
|
||||
import torch
|
||||
|
||||
# FP8 E4M3 max value
|
||||
FP8_E4M3_MAX = torch.finfo(torch.float8_e4m3fn).max
|
||||
|
||||
|
||||
def quantize_fp8_per_tensor(weight: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Quantize weight to FP8 E4M3 with per-tensor scale.
|
||||
|
||||
Uses torch.ops.tensorrt_llm.quantize_e4m3_per_tensor CUDA kernel.
|
||||
|
||||
Args:
|
||||
weight: Input weight tensor (BF16/FP16/FP32), shape (out_features, in_features)
|
||||
|
||||
Returns:
|
||||
Tuple of:
|
||||
- qweight: Quantized weight (FP8 E4M3), same shape as input
|
||||
- weight_scale: Dequantization scale (FP32), shape (1, 1)
|
||||
"""
|
||||
qweight, scale = torch.ops.tensorrt_llm.quantize_e4m3_per_tensor(weight)
|
||||
# Ensure scale is float32 and has shape (1, 1) for consistency
|
||||
return qweight, scale.to(torch.float32)
|
||||
|
||||
|
||||
def quantize_fp8_blockwise(
|
||||
weight: torch.Tensor, block_size: int = 128
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Quantize weight to FP8 E4M3 with 128x128 blockwise scales.
|
||||
|
||||
This function converts BF16/FP16/FP32 weights to FP8 E4M3 format using
|
||||
per-block scale factors. The weight is divided into blocks of size
|
||||
(block_size, block_size) and each block has its own scale.
|
||||
|
||||
Args:
|
||||
weight: Input weight tensor (BF16/FP16/FP32), shape (out_features, in_features)
|
||||
block_size: Block size for blockwise quantization (default: 128)
|
||||
|
||||
Returns:
|
||||
Tuple of:
|
||||
- qweight: Quantized weight (FP8 E4M3), shape (out_features, in_features)
|
||||
- block_scales: Block-wise dequantization scales (FP32),
|
||||
shape (num_blocks_out, num_blocks_in)
|
||||
|
||||
Note:
|
||||
- If dimensions are not divisible by block_size, the last block may be smaller
|
||||
- block_scales are dequantization scales (multiply to get back original scale)
|
||||
- This uses 128x128 block scaling compatible with Linear module's FP8_BLOCK_SCALES
|
||||
"""
|
||||
out_features, in_features = weight.shape
|
||||
weight_fp32 = weight.float()
|
||||
|
||||
# Calculate number of blocks
|
||||
num_blocks_out = (out_features + block_size - 1) // block_size
|
||||
num_blocks_in = (in_features + block_size - 1) // block_size
|
||||
|
||||
# Initialize outputs
|
||||
qweight = torch.empty_like(weight, dtype=torch.float8_e4m3fn)
|
||||
block_scales = torch.empty(
|
||||
(num_blocks_out, num_blocks_in), dtype=torch.float32, device=weight.device
|
||||
)
|
||||
|
||||
# Quantize each block
|
||||
for i in range(num_blocks_out):
|
||||
row_start = i * block_size
|
||||
row_end = min((i + 1) * block_size, out_features)
|
||||
|
||||
for j in range(num_blocks_in):
|
||||
col_start = j * block_size
|
||||
col_end = min((j + 1) * block_size, in_features)
|
||||
|
||||
# Extract block
|
||||
block = weight_fp32[row_start:row_end, col_start:col_end]
|
||||
|
||||
# Compute block scale
|
||||
max_val = block.abs().max()
|
||||
scale = (
|
||||
max_val / FP8_E4M3_MAX if max_val > 0 else torch.tensor(1.0, device=weight.device)
|
||||
)
|
||||
|
||||
# Quantize block
|
||||
inv_scale = scale.reciprocal() if scale > 0 else torch.tensor(1.0, device=weight.device)
|
||||
qblock = (block * inv_scale).clamp(-FP8_E4M3_MAX, FP8_E4M3_MAX).to(torch.float8_e4m3fn)
|
||||
|
||||
# Store results
|
||||
qweight[row_start:row_end, col_start:col_end] = qblock
|
||||
block_scales[i, j] = scale.to(torch.float32)
|
||||
|
||||
return qweight, block_scales
|
||||
409
tensorrt_llm/_torch/visual_gen/teacache.py
Normal file
409
tensorrt_llm/_torch/visual_gen/teacache.py
Normal file
@ -0,0 +1,409 @@
|
||||
import inspect
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Callable, Dict, List, Optional
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from diffusers.models.modeling_outputs import Transformer2DModelOutput
|
||||
|
||||
from tensorrt_llm.logger import logger
|
||||
|
||||
# =============================================================================
|
||||
# Core Data Structures
|
||||
# =============================================================================
|
||||
|
||||
|
||||
@dataclass
|
||||
class CacheContext:
|
||||
"""Context returned by model extractors for TeaCache.
|
||||
|
||||
Attributes:
|
||||
modulated_input: Timestep embedding used for cache distance calculation
|
||||
hidden_states: Input hidden states for the transformer
|
||||
encoder_hidden_states: Text/prompt embeddings
|
||||
run_transformer_blocks: Callable that executes the transformer forward pass
|
||||
postprocess: Callable that formats the output to the expected return type
|
||||
"""
|
||||
|
||||
modulated_input: torch.Tensor
|
||||
hidden_states: torch.Tensor
|
||||
encoder_hidden_states: Any = None
|
||||
run_transformer_blocks: Callable = None
|
||||
postprocess: Callable = None
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Extractor Registry
|
||||
# =============================================================================
|
||||
|
||||
_EXTRACTORS = {}
|
||||
|
||||
|
||||
def register_extractor(model_name, extractor_fn):
|
||||
"""Register an extractor function for a model class."""
|
||||
_EXTRACTORS[model_name] = extractor_fn
|
||||
|
||||
|
||||
def get_extractor(model_type):
|
||||
"""Get the registered extractor for a model type."""
|
||||
if model_type not in _EXTRACTORS:
|
||||
raise ValueError(
|
||||
f"TeaCache: Unknown model '{model_type}'. Available: {list(_EXTRACTORS.keys())}"
|
||||
)
|
||||
return _EXTRACTORS[model_type]
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Config-Based Extractor System
|
||||
# =============================================================================
|
||||
|
||||
|
||||
@dataclass
|
||||
class ExtractorConfig:
|
||||
"""Configuration for model-specific TeaCache extractors.
|
||||
|
||||
Only the timestep embedding logic is model-specific; all other logic is handled generically.
|
||||
|
||||
Attributes:
|
||||
model_class_name: Model class name (e.g., "LTX2VideoTransformer3DModel")
|
||||
timestep_embed_fn: Callable(module, timestep, guidance=None) -> Tensor
|
||||
timestep_param_name: Parameter name for timestep in forward() (default: "timestep")
|
||||
guidance_param_name: Parameter name for guidance if used (default: None)
|
||||
forward_params: List of parameter names (None = auto-introspect from forward signature)
|
||||
return_dict_default: Default value for return_dict parameter (default: True)
|
||||
output_model_class: Output class name for return type (default: "Transformer2DModelOutput")
|
||||
"""
|
||||
|
||||
model_class_name: str
|
||||
timestep_embed_fn: Callable
|
||||
timestep_param_name: str = "timestep"
|
||||
guidance_param_name: Optional[str] = None
|
||||
forward_params: Optional[List[str]] = None
|
||||
return_dict_default: bool = True
|
||||
output_model_class: str = "Transformer2DModelOutput"
|
||||
|
||||
|
||||
class GenericExtractor:
|
||||
"""Handles common TeaCache logic for all diffusion models.
|
||||
|
||||
Extracts forward() arguments, creates run_blocks and postprocess callbacks,
|
||||
and delegates only timestep embedding computation to model-specific logic.
|
||||
"""
|
||||
|
||||
def __init__(self, config: ExtractorConfig):
|
||||
self.config = config
|
||||
|
||||
def _extract_forward_args(self, module: torch.nn.Module, *args, **kwargs) -> Dict:
|
||||
"""Extract and normalize forward() arguments from *args and **kwargs."""
|
||||
# Get parameter names (auto-introspect or use config)
|
||||
if self.config.forward_params is not None:
|
||||
param_names = self.config.forward_params
|
||||
else:
|
||||
# Auto-introspect forward signature
|
||||
try:
|
||||
sig = inspect.signature(module._original_forward)
|
||||
param_names = [p for p in sig.parameters if p not in ("self", "args", "kwargs")]
|
||||
except Exception as e:
|
||||
logger.warning(f"Could not introspect forward signature: {e}")
|
||||
param_names = []
|
||||
|
||||
# Map positional args to parameter names
|
||||
extracted = {param_names[i]: arg for i, arg in enumerate(args) if i < len(param_names)}
|
||||
|
||||
# Merge kwargs (kwargs take precedence)
|
||||
extracted.update(kwargs)
|
||||
return extracted
|
||||
|
||||
def _compute_timestep_embedding(self, module: torch.nn.Module, params: Dict) -> torch.Tensor:
|
||||
"""Compute timestep embedding using configured callable."""
|
||||
timestep = params.get(self.config.timestep_param_name)
|
||||
if timestep is None:
|
||||
raise ValueError(f"Missing required parameter: {self.config.timestep_param_name}")
|
||||
|
||||
# Flatten timestep if needed (common pattern)
|
||||
timestep_flat = timestep.flatten() if timestep.ndim == 2 else timestep
|
||||
guidance = (
|
||||
params.get(self.config.guidance_param_name) if self.config.guidance_param_name else None
|
||||
)
|
||||
|
||||
# Call configured timestep embedding function
|
||||
try:
|
||||
return self.config.timestep_embed_fn(module, timestep_flat, guidance)
|
||||
except Exception as e:
|
||||
logger.error(f"Timestep embedder failed: {e}")
|
||||
# Last resort: use timestep as-is
|
||||
logger.warning("Using timestep fallback")
|
||||
return timestep_flat.unsqueeze(-1) if timestep_flat.ndim == 1 else timestep_flat
|
||||
|
||||
def __call__(self, module: torch.nn.Module, *args, **kwargs) -> CacheContext:
|
||||
"""Main extractor logic - called by TeaCacheHook.
|
||||
|
||||
Extracts forward arguments, computes timestep embedding, and creates callbacks
|
||||
for running the transformer and post-processing the output.
|
||||
"""
|
||||
# Extract forward arguments from positional and keyword args
|
||||
params = self._extract_forward_args(module, *args, **kwargs)
|
||||
|
||||
# Compute timestep embedding (used for cache distance calculation)
|
||||
t_emb = self._compute_timestep_embedding(module, params)
|
||||
return_dict = params.get("return_dict", self.config.return_dict_default)
|
||||
|
||||
def run_blocks():
|
||||
"""Execute the full transformer forward pass with original parameters."""
|
||||
ret = module._original_forward(**params)
|
||||
# Normalize output to tuple format
|
||||
if return_dict and not isinstance(ret, tuple):
|
||||
sample = ret.sample if hasattr(ret, "sample") else ret
|
||||
return (sample,) if not isinstance(sample, tuple) else sample
|
||||
return ret if isinstance(ret, tuple) else (ret,)
|
||||
|
||||
def postprocess(output):
|
||||
"""Convert cached/computed output back to expected return format."""
|
||||
if return_dict:
|
||||
if isinstance(output, tuple):
|
||||
return output
|
||||
return Transformer2DModelOutput(sample=output)
|
||||
# For return_dict=False, unwrap single-element tuple to raw tensor
|
||||
if isinstance(output, tuple) and len(output) == 1:
|
||||
return output[0]
|
||||
# Return raw tensor as-is (TeaCacheHook always passes tensors to postprocess)
|
||||
return output
|
||||
|
||||
return CacheContext(
|
||||
modulated_input=t_emb,
|
||||
hidden_states=params.get("hidden_states"),
|
||||
encoder_hidden_states=params.get("encoder_hidden_states"),
|
||||
run_transformer_blocks=run_blocks,
|
||||
postprocess=postprocess,
|
||||
)
|
||||
|
||||
|
||||
def register_extractor_from_config(config: ExtractorConfig):
|
||||
"""Register a TeaCache extractor for a model. Call this in pipeline's load() method.
|
||||
|
||||
Example:
|
||||
register_extractor_from_config(ExtractorConfig(
|
||||
model_class_name="LTX2VideoTransformer3DModel",
|
||||
timestep_embed_fn=self._compute_ltx2_timestep_embedding,
|
||||
))
|
||||
"""
|
||||
extractor = GenericExtractor(config)
|
||||
register_extractor(config.model_class_name, extractor)
|
||||
logger.debug(f"Registered TeaCache extractor for {config.model_class_name}")
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# TeaCache Runtime (caching hook and lifecycle management)
|
||||
# =============================================================================
|
||||
|
||||
|
||||
class TeaCacheHook:
|
||||
"""Caches transformer blocks when timestep embeddings change slowly.
|
||||
|
||||
The hook monitors the relative change in timestep embeddings between steps.
|
||||
When the change is small (below threshold), it reuses the cached residual
|
||||
from the previous step instead of running the full transformer.
|
||||
|
||||
Separate cache states are maintained for conditional and unconditional branches
|
||||
when using Classifier-Free Guidance (CFG).
|
||||
"""
|
||||
|
||||
def __init__(self, config):
|
||||
self.config = config
|
||||
# Polynomial function to rescale embedding distances
|
||||
self.rescale_func = np.poly1d(config.coefficients)
|
||||
self.extractor_fn = None
|
||||
|
||||
# Separate cache state for conditional (pos) and unconditional (neg) branches
|
||||
self.state_pos = self._new_state()
|
||||
self.state_neg = self._new_state()
|
||||
self.stats = {"total": 0, "cached": 0}
|
||||
|
||||
def _new_state(self):
|
||||
return {"cnt": 0, "acc_dist": 0.0, "prev_input": None, "prev_residual": None}
|
||||
|
||||
def initialize(self, module):
|
||||
self.extractor_fn = get_extractor(module.__class__.__name__)
|
||||
|
||||
def reset_state(self):
|
||||
self.state_pos = self._new_state()
|
||||
self.state_neg = self._new_state()
|
||||
self.stats = {"total": 0, "cached": 0}
|
||||
|
||||
def get_stats(self):
|
||||
total = max(self.stats["total"], 1)
|
||||
cached = self.stats["cached"]
|
||||
return {
|
||||
"hit_rate": cached / total,
|
||||
"total": total,
|
||||
"cached": cached,
|
||||
# Backward compatibility
|
||||
"total_steps": total,
|
||||
"cached_steps": cached,
|
||||
"compute_steps": total - cached,
|
||||
}
|
||||
|
||||
def __call__(self, module, *args, **kwargs):
|
||||
"""Main hook called during transformer forward pass.
|
||||
|
||||
Decides whether to run the full transformer or reuse cached residual
|
||||
based on timestep embedding distance.
|
||||
"""
|
||||
# Extract context (timestep embedding, hidden states, callbacks)
|
||||
ctx = self.extractor_fn(module, *args, **kwargs)
|
||||
|
||||
# Select cache state (for CFG: separate tracking for conditional/unconditional)
|
||||
cache_branch = getattr(module, "_cache_branch", None)
|
||||
state = self.state_neg if cache_branch == "uncond" else self.state_pos
|
||||
|
||||
# Decide: compute transformer or use cache?
|
||||
should_compute = self._should_compute(state, ctx.modulated_input)
|
||||
self.stats["total"] += 1
|
||||
|
||||
if not should_compute and state["prev_residual"] is not None:
|
||||
# Cache hit: Add cached residual to skip transformer computation
|
||||
logger.debug(f"TeaCache: SKIP step {state['cnt']}")
|
||||
# For I2V: output might have fewer channels than input
|
||||
# Apply residual only to the latent channels
|
||||
if ctx.hidden_states.shape[1] != state["prev_residual"].shape[1]:
|
||||
# Extract latent channels (match output channels)
|
||||
num_output_channels = state["prev_residual"].shape[1]
|
||||
latent_channels = ctx.hidden_states[:, :num_output_channels]
|
||||
output = latent_channels + state["prev_residual"]
|
||||
else:
|
||||
output = ctx.hidden_states + state["prev_residual"]
|
||||
self.stats["cached"] += 1
|
||||
else:
|
||||
# Cache miss: Run full transformer and cache the residual
|
||||
outputs = ctx.run_transformer_blocks()
|
||||
output = outputs[0] if isinstance(outputs, tuple) else outputs
|
||||
|
||||
# Store residual (output - input) for next potential cache hit
|
||||
# For I2V: output may have fewer channels than input
|
||||
# Compute residual only on the latent channels
|
||||
if ctx.hidden_states.shape[1] != output.shape[1]:
|
||||
# Extract latent channels (match output channels)
|
||||
num_output_channels = output.shape[1]
|
||||
latent_channels = ctx.hidden_states[:, :num_output_channels]
|
||||
state["prev_residual"] = (output - latent_channels).detach()
|
||||
else:
|
||||
original = ctx.hidden_states.clone()
|
||||
state["prev_residual"] = (output - original).detach()
|
||||
|
||||
# Update state for next iteration
|
||||
state["prev_input"] = ctx.modulated_input.detach()
|
||||
state["cnt"] += 1
|
||||
|
||||
return ctx.postprocess(output)
|
||||
|
||||
def _should_compute(self, state, modulated_inp):
|
||||
"""Decide whether to compute transformer or use cached result.
|
||||
|
||||
Returns True to compute, False to use cache.
|
||||
"""
|
||||
# Warmup: Always compute first few steps to build stable cache
|
||||
if self.config.ret_steps and state["cnt"] < self.config.ret_steps:
|
||||
state["acc_dist"] = 0.0
|
||||
return True
|
||||
|
||||
# Cooldown: Always compute last few steps for quality
|
||||
if self.config.cutoff_steps and state["cnt"] >= self.config.cutoff_steps:
|
||||
return True
|
||||
|
||||
# First step: no previous input to compare
|
||||
if state["prev_input"] is None:
|
||||
return True
|
||||
|
||||
# Compute relative change in timestep embedding
|
||||
curr, prev = modulated_inp, state["prev_input"]
|
||||
|
||||
# For CFG (batch_size > 1), only compare conditional branch
|
||||
# Both branches move similarly, so one comparison is sufficient
|
||||
if modulated_inp.shape[0] > 1:
|
||||
curr, prev = modulated_inp.chunk(2)[1], prev.chunk(2)[1]
|
||||
|
||||
# Calculate relative L1 distance (normalized by magnitude)
|
||||
rel_dist = ((curr - prev).abs().mean() / (prev.abs().mean() + 1e-8)).cpu().item()
|
||||
|
||||
# Apply polynomial rescaling to adjust sensitivity
|
||||
# Accumulate distance (capped at 2x threshold to prevent overflow)
|
||||
rescaled = float(self.rescale_func(rel_dist))
|
||||
state["acc_dist"] = min(
|
||||
state["acc_dist"] + abs(rescaled), self.config.teacache_thresh * 2.0
|
||||
)
|
||||
|
||||
logger.debug(
|
||||
f"TeaCache: step {state['cnt']} | dist {rel_dist:.2e} | acc {state['acc_dist']:.4f}"
|
||||
)
|
||||
|
||||
# Cache decision based on accumulated distance
|
||||
if state["acc_dist"] < self.config.teacache_thresh:
|
||||
# Below threshold: use cache, apply decay to distance
|
||||
state["acc_dist"] *= 0.95
|
||||
return False
|
||||
else:
|
||||
# Above threshold: compute, reset accumulated distance
|
||||
state["acc_dist"] = 0.0
|
||||
return True
|
||||
|
||||
|
||||
class TeaCacheBackend:
|
||||
"""Manages TeaCache lifecycle."""
|
||||
|
||||
def __init__(self, config):
|
||||
self.config = config
|
||||
self.hook = None
|
||||
|
||||
def enable(self, module):
|
||||
if self.hook is None:
|
||||
logger.info(f"TeaCache: Enabling for {module.__class__.__name__}")
|
||||
self.hook = TeaCacheHook(self.config)
|
||||
self.hook.initialize(module)
|
||||
module._original_forward = module.forward
|
||||
module.forward = lambda *args, **kwargs: self.hook(module, *args, **kwargs)
|
||||
|
||||
def disable(self, module):
|
||||
if self.hook and hasattr(module, "_original_forward"):
|
||||
module.forward = module._original_forward
|
||||
self.hook = None
|
||||
|
||||
def refresh(self, num_inference_steps):
|
||||
"""Reset TeaCache state for a new generation.
|
||||
|
||||
Sets warmup/cutoff steps based on total inference steps:
|
||||
- Warmup steps: Always compute to build stable cache
|
||||
- Cutoff steps: Always compute for quality at the end
|
||||
- Middle steps: Use caching based on distance threshold
|
||||
|
||||
Args:
|
||||
num_inference_steps: Total number of denoising steps
|
||||
"""
|
||||
if not self.hook:
|
||||
return
|
||||
|
||||
# Reset cache state (clears previous residuals and counters)
|
||||
self.hook.reset_state()
|
||||
|
||||
# Configure warmup and cutoff based on mode
|
||||
if self.config.use_ret_steps:
|
||||
# Aggressive warmup: 5 steps to stabilize cache
|
||||
self.config.ret_steps = 5
|
||||
self.config.cutoff_steps = num_inference_steps # No cutoff (cache until end)
|
||||
else:
|
||||
# Minimal warmup: 1 step
|
||||
self.config.ret_steps = 1
|
||||
self.config.cutoff_steps = num_inference_steps - 2 # Compute last 2 steps
|
||||
|
||||
self.config.num_steps = num_inference_steps
|
||||
|
||||
logger.info(
|
||||
f"TeaCache: {num_inference_steps} steps | "
|
||||
f"warmup: {self.config.ret_steps}, cutoff: {self.config.cutoff_steps}, "
|
||||
f"thresh: {self.config.teacache_thresh}"
|
||||
)
|
||||
|
||||
def is_enabled(self):
|
||||
return self.hook is not None
|
||||
|
||||
def get_stats(self):
|
||||
return self.hook.get_stats() if self.hook else {}
|
||||
39
tensorrt_llm/_torch/visual_gen/utils.py
Normal file
39
tensorrt_llm/_torch/visual_gen/utils.py
Normal file
@ -0,0 +1,39 @@
|
||||
"""Utility functions for visual generation pipelines."""
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
@torch.compile
|
||||
def postprocess_video_tensor(video: torch.Tensor, remove_batch_dim: bool = True) -> torch.Tensor:
|
||||
"""Post-process video tensor from VAE decoder output to final format.
|
||||
|
||||
This is a more efficient implementation than using VideoProcessor for single-batch cases,
|
||||
as it avoids loop overhead and processes the entire batch with vectorized operations.
|
||||
|
||||
Args:
|
||||
video: Video tensor in (B, C, T, H, W) format from VAE decoder
|
||||
remove_batch_dim: Whether to remove batch dimension. Default True for typical
|
||||
single-batch video generation.
|
||||
|
||||
Returns:
|
||||
Post-processed video tensor:
|
||||
- If remove_batch_dim=True: (T, H, W, C) uint8 tensor
|
||||
- If remove_batch_dim=False: (B, T, H, W, C) uint8 tensor
|
||||
|
||||
Note:
|
||||
Assumes video values are in [-1, 1] range (standard VAE decoder output).
|
||||
"""
|
||||
# Convert to (B, T, H, W, C) format
|
||||
video = video.permute(0, 2, 3, 4, 1) # (B, C, T, H, W) -> (B, T, H, W, C)
|
||||
|
||||
# Normalize to [0, 1] range
|
||||
video = (video / 2 + 0.5).clamp(0, 1)
|
||||
|
||||
# Convert to uint8
|
||||
video = (video * 255).round().to(torch.uint8)
|
||||
|
||||
# Remove batch dimension if requested
|
||||
if remove_batch_dim:
|
||||
video = video[0] # (B, T, H, W, C) -> (T, H, W, C)
|
||||
|
||||
return video
|
||||
@ -19,11 +19,13 @@ from tensorrt_llm import LLM as PyTorchLLM
|
||||
from tensorrt_llm import MultimodalEncoder
|
||||
from tensorrt_llm._tensorrt_engine import LLM
|
||||
from tensorrt_llm._utils import mpi_rank
|
||||
from tensorrt_llm.commands.utils import (get_is_diffusion_model,
|
||||
get_visual_gen_model_type)
|
||||
from tensorrt_llm.executor.utils import LlmLauncherEnvs
|
||||
from tensorrt_llm.inputs.multimodal import MultimodalServerConfig
|
||||
from tensorrt_llm.llmapi import (BuildConfig, CapacitySchedulerPolicy,
|
||||
DynamicBatchConfig, KvCacheConfig,
|
||||
SchedulerConfig)
|
||||
SchedulerConfig, VisualGen)
|
||||
from tensorrt_llm.llmapi.disagg_utils import (DisaggClusterConfig,
|
||||
MetadataServerConfig, ServerRole,
|
||||
extract_disagg_cluster_config,
|
||||
@ -217,7 +219,7 @@ def launch_server(
|
||||
f"{backend} is not a known backend, check help for available options.",
|
||||
param_hint="backend")
|
||||
|
||||
server = OpenAIServer(llm=llm,
|
||||
server = OpenAIServer(generator=llm,
|
||||
model=model,
|
||||
tool_parser=tool_parser,
|
||||
server_role=server_role,
|
||||
@ -361,7 +363,7 @@ def launch_mm_encoder_server(
|
||||
encoder_args.pop("build_config")
|
||||
mm_encoder = MultimodalEncoder(**encoder_args)
|
||||
|
||||
server = OpenAIServer(llm=mm_encoder,
|
||||
server = OpenAIServer(generator=mm_encoder,
|
||||
model=model,
|
||||
server_role=ServerRole.MM_ENCODER,
|
||||
metadata_server_cfg=metadata_server_cfg,
|
||||
@ -369,6 +371,45 @@ def launch_mm_encoder_server(
|
||||
asyncio.run(server(host, port))
|
||||
|
||||
|
||||
def launch_visual_gen_server(
|
||||
host: str,
|
||||
port: int,
|
||||
visual_gen_config: dict,
|
||||
metadata_server_cfg: Optional[MetadataServerConfig] = None,
|
||||
):
|
||||
"""Launch a VISUAL_GEN model server for image/video generation.
|
||||
|
||||
Args:
|
||||
host: Server hostname.
|
||||
port: Server port.
|
||||
visual_gen_config: Arguments for VISUAL_GEN model initialization.
|
||||
metadata_server_cfg: Optional metadata server configuration.
|
||||
"""
|
||||
model = visual_gen_config["model"]
|
||||
logger.info(f"Initializing VisualGen ({model})")
|
||||
|
||||
n_workers = 1
|
||||
parallel_config = visual_gen_config.get("parallel", {})
|
||||
if parallel_config:
|
||||
n_workers = parallel_config.get(
|
||||
"dit_cfg_size", 1) * parallel_config.get("dit_ulysses_size", 1)
|
||||
logger.info(f"World size: {n_workers}")
|
||||
logger.info(f"CFG size: {parallel_config.get('dit_cfg_size', 1)}")
|
||||
logger.info(
|
||||
f"Ulysses size: {parallel_config.get('dit_ulysses_size', 1)}")
|
||||
|
||||
visual_gen_model = VisualGen(model_path=model,
|
||||
n_workers=n_workers,
|
||||
diffusion_config=visual_gen_config)
|
||||
|
||||
server = OpenAIServer(generator=visual_gen_model,
|
||||
model=model,
|
||||
server_role=ServerRole.VISUAL_GEN,
|
||||
metadata_server_cfg=metadata_server_cfg,
|
||||
tool_parser=None)
|
||||
asyncio.run(server(host, port))
|
||||
|
||||
|
||||
class ChoiceWithAlias(click.Choice):
|
||||
|
||||
def __init__(self,
|
||||
@ -600,6 +641,12 @@ class ChoiceWithAlias(click.Choice):
|
||||
default=False,
|
||||
help="Run gRPC server instead of OpenAI HTTP server. "
|
||||
"gRPC server accepts pre-tokenized requests and returns raw token IDs.")
|
||||
@click.option("--extra_visual_gen_options",
|
||||
type=str,
|
||||
default=None,
|
||||
help=help_info_with_stability_tag(
|
||||
"Path to a YAML file with extra VISUAL_GEN model options.",
|
||||
"prototype"))
|
||||
def serve(
|
||||
model: str, tokenizer: Optional[str], custom_tokenizer: Optional[str],
|
||||
host: str, port: int, log_level: str, backend: str, max_beam_width: int,
|
||||
@ -616,8 +663,8 @@ def serve(
|
||||
otlp_traces_endpoint: Optional[str], enable_chunked_prefill: bool,
|
||||
disagg_cluster_uri: Optional[str], media_io_kwargs: Optional[str],
|
||||
custom_module_dirs: list[Path], chat_template: Optional[str],
|
||||
grpc: bool):
|
||||
"""Running an OpenAI API compatible server (or gRPC server with --grpc flag)
|
||||
grpc: bool, extra_visual_gen_options: Optional[str]):
|
||||
"""Running an OpenAI API compatible server
|
||||
|
||||
MODEL: model name | HF checkpoint path | TensorRT engine path
|
||||
"""
|
||||
@ -630,93 +677,120 @@ def serve(
|
||||
logger.error(
|
||||
f"Failed to import custom module from {custom_module_dir}: {e}")
|
||||
raise e
|
||||
llm_args, _ = get_llm_args(
|
||||
model=model,
|
||||
tokenizer=tokenizer,
|
||||
custom_tokenizer=custom_tokenizer,
|
||||
backend=backend,
|
||||
max_beam_width=max_beam_width,
|
||||
max_batch_size=max_batch_size,
|
||||
max_num_tokens=max_num_tokens,
|
||||
max_seq_len=max_seq_len,
|
||||
tensor_parallel_size=tensor_parallel_size,
|
||||
pipeline_parallel_size=pipeline_parallel_size,
|
||||
context_parallel_size=context_parallel_size,
|
||||
moe_expert_parallel_size=moe_expert_parallel_size,
|
||||
moe_cluster_parallel_size=moe_cluster_parallel_size,
|
||||
gpus_per_node=gpus_per_node,
|
||||
free_gpu_memory_fraction=free_gpu_memory_fraction,
|
||||
num_postprocess_workers=num_postprocess_workers,
|
||||
trust_remote_code=trust_remote_code,
|
||||
revision=revision,
|
||||
reasoning_parser=reasoning_parser,
|
||||
fail_fast_on_attention_window_too_large=
|
||||
fail_fast_on_attention_window_too_large,
|
||||
otlp_traces_endpoint=otlp_traces_endpoint,
|
||||
enable_chunked_prefill=enable_chunked_prefill)
|
||||
|
||||
llm_args_extra_dict = {}
|
||||
if extra_llm_api_options is not None:
|
||||
with open(extra_llm_api_options, 'r') as f:
|
||||
llm_args_extra_dict = yaml.safe_load(f)
|
||||
llm_args = update_llm_args_with_extra_dict(llm_args, llm_args_extra_dict)
|
||||
def _serve_llm():
|
||||
nonlocal server_role
|
||||
llm_args, _ = get_llm_args(
|
||||
model=model,
|
||||
tokenizer=tokenizer,
|
||||
custom_tokenizer=custom_tokenizer,
|
||||
backend=backend,
|
||||
max_beam_width=max_beam_width,
|
||||
max_batch_size=max_batch_size,
|
||||
max_num_tokens=max_num_tokens,
|
||||
max_seq_len=max_seq_len,
|
||||
tensor_parallel_size=tensor_parallel_size,
|
||||
pipeline_parallel_size=pipeline_parallel_size,
|
||||
context_parallel_size=context_parallel_size,
|
||||
moe_expert_parallel_size=moe_expert_parallel_size,
|
||||
moe_cluster_parallel_size=moe_cluster_parallel_size,
|
||||
gpus_per_node=gpus_per_node,
|
||||
free_gpu_memory_fraction=free_gpu_memory_fraction,
|
||||
num_postprocess_workers=num_postprocess_workers,
|
||||
trust_remote_code=trust_remote_code,
|
||||
revision=revision,
|
||||
reasoning_parser=reasoning_parser,
|
||||
fail_fast_on_attention_window_too_large=
|
||||
fail_fast_on_attention_window_too_large,
|
||||
otlp_traces_endpoint=otlp_traces_endpoint,
|
||||
enable_chunked_prefill=enable_chunked_prefill)
|
||||
|
||||
metadata_server_cfg = parse_metadata_server_config_file(
|
||||
metadata_server_config_file)
|
||||
llm_args_extra_dict = {}
|
||||
if extra_llm_api_options is not None:
|
||||
with open(extra_llm_api_options, 'r') as f:
|
||||
llm_args_extra_dict = yaml.safe_load(f)
|
||||
llm_args = update_llm_args_with_extra_dict(llm_args,
|
||||
llm_args_extra_dict)
|
||||
|
||||
# Specify disagg_cluster_config in config file or through command line "--disagg_cluster_uri",
|
||||
# but disagg_cluster_uri takes precedence over cluster uri in config file
|
||||
disagg_cluster_config = llm_args.pop("disagg_cluster", None)
|
||||
if disagg_cluster_config:
|
||||
disagg_cluster_config = extract_disagg_cluster_config(
|
||||
disagg_cluster_config, disagg_cluster_uri)
|
||||
elif disagg_cluster_uri:
|
||||
disagg_cluster_config = DisaggClusterConfig(
|
||||
cluster_uri=disagg_cluster_uri)
|
||||
metadata_server_cfg = parse_metadata_server_config_file(
|
||||
metadata_server_config_file)
|
||||
|
||||
if metadata_server_cfg is not None or disagg_cluster_config is not None:
|
||||
assert (
|
||||
server_role is not None
|
||||
), "server_role is required when metadata_server_cfg or disagg_cluster_config is provided"
|
||||
try:
|
||||
server_role = ServerRole[server_role.upper()]
|
||||
except ValueError:
|
||||
raise ValueError(f"Invalid server role: {server_role}. " \
|
||||
f"Must be one of: {', '.join([role.name for role in ServerRole])}")
|
||||
# Specify disagg_cluster_config in config file or through command line "--disagg_cluster_uri",
|
||||
# but disagg_cluster_uri takes precedence over cluster uri in config file
|
||||
disagg_cluster_config = llm_args.pop("disagg_cluster", None)
|
||||
if disagg_cluster_config:
|
||||
disagg_cluster_config = extract_disagg_cluster_config(
|
||||
disagg_cluster_config, disagg_cluster_uri)
|
||||
elif disagg_cluster_uri:
|
||||
disagg_cluster_config = DisaggClusterConfig(
|
||||
cluster_uri=disagg_cluster_uri)
|
||||
|
||||
# Parse media_io_kwargs from JSON string to dict if provided
|
||||
parsed_media_io_kwargs = None
|
||||
if media_io_kwargs is not None:
|
||||
try:
|
||||
parsed_media_io_kwargs = json.loads(media_io_kwargs)
|
||||
except json.JSONDecodeError as e:
|
||||
raise ValueError(f"Invalid JSON for media_io_kwargs: {e}")
|
||||
if metadata_server_cfg is not None or disagg_cluster_config is not None:
|
||||
assert (
|
||||
server_role is not None
|
||||
), "server_role is required when metadata_server_cfg or disagg_cluster_config is provided"
|
||||
try:
|
||||
server_role = ServerRole[server_role.upper()]
|
||||
except ValueError:
|
||||
raise ValueError(f"Invalid server role: {server_role}. " \
|
||||
f"Must be one of: {', '.join([role.name for role in ServerRole])}")
|
||||
# Parse media_io_kwargs from JSON string to dict if provided
|
||||
parsed_media_io_kwargs = None
|
||||
if media_io_kwargs is not None:
|
||||
try:
|
||||
parsed_media_io_kwargs = json.loads(media_io_kwargs)
|
||||
except json.JSONDecodeError as e:
|
||||
raise ValueError(f"Invalid JSON for media_io_kwargs: {e}")
|
||||
|
||||
multimodal_server_config = MultimodalServerConfig(
|
||||
media_io_kwargs=parsed_media_io_kwargs)
|
||||
multimodal_server_config = MultimodalServerConfig(
|
||||
media_io_kwargs=parsed_media_io_kwargs)
|
||||
|
||||
if grpc:
|
||||
# gRPC mode: launch gRPC server instead of OpenAI HTTP server
|
||||
# Check for unsupported arguments that are silently ignored in gRPC mode
|
||||
unsupported_args = {
|
||||
"tool_parser": tool_parser,
|
||||
"chat_template": chat_template,
|
||||
"metadata_server_config_file": metadata_server_config_file,
|
||||
"server_role": server_role,
|
||||
"disagg_cluster_config": disagg_cluster_config,
|
||||
if grpc:
|
||||
# gRPC mode: launch gRPC server instead of OpenAI HTTP server
|
||||
# Check for unsupported arguments that are silently ignored in gRPC mode
|
||||
unsupported_args = {
|
||||
"tool_parser": tool_parser,
|
||||
"chat_template": chat_template,
|
||||
"metadata_server_config_file": metadata_server_config_file,
|
||||
"server_role": server_role,
|
||||
"disagg_cluster_config": disagg_cluster_config,
|
||||
}
|
||||
for name, value in unsupported_args.items():
|
||||
if value is not None:
|
||||
raise ValueError(
|
||||
f"Argument '{name}' is not supported when running in gRPC mode. "
|
||||
f"The gRPC server is designed for use with external routers that handle "
|
||||
f"these features (e.g., tool parsing, chat templates).")
|
||||
launch_grpc_server(host, port, llm_args)
|
||||
else:
|
||||
# Default: launch OpenAI HTTP server
|
||||
launch_server(host, port, llm_args, tool_parser, chat_template,
|
||||
metadata_server_cfg, server_role,
|
||||
disagg_cluster_config, multimodal_server_config)
|
||||
|
||||
def _serve_visual_gen():
|
||||
visual_gen_config = {
|
||||
"model": model,
|
||||
"model_type": get_visual_gen_model_type(model),
|
||||
}
|
||||
for name, value in unsupported_args.items():
|
||||
if value is not None:
|
||||
raise ValueError(
|
||||
f"Argument '{name}' is not supported when running in gRPC mode. "
|
||||
f"The gRPC server is designed for use with external routers that handle "
|
||||
f"these features (e.g., tool parsing, chat templates).")
|
||||
launch_grpc_server(host, port, llm_args)
|
||||
|
||||
visual_gen_extra_args = {}
|
||||
if extra_visual_gen_options is not None:
|
||||
with open(extra_visual_gen_options, 'r') as f:
|
||||
visual_gen_extra_args = yaml.safe_load(f)
|
||||
|
||||
visual_gen_config.update(visual_gen_extra_args)
|
||||
|
||||
metadata_server_cfg = parse_metadata_server_config_file(
|
||||
metadata_server_config_file)
|
||||
|
||||
launch_visual_gen_server(host, port, visual_gen_config,
|
||||
metadata_server_cfg)
|
||||
|
||||
if get_is_diffusion_model(model):
|
||||
_serve_visual_gen()
|
||||
else:
|
||||
# Default: launch OpenAI HTTP server
|
||||
launch_server(host, port, llm_args, tool_parser, chat_template,
|
||||
metadata_server_cfg, server_role, disagg_cluster_config,
|
||||
multimodal_server_config)
|
||||
_serve_llm()
|
||||
|
||||
|
||||
@click.command("mm_embedding_serve")
|
||||
|
||||
132
tensorrt_llm/commands/utils.py
Normal file
132
tensorrt_llm/commands/utils.py
Normal file
@ -0,0 +1,132 @@
|
||||
# Adapted from https://github.com/sgl-project/sglang/blob/030496eb06472f76fcb11de53d93f10cefb4604f/python/sglang/cli/utils.py#L27
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
|
||||
from tensorrt_llm.llmapi.utils import download_hf_partial
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _maybe_download_model(
|
||||
model_name_or_path: str, local_dir: str | None = None, download: bool = True
|
||||
) -> str:
|
||||
"""Resolve a model path. If it's a local directory, return it.
|
||||
|
||||
If it's a Hugging Face Hub ID, download only the config file
|
||||
(`model_index.json` or `config.json`) and return its directory.
|
||||
|
||||
Args:
|
||||
model_name_or_path: Local path or Hugging Face Hub model ID
|
||||
local_dir: Local directory to save the downloaded file (if any)
|
||||
download: Whether to download from Hugging Face Hub when needed
|
||||
|
||||
Returns:
|
||||
Local directory path that contains the downloaded config file, or the original local directory.
|
||||
"""
|
||||
if os.path.exists(model_name_or_path):
|
||||
logger.info("Model already exists locally")
|
||||
return model_name_or_path
|
||||
|
||||
if not download:
|
||||
return model_name_or_path
|
||||
|
||||
try:
|
||||
logger.info(
|
||||
"Downloading model_index.json from HF Hub for %s...",
|
||||
model_name_or_path,
|
||||
)
|
||||
file_path = download_hf_partial(
|
||||
model=model_name_or_path,
|
||||
allow_patterns=["model_index.json", "config.json"],
|
||||
)
|
||||
logger.info("Downloaded to %s", file_path)
|
||||
return str(file_path)
|
||||
except Exception as e:
|
||||
raise ValueError(
|
||||
(
|
||||
"Could not find model locally at %s and failed to download "
|
||||
"model_index.json/config.json from HF Hub: %s"
|
||||
)
|
||||
% (model_name_or_path, e)
|
||||
) from e
|
||||
|
||||
|
||||
# Copied and adapted from hf_diffusers_utils.py
|
||||
def is_diffusers_model_path(model_path: str) -> bool:
|
||||
"""Verify if the model directory contains a valid diffusers configuration.
|
||||
|
||||
Args:
|
||||
model_path: Path to the model directory
|
||||
|
||||
Returns:
|
||||
The loaded model configuration as a dictionary if the model is a diffusers model
|
||||
None if the model is not a diffusers model
|
||||
"""
|
||||
# Prefer model_index.json which indicates a diffusers pipeline
|
||||
config_path = os.path.join(model_path, "model_index.json")
|
||||
if not os.path.exists(config_path):
|
||||
return False
|
||||
|
||||
# Load the config
|
||||
with open(config_path) as f:
|
||||
config = json.load(f)
|
||||
|
||||
# Verify diffusers version exists
|
||||
if "_diffusers_version" not in config:
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
def get_is_diffusion_model(model_path: str):
|
||||
model_path = _maybe_download_model(model_path)
|
||||
is_diffusion_model = is_diffusers_model_path(model_path)
|
||||
if is_diffusion_model:
|
||||
logger.info("Diffusion model detected")
|
||||
return is_diffusion_model
|
||||
|
||||
|
||||
def get_model_path(extra_argv):
|
||||
# Find the model_path argument
|
||||
model_path = None
|
||||
for i, arg in enumerate(extra_argv):
|
||||
if arg == "--model-path":
|
||||
if i + 1 < len(extra_argv):
|
||||
model_path = extra_argv[i + 1]
|
||||
break
|
||||
elif arg.startswith("--model-path="):
|
||||
model_path = arg.split("=", 1)[1]
|
||||
break
|
||||
|
||||
if model_path is None:
|
||||
# Fallback for --help or other cases where model-path is not provided
|
||||
if any(h in extra_argv for h in ["-h", "--help"]):
|
||||
raise Exception(
|
||||
"Usage: sglang serve --model-path <model-name-or-path> [additional-arguments]\n\n"
|
||||
"This command can launch either a standard language model server or a diffusion model server.\n"
|
||||
"The server type is determined by the model path.\n"
|
||||
"For specific arguments, please provide a model_path."
|
||||
)
|
||||
else:
|
||||
raise Exception(
|
||||
"Error: --model-path is required. Please provide the path to the model."
|
||||
)
|
||||
return model_path
|
||||
|
||||
|
||||
VISUAL_GEN_PARTIAL_MODEL_NAME_TO_MODEL_TYPE = {
|
||||
"FLUX.2": "flux2",
|
||||
"LTX-2": "ltx2",
|
||||
"Wan2": "wan2",
|
||||
}
|
||||
|
||||
|
||||
def get_visual_gen_model_type(model_path: str):
|
||||
for partial_model_name, model_type in VISUAL_GEN_PARTIAL_MODEL_NAME_TO_MODEL_TYPE.items():
|
||||
if partial_model_name.lower() in model_path.lower():
|
||||
return model_type
|
||||
|
||||
raise ValueError(
|
||||
f"Unknown VISUAL_GEN model type for model path: {model_path},"
|
||||
f"available models: {VISUAL_GEN_PARTIAL_MODEL_NAME_TO_MODEL_TYPE.keys()}"
|
||||
)
|
||||
@ -83,8 +83,7 @@ class ZeroMqQueue:
|
||||
"Server and client should not receive HMAC key when encryption is disabled"
|
||||
)
|
||||
|
||||
if (socket_type == zmq.PAIR and self.is_server
|
||||
) or socket_type == zmq.PULL or socket_type == zmq.ROUTER:
|
||||
if self.should_bind_socket():
|
||||
self.socket.bind(
|
||||
self.address_endpoint
|
||||
) # Binds to the address and occupy a port immediately
|
||||
@ -101,6 +100,31 @@ class ZeroMqQueue:
|
||||
|
||||
self.address = (self.address_endpoint, self.hmac_key)
|
||||
|
||||
def should_bind_socket(self) -> bool:
|
||||
"""
|
||||
Determine if socket should bind vs connect based on type and role.
|
||||
|
||||
ZMQ binding conventions:
|
||||
- PAIR: server binds, client connects (1-to-1 bidirectional)
|
||||
- PULL: server binds to receive from multiple PUSH sockets
|
||||
- PUSH: server binds when acting as message source
|
||||
- ROUTER: always binds to handle multiple clients
|
||||
|
||||
Returns:
|
||||
True if socket should bind, False if it should connect
|
||||
"""
|
||||
# Server binds for PAIR, PULL, PUSH patterns
|
||||
if self.is_server and self.socket_type in (zmq.PAIR, zmq.PULL,
|
||||
zmq.PUSH):
|
||||
return True
|
||||
|
||||
# ROUTER always binds (multi-client pattern)
|
||||
if self.socket_type == zmq.ROUTER:
|
||||
return True
|
||||
|
||||
# Client connects for all other cases
|
||||
return False
|
||||
|
||||
def setup_lazily(self):
|
||||
# Early return if setup is already done
|
||||
if self._setup_done:
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
# Adapt from
|
||||
# https://github.com/vllm-project/vllm/blob/2e33fe419186c65a18da6668972d61d7bbc31564/vllm/inputs/data.py
|
||||
from typing import Any, Dict, List, Union
|
||||
from typing import Any, Dict, List, Sequence, Union
|
||||
|
||||
from typing_extensions import NotRequired, TypedDict
|
||||
|
||||
@ -85,3 +85,80 @@ def prompt_inputs(inputs: PromptInputs, ) -> Union[TextPrompt, TokensPrompt]:
|
||||
f"Invalid type of inputs for llm.generate: {type(inputs)}")
|
||||
|
||||
return prompt_inputs
|
||||
|
||||
|
||||
class VisualGenTextPrompt(TypedDict):
|
||||
prompt: str
|
||||
negative_prompt: NotRequired[str]
|
||||
|
||||
|
||||
class VisualGenTokensPrompt(TypedDict):
|
||||
prompt_token_ids: List[int]
|
||||
negative_prompt_token_ids: NotRequired[List[int]]
|
||||
|
||||
|
||||
VisualGenPromptInputs = Union[
|
||||
str,
|
||||
List[int],
|
||||
VisualGenTextPrompt,
|
||||
VisualGenTokensPrompt,
|
||||
]
|
||||
|
||||
VisualGenInputs = Union[
|
||||
VisualGenPromptInputs,
|
||||
Sequence[VisualGenPromptInputs],
|
||||
]
|
||||
|
||||
|
||||
def visual_gen_inputs(
|
||||
inputs: "VisualGenPromptInputs",
|
||||
) -> Union["VisualGenTextPrompt", "VisualGenTokensPrompt"]:
|
||||
# str -> text prompt
|
||||
if isinstance(inputs, str):
|
||||
return VisualGenTextPrompt(prompt=inputs)
|
||||
|
||||
# list[int] -> token prompt
|
||||
if isinstance(inputs, list):
|
||||
if len(inputs) == 0:
|
||||
raise ValueError("`inputs` token list cannot be empty.")
|
||||
if not all(isinstance(t, int) for t in inputs):
|
||||
raise TypeError(
|
||||
"`inputs` list must contain only ints when used as token IDs.")
|
||||
return VisualGenTokensPrompt(prompt_token_ids=inputs)
|
||||
|
||||
# dict form
|
||||
if isinstance(inputs, dict):
|
||||
has_prompt = "prompt" in inputs
|
||||
has_prompt_token_ids = "prompt_token_ids" in inputs
|
||||
|
||||
if has_prompt == has_prompt_token_ids:
|
||||
raise ValueError(
|
||||
"VisualGen prompt dict must contain exactly one of "
|
||||
"`prompt` or `prompt_token_ids`.")
|
||||
|
||||
if has_prompt:
|
||||
prompt = inputs.get("prompt")
|
||||
if not isinstance(prompt, str) or prompt == "":
|
||||
raise TypeError("`prompt` must be a non-empty string.")
|
||||
if "negative_prompt" in inputs and not isinstance(
|
||||
inputs["negative_prompt"], str):
|
||||
raise TypeError("`negative_prompt` must be a string.")
|
||||
return inputs # VisualGenTextPrompt
|
||||
|
||||
token_ids = inputs.get("prompt_token_ids")
|
||||
if not isinstance(token_ids, list) or len(token_ids) == 0:
|
||||
raise TypeError("`prompt_token_ids` must be a non-empty list[int].")
|
||||
if not all(isinstance(t, int) for t in token_ids):
|
||||
raise TypeError("`prompt_token_ids` must contain only ints.")
|
||||
if "negative_prompt_token_ids" in inputs:
|
||||
neg_ids = inputs["negative_prompt_token_ids"]
|
||||
if not isinstance(neg_ids, list) or not all(
|
||||
isinstance(t, int) for t in neg_ids):
|
||||
raise TypeError(
|
||||
"`negative_prompt_token_ids` must be a list[int].")
|
||||
return inputs # VisualGenTokensPrompt
|
||||
|
||||
raise TypeError(
|
||||
"Invalid `inputs` for VisualGen.generate. "
|
||||
"Expected one of: str, list[int], VisualGenTextPrompt, VisualGenTokensPrompt."
|
||||
)
|
||||
|
||||
@ -22,10 +22,13 @@ from .llm_utils import (BuildConfig, KvCacheRetentionConfig, QuantAlgo,
|
||||
QuantConfig)
|
||||
from .mm_encoder import MultimodalEncoder
|
||||
from .mpi_session import MpiCommSession
|
||||
from .visual_gen import VisualGen, VisualGenParams
|
||||
|
||||
__all__ = [
|
||||
'LLM',
|
||||
'AsyncLLM',
|
||||
'VisualGen',
|
||||
'VisualGenParams',
|
||||
'MultimodalEncoder',
|
||||
'CompletionOutput',
|
||||
'RequestOutput',
|
||||
|
||||
@ -24,6 +24,7 @@ class ServerRole(IntEnum):
|
||||
CONTEXT = 0
|
||||
GENERATION = 1
|
||||
MM_ENCODER = 2
|
||||
VISUAL_GEN = 3
|
||||
|
||||
|
||||
@dataclass
|
||||
|
||||
@ -236,18 +236,34 @@ def download_hf_model(model: str, revision: Optional[str] = None) -> Path:
|
||||
return Path(hf_folder)
|
||||
|
||||
|
||||
def download_hf_pretrained_config(model: str,
|
||||
revision: Optional[str] = None) -> Path:
|
||||
def download_hf_partial(model: str,
|
||||
allow_patterns: List[str],
|
||||
revision: Optional[str] = None) -> Path:
|
||||
"""Download a partial model from HuggingFace.
|
||||
|
||||
Args:
|
||||
model: The model name or path.
|
||||
revision: The revision to use for the model.
|
||||
allow_patterns: The patterns to allow for the model.
|
||||
|
||||
Returns:
|
||||
The path to the downloaded model.
|
||||
"""
|
||||
with get_file_lock(model):
|
||||
hf_folder = snapshot_download(
|
||||
model,
|
||||
local_files_only=huggingface_hub.constants.HF_HUB_OFFLINE,
|
||||
revision=revision,
|
||||
allow_patterns=["config.json"],
|
||||
allow_patterns=allow_patterns,
|
||||
tqdm_class=DisabledTqdm)
|
||||
return Path(hf_folder)
|
||||
|
||||
|
||||
def download_hf_pretrained_config(model: str,
|
||||
revision: Optional[str] = None) -> Path:
|
||||
return download_hf_partial(model, ["config.json"], revision)
|
||||
|
||||
|
||||
def append_docstring(docstring: str):
|
||||
''' A decorator to append a docstring to a function. '''
|
||||
|
||||
|
||||
544
tensorrt_llm/llmapi/visual_gen.py
Normal file
544
tensorrt_llm/llmapi/visual_gen.py
Normal file
@ -0,0 +1,544 @@
|
||||
import asyncio
|
||||
import queue
|
||||
import socket
|
||||
import threading
|
||||
import time
|
||||
import traceback
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
|
||||
import torch.multiprocessing as mp
|
||||
import zmq
|
||||
|
||||
from tensorrt_llm._torch.visual_gen import DiffusionRequest, DiffusionResponse
|
||||
from tensorrt_llm._torch.visual_gen.executor import run_diffusion_worker
|
||||
from tensorrt_llm._torch.visual_gen.output import MediaOutput
|
||||
|
||||
__all__ = ["VisualGen", "VisualGenParams", "MediaOutput"]
|
||||
from tensorrt_llm.executor.ipc import ZeroMqQueue
|
||||
from tensorrt_llm.inputs.data import VisualGenInputs
|
||||
from tensorrt_llm.logger import logger
|
||||
|
||||
# Timeouts (seconds)
|
||||
POLL_TIMEOUT = 0.01
|
||||
AWAIT_TIMEOUT = 0.05
|
||||
THREAD_TIMEOUT = 5.0
|
||||
WORKER_TIMEOUT = 2.0
|
||||
READY_TIMEOUT = 1200 # 20 minutes for large models (Wan 2.2 with transformer_2)
|
||||
|
||||
|
||||
def find_free_port() -> int:
|
||||
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
|
||||
s.bind(("", 0))
|
||||
return s.getsockname()[1]
|
||||
|
||||
|
||||
def get_ip_address() -> str:
|
||||
"""Get local IP address."""
|
||||
s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
|
||||
try:
|
||||
s.connect(("10.255.255.255", 1))
|
||||
return s.getsockname()[0]
|
||||
except Exception:
|
||||
return "127.0.0.1"
|
||||
finally:
|
||||
s.close()
|
||||
|
||||
|
||||
class DiffusionRemoteClient:
|
||||
"""Client proxy for remote DiffusionExecutor in worker processes."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model_path: Union[str, Path],
|
||||
n_workers: int = 1,
|
||||
diffusion_config: Optional[dict] = None,
|
||||
):
|
||||
self.model_path = str(model_path)
|
||||
self.n_workers = n_workers
|
||||
self.diffusion_config = diffusion_config
|
||||
|
||||
# Setup distributed env
|
||||
self.master_addr = "127.0.0.1"
|
||||
self.master_port = find_free_port()
|
||||
|
||||
# Setup IPC addresses
|
||||
self.host_ip = get_ip_address()
|
||||
req_port, resp_port = find_free_port(), find_free_port()
|
||||
|
||||
self.request_queue_addr = f"tcp://0.0.0.0:{req_port}"
|
||||
self.response_queue_addr = f"tcp://0.0.0.0:{resp_port}"
|
||||
self.req_addr_connect = f"tcp://{self.host_ip}:{req_port}"
|
||||
self.resp_addr_connect = f"tcp://{self.host_ip}:{resp_port}"
|
||||
|
||||
# IPC setup
|
||||
self.requests_ipc = None
|
||||
self.responses_ipc = None
|
||||
self.pending_requests = queue.Queue()
|
||||
self.completed_responses: Dict[int, DiffusionResponse] = {}
|
||||
|
||||
# We'll create asyncio primitives in the background thread's event loop
|
||||
self._event_loop = None
|
||||
self.response_event = None
|
||||
self.lock = None
|
||||
self.shutdown_event = threading.Event()
|
||||
self.event_loop_ready = threading.Event()
|
||||
|
||||
# Start background thread (it will create its own event loop)
|
||||
self.background_thread = threading.Thread(target=self._serve_forever_thread, daemon=True)
|
||||
self.background_thread.start()
|
||||
|
||||
# Wait for the background thread to initialize the event loop
|
||||
self.event_loop_ready.wait()
|
||||
|
||||
# Launch workers
|
||||
logger.info(f"DiffusionClient: Launching {n_workers} workers")
|
||||
ctx = mp.get_context("spawn")
|
||||
self.worker_processes = []
|
||||
for rank in range(n_workers):
|
||||
p = ctx.Process(
|
||||
target=run_diffusion_worker,
|
||||
kwargs={
|
||||
"rank": rank,
|
||||
"world_size": n_workers,
|
||||
"master_addr": self.master_addr,
|
||||
"master_port": self.master_port,
|
||||
"model_path": self.model_path,
|
||||
"request_queue_addr": self.req_addr_connect,
|
||||
"response_queue_addr": self.resp_addr_connect,
|
||||
"diffusion_config": self.diffusion_config,
|
||||
},
|
||||
)
|
||||
p.start()
|
||||
self.worker_processes.append(p)
|
||||
|
||||
self._wait_ready()
|
||||
|
||||
@staticmethod
|
||||
def _close_socket(ipc_queue):
|
||||
if ipc_queue and ipc_queue.socket:
|
||||
ipc_queue.socket.setsockopt(zmq.LINGER, 0)
|
||||
ipc_queue.close()
|
||||
|
||||
def enqueue_requests(self, requests: List[DiffusionRequest]) -> List[int]:
|
||||
"""Enqueue requests and return their IDs."""
|
||||
req_ids = []
|
||||
for req in requests:
|
||||
self.pending_requests.put(req)
|
||||
req_ids.append(req.request_id)
|
||||
return req_ids
|
||||
|
||||
async def await_responses(
|
||||
self, request_ids: Union[int, List[int]], timeout: Optional[float] = None
|
||||
) -> Union[DiffusionResponse, List[DiffusionResponse]]:
|
||||
"""Wait for responses by request IDs.
|
||||
|
||||
Args:
|
||||
request_ids: Single request ID or list of request IDs to wait for
|
||||
timeout: Maximum total wait time in seconds (None = wait indefinitely)
|
||||
|
||||
Returns:
|
||||
Single response or list of responses (None if request timed out)
|
||||
"""
|
||||
is_single = isinstance(request_ids, int)
|
||||
ids = [request_ids] if is_single else request_ids
|
||||
|
||||
start_time = time.time()
|
||||
results = {}
|
||||
|
||||
while len(results) < len(ids):
|
||||
async with self.lock:
|
||||
for req_id in ids:
|
||||
if req_id in self.completed_responses:
|
||||
results[req_id] = self.completed_responses.pop(req_id)
|
||||
|
||||
# All responses collected
|
||||
if len(results) == len(ids):
|
||||
break
|
||||
|
||||
# Check if overall timeout exceeded
|
||||
if timeout is not None:
|
||||
elapsed = time.time() - start_time
|
||||
if elapsed >= timeout:
|
||||
break
|
||||
# Wait for remaining time or AWAIT_TIMEOUT, whichever is shorter
|
||||
wait_time = min(timeout - elapsed, AWAIT_TIMEOUT)
|
||||
else:
|
||||
wait_time = AWAIT_TIMEOUT
|
||||
|
||||
try:
|
||||
await asyncio.wait_for(self.response_event.wait(), timeout=wait_time)
|
||||
except asyncio.TimeoutError:
|
||||
pass
|
||||
self.response_event.clear()
|
||||
|
||||
out = [results.get(rid) for rid in ids]
|
||||
return out[0] if is_single else out
|
||||
|
||||
def await_responses_sync(
|
||||
self, request_ids: Union[int, List[int]], timeout: Optional[float] = None
|
||||
) -> Union[DiffusionResponse, List[DiffusionResponse]]:
|
||||
"""Sync wrapper to await responses from the main thread."""
|
||||
future = asyncio.run_coroutine_threadsafe(
|
||||
self.await_responses(request_ids, timeout), self._event_loop
|
||||
)
|
||||
return future.result(timeout=timeout if timeout else None)
|
||||
|
||||
def _init_ipc(self) -> bool:
|
||||
"""Initialize IPC queues."""
|
||||
try:
|
||||
logger.info("DiffusionClient: Initializing IPC")
|
||||
self.requests_ipc = ZeroMqQueue(
|
||||
(self.request_queue_addr, None),
|
||||
is_server=True,
|
||||
socket_type=zmq.PUSH,
|
||||
use_hmac_encryption=False,
|
||||
)
|
||||
self.responses_ipc = ZeroMqQueue(
|
||||
(self.response_queue_addr, None),
|
||||
is_server=True,
|
||||
socket_type=zmq.PULL,
|
||||
use_hmac_encryption=False,
|
||||
)
|
||||
logger.info("DiffusionClient: IPC ready")
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.error(f"DiffusionClient: IPC init failed: {e}")
|
||||
return False
|
||||
|
||||
def _send_shutdown(self):
|
||||
"""Send shutdown signal."""
|
||||
logger.info("DiffusionClient: Sending shutdown signal")
|
||||
if self.requests_ipc:
|
||||
self.requests_ipc.put(None)
|
||||
self._close_socket(self.requests_ipc)
|
||||
|
||||
def _process_requests(self):
|
||||
"""Process pending requests."""
|
||||
try:
|
||||
req = self.pending_requests.get(timeout=POLL_TIMEOUT)
|
||||
if req is None:
|
||||
self._send_shutdown()
|
||||
self.shutdown_event.set()
|
||||
return
|
||||
|
||||
logger.info(f"DiffusionClient: Sending request {req.request_id}")
|
||||
self.requests_ipc.put(req)
|
||||
except queue.Empty:
|
||||
pass
|
||||
except Exception as e:
|
||||
logger.error(f"DiffusionClient: Error sending request: {e}")
|
||||
logger.error(traceback.format_exc())
|
||||
|
||||
def _process_responses(self):
|
||||
"""Poll and process responses."""
|
||||
try:
|
||||
if self.responses_ipc.poll(timeout=POLL_TIMEOUT):
|
||||
response = self.responses_ipc.get()
|
||||
if isinstance(response, DiffusionResponse):
|
||||
if response.request_id == -1:
|
||||
logger.info("DiffusionClient: Received READY signal")
|
||||
|
||||
# Schedule the lock acquisition and event setting in the event loop
|
||||
asyncio.run_coroutine_threadsafe(
|
||||
self._store_response(response), self._event_loop
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"DiffusionClient: Error processing response: {e}")
|
||||
|
||||
async def _store_response(self, response: DiffusionResponse):
|
||||
"""Store response in the completed_responses dict (async helper)."""
|
||||
async with self.lock:
|
||||
self.completed_responses[response.request_id] = response
|
||||
self.response_event.set()
|
||||
|
||||
def _cleanup_ipc(self):
|
||||
"""Cleanup IPC."""
|
||||
logger.info("DiffusionClient: Cleaning up IPC")
|
||||
self._close_socket(self.requests_ipc)
|
||||
self._close_socket(self.responses_ipc)
|
||||
|
||||
def _serve_forever_thread(self):
|
||||
"""Background thread wrapper that creates and runs an event loop."""
|
||||
logger.info("DiffusionClient: Background thread started")
|
||||
|
||||
# Create a new event loop for this thread
|
||||
self._event_loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(self._event_loop)
|
||||
|
||||
# Create async primitives in this thread's event loop
|
||||
self.response_event = asyncio.Event()
|
||||
self.lock = asyncio.Lock()
|
||||
|
||||
# Signal that the event loop is ready
|
||||
self.event_loop_ready.set()
|
||||
|
||||
# Run the async serve_forever
|
||||
try:
|
||||
self._event_loop.run_until_complete(self._serve_forever())
|
||||
finally:
|
||||
self._event_loop.close()
|
||||
logger.info("DiffusionClient: Background thread stopped")
|
||||
|
||||
async def _serve_forever(self):
|
||||
"""Background thread main loop (async version)."""
|
||||
if not self._init_ipc():
|
||||
return
|
||||
|
||||
while not self.shutdown_event.is_set():
|
||||
self._process_requests()
|
||||
self._process_responses()
|
||||
await asyncio.sleep(0.001) # Yield control to allow other coroutines to run
|
||||
|
||||
self._cleanup_ipc()
|
||||
|
||||
def shutdown(self):
|
||||
"""Shutdown client and workers."""
|
||||
logger.info("DiffusionClient: Shutting down")
|
||||
self.pending_requests.put(None)
|
||||
|
||||
self.background_thread.join(timeout=THREAD_TIMEOUT)
|
||||
if self.background_thread.is_alive():
|
||||
logger.warning("DiffusionClient: Force stopping background thread")
|
||||
self.shutdown_event.set()
|
||||
self.background_thread.join(timeout=1.0)
|
||||
|
||||
# Shutdown workers
|
||||
logger.info("DiffusionClient: Stopping workers")
|
||||
for p in self.worker_processes:
|
||||
p.join(timeout=WORKER_TIMEOUT)
|
||||
if p.is_alive():
|
||||
logger.warning(f"DiffusionClient: Terminating worker {p.pid} with SIGTERM")
|
||||
p.terminate()
|
||||
p.join(timeout=WORKER_TIMEOUT)
|
||||
if p.is_alive():
|
||||
logger.warning(f"DiffusionClient: Force killing worker {p.pid} with SIGKILL")
|
||||
p.kill()
|
||||
p.join(timeout=WORKER_TIMEOUT)
|
||||
|
||||
def _wait_ready(self, timeout: float = READY_TIMEOUT):
|
||||
"""Wait for workers to be ready (sync wrapper for async operation)."""
|
||||
logger.info("DiffusionClient: Waiting for workers")
|
||||
|
||||
# Run the async wait in the background thread's event loop
|
||||
future = asyncio.run_coroutine_threadsafe(self._wait_ready_async(timeout), self._event_loop)
|
||||
return future.result(timeout=timeout)
|
||||
|
||||
async def _wait_ready_async(self, timeout: float = READY_TIMEOUT):
|
||||
"""Wait for workers to be ready (async version)."""
|
||||
start_time = time.time()
|
||||
|
||||
while True:
|
||||
async with self.lock:
|
||||
if -1 in self.completed_responses:
|
||||
self.completed_responses.pop(-1)
|
||||
logger.info("DiffusionClient: Workers ready")
|
||||
return
|
||||
|
||||
if time.time() - start_time > timeout:
|
||||
raise RuntimeError("DiffusionClient: Timeout waiting for workers")
|
||||
|
||||
try:
|
||||
await asyncio.wait_for(self.response_event.wait(), timeout=AWAIT_TIMEOUT)
|
||||
except asyncio.TimeoutError:
|
||||
pass
|
||||
self.response_event.clear()
|
||||
|
||||
|
||||
class DiffusionGenerationResult:
|
||||
"""Future-like object for async generation."""
|
||||
|
||||
def __init__(self, request_id: int, executor: DiffusionRemoteClient):
|
||||
self.request_id = request_id
|
||||
self.executor = executor
|
||||
self._result = None
|
||||
self._finished = False
|
||||
self._error = None
|
||||
|
||||
async def result(self, timeout: Optional[float] = None) -> Any:
|
||||
"""Wait for and return result (async version).
|
||||
|
||||
Can be awaited from any async context (e.g., FastAPI background tasks).
|
||||
"""
|
||||
if self._finished:
|
||||
if self._error:
|
||||
raise RuntimeError(self._error)
|
||||
return self._result
|
||||
|
||||
# Use run_coroutine_threadsafe to execute in the background thread's event loop
|
||||
future = asyncio.run_coroutine_threadsafe(
|
||||
self.executor.await_responses(self.request_id, timeout=timeout),
|
||||
self.executor._event_loop,
|
||||
)
|
||||
|
||||
# Await the future in the current event loop
|
||||
response = await asyncio.wrap_future(future)
|
||||
|
||||
if response.error_msg:
|
||||
self._error = response.error_msg
|
||||
self._finished = True
|
||||
raise RuntimeError(f"Generation failed: {response.error_msg}")
|
||||
|
||||
self._result = response.output
|
||||
self._finished = True
|
||||
return self._result
|
||||
|
||||
def cancel(self):
|
||||
raise NotImplementedError("Cancel request (not yet implemented).")
|
||||
|
||||
|
||||
@dataclass
|
||||
class VisualGenParams:
|
||||
"""Parameters for visual generation.
|
||||
|
||||
Attributes:
|
||||
height: Output height in pixels
|
||||
width: Output width in pixels
|
||||
num_inference_steps: Number of denoising steps
|
||||
guidance_scale: Classifier-free guidance scale
|
||||
max_sequence_length: Maximum sequence length for text encoding
|
||||
seed: Random seed for reproducibility
|
||||
|
||||
# Video-specific parameters
|
||||
num_frames: Number of video frames to generate
|
||||
frame_rate: Frame rate for video output in fps
|
||||
|
||||
# Image-specific parameters
|
||||
num_images_per_prompt: Number of images to generate per prompt (for image models)
|
||||
|
||||
# Advanced parameters
|
||||
guidance_rescale: Guidance rescale factor (for some models)
|
||||
output_type: Output type ("pt" for PyTorch tensors, "pil" for PIL images)
|
||||
"""
|
||||
|
||||
height: int = 720
|
||||
width: int = 1280
|
||||
num_inference_steps: int = 50
|
||||
guidance_scale: float = 5.0
|
||||
max_sequence_length: int = 512
|
||||
seed: int = 42
|
||||
|
||||
# Video-specific parameters
|
||||
num_frames: int = 81
|
||||
frame_rate: float = 24.0
|
||||
input_reference: Optional[str] = None
|
||||
|
||||
# Image-specific parameters
|
||||
num_images_per_prompt: int = 1
|
||||
|
||||
## Image edit parameters
|
||||
image: Optional[List[str]] = None
|
||||
mask: Optional[str] = None
|
||||
|
||||
# Advanced parameters
|
||||
guidance_rescale: float = 0.0
|
||||
output_type: str = "pt"
|
||||
|
||||
# Wan-specific parameters
|
||||
guidance_scale_2: Optional[float] = None
|
||||
boundary_ratio: Optional[float] = None
|
||||
last_image: Optional[str] = None
|
||||
|
||||
|
||||
class VisualGen:
|
||||
"""High-level API for visual generation."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model_path: Union[str, Path],
|
||||
n_workers: int = 1,
|
||||
diffusion_config: Optional[dict] = None,
|
||||
):
|
||||
self.model_path = str(model_path)
|
||||
self.n_workers = n_workers
|
||||
self.diffusion_config = diffusion_config
|
||||
|
||||
self.executor = DiffusionRemoteClient(
|
||||
model_path=self.model_path,
|
||||
n_workers=self.n_workers,
|
||||
diffusion_config=self.diffusion_config,
|
||||
)
|
||||
self.req_counter = 0
|
||||
|
||||
def generate(
|
||||
self,
|
||||
inputs: VisualGenInputs,
|
||||
params: VisualGenParams,
|
||||
) -> MediaOutput:
|
||||
"""Synchronous generation. Blocks until complete.
|
||||
|
||||
Args:
|
||||
params: Generation parameters.
|
||||
|
||||
Returns:
|
||||
MediaOutput: Generated media with model-specific fields populated:
|
||||
- FLUX2: MediaOutput(image=torch.Tensor)
|
||||
- WAN: MediaOutput(video=torch.Tensor)
|
||||
- LTX2: MediaOutput(video=torch.Tensor, audio=torch.Tensor)
|
||||
"""
|
||||
future = self.generate_async(
|
||||
inputs=inputs,
|
||||
params=params,
|
||||
)
|
||||
|
||||
# Use the sync wrapper to get result
|
||||
response = self.executor.await_responses_sync(future.request_id, timeout=None)
|
||||
if response.error_msg:
|
||||
raise RuntimeError(f"Generation failed: {response.error_msg}")
|
||||
return response.output
|
||||
|
||||
def generate_async(
|
||||
self,
|
||||
inputs: VisualGenInputs,
|
||||
params: VisualGenParams,
|
||||
) -> DiffusionGenerationResult:
|
||||
"""Async generation. Returns immediately with future-like object.
|
||||
|
||||
Args:
|
||||
params: Generation parameters.
|
||||
|
||||
Returns:
|
||||
DiffusionGenerationResult: Call result() to get output dict.
|
||||
"""
|
||||
req_id = self.req_counter
|
||||
self.req_counter += 1
|
||||
|
||||
if isinstance(inputs, dict):
|
||||
prompt = inputs.get("prompt")
|
||||
negative_prompt = inputs.get("negative_prompt", None)
|
||||
elif isinstance(inputs, str):
|
||||
prompt = inputs
|
||||
negative_prompt = None
|
||||
else:
|
||||
# TODO: Support batch generation
|
||||
raise ValueError(f"Invalid inputs type: {type(inputs)}")
|
||||
|
||||
request = DiffusionRequest(
|
||||
request_id=req_id,
|
||||
prompt=prompt,
|
||||
negative_prompt=negative_prompt,
|
||||
height=params.height,
|
||||
width=params.width,
|
||||
num_inference_steps=params.num_inference_steps,
|
||||
guidance_scale=params.guidance_scale,
|
||||
max_sequence_length=params.max_sequence_length,
|
||||
seed=params.seed,
|
||||
num_frames=params.num_frames,
|
||||
frame_rate=params.frame_rate,
|
||||
num_images_per_prompt=params.num_images_per_prompt,
|
||||
guidance_rescale=params.guidance_rescale,
|
||||
output_type=params.output_type,
|
||||
image=params.input_reference,
|
||||
guidance_scale_2=params.guidance_scale_2,
|
||||
boundary_ratio=params.boundary_ratio,
|
||||
last_image=params.last_image,
|
||||
)
|
||||
|
||||
self.executor.enqueue_requests([request])
|
||||
return DiffusionGenerationResult(req_id, self.executor)
|
||||
|
||||
def shutdown(self):
|
||||
"""Shutdown executor and cleanup."""
|
||||
logger.info("VisualGen: Shutting down")
|
||||
self.executor.shutdown()
|
||||
@ -16,10 +16,8 @@ from functools import wraps as _wraps
|
||||
|
||||
from tensorrt_llm._utils import mpi_disabled as _mpi_disabled
|
||||
|
||||
if _mpi_disabled():
|
||||
raise RuntimeError(
|
||||
"Ray requested (TLLM_DISABLE_MPI=1), but not installed. Please install Ray."
|
||||
)
|
||||
# Don't raise error on import - only when Ray functionality is actually used
|
||||
_RAY_NOT_INSTALLED_MSG = "Ray requested (TLLM_DISABLE_MPI=1), but not installed. Please install Ray."
|
||||
|
||||
|
||||
def remote(*args, **kwargs):
|
||||
@ -42,6 +40,7 @@ def remote(*args, **kwargs):
|
||||
|
||||
|
||||
def __getattr__(name):
|
||||
raise RuntimeError(
|
||||
f'Ray not installed, so "ray.{name}" is unavailable. Please install Ray.'
|
||||
)
|
||||
msg = f'Ray not installed, so "ray.{name}" is unavailable.'
|
||||
if _mpi_disabled():
|
||||
msg = _RAY_NOT_INSTALLED_MSG
|
||||
raise RuntimeError(msg)
|
||||
|
||||
426
tensorrt_llm/serve/media_storage.py
Normal file
426
tensorrt_llm/serve/media_storage.py
Normal file
@ -0,0 +1,426 @@
|
||||
#!/usr/bin/env python
|
||||
"""Media Storage for generated images and videos.
|
||||
|
||||
This module provides storage handlers for persisting generated media assets
|
||||
(videos, images) and their associated metadata.
|
||||
"""
|
||||
|
||||
import os
|
||||
from io import BytesIO
|
||||
from pathlib import Path
|
||||
from typing import Any, Optional
|
||||
|
||||
import torch
|
||||
from PIL import Image
|
||||
|
||||
from tensorrt_llm.logger import logger
|
||||
|
||||
|
||||
class MediaStorage:
|
||||
"""Handler for storing images and videos in various formats."""
|
||||
|
||||
@staticmethod
|
||||
def save_image(
|
||||
image: Any, output_path: str, format: Optional[str] = None, quality: int = 95
|
||||
) -> str:
|
||||
"""Save image to file.
|
||||
|
||||
Args:
|
||||
image: torch.Tensor (H, W, C) uint8
|
||||
output_path: Path to save the image
|
||||
format: Image format (png, jpg, webp). If None, infer from extension
|
||||
quality: Quality for lossy formats (1-100, higher is better)
|
||||
|
||||
Returns:
|
||||
Path where the image was saved
|
||||
"""
|
||||
# Ensure output directory exists
|
||||
output_dir = os.path.dirname(output_path)
|
||||
if output_dir:
|
||||
os.makedirs(output_dir, exist_ok=True)
|
||||
|
||||
# Convert to PIL Image if needed
|
||||
pil_image = MediaStorage._to_pil_image(image)
|
||||
|
||||
# Determine format
|
||||
if format is None:
|
||||
ext = os.path.splitext(output_path)[1].lower()
|
||||
if ext in [".png"]:
|
||||
format = "PNG"
|
||||
elif ext in [".jpg", ".jpeg"]:
|
||||
format = "JPEG"
|
||||
elif ext in [".webp"]:
|
||||
format = "WEBP"
|
||||
else:
|
||||
logger.warning(f"Unknown image extension {ext}, defaulting to PNG")
|
||||
format = "PNG"
|
||||
output_path = output_path.rsplit(".", 1)[0] + ".png"
|
||||
|
||||
# Save image with format-specific handling
|
||||
MediaStorage._save_pil_image(pil_image, output_path, format, quality)
|
||||
|
||||
logger.info(f"Saved image to {output_path} (format={format})")
|
||||
return output_path
|
||||
|
||||
@staticmethod
|
||||
def convert_image_to_bytes(image: Any, format: str = "PNG", quality: int = 95) -> bytes:
|
||||
"""Convert image to bytes buffer.
|
||||
|
||||
Args:
|
||||
image: torch.Tensor (H, W, C) uint8
|
||||
format: Image format (PNG, JPEG, WEBP)
|
||||
quality: Quality for lossy formats (1-100)
|
||||
|
||||
Returns:
|
||||
Image bytes
|
||||
"""
|
||||
pil_image = MediaStorage._to_pil_image(image)
|
||||
|
||||
# Save to bytes buffer
|
||||
buffer = BytesIO()
|
||||
MediaStorage._save_pil_image(pil_image, buffer, format, quality)
|
||||
|
||||
return buffer.getvalue()
|
||||
|
||||
@staticmethod
|
||||
def _to_pil_image(image: torch.Tensor) -> Image.Image:
|
||||
"""Convert torch.Tensor to PIL Image.
|
||||
|
||||
Args:
|
||||
image: torch.Tensor (H, W, C) uint8
|
||||
|
||||
Returns:
|
||||
PIL Image
|
||||
"""
|
||||
if not isinstance(image, torch.Tensor):
|
||||
raise ValueError(f"Expected torch.Tensor, got {type(image)}")
|
||||
|
||||
# Convert to numpy for PIL
|
||||
image_np = image.cpu().numpy()
|
||||
return Image.fromarray(image_np)
|
||||
|
||||
@staticmethod
|
||||
def _save_pil_image(
|
||||
pil_image: Image.Image,
|
||||
output: Any, # Can be path string or BytesIO
|
||||
format: str,
|
||||
quality: int,
|
||||
):
|
||||
"""Save PIL Image to file or buffer.
|
||||
|
||||
Args:
|
||||
pil_image: PIL Image to save
|
||||
output: Output path (str) or BytesIO buffer
|
||||
format: Image format (PNG, JPEG, WEBP)
|
||||
quality: Quality for lossy formats (1-100)
|
||||
"""
|
||||
format_upper = format.upper()
|
||||
|
||||
if format_upper in ["JPEG", "JPG"]:
|
||||
# Convert RGBA to RGB for JPEG
|
||||
if pil_image.mode in ("RGBA", "LA", "P"):
|
||||
background = Image.new("RGB", pil_image.size, (255, 255, 255))
|
||||
if pil_image.mode == "P":
|
||||
pil_image = pil_image.convert("RGBA")
|
||||
background.paste(
|
||||
pil_image, mask=pil_image.split()[-1] if pil_image.mode == "RGBA" else None
|
||||
)
|
||||
pil_image = background
|
||||
pil_image.save(output, format="JPEG", quality=quality, optimize=True)
|
||||
elif format_upper == "WEBP":
|
||||
pil_image.save(output, format="WEBP", quality=quality)
|
||||
else: # PNG or default
|
||||
pil_image.save(output, format="PNG", optimize=True)
|
||||
|
||||
@staticmethod
|
||||
def save_video(
|
||||
video: Any,
|
||||
output_path: str,
|
||||
audio: Optional[Any] = None,
|
||||
frame_rate: float = 24.0,
|
||||
format: Optional[str] = None,
|
||||
) -> str:
|
||||
"""Save video to file with optional audio.
|
||||
|
||||
Args:
|
||||
video: Video frames as torch.Tensor (T, H, W, C) uint8
|
||||
output_path: Path to save the video
|
||||
audio: Optional audio as torch.Tensor
|
||||
frame_rate: Frames per second (default: 24.0)
|
||||
format: Video format (mp4, gif, png). If None, infer from extension
|
||||
|
||||
Returns:
|
||||
Path where the video was saved
|
||||
"""
|
||||
# Ensure output directory exists
|
||||
if isinstance(output_path, Path):
|
||||
output_path = str(output_path)
|
||||
|
||||
output_dir = os.path.dirname(output_path)
|
||||
if output_dir:
|
||||
os.makedirs(output_dir, exist_ok=True)
|
||||
|
||||
# Determine format
|
||||
if format is None:
|
||||
ext = os.path.splitext(output_path)[1].lower()
|
||||
format = ext[1:] if ext else "mp4"
|
||||
|
||||
format = format.lower()
|
||||
|
||||
# Save based on format
|
||||
if format == "mp4":
|
||||
MediaStorage._save_mp4(video, audio, output_path, frame_rate)
|
||||
elif format == "gif":
|
||||
MediaStorage._save_gif(video, output_path, frame_rate)
|
||||
elif format == "png":
|
||||
MediaStorage._save_middle_frame(video, output_path)
|
||||
else:
|
||||
logger.warning(f"Unsupported video format: {format}, defaulting to mp4")
|
||||
output_path = output_path.rsplit(".", 1)[0] + ".mp4"
|
||||
MediaStorage._save_mp4(video, audio, output_path, frame_rate)
|
||||
|
||||
return output_path
|
||||
|
||||
@staticmethod
|
||||
def convert_video_to_bytes(
|
||||
video: Any, audio: Optional[Any] = None, frame_rate: float = 24.0, format: str = "mp4"
|
||||
) -> bytes:
|
||||
"""Convert video to bytes buffer.
|
||||
|
||||
Args:
|
||||
video: Video frames as torch.Tensor (T, H, W, C) uint8
|
||||
audio: Optional audio as torch.Tensor
|
||||
frame_rate: Frames per second
|
||||
format: Video format (mp4, gif)
|
||||
|
||||
Returns:
|
||||
Video bytes
|
||||
"""
|
||||
import tempfile
|
||||
|
||||
# Create temporary file
|
||||
with tempfile.NamedTemporaryFile(suffix=f".{format}", delete=False) as tmp_file:
|
||||
tmp_path = tmp_file.name
|
||||
|
||||
try:
|
||||
# Save to temporary file
|
||||
MediaStorage.save_video(video, tmp_path, audio, frame_rate, format)
|
||||
|
||||
# Read bytes
|
||||
with open(tmp_path, "rb") as f:
|
||||
video_bytes = f.read()
|
||||
|
||||
return video_bytes
|
||||
finally:
|
||||
# Clean up temporary file
|
||||
if os.path.exists(tmp_path):
|
||||
os.unlink(tmp_path)
|
||||
|
||||
@staticmethod
|
||||
def _save_mp4(
|
||||
video: torch.Tensor, audio: Optional[torch.Tensor], output_path: str, frame_rate: float
|
||||
) -> str:
|
||||
"""Save video with optional audio as MP4.
|
||||
|
||||
Args:
|
||||
video: Video frames as torch.Tensor (T, H, W, C) uint8
|
||||
audio: Optional audio as torch.Tensor
|
||||
output_path: Output path for MP4
|
||||
frame_rate: Frames per second
|
||||
|
||||
Returns:
|
||||
Path where the video was saved
|
||||
"""
|
||||
try:
|
||||
from fractions import Fraction
|
||||
|
||||
import av
|
||||
|
||||
if not isinstance(video, torch.Tensor):
|
||||
raise ValueError(f"Expected torch.Tensor for video, got {type(video)}")
|
||||
|
||||
# Convert video tensor to numpy: (T, H, W, C) uint8
|
||||
video_np = video.cpu().numpy()
|
||||
num_frames, height, width, channels = video_np.shape
|
||||
|
||||
# Ensure RGB format (3 channels)
|
||||
if channels != 3:
|
||||
raise ValueError(f"Expected 3-channel RGB video, got {channels} channels")
|
||||
|
||||
# Open output container
|
||||
container = av.open(output_path, mode="w")
|
||||
|
||||
# Add video stream (H.264 codec)
|
||||
video_stream = container.add_stream("libx264", rate=int(frame_rate))
|
||||
video_stream.width = width
|
||||
video_stream.height = height
|
||||
video_stream.pix_fmt = "yuv420p"
|
||||
video_stream.options = {"preset": "medium", "crf": "23"}
|
||||
|
||||
# Pre-process audio and add audio stream BEFORE any muxing.
|
||||
# All streams must be registered before the first mux() call
|
||||
# (which triggers container header writing).
|
||||
audio_stream = None
|
||||
audio_tensor = None
|
||||
audio_sample_rate = 24000 # Default sample rate
|
||||
if audio is not None:
|
||||
if not isinstance(audio, torch.Tensor):
|
||||
raise ValueError(f"Expected torch.Tensor for audio, got {type(audio)}")
|
||||
|
||||
# Prepare audio tensor: convert to (samples, channels) format
|
||||
audio_tensor = audio
|
||||
|
||||
# Handle different audio tensor dimensions
|
||||
if audio_tensor.ndim == 1:
|
||||
# Mono audio: (samples,) -> (samples, 1)
|
||||
audio_tensor = audio_tensor[:, None]
|
||||
elif audio_tensor.ndim == 2:
|
||||
# If shape[1] != 2 and shape[0] == 2, transpose to (samples, channels)
|
||||
if audio_tensor.shape[1] != 2 and audio_tensor.shape[0] == 2:
|
||||
audio_tensor = audio_tensor.T
|
||||
if audio_tensor.shape[1] > 2:
|
||||
audio_tensor = audio_tensor[:, :2]
|
||||
elif audio_tensor.ndim == 3:
|
||||
if audio_tensor.shape[0] == 1:
|
||||
audio_tensor = audio_tensor.squeeze(0)
|
||||
else:
|
||||
audio_tensor = audio_tensor[0]
|
||||
if audio_tensor.shape[1] != 2 and audio_tensor.shape[0] == 2:
|
||||
audio_tensor = audio_tensor.T
|
||||
if audio_tensor.shape[1] > 2:
|
||||
audio_tensor = audio_tensor[:, :2]
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unsupported audio tensor shape: {audio_tensor.shape}. "
|
||||
f"Expected 1D, 2D, or 3D tensor."
|
||||
)
|
||||
|
||||
if audio_tensor.shape[1] > 2:
|
||||
audio_tensor = audio_tensor[:, :2]
|
||||
|
||||
# Convert to int16 if needed
|
||||
if audio_tensor.dtype != torch.int16:
|
||||
audio_tensor = torch.clip(audio_tensor, -1.0, 1.0)
|
||||
audio_tensor = (audio_tensor * 32767.0).to(torch.int16)
|
||||
|
||||
# Add audio stream now (before any muxing)
|
||||
audio_stream = container.add_stream("aac", rate=audio_sample_rate)
|
||||
audio_stream.codec_context.sample_rate = audio_sample_rate
|
||||
audio_stream.codec_context.layout = "stereo"
|
||||
audio_stream.codec_context.time_base = Fraction(1, audio_sample_rate)
|
||||
|
||||
# --- Encode video frames ---
|
||||
for frame_array in video_np:
|
||||
frame = av.VideoFrame.from_ndarray(frame_array, format="rgb24")
|
||||
for packet in video_stream.encode(frame):
|
||||
container.mux(packet)
|
||||
|
||||
# Flush video encoder
|
||||
for packet in video_stream.encode():
|
||||
container.mux(packet)
|
||||
|
||||
# --- Encode audio (after video is done) ---
|
||||
if audio_stream is not None and audio_tensor is not None:
|
||||
# Build packed int16 frame: (1, samples*channels)
|
||||
audio_np = audio_tensor.contiguous().reshape(1, -1).cpu().numpy()
|
||||
|
||||
frame_in = av.AudioFrame.from_ndarray(audio_np, format="s16", layout="stereo")
|
||||
frame_in.sample_rate = audio_sample_rate
|
||||
|
||||
# Use AudioResampler to convert s16→fltp (AAC's native format)
|
||||
cc = audio_stream.codec_context
|
||||
audio_resampler = av.audio.resampler.AudioResampler(
|
||||
format=cc.format or "fltp",
|
||||
layout=cc.layout or "stereo",
|
||||
rate=cc.sample_rate or audio_sample_rate,
|
||||
)
|
||||
|
||||
audio_next_pts = 0
|
||||
for rframe in audio_resampler.resample(frame_in):
|
||||
if rframe.pts is None:
|
||||
rframe.pts = audio_next_pts
|
||||
audio_next_pts += rframe.samples
|
||||
rframe.sample_rate = audio_sample_rate
|
||||
container.mux(audio_stream.encode(rframe))
|
||||
|
||||
# Flush audio encoder
|
||||
for packet in audio_stream.encode():
|
||||
container.mux(packet)
|
||||
|
||||
# Close container
|
||||
container.close()
|
||||
|
||||
logger.info(f"Saved video{' with audio' if audio is not None else ''} to {output_path}")
|
||||
return output_path
|
||||
|
||||
except ImportError:
|
||||
logger.warning(
|
||||
"PyAV (av) library not available. "
|
||||
"Falling back to saving middle frame as PNG. "
|
||||
"Install with: pip install av"
|
||||
)
|
||||
png_path = output_path.replace(".mp4", ".png")
|
||||
return MediaStorage._save_middle_frame(video, png_path)
|
||||
except Exception as e:
|
||||
logger.error(f"Error encoding video with PyAV: {e}")
|
||||
import traceback
|
||||
|
||||
logger.error(traceback.format_exc())
|
||||
logger.warning("Falling back to saving middle frame as PNG.")
|
||||
png_path = output_path.replace(".mp4", ".png")
|
||||
return MediaStorage._save_middle_frame(video, png_path)
|
||||
|
||||
@staticmethod
|
||||
def _save_gif(video: torch.Tensor, output_path: str, frame_rate: float) -> str:
|
||||
"""Save video as animated GIF.
|
||||
|
||||
Args:
|
||||
video: Video frames as torch.Tensor (T, H, W, C) uint8
|
||||
output_path: Output path for GIF
|
||||
frame_rate: Frames per second
|
||||
|
||||
Returns:
|
||||
Path where the GIF was saved
|
||||
"""
|
||||
if not isinstance(video, torch.Tensor):
|
||||
raise ValueError(f"Expected torch.Tensor for video, got {type(video)}")
|
||||
|
||||
# Convert to numpy and then to list of PIL Images
|
||||
video_np = video.cpu().numpy()
|
||||
frames = [Image.fromarray(video_np[i]) for i in range(video_np.shape[0])]
|
||||
|
||||
# Save as GIF
|
||||
duration_ms = int(1000 / frame_rate)
|
||||
frames[0].save(
|
||||
output_path,
|
||||
save_all=True,
|
||||
append_images=frames[1:],
|
||||
optimize=False,
|
||||
duration=duration_ms,
|
||||
loop=0,
|
||||
)
|
||||
logger.info(f"Saved video as GIF to {output_path} ({len(frames)} frames)")
|
||||
return output_path
|
||||
|
||||
@staticmethod
|
||||
def _save_middle_frame(video: torch.Tensor, output_path: str) -> str:
|
||||
"""Save middle frame of video as PNG.
|
||||
|
||||
Args:
|
||||
video: Video frames as torch.Tensor (T, H, W, C) uint8
|
||||
output_path: Output path for PNG
|
||||
|
||||
Returns:
|
||||
Path where the frame was saved
|
||||
"""
|
||||
if not isinstance(video, torch.Tensor):
|
||||
raise ValueError(f"Expected torch.Tensor for video, got {type(video)}")
|
||||
|
||||
# Extract middle frame
|
||||
video_np = video.cpu().numpy()
|
||||
frame_idx = video_np.shape[0] // 2
|
||||
image = Image.fromarray(video_np[frame_idx])
|
||||
|
||||
image.save(output_path)
|
||||
logger.info(f"Saved frame {frame_idx} to {output_path}")
|
||||
return output_path
|
||||
@ -1,12 +1,14 @@
|
||||
# Adapted from
|
||||
# https://github.com/vllm-project/vllm/blob/4db5176d9758b720b05460c50ace3c01026eb158/vllm/entrypoints/openai/protocol.py
|
||||
import base64
|
||||
import re
|
||||
import time
|
||||
import uuid
|
||||
from typing import Any, Dict, List, Literal, Optional, Union
|
||||
|
||||
import torch
|
||||
import xgrammar
|
||||
from fastapi import UploadFile
|
||||
from openai.types.chat import ChatCompletionAssistantMessageParam
|
||||
from openai.types.chat import \
|
||||
ChatCompletionContentPartParam as OpenAIChatCompletionContentPartParam
|
||||
@ -1120,5 +1122,218 @@ def to_llm_disaggregated_params(
|
||||
)
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Diffusion API Protocol Classes
|
||||
# ============================================================================
|
||||
|
||||
|
||||
class ImageGenerationRequest(OpenAIBaseModel):
|
||||
"""OpenAI-compatible image generation request.
|
||||
|
||||
Follows the OpenAI Images API specification:
|
||||
https://platform.openai.com/docs/api-reference/images/create
|
||||
"""
|
||||
prompt: str
|
||||
model: Optional[str] = None
|
||||
n: int = Field(default=1, ge=1, le=10)
|
||||
output_format: Literal["png", "webp", "jpeg"] = "png"
|
||||
size: Optional[str] = Field(
|
||||
default="auto",
|
||||
description=(
|
||||
"The size of the generated images. Must be in 'WxH' format like "
|
||||
"1024x1024, 1536x1024 (landscape), 1024x1536 (portrait), etc. "
|
||||
"Use 'auto' for model default size."))
|
||||
quality: Literal["standard", "hd"] = "standard"
|
||||
response_format: Literal["url", "b64_json"] = "url"
|
||||
style: Optional[Literal["vivid", "natural"]] = "vivid"
|
||||
user: Optional[str] = None
|
||||
|
||||
# Extended parameters for diffusion control
|
||||
num_inference_steps: Optional[int] = Field(
|
||||
default=None,
|
||||
description=
|
||||
"Number of denoising steps. More steps = higher quality but slower.")
|
||||
guidance_scale: Optional[float] = Field(
|
||||
default=None,
|
||||
description=
|
||||
"Classifier-free guidance scale. Higher values follow prompt more closely."
|
||||
)
|
||||
guidance_rescale: Optional[float] = Field(
|
||||
default=None, description="Classifier-free guidance rescale.")
|
||||
negative_prompt: Optional[str] = Field(
|
||||
default=None,
|
||||
description="Text describing what to avoid in the generated image.")
|
||||
seed: Optional[int] = Field(default=None,
|
||||
description="Random seed for reproducibility.")
|
||||
|
||||
@field_validator("size")
|
||||
@classmethod
|
||||
def validate_size(cls, v):
|
||||
"""Validate size format is 'WxH' or 'auto'."""
|
||||
if v is None or v == "auto":
|
||||
return v
|
||||
if not isinstance(v, str):
|
||||
raise ValueError("size must be a string in 'WxH' format or 'auto'")
|
||||
# Check format: should be like "1024x1024"
|
||||
import re
|
||||
if not re.match(r'^\d+x\d+$', v):
|
||||
raise ValueError(
|
||||
f"Invalid size format '{v}'. Must be in 'WxH' format "
|
||||
"(e.g., '1024x1024', '1536x1024') or 'auto'.")
|
||||
return v
|
||||
|
||||
|
||||
class ImageObject(OpenAIBaseModel):
|
||||
"""Generated image object in the response."""
|
||||
b64_json: Optional[str] = None
|
||||
url: Optional[str] = None
|
||||
revised_prompt: Optional[str] = None
|
||||
|
||||
|
||||
class ImageGenerationResponse(OpenAIBaseModel):
|
||||
"""Response from image generation endpoint."""
|
||||
created: int = Field(default_factory=lambda: int(time.time()))
|
||||
data: List[ImageObject]
|
||||
output_format: Literal["png", "webp", "jpeg"] = "png"
|
||||
quality: Literal["low", "medium", "high"] = "medium"
|
||||
size: Optional[str] = None
|
||||
|
||||
|
||||
class ImageEditRequest(OpenAIBaseModel):
|
||||
"""Request for image editing endpoint.
|
||||
|
||||
Follows the OpenAI Images API specification:
|
||||
https://platform.openai.com/docs/api-reference/images/createEdit
|
||||
"""
|
||||
image: Union[List[str], str] = Field(
|
||||
description="Base64-encoded source image(s) to edit")
|
||||
prompt: str = Field(description="Text description of desired edits")
|
||||
model: Optional[str] = None
|
||||
mask: Optional[str] = Field(
|
||||
default=None,
|
||||
description=
|
||||
"Base64-encoded mask image (optional, black areas will be edited)")
|
||||
n: int = Field(default=1, ge=1, le=10)
|
||||
size: Optional[str] = Field(
|
||||
default="auto",
|
||||
description=(
|
||||
"The size of the edited images. Must be in 'WxH' format like "
|
||||
"1024x1024, 1536x1024 (landscape), 1024x1536 (portrait), etc. "
|
||||
"Use 'auto' to match source image size."))
|
||||
response_format: Literal["url", "b64_json"] = "url"
|
||||
user: Optional[str] = None
|
||||
|
||||
# Extended parameters for diffusion control
|
||||
num_inference_steps: Optional[int] = Field(
|
||||
default=None, description="Number of denoising steps.")
|
||||
guidance_scale: Optional[float] = Field(
|
||||
default=None, description="Classifier-free guidance scale.")
|
||||
guidance_rescale: Optional[float] = Field(
|
||||
default=None, description="Classifier-free guidance rescale.")
|
||||
negative_prompt: Optional[str] = Field(
|
||||
default=None,
|
||||
description="Text describing what to avoid in the edited image.")
|
||||
seed: Optional[int] = Field(default=None,
|
||||
description="Random seed for reproducibility.")
|
||||
|
||||
@field_validator("size")
|
||||
@classmethod
|
||||
def validate_size(cls, v):
|
||||
"""Validate size format is 'WxH' or 'auto'."""
|
||||
if v != "auto" and not re.match(r"^\d+x\d+$", v):
|
||||
raise ValueError(
|
||||
"Size must be 'auto' or in 'WxH' format (e.g., '1024x1024')")
|
||||
return v
|
||||
|
||||
|
||||
class VideoGenerationRequest(OpenAIBaseModel):
|
||||
"""Video generation request (extended API).
|
||||
|
||||
This is an extension to the OpenAI API for video generation support.
|
||||
"""
|
||||
prompt: str
|
||||
input_reference: Optional[Union[str, UploadFile]] = Field(
|
||||
default=None,
|
||||
description="Optional image reference that guides generation.")
|
||||
model: Optional[str] = None
|
||||
size: Optional[str] = Field(
|
||||
default="auto",
|
||||
description=
|
||||
("The size of the generated video frames. Must be in 'WxH' format like "
|
||||
"512x512, 1024x576 (landscape), 576x1024 (portrait), etc. "
|
||||
"Use 'auto' for model default size."))
|
||||
seconds: float = Field(default=2.0,
|
||||
ge=1.0,
|
||||
le=16.0,
|
||||
description="Video duration in seconds.")
|
||||
|
||||
# Extended parameters for diffusion control
|
||||
n: int = Field(default=1, ge=1, le=4)
|
||||
fps: int = Field(default=24, ge=8, le=60, description="Frames per second.")
|
||||
num_inference_steps: Optional[int] = Field(
|
||||
default=None, description="Number of denoising steps.")
|
||||
guidance_scale: Optional[float] = Field(
|
||||
default=None, description="Classifier-free guidance scale.")
|
||||
guidance_rescale: Optional[float] = Field(
|
||||
default=None, description="Classifier-free guidance rescale.")
|
||||
negative_prompt: Optional[str] = Field(
|
||||
default=None,
|
||||
description="Text describing what to avoid in the generated video.")
|
||||
seed: Optional[int] = Field(default=None,
|
||||
description="Random seed for reproducibility.")
|
||||
|
||||
@field_validator("size")
|
||||
@classmethod
|
||||
def validate_size(cls, v):
|
||||
"""Validate size format is 'WxH' or 'auto'."""
|
||||
if v is None or v == "auto":
|
||||
return v
|
||||
if not isinstance(v, str):
|
||||
raise ValueError("size must be a string in 'WxH' format or 'auto'")
|
||||
import re
|
||||
if not re.match(r'^\d+x\d+$', v):
|
||||
raise ValueError(
|
||||
f"Invalid size format '{v}'. Must be in 'WxH' format "
|
||||
"(e.g., '512x512', '1024x576') or 'auto'.")
|
||||
return v
|
||||
|
||||
|
||||
class VideoJob(OpenAIBaseModel):
|
||||
"""Metadata for an asynchronous video generation job.
|
||||
|
||||
Follows the OpenAI Videos API specification:
|
||||
https://platform.openai.com/docs/api-reference/videos
|
||||
"""
|
||||
completed_at: Optional[int] = Field(
|
||||
default=None, description="Unix timestamp of completion")
|
||||
created_at: int = Field(description="Unix timestamp of creation")
|
||||
error: Optional[str] = Field(default=None,
|
||||
description="Error message if failed")
|
||||
expires_at: Optional[int] = Field(
|
||||
default=None, description="Unix timestamp of expiration")
|
||||
id: str = Field(description="Unique identifier for the video")
|
||||
model: str = Field(description="The model used for generation")
|
||||
object: str = Field(default="video", description="Object type")
|
||||
progress: Optional[int] = Field(
|
||||
default=None,
|
||||
description="Progress of the video generation job (0-100)")
|
||||
prompt: str = Field(description="The prompt used to generate the video")
|
||||
status: Literal["queued", "in_progress", "completed", "failed"] = Field(
|
||||
description="Current status of the video generation job")
|
||||
|
||||
# Video properties
|
||||
duration: Optional[float] = Field(default=None,
|
||||
description="Video duration in seconds")
|
||||
fps: Optional[int] = Field(default=None, description="Frames per second")
|
||||
size: Optional[str] = Field(default=None,
|
||||
description="Video dimensions in 'WxH' format")
|
||||
|
||||
|
||||
class VideoJobList(OpenAIBaseModel):
|
||||
"""Response from listing video jobs endpoint."""
|
||||
data: List[VideoJob] = Field(description="List of video jobs")
|
||||
object: str = Field(default="list", description="Object type")
|
||||
|
||||
|
||||
UCompletionRequest = Union[CompletionRequest, ChatCompletionRequest]
|
||||
UCompletionResponse = Union[CompletionResponse, ChatCompletionResponse]
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
112
tensorrt_llm/serve/visual_gen_utils.py
Normal file
112
tensorrt_llm/serve/visual_gen_utils.py
Normal file
@ -0,0 +1,112 @@
|
||||
import asyncio
|
||||
import base64
|
||||
import os
|
||||
import shutil
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from tensorrt_llm.llmapi.visual_gen import VisualGenParams
|
||||
from tensorrt_llm.serve.openai_protocol import (
|
||||
ImageEditRequest,
|
||||
ImageGenerationRequest,
|
||||
VideoGenerationRequest,
|
||||
)
|
||||
|
||||
|
||||
def parse_visual_gen_params(
|
||||
request: ImageGenerationRequest | VideoGenerationRequest | ImageEditRequest,
|
||||
id: str,
|
||||
media_storage_path: Optional[str] = None,
|
||||
) -> VisualGenParams:
|
||||
params = VisualGenParams()
|
||||
params.prompt = request.prompt
|
||||
if request.negative_prompt is not None:
|
||||
params.negative_prompt = request.negative_prompt
|
||||
if request.size is not None and request.size != "auto":
|
||||
params.width, params.height = map(int, request.size.split("x"))
|
||||
if request.guidance_scale is not None:
|
||||
params.guidance_scale = request.guidance_scale
|
||||
if request.guidance_rescale is not None:
|
||||
params.guidance_rescale = request.guidance_rescale
|
||||
|
||||
if isinstance(request, ImageGenerationRequest) or isinstance(request, ImageEditRequest):
|
||||
if request.num_inference_steps is not None:
|
||||
params.num_inference_steps = request.num_inference_steps
|
||||
elif isinstance(request, ImageGenerationRequest) and request.quality == "hd":
|
||||
params.num_inference_steps = 30
|
||||
if request.n is not None:
|
||||
params.num_images_per_prompt = request.n
|
||||
if isinstance(request, ImageEditRequest):
|
||||
if request.image is not None:
|
||||
if isinstance(request.image, list):
|
||||
params.image = [base64.b64decode(image) for image in request.image]
|
||||
else:
|
||||
params.image = [base64.b64decode(request.image)]
|
||||
if request.mask is not None:
|
||||
if isinstance(request.mask, list):
|
||||
params.mask = [base64.b64decode(mask) for mask in request.mask]
|
||||
else:
|
||||
params.mask = base64.b64decode(request.mask)
|
||||
|
||||
elif isinstance(request, VideoGenerationRequest):
|
||||
if request.num_inference_steps is not None:
|
||||
params.num_inference_steps = request.num_inference_steps
|
||||
if request.input_reference is not None:
|
||||
if media_storage_path is None:
|
||||
raise ValueError("media_storage_path is required when input_reference is provided")
|
||||
params.input_reference = os.path.join(media_storage_path, f"{id}_reference.png")
|
||||
if isinstance(request.input_reference, str):
|
||||
with open(params.input_reference, "wb") as f:
|
||||
f.write(base64.b64decode(request.input_reference))
|
||||
else:
|
||||
with open(params.input_reference, "wb") as f:
|
||||
shutil.copyfileobj(request.input_reference.file, f)
|
||||
|
||||
params.frame_rate = request.fps
|
||||
params.num_frames = int(request.seconds * request.fps)
|
||||
|
||||
if request.seed is not None:
|
||||
params.seed = int(request.seed)
|
||||
|
||||
return params
|
||||
|
||||
|
||||
class AsyncDictStore:
|
||||
"""A small async-safe in-memory key-value store for dict items.
|
||||
|
||||
This encapsulates the usual pattern of a module-level dict guarded by
|
||||
an asyncio.Lock and provides simple CRUD methods that are safe to call
|
||||
concurrently from FastAPI request handlers and background tasks.
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self._items: Dict[str, Dict[str, Any]] = {}
|
||||
self._lock = asyncio.Lock()
|
||||
|
||||
async def upsert(self, key: str, value: Dict[str, Any]) -> None:
|
||||
async with self._lock:
|
||||
self._items[key] = value
|
||||
|
||||
async def update_fields(self, key: str, updates: Dict[str, Any]) -> Optional[Dict[str, Any]]:
|
||||
async with self._lock:
|
||||
item = self._items.get(key)
|
||||
if item is None:
|
||||
return None
|
||||
item.update(updates)
|
||||
return item
|
||||
|
||||
async def get(self, key: str) -> Optional[Dict[str, Any]]:
|
||||
async with self._lock:
|
||||
return self._items.get(key)
|
||||
|
||||
async def pop(self, key: str) -> Optional[Dict[str, Any]]:
|
||||
async with self._lock:
|
||||
return self._items.pop(key, None)
|
||||
|
||||
async def list_values(self) -> List[Dict[str, Any]]:
|
||||
async with self._lock:
|
||||
return list(self._items.values())
|
||||
|
||||
|
||||
# Global stores shared by OpenAI entrypoints
|
||||
# [request_id, dict]
|
||||
VIDEO_STORE = AsyncDictStore()
|
||||
288
tests/integration/defs/examples/test_visual_gen.py
Normal file
288
tests/integration/defs/examples/test_visual_gen.py
Normal file
@ -0,0 +1,288 @@
|
||||
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Integration tests: VBench dimension scores for WAN and LTX-2 (TRT-LLM vs diffusers reference)."""
|
||||
|
||||
import glob
|
||||
import json
|
||||
import os
|
||||
|
||||
import pytest
|
||||
from defs.common import venv_check_call
|
||||
from defs.conftest import llm_models_root
|
||||
from defs.trt_test_alternative import check_call
|
||||
|
||||
WAN_T2V_MODEL_SUBPATH = "Wan2.1-T2V-1.3B-Diffusers"
|
||||
VISUAL_GEN_OUTPUT_VIDEO = "trtllm_output.mp4"
|
||||
DIFFUSERS_REFERENCE_VIDEO = "diffusers_reference.mp4"
|
||||
WAN_T2V_PROMPT = "A cute cat playing piano"
|
||||
WAN_T2V_HEIGHT = 480
|
||||
WAN_T2V_WIDTH = 832
|
||||
WAN_T2V_NUM_FRAMES = 165
|
||||
|
||||
# Dimensions to evaluate
|
||||
VBENCH_DIMENSIONS = [
|
||||
"subject_consistency",
|
||||
"background_consistency",
|
||||
"motion_smoothness",
|
||||
"dynamic_degree",
|
||||
"aesthetic_quality",
|
||||
"imaging_quality",
|
||||
]
|
||||
|
||||
# Golden VBench scores from HF reference video (WAN); TRT-LLM is compared against these.
|
||||
VBENCH_WAN_GOLDEN_SCORES = {
|
||||
"subject_consistency": 0.9381,
|
||||
"background_consistency": 0.9535,
|
||||
"motion_smoothness": 0.9923,
|
||||
"dynamic_degree": 1.0000,
|
||||
"aesthetic_quality": 0.5033,
|
||||
"imaging_quality": 0.3033,
|
||||
}
|
||||
|
||||
VBENCH_REPO = "https://github.com/Vchitect/VBench.git"
|
||||
VBENCH_BRANCH = "master"
|
||||
# Pin to a fixed commit for reproducible runs
|
||||
VBENCH_COMMIT = "98b19513678e99c80d8377fda25ba53b81a491a6"
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def vbench_repo_root(llm_venv):
|
||||
"""Clone VBench repo into workspace and install; return repo root path."""
|
||||
workspace = llm_venv.get_working_directory()
|
||||
repo_path = os.path.join(workspace, "VBench_repo")
|
||||
if os.path.exists(repo_path):
|
||||
return repo_path
|
||||
# Clone without --depth=1 so we can checkout a specific commit
|
||||
check_call(
|
||||
["git", "clone", "--single-branch", "--branch", VBENCH_BRANCH, VBENCH_REPO, repo_path],
|
||||
shell=False,
|
||||
)
|
||||
check_call(["git", "-C", repo_path, "checkout", VBENCH_COMMIT], shell=False)
|
||||
# # Install VBench dependencies explicitly
|
||||
# llm_venv.run_cmd([
|
||||
# "-m", "pip", "install",
|
||||
# "tqdm>=4.60.0",
|
||||
# "openai-clip>=1.0",
|
||||
# "pyiqa>=0.1.0", # install this might also install transformers=4.37.2, which is incompatible
|
||||
# "easydict",
|
||||
# "decord>=0.6.0",
|
||||
# ])
|
||||
return repo_path
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def wan_trtllm_video_path(llm_venv, llm_root):
|
||||
"""Generate input video via visual_gen_wan_t2v.py and return path to trtllm_output.mp4."""
|
||||
scratch_space = llm_models_root()
|
||||
model_path = os.path.join(scratch_space, WAN_T2V_MODEL_SUBPATH)
|
||||
if not os.path.isdir(model_path):
|
||||
pytest.skip(
|
||||
f"Wan T2V model not found: {model_path} "
|
||||
f"(set LLM_MODELS_ROOT or place {WAN_T2V_MODEL_SUBPATH} under scratch)"
|
||||
)
|
||||
out_dir = os.path.join(llm_venv.get_working_directory(), "visual_gen_output")
|
||||
os.makedirs(out_dir, exist_ok=True)
|
||||
output_path = os.path.join(out_dir, VISUAL_GEN_OUTPUT_VIDEO)
|
||||
if os.path.isfile(output_path):
|
||||
return output_path
|
||||
# Install av and diffusers from main branch
|
||||
llm_venv.run_cmd(["-m", "pip", "install", "av"])
|
||||
llm_venv.run_cmd(
|
||||
[
|
||||
"-m",
|
||||
"pip",
|
||||
"install",
|
||||
"git+https://github.com/huggingface/diffusers.git",
|
||||
]
|
||||
)
|
||||
script_path = os.path.join(llm_root, "examples", "visual_gen", "visual_gen_wan_t2v.py")
|
||||
assert os.path.isfile(script_path), f"Visual gen script not found: {script_path}"
|
||||
venv_check_call(
|
||||
llm_venv,
|
||||
[
|
||||
script_path,
|
||||
"--height",
|
||||
str(WAN_T2V_HEIGHT),
|
||||
"--width",
|
||||
str(WAN_T2V_WIDTH),
|
||||
"--num_frames",
|
||||
str(WAN_T2V_NUM_FRAMES),
|
||||
"--model_path",
|
||||
model_path,
|
||||
"--prompt",
|
||||
WAN_T2V_PROMPT,
|
||||
"--output_path",
|
||||
output_path,
|
||||
],
|
||||
)
|
||||
assert os.path.isfile(output_path), f"Visual gen did not produce {output_path}"
|
||||
return output_path
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def wan_reference_video_path(llm_venv, llm_root):
|
||||
"""Generate reference video via diffusers (hf_wan.py) using the same model checkpoint."""
|
||||
scratch_space = llm_models_root()
|
||||
model_path = os.path.join(scratch_space, WAN_T2V_MODEL_SUBPATH)
|
||||
if not os.path.isdir(model_path):
|
||||
pytest.skip(
|
||||
f"Wan T2V model not found: {model_path} "
|
||||
f"(set LLM_MODELS_ROOT or place {WAN_T2V_MODEL_SUBPATH} under scratch)"
|
||||
)
|
||||
out_dir = os.path.join(llm_venv.get_working_directory(), "visual_gen_output")
|
||||
os.makedirs(out_dir, exist_ok=True)
|
||||
reference_path = os.path.join(out_dir, DIFFUSERS_REFERENCE_VIDEO)
|
||||
if os.path.isfile(reference_path):
|
||||
return reference_path
|
||||
hf_script = os.path.join(llm_root, "examples", "visual_gen", "hf_wan.py")
|
||||
assert os.path.isfile(hf_script), f"Diffusers script not found: {hf_script}"
|
||||
venv_check_call(
|
||||
llm_venv,
|
||||
[
|
||||
hf_script,
|
||||
"--model_path",
|
||||
model_path,
|
||||
"--prompt",
|
||||
WAN_T2V_PROMPT,
|
||||
"--output_path",
|
||||
reference_path,
|
||||
"--height",
|
||||
str(WAN_T2V_HEIGHT),
|
||||
"--width",
|
||||
str(WAN_T2V_WIDTH),
|
||||
"--num_frames",
|
||||
str(WAN_T2V_NUM_FRAMES),
|
||||
],
|
||||
)
|
||||
assert os.path.isfile(reference_path), f"Diffusers did not produce {reference_path}"
|
||||
return reference_path
|
||||
|
||||
|
||||
def _visual_gen_out_dir(llm_venv, subdir=""):
|
||||
"""Output directory for generated media; subdir e.g. 'ltx2' for model-specific outputs."""
|
||||
base = os.path.join(llm_venv.get_working_directory(), "visual_gen_output")
|
||||
return os.path.join(base, subdir) if subdir else base
|
||||
|
||||
|
||||
def _normalize_score(val):
|
||||
"""Normalize to 0-1 scale (e.g. imaging_quality can be 0-100)."""
|
||||
if isinstance(val, bool):
|
||||
return float(val)
|
||||
if isinstance(val, (int, float)) and val > 1.5:
|
||||
return val / 100.0
|
||||
return float(val)
|
||||
|
||||
|
||||
def _get_per_video_scores(results, video_path_substr):
|
||||
"""From VBench results, get per-dimension score for the video whose path contains video_path_substr."""
|
||||
scores = {}
|
||||
for dim in VBENCH_DIMENSIONS:
|
||||
dim_result = results[dim]
|
||||
assert isinstance(dim_result, list) and len(dim_result) >= 2, (
|
||||
f"Dimension '{dim}' result must be [overall_score, video_results]; got {type(dim_result)}"
|
||||
)
|
||||
video_results = dim_result[1]
|
||||
for entry in video_results:
|
||||
if video_path_substr in entry.get("video_path", ""):
|
||||
raw = entry.get("video_results")
|
||||
scores[dim] = _normalize_score(raw)
|
||||
break
|
||||
else:
|
||||
raise AssertionError(
|
||||
f"No video matching '{video_path_substr}' in dimension '{dim}'; "
|
||||
f"paths: {[e.get('video_path') for e in video_results]}"
|
||||
)
|
||||
return scores
|
||||
|
||||
|
||||
def _run_vbench_and_compare_to_golden(
|
||||
vbench_repo_root,
|
||||
videos_dir,
|
||||
trtllm_filename,
|
||||
golden_scores,
|
||||
llm_venv,
|
||||
title,
|
||||
max_score_diff=0.1,
|
||||
):
|
||||
"""Run VBench on videos_dir (TRT-LLM output only), compare to golden HF reference scores."""
|
||||
output_path = os.path.join(
|
||||
llm_venv.get_working_directory(), "vbench_eval_output", title.replace(" ", "_").lower()
|
||||
)
|
||||
os.makedirs(output_path, exist_ok=True)
|
||||
evaluate_script = os.path.join(vbench_repo_root, "evaluate.py")
|
||||
cmd = [
|
||||
evaluate_script,
|
||||
"--videos_path",
|
||||
videos_dir,
|
||||
"--output_path",
|
||||
output_path,
|
||||
"--mode",
|
||||
"custom_input",
|
||||
]
|
||||
cmd.extend(["--dimension"] + VBENCH_DIMENSIONS)
|
||||
venv_check_call(llm_venv, cmd)
|
||||
pattern = os.path.join(output_path, "*_eval_results.json")
|
||||
result_files = glob.glob(pattern)
|
||||
assert result_files, (
|
||||
f"No eval results found matching {pattern}; output dir: {os.listdir(output_path)}"
|
||||
)
|
||||
with open(result_files[0], "r") as f:
|
||||
results = json.load(f)
|
||||
for dim in VBENCH_DIMENSIONS:
|
||||
assert dim in results, (
|
||||
f"Expected dimension '{dim}' in results; keys: {list(results.keys())}"
|
||||
)
|
||||
scores_trtllm = _get_per_video_scores(results, trtllm_filename)
|
||||
scores_ref = golden_scores
|
||||
max_len = max(len(d) for d in VBENCH_DIMENSIONS)
|
||||
header = f"{'Dimension':<{max_len}} | {'TRT-LLM':>10} | {'HF Ref':>10} | {'Diff':>8}"
|
||||
sep = "-" * len(header)
|
||||
print("\n" + "=" * len(header))
|
||||
print(f"VBench dimension scores ({title}): TRT-LLM vs golden HF reference scores")
|
||||
print("=" * len(header))
|
||||
print(header)
|
||||
print(sep)
|
||||
max_diff_val = 0.0
|
||||
for dim in VBENCH_DIMENSIONS:
|
||||
t, r = scores_trtllm[dim], scores_ref[dim]
|
||||
diff = abs(t - r)
|
||||
max_diff_val = max(max_diff_val, diff)
|
||||
print(f"{dim:<{max_len}} | {t:>10.4f} | {r:>10.4f} | {diff:>8.4f}")
|
||||
print(sep)
|
||||
print(
|
||||
f"{' (all dimensions)':<{max_len}} | (TRT-LLM) | (golden) | max_diff={max_diff_val:.4f}"
|
||||
)
|
||||
print("=" * len(header) + "\n")
|
||||
for dim in VBENCH_DIMENSIONS:
|
||||
diff = abs(scores_trtllm[dim] - scores_ref[dim])
|
||||
assert diff < max_score_diff or scores_trtllm[dim] >= scores_ref[dim], (
|
||||
f"Dimension '{dim}' score difference {diff:.4f} >= {max_score_diff} "
|
||||
f"(TRT-LLM={scores_trtllm[dim]:.4f}, golden={scores_ref[dim]:.4f})"
|
||||
)
|
||||
|
||||
|
||||
def test_vbench_dimension_score_wan(vbench_repo_root, wan_trtllm_video_path, llm_venv):
|
||||
"""Run VBench on WAN TRT-LLM video; compare to golden HF reference scores (diff < 0.05 or TRT-LLM >= golden)."""
|
||||
videos_dir = os.path.dirname(wan_trtllm_video_path)
|
||||
assert os.path.isfile(wan_trtllm_video_path), "TRT-LLM video must exist"
|
||||
_run_vbench_and_compare_to_golden(
|
||||
vbench_repo_root,
|
||||
videos_dir,
|
||||
VISUAL_GEN_OUTPUT_VIDEO,
|
||||
VBENCH_WAN_GOLDEN_SCORES,
|
||||
llm_venv,
|
||||
title="WAN",
|
||||
max_score_diff=0.05,
|
||||
)
|
||||
@ -91,6 +91,17 @@ l0_b200:
|
||||
- unittest/tools/test_layer_wise_benchmarks.py::test_performance_alignment[1]
|
||||
- unittest/_torch/modeling/test_modeling_exaone4.py::TestEXAONE4::test_llm_load_1_FP8
|
||||
- unittest/kv_cache_manager_v2_tests/
|
||||
# ------------- Visual Gen tests ---------------
|
||||
- unittest/_torch/visual_gen/test_fused_qkv.py
|
||||
- unittest/_torch/visual_gen/test_quant_ops.py
|
||||
- unittest/_torch/visual_gen/test_attention_integration.py
|
||||
- unittest/_torch/visual_gen/test_attention_perf.py
|
||||
- unittest/_torch/visual_gen/test_trtllm_serve_endpoints.py
|
||||
- unittest/_torch/visual_gen/test_trtllm_serve_e2e.py
|
||||
- unittest/_torch/visual_gen/test_wan.py -k "not TestWanTwoStageTransformer"
|
||||
- unittest/_torch/visual_gen/test_wan_i2v.py
|
||||
- unittest/_torch/visual_gen/test_model_loader.py
|
||||
# - examples/test_visual_gen.py
|
||||
- condition:
|
||||
ranges:
|
||||
system_gpu_count:
|
||||
@ -161,6 +172,7 @@ l0_b200:
|
||||
- accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_nvfp4[moe_backend=CUTEDSL-mtp_nextn=0-fp8kv=True-attention_dp=True-cuda_graph=True-overlap_scheduler=True-torch_compile=False]
|
||||
- accuracy/test_llm_api_pytorch.py::TestSeedOss_36B::test_auto_dtype
|
||||
- accuracy/test_llm_api_pytorch.py::TestLlama3_1_8B_Instruct_RocketKV::test_auto_dtype
|
||||
- unittest/_torch/visual_gen/test_wan.py::TestWanTwoStageTransformer
|
||||
# ------------- AutoDeploy Backend Stages ---------------
|
||||
- condition:
|
||||
ranges:
|
||||
|
||||
@ -30,6 +30,12 @@ l0_dgx_b200:
|
||||
- disaggregated/test_disaggregated.py::test_disaggregated_gpt_oss_120b_harmony[gpt_oss/gpt-oss-120b]
|
||||
- accuracy/test_llm_api_pytorch.py::TestDeepSeekR1::test_nvfp4_multi_gpus[latency_adp_lmtp_tp4]
|
||||
- accuracy/test_llm_api_pytorch.py::TestMiniMaxM2::test_4gpus[attention_dp=False-cuda_graph=True-overlap_scheduler=True-tp_size=4-ep_size=4] TIMEOUT (60)
|
||||
# ------------- VisualGen multi-GPU tests ---------------
|
||||
- unittest/_torch/visual_gen/multi_gpu
|
||||
- unittest/_torch/visual_gen/test_wan.py::TestWanParallelism::test_cfg_2gpu_correctness
|
||||
- unittest/_torch/visual_gen/test_wan.py::TestWanCombinedOptimizations::test_all_optimizations_combined
|
||||
- unittest/_torch/visual_gen/test_wan_i2v.py::TestWanI2VParallelism::test_cfg_2gpu_correctness
|
||||
- unittest/_torch/visual_gen/test_wan_i2v.py::TestWanI2VCombinedOptimizations::test_all_optimizations_combined
|
||||
- condition:
|
||||
ranges:
|
||||
system_gpu_count:
|
||||
|
||||
1
tests/unittest/_torch/visual_gen/multi_gpu/__init__.py
Normal file
1
tests/unittest/_torch/visual_gen/multi_gpu/__init__.py
Normal file
@ -0,0 +1 @@
|
||||
"""Multi-GPU tests for visual generation modules."""
|
||||
@ -0,0 +1,505 @@
|
||||
"""Multi-GPU tests for Ulysses Attention.
|
||||
|
||||
These tests use torch.multiprocessing.spawn to launch multiple processes internally.
|
||||
Run with:
|
||||
pytest tests/visual_gen/multi_gpu/test_ulysses_attention.py -v
|
||||
"""
|
||||
|
||||
import os
|
||||
|
||||
os.environ["TLLM_DISABLE_MPI"] = "1"
|
||||
|
||||
import math
|
||||
from typing import Callable
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
import torch.multiprocessing as mp
|
||||
import torch.nn.functional as F
|
||||
|
||||
# Try to import the modules - skip tests if not available
|
||||
try:
|
||||
from tensorrt_llm._torch.attention_backend.interface import PredefinedAttentionMask
|
||||
from tensorrt_llm._torch.distributed import all_to_all_4d
|
||||
from tensorrt_llm._torch.visual_gen.attention_backend import UlyssesAttention, VanillaAttention
|
||||
from tensorrt_llm._utils import get_free_port
|
||||
|
||||
MODULES_AVAILABLE = True
|
||||
except ImportError:
|
||||
MODULES_AVAILABLE = False
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True, scope="module")
|
||||
def _cleanup_mpi_env():
|
||||
"""Clean up TLLM_DISABLE_MPI env var after tests complete."""
|
||||
yield
|
||||
os.environ.pop("TLLM_DISABLE_MPI", None)
|
||||
|
||||
|
||||
def init_distributed_worker(rank: int, world_size: int, backend: str = "gloo", port: int = 29500):
|
||||
"""Initialize distributed environment for a worker process."""
|
||||
os.environ["MASTER_ADDR"] = "localhost"
|
||||
os.environ["MASTER_PORT"] = str(port)
|
||||
os.environ["RANK"] = str(rank)
|
||||
os.environ["WORLD_SIZE"] = str(world_size)
|
||||
|
||||
# Use gloo backend for CPU, nccl for GPU
|
||||
if backend == "nccl" and torch.cuda.is_available():
|
||||
torch.cuda.set_device(rank % torch.cuda.device_count())
|
||||
dist.init_process_group(backend="nccl", rank=rank, world_size=world_size)
|
||||
else:
|
||||
dist.init_process_group(backend="gloo", rank=rank, world_size=world_size)
|
||||
|
||||
|
||||
def cleanup_distributed():
|
||||
"""Clean up distributed environment."""
|
||||
if dist.is_initialized():
|
||||
dist.destroy_process_group()
|
||||
|
||||
|
||||
def _distributed_worker(rank, world_size, backend, test_fn, port):
|
||||
"""Worker function that runs in each process. Module-level for pickling."""
|
||||
try:
|
||||
init_distributed_worker(rank, world_size, backend, port)
|
||||
test_fn(rank, world_size)
|
||||
except Exception as e:
|
||||
print(f"Rank {rank} failed with error: {e}")
|
||||
raise
|
||||
finally:
|
||||
cleanup_distributed()
|
||||
|
||||
|
||||
def run_test_in_distributed(world_size: int, test_fn: Callable, use_cuda: bool = True):
|
||||
"""Run a test function in a distributed environment with multiple processes.
|
||||
|
||||
Args:
|
||||
world_size: Number of processes to spawn
|
||||
test_fn: Test function to run (must be module-level for pickling).
|
||||
Should accept (rank, world_size) as arguments.
|
||||
use_cuda: Whether to use CUDA (requires sufficient GPUs)
|
||||
"""
|
||||
if not MODULES_AVAILABLE:
|
||||
pytest.skip("Required modules not available")
|
||||
|
||||
if use_cuda and torch.cuda.device_count() < world_size:
|
||||
pytest.skip(f"Test requires {world_size} GPUs, only {torch.cuda.device_count()} available")
|
||||
|
||||
backend = "nccl" if use_cuda else "gloo"
|
||||
|
||||
port = get_free_port()
|
||||
|
||||
# Spawn processes
|
||||
mp.spawn(
|
||||
_distributed_worker, args=(world_size, backend, test_fn, port), nprocs=world_size, join=True
|
||||
)
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Test logic functions (module-level so they can be pickled by mp.spawn)
|
||||
# =============================================================================
|
||||
|
||||
|
||||
def _logic_a2a_seq_to_head(rank, world_size):
|
||||
"""all_to_all_4d: sequence sharding to head sharding."""
|
||||
batch = 2
|
||||
seq_per_rank = 4
|
||||
heads = 8
|
||||
head_dim = 64
|
||||
|
||||
if heads % world_size != 0:
|
||||
heads = world_size * 2
|
||||
|
||||
device = torch.device(f"cuda:{rank}") if torch.cuda.is_available() else torch.device("cpu")
|
||||
|
||||
input_tensor = (
|
||||
torch.randn(batch, seq_per_rank, heads, head_dim, device=device, dtype=torch.float32)
|
||||
+ rank * 100
|
||||
)
|
||||
|
||||
output = all_to_all_4d(
|
||||
input_tensor,
|
||||
scatter_dim=2,
|
||||
gather_dim=1,
|
||||
process_group=None,
|
||||
)
|
||||
|
||||
expected_shape = (batch, seq_per_rank * world_size, heads // world_size, head_dim)
|
||||
assert output.shape == expected_shape, (
|
||||
f"Rank {rank}: Expected shape {expected_shape}, got {output.shape}"
|
||||
)
|
||||
assert output.device == device
|
||||
|
||||
|
||||
def _logic_a2a_head_to_seq(rank, world_size):
|
||||
"""all_to_all_4d: head sharding to sequence sharding."""
|
||||
batch = 2
|
||||
seq = 16
|
||||
heads_per_rank = 2
|
||||
head_dim = 64
|
||||
|
||||
device = torch.device(f"cuda:{rank}") if torch.cuda.is_available() else torch.device("cpu")
|
||||
|
||||
input_tensor = torch.randn(
|
||||
batch, seq, heads_per_rank, head_dim, device=device, dtype=torch.float32
|
||||
)
|
||||
|
||||
output = all_to_all_4d(
|
||||
input_tensor,
|
||||
scatter_dim=1,
|
||||
gather_dim=2,
|
||||
process_group=None,
|
||||
)
|
||||
|
||||
expected_shape = (batch, seq // world_size, heads_per_rank * world_size, head_dim)
|
||||
assert output.shape == expected_shape, (
|
||||
f"Rank {rank}: Expected shape {expected_shape}, got {output.shape}"
|
||||
)
|
||||
|
||||
|
||||
def _logic_a2a_roundtrip(rank, world_size):
|
||||
"""all_to_all_4d: forward and backward are inverses."""
|
||||
batch = 2
|
||||
seq_per_rank = 4
|
||||
heads = world_size * 4
|
||||
head_dim = 64
|
||||
|
||||
device = torch.device(f"cuda:{rank}") if torch.cuda.is_available() else torch.device("cpu")
|
||||
|
||||
original = torch.randn(batch, seq_per_rank, heads, head_dim, device=device, dtype=torch.float32)
|
||||
|
||||
intermediate = all_to_all_4d(original, scatter_dim=2, gather_dim=1, process_group=None)
|
||||
reconstructed = all_to_all_4d(intermediate, scatter_dim=1, gather_dim=2, process_group=None)
|
||||
|
||||
assert reconstructed.shape == original.shape
|
||||
torch.testing.assert_close(reconstructed, original, rtol=1e-5, atol=1e-5)
|
||||
|
||||
|
||||
def _logic_a2a_single_process(rank, world_size):
|
||||
"""all_to_all_4d: single process returns input unchanged."""
|
||||
batch, seq, heads, head_dim = 2, 8, 4, 64
|
||||
device = torch.device(f"cuda:{rank}" if torch.cuda.is_available() else "cpu")
|
||||
|
||||
input_tensor = torch.randn(batch, seq, heads, head_dim, device=device)
|
||||
|
||||
output = all_to_all_4d(input_tensor, scatter_dim=2, gather_dim=1, process_group=None)
|
||||
|
||||
torch.testing.assert_close(output, input_tensor)
|
||||
|
||||
|
||||
def _logic_ulysses_init(rank, world_size):
|
||||
"""UlyssesAttention initialization."""
|
||||
num_heads = world_size * 4
|
||||
head_dim = 64
|
||||
|
||||
inner = VanillaAttention(num_heads=num_heads // world_size, head_dim=head_dim)
|
||||
attention = UlyssesAttention(
|
||||
inner_backend=inner,
|
||||
process_group=None,
|
||||
)
|
||||
|
||||
assert attention.num_heads == num_heads
|
||||
assert attention.head_dim == head_dim
|
||||
assert attention.world_size == world_size
|
||||
assert rank >= 0 and rank < world_size
|
||||
|
||||
|
||||
def _logic_ulysses_forward(rank, world_size):
|
||||
"""UlyssesAttention forward pass."""
|
||||
batch = 2
|
||||
seq_per_rank = 8
|
||||
num_heads = world_size * 4
|
||||
head_dim = 64
|
||||
|
||||
device = torch.device(f"cuda:{rank}") if torch.cuda.is_available() else torch.device("cpu")
|
||||
|
||||
inner = VanillaAttention(num_heads=num_heads // world_size, head_dim=head_dim)
|
||||
attention = UlyssesAttention(
|
||||
inner_backend=inner,
|
||||
process_group=None,
|
||||
).to(device)
|
||||
|
||||
q = torch.randn(batch, seq_per_rank, num_heads, head_dim, device=device)
|
||||
k = torch.randn(batch, seq_per_rank, num_heads, head_dim, device=device)
|
||||
v = torch.randn(batch, seq_per_rank, num_heads, head_dim, device=device)
|
||||
|
||||
output = attention(q, k, v, batch_size=batch, seq_len=seq_per_rank * world_size)
|
||||
|
||||
assert output.shape == q.shape, f"Rank {rank}: Expected shape {q.shape}, got {output.shape}"
|
||||
assert output.device == device
|
||||
|
||||
|
||||
def _logic_ulysses_with_mask(rank, world_size):
|
||||
"""UlyssesAttention with attention mask."""
|
||||
batch = 2
|
||||
seq_per_rank = 8
|
||||
seq_full = seq_per_rank * world_size
|
||||
num_heads = world_size * 4
|
||||
head_dim = 64
|
||||
|
||||
device = torch.device(f"cuda:{rank}") if torch.cuda.is_available() else torch.device("cpu")
|
||||
|
||||
inner = VanillaAttention(num_heads=num_heads // world_size, head_dim=head_dim)
|
||||
attention = UlyssesAttention(
|
||||
inner_backend=inner,
|
||||
process_group=None,
|
||||
).to(device)
|
||||
|
||||
q = torch.randn(batch, seq_per_rank, num_heads, head_dim, device=device)
|
||||
k = torch.randn(batch, seq_per_rank, num_heads, head_dim, device=device)
|
||||
v = torch.randn(batch, seq_per_rank, num_heads, head_dim, device=device)
|
||||
|
||||
mask = PredefinedAttentionMask.CAUSAL
|
||||
|
||||
output = attention(q, k, v, batch_size=batch, seq_len=seq_full, attention_mask=mask)
|
||||
|
||||
assert output.shape == q.shape
|
||||
|
||||
|
||||
def _logic_ulysses_vs_standard_multi_gpu(rank, world_size):
|
||||
"""UlyssesAttention across multiple GPUs matches standard attention on the full sequence."""
|
||||
batch = 2
|
||||
seq_per_rank = 8
|
||||
seq_full = seq_per_rank * world_size
|
||||
num_heads = world_size * 4
|
||||
head_dim = 64
|
||||
|
||||
device = torch.device(f"cuda:{rank}" if torch.cuda.is_available() else "cpu")
|
||||
|
||||
# Every rank generates identical full tensors using the same seed.
|
||||
torch.manual_seed(42)
|
||||
q_full = torch.randn(batch, seq_full, num_heads, head_dim, device=device)
|
||||
k_full = torch.randn(batch, seq_full, num_heads, head_dim, device=device)
|
||||
v_full = torch.randn(batch, seq_full, num_heads, head_dim, device=device)
|
||||
|
||||
# Each rank takes its sequence shard.
|
||||
q_shard = q_full[:, rank * seq_per_rank : (rank + 1) * seq_per_rank].contiguous()
|
||||
k_shard = k_full[:, rank * seq_per_rank : (rank + 1) * seq_per_rank].contiguous()
|
||||
v_shard = v_full[:, rank * seq_per_rank : (rank + 1) * seq_per_rank].contiguous()
|
||||
|
||||
# Ulysses attention on shards.
|
||||
inner = VanillaAttention(num_heads=num_heads // world_size, head_dim=head_dim)
|
||||
attention = UlyssesAttention(
|
||||
inner_backend=inner,
|
||||
process_group=None,
|
||||
).to(device)
|
||||
|
||||
ulysses_output = attention(q_shard, k_shard, v_shard, batch_size=batch, seq_len=seq_full)
|
||||
|
||||
# Standard attention on the full tensors.
|
||||
q_std = q_full.transpose(1, 2) # [B, H, S, D]
|
||||
k_std = k_full.transpose(1, 2)
|
||||
v_std = v_full.transpose(1, 2)
|
||||
|
||||
std_output = F.scaled_dot_product_attention(
|
||||
q_std, k_std, v_std, scale=1.0 / math.sqrt(head_dim), dropout_p=0.0
|
||||
)
|
||||
std_output = std_output.transpose(1, 2).contiguous() # [B, S, H, D]
|
||||
|
||||
# Compare the shard slice.
|
||||
expected_shard = std_output[:, rank * seq_per_rank : (rank + 1) * seq_per_rank]
|
||||
torch.testing.assert_close(
|
||||
ulysses_output,
|
||||
expected_shard,
|
||||
rtol=1e-4,
|
||||
atol=1e-4,
|
||||
msg=f"Rank {rank}: Ulysses multi-GPU output differs from standard attention",
|
||||
)
|
||||
|
||||
|
||||
def _logic_ulysses_invalid_heads(rank, world_size):
|
||||
"""Invalid head count (not divisible by world_size) cannot be sharded."""
|
||||
assert rank >= 0 and rank < world_size
|
||||
|
||||
num_heads = world_size * 4 + 1 # Not divisible
|
||||
head_dim = 64
|
||||
|
||||
# With the decorator pattern, the caller is responsible for sharding heads.
|
||||
# num_heads // world_size truncates, so the wrapper's computed full head
|
||||
# count won't match the original.
|
||||
sharded_heads = num_heads // world_size
|
||||
inner = VanillaAttention(num_heads=sharded_heads, head_dim=head_dim)
|
||||
attention = UlyssesAttention(inner_backend=inner, process_group=None)
|
||||
assert attention.num_heads != num_heads # Truncation means mismatch
|
||||
|
||||
|
||||
def _logic_different_batch_sizes(rank, world_size):
|
||||
"""Various batch sizes."""
|
||||
num_heads = world_size * 4
|
||||
head_dim = 64
|
||||
seq_per_rank = 8
|
||||
device = torch.device(f"cuda:{rank}") if torch.cuda.is_available() else torch.device("cpu")
|
||||
|
||||
inner = VanillaAttention(num_heads=num_heads // world_size, head_dim=head_dim)
|
||||
attention = UlyssesAttention(
|
||||
inner_backend=inner,
|
||||
process_group=None,
|
||||
).to(device)
|
||||
|
||||
for batch_size in [1, 2, 4, 8]:
|
||||
q = torch.randn(batch_size, seq_per_rank, num_heads, head_dim, device=device)
|
||||
k = torch.randn(batch_size, seq_per_rank, num_heads, head_dim, device=device)
|
||||
v = torch.randn(batch_size, seq_per_rank, num_heads, head_dim, device=device)
|
||||
|
||||
output = attention(q, k, v, batch_size=batch_size, seq_len=seq_per_rank * world_size)
|
||||
assert output.shape == q.shape
|
||||
|
||||
|
||||
def _logic_different_head_dims(rank, world_size):
|
||||
"""Various head dims."""
|
||||
batch = 2
|
||||
seq_per_rank = 8
|
||||
num_heads = world_size * 4
|
||||
device = torch.device(f"cuda:{rank}") if torch.cuda.is_available() else torch.device("cpu")
|
||||
|
||||
for head_dim in [32, 64, 128]:
|
||||
inner = VanillaAttention(num_heads=num_heads // world_size, head_dim=head_dim)
|
||||
attention = UlyssesAttention(
|
||||
inner_backend=inner,
|
||||
process_group=None,
|
||||
).to(device)
|
||||
|
||||
q = torch.randn(batch, seq_per_rank, num_heads, head_dim, device=device)
|
||||
k = torch.randn(batch, seq_per_rank, num_heads, head_dim, device=device)
|
||||
v = torch.randn(batch, seq_per_rank, num_heads, head_dim, device=device)
|
||||
|
||||
output = attention(q, k, v, batch_size=batch, seq_len=seq_per_rank * world_size)
|
||||
assert output.shape == q.shape
|
||||
|
||||
|
||||
def _logic_world_size_4(rank, world_size):
|
||||
"""4-GPU test."""
|
||||
batch = 2
|
||||
seq_per_rank = 16
|
||||
num_heads = world_size * 8 # 32 heads total
|
||||
head_dim = 64
|
||||
|
||||
device = torch.device(f"cuda:{rank}") if torch.cuda.is_available() else torch.device("cpu")
|
||||
|
||||
inner = VanillaAttention(num_heads=num_heads // world_size, head_dim=head_dim)
|
||||
attention = UlyssesAttention(
|
||||
inner_backend=inner,
|
||||
process_group=None,
|
||||
).to(device)
|
||||
|
||||
q = torch.randn(batch, seq_per_rank, num_heads, head_dim, device=device)
|
||||
k = torch.randn(batch, seq_per_rank, num_heads, head_dim, device=device)
|
||||
v = torch.randn(batch, seq_per_rank, num_heads, head_dim, device=device)
|
||||
|
||||
output = attention(q, k, v, batch_size=batch, seq_len=seq_per_rank * world_size)
|
||||
assert output.shape == q.shape
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Test classes
|
||||
# =============================================================================
|
||||
|
||||
|
||||
class TestAllToAll4D:
|
||||
"""Tests for all_to_all_4d function."""
|
||||
|
||||
def test_all_to_all_4d_sequence_to_head(self):
|
||||
"""Test sequence sharding to head sharding transformation."""
|
||||
run_test_in_distributed(world_size=2, test_fn=_logic_a2a_seq_to_head, use_cuda=True)
|
||||
|
||||
def test_all_to_all_4d_head_to_sequence(self):
|
||||
"""Test head sharding to sequence sharding transformation."""
|
||||
run_test_in_distributed(world_size=2, test_fn=_logic_a2a_head_to_seq, use_cuda=True)
|
||||
|
||||
def test_all_to_all_4d_roundtrip(self):
|
||||
"""Test that forward and backward all-to-all are inverses."""
|
||||
run_test_in_distributed(world_size=2, test_fn=_logic_a2a_roundtrip, use_cuda=True)
|
||||
|
||||
def test_all_to_all_4d_single_process(self):
|
||||
"""Test that single process returns input unchanged."""
|
||||
run_test_in_distributed(world_size=1, test_fn=_logic_a2a_single_process, use_cuda=True)
|
||||
|
||||
|
||||
class TestUlyssesAttention:
|
||||
"""Tests for UlyssesAttention module."""
|
||||
|
||||
def test_ulysses_attention_initialization(self):
|
||||
"""Test UlyssesAttention initialization."""
|
||||
run_test_in_distributed(world_size=2, test_fn=_logic_ulysses_init, use_cuda=True)
|
||||
|
||||
def test_ulysses_attention_forward(self):
|
||||
"""Test UlyssesAttention forward pass."""
|
||||
run_test_in_distributed(world_size=2, test_fn=_logic_ulysses_forward, use_cuda=True)
|
||||
|
||||
def test_ulysses_attention_with_mask(self):
|
||||
"""Test UlyssesAttention with attention mask."""
|
||||
run_test_in_distributed(world_size=2, test_fn=_logic_ulysses_with_mask, use_cuda=True)
|
||||
|
||||
def test_ulysses_vs_standard_attention_single_gpu(self):
|
||||
"""Compare UlyssesAttention with standard attention on single GPU."""
|
||||
if not MODULES_AVAILABLE:
|
||||
pytest.skip("Required modules not available")
|
||||
|
||||
if not torch.cuda.is_available():
|
||||
pytest.skip("Test requires CUDA")
|
||||
|
||||
batch = 2
|
||||
seq = 16
|
||||
num_heads = 8
|
||||
head_dim = 64
|
||||
device = torch.device("cuda:0")
|
||||
|
||||
inner = VanillaAttention(num_heads=num_heads, head_dim=head_dim)
|
||||
ulysses_attn = UlyssesAttention(
|
||||
inner_backend=inner,
|
||||
process_group=None,
|
||||
).to(device)
|
||||
|
||||
torch.manual_seed(42)
|
||||
q = torch.randn(batch, seq, num_heads, head_dim, device=device)
|
||||
k = torch.randn(batch, seq, num_heads, head_dim, device=device)
|
||||
v = torch.randn(batch, seq, num_heads, head_dim, device=device)
|
||||
|
||||
ulysses_output = ulysses_attn(q, k, v, batch_size=batch, seq_len=seq)
|
||||
|
||||
q_std = q.transpose(1, 2) # [B, H, S, D]
|
||||
k_std = k.transpose(1, 2)
|
||||
v_std = v.transpose(1, 2)
|
||||
|
||||
std_output = F.scaled_dot_product_attention(
|
||||
q_std, k_std, v_std, scale=1.0 / math.sqrt(head_dim), dropout_p=0.0
|
||||
)
|
||||
std_output = std_output.transpose(1, 2).contiguous() # [B, S, H, D]
|
||||
|
||||
torch.testing.assert_close(
|
||||
ulysses_output,
|
||||
std_output,
|
||||
rtol=1e-4,
|
||||
atol=1e-4,
|
||||
msg="Ulysses attention output differs from standard attention",
|
||||
)
|
||||
|
||||
def test_ulysses_vs_standard_attention_multi_gpu(self):
|
||||
"""Compare UlyssesAttention across GPUs with standard attention on full sequence."""
|
||||
run_test_in_distributed(
|
||||
world_size=2, test_fn=_logic_ulysses_vs_standard_multi_gpu, use_cuda=True
|
||||
)
|
||||
|
||||
def test_ulysses_attention_invalid_heads(self):
|
||||
"""Test that invalid head count raises error."""
|
||||
run_test_in_distributed(world_size=2, test_fn=_logic_ulysses_invalid_heads, use_cuda=False)
|
||||
|
||||
|
||||
class TestUlyssesAttentionEdgeCases:
|
||||
"""Edge case tests for UlyssesAttention."""
|
||||
|
||||
def test_different_batch_sizes(self):
|
||||
"""Test with various batch sizes."""
|
||||
run_test_in_distributed(world_size=2, test_fn=_logic_different_batch_sizes, use_cuda=True)
|
||||
|
||||
def test_different_head_dims(self):
|
||||
"""Test with various head dims."""
|
||||
run_test_in_distributed(world_size=2, test_fn=_logic_different_head_dims, use_cuda=True)
|
||||
|
||||
def test_world_size_4(self):
|
||||
"""Test with 4 GPUs."""
|
||||
run_test_in_distributed(world_size=4, test_fn=_logic_world_size_4, use_cuda=True)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
pytest.main([__file__, "-v"])
|
||||
540
tests/unittest/_torch/visual_gen/test_attention_integration.py
Normal file
540
tests/unittest/_torch/visual_gen/test_attention_integration.py
Normal file
@ -0,0 +1,540 @@
|
||||
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
"""Test WAN Attention Integration.
|
||||
|
||||
Compares the new integrated attention (using TRT-LLM backend) with the original
|
||||
naive implementation to ensure numerical equivalence.
|
||||
"""
|
||||
|
||||
from types import SimpleNamespace
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from tensorrt_llm._torch.modules.rms_norm import RMSNorm
|
||||
from tensorrt_llm._torch.visual_gen.config import AttentionConfig, DiffusionModelConfig
|
||||
|
||||
# Import new integrated versions
|
||||
from tensorrt_llm._torch.visual_gen.modules.attention import Attention, QKVMode, apply_rotary_emb
|
||||
|
||||
# ============================================================================
|
||||
# Original naive implementations for comparison
|
||||
# ============================================================================
|
||||
|
||||
|
||||
class NaiveWanSelfAttention(nn.Module):
|
||||
"""Original naive self-attention implementation (for comparison)."""
|
||||
|
||||
def __init__(
|
||||
self, hidden_size: int, num_heads: int, head_dim: int, eps: float = 1e-6, dtype=None
|
||||
):
|
||||
super().__init__()
|
||||
self.num_heads = num_heads
|
||||
self.head_dim = head_dim
|
||||
self.hidden_size = hidden_size
|
||||
|
||||
# fused QKV projection
|
||||
self.to_qkv = nn.Linear(hidden_size, 3 * hidden_size, dtype=dtype)
|
||||
self.norm_q = RMSNorm(hidden_size=hidden_size, eps=eps, dtype=dtype, has_weights=True)
|
||||
self.norm_k = RMSNorm(hidden_size=hidden_size, eps=eps, dtype=dtype, has_weights=True)
|
||||
self.to_out = nn.ModuleList([nn.Linear(hidden_size, hidden_size, dtype=dtype)])
|
||||
|
||||
def forward(self, hidden_states, freqs_cos, freqs_sin):
|
||||
B, S = hidden_states.shape[:2]
|
||||
|
||||
q, k, v = self.to_qkv(hidden_states).chunk(3, dim=-1)
|
||||
|
||||
q = self.norm_q(q).view(B, S, self.num_heads, self.head_dim).transpose(1, 2)
|
||||
k = self.norm_k(k).view(B, S, self.num_heads, self.head_dim).transpose(1, 2)
|
||||
v = v.view(B, S, self.num_heads, self.head_dim).transpose(1, 2)
|
||||
|
||||
if freqs_cos is not None and freqs_sin is not None:
|
||||
q = apply_rotary_emb(q, freqs_cos, freqs_sin)
|
||||
k = apply_rotary_emb(k, freqs_cos, freqs_sin)
|
||||
|
||||
out = F.scaled_dot_product_attention(q, k, v, is_causal=False)
|
||||
out = out.transpose(1, 2).flatten(2)
|
||||
out = self.to_out[0](out)
|
||||
return out
|
||||
|
||||
|
||||
class NaiveWanCrossAttention(nn.Module):
|
||||
"""Original naive cross-attention implementation (for comparison)."""
|
||||
|
||||
def __init__(
|
||||
self, hidden_size: int, num_heads: int, head_dim: int, eps: float = 1e-6, dtype=None
|
||||
):
|
||||
super().__init__()
|
||||
self.num_heads = num_heads
|
||||
self.head_dim = head_dim
|
||||
self.hidden_size = hidden_size
|
||||
|
||||
self.to_q = nn.Linear(hidden_size, hidden_size, dtype=dtype)
|
||||
self.to_k = nn.Linear(hidden_size, hidden_size, dtype=dtype)
|
||||
self.to_v = nn.Linear(hidden_size, hidden_size, dtype=dtype)
|
||||
self.norm_q = RMSNorm(hidden_size=hidden_size, eps=eps, dtype=dtype, has_weights=True)
|
||||
self.norm_k = RMSNorm(hidden_size=hidden_size, eps=eps, dtype=dtype, has_weights=True)
|
||||
self.to_out = nn.ModuleList([nn.Linear(hidden_size, hidden_size, dtype=dtype)])
|
||||
|
||||
def forward(self, hidden_states, encoder_hidden_states):
|
||||
B, S = hidden_states.shape[:2]
|
||||
|
||||
q = self.norm_q(self.to_q(hidden_states))
|
||||
k = self.norm_k(self.to_k(encoder_hidden_states))
|
||||
v = self.to_v(encoder_hidden_states)
|
||||
|
||||
q = q.view(B, S, self.num_heads, self.head_dim).transpose(1, 2)
|
||||
k = k.view(B, -1, self.num_heads, self.head_dim).transpose(1, 2)
|
||||
v = v.view(B, -1, self.num_heads, self.head_dim).transpose(1, 2)
|
||||
|
||||
out = F.scaled_dot_product_attention(q, k, v, is_causal=False)
|
||||
out = out.transpose(1, 2).flatten(2)
|
||||
out = self.to_out[0](out)
|
||||
return out
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Test utilities
|
||||
# ============================================================================
|
||||
|
||||
|
||||
def create_model_config(
|
||||
hidden_size: int,
|
||||
num_heads: int,
|
||||
head_dim: int,
|
||||
eps: float = 1e-6,
|
||||
attn_backend: str = "VANILLA",
|
||||
):
|
||||
"""Create a mock DiffusionModelConfig for testing."""
|
||||
pretrained_config = SimpleNamespace(
|
||||
hidden_size=hidden_size,
|
||||
num_attention_heads=num_heads,
|
||||
attention_head_dim=head_dim,
|
||||
eps=eps,
|
||||
)
|
||||
|
||||
# Create a minimal config without quantization
|
||||
config = DiffusionModelConfig(
|
||||
pretrained_config=pretrained_config,
|
||||
attention=AttentionConfig(backend=attn_backend),
|
||||
skip_create_weights_in_init=False,
|
||||
)
|
||||
return config
|
||||
|
||||
|
||||
def copy_weights_self_attention(naive: NaiveWanSelfAttention, integrated: Attention):
|
||||
"""Copy weights from naive to integrated self-attention."""
|
||||
# QKV projection: naive has to_qkv, integrated has qkv_proj
|
||||
integrated.qkv_proj.weight.data.copy_(naive.to_qkv.weight.data)
|
||||
if naive.to_qkv.bias is not None and integrated.qkv_proj.bias is not None:
|
||||
integrated.qkv_proj.bias.data.copy_(naive.to_qkv.bias.data)
|
||||
|
||||
# QK norms
|
||||
integrated.norm_q.weight.data.copy_(naive.norm_q.weight.data)
|
||||
integrated.norm_k.weight.data.copy_(naive.norm_k.weight.data)
|
||||
|
||||
# Output projection
|
||||
integrated.to_out[0].weight.data.copy_(naive.to_out[0].weight.data)
|
||||
if naive.to_out[0].bias is not None and integrated.to_out[0].bias is not None:
|
||||
integrated.to_out[0].bias.data.copy_(naive.to_out[0].bias.data)
|
||||
|
||||
|
||||
def copy_weights_cross_attention(naive: NaiveWanCrossAttention, integrated: Attention):
|
||||
"""Copy weights from naive to integrated cross-attention."""
|
||||
# Q, K, V projections
|
||||
integrated.to_q.weight.data.copy_(naive.to_q.weight.data)
|
||||
integrated.to_k.weight.data.copy_(naive.to_k.weight.data)
|
||||
integrated.to_v.weight.data.copy_(naive.to_v.weight.data)
|
||||
|
||||
if naive.to_q.bias is not None and integrated.to_q.bias is not None:
|
||||
integrated.to_q.bias.data.copy_(naive.to_q.bias.data)
|
||||
if naive.to_k.bias is not None and integrated.to_k.bias is not None:
|
||||
integrated.to_k.bias.data.copy_(naive.to_k.bias.data)
|
||||
if naive.to_v.bias is not None and integrated.to_v.bias is not None:
|
||||
integrated.to_v.bias.data.copy_(naive.to_v.bias.data)
|
||||
|
||||
# QK norms
|
||||
integrated.norm_q.weight.data.copy_(naive.norm_q.weight.data)
|
||||
integrated.norm_k.weight.data.copy_(naive.norm_k.weight.data)
|
||||
|
||||
# Output projection
|
||||
integrated.to_out[0].weight.data.copy_(naive.to_out[0].weight.data)
|
||||
if naive.to_out[0].bias is not None and integrated.to_out[0].bias is not None:
|
||||
integrated.to_out[0].bias.data.copy_(naive.to_out[0].bias.data)
|
||||
|
||||
|
||||
def generate_rope_embeddings(
|
||||
seq_len: int, head_dim: int, device: torch.device, is_HSD: bool = False
|
||||
):
|
||||
"""Generate RoPE embeddings with full head_dim.
|
||||
|
||||
apply_rotary_emb expects freqs with full head_dim, then slices with [..., 0::2] and [..., 1::2].
|
||||
|
||||
Args:
|
||||
is_HSD: If True, returns [1, 1, S, D] for broadcasting with [B, H, S, D] (naive)
|
||||
If False, returns [1, S, 1, D] for broadcasting with [B, S, H, D] (integrated)
|
||||
"""
|
||||
position = torch.arange(seq_len, device=device).unsqueeze(1)
|
||||
# Use full head_dim - apply_rotary_emb will slice with 0::2 and 1::2
|
||||
div_term = torch.exp(
|
||||
torch.arange(0, head_dim, device=device) * (-torch.log(torch.tensor(10000.0)) / head_dim)
|
||||
)
|
||||
|
||||
if is_HSD:
|
||||
freqs_cos = torch.cos(position * div_term).unsqueeze(0).unsqueeze(0) # [1, 1, S, D]
|
||||
freqs_sin = torch.sin(position * div_term).unsqueeze(0).unsqueeze(0) # [1, 1, S, D]
|
||||
else:
|
||||
freqs_cos = torch.cos(position * div_term).unsqueeze(0).unsqueeze(2) # [1, S, 1, D]
|
||||
freqs_sin = torch.sin(position * div_term).unsqueeze(0).unsqueeze(2) # [1, S, 1, D]
|
||||
|
||||
return freqs_cos, freqs_sin
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Test functions
|
||||
# ============================================================================
|
||||
@pytest.mark.parametrize("attn_backend", ["VANILLA", "TRTLLM"])
|
||||
def test_self_attention_equivalence(attn_backend: str):
|
||||
"""Test that integrated self-attention produces same output as naive."""
|
||||
print("\n" + "=" * 60)
|
||||
print("Testing Self-Attention Equivalence")
|
||||
print("=" * 60)
|
||||
|
||||
# Config
|
||||
batch_size = 2
|
||||
seq_len = 16
|
||||
hidden_size = 128
|
||||
num_heads = 4
|
||||
head_dim = hidden_size // num_heads
|
||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
dtype = torch.bfloat16 # Use bf16 since flashinfer doesn't support fp32
|
||||
|
||||
print(f"Config: B={batch_size}, S={seq_len}, H={hidden_size}, heads={num_heads}")
|
||||
print(f"Device: {device}, dtype: {dtype}")
|
||||
|
||||
# Create models
|
||||
naive = NaiveWanSelfAttention(hidden_size, num_heads, head_dim, dtype=dtype).to(device)
|
||||
|
||||
model_config = create_model_config(hidden_size, num_heads, head_dim, attn_backend=attn_backend)
|
||||
integrated = Attention(
|
||||
hidden_size, num_heads, qkv_mode=QKVMode.FUSE_QKV, config=model_config
|
||||
).to(device) # self attention
|
||||
|
||||
# Copy weights
|
||||
copy_weights_self_attention(naive, integrated)
|
||||
|
||||
# Set to eval mode
|
||||
naive.eval()
|
||||
integrated.eval()
|
||||
|
||||
# Create inputs
|
||||
torch.manual_seed(42)
|
||||
hidden_states = torch.randn(batch_size, seq_len, hidden_size, device=device, dtype=dtype)
|
||||
# Naive uses [1, 1, S, D] (HSD format) - broadcasts with [B, H, S, D]
|
||||
freqs_cos_HSD, freqs_sin_HSD = generate_rope_embeddings(seq_len, head_dim, device, is_HSD=True)
|
||||
# Integrated uses [1, S, 1, D] (SHD format) - broadcasts with [B, S, H, D]
|
||||
freqs_cos_SHD, freqs_sin_SHD = generate_rope_embeddings(seq_len, head_dim, device, is_HSD=False)
|
||||
|
||||
# Forward pass
|
||||
with torch.no_grad():
|
||||
out_naive = naive(hidden_states, freqs_cos_HSD, freqs_sin_HSD)
|
||||
out_integrated = integrated(hidden_states, freqs=(freqs_cos_SHD, freqs_sin_SHD))
|
||||
|
||||
# Compare (using looser tolerance for bf16)
|
||||
max_diff = (out_naive - out_integrated).abs().max().item()
|
||||
mean_diff = (out_naive - out_integrated).abs().mean().item()
|
||||
is_close = torch.allclose(out_naive, out_integrated, rtol=1e-2, atol=1e-2)
|
||||
|
||||
print("\nResults:")
|
||||
print(f" Output shape: naive={out_naive.shape}, integrated={out_integrated.shape}")
|
||||
print(f" Max absolute difference: {max_diff:.2e}")
|
||||
print(f" Mean absolute difference: {mean_diff:.2e}")
|
||||
print(f" Outputs match (rtol=1e-2, atol=1e-2): {is_close}")
|
||||
|
||||
if is_close:
|
||||
print(" ✅ PASS: Self-attention outputs match!")
|
||||
else:
|
||||
print(" ❌ FAIL: Self-attention outputs differ!")
|
||||
|
||||
assert is_close, (
|
||||
f"Self-attention outputs differ: max_diff={max_diff:.2e}, mean_diff={mean_diff:.2e}"
|
||||
)
|
||||
return is_close
|
||||
|
||||
|
||||
@pytest.mark.parametrize("attn_backend", ["VANILLA"])
|
||||
def test_cross_attention_equivalence(attn_backend: str):
|
||||
"""Test that integrated cross-attention produces same output as naive."""
|
||||
print("\n" + "=" * 60)
|
||||
print("Testing Cross-Attention Equivalence")
|
||||
print("=" * 60)
|
||||
|
||||
# Config
|
||||
batch_size = 2
|
||||
seq_len = 16
|
||||
encoder_seq_len = 24 # Different from query seq_len
|
||||
hidden_size = 128
|
||||
num_heads = 4
|
||||
head_dim = hidden_size // num_heads
|
||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
dtype = torch.bfloat16 # Use bf16 since flashinfer doesn't support fp32
|
||||
|
||||
print(
|
||||
f"Config: B={batch_size}, S_q={seq_len}, S_kv={encoder_seq_len}, H={hidden_size}, heads={num_heads}"
|
||||
)
|
||||
print(f"Device: {device}, dtype: {dtype}")
|
||||
|
||||
# Create models
|
||||
naive = NaiveWanCrossAttention(hidden_size, num_heads, head_dim, dtype=dtype).to(device)
|
||||
|
||||
model_config = create_model_config(hidden_size, num_heads, head_dim, attn_backend=attn_backend)
|
||||
integrated = Attention(
|
||||
hidden_size, num_heads, qkv_mode=QKVMode.SEPARATE_QKV, config=model_config
|
||||
).to(device) # cross attention
|
||||
|
||||
# Copy weights
|
||||
copy_weights_cross_attention(naive, integrated)
|
||||
|
||||
# Set to eval mode
|
||||
naive.eval()
|
||||
integrated.eval()
|
||||
|
||||
# Create inputs
|
||||
torch.manual_seed(42)
|
||||
hidden_states = torch.randn(batch_size, seq_len, hidden_size, device=device, dtype=dtype)
|
||||
encoder_hidden_states = torch.randn(
|
||||
batch_size, encoder_seq_len, hidden_size, device=device, dtype=dtype
|
||||
)
|
||||
|
||||
# Forward pass
|
||||
with torch.no_grad():
|
||||
out_naive = naive(hidden_states, encoder_hidden_states)
|
||||
out_integrated = integrated(hidden_states, encoder_hidden_states)
|
||||
|
||||
# Compare (using looser tolerance for bf16)
|
||||
max_diff = (out_naive - out_integrated).abs().max().item()
|
||||
mean_diff = (out_naive - out_integrated).abs().mean().item()
|
||||
is_close = torch.allclose(out_naive, out_integrated, rtol=1e-2, atol=1e-2)
|
||||
|
||||
print("\nResults:")
|
||||
print(f" Output shape: naive={out_naive.shape}, integrated={out_integrated.shape}")
|
||||
print(f" Max absolute difference: {max_diff:.2e}")
|
||||
print(f" Mean absolute difference: {mean_diff:.2e}")
|
||||
print(f" Outputs match (rtol=1e-2, atol=1e-2): {is_close}")
|
||||
|
||||
if is_close:
|
||||
print(" ✅ PASS: Cross-attention outputs match!")
|
||||
else:
|
||||
print(" ❌ FAIL: Cross-attention outputs differ!")
|
||||
|
||||
assert is_close, (
|
||||
f"Cross-attention outputs differ: max_diff={max_diff:.2e}, mean_diff={mean_diff:.2e}"
|
||||
)
|
||||
return is_close
|
||||
|
||||
|
||||
def test_trtllm_cached_prepare():
|
||||
"""Test that TRTLLM attention cached prepare works correctly.
|
||||
|
||||
This test verifies that when running multiple forward passes with same B/S
|
||||
but different q/k/v values, the cached prepare phase doesn't cause incorrect
|
||||
results (i.e., outputs should differ when inputs differ).
|
||||
"""
|
||||
print("\n" + "=" * 60)
|
||||
print("Testing TRTLLM Cached Prepare Phase")
|
||||
print("=" * 60)
|
||||
|
||||
# Config - same B, S for all iterations
|
||||
batch_size = 2
|
||||
seq_len = 16
|
||||
hidden_size = 128
|
||||
num_heads = 4
|
||||
head_dim = hidden_size // num_heads
|
||||
num_iterations = 5
|
||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
dtype = torch.bfloat16
|
||||
|
||||
print(f"Config: B={batch_size}, S={seq_len}, H={hidden_size}, heads={num_heads}")
|
||||
print(f"Running {num_iterations} iterations with same B/S but different inputs")
|
||||
|
||||
# Create models - single instance to test caching
|
||||
naive = NaiveWanSelfAttention(hidden_size, num_heads, head_dim, dtype=dtype).to(device)
|
||||
model_config = create_model_config(hidden_size, num_heads, head_dim, attn_backend="TRTLLM")
|
||||
integrated = Attention(
|
||||
hidden_size, num_heads, qkv_mode=QKVMode.FUSE_QKV, config=model_config
|
||||
).to(device) # self attention
|
||||
|
||||
# Copy weights
|
||||
copy_weights_self_attention(naive, integrated)
|
||||
|
||||
naive.eval()
|
||||
integrated.eval()
|
||||
|
||||
# Generate freqs (same for all iterations since S is same)
|
||||
freqs_cos_HSD, freqs_sin_HSD = generate_rope_embeddings(seq_len, head_dim, device, is_HSD=True)
|
||||
freqs_cos_SHD, freqs_sin_SHD = generate_rope_embeddings(seq_len, head_dim, device, is_HSD=False)
|
||||
|
||||
all_passed = True
|
||||
outputs_integrated = []
|
||||
|
||||
with torch.no_grad():
|
||||
for i in range(num_iterations):
|
||||
# Different random inputs for each iteration
|
||||
torch.manual_seed(42 + i) # Different seed each time
|
||||
hidden_states = torch.randn(
|
||||
batch_size, seq_len, hidden_size, device=device, dtype=dtype
|
||||
)
|
||||
|
||||
out_naive = naive(hidden_states, freqs_cos_HSD, freqs_sin_HSD)
|
||||
out_integrated = integrated(hidden_states, freqs=(freqs_cos_SHD, freqs_sin_SHD))
|
||||
|
||||
# Check this iteration matches naive
|
||||
max_diff = (out_naive - out_integrated).abs().max().item()
|
||||
is_close = torch.allclose(out_naive, out_integrated, rtol=1e-2, atol=1e-2)
|
||||
|
||||
status = "✅" if is_close else "❌"
|
||||
print(f" Iteration {i + 1}: max_diff={max_diff:.2e} {status}")
|
||||
|
||||
if not is_close:
|
||||
all_passed = False
|
||||
|
||||
outputs_integrated.append(out_integrated.clone())
|
||||
|
||||
# Additional check: outputs should be DIFFERENT across iterations
|
||||
# (since inputs were different)
|
||||
print("\n Checking outputs differ across iterations (inputs were different):")
|
||||
outputs_differ = True
|
||||
for i in range(1, num_iterations):
|
||||
diff = (outputs_integrated[i] - outputs_integrated[0]).abs().max().item()
|
||||
if diff < 1e-6:
|
||||
print(
|
||||
f" ⚠️ Iteration {i + 1} output same as iteration 1 (diff={diff:.2e}) - possible caching bug!"
|
||||
)
|
||||
outputs_differ = False
|
||||
else:
|
||||
print(f" Iteration {i + 1} vs 1: diff={diff:.2e} ✅")
|
||||
|
||||
if all_passed and outputs_differ:
|
||||
print("\n ✅ PASS: Cached prepare works correctly!")
|
||||
else:
|
||||
print("\n ❌ FAIL: Cached prepare may have issues!")
|
||||
all_passed = False
|
||||
|
||||
assert all_passed, "Cached prepare: outputs did not match naive reference"
|
||||
assert outputs_differ, (
|
||||
"Cached prepare: outputs should differ across iterations with different inputs"
|
||||
)
|
||||
return all_passed
|
||||
|
||||
|
||||
def test_trtllm_varying_seq_len():
|
||||
"""Test TRTLLM attention with varying sequence lengths.
|
||||
|
||||
This tests that the prepare phase correctly handles different seq_lens
|
||||
and doesn't incorrectly reuse cached metadata.
|
||||
"""
|
||||
print("\n" + "=" * 60)
|
||||
print("Testing TRTLLM with Varying Sequence Lengths")
|
||||
print("=" * 60)
|
||||
|
||||
batch_size = 2
|
||||
hidden_size = 128
|
||||
num_heads = 4
|
||||
head_dim = hidden_size // num_heads
|
||||
seq_lens = [8, 16, 32, 16, 8] # Vary seq_len, including repeats
|
||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
dtype = torch.bfloat16
|
||||
|
||||
print(f"Config: B={batch_size}, H={hidden_size}, heads={num_heads}")
|
||||
print(f"Testing seq_lens: {seq_lens}")
|
||||
|
||||
# Create models - single instance to test caching across different seq_lens
|
||||
naive = NaiveWanSelfAttention(hidden_size, num_heads, head_dim, dtype=dtype).to(device)
|
||||
model_config = create_model_config(hidden_size, num_heads, head_dim, attn_backend="TRTLLM")
|
||||
integrated = Attention(
|
||||
hidden_size, num_heads, qkv_mode=QKVMode.FUSE_QKV, config=model_config
|
||||
).to(device) # self attention
|
||||
|
||||
copy_weights_self_attention(naive, integrated)
|
||||
|
||||
naive.eval()
|
||||
integrated.eval()
|
||||
|
||||
all_passed = True
|
||||
|
||||
with torch.no_grad():
|
||||
for i, seq_len in enumerate(seq_lens):
|
||||
torch.manual_seed(42 + i)
|
||||
hidden_states = torch.randn(
|
||||
batch_size, seq_len, hidden_size, device=device, dtype=dtype
|
||||
)
|
||||
|
||||
freqs_cos_HSD, freqs_sin_HSD = generate_rope_embeddings(
|
||||
seq_len, head_dim, device, is_HSD=True
|
||||
)
|
||||
freqs_cos_SHD, freqs_sin_SHD = generate_rope_embeddings(
|
||||
seq_len, head_dim, device, is_HSD=False
|
||||
)
|
||||
|
||||
out_naive = naive(hidden_states, freqs_cos_HSD, freqs_sin_HSD)
|
||||
out_integrated = integrated(hidden_states, freqs=(freqs_cos_SHD, freqs_sin_SHD))
|
||||
|
||||
max_diff = (out_naive - out_integrated).abs().max().item()
|
||||
is_close = torch.allclose(out_naive, out_integrated, rtol=1e-2, atol=1e-2)
|
||||
|
||||
status = "✅" if is_close else "❌"
|
||||
print(f" seq_len={seq_len:3d}: max_diff={max_diff:.2e} {status}")
|
||||
|
||||
if not is_close:
|
||||
all_passed = False
|
||||
|
||||
if all_passed:
|
||||
print("\n ✅ PASS: Varying seq_len handled correctly!")
|
||||
else:
|
||||
print("\n ❌ FAIL: Issues with varying seq_len!")
|
||||
|
||||
assert all_passed, "Varying seq_len: outputs did not match naive reference"
|
||||
return all_passed
|
||||
|
||||
|
||||
def run_all_tests():
|
||||
"""Run all tests and report results."""
|
||||
print("\n" + "=" * 60)
|
||||
print("WAN Attention Integration Tests")
|
||||
print("=" * 60)
|
||||
|
||||
results = {}
|
||||
|
||||
# Run self-attention tests with different backends
|
||||
for backend in ["VANILLA", "TRTLLM"]:
|
||||
results[f"self_attention_{backend}"] = test_self_attention_equivalence(backend)
|
||||
|
||||
# Run cross-attention test (VANILLA only)
|
||||
results["cross_attention_VANILLA"] = test_cross_attention_equivalence("VANILLA")
|
||||
|
||||
# Run TRTLLM-specific caching tests
|
||||
results["trtllm_cached_prepare"] = test_trtllm_cached_prepare()
|
||||
results["trtllm_varying_seq_len"] = test_trtllm_varying_seq_len()
|
||||
|
||||
print("\n" + "=" * 60)
|
||||
print("Summary")
|
||||
print("=" * 60)
|
||||
|
||||
all_passed = all(results.values())
|
||||
for name, passed in results.items():
|
||||
status = "✅ PASS" if passed else "❌ FAIL"
|
||||
print(f" {name}: {status}")
|
||||
|
||||
print()
|
||||
if all_passed:
|
||||
print("All tests passed! ✅")
|
||||
else:
|
||||
print("Some tests failed! ❌")
|
||||
|
||||
return all_passed
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
run_all_tests()
|
||||
622
tests/unittest/_torch/visual_gen/test_attention_perf.py
Normal file
622
tests/unittest/_torch/visual_gen/test_attention_perf.py
Normal file
@ -0,0 +1,622 @@
|
||||
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""WAN Attention Performance Benchmark.
|
||||
|
||||
Compares VANILLA vs TRTLLM attention backends for visual generation models.
|
||||
Uses CUDA events for precise GPU timing and supports NVTX profiling.
|
||||
|
||||
Usage:
|
||||
# Run all tests
|
||||
python test_attention_perf.py
|
||||
|
||||
# With Nsight Systems profiling
|
||||
nsys profile -t cuda,nvtx --nvtx-capture=range -o wan_attn_perf python test_attention_perf.py
|
||||
|
||||
# Run specific tests with pytest
|
||||
pytest test_attention_perf.py -v -k "test_self_attention_perf"
|
||||
"""
|
||||
|
||||
import time
|
||||
from contextlib import contextmanager
|
||||
from types import SimpleNamespace
|
||||
from typing import Dict, Optional, Tuple
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from tensorrt_llm._torch.visual_gen.config import AttentionConfig, DiffusionModelConfig
|
||||
from tensorrt_llm._torch.visual_gen.modules.attention import Attention, QKVMode
|
||||
|
||||
# NVTX support for profiling
|
||||
try:
|
||||
import nvtx
|
||||
|
||||
NVTX_AVAILABLE = True
|
||||
if hasattr(nvtx, "annotate"):
|
||||
NVTX_METHOD = "annotate"
|
||||
elif hasattr(nvtx, "range_start") and hasattr(nvtx, "range_end"):
|
||||
NVTX_METHOD = "range"
|
||||
else:
|
||||
NVTX_METHOD = None
|
||||
NVTX_AVAILABLE = False
|
||||
except ImportError:
|
||||
NVTX_AVAILABLE = False
|
||||
NVTX_METHOD = None
|
||||
|
||||
# Torch profiler support
|
||||
try:
|
||||
from torch.profiler import record_function
|
||||
|
||||
TORCH_PROFILER_AVAILABLE = True
|
||||
except ImportError:
|
||||
TORCH_PROFILER_AVAILABLE = False
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Timing utilities
|
||||
# ============================================================================
|
||||
|
||||
|
||||
@contextmanager
|
||||
def cuda_timer(device: torch.device):
|
||||
"""Context manager for precise GPU timing using CUDA events."""
|
||||
if device.type == "cuda":
|
||||
start_event = torch.cuda.Event(enable_timing=True)
|
||||
end_event = torch.cuda.Event(enable_timing=True)
|
||||
start_event.record()
|
||||
|
||||
def get_elapsed_time():
|
||||
end_event.record()
|
||||
torch.cuda.synchronize()
|
||||
return start_event.elapsed_time(end_event)
|
||||
|
||||
yield get_elapsed_time
|
||||
else:
|
||||
start_time = time.perf_counter()
|
||||
|
||||
def get_elapsed_time():
|
||||
return (time.perf_counter() - start_time) * 1000
|
||||
|
||||
yield get_elapsed_time
|
||||
|
||||
|
||||
@contextmanager
|
||||
def nvtx_range(name: str):
|
||||
"""Context manager for NVTX range profiling."""
|
||||
if NVTX_AVAILABLE and NVTX_METHOD:
|
||||
if NVTX_METHOD == "annotate":
|
||||
with nvtx.annotate(name):
|
||||
yield
|
||||
elif NVTX_METHOD == "range":
|
||||
range_id = nvtx.range_start(name)
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
nvtx.range_end(range_id)
|
||||
else:
|
||||
yield
|
||||
else:
|
||||
yield
|
||||
|
||||
|
||||
@contextmanager
|
||||
def torch_profiler_range(name: str):
|
||||
"""Context manager for torch profiler range."""
|
||||
if TORCH_PROFILER_AVAILABLE:
|
||||
with record_function(name):
|
||||
yield
|
||||
else:
|
||||
yield
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Test utilities
|
||||
# ============================================================================
|
||||
|
||||
|
||||
def create_model_config(
|
||||
hidden_size: int,
|
||||
num_heads: int,
|
||||
head_dim: int,
|
||||
eps: float = 1e-6,
|
||||
attn_backend: str = "VANILLA",
|
||||
) -> DiffusionModelConfig:
|
||||
"""Create a mock DiffusionModelConfig for testing."""
|
||||
pretrained_config = SimpleNamespace(
|
||||
hidden_size=hidden_size,
|
||||
num_attention_heads=num_heads,
|
||||
attention_head_dim=head_dim,
|
||||
eps=eps,
|
||||
)
|
||||
|
||||
config = DiffusionModelConfig(
|
||||
pretrained_config=pretrained_config,
|
||||
attention=AttentionConfig(backend=attn_backend),
|
||||
skip_create_weights_in_init=False,
|
||||
)
|
||||
return config
|
||||
|
||||
|
||||
def generate_rope_embeddings(
|
||||
seq_len: int, head_dim: int, device: torch.device, is_HSD: bool = False
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""Generate RoPE embeddings.
|
||||
|
||||
Args:
|
||||
seq_len: Sequence length
|
||||
head_dim: Head dimension
|
||||
device: Target device
|
||||
is_HSD: If True, returns [1, 1, S, D] for HSD format, else [1, S, 1, D] for SHD
|
||||
|
||||
Returns:
|
||||
Tuple of (freqs_cos, freqs_sin)
|
||||
"""
|
||||
position = torch.arange(seq_len, device=device).unsqueeze(1)
|
||||
div_term = torch.exp(
|
||||
torch.arange(0, head_dim, device=device) * (-torch.log(torch.tensor(10000.0)) / head_dim)
|
||||
)
|
||||
|
||||
if is_HSD:
|
||||
freqs_cos = torch.cos(position * div_term).unsqueeze(0).unsqueeze(0)
|
||||
freqs_sin = torch.sin(position * div_term).unsqueeze(0).unsqueeze(0)
|
||||
else:
|
||||
freqs_cos = torch.cos(position * div_term).unsqueeze(0).unsqueeze(2)
|
||||
freqs_sin = torch.sin(position * div_term).unsqueeze(0).unsqueeze(2)
|
||||
|
||||
return freqs_cos, freqs_sin
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Performance benchmark class
|
||||
# ============================================================================
|
||||
|
||||
|
||||
class WanAttentionPerformanceBenchmark:
|
||||
"""Performance benchmark for WAN attention backends."""
|
||||
|
||||
# WAN model configurations: (batch_size, num_heads, seq_len, head_dim, description)
|
||||
TEST_SIZES = [
|
||||
# Wan2.1-T2V-1.3B configurations
|
||||
(1, 24, 14040, 64, "Wan-1.3B 480p 2s"),
|
||||
(1, 24, 3510, 64, "Wan-1.3B 480p 2s ring4"),
|
||||
(1, 24, 7020, 64, "Wan-1.3B 480p 2s ring2"),
|
||||
# Wan2.1-T2V-14B configurations
|
||||
(1, 40, 75600, 128, "Wan-14B 720p 5s"),
|
||||
(1, 40, 37800, 128, "Wan-14B 720p 5s ring2"),
|
||||
(1, 40, 18900, 128, "Wan-14B 720p 5s ring4"),
|
||||
(1, 40, 9450, 128, "Wan-14B 720p 5s ring8"),
|
||||
# Ulysses parallelism configurations
|
||||
(1, 20, 75600, 128, "Wan-14B 720p ulysses2"),
|
||||
(1, 10, 75600, 128, "Wan-14B 720p ulysses4"),
|
||||
(1, 5, 75600, 128, "Wan-14B 720p ulysses8"),
|
||||
# Smaller test cases for quick validation
|
||||
(2, 24, 1024, 64, "Small batch2"),
|
||||
(1, 24, 4096, 64, "Medium 4k"),
|
||||
(1, 40, 8192, 128, "Large 8k"),
|
||||
]
|
||||
|
||||
# Quick test sizes for CI/pytest
|
||||
QUICK_TEST_SIZES = [
|
||||
(1, 24, 1024, 64, "Quick 1k"),
|
||||
(1, 24, 2048, 64, "Quick 2k"),
|
||||
(2, 24, 1024, 64, "Quick batch2"),
|
||||
]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
device: Optional[torch.device] = None,
|
||||
dtype: torch.dtype = torch.bfloat16,
|
||||
warmup_iterations: int = 10,
|
||||
benchmark_iterations: int = 50,
|
||||
):
|
||||
self.device = device or torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
self.dtype = dtype
|
||||
self.warmup_iterations = warmup_iterations
|
||||
self.benchmark_iterations = benchmark_iterations
|
||||
self.backends = ["VANILLA", "TRTLLM"]
|
||||
|
||||
def create_attention_model(
|
||||
self, hidden_size: int, num_heads: int, head_dim: int, backend: str
|
||||
) -> Attention:
|
||||
"""Create a WAN self-attention model with specified backend."""
|
||||
config = create_model_config(hidden_size, num_heads, head_dim, attn_backend=backend)
|
||||
model = Attention(hidden_size, num_heads, qkv_mode=QKVMode.FUSE_QKV, config=config).to(
|
||||
self.device
|
||||
)
|
||||
model.eval()
|
||||
return model
|
||||
|
||||
def create_test_data(
|
||||
self, batch_size: int, seq_len: int, hidden_size: int, head_dim: int
|
||||
) -> Tuple[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
|
||||
"""Create test input data and RoPE embeddings."""
|
||||
hidden_states = torch.randn(
|
||||
batch_size, seq_len, hidden_size, device=self.device, dtype=self.dtype
|
||||
)
|
||||
freqs = generate_rope_embeddings(seq_len, head_dim, self.device, is_HSD=False)
|
||||
return hidden_states, freqs
|
||||
|
||||
def estimate_memory_gb(
|
||||
self, batch_size: int, num_heads: int, seq_len: int, head_dim: int
|
||||
) -> float:
|
||||
"""Estimate tensor memory usage in GB."""
|
||||
hidden_size = num_heads * head_dim
|
||||
# Input: [B, S, H] + Q, K, V: [B, S, num_heads, head_dim] each
|
||||
bytes_per_element = 2 # bf16
|
||||
input_bytes = batch_size * seq_len * hidden_size * bytes_per_element
|
||||
qkv_bytes = 3 * batch_size * seq_len * num_heads * head_dim * bytes_per_element
|
||||
output_bytes = batch_size * seq_len * hidden_size * bytes_per_element
|
||||
# Attention matrix can be O(S^2) but flash attention avoids materializing it
|
||||
return (input_bytes + qkv_bytes + output_bytes) / (1024**3)
|
||||
|
||||
def benchmark_single(
|
||||
self,
|
||||
batch_size: int,
|
||||
num_heads: int,
|
||||
seq_len: int,
|
||||
head_dim: int,
|
||||
backend: str,
|
||||
verbose: bool = True,
|
||||
) -> Optional[Dict]:
|
||||
"""Benchmark a single configuration.
|
||||
|
||||
Returns:
|
||||
Dict with timing statistics or None if test failed/skipped
|
||||
"""
|
||||
hidden_size = num_heads * head_dim
|
||||
|
||||
# Memory check
|
||||
est_memory = self.estimate_memory_gb(batch_size, num_heads, seq_len, head_dim)
|
||||
if est_memory > 8.0:
|
||||
if verbose:
|
||||
print(f" Skipping - estimated memory {est_memory:.2f}GB > 8GB limit")
|
||||
return None
|
||||
|
||||
try:
|
||||
# Create model and data
|
||||
model = self.create_attention_model(hidden_size, num_heads, head_dim, backend)
|
||||
hidden_states, freqs = self.create_test_data(batch_size, seq_len, hidden_size, head_dim)
|
||||
|
||||
# Warmup
|
||||
with nvtx_range(f"warmup_{backend}"):
|
||||
with torch_profiler_range(f"warmup_{backend}"):
|
||||
with torch.no_grad():
|
||||
for _ in range(self.warmup_iterations):
|
||||
_ = model(hidden_states, freqs=freqs)
|
||||
|
||||
if self.device.type == "cuda":
|
||||
torch.cuda.synchronize()
|
||||
|
||||
# Benchmark
|
||||
times = []
|
||||
with nvtx_range(f"benchmark_{backend}"):
|
||||
with torch_profiler_range(f"benchmark_{backend}"):
|
||||
with torch.no_grad():
|
||||
for i in range(self.benchmark_iterations):
|
||||
with nvtx_range(f"iter_{backend}_{i}"):
|
||||
with cuda_timer(self.device) as get_time:
|
||||
_ = model(hidden_states, freqs=freqs)
|
||||
times.append(get_time())
|
||||
|
||||
# Statistics
|
||||
times_tensor = torch.tensor(times)
|
||||
stats = {
|
||||
"avg_ms": times_tensor.mean().item(),
|
||||
"min_ms": times_tensor.min().item(),
|
||||
"max_ms": times_tensor.max().item(),
|
||||
"std_ms": times_tensor.std().item(),
|
||||
"median_ms": times_tensor.median().item(),
|
||||
"p95_ms": torch.quantile(times_tensor, 0.95).item(),
|
||||
"p99_ms": torch.quantile(times_tensor, 0.99).item(),
|
||||
}
|
||||
|
||||
# Calculate throughput (approximate TOPS)
|
||||
total_ops = batch_size * num_heads * seq_len * seq_len * head_dim
|
||||
stats["throughput_tops"] = (total_ops / 1e12) / (stats["avg_ms"] / 1000)
|
||||
|
||||
if verbose:
|
||||
print(
|
||||
f" {backend}: avg={stats['avg_ms']:.3f}ms, "
|
||||
f"median={stats['median_ms']:.3f}ms, "
|
||||
f"throughput={stats['throughput_tops']:.2f} TOPS"
|
||||
)
|
||||
|
||||
return stats
|
||||
|
||||
except Exception as e:
|
||||
if verbose:
|
||||
print(f" {backend}: ERROR - {e}")
|
||||
return None
|
||||
|
||||
def benchmark_comparison(
|
||||
self,
|
||||
batch_size: int,
|
||||
num_heads: int,
|
||||
seq_len: int,
|
||||
head_dim: int,
|
||||
description: str = "",
|
||||
verbose: bool = True,
|
||||
) -> Dict[str, Optional[Dict]]:
|
||||
"""Benchmark and compare all backends for a given configuration."""
|
||||
if verbose:
|
||||
print(
|
||||
f"\nBenchmarking: ({batch_size}, {num_heads}, {seq_len}, {head_dim}) {description}"
|
||||
)
|
||||
print(f" Device: {self.device}, dtype: {self.dtype}")
|
||||
print(f" Warmup: {self.warmup_iterations}, Iterations: {self.benchmark_iterations}")
|
||||
|
||||
results = {}
|
||||
for backend in self.backends:
|
||||
results[backend] = self.benchmark_single(
|
||||
batch_size, num_heads, seq_len, head_dim, backend, verbose
|
||||
)
|
||||
|
||||
# Print comparison
|
||||
if verbose and results.get("VANILLA") and results.get("TRTLLM"):
|
||||
vanilla_avg = results["VANILLA"]["avg_ms"]
|
||||
trtllm_avg = results["TRTLLM"]["avg_ms"]
|
||||
speedup = vanilla_avg / trtllm_avg
|
||||
print(f" TRTLLM vs VANILLA: {speedup:.2f}x {'faster' if speedup > 1 else 'slower'}")
|
||||
|
||||
return results
|
||||
|
||||
def run_full_benchmark(self, use_quick_sizes: bool = False) -> Dict:
|
||||
"""Run benchmark on all configured sizes."""
|
||||
test_sizes = self.QUICK_TEST_SIZES if use_quick_sizes else self.TEST_SIZES
|
||||
|
||||
print("\n" + "=" * 70)
|
||||
print("WAN ATTENTION PERFORMANCE BENCHMARK")
|
||||
print("=" * 70)
|
||||
print(f"Device: {self.device}")
|
||||
print(f"dtype: {self.dtype}")
|
||||
print(f"Backends: {self.backends}")
|
||||
print(f"NVTX: {'Enabled' if NVTX_AVAILABLE else 'Disabled'}")
|
||||
print(f"Torch Profiler: {'Enabled' if TORCH_PROFILER_AVAILABLE else 'Disabled'}")
|
||||
|
||||
all_results = {}
|
||||
|
||||
with nvtx_range("wan_attention_benchmark"):
|
||||
with torch_profiler_range("wan_attention_benchmark"):
|
||||
for batch_size, num_heads, seq_len, head_dim, desc in test_sizes:
|
||||
key = f"{desc}_{batch_size}x{num_heads}x{seq_len}x{head_dim}"
|
||||
results = self.benchmark_comparison(
|
||||
batch_size, num_heads, seq_len, head_dim, desc
|
||||
)
|
||||
all_results[key] = {
|
||||
"config": {
|
||||
"batch_size": batch_size,
|
||||
"num_heads": num_heads,
|
||||
"seq_len": seq_len,
|
||||
"head_dim": head_dim,
|
||||
"description": desc,
|
||||
},
|
||||
"results": results,
|
||||
}
|
||||
|
||||
# Print summary
|
||||
self._print_summary(all_results)
|
||||
return all_results
|
||||
|
||||
def _print_summary(self, all_results: Dict) -> None:
|
||||
"""Print benchmark summary table."""
|
||||
print("\n" + "=" * 70)
|
||||
print("BENCHMARK SUMMARY")
|
||||
print("=" * 70)
|
||||
print(f"{'Configuration':<40} {'VANILLA (ms)':<15} {'TRTLLM (ms)':<15} {'Speedup':<10}")
|
||||
print("-" * 70)
|
||||
|
||||
for key, data in all_results.items():
|
||||
desc = data["config"]["description"]
|
||||
results = data["results"]
|
||||
|
||||
vanilla = results.get("VANILLA")
|
||||
trtllm = results.get("TRTLLM")
|
||||
|
||||
vanilla_str = f"{vanilla['avg_ms']:.2f}" if vanilla else "N/A"
|
||||
trtllm_str = f"{trtllm['avg_ms']:.2f}" if trtllm else "N/A"
|
||||
|
||||
if vanilla and trtllm:
|
||||
speedup = vanilla["avg_ms"] / trtllm["avg_ms"]
|
||||
speedup_str = f"{speedup:.2f}x"
|
||||
else:
|
||||
speedup_str = "N/A"
|
||||
|
||||
print(f"{desc:<40} {vanilla_str:<15} {trtllm_str:<15} {speedup_str:<10}")
|
||||
|
||||
def test_memory_usage(
|
||||
self,
|
||||
batch_size: int = 1,
|
||||
num_heads: int = 24,
|
||||
seq_len: int = 4096,
|
||||
head_dim: int = 64,
|
||||
) -> Dict[str, Dict]:
|
||||
"""Test memory usage of different backends."""
|
||||
if self.device.type != "cuda":
|
||||
print("Memory test requires CUDA device")
|
||||
return {}
|
||||
|
||||
print("\n" + "=" * 70)
|
||||
print("MEMORY USAGE TEST")
|
||||
print("=" * 70)
|
||||
print(f"Config: ({batch_size}, {num_heads}, {seq_len}, {head_dim})")
|
||||
|
||||
hidden_size = num_heads * head_dim
|
||||
memory_results = {}
|
||||
|
||||
for backend in self.backends:
|
||||
print(f"\nTesting {backend}...")
|
||||
|
||||
try:
|
||||
# Clear cache
|
||||
torch.cuda.empty_cache()
|
||||
torch.cuda.reset_peak_memory_stats()
|
||||
|
||||
# Create model and data
|
||||
model = self.create_attention_model(hidden_size, num_heads, head_dim, backend)
|
||||
hidden_states, freqs = self.create_test_data(
|
||||
batch_size, seq_len, hidden_size, head_dim
|
||||
)
|
||||
|
||||
# Warmup
|
||||
with torch.no_grad():
|
||||
_ = model(hidden_states, freqs=freqs)
|
||||
|
||||
torch.cuda.synchronize()
|
||||
torch.cuda.reset_peak_memory_stats()
|
||||
|
||||
# Forward pass
|
||||
with nvtx_range(f"memory_test_{backend}"):
|
||||
with torch.no_grad():
|
||||
_ = model(hidden_states, freqs=freqs)
|
||||
|
||||
torch.cuda.synchronize()
|
||||
|
||||
peak_memory_gb = torch.cuda.max_memory_allocated() / (1024**3)
|
||||
current_memory_gb = torch.cuda.memory_allocated() / (1024**3)
|
||||
|
||||
memory_results[backend] = {
|
||||
"peak_memory_gb": peak_memory_gb,
|
||||
"current_memory_gb": current_memory_gb,
|
||||
}
|
||||
|
||||
print(f" Peak memory: {peak_memory_gb:.3f} GB")
|
||||
print(f" Current memory: {current_memory_gb:.3f} GB")
|
||||
|
||||
except Exception as e:
|
||||
print(f" ERROR: {e}")
|
||||
memory_results[backend] = None
|
||||
|
||||
return memory_results
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Pytest test functions
|
||||
# ============================================================================
|
||||
|
||||
|
||||
class TestWanAttentionPerformance:
|
||||
"""Pytest test class for WAN attention performance."""
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def setup(self):
|
||||
"""Setup test environment."""
|
||||
self.benchmark = WanAttentionPerformanceBenchmark(
|
||||
warmup_iterations=5,
|
||||
benchmark_iterations=20,
|
||||
)
|
||||
|
||||
@pytest.mark.parametrize("backend", ["VANILLA", "TRTLLM"])
|
||||
def test_self_attention_perf(self, backend: str):
|
||||
"""Test that attention backend runs without errors."""
|
||||
batch_size, num_heads, seq_len, head_dim = 1, 24, 1024, 64
|
||||
|
||||
result = self.benchmark.benchmark_single(
|
||||
batch_size, num_heads, seq_len, head_dim, backend, verbose=True
|
||||
)
|
||||
|
||||
if result is not None:
|
||||
assert result["avg_ms"] > 0, "Average time should be positive"
|
||||
assert result["min_ms"] <= result["avg_ms"], "Min should be <= avg"
|
||||
assert result["max_ms"] >= result["avg_ms"], "Max should be >= avg"
|
||||
print(f" {backend}: avg={result['avg_ms']:.3f}ms OK")
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"batch_size,num_heads,seq_len,head_dim",
|
||||
[
|
||||
(1, 24, 1024, 64),
|
||||
(1, 24, 2048, 64),
|
||||
(2, 24, 1024, 64),
|
||||
],
|
||||
)
|
||||
def test_backend_comparison(self, batch_size: int, num_heads: int, seq_len: int, head_dim: int):
|
||||
"""Test VANILLA vs TRTLLM comparison."""
|
||||
results = self.benchmark.benchmark_comparison(
|
||||
batch_size, num_heads, seq_len, head_dim, verbose=True
|
||||
)
|
||||
|
||||
# At least one backend should work
|
||||
assert any(r is not None for r in results.values()), "All backends failed"
|
||||
|
||||
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA required")
|
||||
def test_memory_usage(self):
|
||||
"""Test memory usage tracking."""
|
||||
memory_results = self.benchmark.test_memory_usage(
|
||||
batch_size=1, num_heads=24, seq_len=2048, head_dim=64
|
||||
)
|
||||
|
||||
for backend, result in memory_results.items():
|
||||
if result is not None:
|
||||
assert result["peak_memory_gb"] > 0, f"{backend} peak memory should be positive"
|
||||
|
||||
def test_quick_benchmark(self):
|
||||
"""Run quick benchmark for CI validation."""
|
||||
results = self.benchmark.run_full_benchmark(use_quick_sizes=True)
|
||||
assert len(results) > 0, "Should have benchmark results"
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Main entry point
|
||||
# ============================================================================
|
||||
|
||||
|
||||
def main():
|
||||
"""Run full benchmark suite."""
|
||||
print("\n" + "=" * 70)
|
||||
print("WAN ATTENTION PERFORMANCE BENCHMARK SUITE")
|
||||
print("=" * 70)
|
||||
|
||||
if not torch.cuda.is_available():
|
||||
print("WARNING: CUDA not available, results will not be meaningful")
|
||||
|
||||
# Print profiling instructions
|
||||
if torch.cuda.is_available():
|
||||
print("\nPROFILING INSTRUCTIONS:")
|
||||
print("-" * 50)
|
||||
if NVTX_AVAILABLE:
|
||||
print("NVTX Profiling (Nsight Systems):")
|
||||
print(" nsys profile -t cuda,nvtx --nvtx-capture=range \\")
|
||||
print(" -o wan_attn_perf python test_attention_perf.py")
|
||||
else:
|
||||
print("NVTX not available. Install with: pip install nvtx")
|
||||
|
||||
print("\nPyTorch Profiler:")
|
||||
print(" The benchmark includes record_function() calls for profiling")
|
||||
print("-" * 50)
|
||||
|
||||
# Create benchmark instance
|
||||
benchmark = WanAttentionPerformanceBenchmark(
|
||||
warmup_iterations=10,
|
||||
benchmark_iterations=50,
|
||||
)
|
||||
|
||||
# Run full benchmark
|
||||
print("\n" + "=" * 70)
|
||||
print("FULL BENCHMARK")
|
||||
print("=" * 70)
|
||||
all_results = benchmark.run_full_benchmark(use_quick_sizes=False)
|
||||
|
||||
# Memory test
|
||||
if torch.cuda.is_available():
|
||||
benchmark.test_memory_usage(batch_size=1, num_heads=24, seq_len=4096, head_dim=64)
|
||||
|
||||
print("\n" + "=" * 70)
|
||||
print("BENCHMARK COMPLETE")
|
||||
print("=" * 70)
|
||||
|
||||
return all_results
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
126
tests/unittest/_torch/visual_gen/test_fused_qkv.py
Normal file
126
tests/unittest/_torch/visual_gen/test_fused_qkv.py
Normal file
@ -0,0 +1,126 @@
|
||||
"""Tests for fused QKV support in diffusion models.
|
||||
|
||||
Tests:
|
||||
1. Model structure with fuse_qkv=True (default) vs fuse_qkv=False
|
||||
2. Weight loading works for fused QKV layers
|
||||
"""
|
||||
|
||||
import unittest
|
||||
from types import SimpleNamespace
|
||||
from typing import Dict
|
||||
|
||||
import torch
|
||||
|
||||
from tensorrt_llm._torch.visual_gen.config import DiffusionModelConfig
|
||||
|
||||
|
||||
def _create_test_config(hidden_size: int = 64) -> DiffusionModelConfig:
|
||||
"""Create a test DiffusionModelConfig."""
|
||||
num_heads = hidden_size // 8 # e.g., 64 // 8 = 8 heads
|
||||
head_dim = 8
|
||||
return DiffusionModelConfig(
|
||||
pretrained_config=SimpleNamespace(
|
||||
hidden_size=hidden_size,
|
||||
num_attention_heads=num_heads,
|
||||
attention_head_dim=head_dim,
|
||||
num_layers=2,
|
||||
ffn_dim=256,
|
||||
out_channels=16,
|
||||
patch_size=[1, 2, 2],
|
||||
in_channels=16,
|
||||
text_dim=64,
|
||||
freq_dim=32,
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
class TestFusedQKVWeightLoading(unittest.TestCase):
|
||||
"""Test weight loading for fused QKV layers."""
|
||||
|
||||
def setUp(self):
|
||||
"""Set up test fixtures."""
|
||||
torch.manual_seed(42)
|
||||
self.hidden_size = 64
|
||||
|
||||
def _create_mock_checkpoint_weights(self) -> Dict[str, torch.Tensor]:
|
||||
"""Create mock checkpoint weights with separate to_q, to_k, to_v."""
|
||||
dtype = torch.bfloat16 # Match model dtype
|
||||
weights = {}
|
||||
for block_idx in range(2):
|
||||
for attn_name in ["attn1", "attn2"]:
|
||||
prefix = f"blocks.{block_idx}.{attn_name}"
|
||||
# Separate QKV weights (as in checkpoint)
|
||||
weights[f"{prefix}.to_q.weight"] = torch.randn(
|
||||
self.hidden_size, self.hidden_size, dtype=dtype
|
||||
)
|
||||
weights[f"{prefix}.to_q.bias"] = torch.randn(self.hidden_size, dtype=dtype)
|
||||
weights[f"{prefix}.to_k.weight"] = torch.randn(
|
||||
self.hidden_size, self.hidden_size, dtype=dtype
|
||||
)
|
||||
weights[f"{prefix}.to_k.bias"] = torch.randn(self.hidden_size, dtype=dtype)
|
||||
weights[f"{prefix}.to_v.weight"] = torch.randn(
|
||||
self.hidden_size, self.hidden_size, dtype=dtype
|
||||
)
|
||||
weights[f"{prefix}.to_v.bias"] = torch.randn(self.hidden_size, dtype=dtype)
|
||||
# Output projection
|
||||
weights[f"{prefix}.to_out.0.weight"] = torch.randn(
|
||||
self.hidden_size, self.hidden_size, dtype=dtype
|
||||
)
|
||||
weights[f"{prefix}.to_out.0.bias"] = torch.randn(self.hidden_size, dtype=dtype)
|
||||
|
||||
# FFN weights
|
||||
ffn_dim = 256
|
||||
prefix = f"blocks.{block_idx}.ffn"
|
||||
weights[f"{prefix}.net.0.proj.weight"] = torch.randn(
|
||||
ffn_dim, self.hidden_size, dtype=dtype
|
||||
)
|
||||
weights[f"{prefix}.net.0.proj.bias"] = torch.randn(ffn_dim, dtype=dtype)
|
||||
weights[f"{prefix}.net.2.weight"] = torch.randn(self.hidden_size, ffn_dim, dtype=dtype)
|
||||
weights[f"{prefix}.net.2.bias"] = torch.randn(self.hidden_size, dtype=dtype)
|
||||
|
||||
# proj_out
|
||||
weights["proj_out.weight"] = torch.randn(64, self.hidden_size, dtype=dtype)
|
||||
weights["proj_out.bias"] = torch.randn(64, dtype=dtype)
|
||||
|
||||
return weights
|
||||
|
||||
def test_load_weights_fused(self):
|
||||
"""Test loading weights with fused QKV (default for self-attention)."""
|
||||
from tensorrt_llm._torch.visual_gen.models.wan.transformer_wan import WanTransformer3DModel
|
||||
|
||||
config = _create_test_config(self.hidden_size)
|
||||
|
||||
# Create model - self-attention (attn1) uses fused QKV by default
|
||||
model = WanTransformer3DModel(model_config=config)
|
||||
weights = self._create_mock_checkpoint_weights()
|
||||
|
||||
# Load weights (model handles fused QKV internally via DynamicLinearWeightLoader)
|
||||
model.load_weights(weights)
|
||||
|
||||
# Verify fused weights were loaded correctly for self-attention
|
||||
attn1 = model.blocks[0].attn1
|
||||
qkv_weight = attn1.qkv_proj.weight.data
|
||||
|
||||
# Expected: concatenation of to_q, to_k, to_v weights
|
||||
expected_weight = torch.cat(
|
||||
[
|
||||
weights["blocks.0.attn1.to_q.weight"],
|
||||
weights["blocks.0.attn1.to_k.weight"],
|
||||
weights["blocks.0.attn1.to_v.weight"],
|
||||
],
|
||||
dim=0,
|
||||
)
|
||||
|
||||
self.assertEqual(qkv_weight.shape, expected_weight.shape)
|
||||
self.assertTrue(torch.allclose(qkv_weight, expected_weight))
|
||||
|
||||
# Also verify cross-attention (attn2) uses separate Q/K/V
|
||||
attn2 = model.blocks[0].attn2
|
||||
self.assertTrue(hasattr(attn2, "to_q"), "Cross-attention should have separate to_q")
|
||||
self.assertTrue(
|
||||
torch.allclose(attn2.to_q.weight.data, weights["blocks.0.attn2.to_q.weight"])
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
494
tests/unittest/_torch/visual_gen/test_model_loader.py
Normal file
494
tests/unittest/_torch/visual_gen/test_model_loader.py
Normal file
@ -0,0 +1,494 @@
|
||||
"""Test PipelineLoader with DiffusionArgs API."""
|
||||
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from tensorrt_llm._torch.visual_gen.config import PipelineComponent
|
||||
|
||||
|
||||
def _llm_models_root() -> str:
|
||||
"""Return LLM_MODELS_ROOT path if it is set in env, assert when it's set but not a valid path."""
|
||||
root = Path("/home/scratch.trt_llm_data_ci/llm-models/")
|
||||
if "LLM_MODELS_ROOT" in os.environ:
|
||||
root = Path(os.environ["LLM_MODELS_ROOT"])
|
||||
if not root.exists():
|
||||
root = Path("/scratch.trt_llm_data/llm-models/")
|
||||
assert root.exists(), (
|
||||
"You shall set LLM_MODELS_ROOT env or be able to access scratch.trt_llm_data to run this test"
|
||||
)
|
||||
return str(root)
|
||||
|
||||
|
||||
# Skip if checkpoint not available
|
||||
# Set DIFFUSION_MODEL_PATH env var to run integration tests
|
||||
CHECKPOINT_PATH = os.environ.get(
|
||||
"DIFFUSION_MODEL_PATH",
|
||||
os.path.join(_llm_models_root(), "Wan2.1-T2V-1.3B-Diffusers"),
|
||||
)
|
||||
|
||||
# Skip heavy components (text_encoder ~44GB, vae ~300MB) to speed up tests
|
||||
# These components are loaded via diffusers and don't need quantization testing
|
||||
SKIP_HEAVY_COMPONENTS = [
|
||||
PipelineComponent.TEXT_ENCODER,
|
||||
PipelineComponent.VAE,
|
||||
PipelineComponent.TOKENIZER,
|
||||
PipelineComponent.SCHEDULER,
|
||||
]
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def checkpoint_exists():
|
||||
return CHECKPOINT_PATH and os.path.exists(CHECKPOINT_PATH)
|
||||
|
||||
|
||||
def test_meta_init_mode_creates_meta_tensors(checkpoint_exists):
|
||||
"""Test that MetaInitMode creates tensors on meta device (no GPU memory)."""
|
||||
if not checkpoint_exists:
|
||||
pytest.skip("Checkpoint not available")
|
||||
|
||||
from tensorrt_llm._torch.models.modeling_utils import MetaInitMode
|
||||
from tensorrt_llm._torch.visual_gen import DiffusionArgs
|
||||
from tensorrt_llm._torch.visual_gen.config import DiffusionModelConfig
|
||||
from tensorrt_llm._torch.visual_gen.models import AutoPipeline
|
||||
|
||||
# Load config directly
|
||||
args = DiffusionArgs(checkpoint_path=CHECKPOINT_PATH)
|
||||
config = DiffusionModelConfig.from_pretrained(
|
||||
CHECKPOINT_PATH,
|
||||
args=args,
|
||||
)
|
||||
|
||||
# Create pipeline WITH MetaInitMode
|
||||
with MetaInitMode():
|
||||
pipeline = AutoPipeline.from_config(config, CHECKPOINT_PATH)
|
||||
|
||||
# Verify tensors are on meta device (no GPU memory allocated)
|
||||
param = next(pipeline.transformer.parameters())
|
||||
assert param.device.type == "meta", f"Expected meta device, got {param.device}"
|
||||
|
||||
|
||||
def test_load_wan_pipeline_basic(checkpoint_exists):
|
||||
"""Test basic loading without quantization using DiffusionArgs."""
|
||||
if not checkpoint_exists:
|
||||
pytest.skip("Checkpoint not available")
|
||||
|
||||
from tensorrt_llm._torch.visual_gen import DiffusionArgs, PipelineLoader
|
||||
|
||||
# Simple one-liner with DiffusionArgs
|
||||
# Skip text_encoder/vae to speed up test (focus on transformer)
|
||||
args = DiffusionArgs(
|
||||
checkpoint_path=CHECKPOINT_PATH,
|
||||
skip_components=SKIP_HEAVY_COMPONENTS,
|
||||
)
|
||||
pipeline = PipelineLoader(args).load()
|
||||
|
||||
# Verify pipeline type
|
||||
assert pipeline.__class__.__name__ == "WanPipeline"
|
||||
assert pipeline.transformer is not None
|
||||
|
||||
# Verify text_encoder/vae were skipped
|
||||
assert pipeline.text_encoder is None, "text_encoder should be skipped"
|
||||
assert pipeline.vae is None, "vae should be skipped"
|
||||
|
||||
# Verify weights are loaded (not meta tensors)
|
||||
param = next(pipeline.transformer.parameters())
|
||||
assert param.device.type == "cuda"
|
||||
assert param.dtype in [torch.float32, torch.bfloat16, torch.float16]
|
||||
|
||||
|
||||
def test_load_wan_pipeline_with_fp8_dynamic_quant(checkpoint_exists):
|
||||
"""Test loading with FP8 dynamic quantization using DiffusionArgs.
|
||||
|
||||
Verifies the dynamic quantization flow:
|
||||
1. Config has dynamic_weight_quant=True when linear.type="trtllm-fp8-per-tensor"
|
||||
2. Model Linear layers have FP8 weight buffers
|
||||
3. BF16 checkpoint weights are quantized on-the-fly
|
||||
4. Quantized weights are in FP8 format
|
||||
"""
|
||||
if not checkpoint_exists:
|
||||
pytest.skip("Checkpoint not available")
|
||||
|
||||
from tensorrt_llm._torch.modules.linear import Linear
|
||||
from tensorrt_llm._torch.visual_gen import DiffusionArgs, PipelineLoader
|
||||
|
||||
# Use DiffusionArgs with FP8 quantization
|
||||
# Skip text_encoder/vae to speed up test (focus on transformer quantization)
|
||||
args = DiffusionArgs(
|
||||
checkpoint_path=CHECKPOINT_PATH,
|
||||
quant_config={"quant_algo": "FP8", "dynamic": True},
|
||||
skip_components=SKIP_HEAVY_COMPONENTS,
|
||||
)
|
||||
pipeline = PipelineLoader(args).load()
|
||||
|
||||
# Verify model config has dynamic_weight_quant enabled
|
||||
assert pipeline.model_config.dynamic_weight_quant is True, (
|
||||
"dynamic_weight_quant should be True when linear.type specifies FP8"
|
||||
)
|
||||
|
||||
# Verify FP8 weights in transformer Linear layers
|
||||
found_fp8_linear = False
|
||||
for name, module in pipeline.transformer.named_modules():
|
||||
if isinstance(module, Linear):
|
||||
if hasattr(module, "weight") and module.weight is not None:
|
||||
assert module.weight.dtype == torch.float8_e4m3fn, (
|
||||
f"Linear {name} weight dtype is {module.weight.dtype}, expected float8_e4m3fn"
|
||||
)
|
||||
assert hasattr(module, "weight_scale") and module.weight_scale is not None, (
|
||||
f"Linear {name} missing weight_scale buffer"
|
||||
)
|
||||
found_fp8_linear = True
|
||||
break
|
||||
|
||||
assert found_fp8_linear, "No FP8 Linear modules found in transformer"
|
||||
|
||||
|
||||
def test_load_wan_pipeline_with_fp8_blockwise(checkpoint_exists):
|
||||
"""Test loading with FP8 blockwise quantization using DiffusionArgs."""
|
||||
if not checkpoint_exists:
|
||||
pytest.skip("Checkpoint not available")
|
||||
|
||||
from tensorrt_llm._torch.modules.linear import Linear
|
||||
from tensorrt_llm._torch.visual_gen import DiffusionArgs, PipelineLoader
|
||||
|
||||
# Skip text_encoder/vae to speed up test (focus on transformer quantization)
|
||||
args = DiffusionArgs(
|
||||
checkpoint_path=CHECKPOINT_PATH,
|
||||
quant_config={"quant_algo": "FP8_BLOCK_SCALES", "dynamic": True},
|
||||
skip_components=SKIP_HEAVY_COMPONENTS,
|
||||
)
|
||||
pipeline = PipelineLoader(args).load()
|
||||
|
||||
# Verify FP8 weights
|
||||
for name, module in pipeline.transformer.named_modules():
|
||||
if isinstance(module, Linear):
|
||||
if hasattr(module, "weight") and module.weight is not None:
|
||||
assert module.weight.dtype == torch.float8_e4m3fn, (
|
||||
f"Linear {name} should have FP8 weight"
|
||||
)
|
||||
break
|
||||
|
||||
|
||||
def test_diffusion_args_to_quant_config():
|
||||
"""Test that DiffusionArgs correctly parses quant_config dict to QuantConfig."""
|
||||
from tensorrt_llm._torch.visual_gen import DiffusionArgs
|
||||
from tensorrt_llm.quantization.mode import QuantAlgo
|
||||
|
||||
# Default - no quantization
|
||||
args = DiffusionArgs(checkpoint_path="/fake/path")
|
||||
assert args.quant_config.quant_algo is None
|
||||
|
||||
# FP8 per-tensor (dict is coerced to QuantConfig by model_validator)
|
||||
args = DiffusionArgs(
|
||||
checkpoint_path="/fake/path",
|
||||
quant_config={"quant_algo": "FP8", "dynamic": True},
|
||||
)
|
||||
qc = args.quant_config
|
||||
assert qc is not None
|
||||
assert qc.quant_algo == QuantAlgo.FP8
|
||||
assert args.dynamic_weight_quant is True
|
||||
|
||||
# FP8 blockwise
|
||||
args = DiffusionArgs(
|
||||
checkpoint_path="/fake/path",
|
||||
quant_config={"quant_algo": "FP8_BLOCK_SCALES", "dynamic": True},
|
||||
)
|
||||
qc = args.quant_config
|
||||
assert qc.quant_algo == QuantAlgo.FP8_BLOCK_SCALES
|
||||
|
||||
# NVFP4
|
||||
args = DiffusionArgs(
|
||||
checkpoint_path="/fake/path",
|
||||
quant_config={"quant_algo": "NVFP4", "dynamic": True},
|
||||
)
|
||||
qc = args.quant_config
|
||||
assert qc.quant_algo == QuantAlgo.NVFP4
|
||||
|
||||
# With ignore patterns (exclude_modules)
|
||||
args = DiffusionArgs(
|
||||
checkpoint_path="/fake/path",
|
||||
quant_config={
|
||||
"quant_algo": "FP8",
|
||||
"ignore": ["blocks.0.attn1.*", "proj_out"],
|
||||
"config_groups": {
|
||||
"group_0": {
|
||||
"weights": {"dynamic": True, "num_bits": 8, "type": "float"},
|
||||
"targets": ["Linear"],
|
||||
}
|
||||
},
|
||||
},
|
||||
)
|
||||
qc = args.quant_config
|
||||
assert qc is not None
|
||||
assert qc.quant_algo == QuantAlgo.FP8
|
||||
assert qc.exclude_modules == ["blocks.0.attn1.*", "proj_out"]
|
||||
assert args.dynamic_weight_quant is True
|
||||
|
||||
|
||||
def test_diffusion_args_to_mapping():
|
||||
"""Test that DiffusionArgs correctly generates Mapping from ParallelConfig."""
|
||||
from tensorrt_llm._torch.visual_gen import DiffusionArgs, ParallelConfig
|
||||
|
||||
# ParallelConfig validator requires WORLD_SIZE >= total parallel (tp*cp = 4)
|
||||
old_world = os.environ.get("WORLD_SIZE")
|
||||
try:
|
||||
os.environ["WORLD_SIZE"] = "4"
|
||||
args = DiffusionArgs(
|
||||
checkpoint_path="/fake/path",
|
||||
parallel=ParallelConfig(dit_tp_size=2, dit_cp_size=2),
|
||||
)
|
||||
mapping = args.to_mapping()
|
||||
assert mapping.tp_size == 2
|
||||
assert mapping.cp_size == 2
|
||||
# world_size = tp_size * pp_size * cp_size (DP is handled separately)
|
||||
assert mapping.world_size == 4
|
||||
finally:
|
||||
if old_world is not None:
|
||||
os.environ["WORLD_SIZE"] = old_world
|
||||
elif "WORLD_SIZE" in os.environ:
|
||||
del os.environ["WORLD_SIZE"]
|
||||
|
||||
|
||||
def test_load_without_quant_config_no_fp8(checkpoint_exists):
|
||||
"""Test that loading without quant_config does NOT produce FP8 weights."""
|
||||
if not checkpoint_exists:
|
||||
pytest.skip("Checkpoint not available")
|
||||
|
||||
from tensorrt_llm._torch.modules.linear import Linear
|
||||
from tensorrt_llm._torch.visual_gen import DiffusionArgs, PipelineLoader
|
||||
|
||||
# No quantization specified
|
||||
# Skip text_encoder/vae to speed up test (focus on transformer)
|
||||
args = DiffusionArgs(
|
||||
checkpoint_path=CHECKPOINT_PATH,
|
||||
skip_components=SKIP_HEAVY_COMPONENTS,
|
||||
)
|
||||
pipeline = PipelineLoader(args).load()
|
||||
|
||||
# Verify dynamic_weight_quant is False
|
||||
assert pipeline.model_config.dynamic_weight_quant is False, (
|
||||
"dynamic_weight_quant should be False when no quant_config"
|
||||
)
|
||||
|
||||
# Verify NO FP8 weights
|
||||
for name, module in pipeline.transformer.named_modules():
|
||||
if isinstance(module, Linear):
|
||||
if hasattr(module, "weight") and module.weight is not None:
|
||||
assert module.weight.dtype != torch.float8_e4m3fn, (
|
||||
f"Linear {name} should NOT be FP8 without quant_config"
|
||||
)
|
||||
break
|
||||
|
||||
|
||||
def test_diffusion_args_from_dict():
|
||||
"""Test DiffusionArgs can be created from a dictionary."""
|
||||
from tensorrt_llm._torch.visual_gen import DiffusionArgs
|
||||
from tensorrt_llm.quantization.mode import QuantAlgo
|
||||
|
||||
config_dict = {
|
||||
"checkpoint_path": "/path/to/model",
|
||||
"quant_config": {"quant_algo": "FP8", "dynamic": True},
|
||||
"parallel": {"dit_tp_size": 2},
|
||||
"pipeline": {"fuse_qkv": True},
|
||||
}
|
||||
# ParallelConfig validator requires WORLD_SIZE >= total parallel (dit_tp_size=2)
|
||||
old_world = os.environ.get("WORLD_SIZE")
|
||||
try:
|
||||
os.environ["WORLD_SIZE"] = "2"
|
||||
args = DiffusionArgs.from_dict(config_dict)
|
||||
assert args.checkpoint_path == "/path/to/model"
|
||||
assert args.quant_config.quant_algo == QuantAlgo.FP8
|
||||
assert args.dynamic_weight_quant is True
|
||||
assert args.parallel.dit_tp_size == 2
|
||||
assert args.pipeline.fuse_qkv is True
|
||||
finally:
|
||||
if old_world is not None:
|
||||
os.environ["WORLD_SIZE"] = old_world
|
||||
elif "WORLD_SIZE" in os.environ:
|
||||
del os.environ["WORLD_SIZE"]
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Memory and Performance Tests
|
||||
# =============================================================================
|
||||
|
||||
|
||||
def _get_module_memory_gb(module):
|
||||
"""Get GPU memory usage of a module in GB."""
|
||||
return sum(p.numel() * p.element_size() for p in module.parameters()) / 1024**3
|
||||
|
||||
|
||||
def _get_cuda_memory_gb():
|
||||
"""Get current CUDA memory allocated in GB."""
|
||||
return torch.cuda.memory_allocated() / 1024**3
|
||||
|
||||
|
||||
def _get_cuda_peak_memory_gb():
|
||||
"""Get peak CUDA memory allocated in GB."""
|
||||
return torch.cuda.max_memory_allocated() / 1024**3
|
||||
|
||||
|
||||
def test_fp8_vs_bf16_memory_comparison(checkpoint_exists):
|
||||
"""Test FP8 dynamic quant uses ~2x less memory than BF16, including peak memory.
|
||||
|
||||
This test verifies that dynamic quantization doesn't create unnecessary
|
||||
intermediate buffers that would negate the memory savings.
|
||||
|
||||
Expected for Wan 1.3B transformer:
|
||||
- BF16: ~2.6 GB model memory, similar peak during loading
|
||||
- FP8: ~1.3 GB model memory, peak should be < 2x BF16 peak
|
||||
"""
|
||||
if not checkpoint_exists:
|
||||
pytest.skip("Checkpoint not available")
|
||||
|
||||
from tensorrt_llm._torch.visual_gen import DiffusionArgs, PipelineLoader
|
||||
|
||||
# =========================================================================
|
||||
# Test 1: Load BF16 (no quantization)
|
||||
# =========================================================================
|
||||
torch.cuda.empty_cache()
|
||||
torch.cuda.reset_peak_memory_stats()
|
||||
|
||||
args_bf16 = DiffusionArgs(
|
||||
checkpoint_path=CHECKPOINT_PATH,
|
||||
skip_components=SKIP_HEAVY_COMPONENTS,
|
||||
)
|
||||
pipeline_bf16 = PipelineLoader(args_bf16).load()
|
||||
|
||||
bf16_model_mem = _get_module_memory_gb(pipeline_bf16.transformer)
|
||||
bf16_total_mem = _get_cuda_memory_gb()
|
||||
bf16_peak_mem = _get_cuda_peak_memory_gb()
|
||||
|
||||
print(f"\n[BF16] Transformer model memory: {bf16_model_mem:.2f} GB")
|
||||
print(f"[BF16] Total CUDA memory: {bf16_total_mem:.2f} GB")
|
||||
print(f"[BF16] Peak CUDA memory: {bf16_peak_mem:.2f} GB")
|
||||
|
||||
# Cleanup BF16
|
||||
del pipeline_bf16
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
# =========================================================================
|
||||
# Test 2: Load FP8 (dynamic quantization)
|
||||
# =========================================================================
|
||||
torch.cuda.reset_peak_memory_stats()
|
||||
|
||||
args_fp8 = DiffusionArgs(
|
||||
checkpoint_path=CHECKPOINT_PATH,
|
||||
quant_config={"quant_algo": "FP8", "dynamic": True},
|
||||
skip_components=SKIP_HEAVY_COMPONENTS,
|
||||
)
|
||||
pipeline_fp8 = PipelineLoader(args_fp8).load()
|
||||
|
||||
fp8_model_mem = _get_module_memory_gb(pipeline_fp8.transformer)
|
||||
fp8_total_mem = _get_cuda_memory_gb()
|
||||
fp8_peak_mem = _get_cuda_peak_memory_gb()
|
||||
|
||||
print(f"\n[FP8] Transformer model memory: {fp8_model_mem:.2f} GB")
|
||||
print(f"[FP8] Total CUDA memory: {fp8_total_mem:.2f} GB")
|
||||
print(f"[FP8] Peak CUDA memory: {fp8_peak_mem:.2f} GB")
|
||||
|
||||
# =========================================================================
|
||||
# Verify memory savings
|
||||
# =========================================================================
|
||||
model_mem_ratio = bf16_model_mem / fp8_model_mem
|
||||
peak_mem_ratio = bf16_peak_mem / fp8_peak_mem
|
||||
|
||||
print(f"\n[Comparison] Model memory ratio (BF16/FP8): {model_mem_ratio:.2f}x")
|
||||
print(f"[Comparison] Peak memory ratio (BF16/FP8): {peak_mem_ratio:.2f}x")
|
||||
|
||||
# Model memory should be ~2x smaller for FP8
|
||||
assert model_mem_ratio > 1.8, (
|
||||
f"FP8 model memory should be ~2x smaller than BF16, got {model_mem_ratio:.2f}x"
|
||||
)
|
||||
|
||||
# Peak memory during loading should also show savings
|
||||
# Allow some overhead for dynamic quant, but should still be significantly better
|
||||
assert peak_mem_ratio > 1.5, (
|
||||
f"FP8 peak memory should be significantly smaller than BF16, got {peak_mem_ratio:.2f}x. "
|
||||
f"This may indicate unnecessary intermediate buffers during dynamic quantization."
|
||||
)
|
||||
|
||||
# FP8 peak should not be much higher than FP8 final (no large temp buffers)
|
||||
fp8_peak_overhead = fp8_peak_mem / fp8_total_mem
|
||||
print(f"[FP8 Per-Tensor] Peak/Final memory ratio: {fp8_peak_overhead:.2f}x")
|
||||
|
||||
# Peak should be close to final (< 1.5x overhead during loading)
|
||||
assert fp8_peak_overhead < 2.0, (
|
||||
f"FP8 peak memory ({fp8_peak_mem:.2f} GB) is too high compared to final "
|
||||
f"({fp8_total_mem:.2f} GB). Ratio: {fp8_peak_overhead:.2f}x. "
|
||||
f"This suggests unnecessary buffer allocation during dynamic quantization."
|
||||
)
|
||||
|
||||
# Cleanup
|
||||
del pipeline_fp8
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
# =========================================================================
|
||||
# Test 3: Load FP8 Blockwise (dynamic quantization with block scales)
|
||||
# =========================================================================
|
||||
torch.cuda.reset_peak_memory_stats()
|
||||
|
||||
args_fp8_block = DiffusionArgs(
|
||||
checkpoint_path=CHECKPOINT_PATH,
|
||||
quant_config={"quant_algo": "FP8_BLOCK_SCALES", "dynamic": True},
|
||||
skip_components=SKIP_HEAVY_COMPONENTS,
|
||||
)
|
||||
pipeline_fp8_block = PipelineLoader(args_fp8_block).load()
|
||||
|
||||
fp8_block_model_mem = _get_module_memory_gb(pipeline_fp8_block.transformer)
|
||||
fp8_block_total_mem = _get_cuda_memory_gb()
|
||||
fp8_block_peak_mem = _get_cuda_peak_memory_gb()
|
||||
|
||||
print(f"\n[FP8 Blockwise] Transformer model memory: {fp8_block_model_mem:.2f} GB")
|
||||
print(f"[FP8 Blockwise] Total CUDA memory: {fp8_block_total_mem:.2f} GB")
|
||||
print(f"[FP8 Blockwise] Peak CUDA memory: {fp8_block_peak_mem:.2f} GB")
|
||||
|
||||
# =========================================================================
|
||||
# Verify FP8 Blockwise memory savings
|
||||
# =========================================================================
|
||||
block_model_mem_ratio = bf16_model_mem / fp8_block_model_mem
|
||||
block_peak_mem_ratio = bf16_peak_mem / fp8_block_peak_mem
|
||||
|
||||
print(f"\n[Comparison] Model memory ratio (BF16/FP8-Block): {block_model_mem_ratio:.2f}x")
|
||||
print(f"[Comparison] Peak memory ratio (BF16/FP8-Block): {block_peak_mem_ratio:.2f}x")
|
||||
|
||||
# FP8 Blockwise has additional scale tensors, so slightly less than 2x savings
|
||||
# But should still be significantly better than BF16
|
||||
assert block_model_mem_ratio > 1.5, (
|
||||
f"FP8 Blockwise model memory should be significantly smaller than BF16, got {block_model_mem_ratio:.2f}x"
|
||||
)
|
||||
|
||||
# Peak memory check
|
||||
assert block_peak_mem_ratio > 1.3, (
|
||||
f"FP8 Blockwise peak memory should be smaller than BF16, got {block_peak_mem_ratio:.2f}x"
|
||||
)
|
||||
|
||||
fp8_block_peak_overhead = fp8_block_peak_mem / fp8_block_total_mem
|
||||
print(f"[FP8 Blockwise] Peak/Final memory ratio: {fp8_block_peak_overhead:.2f}x")
|
||||
|
||||
assert fp8_block_peak_overhead < 2.0, (
|
||||
f"FP8 Blockwise peak memory ({fp8_block_peak_mem:.2f} GB) is too high compared to final "
|
||||
f"({fp8_block_total_mem:.2f} GB). Ratio: {fp8_block_peak_overhead:.2f}x."
|
||||
)
|
||||
|
||||
# Cleanup
|
||||
del pipeline_fp8_block
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
# =========================================================================
|
||||
# Summary
|
||||
# =========================================================================
|
||||
print("\n" + "=" * 60)
|
||||
print("SUMMARY")
|
||||
print("=" * 60)
|
||||
print(f"BF16: {bf16_model_mem:.2f} GB model, {bf16_peak_mem:.2f} GB peak")
|
||||
print(
|
||||
f"FP8 Per-Tensor: {fp8_model_mem:.2f} GB model, {fp8_peak_mem:.2f} GB peak "
|
||||
f"({model_mem_ratio:.2f}x savings)"
|
||||
)
|
||||
print(
|
||||
f"FP8 Blockwise: {fp8_block_model_mem:.2f} GB model, {fp8_block_peak_mem:.2f} GB peak "
|
||||
f"({block_model_mem_ratio:.2f}x savings)"
|
||||
)
|
||||
120
tests/unittest/_torch/visual_gen/test_quant_ops.py
Normal file
120
tests/unittest/_torch/visual_gen/test_quant_ops.py
Normal file
@ -0,0 +1,120 @@
|
||||
"""Unit tests for diffusion quantization operations."""
|
||||
|
||||
import unittest
|
||||
|
||||
import torch
|
||||
|
||||
from tensorrt_llm._torch.visual_gen.quantization.ops import (
|
||||
quantize_fp8_blockwise,
|
||||
quantize_fp8_per_tensor,
|
||||
)
|
||||
|
||||
|
||||
def _dequant_fp8_per_tensor(qweight, scale):
|
||||
"""Dequantize per-tensor FP8 weight."""
|
||||
return qweight.to(torch.float32) * scale
|
||||
|
||||
|
||||
class TestQuantOps(unittest.TestCase):
|
||||
"""Test quantization operations."""
|
||||
|
||||
def setUp(self):
|
||||
"""Set random seed for reproducibility."""
|
||||
torch.manual_seed(42)
|
||||
if not torch.cuda.is_available():
|
||||
self.skipTest("CUDA not available")
|
||||
|
||||
def test_fp8_per_tensor(self):
|
||||
"""Test FP8 per-tensor quantization using CUDA kernel."""
|
||||
weight = torch.randn(256, 512, dtype=torch.bfloat16, device="cuda")
|
||||
qweight, scale = quantize_fp8_per_tensor(weight)
|
||||
|
||||
# Check output types
|
||||
self.assertEqual(qweight.dtype, torch.float8_e4m3fn)
|
||||
self.assertEqual(qweight.shape, weight.shape)
|
||||
self.assertEqual(scale.dtype, torch.float32)
|
||||
self.assertEqual(scale.shape, (1, 1))
|
||||
|
||||
# Verify dequantization (approximate)
|
||||
dequant = _dequant_fp8_per_tensor(qweight, scale)
|
||||
error = (dequant - weight.to(torch.float32)).abs().mean()
|
||||
self.assertLess(error, 0.15) # Reasonable quantization error
|
||||
|
||||
def test_fp8_per_tensor_different_shapes(self):
|
||||
"""Test FP8 per-tensor quantization with various shapes."""
|
||||
shapes = [(128, 256), (256, 512), (512, 1024), (1024, 2048)]
|
||||
for shape in shapes:
|
||||
with self.subTest(shape=shape):
|
||||
weight = torch.randn(shape, dtype=torch.bfloat16, device="cuda")
|
||||
qweight, scale = quantize_fp8_per_tensor(weight)
|
||||
|
||||
self.assertEqual(qweight.dtype, torch.float8_e4m3fn)
|
||||
self.assertEqual(qweight.shape, weight.shape)
|
||||
self.assertEqual(scale.dtype, torch.float32)
|
||||
|
||||
def test_fp8_blockwise(self):
|
||||
"""Test FP8 128x128 blockwise quantization."""
|
||||
weight = torch.randn(512, 512, dtype=torch.bfloat16, device="cuda")
|
||||
block_size = 128
|
||||
qweight, scales = quantize_fp8_blockwise(weight, block_size=block_size)
|
||||
|
||||
# Check output types
|
||||
self.assertEqual(qweight.dtype, torch.float8_e4m3fn)
|
||||
self.assertEqual(qweight.shape, weight.shape)
|
||||
self.assertEqual(scales.dtype, torch.float32)
|
||||
|
||||
# Check scales shape: (num_blocks_out, num_blocks_in) for 128x128 blocks
|
||||
num_blocks_out = (512 + block_size - 1) // block_size # 4
|
||||
num_blocks_in = (512 + block_size - 1) // block_size # 4
|
||||
self.assertEqual(scales.shape, (num_blocks_out, num_blocks_in))
|
||||
|
||||
def test_fp8_blockwise_non_divisible(self):
|
||||
"""Test FP8 blockwise quantization with non-divisible dimensions."""
|
||||
weight = torch.randn(300, 500, dtype=torch.bfloat16, device="cuda")
|
||||
block_size = 128
|
||||
qweight, scales = quantize_fp8_blockwise(weight, block_size=block_size)
|
||||
|
||||
# Check output types
|
||||
self.assertEqual(qweight.dtype, torch.float8_e4m3fn)
|
||||
self.assertEqual(qweight.shape, weight.shape)
|
||||
|
||||
# Check scales shape (should handle non-divisible dimensions)
|
||||
num_blocks_out = (300 + block_size - 1) // block_size # 3
|
||||
num_blocks_in = (500 + block_size - 1) // block_size # 4
|
||||
self.assertEqual(scales.shape, (num_blocks_out, num_blocks_in))
|
||||
|
||||
def test_fp8_blockwise_different_block_sizes(self):
|
||||
"""Test FP8 blockwise quantization with different block sizes."""
|
||||
weight = torch.randn(256, 256, dtype=torch.bfloat16, device="cuda")
|
||||
|
||||
for block_size in [64, 128, 256]:
|
||||
with self.subTest(block_size=block_size):
|
||||
qweight, scales = quantize_fp8_blockwise(weight, block_size=block_size)
|
||||
|
||||
self.assertEqual(qweight.dtype, torch.float8_e4m3fn)
|
||||
self.assertEqual(qweight.shape, weight.shape)
|
||||
|
||||
num_blocks = (256 + block_size - 1) // block_size
|
||||
self.assertEqual(scales.shape, (num_blocks, num_blocks))
|
||||
|
||||
def test_fp8_per_tensor_zero_weight(self):
|
||||
"""Test FP8 per-tensor quantization with zero weight."""
|
||||
weight = torch.zeros(128, 256, dtype=torch.bfloat16, device="cuda")
|
||||
qweight, scale = quantize_fp8_per_tensor(weight)
|
||||
|
||||
# Should handle zero weights gracefully
|
||||
self.assertEqual(qweight.dtype, torch.float8_e4m3fn)
|
||||
self.assertTrue(torch.all(qweight.to(torch.float32) == 0))
|
||||
|
||||
def test_fp8_blockwise_zero_weight(self):
|
||||
"""Test FP8 blockwise quantization with zero weight."""
|
||||
weight = torch.zeros(256, 256, dtype=torch.bfloat16, device="cuda")
|
||||
qweight, scales = quantize_fp8_blockwise(weight, block_size=128)
|
||||
|
||||
# Should handle zero weights gracefully
|
||||
self.assertEqual(qweight.dtype, torch.float8_e4m3fn)
|
||||
self.assertTrue(torch.all(qweight.to(torch.float32) == 0))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
398
tests/unittest/_torch/visual_gen/test_trtllm_serve_e2e.py
Normal file
398
tests/unittest/_torch/visual_gen/test_trtllm_serve_e2e.py
Normal file
@ -0,0 +1,398 @@
|
||||
"""End-to-end tests for trtllm-serve visual_gen with real models.
|
||||
|
||||
Tests text-to-video (t2v) and text+image-to-video (ti2v) generation through
|
||||
the full ``trtllm-serve`` stack backed by real VisualGen models.
|
||||
|
||||
The server is launched as a subprocess (same pattern as
|
||||
``tests/unittest/llmapi/apps/openai_server.py``), so each test class gets an
|
||||
isolated ``trtllm-serve`` process.
|
||||
|
||||
Usage::
|
||||
|
||||
# Run all real-model tests (requires GPU + models in $HOME/llm-models-ci)
|
||||
pytest tests/visual_gen/test_trtllm_serve_e2e.py -v
|
||||
|
||||
# Run only t2v tests
|
||||
pytest tests/visual_gen/test_trtllm_serve_e2e.py -v -k TestWanT2V
|
||||
|
||||
# Run only ti2v tests
|
||||
pytest tests/visual_gen/test_trtllm_serve_e2e.py -v -k TestWanI2V
|
||||
"""
|
||||
|
||||
import os
|
||||
import subprocess
|
||||
import sys
|
||||
import tempfile
|
||||
import time
|
||||
from pathlib import Path
|
||||
from typing import List, Optional
|
||||
|
||||
import pytest
|
||||
import requests
|
||||
import yaml
|
||||
|
||||
from tensorrt_llm._utils import get_free_port
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Model paths
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _llm_models_root() -> str:
|
||||
"""Return LLM_MODELS_ROOT path if it is set in env, assert when it's set but not a valid path."""
|
||||
root = Path("/home/scratch.trt_llm_data_ci/llm-models/")
|
||||
if "LLM_MODELS_ROOT" in os.environ:
|
||||
root = Path(os.environ["LLM_MODELS_ROOT"])
|
||||
if not root.exists():
|
||||
root = Path("/scratch.trt_llm_data/llm-models/")
|
||||
assert root.exists(), (
|
||||
"You shall set LLM_MODELS_ROOT env or be able to access scratch.trt_llm_data to run this test"
|
||||
)
|
||||
return str(root)
|
||||
|
||||
|
||||
_WAN_T2V_PATH = Path(_llm_models_root()) / "Wan2.1-T2V-1.3B-Diffusers"
|
||||
_WAN_I2V_PATH = Path(_llm_models_root()) / "Wan2.2-I2V-A14B-Diffusers"
|
||||
|
||||
# Reference image used for image-to-video (ti2v) tests
|
||||
_PROJECT_ROOT = Path(__file__).resolve().parents[4] # repo root
|
||||
_REF_IMAGE_PATH = _PROJECT_ROOT / "examples" / "visual_gen" / "cat_piano.png"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Remote server helper (follows RemoteOpenAIServer pattern)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class RemoteVisualGenServer:
|
||||
"""Launch ``trtllm-serve`` for a visual-gen model as a subprocess.
|
||||
|
||||
Mirrors the interface of ``tests.unittest.llmapi.apps.openai_server.RemoteOpenAIServer``
|
||||
adapted for diffusion / visual-gen models.
|
||||
"""
|
||||
|
||||
MAX_SERVER_START_WAIT_S = 1200 # 20 min – large models need time to load
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model: str,
|
||||
extra_visual_gen_options: Optional[dict] = None,
|
||||
cli_args: Optional[List[str]] = None,
|
||||
host: str = "localhost",
|
||||
port: Optional[int] = None,
|
||||
env: Optional[dict] = None,
|
||||
) -> None:
|
||||
self.host = host
|
||||
self.port = port if port is not None else get_free_port()
|
||||
self._config_file: Optional[str] = None
|
||||
self.proc: Optional[subprocess.Popen] = None
|
||||
|
||||
args = ["--host", self.host, "--port", str(self.port)]
|
||||
if cli_args:
|
||||
args += cli_args
|
||||
|
||||
# Write the visual-gen YAML config to a temp file
|
||||
if extra_visual_gen_options:
|
||||
fd, self._config_file = tempfile.mkstemp(suffix=".yml", prefix="vg_cfg_")
|
||||
with os.fdopen(fd, "w") as f:
|
||||
yaml.dump(extra_visual_gen_options, f)
|
||||
args += ["--extra_visual_gen_options", self._config_file]
|
||||
|
||||
launch_cmd = ["trtllm-serve", model] + args
|
||||
|
||||
if env is None:
|
||||
env = os.environ.copy()
|
||||
|
||||
self.proc = subprocess.Popen(
|
||||
launch_cmd,
|
||||
env=env,
|
||||
stdout=sys.stdout,
|
||||
stderr=sys.stderr,
|
||||
)
|
||||
self._wait_for_server(timeout=self.MAX_SERVER_START_WAIT_S)
|
||||
|
||||
# -- lifecycle ---------------------------------------------------------
|
||||
|
||||
def __enter__(self):
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type, exc_value, traceback):
|
||||
self.terminate()
|
||||
|
||||
def terminate(self):
|
||||
if self.proc is None:
|
||||
return
|
||||
self.proc.terminate()
|
||||
try:
|
||||
self.proc.wait(timeout=30)
|
||||
except subprocess.TimeoutExpired:
|
||||
self.proc.kill()
|
||||
self.proc.wait(timeout=30)
|
||||
self.proc = None
|
||||
|
||||
if self._config_file:
|
||||
try:
|
||||
os.remove(self._config_file)
|
||||
except OSError:
|
||||
pass
|
||||
self._config_file = None
|
||||
|
||||
# -- readiness ---------------------------------------------------------
|
||||
|
||||
def _wait_for_server(self, timeout: float):
|
||||
url = self.url_for("health")
|
||||
start = time.time()
|
||||
while True:
|
||||
try:
|
||||
if requests.get(url).status_code == 200:
|
||||
return
|
||||
except Exception as err:
|
||||
result = self.proc.poll()
|
||||
if result is not None and result != 0:
|
||||
raise RuntimeError("Visual-gen server exited unexpectedly.") from err
|
||||
time.sleep(1)
|
||||
if time.time() - start > timeout:
|
||||
self.terminate()
|
||||
raise RuntimeError(f"Visual-gen server failed to start within {timeout}s.")
|
||||
|
||||
# -- URL helpers -------------------------------------------------------
|
||||
|
||||
@property
|
||||
def url_root(self) -> str:
|
||||
return f"http://{self.host}:{self.port}"
|
||||
|
||||
def url_for(self, *parts: str) -> str:
|
||||
return self.url_root + "/" + "/".join(parts)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _model_available(path: Path) -> bool:
|
||||
return path.is_dir()
|
||||
|
||||
|
||||
def _av_available() -> bool:
|
||||
"""Check if PyAV is installed (required for video encoding in E2E tests)."""
|
||||
try:
|
||||
import av # noqa: F401
|
||||
|
||||
return True
|
||||
except ImportError:
|
||||
return False
|
||||
|
||||
|
||||
def _make_visual_gen_options(**extra) -> dict:
|
||||
"""Build the YAML dict passed via ``--extra_visual_gen_options``."""
|
||||
config = {
|
||||
"linear": {"type": "default"},
|
||||
"parallel": {"dit_cfg_size": 1, "dit_ulysses_size": 1},
|
||||
}
|
||||
config.update(extra)
|
||||
return config
|
||||
|
||||
|
||||
# =========================================================================
|
||||
# WAN 2.1 – Text-to-Video (t2v)
|
||||
# =========================================================================
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
not _model_available(_WAN_T2V_PATH), reason=f"Wan2.1-T2V model not found at {_WAN_T2V_PATH}"
|
||||
)
|
||||
@pytest.mark.skipif(
|
||||
not _av_available(), reason="PyAV (av) not installed — required for video encoding in E2E tests"
|
||||
)
|
||||
class TestWanTextToVideo:
|
||||
"""Test Wan2.1-T2V-1.3B-Diffusers text-to-video generation via serve API."""
|
||||
|
||||
@pytest.fixture(scope="class")
|
||||
def server(self):
|
||||
with RemoteVisualGenServer(
|
||||
model=str(_WAN_T2V_PATH),
|
||||
extra_visual_gen_options=_make_visual_gen_options(),
|
||||
) as srv:
|
||||
yield srv
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def test_health(self, server):
|
||||
resp = requests.get(server.url_for("health"))
|
||||
assert resp.status_code == 200
|
||||
|
||||
def test_t2v_sync(self, server):
|
||||
"""Synchronous text-to-video via POST /v1/videos/generations."""
|
||||
resp = requests.post(
|
||||
server.url_for("v1", "videos", "generations"),
|
||||
json={
|
||||
"prompt": "A cute cat playing piano",
|
||||
"size": "480x320",
|
||||
"seconds": 1.0,
|
||||
"fps": 8,
|
||||
"num_inference_steps": 4,
|
||||
"seed": 42,
|
||||
},
|
||||
)
|
||||
assert resp.status_code == 200, resp.text
|
||||
assert resp.headers["content-type"] == "video/mp4"
|
||||
assert len(resp.content) > 1000, "Video file too small"
|
||||
|
||||
def test_t2v_async_lifecycle(self, server):
|
||||
"""Async video generation: create job → poll → download → delete."""
|
||||
base = server.url_for("v1", "videos")
|
||||
|
||||
# 1. Create job
|
||||
create_resp = requests.post(
|
||||
base,
|
||||
json={
|
||||
"prompt": "A rocket launching into a starry sky",
|
||||
"size": "480x320",
|
||||
"seconds": 1.0,
|
||||
"fps": 8,
|
||||
"num_inference_steps": 4,
|
||||
"seed": 42,
|
||||
},
|
||||
)
|
||||
assert create_resp.status_code == 202, create_resp.text
|
||||
job = create_resp.json()
|
||||
video_id = job["id"]
|
||||
assert video_id.startswith("video_")
|
||||
assert job["status"] == "queued"
|
||||
|
||||
# 2. Poll until completed (or timeout)
|
||||
deadline = time.time() + 600 # 10 min
|
||||
status = "queued"
|
||||
while status not in ("completed", "failed") and time.time() < deadline:
|
||||
time.sleep(2)
|
||||
meta_resp = requests.get(f"{base}/{video_id}")
|
||||
assert meta_resp.status_code == 200
|
||||
status = meta_resp.json()["status"]
|
||||
|
||||
assert status == "completed", f"Video generation did not complete: {status}"
|
||||
|
||||
# 3. Download video content
|
||||
content_resp = requests.get(f"{base}/{video_id}/content")
|
||||
assert content_resp.status_code == 200
|
||||
assert "video/mp4" in content_resp.headers.get("content-type", "")
|
||||
assert len(content_resp.content) > 1000
|
||||
|
||||
# 4. Verify it appears in list
|
||||
list_resp = requests.get(base)
|
||||
assert list_resp.status_code == 200
|
||||
ids = [v["id"] for v in list_resp.json()["data"]]
|
||||
assert video_id in ids
|
||||
|
||||
# 5. Delete
|
||||
del_resp = requests.delete(f"{base}/{video_id}")
|
||||
assert del_resp.status_code == 200
|
||||
assert del_resp.json()["deleted"] is True
|
||||
|
||||
# 6. Confirm gone
|
||||
gone_resp = requests.get(f"{base}/{video_id}")
|
||||
assert gone_resp.status_code == 404
|
||||
|
||||
|
||||
# =========================================================================
|
||||
# WAN 2.2 – Image-to-Video (ti2v)
|
||||
# =========================================================================
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
not _model_available(_WAN_I2V_PATH), reason=f"Wan2.2-I2V model not found at {_WAN_I2V_PATH}"
|
||||
)
|
||||
@pytest.mark.skipif(
|
||||
not _REF_IMAGE_PATH.is_file(), reason=f"Reference image not found at {_REF_IMAGE_PATH}"
|
||||
)
|
||||
@pytest.mark.skipif(
|
||||
not _av_available(), reason="PyAV (av) not installed — required for video encoding in E2E tests"
|
||||
)
|
||||
class TestWanImageToVideo:
|
||||
"""Test Wan2.2-I2V-A14B-Diffusers image-to-video generation via serve API."""
|
||||
|
||||
@pytest.fixture(scope="class")
|
||||
def server(self):
|
||||
with RemoteVisualGenServer(
|
||||
model=str(_WAN_I2V_PATH),
|
||||
extra_visual_gen_options=_make_visual_gen_options(),
|
||||
) as srv:
|
||||
yield srv
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def test_health(self, server):
|
||||
resp = requests.get(server.url_for("health"))
|
||||
assert resp.status_code == 200
|
||||
|
||||
def test_ti2v_sync(self, server):
|
||||
"""Synchronous image-to-video via multipart POST /v1/videos/generations."""
|
||||
with open(_REF_IMAGE_PATH, "rb") as f:
|
||||
resp = requests.post(
|
||||
server.url_for("v1", "videos", "generations"),
|
||||
data={
|
||||
"prompt": "The cat starts playing piano, keys moving",
|
||||
"size": "480x320",
|
||||
"seconds": "1.0",
|
||||
"fps": "8",
|
||||
"num_inference_steps": "4",
|
||||
"seed": "42",
|
||||
},
|
||||
files={
|
||||
"input_reference": ("cat_piano.png", f, "image/png"),
|
||||
},
|
||||
)
|
||||
assert resp.status_code == 200, resp.text
|
||||
assert resp.headers["content-type"] == "video/mp4"
|
||||
assert len(resp.content) > 1000, "Video file too small"
|
||||
|
||||
def test_ti2v_async_lifecycle(self, server):
|
||||
"""Async i2v: create job with image → poll → download → delete."""
|
||||
base = server.url_for("v1", "videos")
|
||||
|
||||
# 1. Create job via multipart
|
||||
with open(_REF_IMAGE_PATH, "rb") as f:
|
||||
create_resp = requests.post(
|
||||
base,
|
||||
data={
|
||||
"prompt": "Snow falls on the piano and the cat",
|
||||
"size": "480x320",
|
||||
"seconds": "1.0",
|
||||
"fps": "8",
|
||||
"num_inference_steps": "4",
|
||||
"seed": "42",
|
||||
},
|
||||
files={
|
||||
"input_reference": ("cat_piano.png", f, "image/png"),
|
||||
},
|
||||
)
|
||||
assert create_resp.status_code == 202, create_resp.text
|
||||
job = create_resp.json()
|
||||
video_id = job["id"]
|
||||
assert job["status"] == "queued"
|
||||
|
||||
# 2. Poll until completed
|
||||
deadline = time.time() + 600
|
||||
status = "queued"
|
||||
while status not in ("completed", "failed") and time.time() < deadline:
|
||||
time.sleep(2)
|
||||
meta_resp = requests.get(f"{base}/{video_id}")
|
||||
assert meta_resp.status_code == 200
|
||||
status = meta_resp.json()["status"]
|
||||
|
||||
assert status == "completed", f"Video generation did not complete: {status}"
|
||||
|
||||
# 3. Download
|
||||
content_resp = requests.get(f"{base}/{video_id}/content")
|
||||
assert content_resp.status_code == 200
|
||||
assert "video/mp4" in content_resp.headers.get("content-type", "")
|
||||
assert len(content_resp.content) > 1000
|
||||
|
||||
# 4. Delete
|
||||
del_resp = requests.delete(f"{base}/{video_id}")
|
||||
assert del_resp.status_code == 200
|
||||
assert del_resp.json()["deleted"] is True
|
||||
|
||||
# 5. Confirm gone
|
||||
gone_resp = requests.get(f"{base}/{video_id}")
|
||||
assert gone_resp.status_code == 404
|
||||
876
tests/unittest/_torch/visual_gen/test_trtllm_serve_endpoints.py
Normal file
876
tests/unittest/_torch/visual_gen/test_trtllm_serve_endpoints.py
Normal file
@ -0,0 +1,876 @@
|
||||
"""trtllm-serve visual_gen endpoints tests.
|
||||
|
||||
Tests all endpoints registered for the VISUAL_GEN server role
|
||||
in OpenAIServer.register_visual_gen_routes():
|
||||
|
||||
POST /v1/images/generations
|
||||
POST /v1/images/edits
|
||||
POST /v1/videos/generations (sync)
|
||||
POST /v1/videos (async)
|
||||
GET /v1/videos (list)
|
||||
GET /v1/videos/{video_id} (metadata)
|
||||
GET /v1/videos/{video_id}/content (download)
|
||||
DELETE /v1/videos/{video_id} (delete)
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import base64
|
||||
import os
|
||||
from io import BytesIO
|
||||
from typing import Optional
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
from fastapi.testclient import TestClient
|
||||
from PIL import Image
|
||||
|
||||
from tensorrt_llm._torch.visual_gen.output import MediaOutput
|
||||
from tensorrt_llm.serve.media_storage import MediaStorage
|
||||
from tensorrt_llm.serve.openai_protocol import VideoJob
|
||||
from tensorrt_llm.serve.visual_gen_utils import VIDEO_STORE
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _make_dummy_image_tensor(height: int = 64, width: int = 64) -> torch.Tensor:
|
||||
"""Create a small dummy uint8 image tensor (H, W, C)."""
|
||||
return torch.randint(0, 256, (height, width, 3), dtype=torch.uint8)
|
||||
|
||||
|
||||
def _make_dummy_video_tensor(
|
||||
num_frames: int = 4, height: int = 64, width: int = 64
|
||||
) -> torch.Tensor:
|
||||
"""Create a small dummy uint8 video tensor (T, H, W, C)."""
|
||||
return torch.randint(0, 256, (num_frames, height, width, 3), dtype=torch.uint8)
|
||||
|
||||
|
||||
def _make_dummy_audio_tensor(length: int = 16000) -> torch.Tensor:
|
||||
"""Create a small dummy float32 audio tensor."""
|
||||
return torch.randn(1, length, dtype=torch.float32)
|
||||
|
||||
|
||||
def _b64_white_png_1x1() -> str:
|
||||
"""Return a base64-encoded 1x1 white PNG for image edit tests."""
|
||||
buf = BytesIO()
|
||||
Image.new("RGB", (1, 1), (255, 255, 255)).save(buf, format="PNG")
|
||||
return base64.b64encode(buf.getvalue()).decode("utf-8")
|
||||
|
||||
|
||||
def _run_async(coro):
|
||||
"""Run an async coroutine in a new event loop (for test helpers)."""
|
||||
loop = asyncio.new_event_loop()
|
||||
try:
|
||||
return loop.run_until_complete(coro)
|
||||
finally:
|
||||
loop.close()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Mock VisualGen
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class MockVisualGen:
|
||||
"""Lightweight stand-in for VisualGen that avoids GPU / model loading."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
image_output: Optional[torch.Tensor] = None,
|
||||
video_output: Optional[torch.Tensor] = None,
|
||||
audio_output: Optional[torch.Tensor] = None,
|
||||
should_fail: bool = False,
|
||||
):
|
||||
self._image = image_output
|
||||
self._video = video_output
|
||||
self._audio = audio_output
|
||||
self._should_fail = should_fail
|
||||
self._healthy = True
|
||||
self.req_counter = 0
|
||||
|
||||
# --- VisualGen interface ---
|
||||
|
||||
def generate(self, inputs=None, params=None) -> MediaOutput:
|
||||
if self._should_fail:
|
||||
raise RuntimeError("Generation intentionally failed")
|
||||
return MediaOutput(
|
||||
image=self._image,
|
||||
video=self._video,
|
||||
audio=self._audio,
|
||||
)
|
||||
|
||||
def generate_async(self, inputs=None, params=None) -> "MockDiffusionGenerationResult":
|
||||
return MockDiffusionGenerationResult(
|
||||
image=self._image,
|
||||
video=self._video,
|
||||
audio=self._audio,
|
||||
should_fail=self._should_fail,
|
||||
)
|
||||
|
||||
def _check_health(self) -> bool:
|
||||
return self._healthy
|
||||
|
||||
async def get_stats_async(self, timeout: int):
|
||||
return
|
||||
|
||||
def shutdown(self):
|
||||
pass
|
||||
|
||||
|
||||
class MockDiffusionGenerationResult:
|
||||
"""Mock future-like result for generate_async."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
image: Optional[torch.Tensor] = None,
|
||||
video: Optional[torch.Tensor] = None,
|
||||
audio: Optional[torch.Tensor] = None,
|
||||
should_fail: bool = False,
|
||||
):
|
||||
self._image = image
|
||||
self._video = video
|
||||
self._audio = audio
|
||||
self._should_fail = should_fail
|
||||
|
||||
async def result(self, timeout=None):
|
||||
if self._should_fail:
|
||||
raise RuntimeError("Async generation intentionally failed")
|
||||
return MediaOutput(
|
||||
image=self._image,
|
||||
video=self._video,
|
||||
audio=self._audio,
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Server factory
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _create_server(generator: MockVisualGen, model_name: str = "test-model") -> TestClient:
|
||||
"""Instantiate an OpenAIServer for VISUAL_GEN with a mocked generator.
|
||||
|
||||
We patch the ``VisualGen`` name inside the ``openai_server`` module so that
|
||||
``isinstance(generator, VisualGen)`` returns True for our mock.
|
||||
"""
|
||||
from tensorrt_llm.llmapi.disagg_utils import ServerRole
|
||||
from tensorrt_llm.serve.openai_server import OpenAIServer
|
||||
|
||||
with patch("tensorrt_llm.serve.openai_server.VisualGen", MockVisualGen):
|
||||
server = OpenAIServer(
|
||||
generator=generator,
|
||||
model=model_name,
|
||||
tool_parser=None,
|
||||
server_role=ServerRole.VISUAL_GEN,
|
||||
metadata_server_cfg=None,
|
||||
)
|
||||
return TestClient(server.app)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Fixtures
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def image_client(tmp_path):
|
||||
"""TestClient backed by a MockVisualGen that produces images."""
|
||||
gen = MockVisualGen(image_output=_make_dummy_image_tensor())
|
||||
os.environ["TRTLLM_MEDIA_STORAGE_PATH"] = str(tmp_path)
|
||||
client = _create_server(gen)
|
||||
yield client
|
||||
os.environ.pop("TRTLLM_MEDIA_STORAGE_PATH", None)
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def video_client(tmp_path):
|
||||
"""TestClient backed by a MockVisualGen that produces videos."""
|
||||
gen = MockVisualGen(video_output=_make_dummy_video_tensor())
|
||||
os.environ["TRTLLM_MEDIA_STORAGE_PATH"] = str(tmp_path)
|
||||
client = _create_server(gen)
|
||||
yield client
|
||||
os.environ.pop("TRTLLM_MEDIA_STORAGE_PATH", None)
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def video_audio_client(tmp_path):
|
||||
"""TestClient backed by a MockVisualGen that produces videos with audio."""
|
||||
gen = MockVisualGen(
|
||||
video_output=_make_dummy_video_tensor(),
|
||||
audio_output=_make_dummy_audio_tensor(),
|
||||
)
|
||||
os.environ["TRTLLM_MEDIA_STORAGE_PATH"] = str(tmp_path)
|
||||
client = _create_server(gen)
|
||||
yield client
|
||||
os.environ.pop("TRTLLM_MEDIA_STORAGE_PATH", None)
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def failing_client(tmp_path):
|
||||
"""TestClient backed by a MockVisualGen that always fails."""
|
||||
gen = MockVisualGen(should_fail=True)
|
||||
os.environ["TRTLLM_MEDIA_STORAGE_PATH"] = str(tmp_path)
|
||||
client = _create_server(gen)
|
||||
yield client
|
||||
os.environ.pop("TRTLLM_MEDIA_STORAGE_PATH", None)
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _clear_video_store():
|
||||
"""Reset the global VIDEO_STORE before each test."""
|
||||
VIDEO_STORE._items.clear()
|
||||
yield
|
||||
VIDEO_STORE._items.clear()
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _mock_video_encoding():
|
||||
"""Mock MP4 encoding to avoid PyAV dependency in unit tests.
|
||||
|
||||
Replaces MediaStorage._save_mp4 with a stub that writes a small
|
||||
dummy file so FileResponse can serve it.
|
||||
"""
|
||||
|
||||
def _dummy_save_mp4(video, audio, output_path, frame_rate):
|
||||
os.makedirs(os.path.dirname(str(output_path)) or ".", exist_ok=True)
|
||||
with open(str(output_path), "wb") as f:
|
||||
f.write(b"\x00\x00\x00\x1cftypisom" + b"\x00" * 32)
|
||||
return str(output_path)
|
||||
|
||||
with patch.object(MediaStorage, "_save_mp4", staticmethod(_dummy_save_mp4)):
|
||||
yield
|
||||
|
||||
|
||||
# =========================================================================
|
||||
# POST /v1/images/generations
|
||||
# =========================================================================
|
||||
|
||||
|
||||
class TestImageGeneration:
|
||||
def test_basic_image_generation_b64(self, image_client):
|
||||
resp = image_client.post(
|
||||
"/v1/images/generations",
|
||||
json={
|
||||
"prompt": "A cat sitting on a mat",
|
||||
"response_format": "b64_json",
|
||||
"size": "64x64",
|
||||
},
|
||||
)
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
assert "data" in data
|
||||
assert len(data["data"]) >= 1
|
||||
img_obj = data["data"][0]
|
||||
assert img_obj["b64_json"] is not None
|
||||
# Verify it decodes to valid bytes
|
||||
decoded = base64.b64decode(img_obj["b64_json"])
|
||||
assert len(decoded) > 0
|
||||
assert img_obj["revised_prompt"] == "A cat sitting on a mat"
|
||||
|
||||
def test_image_generation_with_optional_params(self, image_client):
|
||||
resp = image_client.post(
|
||||
"/v1/images/generations",
|
||||
json={
|
||||
"prompt": "Sunset over ocean",
|
||||
"response_format": "b64_json",
|
||||
"size": "128x64",
|
||||
"num_inference_steps": 20,
|
||||
"guidance_scale": 7.5,
|
||||
"seed": 123,
|
||||
"negative_prompt": "blurry",
|
||||
},
|
||||
)
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
assert data["size"] == "128x64"
|
||||
|
||||
def test_image_generation_url_format_not_supported(self, image_client):
|
||||
resp = image_client.post(
|
||||
"/v1/images/generations",
|
||||
json={
|
||||
"prompt": "A dog",
|
||||
"response_format": "url",
|
||||
},
|
||||
)
|
||||
assert resp.status_code == 501
|
||||
|
||||
def test_image_generation_auto_size(self, image_client):
|
||||
resp = image_client.post(
|
||||
"/v1/images/generations",
|
||||
json={
|
||||
"prompt": "A tree",
|
||||
"response_format": "b64_json",
|
||||
"size": "auto",
|
||||
},
|
||||
)
|
||||
assert resp.status_code == 200
|
||||
|
||||
def test_image_generation_failure(self, failing_client):
|
||||
resp = failing_client.post(
|
||||
"/v1/images/generations",
|
||||
json={
|
||||
"prompt": "A bird",
|
||||
"response_format": "b64_json",
|
||||
},
|
||||
)
|
||||
assert resp.status_code == 400
|
||||
|
||||
def test_image_generation_invalid_size(self, image_client):
|
||||
"""Invalid size triggers RequestValidationError → custom handler → 400."""
|
||||
resp = image_client.post(
|
||||
"/v1/images/generations",
|
||||
json={
|
||||
"prompt": "A mountain",
|
||||
"response_format": "b64_json",
|
||||
"size": "invalid",
|
||||
},
|
||||
)
|
||||
assert resp.status_code == 400
|
||||
|
||||
def test_image_generation_null_output(self, tmp_path):
|
||||
"""Generator returns MediaOutput with image=None."""
|
||||
gen = MockVisualGen(image_output=None)
|
||||
os.environ["TRTLLM_MEDIA_STORAGE_PATH"] = str(tmp_path)
|
||||
client = _create_server(gen)
|
||||
resp = client.post(
|
||||
"/v1/images/generations",
|
||||
json={
|
||||
"prompt": "null image",
|
||||
"response_format": "b64_json",
|
||||
},
|
||||
)
|
||||
assert resp.status_code == 500
|
||||
os.environ.pop("TRTLLM_MEDIA_STORAGE_PATH", None)
|
||||
|
||||
def test_image_generation_multiple_n(self, image_client):
|
||||
"""Request n=2 images in one call."""
|
||||
resp = image_client.post(
|
||||
"/v1/images/generations",
|
||||
json={
|
||||
"prompt": "Flowers",
|
||||
"response_format": "b64_json",
|
||||
"size": "64x64",
|
||||
"n": 2,
|
||||
},
|
||||
)
|
||||
assert resp.status_code == 200
|
||||
|
||||
def test_image_generation_hd_quality(self, image_client):
|
||||
resp = image_client.post(
|
||||
"/v1/images/generations",
|
||||
json={
|
||||
"prompt": "HD landscape",
|
||||
"response_format": "b64_json",
|
||||
"quality": "hd",
|
||||
},
|
||||
)
|
||||
assert resp.status_code == 200
|
||||
|
||||
def test_missing_prompt_image_generation(self, image_client):
|
||||
"""Missing required field → RequestValidationError → custom handler → 400."""
|
||||
resp = image_client.post(
|
||||
"/v1/images/generations",
|
||||
json={},
|
||||
)
|
||||
assert resp.status_code == 400
|
||||
|
||||
|
||||
# =========================================================================
|
||||
# POST /v1/images/edits
|
||||
# =========================================================================
|
||||
|
||||
|
||||
class TestImageEdit:
|
||||
def test_basic_image_edit(self, image_client):
|
||||
b64_img = _b64_white_png_1x1()
|
||||
resp = image_client.post(
|
||||
"/v1/images/edits",
|
||||
json={
|
||||
"image": b64_img,
|
||||
"prompt": "Make it blue",
|
||||
"num_inference_steps": 10,
|
||||
},
|
||||
)
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
assert "data" in data
|
||||
assert len(data["data"]) >= 1
|
||||
assert data["data"][0]["b64_json"] is not None
|
||||
|
||||
def test_image_edit_with_list_images(self, image_client):
|
||||
b64_img = _b64_white_png_1x1()
|
||||
resp = image_client.post(
|
||||
"/v1/images/edits",
|
||||
json={
|
||||
"image": [b64_img, b64_img],
|
||||
"prompt": "Merge them",
|
||||
"num_inference_steps": 10,
|
||||
},
|
||||
)
|
||||
assert resp.status_code == 200
|
||||
|
||||
def test_image_edit_with_mask(self, image_client):
|
||||
b64_img = _b64_white_png_1x1()
|
||||
b64_mask = _b64_white_png_1x1()
|
||||
resp = image_client.post(
|
||||
"/v1/images/edits",
|
||||
json={
|
||||
"image": b64_img,
|
||||
"prompt": "Remove object",
|
||||
"mask": b64_mask,
|
||||
"num_inference_steps": 10,
|
||||
},
|
||||
)
|
||||
assert resp.status_code == 200
|
||||
|
||||
def test_image_edit_with_optional_params(self, image_client):
|
||||
b64_img = _b64_white_png_1x1()
|
||||
resp = image_client.post(
|
||||
"/v1/images/edits",
|
||||
json={
|
||||
"image": b64_img,
|
||||
"prompt": "Enhance colors",
|
||||
"size": "128x128",
|
||||
"guidance_scale": 8.0,
|
||||
"num_inference_steps": 15,
|
||||
"seed": 42,
|
||||
"negative_prompt": "dark",
|
||||
},
|
||||
)
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
assert data["size"] == "128x128"
|
||||
|
||||
def test_image_edit_failure(self, failing_client):
|
||||
b64_img = _b64_white_png_1x1()
|
||||
resp = failing_client.post(
|
||||
"/v1/images/edits",
|
||||
json={
|
||||
"image": b64_img,
|
||||
"prompt": "Edit this",
|
||||
"num_inference_steps": 10,
|
||||
},
|
||||
)
|
||||
assert resp.status_code == 500
|
||||
|
||||
def test_missing_image_for_edit(self, image_client):
|
||||
"""Missing required field → RequestValidationError → custom handler → 400."""
|
||||
resp = image_client.post(
|
||||
"/v1/images/edits",
|
||||
json={
|
||||
"prompt": "Edit without image",
|
||||
},
|
||||
)
|
||||
assert resp.status_code == 400
|
||||
|
||||
|
||||
# =========================================================================
|
||||
# POST /v1/videos/generations (synchronous)
|
||||
# =========================================================================
|
||||
|
||||
|
||||
@pytest.mark.threadleak(enabled=False) # FileResponse spawns AnyIO worker threads
|
||||
class TestVideoGenerationSync:
|
||||
def test_basic_sync_video_generation(self, video_client):
|
||||
resp = video_client.post(
|
||||
"/v1/videos/generations",
|
||||
json={
|
||||
"prompt": "A rocket launching",
|
||||
"size": "64x64",
|
||||
"seconds": 1.0,
|
||||
"fps": 8,
|
||||
},
|
||||
headers={"content-type": "application/json"},
|
||||
)
|
||||
assert resp.status_code == 200
|
||||
assert resp.headers["content-type"] == "video/mp4"
|
||||
assert len(resp.content) > 0
|
||||
|
||||
def test_sync_video_generation_with_params(self, video_client):
|
||||
resp = video_client.post(
|
||||
"/v1/videos/generations",
|
||||
json={
|
||||
"prompt": "Ocean waves",
|
||||
"size": "64x64",
|
||||
"seconds": 2.0,
|
||||
"fps": 8,
|
||||
"num_inference_steps": 10,
|
||||
"guidance_scale": 5.0,
|
||||
"seed": 42,
|
||||
"negative_prompt": "blurry",
|
||||
},
|
||||
headers={"content-type": "application/json"},
|
||||
)
|
||||
assert resp.status_code == 200
|
||||
assert len(resp.content) > 0
|
||||
|
||||
def test_sync_video_generation_multipart(self, video_client):
|
||||
# Use files={} with a dummy file to ensure multipart/form-data
|
||||
dummy_file = BytesIO(b"")
|
||||
resp = video_client.post(
|
||||
"/v1/videos/generations",
|
||||
data={
|
||||
"prompt": "Mountain sunrise",
|
||||
"size": "64x64",
|
||||
"seconds": "1.0",
|
||||
"fps": "8",
|
||||
},
|
||||
files={"_dummy": ("dummy", dummy_file, "application/octet-stream")},
|
||||
)
|
||||
# The server will parse fields; _dummy is ignored since it's not "input_reference"
|
||||
assert resp.status_code == 200
|
||||
assert len(resp.content) > 0
|
||||
|
||||
def test_sync_video_generation_multipart_with_reference(self, video_client, tmp_path):
|
||||
# Create a dummy reference image file
|
||||
ref_path = tmp_path / "ref.png"
|
||||
Image.new("RGB", (4, 4), (128, 128, 128)).save(str(ref_path))
|
||||
|
||||
with open(ref_path, "rb") as f:
|
||||
resp = video_client.post(
|
||||
"/v1/videos/generations",
|
||||
data={
|
||||
"prompt": "Animate this image",
|
||||
"size": "64x64",
|
||||
"seconds": "1.0",
|
||||
"fps": "8",
|
||||
},
|
||||
files={"input_reference": ("ref.png", f, "image/png")},
|
||||
)
|
||||
assert resp.status_code == 200
|
||||
assert len(resp.content) > 0
|
||||
|
||||
def test_sync_video_failure(self, failing_client):
|
||||
resp = failing_client.post(
|
||||
"/v1/videos/generations",
|
||||
json={
|
||||
"prompt": "Should fail",
|
||||
"size": "64x64",
|
||||
"seconds": 1.0,
|
||||
"fps": 8,
|
||||
},
|
||||
headers={"content-type": "application/json"},
|
||||
)
|
||||
assert resp.status_code == 400
|
||||
|
||||
def test_sync_video_null_output(self, tmp_path):
|
||||
"""Generator returns MediaOutput with video=None."""
|
||||
gen = MockVisualGen(video_output=None)
|
||||
os.environ["TRTLLM_MEDIA_STORAGE_PATH"] = str(tmp_path)
|
||||
client = _create_server(gen)
|
||||
resp = client.post(
|
||||
"/v1/videos/generations",
|
||||
json={"prompt": "null video", "size": "64x64", "seconds": 1.0, "fps": 8},
|
||||
headers={"content-type": "application/json"},
|
||||
)
|
||||
assert resp.status_code == 500
|
||||
os.environ.pop("TRTLLM_MEDIA_STORAGE_PATH", None)
|
||||
|
||||
def test_sync_video_unsupported_content_type(self, video_client):
|
||||
resp = video_client.post(
|
||||
"/v1/videos/generations",
|
||||
content=b"some raw bytes",
|
||||
headers={"content-type": "text/plain"},
|
||||
)
|
||||
assert resp.status_code == 400
|
||||
|
||||
def test_sync_video_missing_prompt_json(self, video_client):
|
||||
"""Missing required prompt → Pydantic ValidationError → 400."""
|
||||
resp = video_client.post(
|
||||
"/v1/videos/generations",
|
||||
json={"size": "64x64"},
|
||||
headers={"content-type": "application/json"},
|
||||
)
|
||||
assert resp.status_code == 400
|
||||
|
||||
def test_sync_video_missing_prompt_multipart(self, video_client):
|
||||
"""Missing prompt in multipart form → ValueError → 400."""
|
||||
dummy_file = BytesIO(b"")
|
||||
resp = video_client.post(
|
||||
"/v1/videos/generations",
|
||||
data={"size": "64x64"},
|
||||
files={"_dummy": ("dummy", dummy_file, "application/octet-stream")},
|
||||
)
|
||||
assert resp.status_code == 400
|
||||
|
||||
|
||||
# =========================================================================
|
||||
# POST /v1/videos (asynchronous)
|
||||
# =========================================================================
|
||||
|
||||
|
||||
class TestVideoGenerationAsync:
|
||||
def test_async_video_returns_202(self, video_client):
|
||||
resp = video_client.post(
|
||||
"/v1/videos",
|
||||
json={
|
||||
"prompt": "A dancing robot",
|
||||
"size": "64x64",
|
||||
"seconds": 1.0,
|
||||
"fps": 8,
|
||||
},
|
||||
headers={"content-type": "application/json"},
|
||||
)
|
||||
assert resp.status_code == 202
|
||||
data = resp.json()
|
||||
assert data["status"] == "queued"
|
||||
assert data["object"] == "video"
|
||||
assert data["prompt"] == "A dancing robot"
|
||||
assert data["id"].startswith("video_")
|
||||
|
||||
def test_async_video_job_metadata_fields(self, video_client):
|
||||
resp = video_client.post(
|
||||
"/v1/videos",
|
||||
json={
|
||||
"prompt": "Starry night",
|
||||
"size": "64x64",
|
||||
"seconds": 2.0,
|
||||
"fps": 12,
|
||||
},
|
||||
headers={"content-type": "application/json"},
|
||||
)
|
||||
data = resp.json()
|
||||
assert "created_at" in data
|
||||
assert data["duration"] == 2.0
|
||||
assert data["fps"] == 12
|
||||
assert data["size"] == "64x64"
|
||||
|
||||
def test_async_video_multipart(self, video_client):
|
||||
"""Multipart encoding requires a file field to trigger the correct content-type."""
|
||||
dummy_file = BytesIO(b"")
|
||||
resp = video_client.post(
|
||||
"/v1/videos",
|
||||
data={
|
||||
"prompt": "A sunset",
|
||||
"size": "64x64",
|
||||
"seconds": "1.0",
|
||||
"fps": "8",
|
||||
},
|
||||
files={"_dummy": ("dummy", dummy_file, "application/octet-stream")},
|
||||
)
|
||||
assert resp.status_code == 202
|
||||
|
||||
def test_async_video_invalid_seconds(self, video_client):
|
||||
"""Seconds must be between 1.0 and 16.0. Validation error → 400."""
|
||||
resp = video_client.post(
|
||||
"/v1/videos",
|
||||
json={
|
||||
"prompt": "Too short",
|
||||
"seconds": 0.1,
|
||||
"size": "64x64",
|
||||
"fps": 8,
|
||||
},
|
||||
headers={"content-type": "application/json"},
|
||||
)
|
||||
assert resp.status_code == 400
|
||||
|
||||
def test_async_video_invalid_fps(self, video_client):
|
||||
"""Fps must be between 8 and 60. Validation error → 400."""
|
||||
resp = video_client.post(
|
||||
"/v1/videos",
|
||||
json={
|
||||
"prompt": "Bad fps",
|
||||
"seconds": 1.0,
|
||||
"fps": 2,
|
||||
"size": "64x64",
|
||||
},
|
||||
headers={"content-type": "application/json"},
|
||||
)
|
||||
assert resp.status_code == 400
|
||||
|
||||
|
||||
# =========================================================================
|
||||
# GET /v1/videos (list)
|
||||
# =========================================================================
|
||||
|
||||
|
||||
class TestListVideos:
|
||||
def test_list_videos_empty(self, video_client):
|
||||
resp = video_client.get("/v1/videos")
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
assert data["object"] == "list"
|
||||
assert data["data"] == []
|
||||
|
||||
def test_list_videos_after_creation(self, video_client):
|
||||
# Create two video jobs
|
||||
video_client.post(
|
||||
"/v1/videos",
|
||||
json={"prompt": "First video", "size": "64x64", "seconds": 1.0, "fps": 8},
|
||||
headers={"content-type": "application/json"},
|
||||
)
|
||||
video_client.post(
|
||||
"/v1/videos",
|
||||
json={"prompt": "Second video", "size": "64x64", "seconds": 1.0, "fps": 8},
|
||||
headers={"content-type": "application/json"},
|
||||
)
|
||||
|
||||
resp = video_client.get("/v1/videos")
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
assert len(data["data"]) == 2
|
||||
|
||||
|
||||
# =========================================================================
|
||||
# GET /v1/videos/{video_id} (metadata)
|
||||
# =========================================================================
|
||||
|
||||
|
||||
class TestGetVideoMetadata:
|
||||
def test_get_video_metadata_success(self, video_client):
|
||||
create_resp = video_client.post(
|
||||
"/v1/videos",
|
||||
json={"prompt": "Space walk", "size": "64x64", "seconds": 1.0, "fps": 8},
|
||||
headers={"content-type": "application/json"},
|
||||
)
|
||||
video_id = create_resp.json()["id"]
|
||||
|
||||
resp = video_client.get(f"/v1/videos/{video_id}")
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
assert data["id"] == video_id
|
||||
assert data["object"] == "video"
|
||||
assert data["prompt"] == "Space walk"
|
||||
|
||||
def test_get_video_metadata_not_found(self, video_client):
|
||||
resp = video_client.get("/v1/videos/video_nonexistent")
|
||||
assert resp.status_code == 404
|
||||
|
||||
|
||||
# =========================================================================
|
||||
# GET /v1/videos/{video_id}/content (download)
|
||||
# =========================================================================
|
||||
|
||||
|
||||
@pytest.mark.threadleak(enabled=False) # FileResponse spawns AnyIO worker threads
|
||||
class TestGetVideoContent:
|
||||
def _insert_video_job(self, video_id: str, status: str = "queued"):
|
||||
import time as _time
|
||||
|
||||
job = VideoJob(
|
||||
created_at=int(_time.time()),
|
||||
id=video_id,
|
||||
model="test-model",
|
||||
prompt="test prompt",
|
||||
status=status,
|
||||
)
|
||||
_run_async(VIDEO_STORE.upsert(video_id, job))
|
||||
|
||||
def test_get_video_content_success(self, tmp_path):
|
||||
gen = MockVisualGen(video_output=_make_dummy_video_tensor())
|
||||
os.environ["TRTLLM_MEDIA_STORAGE_PATH"] = str(tmp_path)
|
||||
client = _create_server(gen)
|
||||
|
||||
video_id = "video_testcontent"
|
||||
self._insert_video_job(video_id, status="completed")
|
||||
|
||||
# Write a dummy mp4 file so FileResponse can serve it
|
||||
video_path = tmp_path / f"{video_id}.mp4"
|
||||
video_path.write_bytes(b"\x00\x00\x00\x1cftyp" + b"\x00" * 16)
|
||||
|
||||
resp = client.get(f"/v1/videos/{video_id}/content")
|
||||
assert resp.status_code == 200
|
||||
assert "video/mp4" in resp.headers.get("content-type", "")
|
||||
assert len(resp.content) > 0
|
||||
os.environ.pop("TRTLLM_MEDIA_STORAGE_PATH", None)
|
||||
|
||||
def test_get_video_content_not_found(self, video_client):
|
||||
resp = video_client.get("/v1/videos/video_nonexistent/content")
|
||||
assert resp.status_code == 404
|
||||
|
||||
def test_get_video_content_not_ready(self, tmp_path):
|
||||
"""A queued video should return 400 when its content is requested."""
|
||||
gen = MockVisualGen(video_output=_make_dummy_video_tensor())
|
||||
os.environ["TRTLLM_MEDIA_STORAGE_PATH"] = str(tmp_path)
|
||||
client = _create_server(gen)
|
||||
|
||||
video_id = "video_notready"
|
||||
self._insert_video_job(video_id, status="queued")
|
||||
|
||||
resp = client.get(f"/v1/videos/{video_id}/content")
|
||||
assert resp.status_code == 400
|
||||
os.environ.pop("TRTLLM_MEDIA_STORAGE_PATH", None)
|
||||
|
||||
def test_get_video_content_completed_but_file_missing(self, tmp_path):
|
||||
"""Video marked completed but file deleted from disk → 404."""
|
||||
gen = MockVisualGen(video_output=_make_dummy_video_tensor())
|
||||
os.environ["TRTLLM_MEDIA_STORAGE_PATH"] = str(tmp_path)
|
||||
client = _create_server(gen)
|
||||
|
||||
video_id = "video_nofile"
|
||||
self._insert_video_job(video_id, status="completed")
|
||||
# Do NOT write a file
|
||||
|
||||
resp = client.get(f"/v1/videos/{video_id}/content")
|
||||
assert resp.status_code == 404
|
||||
os.environ.pop("TRTLLM_MEDIA_STORAGE_PATH", None)
|
||||
|
||||
|
||||
# =========================================================================
|
||||
# DELETE /v1/videos/{video_id}
|
||||
# =========================================================================
|
||||
|
||||
|
||||
class TestDeleteVideo:
|
||||
def test_delete_video_success(self, tmp_path):
|
||||
gen = MockVisualGen(video_output=_make_dummy_video_tensor())
|
||||
os.environ["TRTLLM_MEDIA_STORAGE_PATH"] = str(tmp_path)
|
||||
client = _create_server(gen)
|
||||
|
||||
create_resp = client.post(
|
||||
"/v1/videos",
|
||||
json={"prompt": "Delete me", "size": "64x64", "seconds": 1.0, "fps": 8},
|
||||
headers={"content-type": "application/json"},
|
||||
)
|
||||
video_id = create_resp.json()["id"]
|
||||
|
||||
# Write a dummy video file
|
||||
(tmp_path / f"{video_id}.mp4").write_bytes(b"\x00" * 32)
|
||||
|
||||
resp = client.delete(f"/v1/videos/{video_id}")
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
assert data["deleted"] is True
|
||||
|
||||
# Verify it's gone from the store
|
||||
resp = client.get(f"/v1/videos/{video_id}")
|
||||
assert resp.status_code == 404
|
||||
|
||||
# Verify file is deleted
|
||||
assert not (tmp_path / f"{video_id}.mp4").exists()
|
||||
os.environ.pop("TRTLLM_MEDIA_STORAGE_PATH", None)
|
||||
|
||||
def test_delete_video_not_found(self, video_client):
|
||||
resp = video_client.delete("/v1/videos/video_nonexistent")
|
||||
assert resp.status_code == 404
|
||||
|
||||
def test_delete_video_without_file_on_disk(self, video_client):
|
||||
"""Delete a video job that exists in the store but has no file on disk."""
|
||||
create_resp = video_client.post(
|
||||
"/v1/videos",
|
||||
json={"prompt": "No file", "size": "64x64", "seconds": 1.0, "fps": 8},
|
||||
headers={"content-type": "application/json"},
|
||||
)
|
||||
video_id = create_resp.json()["id"]
|
||||
|
||||
resp = video_client.delete(f"/v1/videos/{video_id}")
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
assert data["deleted"] is True
|
||||
|
||||
def test_delete_video_then_list_empty(self, video_client):
|
||||
"""After deleting the only video, the list should be empty."""
|
||||
create_resp = video_client.post(
|
||||
"/v1/videos",
|
||||
json={"prompt": "Ephemeral", "size": "64x64", "seconds": 1.0, "fps": 8},
|
||||
headers={"content-type": "application/json"},
|
||||
)
|
||||
video_id = create_resp.json()["id"]
|
||||
|
||||
video_client.delete(f"/v1/videos/{video_id}")
|
||||
|
||||
resp = video_client.get("/v1/videos")
|
||||
assert resp.status_code == 200
|
||||
assert resp.json()["data"] == []
|
||||
3094
tests/unittest/_torch/visual_gen/test_wan.py
Normal file
3094
tests/unittest/_torch/visual_gen/test_wan.py
Normal file
File diff suppressed because it is too large
Load Diff
1491
tests/unittest/_torch/visual_gen/test_wan_i2v.py
Normal file
1491
tests/unittest/_torch/visual_gen/test_wan_i2v.py
Normal file
File diff suppressed because it is too large
Load Diff
Loading…
Reference in New Issue
Block a user