TensorRT-LLMs/examples/models/contrib/mmdit/README.md
bhsueh_NV 322ac565fc
chore: clean some ci of qa test (#3083)
* move some models to examples/models/contrib

Signed-off-by: bhsueh <11360707+byshiue@users.noreply.github.com>

* update the document

Signed-off-by: bhsueh <11360707+byshiue@users.noreply.github.com>

* remove arctic, blip2, cogvlm, dbrx from qa test list

Signed-off-by: bhsueh <11360707+byshiue@users.noreply.github.com>

* remove tests of dit, mmdit and stdit from qa test

Signed-off-by: bhsueh <11360707+byshiue@users.noreply.github.com>

* remove grok, jais, sdxl, skywork, smaug from qa test list

Signed-off-by: bhsueh <11360707+byshiue@users.noreply.github.com>

* re-organize the glm examples

Signed-off-by: bhsueh <11360707+byshiue@users.noreply.github.com>

* fix issues after running pre-commit

Signed-off-by: bhsueh <11360707+byshiue@users.noreply.github.com>

* fix some typo in glm_4_9b readme

Signed-off-by: bhsueh <11360707+byshiue@users.noreply.github.com>

* fix bug

Signed-off-by: bhsueh <11360707+byshiue@users.noreply.github.com>

---------

Signed-off-by: bhsueh <11360707+byshiue@users.noreply.github.com>
2025-03-31 14:30:41 +08:00

89 lines
3.9 KiB
Markdown

# MMDiT in SD 3 & SD 3.5
This document shows how to build and run a [MMDiT](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/transformers/transformer_sd3.py) in Stable Diffusion 3/3.5 with TensorRT-LLM.
## Overview
The TensorRT-LLM implementation of MMDiT can be found in [tensorrt_llm/models/sd3/model.py](../../../../tensorrt_llm/models/mmdit_sd3/model.py). The TensorRT-LLM MMDiT (SD 3/3.5) example code is located in [`examples/models/contrib/mmdit`](./). There are main files to build and run MMDiT with TensorRT-LLM:
* [`convert_checkpoint.py`](./convert_checkpoint.py) to convert the MMDiT model into tensorrt-llm checkpoint format.
* [`sample.py`](./sample.py) to run the [diffusers](https://huggingface.co/docs/diffusers/index) pipeline with TensorRT engine(s) to generate images.
## Support Matrix
- [x] TP
- [x] CP
- [ ] FP8
## Usage
The TensorRT-LLM MMDiT example code locates at [examples/models/contrib/mmdit](./). It takes HuggingFace checkpoint as input, and builds the corresponding TensorRT engines. The number of TensorRT engines depends on the number of GPUs used to run inference.
### Build MMDiT TensorRT engine(s)
This checkpoint will be converted to the TensorRT-LLM checkpoint format by [`convert_checkpoint.py`](./convert_checkpoint.py). After that, we can build TensorRT engine(s) with the TensorRT-LLM checkpoint.
```
# Convert to TRT-LLM
python convert_checkpoint.py --model_path='stabilityai/stable-diffusion-3.5-medium'
trtllm-build --checkpoint_dir=./tllm_checkpoint/ \
--max_batch_size=2 \
--remove_input_padding=disable \
--bert_attention_plugin=auto
```
Set `--max_batch_size` to tell how many images at most you would like to generate. We disable `--remove_input_padding` since we don't need to padding MMDiT's patches.
After build, we can find a `./engine_output` directory, it is ready for running MMDiT model with TensorRT-LLM now.
### Generate images
A [`sample.py`](./sample.py) is provided to generated images with the optimized TensorRT engines.
If using `float16` for inference, `FusedRMSNorm` from `Apex` used by T5-encoder should be disabled in the [huggingface/transformers](https://github.com/huggingface/transformers/blob/v4.48.3/src/transformers/models/t5/modeling_t5.py#L259) or just uninstall the `apex`:
```python
try:
from apex.normalization import FusedRMSNorm
# [NOTE] Avoid using `FusedRMSNorm` for T5 encoder.
# T5LayerNorm = FusedRMSNorm # noqa
logger.info("Discovered apex.normalization.FusedRMSNorm - will use it instead of T5LayerNorm")
except ImportError:
# using the normal T5LayerNorm
pass
except Exception:
logger.warning("discovered apex but it failed to load, falling back to T5LayerNorm")
pass
ALL_LAYERNORM_LAYERS.append(T5LayerNorm)
```
Just run `python sample.py` and we can see an image named `sd3.5-mmdit.png` will be generated:
![sd3.5-mmdit.png](./assets/sd3.5-mmdit.png).
### Tensor Parallel
```
# Convert to TRT-LLM
python convert_checkpoint.py --tp_size=2 --model_path='stabilityai/stable-diffusion-3.5-medium'
trtllm-build --checkpoint_dir=./tllm_checkpoint/ \
--max_batch_size=2 \
--remove_input_padding=disable \
--bert_attention_plugin=auto
mpirun -n 2 --allow-run-as-root python sample.py "A capybara holding a sign that reads 'Hello World' in the forrest."
```
### Context Parallel
Pipeline with CP is similar to that with TP, but it doesn't support `BertAttention` plugin. And make sure `tensorrt>=10.8.0.43`.
```
# Convert to TRT-LLM
python convert_checkpoint.py --tp_size=2 --model_path='stabilityai/stable-diffusion-3.5-medium'
trtllm-build --checkpoint_dir=./tllm_checkpoint/ \
--max_batch_size=2 \
--remove_input_padding=disable \
--bert_attention_plugin=disable
mpirun -n 2 --allow-run-as-root python sample.py "A capybara holding a sign that reads 'Hello World' in the forrest."
```