9.8 KiB
Visual Generation (Diffusion Models) [Beta]
Background and Motivation
Visual generation models based on diffusion transformers (DiT) have become the standard for high-quality image and video synthesis. These models iteratively denoise latent representations through a learned transformer backbone, then decode the final latents with a VAE to produce pixels. As model sizes and output resolutions grow, efficient inference becomes critical — demanding multi-GPU parallelism, weight quantization, and runtime caching to achieve practical throughput and latency.
TensorRT-LLM VisualGen module provides a unified inference stack for diffusion models. Key capabilities include (subject to change as the feature matures):
- A shared pipeline abstraction for diffusion model families, covering the denoising loop, guidance strategies, and component loading.
- Pluggable attention backends.
- Quantization support (dynamic and static) using the ModelOpt configuration format.
- Multi-GPU parallelism strategies.
- TeaCache — a runtime caching optimization for the transformer backbone.
trtllm-serveintegration with OpenAI-compatible API endpoints.
Note: This is the initial release of TensorRT-LLM VisualGen. APIs, supported models, and optimization options are actively evolving and may change in future releases.
Quick Start
Prerequisites
pip install -r requirements-dev.txt
pip install git+https://github.com/huggingface/diffusers.git
pip install av
Python API
The example scripts under examples/visual_gen/ demonstrate direct Python usage. For Wan2.1 text-to-video generation:
cd examples/visual_gen
python visual_gen_wan_t2v.py \
--model_path Wan-AI/Wan2.1-T2V-1.3B-Diffusers \
--prompt "A cute cat playing piano" \
--height 480 --width 832 --num_frames 33 \
--output_path output.mp4
Run python visual_gen_wan_t2v.py --help for the full list of arguments. Key options control resolution, denoising steps, quantization mode, attention backend, parallelism, and TeaCache settings.
Usage with trtllm-serve
The trtllm-serve command automatically detects diffusion models (by the presence of model_index.json) and launches an OpenAI-compatible visual generation server.
1. Create a YAML configuration file:
# wan_config.yml
linear:
type: default
teacache:
enable_teacache: true
teacache_thresh: 0.2
parallel:
dit_cfg_size: 1
dit_ulysses_size: 1
2. Launch the server:
trtllm-serve Wan-AI/Wan2.1-T2V-1.3B-Diffusers \
--extra_visual_gen_options wan_config.yml
3. Send requests using curl or any OpenAI-compatible client:
Synchronous video generation:
curl -X POST "http://localhost:8000/v1/videos/generations" \
-H "Content-Type: application/json" \
-d '{
"prompt": "A cool cat on a motorcycle in the night",
"seconds": 4.0,
"fps": 24,
"size": "480x832"
}' -o output.mp4
Asynchronous video generation:
# Submit the job
curl -X POST "http://localhost:8000/v1/videos" \
-H "Content-Type: application/json" \
-d '{
"prompt": "A cool cat on a motorcycle in the night",
"seconds": 4.0,
"fps": 24,
"size": "480x832"
}'
# Returns: {"id": "<video_id>", "status": "processing", ...}
# Poll for status
curl -X GET "http://localhost:8000/v1/videos/<video_id>"
# Download when complete
curl -X GET "http://localhost:8000/v1/videos/<video_id>/content" -o output.mp4
The server exposes OpenAI-compatible endpoints for image generation (/v1/images/generations), video generation (/v1/videos, /v1/videos/generations), video management, and standard health/model info endpoints.
The --extra_visual_gen_options YAML file configures quantization (linear), TeaCache (teacache), and parallelism (parallel). See examples/visual_gen/serve/configs/ for reference configurations.
Quantization
TensorRT-LLM VisualGen supports both dynamic quantization (on-the-fly at weight-loading time from BF16 checkpoints) and static quantization (loading pre-quantized checkpoints with embedded scales). Both modes use the same ModelOpt quantization_config format.
Quick start — dynamic quantization via --linear_type:
python visual_gen_wan_t2v.py \
--model_path Wan-AI/Wan2.1-T2V-1.3B-Diffusers \
--prompt "A cute cat playing piano" \
--linear_type trtllm-fp8-per-tensor \
--output_path output_fp8.mp4
Supported --linear_type values: default (BF16/FP16), trtllm-fp8-per-tensor, trtllm-fp8-blockwise, svd-nvfp4.
ModelOpt quantization_config format:
Both dynamic and static quantization use the ModelOpt quantization_config format — the same format found in a model's config.json under the quantization_config field. This config can be passed as a dict to DiffusionArgs.quant_config when constructing the pipeline programmatically:
from tensorrt_llm._torch.visual_gen.config import DiffusionArgs
args = DiffusionArgs(
checkpoint_path="/path/to/model",
quant_config={"quant_algo": "FP8", "dynamic": True}, # dynamic FP8
)
The --linear_type CLI flag is a convenience shorthand that maps to these configs internally (e.g., trtllm-fp8-per-tensor → {"quant_algo": "FP8", "dynamic": True}).
Key fields: "dynamic" controls load-time quantization (true) vs pre-quantized checkpoint (false); "ignore" excludes specific modules from quantization.
Developer Guide
This section describes the TensorRT-LLM VisualGen module architecture and guides developers on how to add support for new diffusion model families.
Architecture Overview
The VisualGen module lives under tensorrt_llm._torch.visual_gen. At a high level, the flow is:
- Config — User-facing
DiffusionArgs(CLI / YAML) is merged with checkpoint metadata intoDiffusionModelConfig. - Pipeline creation & loading —
AutoPipelinedetects the model type frommodel_index.json, instantiates the matchingBasePipelinesubclass, and loads weights (with optional dynamic quantization) and standard components (VAE, text encoder, tokenizer, scheduler). - Execution —
DiffusionExecutorcoordinates multi-GPU inference via worker processes.
Note: Internal module structure is subject to change. Refer to inline docstrings in
tensorrt_llm/_torch/visual_gen/for the latest details.
Implementing a New Diffusion Model
Adding a new model (e.g., a hypothetical "MyDiT") requires four steps. The framework handles weight loading, parallelism, quantization, and serving automatically once the pipeline is registered.
1. Create the Transformer Module
Create the DiT backbone in tensorrt_llm/_torch/visual_gen/models/mydit/transformer_mydit.py. It should be an nn.Module that:
- Uses existing modules (e.g.,
Attentionwith configurable attention backend,Linearfor builtin linear ops) wherever possible. - Implements
load_weights(weights: Dict[str, torch.Tensor])to map checkpoint weight names to module parameters.
2. Create the Pipeline Class
Create a pipeline class extending BasePipeline in tensorrt_llm/_torch/visual_gen/models/mydit/. Override methods for transformer initialization, component loading, and inference. BasePipeline provides the denoising loop, CFG handling, and TeaCache integration — your pipeline only needs to implement model-specific logic. See WanPipeline for a reference implementation.
3. Register the Pipeline
Use the @register_pipeline("MyDiTPipeline") decorator on your pipeline class to register it in the global PIPELINE_REGISTRY. Make sure to export it from models/__init__.py.
4. Update AutoPipeline Detection
In pipeline_registry.py, add detection logic for your model's _class_name in model_index.json.
After these steps, the framework automatically handles:
- Weight loading with optional dynamic quantization via
PipelineLoader - Multi-GPU execution via
DiffusionExecutor - TeaCache integration (if you call
self._setup_teacache()inpost_load_weights()) - Serving via
trtllm-servewith the full endpoint set
Summary and Future Work
Current Status
Supported models: Wan2.1 and Wan2.2 families (text-to-video, image-to-video; 1.3B and 14B variants).
Supported features:
| Feature | Status |
|---|---|
| Multi-GPU Parallelism | CFG parallel, Ulysses sequence parallel (more strategies planned) |
| TeaCache | Caches transformer outputs when timestep embeddings change slowly |
| Quantization | Dynamic (on-the-fly from BF16) and static (pre-quantized checkpoints), both via ModelOpt quantization_config format |
| Attention Backends | Vanilla (torch SDPA) and TRT-LLM optimized fused kernels |
trtllm-serve |
OpenAI-compatible endpoints for image/video generation (sync + async) |
Future Work
- Additional model support: Extend to more diffusion model families.
- More attention backends: Support for additional attention backends.
- Advanced parallelism: Additional parallelism strategies for larger models and higher resolutions.
- Serving enhancements: Improved throughput and user experience for production serving workloads.