mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-02-06 11:11:36 +08:00
Signed-off-by: nv-guomingz <137257613+nv-guomingz@users.noreply.github.com> Signed-off-by: Wangshanshan <30051912+dominicshanshan@users.noreply.github.com>
89 lines
3.9 KiB
Markdown
89 lines
3.9 KiB
Markdown
# MMDiT in SD 3 & SD 3.5
|
|
This document shows how to build and run a [MMDiT](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/transformers/transformer_sd3.py) in Stable Diffusion 3/3.5 with TensorRT-LLM.
|
|
|
|
## Overview
|
|
|
|
The TensorRT LLM implementation of MMDiT can be found in [tensorrt_llm/models/sd3/model.py](../../../../tensorrt_llm/models/mmdit_sd3/model.py). The TensorRT LLM MMDiT (SD 3/3.5) example code is located in [`examples/models/contrib/mmdit`](./). There are main files to build and run MMDiT with TensorRT-LLM:
|
|
|
|
* [`convert_checkpoint.py`](./convert_checkpoint.py) to convert the MMDiT model into TensorRT LLM checkpoint format.
|
|
* [`sample.py`](./sample.py) to run the [diffusers](https://huggingface.co/docs/diffusers/index) pipeline with TensorRT engine(s) to generate images.
|
|
|
|
## Support Matrix
|
|
|
|
- [x] TP
|
|
- [x] CP
|
|
- [ ] FP8
|
|
|
|
## Usage
|
|
|
|
The TensorRT LLM MMDiT example code locates at [examples/models/contrib/mmdit](./). It takes HuggingFace checkpoint as input, and builds the corresponding TensorRT engines. The number of TensorRT engines depends on the number of GPUs used to run inference.
|
|
|
|
### Build MMDiT TensorRT engine(s)
|
|
|
|
This checkpoint will be converted to the TensorRT LLM checkpoint format by [`convert_checkpoint.py`](./convert_checkpoint.py). After that, we can build TensorRT engine(s) with the TensorRT LLM checkpoint.
|
|
|
|
```
|
|
# Convert to TRT-LLM
|
|
python convert_checkpoint.py --model_path='stabilityai/stable-diffusion-3.5-medium'
|
|
trtllm-build --checkpoint_dir=./tllm_checkpoint/ \
|
|
--max_batch_size=2 \
|
|
--remove_input_padding=disable \
|
|
--bert_attention_plugin=auto
|
|
```
|
|
|
|
Set `--max_batch_size` to tell how many images at most you would like to generate. We disable `--remove_input_padding` since we don't need to padding MMDiT's patches.
|
|
|
|
After build, we can find a `./engine_output` directory, it is ready for running MMDiT model with TensorRT LLM now.
|
|
|
|
### Generate images
|
|
|
|
A [`sample.py`](./sample.py) is provided to generated images with the optimized TensorRT engines.
|
|
|
|
If using `float16` for inference, `FusedRMSNorm` from `Apex` used by T5-encoder should be disabled in the [huggingface/transformers](https://github.com/huggingface/transformers/blob/v4.48.3/src/transformers/models/t5/modeling_t5.py#L259) or just uninstall the `apex`:
|
|
```python
|
|
try:
|
|
from apex.normalization import FusedRMSNorm
|
|
|
|
# [NOTE] Avoid using `FusedRMSNorm` for T5 encoder.
|
|
# T5LayerNorm = FusedRMSNorm # noqa
|
|
|
|
logger.info("Discovered apex.normalization.FusedRMSNorm - will use it instead of T5LayerNorm")
|
|
except ImportError:
|
|
# using the normal T5LayerNorm
|
|
pass
|
|
except Exception:
|
|
logger.warning("discovered apex but it failed to load, falling back to T5LayerNorm")
|
|
pass
|
|
|
|
ALL_LAYERNORM_LAYERS.append(T5LayerNorm)
|
|
```
|
|
|
|
Just run `python sample.py` and we can see an image named `sd3.5-mmdit.png` will be generated:
|
|
.
|
|
|
|
### Tensor Parallel
|
|
|
|
```
|
|
# Convert to TRT-LLM
|
|
python convert_checkpoint.py --tp_size=2 --model_path='stabilityai/stable-diffusion-3.5-medium'
|
|
trtllm-build --checkpoint_dir=./tllm_checkpoint/ \
|
|
--max_batch_size=2 \
|
|
--remove_input_padding=disable \
|
|
--bert_attention_plugin=auto
|
|
mpirun -n 2 --allow-run-as-root python sample.py "A capybara holding a sign that reads 'Hello World' in the forrest."
|
|
```
|
|
|
|
### Context Parallel
|
|
|
|
Pipeline with CP is similar to that with TP, but it doesn't support `BertAttention` plugin. And make sure `tensorrt>=10.8.0.43`.
|
|
|
|
```
|
|
# Convert to TRT-LLM
|
|
python convert_checkpoint.py --tp_size=2 --model_path='stabilityai/stable-diffusion-3.5-medium'
|
|
trtllm-build --checkpoint_dir=./tllm_checkpoint/ \
|
|
--max_batch_size=2 \
|
|
--remove_input_padding=disable \
|
|
--bert_attention_plugin=disable
|
|
mpirun -n 2 --allow-run-as-root python sample.py "A capybara holding a sign that reads 'Hello World' in the forrest."
|
|
```
|