Compare commits

..

23 Commits

Author SHA1 Message Date
Aryan e8aacda762 Release: v0.32.1 2024-12-25 11:34:06 +01:00
Aryan 12184f4015 Fix TorchAO related bugs; revert device_map changes (#10371)
* Revert "Add support for sharded models when TorchAO quantization is enabled (#10256)"

This reverts commit 41ba8c0bf6.

* update tests

* udpate

* update

* update

* update device map tests

* apply review suggestions

* update

* make style

* fix

* update docs

* update tests

* update workflow

* update

* improve tests

* allclose tolerance

* Update src/diffusers/models/modeling_utils.py

Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>

* Update tests/quantization/torchao/test_torchao.py

Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>

* improve tests

* fix

* update correct slices

---------

Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
2024-12-25 11:31:21 +01:00
Sayak Paul 6e1d2da194 fix test pypi installation in the release workflow (#10360)
fix
2024-12-25 11:25:34 +01:00
YiYi Xu 11b1151840 make style for https://github.com/huggingface/diffusers/pull/10368 (#10370)
* fix bug for torch.uint1-7 not support in torch<2.6

* up

---------

Co-authored-by: baymax591 <cbai@mail.nwpu.edu.cn>
2024-12-25 11:25:14 +01:00
sayakpaul cd4d0d8ffb Release: v0.32.0 2024-12-23 20:26:28 +05:30
Aryan 4b557132ce [core] LTX Video 0.9.1 (#10330)
* update

* make style

* update

* update

* update

* make style

* single file related changes

* update

* fix

* update single file urls and docs

* update

* fix
2024-12-23 19:51:33 +05:30
Sayak Paul 851dfa30ae [Tests] Fix more tests sayak (#10359)
* fixes to tests

* fixture

* fixes
2024-12-23 19:11:21 +05:30
Sayak Paul ea1ba0ba53 [LoRA] test fix (#10351)
updates
2024-12-23 15:45:45 +05:30
Aryan 9d27df8071 Rename LTX blocks and docs title (#10213)
* rename blocks and docs

* fix docs

---------

Co-authored-by: Dhruv Nair <dhruv.nair@gmail.com>
2024-12-23 15:29:10 +05:30
Aryan 055d95543a Fix failing CogVideoX LoRA fuse test (#10352)
fix
2024-12-23 14:22:09 +05:30
hlky 71cc2013fe Fix FluxIPAdapterTesterMixin (#10354) 2024-12-23 14:20:06 +05:30
Sayak Paul c34fc34563 [Tests] QoL improvements to the LoRA test suite (#10304)
* misc lora test improvements.

* updates

* fixes to tests
2024-12-23 13:59:55 +05:30
Dhruv Nair 5fcee4a447 [Single File] Fix loading (#10349)
update
2024-12-23 13:12:23 +05:30
Sayak Paul 76e2727b5c [SANA LoRA] sana lora training tests and misc. (#10296)
* sana lora training tests and misc.

* remove push to hub

* Update examples/dreambooth/train_dreambooth_lora_sana.py

Co-authored-by: Aryan <aryan@huggingface.co>

---------

Co-authored-by: Aryan <aryan@huggingface.co>
2024-12-23 12:35:13 +05:30
Aryan 02c777c065 [tests] Refactor TorchAO serialization fast tests (#10271)
refactor
2024-12-23 11:04:57 +05:30
Sayak Paul 6a970a45c5 [docs] fix: torchao example. (#10278)
fix: torchao example.
2024-12-23 11:03:50 +05:30
Aryan ffc0eaab6d Bump minimum TorchAO version to 0.7.0 (#10293)
* bump min torchao version to 0.7.0

* update
2024-12-23 11:03:04 +05:30
Thien Tran 3c2e2aa8a9 .from_single_file() - Add missing .shape (#10332)
Add missing `.shape`
2024-12-23 08:57:25 +05:30
Junsong Chen b58868e6f4 [Sana bug] bug fix for 2K model config (#10340)
* fix the Positinoal Embedding bug in 2K model;

* Change the default model to the BF16 one for more stable training and output

* make style

* substract buffer size

* add compute_module_persistent_sizes

---------

Co-authored-by: yiyixuxu <yixu310@gmail.com>
2024-12-23 08:56:25 +05:30
Dhruv Nair da21d590b5 [Single File] Add Single File support for HunYuan video (#10320)
* update

* Update src/diffusers/loaders/single_file_utils.py

Co-authored-by: Aryan <aryan@huggingface.co>

---------

Co-authored-by: Aryan <aryan@huggingface.co>
2024-12-23 08:44:58 +05:30
YiYi Xu 7c2f0afb1c update get_parameter_dtype (#10342)
add:
q
2024-12-23 08:14:13 +05:30
hlky f615f00f58 Fix enable_sequential_cpu_offload in test_kandinsky_combined (#10324)
Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
2024-12-22 15:28:28 -10:00
Aryan 6aaa0518e3 Community hosted weights for diffusers format HunyuanVideo weights (#10344)
update docs and example to use community weights
2024-12-22 15:26:28 -10:00
96 changed files with 1838 additions and 665 deletions
+2
View File
@@ -359,6 +359,8 @@ jobs:
test_location: "bnb"
- backend: "gguf"
test_location: "gguf"
- backend: "torchao"
test_location: "torchao"
runs-on:
group: aws-g6e-xlarge-plus
container:
+1 -1
View File
@@ -68,7 +68,7 @@ jobs:
- name: Test installing diffusers and importing
run: |
pip install diffusers && pip uninstall diffusers -y
pip install -i https://testpypi.python.org/pypi diffusers
pip install -i https://test.pypi.org/simple/ diffusers
python -c "from diffusers import __version__; print(__version__)"
python -c "from diffusers import DiffusionPipeline; pipe = DiffusionPipeline.from_pretrained('fusing/unet-ldm-dummy-update'); pipe()"
python -c "from diffusers import DiffusionPipeline; pipe = DiffusionPipeline.from_pretrained('hf-internal-testing/tiny-stable-diffusion-pipe', safety_checker=None); pipe('ah suh du')"
+1 -1
View File
@@ -429,7 +429,7 @@
- local: api/pipelines/ledits_pp
title: LEDITS++
- local: api/pipelines/ltx_video
title: LTX
title: LTXVideo
- local: api/pipelines/lumina
title: Lumina-T2X
- local: api/pipelines/marigold
@@ -18,7 +18,7 @@ The model can be loaded with the following code snippet.
```python
from diffusers import AutoencoderKLHunyuanVideo
vae = AutoencoderKLHunyuanVideo.from_pretrained("tencent/HunyuanVideo", torch_dtype=torch.float16)
vae = AutoencoderKLHunyuanVideo.from_pretrained("hunyuanvideo-community/HunyuanVideo", subfolder="vae", torch_dtype=torch.float16)
```
## AutoencoderKLHunyuanVideo
@@ -18,7 +18,7 @@ The model can be loaded with the following code snippet.
```python
from diffusers import AutoencoderKLLTXVideo
vae = AutoencoderKLLTXVideo.from_pretrained("TODO/TODO", subfolder="vae", torch_dtype=torch.float32).to("cuda")
vae = AutoencoderKLLTXVideo.from_pretrained("Lightricks/LTX-Video", subfolder="vae", torch_dtype=torch.float32).to("cuda")
```
## AutoencoderKLLTXVideo
@@ -18,7 +18,7 @@ The model can be loaded with the following code snippet.
```python
from diffusers import HunyuanVideoTransformer3DModel
transformer = HunyuanVideoTransformer3DModel.from_pretrained("tencent/HunyuanVideo", torch_dtype=torch.bfloat16)
transformer = HunyuanVideoTransformer3DModel.from_pretrained("hunyuanvideo-community/HunyuanVideo", subfolder="transformer", torch_dtype=torch.bfloat16)
```
## HunyuanVideoTransformer3DModel
@@ -18,7 +18,7 @@ The model can be loaded with the following code snippet.
```python
from diffusers import LTXVideoTransformer3DModel
transformer = LTXVideoTransformer3DModel.from_pretrained("TODO/TODO", subfolder="transformer", torch_dtype=torch.bfloat16).to("cuda")
transformer = LTXVideoTransformer3DModel.from_pretrained("Lightricks/LTX-Video", subfolder="transformer", torch_dtype=torch.bfloat16).to("cuda")
```
## LTXVideoTransformer3DModel
@@ -22,7 +22,7 @@ The model can be loaded with the following code snippet.
```python
from diffusers import SanaTransformer2DModel
transformer = SanaTransformer2DModel.from_pretrained("Efficient-Large-Model/Sana_1600M_1024px_diffusers", subfolder="transformer", torch_dtype=torch.float16)
transformer = SanaTransformer2DModel.from_pretrained("Efficient-Large-Model/Sana_1600M_1024px_BF16_diffusers", subfolder="transformer", torch_dtype=torch.bfloat16)
```
## SanaTransformer2DModel
@@ -29,7 +29,7 @@ Recommendations for inference:
- Transformer should be in `torch.bfloat16`.
- VAE should be in `torch.float16`.
- `num_frames` should be of the form `4 * k + 1`, for example `49` or `129`.
- For smaller resolution images, try lower values of `shift` (between `2.0` to `5.0`) in the [Scheduler](https://huggingface.co/docs/diffusers/main/en/api/schedulers/flow_match_euler_discrete#diffusers.FlowMatchEulerDiscreteScheduler.shift). For larger resolution images, try higher values (between `7.0` and `12.0`). The default value is `7.0` for HunyuanVideo.
- For smaller resolution videos, try lower values of `shift` (between `2.0` to `5.0`) in the [Scheduler](https://huggingface.co/docs/diffusers/main/en/api/schedulers/flow_match_euler_discrete#diffusers.FlowMatchEulerDiscreteScheduler.shift). For larger resolution images, try higher values (between `7.0` and `12.0`). The default value is `7.0` for HunyuanVideo.
- For more information about supported resolutions and other details, please refer to the original repository [here](https://github.com/Tencent/HunyuanVideo/).
## HunyuanVideoPipeline
+40 -2
View File
@@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License. -->
# LTX
# LTX Video
[LTX Video](https://huggingface.co/Lightricks/LTX-Video) is the first DiT-based video generation model capable of generating high-quality videos in real-time. It produces 24 FPS videos at a 768x512 resolution faster than they can be watched. Trained on a large-scale dataset of diverse videos, the model generates high-resolution videos with realistic and varied content. We provide a model for both text-to-video as well as image + text-to-video usecases.
@@ -22,14 +22,24 @@ Make sure to check out the Schedulers [guide](../../using-diffusers/schedulers.m
</Tip>
Available models:
| Model name | Recommended dtype |
|:-------------:|:-----------------:|
| [`LTX Video 0.9.0`](https://huggingface.co/Lightricks/LTX-Video/blob/main/ltx-video-2b-v0.9.safetensors) | `torch.bfloat16` |
| [`LTX Video 0.9.1`](https://huggingface.co/Lightricks/LTX-Video/blob/main/ltx-video-2b-v0.9.1.safetensors) | `torch.bfloat16` |
Note: The recommended dtype is for the transformer component. The VAE and text encoders can be either `torch.float32`, `torch.bfloat16` or `torch.float16` but the recommended dtype is `torch.bfloat16` as used in the original repository.
## Loading Single Files
Loading the original LTX Video checkpoints is also possible with [`~ModelMixin.from_single_file`].
Loading the original LTX Video checkpoints is also possible with [`~ModelMixin.from_single_file`]. We recommend using `from_single_file` for the Lightricks series of models, as they plan to release multiple models in the future in the single file format.
```python
import torch
from diffusers import AutoencoderKLLTXVideo, LTXImageToVideoPipeline, LTXVideoTransformer3DModel
# `single_file_url` could also be https://huggingface.co/Lightricks/LTX-Video/ltx-video-2b-v0.9.1.safetensors
single_file_url = "https://huggingface.co/Lightricks/LTX-Video/ltx-video-2b-v0.9.safetensors"
transformer = LTXVideoTransformer3DModel.from_single_file(
single_file_url, torch_dtype=torch.bfloat16
@@ -99,6 +109,34 @@ export_to_video(video, "output_gguf_ltx.mp4", fps=24)
Make sure to read the [documentation on GGUF](../../quantization/gguf) to learn more about our GGUF support.
<!-- TODO(aryan): Update this when official weights are supported -->
Loading and running inference with [LTX Video 0.9.1](https://huggingface.co/Lightricks/LTX-Video/blob/main/ltx-video-2b-v0.9.1.safetensors) weights.
```python
import torch
from diffusers import LTXPipeline
from diffusers.utils import export_to_video
pipe = LTXPipeline.from_pretrained("a-r-r-o-w/LTX-Video-0.9.1-diffusers", torch_dtype=torch.bfloat16)
pipe.to("cuda")
prompt = "A woman with long brown hair and light skin smiles at another woman with long blonde hair. The woman with brown hair wears a black jacket and has a small, barely noticeable mole on her right cheek. The camera angle is a close-up, focused on the woman with brown hair's face. The lighting is warm and natural, likely from the setting sun, casting a soft glow on the scene. The scene appears to be real-life footage"
negative_prompt = "worst quality, inconsistent motion, blurry, jittery, distorted"
video = pipe(
prompt=prompt,
negative_prompt=negative_prompt,
width=768,
height=512,
num_frames=161,
decode_timestep=0.03,
decode_noise_scale=0.025,
num_inference_steps=50,
).frames[0]
export_to_video(video, "output.mp4", fps=24)
```
Refer to [this section](https://huggingface.co/docs/diffusers/main/en/api/pipelines/cogvideox#memory-optimization) to learn more about optimizing memory consumption.
## LTXPipeline
+1 -1
View File
@@ -32,9 +32,9 @@ Available models:
| Model | Recommended dtype |
|:-----:|:-----------------:|
| [`Efficient-Large-Model/Sana_1600M_1024px_BF16_diffusers`](https://huggingface.co/Efficient-Large-Model/Sana_1600M_1024px_BF16_diffusers) | `torch.bfloat16` |
| [`Efficient-Large-Model/Sana_1600M_1024px_diffusers`](https://huggingface.co/Efficient-Large-Model/Sana_1600M_1024px_diffusers) | `torch.float16` |
| [`Efficient-Large-Model/Sana_1600M_1024px_MultiLing_diffusers`](https://huggingface.co/Efficient-Large-Model/Sana_1600M_1024px_MultiLing_diffusers) | `torch.float16` |
| [`Efficient-Large-Model/Sana_1600M_1024px_BF16_diffusers`](https://huggingface.co/Efficient-Large-Model/Sana_1600M_1024px_BF16_diffusers) | `torch.bfloat16` |
| [`Efficient-Large-Model/Sana_1600M_512px_diffusers`](https://huggingface.co/Efficient-Large-Model/Sana_1600M_512px_diffusers) | `torch.float16` |
| [`Efficient-Large-Model/Sana_1600M_512px_MultiLing_diffusers`](https://huggingface.co/Efficient-Large-Model/Sana_1600M_512px_MultiLing_diffusers) | `torch.float16` |
| [`Efficient-Large-Model/Sana_600M_1024px_diffusers`](https://huggingface.co/Efficient-Large-Model/Sana_600M_1024px_diffusers) | `torch.float16` |
+66 -2
View File
@@ -25,9 +25,10 @@ Quantize a model by passing [`TorchAoConfig`] to [`~ModelMixin.from_pretrained`]
The example below only quantizes the weights to int8.
```python
import torch
from diffusers import FluxPipeline, FluxTransformer2DModel, TorchAoConfig
model_id = "black-forest-labs/Flux.1-Dev"
model_id = "black-forest-labs/FLUX.1-dev"
dtype = torch.bfloat16
quantization_config = TorchAoConfig("int8wo")
@@ -44,8 +45,14 @@ pipe = FluxPipeline.from_pretrained(
)
pipe.to("cuda")
# Without quantization: ~31.447 GB
# With quantization: ~20.40 GB
print(f"Pipeline memory usage: {torch.cuda.max_memory_reserved() / 1024**3:.3f} GB")
prompt = "A cat holding a sign that says hello world"
image = pipe(prompt, num_inference_steps=28, guidance_scale=0.0).images[0]
image = pipe(
prompt, num_inference_steps=50, guidance_scale=4.5, max_sequence_length=512
).images[0]
image.save("output.png")
```
@@ -86,6 +93,63 @@ Some quantization methods are aliases (for example, `int8wo` is the commonly use
Refer to the official torchao documentation for a better understanding of the available quantization methods and the exhaustive list of configuration options available.
## Serializing and Deserializing quantized models
To serialize a quantized model in a given dtype, first load the model with the desired quantization dtype and then save it using the [`~ModelMixin.save_pretrained`] method.
```python
import torch
from diffusers import FluxTransformer2DModel, TorchAoConfig
quantization_config = TorchAoConfig("int8wo")
transformer = FluxTransformer2DModel.from_pretrained(
"black-forest-labs/Flux.1-Dev",
subfolder="transformer",
quantization_config=quantization_config,
torch_dtype=torch.bfloat16,
)
transformer.save_pretrained("/path/to/flux_int8wo", safe_serialization=False)
```
To load a serialized quantized model, use the [`~ModelMixin.from_pretrained`] method.
```python
import torch
from diffusers import FluxPipeline, FluxTransformer2DModel
transformer = FluxTransformer2DModel.from_pretrained("/path/to/flux_int8wo", torch_dtype=torch.bfloat16, use_safetensors=False)
pipe = FluxPipeline.from_pretrained("black-forest-labs/Flux.1-Dev", transformer=transformer, torch_dtype=torch.bfloat16)
pipe.to("cuda")
prompt = "A cat holding a sign that says hello world"
image = pipe(prompt, num_inference_steps=30, guidance_scale=7.0).images[0]
image.save("output.png")
```
Some quantization methods, such as `uint4wo`, cannot be loaded directly and may result in an `UnpicklingError` when trying to load the models, but work as expected when saving them. In order to work around this, one can load the state dict manually into the model. Note, however, that this requires using `weights_only=False` in `torch.load`, so it should be run only if the weights were obtained from a trustable source.
```python
import torch
from accelerate import init_empty_weights
from diffusers import FluxPipeline, FluxTransformer2DModel, TorchAoConfig
# Serialize the model
transformer = FluxTransformer2DModel.from_pretrained(
"black-forest-labs/Flux.1-Dev",
subfolder="transformer",
quantization_config=TorchAoConfig("uint4wo"),
torch_dtype=torch.bfloat16,
)
transformer.save_pretrained("/path/to/flux_uint4wo", safe_serialization=False, max_shard_size="50GB")
# ...
# Load the model
state_dict = torch.load("/path/to/flux_uint4wo/diffusion_pytorch_model.bin", weights_only=False, map_location="cpu")
with init_empty_weights():
transformer = FluxTransformer2DModel.from_config("/path/to/flux_uint4wo/config.json")
transformer.load_state_dict(state_dict, strict=True, assign=True)
```
## Resources
- [TorchAO Quantization API](https://github.com/pytorch/ao/blob/main/torchao/quantization/README.md)
@@ -74,7 +74,7 @@ if is_wandb_available():
import wandb
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.32.0.dev0")
check_min_version("0.32.0")
logger = get_logger(__name__)
@@ -73,7 +73,7 @@ from diffusers.utils.import_utils import is_xformers_available
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.32.0.dev0")
check_min_version("0.32.0")
logger = get_logger(__name__)
@@ -79,7 +79,7 @@ if is_wandb_available():
import wandb
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.32.0.dev0")
check_min_version("0.32.0")
logger = get_logger(__name__)
@@ -61,7 +61,7 @@ if is_wandb_available():
import wandb
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.32.0.dev0")
check_min_version("0.32.0")
logger = get_logger(__name__)
+1 -1
View File
@@ -52,7 +52,7 @@ if is_wandb_available():
import wandb
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.32.0.dev0")
check_min_version("0.32.0")
logger = get_logger(__name__)
@@ -43,8 +43,7 @@ from diffusers.utils import BaseOutput, check_min_version
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.32.0.dev0")
check_min_version("0.32.0")
class MarigoldDepthOutput(BaseOutput):
"""
@@ -73,7 +73,7 @@ if is_wandb_available():
import wandb
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.32.0.dev0")
check_min_version("0.32.0")
logger = get_logger(__name__)
@@ -66,7 +66,7 @@ if is_wandb_available():
import wandb
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.32.0.dev0")
check_min_version("0.32.0")
logger = get_logger(__name__)
@@ -79,7 +79,7 @@ if is_wandb_available():
import wandb
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.32.0.dev0")
check_min_version("0.32.0")
logger = get_logger(__name__)
@@ -72,7 +72,7 @@ if is_wandb_available():
import wandb
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.32.0.dev0")
check_min_version("0.32.0")
logger = get_logger(__name__)
@@ -78,7 +78,7 @@ if is_wandb_available():
import wandb
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.32.0.dev0")
check_min_version("0.32.0")
logger = get_logger(__name__)
+1 -1
View File
@@ -60,7 +60,7 @@ if is_wandb_available():
import wandb
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.32.0.dev0")
check_min_version("0.32.0")
logger = get_logger(__name__)
+1 -1
View File
@@ -60,7 +60,7 @@ if is_wandb_available():
import wandb
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.32.0.dev0")
check_min_version("0.32.0")
logger = logging.getLogger(__name__)
+1 -1
View File
@@ -65,7 +65,7 @@ if is_wandb_available():
import wandb
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.32.0.dev0")
check_min_version("0.32.0")
logger = get_logger(__name__)
if is_torch_npu_available():
+1 -1
View File
@@ -59,7 +59,7 @@ if is_wandb_available():
import wandb
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.32.0.dev0")
check_min_version("0.32.0")
logger = get_logger(__name__)
+1 -1
View File
@@ -61,7 +61,7 @@ if is_wandb_available():
import wandb
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.32.0.dev0")
check_min_version("0.32.0")
logger = get_logger(__name__)
if is_torch_npu_available():
@@ -63,7 +63,7 @@ from diffusers.utils.import_utils import is_xformers_available
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.32.0.dev0")
check_min_version("0.32.0")
logger = get_logger(__name__)
@@ -0,0 +1,206 @@
# coding=utf-8
# Copyright 2024 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
import os
import sys
import tempfile
import safetensors
sys.path.append("..")
from test_examples_utils import ExamplesTestsAccelerate, run_command # noqa: E402
logging.basicConfig(level=logging.DEBUG)
logger = logging.getLogger()
stream_handler = logging.StreamHandler(sys.stdout)
logger.addHandler(stream_handler)
class DreamBoothLoRASANA(ExamplesTestsAccelerate):
instance_data_dir = "docs/source/en/imgs"
pretrained_model_name_or_path = "hf-internal-testing/tiny-sana-pipe"
script_path = "examples/dreambooth/train_dreambooth_lora_sana.py"
transformer_layer_type = "transformer_blocks.0.attn1.to_k"
def test_dreambooth_lora_sana(self):
with tempfile.TemporaryDirectory() as tmpdir:
test_args = f"""
{self.script_path}
--pretrained_model_name_or_path {self.pretrained_model_name_or_path}
--instance_data_dir {self.instance_data_dir}
--resolution 32
--train_batch_size 1
--gradient_accumulation_steps 1
--max_train_steps 2
--learning_rate 5.0e-04
--scale_lr
--lr_scheduler constant
--lr_warmup_steps 0
--output_dir {tmpdir}
--max_sequence_length 16
""".split()
test_args.extend(["--instance_prompt", ""])
run_command(self._launch_args + test_args)
# save_pretrained smoke test
self.assertTrue(os.path.isfile(os.path.join(tmpdir, "pytorch_lora_weights.safetensors")))
# make sure the state_dict has the correct naming in the parameters.
lora_state_dict = safetensors.torch.load_file(os.path.join(tmpdir, "pytorch_lora_weights.safetensors"))
is_lora = all("lora" in k for k in lora_state_dict.keys())
self.assertTrue(is_lora)
# when not training the text encoder, all the parameters in the state dict should start
# with `"transformer"` in their names.
starts_with_transformer = all(key.startswith("transformer") for key in lora_state_dict.keys())
self.assertTrue(starts_with_transformer)
def test_dreambooth_lora_latent_caching(self):
with tempfile.TemporaryDirectory() as tmpdir:
test_args = f"""
{self.script_path}
--pretrained_model_name_or_path {self.pretrained_model_name_or_path}
--instance_data_dir {self.instance_data_dir}
--resolution 32
--train_batch_size 1
--gradient_accumulation_steps 1
--max_train_steps 2
--cache_latents
--learning_rate 5.0e-04
--scale_lr
--lr_scheduler constant
--lr_warmup_steps 0
--output_dir {tmpdir}
--max_sequence_length 16
""".split()
test_args.extend(["--instance_prompt", ""])
run_command(self._launch_args + test_args)
# save_pretrained smoke test
self.assertTrue(os.path.isfile(os.path.join(tmpdir, "pytorch_lora_weights.safetensors")))
# make sure the state_dict has the correct naming in the parameters.
lora_state_dict = safetensors.torch.load_file(os.path.join(tmpdir, "pytorch_lora_weights.safetensors"))
is_lora = all("lora" in k for k in lora_state_dict.keys())
self.assertTrue(is_lora)
# when not training the text encoder, all the parameters in the state dict should start
# with `"transformer"` in their names.
starts_with_transformer = all(key.startswith("transformer") for key in lora_state_dict.keys())
self.assertTrue(starts_with_transformer)
def test_dreambooth_lora_layers(self):
with tempfile.TemporaryDirectory() as tmpdir:
test_args = f"""
{self.script_path}
--pretrained_model_name_or_path {self.pretrained_model_name_or_path}
--instance_data_dir {self.instance_data_dir}
--resolution 32
--train_batch_size 1
--gradient_accumulation_steps 1
--max_train_steps 2
--cache_latents
--learning_rate 5.0e-04
--scale_lr
--lora_layers {self.transformer_layer_type}
--lr_scheduler constant
--lr_warmup_steps 0
--output_dir {tmpdir}
--max_sequence_length 16
""".split()
test_args.extend(["--instance_prompt", ""])
run_command(self._launch_args + test_args)
# save_pretrained smoke test
self.assertTrue(os.path.isfile(os.path.join(tmpdir, "pytorch_lora_weights.safetensors")))
# make sure the state_dict has the correct naming in the parameters.
lora_state_dict = safetensors.torch.load_file(os.path.join(tmpdir, "pytorch_lora_weights.safetensors"))
is_lora = all("lora" in k for k in lora_state_dict.keys())
self.assertTrue(is_lora)
# when not training the text encoder, all the parameters in the state dict should start
# with `"transformer"` in their names. In this test, we only params of
# `self.transformer_layer_type` should be in the state dict.
starts_with_transformer = all(self.transformer_layer_type in key for key in lora_state_dict)
self.assertTrue(starts_with_transformer)
def test_dreambooth_lora_sana_checkpointing_checkpoints_total_limit(self):
with tempfile.TemporaryDirectory() as tmpdir:
test_args = f"""
{self.script_path}
--pretrained_model_name_or_path={self.pretrained_model_name_or_path}
--instance_data_dir={self.instance_data_dir}
--output_dir={tmpdir}
--resolution=32
--train_batch_size=1
--gradient_accumulation_steps=1
--max_train_steps=6
--checkpoints_total_limit=2
--checkpointing_steps=2
--max_sequence_length 16
""".split()
test_args.extend(["--instance_prompt", ""])
run_command(self._launch_args + test_args)
self.assertEqual(
{x for x in os.listdir(tmpdir) if "checkpoint" in x},
{"checkpoint-4", "checkpoint-6"},
)
def test_dreambooth_lora_sana_checkpointing_checkpoints_total_limit_removes_multiple_checkpoints(self):
with tempfile.TemporaryDirectory() as tmpdir:
test_args = f"""
{self.script_path}
--pretrained_model_name_or_path={self.pretrained_model_name_or_path}
--instance_data_dir={self.instance_data_dir}
--output_dir={tmpdir}
--resolution=32
--train_batch_size=1
--gradient_accumulation_steps=1
--max_train_steps=4
--checkpointing_steps=2
--max_sequence_length 166
""".split()
test_args.extend(["--instance_prompt", ""])
run_command(self._launch_args + test_args)
self.assertEqual({x for x in os.listdir(tmpdir) if "checkpoint" in x}, {"checkpoint-2", "checkpoint-4"})
resume_run_args = f"""
{self.script_path}
--pretrained_model_name_or_path={self.pretrained_model_name_or_path}
--instance_data_dir={self.instance_data_dir}
--output_dir={tmpdir}
--resolution=32
--train_batch_size=1
--gradient_accumulation_steps=1
--max_train_steps=8
--checkpointing_steps=2
--resume_from_checkpoint=checkpoint-4
--checkpoints_total_limit=2
--max_sequence_length 16
""".split()
resume_run_args.extend(["--instance_prompt", ""])
run_command(self._launch_args + resume_run_args)
self.assertEqual({x for x in os.listdir(tmpdir) if "checkpoint" in x}, {"checkpoint-6", "checkpoint-8"})
+1 -1
View File
@@ -63,7 +63,7 @@ if is_wandb_available():
import wandb
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.32.0.dev0")
check_min_version("0.32.0")
logger = get_logger(__name__)
+1 -1
View File
@@ -35,7 +35,7 @@ from diffusers.utils import check_min_version
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.32.0.dev0")
check_min_version("0.32.0")
# Cache compiled models across invocations of this script.
cc.initialize_cache(os.path.expanduser("~/.cache/jax/compilation_cache"))
+1 -1
View File
@@ -65,7 +65,7 @@ if is_wandb_available():
import wandb
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.32.0.dev0")
check_min_version("0.32.0")
logger = get_logger(__name__)
+1 -1
View File
@@ -70,7 +70,7 @@ if is_wandb_available():
import wandb
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.32.0.dev0")
check_min_version("0.32.0")
logger = get_logger(__name__)
@@ -72,7 +72,7 @@ if is_wandb_available():
import wandb
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.32.0.dev0")
check_min_version("0.32.0")
logger = get_logger(__name__)
@@ -70,7 +70,7 @@ if is_wandb_available():
import wandb
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.32.0.dev0")
check_min_version("0.32.0")
logger = get_logger(__name__)
@@ -943,7 +943,7 @@ def main(args):
# Load scheduler and models
noise_scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained(
args.pretrained_model_name_or_path, subfolder="scheduler"
args.pretrained_model_name_or_path, subfolder="scheduler", revision=args.revision
)
noise_scheduler_copy = copy.deepcopy(noise_scheduler)
text_encoder = Gemma2Model.from_pretrained(
@@ -964,15 +964,6 @@ def main(args):
vae.requires_grad_(False)
text_encoder.requires_grad_(False)
# Initialize a text encoding pipeline and keep it to CPU for now.
text_encoding_pipeline = SanaPipeline.from_pretrained(
args.pretrained_model_name_or_path,
vae=None,
transformer=None,
text_encoder=text_encoder,
tokenizer=tokenizer,
)
# For mixed precision training we cast all non-trainable weights (vae, text_encoder and transformer) to half-precision
# as these weights are only used for inference, keeping weights in full precision is not required.
weight_dtype = torch.float32
@@ -993,6 +984,15 @@ def main(args):
# because Gemma2 is particularly suited for bfloat16.
text_encoder.to(dtype=torch.bfloat16)
# Initialize a text encoding pipeline and keep it to CPU for now.
text_encoding_pipeline = SanaPipeline.from_pretrained(
args.pretrained_model_name_or_path,
vae=None,
transformer=None,
text_encoder=text_encoder,
tokenizer=tokenizer,
)
if args.gradient_checkpointing:
transformer.enable_gradient_checkpointing()
@@ -1182,6 +1182,7 @@ def main(args):
)
if args.offload:
text_encoding_pipeline = text_encoding_pipeline.to("cpu")
prompt_embeds = prompt_embeds.to(transformer.dtype)
return prompt_embeds, prompt_attention_mask
# If no type of tuning is done on the text_encoder and custom instance prompts are NOT
@@ -1216,7 +1217,7 @@ def main(args):
vae_config_scaling_factor = vae.config.scaling_factor
if args.cache_latents:
latents_cache = []
vae = vae.to("cuda")
vae = vae.to(accelerator.device)
for batch in tqdm(train_dataloader, desc="Caching latents"):
with torch.no_grad():
batch["pixel_values"] = batch["pixel_values"].to(
@@ -72,7 +72,7 @@ if is_wandb_available():
import wandb
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.32.0.dev0")
check_min_version("0.32.0")
logger = get_logger(__name__)
@@ -79,7 +79,7 @@ if is_wandb_available():
import wandb
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.32.0.dev0")
check_min_version("0.32.0")
logger = get_logger(__name__)
+1 -1
View File
@@ -63,7 +63,7 @@ if is_wandb_available():
import wandb
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.32.0.dev0")
check_min_version("0.32.0")
logger = get_logger(__name__)
+1 -1
View File
@@ -54,7 +54,7 @@ if is_wandb_available():
import wandb
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.32.0.dev0")
check_min_version("0.32.0")
logger = get_logger(__name__)
@@ -57,7 +57,7 @@ if is_wandb_available():
import wandb
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.32.0.dev0")
check_min_version("0.32.0")
logger = get_logger(__name__)
@@ -57,7 +57,7 @@ if is_wandb_available():
import wandb
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.32.0.dev0")
check_min_version("0.32.0")
logger = get_logger(__name__, log_level="INFO")
@@ -60,7 +60,7 @@ if is_wandb_available():
import wandb
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.32.0.dev0")
check_min_version("0.32.0")
logger = get_logger(__name__, log_level="INFO")
@@ -52,7 +52,7 @@ if is_wandb_available():
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.32.0.dev0")
check_min_version("0.32.0")
logger = get_logger(__name__, log_level="INFO")
@@ -46,7 +46,7 @@ from diffusers.utils import check_min_version, is_wandb_available
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.32.0.dev0")
check_min_version("0.32.0")
logger = get_logger(__name__, log_level="INFO")
@@ -46,7 +46,7 @@ from diffusers.utils import check_min_version, is_wandb_available
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.32.0.dev0")
check_min_version("0.32.0")
logger = get_logger(__name__, log_level="INFO")
@@ -51,7 +51,7 @@ if is_wandb_available():
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.32.0.dev0")
check_min_version("0.32.0")
logger = get_logger(__name__, log_level="INFO")
@@ -60,7 +60,7 @@ if is_wandb_available():
import wandb
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.32.0.dev0")
check_min_version("0.32.0")
logger = get_logger(__name__)
@@ -57,7 +57,7 @@ if is_wandb_available():
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.32.0.dev0")
check_min_version("0.32.0")
logger = get_logger(__name__, log_level="INFO")
@@ -49,7 +49,7 @@ from diffusers.utils import check_min_version
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.32.0.dev0")
check_min_version("0.32.0")
logger = logging.getLogger(__name__)
@@ -56,7 +56,7 @@ if is_wandb_available():
import wandb
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.32.0.dev0")
check_min_version("0.32.0")
logger = get_logger(__name__, log_level="INFO")
@@ -68,7 +68,7 @@ if is_wandb_available():
import wandb
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.32.0.dev0")
check_min_version("0.32.0")
logger = get_logger(__name__)
if is_torch_npu_available():
@@ -55,7 +55,7 @@ from diffusers.utils.torch_utils import is_compiled_module
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.32.0.dev0")
check_min_version("0.32.0")
logger = get_logger(__name__)
if is_torch_npu_available():
@@ -81,7 +81,7 @@ else:
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.32.0.dev0")
check_min_version("0.32.0")
logger = get_logger(__name__)
@@ -56,7 +56,7 @@ else:
# ------------------------------------------------------------------------------
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.32.0.dev0")
check_min_version("0.32.0")
logger = logging.getLogger(__name__)
@@ -76,7 +76,7 @@ else:
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.32.0.dev0")
check_min_version("0.32.0")
logger = get_logger(__name__)
@@ -29,7 +29,7 @@ from diffusers.utils.import_utils import is_xformers_available
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.32.0.dev0")
check_min_version("0.32.0")
logger = get_logger(__name__, log_level="INFO")
+1 -1
View File
@@ -50,7 +50,7 @@ if is_wandb_available():
import wandb
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.32.0.dev0")
check_min_version("0.32.0")
logger = get_logger(__name__, log_level="INFO")
+99 -11
View File
@@ -1,7 +1,9 @@
import argparse
from pathlib import Path
from typing import Any, Dict
import torch
from accelerate import init_empty_weights
from safetensors.torch import load_file
from transformers import T5EncoderModel, T5Tokenizer
@@ -21,7 +23,9 @@ TRANSFORMER_KEYS_RENAME_DICT = {
"k_norm": "norm_k",
}
TRANSFORMER_SPECIAL_KEYS_REMAP = {}
TRANSFORMER_SPECIAL_KEYS_REMAP = {
"vae": remove_keys_,
}
VAE_KEYS_RENAME_DICT = {
# decoder
@@ -54,10 +58,31 @@ VAE_KEYS_RENAME_DICT = {
"per_channel_statistics.std-of-means": "latents_std",
}
VAE_091_RENAME_DICT = {
# decoder
"up_blocks.0": "mid_block",
"up_blocks.1": "up_blocks.0.upsamplers.0",
"up_blocks.2": "up_blocks.0",
"up_blocks.3": "up_blocks.1.upsamplers.0",
"up_blocks.4": "up_blocks.1",
"up_blocks.5": "up_blocks.2.upsamplers.0",
"up_blocks.6": "up_blocks.2",
"up_blocks.7": "up_blocks.3.upsamplers.0",
"up_blocks.8": "up_blocks.3",
# common
"last_time_embedder": "time_embedder",
"last_scale_shift_table": "scale_shift_table",
}
VAE_SPECIAL_KEYS_REMAP = {
"per_channel_statistics.channel": remove_keys_,
"per_channel_statistics.mean-of-means": remove_keys_,
"per_channel_statistics.mean-of-stds": remove_keys_,
"model.diffusion_model": remove_keys_,
}
VAE_091_SPECIAL_KEYS_REMAP = {
"timestep_scale_multiplier": remove_keys_,
}
@@ -80,13 +105,16 @@ def convert_transformer(
ckpt_path: str,
dtype: torch.dtype,
):
PREFIX_KEY = ""
PREFIX_KEY = "model.diffusion_model."
original_state_dict = get_state_dict(load_file(ckpt_path))
transformer = LTXVideoTransformer3DModel().to(dtype=dtype)
with init_empty_weights():
transformer = LTXVideoTransformer3DModel()
for key in list(original_state_dict.keys()):
new_key = key[len(PREFIX_KEY) :]
new_key = key[:]
if new_key.startswith(PREFIX_KEY):
new_key = key[len(PREFIX_KEY) :]
for replace_key, rename_key in TRANSFORMER_KEYS_RENAME_DICT.items():
new_key = new_key.replace(replace_key, rename_key)
update_state_dict_inplace(original_state_dict, key, new_key)
@@ -97,16 +125,21 @@ def convert_transformer(
continue
handler_fn_inplace(key, original_state_dict)
transformer.load_state_dict(original_state_dict, strict=True)
transformer.load_state_dict(original_state_dict, strict=True, assign=True)
return transformer
def convert_vae(ckpt_path: str, dtype: torch.dtype):
def convert_vae(ckpt_path: str, config, dtype: torch.dtype):
PREFIX_KEY = "vae."
original_state_dict = get_state_dict(load_file(ckpt_path))
vae = AutoencoderKLLTXVideo().to(dtype=dtype)
with init_empty_weights():
vae = AutoencoderKLLTXVideo(**config)
for key in list(original_state_dict.keys()):
new_key = key[:]
if new_key.startswith(PREFIX_KEY):
new_key = key[len(PREFIX_KEY) :]
for replace_key, rename_key in VAE_KEYS_RENAME_DICT.items():
new_key = new_key.replace(replace_key, rename_key)
update_state_dict_inplace(original_state_dict, key, new_key)
@@ -117,10 +150,60 @@ def convert_vae(ckpt_path: str, dtype: torch.dtype):
continue
handler_fn_inplace(key, original_state_dict)
vae.load_state_dict(original_state_dict, strict=True)
vae.load_state_dict(original_state_dict, strict=True, assign=True)
return vae
def get_vae_config(version: str) -> Dict[str, Any]:
if version == "0.9.0":
config = {
"in_channels": 3,
"out_channels": 3,
"latent_channels": 128,
"block_out_channels": (128, 256, 512, 512),
"decoder_block_out_channels": (128, 256, 512, 512),
"layers_per_block": (4, 3, 3, 3, 4),
"decoder_layers_per_block": (4, 3, 3, 3, 4),
"spatio_temporal_scaling": (True, True, True, False),
"decoder_spatio_temporal_scaling": (True, True, True, False),
"decoder_inject_noise": (False, False, False, False, False),
"upsample_residual": (False, False, False, False),
"upsample_factor": (1, 1, 1, 1),
"patch_size": 4,
"patch_size_t": 1,
"resnet_norm_eps": 1e-6,
"scaling_factor": 1.0,
"encoder_causal": True,
"decoder_causal": False,
"timestep_conditioning": False,
}
elif version == "0.9.1":
config = {
"in_channels": 3,
"out_channels": 3,
"latent_channels": 128,
"block_out_channels": (128, 256, 512, 512),
"decoder_block_out_channels": (256, 512, 1024),
"layers_per_block": (4, 3, 3, 3, 4),
"decoder_layers_per_block": (5, 6, 7, 8),
"spatio_temporal_scaling": (True, True, True, False),
"decoder_spatio_temporal_scaling": (True, True, True),
"decoder_inject_noise": (True, True, True, False),
"upsample_residual": (True, True, True),
"upsample_factor": (2, 2, 2),
"timestep_conditioning": True,
"patch_size": 4,
"patch_size_t": 1,
"resnet_norm_eps": 1e-6,
"scaling_factor": 1.0,
"encoder_causal": True,
"decoder_causal": False,
}
VAE_KEYS_RENAME_DICT.update(VAE_091_RENAME_DICT)
VAE_SPECIAL_KEYS_REMAP.update(VAE_091_SPECIAL_KEYS_REMAP)
return config
def get_args():
parser = argparse.ArgumentParser()
parser.add_argument(
@@ -139,6 +222,9 @@ def get_args():
parser.add_argument("--save_pipeline", action="store_true")
parser.add_argument("--output_path", type=str, required=True, help="Path where converted model should be saved")
parser.add_argument("--dtype", default="fp32", help="Torch dtype to save the model in.")
parser.add_argument(
"--version", type=str, default="0.9.0", choices=["0.9.0", "0.9.1"], help="Version of the LTX model"
)
return parser.parse_args()
@@ -161,6 +247,7 @@ if __name__ == "__main__":
transformer = None
dtype = DTYPE_MAPPING[args.dtype]
variant = VARIANT_MAPPING[args.dtype]
output_path = Path(args.output_path)
if args.save_pipeline:
assert args.transformer_ckpt_path is not None and args.vae_ckpt_path is not None
@@ -169,13 +256,14 @@ if __name__ == "__main__":
transformer: LTXVideoTransformer3DModel = convert_transformer(args.transformer_ckpt_path, dtype)
if not args.save_pipeline:
transformer.save_pretrained(
args.output_path, safe_serialization=True, max_shard_size="5GB", variant=variant
output_path / "transformer", safe_serialization=True, max_shard_size="5GB", variant=variant
)
if args.vae_ckpt_path is not None:
vae: AutoencoderKLLTXVideo = convert_vae(args.vae_ckpt_path, dtype)
config = get_vae_config(args.version)
vae: AutoencoderKLLTXVideo = convert_vae(args.vae_ckpt_path, config, dtype)
if not args.save_pipeline:
vae.save_pretrained(args.output_path, safe_serialization=True, max_shard_size="5GB", variant=variant)
vae.save_pretrained(output_path / "vae", safe_serialization=True, max_shard_size="5GB", variant=variant)
if args.save_pipeline:
text_encoder_id = "google/t5-v1_1-xxl"
+6
View File
@@ -88,13 +88,18 @@ def main(args):
# y norm
converted_state_dict["caption_norm.weight"] = state_dict.pop("attention_y_norm.weight")
# scheduler
flow_shift = 3.0
# model config
if args.model_type == "SanaMS_1600M_P1_D20":
layer_num = 20
elif args.model_type == "SanaMS_600M_P1_D28":
layer_num = 28
else:
raise ValueError(f"{args.model_type} is not supported.")
# Positional embedding interpolation scale.
interpolation_scale = {512: None, 1024: None, 2048: 1.0}
for depth in range(layer_num):
# Transformer blocks.
@@ -176,6 +181,7 @@ def main(args):
patch_size=1,
norm_elementwise_affine=False,
norm_eps=1e-6,
interpolation_scale=interpolation_scale[args.image_size],
)
if is_accelerate_available():
+1 -1
View File
@@ -254,7 +254,7 @@ version_range_max = max(sys.version_info[1], 10) + 1
setup(
name="diffusers",
version="0.32.0.dev0", # expected format is one of x.y.z.dev0, or x.y.z.rc1 or x.y.z (no to dashes, yes to dots)
version="0.32.1", # expected format is one of x.y.z.dev0, or x.y.z.rc1 or x.y.z (no to dashes, yes to dots)
description="State-of-the-art diffusion in PyTorch and JAX.",
long_description=open("README.md", "r", encoding="utf-8").read(),
long_description_content_type="text/markdown",
+1 -1
View File
@@ -1,4 +1,4 @@
__version__ = "0.32.0.dev0"
__version__ = "0.32.1"
from typing import TYPE_CHECKING
+7 -1
View File
@@ -28,6 +28,7 @@ from .single_file_utils import (
convert_autoencoder_dc_checkpoint_to_diffusers,
convert_controlnet_checkpoint,
convert_flux_transformer_checkpoint_to_diffusers,
convert_hunyuan_video_transformer_to_diffusers,
convert_ldm_unet_checkpoint,
convert_ldm_vae_checkpoint,
convert_ltx_transformer_checkpoint_to_diffusers,
@@ -101,6 +102,10 @@ SINGLE_FILE_LOADABLE_CLASSES = {
"checkpoint_mapping_fn": convert_mochi_transformer_checkpoint_to_diffusers,
"default_subfolder": "transformer",
},
"HunyuanVideoTransformer3DModel": {
"checkpoint_mapping_fn": convert_hunyuan_video_transformer_to_diffusers,
"default_subfolder": "transformer",
},
}
@@ -220,6 +225,7 @@ class FromOriginalModelMixin:
local_files_only = kwargs.pop("local_files_only", None)
subfolder = kwargs.pop("subfolder", None)
revision = kwargs.pop("revision", None)
config_revision = kwargs.pop("config_revision", None)
torch_dtype = kwargs.pop("torch_dtype", None)
quantization_config = kwargs.pop("quantization_config", None)
device = kwargs.pop("device", None)
@@ -297,7 +303,7 @@ class FromOriginalModelMixin:
subfolder=subfolder,
local_files_only=local_files_only,
token=token,
revision=revision,
revision=config_revision,
)
expected_kwargs, optional_kwargs = cls._get_signature_keys(cls)
+161 -2
View File
@@ -108,6 +108,7 @@ CHECKPOINT_KEY_NAMES = {
"autoencoder-dc": "decoder.stages.1.op_list.0.main.conv.conv.bias",
"autoencoder-dc-sana": "encoder.project_in.conv.bias",
"mochi-1-preview": ["model.diffusion_model.blocks.0.attn.qkv_x.weight", "blocks.0.attn.qkv_x.weight"],
"hunyuan-video": "txt_in.individual_token_refiner.blocks.0.adaLN_modulation.1.bias",
}
DIFFUSERS_DEFAULT_PIPELINE_PATHS = {
@@ -156,12 +157,14 @@ DIFFUSERS_DEFAULT_PIPELINE_PATHS = {
"flux-fill": {"pretrained_model_name_or_path": "black-forest-labs/FLUX.1-Fill-dev"},
"flux-depth": {"pretrained_model_name_or_path": "black-forest-labs/FLUX.1-Depth-dev"},
"flux-schnell": {"pretrained_model_name_or_path": "black-forest-labs/FLUX.1-schnell"},
"ltx-video": {"pretrained_model_name_or_path": "Lightricks/LTX-Video"},
"ltx-video": {"pretrained_model_name_or_path": "diffusers/LTX-Video-0.9.0"},
"ltx-video-0.9.1": {"pretrained_model_name_or_path": "diffusers/LTX-Video-0.9.1"},
"autoencoder-dc-f128c512": {"pretrained_model_name_or_path": "mit-han-lab/dc-ae-f128c512-mix-1.0-diffusers"},
"autoencoder-dc-f64c128": {"pretrained_model_name_or_path": "mit-han-lab/dc-ae-f64c128-mix-1.0-diffusers"},
"autoencoder-dc-f32c32": {"pretrained_model_name_or_path": "mit-han-lab/dc-ae-f32c32-mix-1.0-diffusers"},
"autoencoder-dc-f32c32-sana": {"pretrained_model_name_or_path": "mit-han-lab/dc-ae-f32c32-sana-1.0-diffusers"},
"mochi-1-preview": {"pretrained_model_name_or_path": "genmo/mochi-1-preview"},
"hunyuan-video": {"pretrained_model_name_or_path": "hunyuanvideo-community/HunyuanVideo"},
}
# Use to configure model sample size when original config is provided
@@ -603,7 +606,10 @@ def infer_diffusers_model_type(checkpoint):
model_type = "flux-schnell"
elif any(key in checkpoint for key in CHECKPOINT_KEY_NAMES["ltx-video"]):
model_type = "ltx-video"
if "vae.decoder.last_time_embedder.timestep_embedder.linear_1.weight" in checkpoint:
model_type = "ltx-video-0.9.1"
else:
model_type = "ltx-video"
elif CHECKPOINT_KEY_NAMES["autoencoder-dc"] in checkpoint:
encoder_key = "encoder.project_in.conv.conv.bias"
@@ -624,6 +630,9 @@ def infer_diffusers_model_type(checkpoint):
elif any(key in checkpoint for key in CHECKPOINT_KEY_NAMES["mochi-1-preview"]):
model_type = "mochi-1-preview"
elif CHECKPOINT_KEY_NAMES["hunyuan-video"] in checkpoint:
model_type = "hunyuan-video"
else:
model_type = "v1"
@@ -2333,12 +2342,32 @@ def convert_ltx_vae_checkpoint_to_diffusers(checkpoint, **kwargs):
"per_channel_statistics.std-of-means": "latents_std",
}
VAE_091_RENAME_DICT = {
# decoder
"up_blocks.0": "mid_block",
"up_blocks.1": "up_blocks.0.upsamplers.0",
"up_blocks.2": "up_blocks.0",
"up_blocks.3": "up_blocks.1.upsamplers.0",
"up_blocks.4": "up_blocks.1",
"up_blocks.5": "up_blocks.2.upsamplers.0",
"up_blocks.6": "up_blocks.2",
"up_blocks.7": "up_blocks.3.upsamplers.0",
"up_blocks.8": "up_blocks.3",
# common
"last_time_embedder": "time_embedder",
"last_scale_shift_table": "scale_shift_table",
}
VAE_SPECIAL_KEYS_REMAP = {
"per_channel_statistics.channel": remove_keys_,
"per_channel_statistics.mean-of-means": remove_keys_,
"per_channel_statistics.mean-of-stds": remove_keys_,
"timestep_scale_multiplier": remove_keys_,
}
if "vae.decoder.last_time_embedder.timestep_embedder.linear_1.weight" in converted_state_dict:
VAE_KEYS_RENAME_DICT.update(VAE_091_RENAME_DICT)
for key in list(converted_state_dict.keys()):
new_key = key
for replace_key, rename_key in VAE_KEYS_RENAME_DICT.items():
@@ -2522,3 +2551,133 @@ def convert_mochi_transformer_checkpoint_to_diffusers(checkpoint, **kwargs):
new_state_dict["pos_frequencies"] = checkpoint.pop("pos_frequencies")
return new_state_dict
def convert_hunyuan_video_transformer_to_diffusers(checkpoint, **kwargs):
def remap_norm_scale_shift_(key, state_dict):
weight = state_dict.pop(key)
shift, scale = weight.chunk(2, dim=0)
new_weight = torch.cat([scale, shift], dim=0)
state_dict[key.replace("final_layer.adaLN_modulation.1", "norm_out.linear")] = new_weight
def remap_txt_in_(key, state_dict):
def rename_key(key):
new_key = key.replace("individual_token_refiner.blocks", "token_refiner.refiner_blocks")
new_key = new_key.replace("adaLN_modulation.1", "norm_out.linear")
new_key = new_key.replace("txt_in", "context_embedder")
new_key = new_key.replace("t_embedder.mlp.0", "time_text_embed.timestep_embedder.linear_1")
new_key = new_key.replace("t_embedder.mlp.2", "time_text_embed.timestep_embedder.linear_2")
new_key = new_key.replace("c_embedder", "time_text_embed.text_embedder")
new_key = new_key.replace("mlp", "ff")
return new_key
if "self_attn_qkv" in key:
weight = state_dict.pop(key)
to_q, to_k, to_v = weight.chunk(3, dim=0)
state_dict[rename_key(key.replace("self_attn_qkv", "attn.to_q"))] = to_q
state_dict[rename_key(key.replace("self_attn_qkv", "attn.to_k"))] = to_k
state_dict[rename_key(key.replace("self_attn_qkv", "attn.to_v"))] = to_v
else:
state_dict[rename_key(key)] = state_dict.pop(key)
def remap_img_attn_qkv_(key, state_dict):
weight = state_dict.pop(key)
to_q, to_k, to_v = weight.chunk(3, dim=0)
state_dict[key.replace("img_attn_qkv", "attn.to_q")] = to_q
state_dict[key.replace("img_attn_qkv", "attn.to_k")] = to_k
state_dict[key.replace("img_attn_qkv", "attn.to_v")] = to_v
def remap_txt_attn_qkv_(key, state_dict):
weight = state_dict.pop(key)
to_q, to_k, to_v = weight.chunk(3, dim=0)
state_dict[key.replace("txt_attn_qkv", "attn.add_q_proj")] = to_q
state_dict[key.replace("txt_attn_qkv", "attn.add_k_proj")] = to_k
state_dict[key.replace("txt_attn_qkv", "attn.add_v_proj")] = to_v
def remap_single_transformer_blocks_(key, state_dict):
hidden_size = 3072
if "linear1.weight" in key:
linear1_weight = state_dict.pop(key)
split_size = (hidden_size, hidden_size, hidden_size, linear1_weight.size(0) - 3 * hidden_size)
q, k, v, mlp = torch.split(linear1_weight, split_size, dim=0)
new_key = key.replace("single_blocks", "single_transformer_blocks").removesuffix(".linear1.weight")
state_dict[f"{new_key}.attn.to_q.weight"] = q
state_dict[f"{new_key}.attn.to_k.weight"] = k
state_dict[f"{new_key}.attn.to_v.weight"] = v
state_dict[f"{new_key}.proj_mlp.weight"] = mlp
elif "linear1.bias" in key:
linear1_bias = state_dict.pop(key)
split_size = (hidden_size, hidden_size, hidden_size, linear1_bias.size(0) - 3 * hidden_size)
q_bias, k_bias, v_bias, mlp_bias = torch.split(linear1_bias, split_size, dim=0)
new_key = key.replace("single_blocks", "single_transformer_blocks").removesuffix(".linear1.bias")
state_dict[f"{new_key}.attn.to_q.bias"] = q_bias
state_dict[f"{new_key}.attn.to_k.bias"] = k_bias
state_dict[f"{new_key}.attn.to_v.bias"] = v_bias
state_dict[f"{new_key}.proj_mlp.bias"] = mlp_bias
else:
new_key = key.replace("single_blocks", "single_transformer_blocks")
new_key = new_key.replace("linear2", "proj_out")
new_key = new_key.replace("q_norm", "attn.norm_q")
new_key = new_key.replace("k_norm", "attn.norm_k")
state_dict[new_key] = state_dict.pop(key)
TRANSFORMER_KEYS_RENAME_DICT = {
"img_in": "x_embedder",
"time_in.mlp.0": "time_text_embed.timestep_embedder.linear_1",
"time_in.mlp.2": "time_text_embed.timestep_embedder.linear_2",
"guidance_in.mlp.0": "time_text_embed.guidance_embedder.linear_1",
"guidance_in.mlp.2": "time_text_embed.guidance_embedder.linear_2",
"vector_in.in_layer": "time_text_embed.text_embedder.linear_1",
"vector_in.out_layer": "time_text_embed.text_embedder.linear_2",
"double_blocks": "transformer_blocks",
"img_attn_q_norm": "attn.norm_q",
"img_attn_k_norm": "attn.norm_k",
"img_attn_proj": "attn.to_out.0",
"txt_attn_q_norm": "attn.norm_added_q",
"txt_attn_k_norm": "attn.norm_added_k",
"txt_attn_proj": "attn.to_add_out",
"img_mod.linear": "norm1.linear",
"img_norm1": "norm1.norm",
"img_norm2": "norm2",
"img_mlp": "ff",
"txt_mod.linear": "norm1_context.linear",
"txt_norm1": "norm1.norm",
"txt_norm2": "norm2_context",
"txt_mlp": "ff_context",
"self_attn_proj": "attn.to_out.0",
"modulation.linear": "norm.linear",
"pre_norm": "norm.norm",
"final_layer.norm_final": "norm_out.norm",
"final_layer.linear": "proj_out",
"fc1": "net.0.proj",
"fc2": "net.2",
"input_embedder": "proj_in",
}
TRANSFORMER_SPECIAL_KEYS_REMAP = {
"txt_in": remap_txt_in_,
"img_attn_qkv": remap_img_attn_qkv_,
"txt_attn_qkv": remap_txt_attn_qkv_,
"single_blocks": remap_single_transformer_blocks_,
"final_layer.adaLN_modulation.1": remap_norm_scale_shift_,
}
def update_state_dict_(state_dict, old_key, new_key):
state_dict[new_key] = state_dict.pop(old_key)
for key in list(checkpoint.keys()):
new_key = key[:]
for replace_key, rename_key in TRANSFORMER_KEYS_RENAME_DICT.items():
new_key = new_key.replace(replace_key, rename_key)
update_state_dict_(checkpoint, key, new_key)
for key in list(checkpoint.keys()):
for special_key, handler_fn_inplace in TRANSFORMER_SPECIAL_KEYS_REMAP.items():
if special_key not in key:
continue
handler_fn_inplace(key, checkpoint)
return checkpoint
@@ -177,3 +177,5 @@ class FluxTransformer2DLoadersMixin:
self.encoder_hid_proj = MultiIPAdapterImageProjection(image_projection_layers)
self.config.encoder_hid_dim_type = "ip_image_proj"
self.to(dtype=self.dtype, device=self.device)
@@ -22,13 +22,14 @@ from ...configuration_utils import ConfigMixin, register_to_config
from ...loaders import FromOriginalModelMixin
from ...utils.accelerate_utils import apply_forward_hook
from ..activations import get_activation
from ..embeddings import PixArtAlphaCombinedTimestepSizeEmbeddings
from ..modeling_outputs import AutoencoderKLOutput
from ..modeling_utils import ModelMixin
from ..normalization import RMSNorm
from .vae import DecoderOutput, DiagonalGaussianDistribution
class LTXCausalConv3d(nn.Module):
class LTXVideoCausalConv3d(nn.Module):
def __init__(
self,
in_channels: int,
@@ -79,9 +80,9 @@ class LTXCausalConv3d(nn.Module):
return hidden_states
class LTXResnetBlock3d(nn.Module):
class LTXVideoResnetBlock3d(nn.Module):
r"""
A 3D ResNet block used in the LTX model.
A 3D ResNet block used in the LTXVideo model.
Args:
in_channels (`int`):
@@ -109,7 +110,9 @@ class LTXResnetBlock3d(nn.Module):
elementwise_affine: bool = False,
non_linearity: str = "swish",
is_causal: bool = True,
):
inject_noise: bool = False,
timestep_conditioning: bool = False,
) -> None:
super().__init__()
out_channels = out_channels or in_channels
@@ -117,13 +120,13 @@ class LTXResnetBlock3d(nn.Module):
self.nonlinearity = get_activation(non_linearity)
self.norm1 = RMSNorm(in_channels, eps=1e-8, elementwise_affine=elementwise_affine)
self.conv1 = LTXCausalConv3d(
self.conv1 = LTXVideoCausalConv3d(
in_channels=in_channels, out_channels=out_channels, kernel_size=3, is_causal=is_causal
)
self.norm2 = RMSNorm(out_channels, eps=1e-8, elementwise_affine=elementwise_affine)
self.dropout = nn.Dropout(dropout)
self.conv2 = LTXCausalConv3d(
self.conv2 = LTXVideoCausalConv3d(
in_channels=out_channels, out_channels=out_channels, kernel_size=3, is_causal=is_causal
)
@@ -131,22 +134,58 @@ class LTXResnetBlock3d(nn.Module):
self.conv_shortcut = None
if in_channels != out_channels:
self.norm3 = nn.LayerNorm(in_channels, eps=eps, elementwise_affine=True, bias=True)
self.conv_shortcut = LTXCausalConv3d(
self.conv_shortcut = LTXVideoCausalConv3d(
in_channels=in_channels, out_channels=out_channels, kernel_size=1, stride=1, is_causal=is_causal
)
def forward(self, inputs: torch.Tensor) -> torch.Tensor:
self.per_channel_scale1 = None
self.per_channel_scale2 = None
if inject_noise:
self.per_channel_scale1 = nn.Parameter(torch.zeros(in_channels, 1, 1))
self.per_channel_scale2 = nn.Parameter(torch.zeros(in_channels, 1, 1))
self.scale_shift_table = None
if timestep_conditioning:
self.scale_shift_table = nn.Parameter(torch.randn(4, in_channels) / in_channels**0.5)
def forward(
self, inputs: torch.Tensor, temb: Optional[torch.Tensor] = None, generator: Optional[torch.Generator] = None
) -> torch.Tensor:
hidden_states = inputs
hidden_states = self.norm1(hidden_states.movedim(1, -1)).movedim(-1, 1)
if self.scale_shift_table is not None:
temb = temb.unflatten(1, (4, -1)) + self.scale_shift_table[None, ..., None, None, None]
shift_1, scale_1, shift_2, scale_2 = temb.unbind(dim=1)
hidden_states = hidden_states * (1 + scale_1) + shift_1
hidden_states = self.nonlinearity(hidden_states)
hidden_states = self.conv1(hidden_states)
if self.per_channel_scale1 is not None:
spatial_shape = hidden_states.shape[-2:]
spatial_noise = torch.randn(
spatial_shape, generator=generator, device=hidden_states.device, dtype=hidden_states.dtype
)[None]
hidden_states = hidden_states + (spatial_noise * self.per_channel_scale1)[None, :, None, ...]
hidden_states = self.norm2(hidden_states.movedim(1, -1)).movedim(-1, 1)
if self.scale_shift_table is not None:
hidden_states = hidden_states * (1 + scale_2) + shift_2
hidden_states = self.nonlinearity(hidden_states)
hidden_states = self.dropout(hidden_states)
hidden_states = self.conv2(hidden_states)
if self.per_channel_scale2 is not None:
spatial_shape = hidden_states.shape[-2:]
spatial_noise = torch.randn(
spatial_shape, generator=generator, device=hidden_states.device, dtype=hidden_states.dtype
)[None]
hidden_states = hidden_states + (spatial_noise * self.per_channel_scale2)[None, :, None, ...]
if self.norm3 is not None:
inputs = self.norm3(inputs.movedim(1, -1)).movedim(-1, 1)
@@ -157,20 +196,24 @@ class LTXResnetBlock3d(nn.Module):
return hidden_states
class LTXUpsampler3d(nn.Module):
class LTXVideoUpsampler3d(nn.Module):
def __init__(
self,
in_channels: int,
stride: Union[int, Tuple[int, int, int]] = 1,
is_causal: bool = True,
residual: bool = False,
upscale_factor: int = 1,
) -> None:
super().__init__()
self.stride = stride if isinstance(stride, tuple) else (stride, stride, stride)
self.residual = residual
self.upscale_factor = upscale_factor
out_channels = in_channels * stride[0] * stride[1] * stride[2]
out_channels = (in_channels * stride[0] * stride[1] * stride[2]) // upscale_factor
self.conv = LTXCausalConv3d(
self.conv = LTXVideoCausalConv3d(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=3,
@@ -181,6 +224,15 @@ class LTXUpsampler3d(nn.Module):
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
batch_size, num_channels, num_frames, height, width = hidden_states.shape
if self.residual:
residual = hidden_states.reshape(
batch_size, -1, self.stride[0], self.stride[1], self.stride[2], num_frames, height, width
)
residual = residual.permute(0, 1, 5, 2, 6, 3, 7, 4).flatten(6, 7).flatten(4, 5).flatten(2, 3)
repeats = (self.stride[0] * self.stride[1] * self.stride[2]) // self.upscale_factor
residual = residual.repeat(1, repeats, 1, 1, 1)
residual = residual[:, :, self.stride[0] - 1 :]
hidden_states = self.conv(hidden_states)
hidden_states = hidden_states.reshape(
batch_size, -1, self.stride[0], self.stride[1], self.stride[2], num_frames, height, width
@@ -188,12 +240,15 @@ class LTXUpsampler3d(nn.Module):
hidden_states = hidden_states.permute(0, 1, 5, 2, 6, 3, 7, 4).flatten(6, 7).flatten(4, 5).flatten(2, 3)
hidden_states = hidden_states[:, :, self.stride[0] - 1 :]
if self.residual:
hidden_states = hidden_states + residual
return hidden_states
class LTXDownBlock3D(nn.Module):
class LTXVideoDownBlock3D(nn.Module):
r"""
Down block used in the LTX model.
Down block used in the LTXVideo model.
Args:
in_channels (`int`):
@@ -235,7 +290,7 @@ class LTXDownBlock3D(nn.Module):
resnets = []
for _ in range(num_layers):
resnets.append(
LTXResnetBlock3d(
LTXVideoResnetBlock3d(
in_channels=in_channels,
out_channels=in_channels,
dropout=dropout,
@@ -250,7 +305,7 @@ class LTXDownBlock3D(nn.Module):
if spatio_temporal_scale:
self.downsamplers = nn.ModuleList(
[
LTXCausalConv3d(
LTXVideoCausalConv3d(
in_channels=in_channels,
out_channels=in_channels,
kernel_size=3,
@@ -262,7 +317,7 @@ class LTXDownBlock3D(nn.Module):
self.conv_out = None
if in_channels != out_channels:
self.conv_out = LTXResnetBlock3d(
self.conv_out = LTXVideoResnetBlock3d(
in_channels=in_channels,
out_channels=out_channels,
dropout=dropout,
@@ -273,7 +328,12 @@ class LTXDownBlock3D(nn.Module):
self.gradient_checkpointing = False
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
def forward(
self,
hidden_states: torch.Tensor,
temb: Optional[torch.Tensor] = None,
generator: Optional[torch.Generator] = None,
) -> torch.Tensor:
r"""Forward method of the `LTXDownBlock3D` class."""
for i, resnet in enumerate(self.resnets):
@@ -285,24 +345,26 @@ class LTXDownBlock3D(nn.Module):
return create_forward
hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states)
hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(resnet), hidden_states, temb, generator
)
else:
hidden_states = resnet(hidden_states)
hidden_states = resnet(hidden_states, temb, generator)
if self.downsamplers is not None:
for downsampler in self.downsamplers:
hidden_states = downsampler(hidden_states)
if self.conv_out is not None:
hidden_states = self.conv_out(hidden_states)
hidden_states = self.conv_out(hidden_states, temb, generator)
return hidden_states
# Adapted from diffusers.models.autoencoders.autoencoder_kl_cogvideox.CogVideoMidBlock3d
class LTXMidBlock3d(nn.Module):
class LTXVideoMidBlock3d(nn.Module):
r"""
A middle block used in the LTX model.
A middle block used in the LTXVideo model.
Args:
in_channels (`int`):
@@ -329,28 +391,51 @@ class LTXMidBlock3d(nn.Module):
resnet_eps: float = 1e-6,
resnet_act_fn: str = "swish",
is_causal: bool = True,
inject_noise: bool = False,
timestep_conditioning: bool = False,
) -> None:
super().__init__()
self.time_embedder = None
if timestep_conditioning:
self.time_embedder = PixArtAlphaCombinedTimestepSizeEmbeddings(in_channels * 4, 0)
resnets = []
for _ in range(num_layers):
resnets.append(
LTXResnetBlock3d(
LTXVideoResnetBlock3d(
in_channels=in_channels,
out_channels=in_channels,
dropout=dropout,
eps=resnet_eps,
non_linearity=resnet_act_fn,
is_causal=is_causal,
inject_noise=inject_noise,
timestep_conditioning=timestep_conditioning,
)
)
self.resnets = nn.ModuleList(resnets)
self.gradient_checkpointing = False
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
def forward(
self,
hidden_states: torch.Tensor,
temb: Optional[torch.Tensor] = None,
generator: Optional[torch.Generator] = None,
) -> torch.Tensor:
r"""Forward method of the `LTXMidBlock3D` class."""
if self.time_embedder is not None:
temb = self.time_embedder(
timestep=temb.flatten(),
resolution=None,
aspect_ratio=None,
batch_size=hidden_states.size(0),
hidden_dtype=hidden_states.dtype,
)
temb = temb.view(hidden_states.size(0), -1, 1, 1, 1)
for i, resnet in enumerate(self.resnets):
if torch.is_grad_enabled() and self.gradient_checkpointing:
@@ -360,16 +445,18 @@ class LTXMidBlock3d(nn.Module):
return create_forward
hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states)
hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(resnet), hidden_states, temb, generator
)
else:
hidden_states = resnet(hidden_states)
hidden_states = resnet(hidden_states, temb, generator)
return hidden_states
class LTXUpBlock3d(nn.Module):
class LTXVideoUpBlock3d(nn.Module):
r"""
Up block used in the LTX model.
Up block used in the LTXVideo model.
Args:
in_channels (`int`):
@@ -403,45 +490,82 @@ class LTXUpBlock3d(nn.Module):
resnet_act_fn: str = "swish",
spatio_temporal_scale: bool = True,
is_causal: bool = True,
inject_noise: bool = False,
timestep_conditioning: bool = False,
upsample_residual: bool = False,
upscale_factor: int = 1,
):
super().__init__()
out_channels = out_channels or in_channels
self.time_embedder = None
if timestep_conditioning:
self.time_embedder = PixArtAlphaCombinedTimestepSizeEmbeddings(in_channels * 4, 0)
self.conv_in = None
if in_channels != out_channels:
self.conv_in = LTXResnetBlock3d(
self.conv_in = LTXVideoResnetBlock3d(
in_channels=in_channels,
out_channels=out_channels,
dropout=dropout,
eps=resnet_eps,
non_linearity=resnet_act_fn,
is_causal=is_causal,
inject_noise=inject_noise,
timestep_conditioning=timestep_conditioning,
)
self.upsamplers = None
if spatio_temporal_scale:
self.upsamplers = nn.ModuleList([LTXUpsampler3d(out_channels, stride=(2, 2, 2), is_causal=is_causal)])
self.upsamplers = nn.ModuleList(
[
LTXVideoUpsampler3d(
out_channels * upscale_factor,
stride=(2, 2, 2),
is_causal=is_causal,
residual=upsample_residual,
upscale_factor=upscale_factor,
)
]
)
resnets = []
for _ in range(num_layers):
resnets.append(
LTXResnetBlock3d(
LTXVideoResnetBlock3d(
in_channels=out_channels,
out_channels=out_channels,
dropout=dropout,
eps=resnet_eps,
non_linearity=resnet_act_fn,
is_causal=is_causal,
inject_noise=inject_noise,
timestep_conditioning=timestep_conditioning,
)
)
self.resnets = nn.ModuleList(resnets)
self.gradient_checkpointing = False
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
def forward(
self,
hidden_states: torch.Tensor,
temb: Optional[torch.Tensor] = None,
generator: Optional[torch.Generator] = None,
) -> torch.Tensor:
if self.conv_in is not None:
hidden_states = self.conv_in(hidden_states)
hidden_states = self.conv_in(hidden_states, temb, generator)
if self.time_embedder is not None:
temb = self.time_embedder(
timestep=temb.flatten(),
resolution=None,
aspect_ratio=None,
batch_size=hidden_states.size(0),
hidden_dtype=hidden_states.dtype,
)
temb = temb.view(hidden_states.size(0), -1, 1, 1, 1)
if self.upsamplers is not None:
for upsampler in self.upsamplers:
@@ -456,16 +580,18 @@ class LTXUpBlock3d(nn.Module):
return create_forward
hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states)
hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(resnet), hidden_states, temb, generator
)
else:
hidden_states = resnet(hidden_states)
hidden_states = resnet(hidden_states, temb, generator)
return hidden_states
class LTXEncoder3d(nn.Module):
class LTXVideoEncoder3d(nn.Module):
r"""
The `LTXEncoder3D` layer of a variational autoencoder that encodes input video samples to its latent
The `LTXVideoEncoder3d` layer of a variational autoencoder that encodes input video samples to its latent
representation.
Args:
@@ -509,7 +635,7 @@ class LTXEncoder3d(nn.Module):
output_channel = block_out_channels[0]
self.conv_in = LTXCausalConv3d(
self.conv_in = LTXVideoCausalConv3d(
in_channels=self.in_channels,
out_channels=output_channel,
kernel_size=3,
@@ -524,7 +650,7 @@ class LTXEncoder3d(nn.Module):
input_channel = output_channel
output_channel = block_out_channels[i + 1] if i + 1 < num_block_out_channels else block_out_channels[i]
down_block = LTXDownBlock3D(
down_block = LTXVideoDownBlock3D(
in_channels=input_channel,
out_channels=output_channel,
num_layers=layers_per_block[i],
@@ -536,7 +662,7 @@ class LTXEncoder3d(nn.Module):
self.down_blocks.append(down_block)
# mid block
self.mid_block = LTXMidBlock3d(
self.mid_block = LTXVideoMidBlock3d(
in_channels=output_channel,
num_layers=layers_per_block[-1],
resnet_eps=resnet_norm_eps,
@@ -546,14 +672,14 @@ class LTXEncoder3d(nn.Module):
# out
self.norm_out = RMSNorm(out_channels, eps=1e-8, elementwise_affine=False)
self.conv_act = nn.SiLU()
self.conv_out = LTXCausalConv3d(
self.conv_out = LTXVideoCausalConv3d(
in_channels=output_channel, out_channels=out_channels + 1, kernel_size=3, stride=1, is_causal=is_causal
)
self.gradient_checkpointing = False
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
r"""The forward method of the `LTXEncoder3D` class."""
r"""The forward method of the `LTXVideoEncoder3d` class."""
p = self.patch_size
p_t = self.patch_size_t
@@ -599,9 +725,10 @@ class LTXEncoder3d(nn.Module):
return hidden_states
class LTXDecoder3d(nn.Module):
class LTXVideoDecoder3d(nn.Module):
r"""
The `LTXDecoder3d` layer of a variational autoencoder that decodes its latent representation into an output sample.
The `LTXVideoDecoder3d` layer of a variational autoencoder that decodes its latent representation into an output
sample.
Args:
in_channels (`int`, defaults to 128):
@@ -622,6 +749,8 @@ class LTXDecoder3d(nn.Module):
Epsilon value for ResNet normalization layers.
is_causal (`bool`, defaults to `False`):
Whether this layer behaves causally (future frames depend only on past frames) or not.
timestep_conditioning (`bool`, defaults to `False`):
Whether to condition the model on timesteps.
"""
def __init__(
@@ -635,6 +764,10 @@ class LTXDecoder3d(nn.Module):
patch_size_t: int = 1,
resnet_norm_eps: float = 1e-6,
is_causal: bool = False,
inject_noise: Tuple[bool, ...] = (False, False, False, False),
timestep_conditioning: bool = False,
upsample_residual: Tuple[bool, ...] = (False, False, False, False),
upsample_factor: Tuple[bool, ...] = (1, 1, 1, 1),
) -> None:
super().__init__()
@@ -645,30 +778,42 @@ class LTXDecoder3d(nn.Module):
block_out_channels = tuple(reversed(block_out_channels))
spatio_temporal_scaling = tuple(reversed(spatio_temporal_scaling))
layers_per_block = tuple(reversed(layers_per_block))
inject_noise = tuple(reversed(inject_noise))
upsample_residual = tuple(reversed(upsample_residual))
upsample_factor = tuple(reversed(upsample_factor))
output_channel = block_out_channels[0]
self.conv_in = LTXCausalConv3d(
self.conv_in = LTXVideoCausalConv3d(
in_channels=in_channels, out_channels=output_channel, kernel_size=3, stride=1, is_causal=is_causal
)
self.mid_block = LTXMidBlock3d(
in_channels=output_channel, num_layers=layers_per_block[0], resnet_eps=resnet_norm_eps, is_causal=is_causal
self.mid_block = LTXVideoMidBlock3d(
in_channels=output_channel,
num_layers=layers_per_block[0],
resnet_eps=resnet_norm_eps,
is_causal=is_causal,
inject_noise=inject_noise[0],
timestep_conditioning=timestep_conditioning,
)
# up blocks
num_block_out_channels = len(block_out_channels)
self.up_blocks = nn.ModuleList([])
for i in range(num_block_out_channels):
input_channel = output_channel
output_channel = block_out_channels[i]
input_channel = output_channel // upsample_factor[i]
output_channel = block_out_channels[i] // upsample_factor[i]
up_block = LTXUpBlock3d(
up_block = LTXVideoUpBlock3d(
in_channels=input_channel,
out_channels=output_channel,
num_layers=layers_per_block[i + 1],
resnet_eps=resnet_norm_eps,
spatio_temporal_scale=spatio_temporal_scaling[i],
is_causal=is_causal,
inject_noise=inject_noise[i + 1],
timestep_conditioning=timestep_conditioning,
upsample_residual=upsample_residual[i],
upscale_factor=upsample_factor[i],
)
self.up_blocks.append(up_block)
@@ -676,13 +821,20 @@ class LTXDecoder3d(nn.Module):
# out
self.norm_out = RMSNorm(out_channels, eps=1e-8, elementwise_affine=False)
self.conv_act = nn.SiLU()
self.conv_out = LTXCausalConv3d(
self.conv_out = LTXVideoCausalConv3d(
in_channels=output_channel, out_channels=self.out_channels, kernel_size=3, stride=1, is_causal=is_causal
)
# timestep embedding
self.time_embedder = None
self.scale_shift_table = None
if timestep_conditioning:
self.time_embedder = PixArtAlphaCombinedTimestepSizeEmbeddings(output_channel * 2, 0)
self.scale_shift_table = nn.Parameter(torch.randn(2, output_channel) / output_channel**0.5)
self.gradient_checkpointing = False
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
def forward(self, hidden_states: torch.Tensor, temb: Optional[torch.Tensor] = None) -> torch.Tensor:
hidden_states = self.conv_in(hidden_states)
if torch.is_grad_enabled() and self.gradient_checkpointing:
@@ -693,17 +845,33 @@ class LTXDecoder3d(nn.Module):
return create_forward
hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(self.mid_block), hidden_states)
hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(self.mid_block), hidden_states, temb
)
for up_block in self.up_blocks:
hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(up_block), hidden_states)
hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(up_block), hidden_states, temb)
else:
hidden_states = self.mid_block(hidden_states)
hidden_states = self.mid_block(hidden_states, temb)
for up_block in self.up_blocks:
hidden_states = up_block(hidden_states)
hidden_states = up_block(hidden_states, temb)
hidden_states = self.norm_out(hidden_states.movedim(1, -1)).movedim(-1, 1)
if self.time_embedder is not None:
temb = self.time_embedder(
timestep=temb.flatten(),
resolution=None,
aspect_ratio=None,
batch_size=hidden_states.size(0),
hidden_dtype=hidden_states.dtype,
)
temb = temb.view(hidden_states.size(0), -1, 1, 1, 1).unflatten(1, (2, -1))
temb = temb + self.scale_shift_table[None, ..., None, None, None]
shift, scale = temb.unbind(dim=1)
hidden_states = hidden_states * (1 + scale) + shift
hidden_states = self.conv_act(hidden_states)
hidden_states = self.conv_out(hidden_states)
@@ -766,8 +934,15 @@ class AutoencoderKLLTXVideo(ModelMixin, ConfigMixin, FromOriginalModelMixin):
out_channels: int = 3,
latent_channels: int = 128,
block_out_channels: Tuple[int, ...] = (128, 256, 512, 512),
spatio_temporal_scaling: Tuple[bool, ...] = (True, True, True, False),
decoder_block_out_channels: Tuple[int, ...] = (128, 256, 512, 512),
layers_per_block: Tuple[int, ...] = (4, 3, 3, 3, 4),
decoder_layers_per_block: Tuple[int, ...] = (4, 3, 3, 3, 4),
spatio_temporal_scaling: Tuple[bool, ...] = (True, True, True, False),
decoder_spatio_temporal_scaling: Tuple[bool, ...] = (True, True, True, False),
decoder_inject_noise: Tuple[bool, ...] = (False, False, False, False, False),
upsample_residual: Tuple[bool, ...] = (False, False, False, False),
upsample_factor: Tuple[int, ...] = (1, 1, 1, 1),
timestep_conditioning: bool = False,
patch_size: int = 4,
patch_size_t: int = 1,
resnet_norm_eps: float = 1e-6,
@@ -777,7 +952,7 @@ class AutoencoderKLLTXVideo(ModelMixin, ConfigMixin, FromOriginalModelMixin):
) -> None:
super().__init__()
self.encoder = LTXEncoder3d(
self.encoder = LTXVideoEncoder3d(
in_channels=in_channels,
out_channels=latent_channels,
block_out_channels=block_out_channels,
@@ -788,16 +963,20 @@ class AutoencoderKLLTXVideo(ModelMixin, ConfigMixin, FromOriginalModelMixin):
resnet_norm_eps=resnet_norm_eps,
is_causal=encoder_causal,
)
self.decoder = LTXDecoder3d(
self.decoder = LTXVideoDecoder3d(
in_channels=latent_channels,
out_channels=out_channels,
block_out_channels=block_out_channels,
spatio_temporal_scaling=spatio_temporal_scaling,
layers_per_block=layers_per_block,
block_out_channels=decoder_block_out_channels,
spatio_temporal_scaling=decoder_spatio_temporal_scaling,
layers_per_block=decoder_layers_per_block,
patch_size=patch_size,
patch_size_t=patch_size_t,
resnet_norm_eps=resnet_norm_eps,
is_causal=decoder_causal,
timestep_conditioning=timestep_conditioning,
inject_noise=decoder_inject_noise,
upsample_residual=upsample_residual,
upsample_factor=upsample_factor,
)
latents_mean = torch.zeros((latent_channels,), requires_grad=False)
@@ -837,7 +1016,7 @@ class AutoencoderKLLTXVideo(ModelMixin, ConfigMixin, FromOriginalModelMixin):
self.tile_sample_stride_width = 448
def _set_gradient_checkpointing(self, module, value=False):
if isinstance(module, (LTXEncoder3d, LTXDecoder3d)):
if isinstance(module, (LTXVideoEncoder3d, LTXVideoDecoder3d)):
module.gradient_checkpointing = value
def enable_tiling(
@@ -936,13 +1115,15 @@ class AutoencoderKLLTXVideo(ModelMixin, ConfigMixin, FromOriginalModelMixin):
return (posterior,)
return AutoencoderKLOutput(latent_dist=posterior)
def _decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOutput, torch.Tensor]:
def _decode(
self, z: torch.Tensor, temb: Optional[torch.Tensor] = None, return_dict: bool = True
) -> Union[DecoderOutput, torch.Tensor]:
batch_size, num_channels, num_frames, height, width = z.shape
tile_latent_min_height = self.tile_sample_min_height // self.spatial_compression_ratio
tile_latent_min_width = self.tile_sample_stride_width // self.spatial_compression_ratio
if self.use_tiling and (width > tile_latent_min_width or height > tile_latent_min_height):
return self.tiled_decode(z, return_dict=return_dict)
return self.tiled_decode(z, temb, return_dict=return_dict)
if self.use_framewise_decoding:
# TODO(aryan): requires investigation
@@ -952,7 +1133,7 @@ class AutoencoderKLLTXVideo(ModelMixin, ConfigMixin, FromOriginalModelMixin):
"should be possible, please submit a PR to https://github.com/huggingface/diffusers/pulls."
)
else:
dec = self.decoder(z)
dec = self.decoder(z, temb)
if not return_dict:
return (dec,)
@@ -960,7 +1141,9 @@ class AutoencoderKLLTXVideo(ModelMixin, ConfigMixin, FromOriginalModelMixin):
return DecoderOutput(sample=dec)
@apply_forward_hook
def decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOutput, torch.Tensor]:
def decode(
self, z: torch.Tensor, temb: Optional[torch.Tensor] = None, return_dict: bool = True
) -> Union[DecoderOutput, torch.Tensor]:
"""
Decode a batch of images.
@@ -975,10 +1158,15 @@ class AutoencoderKLLTXVideo(ModelMixin, ConfigMixin, FromOriginalModelMixin):
returned.
"""
if self.use_slicing and z.shape[0] > 1:
decoded_slices = [self._decode(z_slice).sample for z_slice in z.split(1)]
if temb is not None:
decoded_slices = [
self._decode(z_slice, t_slice).sample for z_slice, t_slice in (z.split(1), temb.split(1))
]
else:
decoded_slices = [self._decode(z_slice).sample for z_slice in z.split(1)]
decoded = torch.cat(decoded_slices)
else:
decoded = self._decode(z).sample
decoded = self._decode(z, temb).sample
if not return_dict:
return (decoded,)
@@ -1060,7 +1248,9 @@ class AutoencoderKLLTXVideo(ModelMixin, ConfigMixin, FromOriginalModelMixin):
enc = torch.cat(result_rows, dim=3)[:, :, :, :latent_height, :latent_width]
return enc
def tiled_decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOutput, torch.Tensor]:
def tiled_decode(
self, z: torch.Tensor, temb: Optional[torch.Tensor], return_dict: bool = True
) -> Union[DecoderOutput, torch.Tensor]:
r"""
Decode a batch of images using a tiled decoder.
@@ -1101,7 +1291,9 @@ class AutoencoderKLLTXVideo(ModelMixin, ConfigMixin, FromOriginalModelMixin):
"should be possible, please submit a PR to https://github.com/huggingface/diffusers/pulls."
)
else:
time = self.decoder(z[:, :, :, i : i + tile_latent_min_height, j : j + tile_latent_min_width])
time = self.decoder(
z[:, :, :, i : i + tile_latent_min_height, j : j + tile_latent_min_width], temb
)
row.append(time)
rows.append(row)
@@ -1129,6 +1321,7 @@ class AutoencoderKLLTXVideo(ModelMixin, ConfigMixin, FromOriginalModelMixin):
def forward(
self,
sample: torch.Tensor,
temb: Optional[torch.Tensor] = None,
sample_posterior: bool = False,
return_dict: bool = True,
generator: Optional[torch.Generator] = None,
@@ -1139,7 +1332,7 @@ class AutoencoderKLLTXVideo(ModelMixin, ConfigMixin, FromOriginalModelMixin):
z = posterior.sample(generator=generator)
else:
z = posterior.mode()
dec = self.decode(z)
dec = self.decode(z, temb)
if not return_dict:
return (dec,)
return dec
+1 -1
View File
@@ -748,10 +748,10 @@ class CogVideoXPatchEmbed(nn.Module):
pos_embedding = self._get_positional_embeddings(
height, width, pre_time_compression_frames, device=embeds.device
)
pos_embedding = pos_embedding.to(dtype=embeds.dtype)
else:
pos_embedding = self.pos_embedding
pos_embedding = pos_embedding.to(dtype=embeds.dtype)
embeds = embeds + pos_embedding
return embeds
+1 -1
View File
@@ -228,7 +228,7 @@ def load_model_dict_into_meta(
else:
model_name_or_path_str = f"{model_name_or_path} " if model_name_or_path is not None else ""
raise ValueError(
f"Cannot load {model_name_or_path_str} because {param_name} expected shape {empty_state_dict[param_name]}, but got {param.shape}. If you want to instead overwrite randomly initialized weights, please make sure to pass both `low_cpu_mem_usage=False` and `ignore_mismatched_sizes=True`. For more information, see also: https://github.com/huggingface/diffusers/issues/1619#issuecomment-1345604389 as an example."
f"Cannot load {model_name_or_path_str} because {param_name} expected shape {empty_state_dict[param_name].shape}, but got {param.shape}. If you want to instead overwrite randomly initialized weights, please make sure to pass both `low_cpu_mem_usage=False` and `ignore_mismatched_sizes=True`. For more information, see also: https://github.com/huggingface/diffusers/issues/1619#issuecomment-1345604389 as an example."
)
if is_quantized and (
+4 -4
View File
@@ -718,10 +718,9 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
hf_quantizer = None
if hf_quantizer is not None:
is_bnb_quantization_method = hf_quantizer.quantization_config.quant_method.value == "bitsandbytes"
if is_bnb_quantization_method and device_map is not None:
if device_map is not None:
raise NotImplementedError(
"Currently, `device_map` is automatically inferred for quantized bitsandbytes models. Support for providing `device_map` as an input will be added in the future."
"Currently, providing `device_map` is not supported for quantized models. Providing `device_map` as an input will be added in the future."
)
hf_quantizer.validate_environment(torch_dtype=torch_dtype, from_flax=from_flax, device_map=device_map)
@@ -820,7 +819,8 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
revision=revision,
subfolder=subfolder or "",
)
if hf_quantizer is not None and is_bnb_quantization_method:
# TODO: https://github.com/huggingface/diffusers/issues/10013
if hf_quantizer is not None:
model_file = _merge_sharded_checkpoints(sharded_ckpt_cached_folder, sharded_metadata)
logger.info("Merged sharded checkpoints as `hf_quantizer` is not None.")
is_sharded = False
@@ -242,6 +242,7 @@ class SanaTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
patch_size: int = 1,
norm_elementwise_affine: bool = False,
norm_eps: float = 1e-6,
interpolation_scale: Optional[int] = None,
) -> None:
super().__init__()
@@ -249,14 +250,14 @@ class SanaTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
inner_dim = num_attention_heads * attention_head_dim
# 1. Patch Embedding
interpolation_scale = interpolation_scale if interpolation_scale is not None else max(sample_size // 64, 1)
self.patch_embed = PatchEmbed(
height=sample_size,
width=sample_size,
patch_size=patch_size,
in_channels=in_channels,
embed_dim=inner_dim,
interpolation_scale=None,
pos_embed_type=None,
interpolation_scale=interpolation_scale,
)
# 2. Additional condition embeddings
@@ -18,6 +18,8 @@ import torch
import torch.nn as nn
import torch.nn.functional as F
from diffusers.loaders import FromOriginalModelMixin
from ...configuration_utils import ConfigMixin, register_to_config
from ...loaders import PeftAdapterMixin
from ...utils import USE_PEFT_BACKEND, is_torch_version, logging, scale_lora_layers, unscale_lora_layers
@@ -500,7 +502,7 @@ class HunyuanVideoTransformerBlock(nn.Module):
return hidden_states, encoder_hidden_states
class HunyuanVideoTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
class HunyuanVideoTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin):
r"""
A Transformer model for video-like data used in [HunyuanVideo](https://huggingface.co/tencent/HunyuanVideo).
@@ -35,7 +35,7 @@ from ..normalization import AdaLayerNormSingle, RMSNorm
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
class LTXAttentionProcessor2_0:
class LTXVideoAttentionProcessor2_0:
r"""
Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0). This is
used in the LTX model. It applies a normalization layer and rotary embedding on the query and key vector.
@@ -44,7 +44,7 @@ class LTXAttentionProcessor2_0:
def __init__(self):
if not hasattr(F, "scaled_dot_product_attention"):
raise ImportError(
"LTXAttentionProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0."
"LTXVideoAttentionProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0."
)
def __call__(
@@ -92,7 +92,7 @@ class LTXAttentionProcessor2_0:
return hidden_states
class LTXRotaryPosEmbed(nn.Module):
class LTXVideoRotaryPosEmbed(nn.Module):
def __init__(
self,
dim: int,
@@ -164,7 +164,7 @@ class LTXRotaryPosEmbed(nn.Module):
@maybe_allow_in_graph
class LTXTransformerBlock(nn.Module):
class LTXVideoTransformerBlock(nn.Module):
r"""
Transformer block used in [LTX](https://huggingface.co/Lightricks/LTX-Video).
@@ -208,7 +208,7 @@ class LTXTransformerBlock(nn.Module):
cross_attention_dim=None,
out_bias=attention_out_bias,
qk_norm=qk_norm,
processor=LTXAttentionProcessor2_0(),
processor=LTXVideoAttentionProcessor2_0(),
)
self.norm2 = RMSNorm(dim, eps=eps, elementwise_affine=elementwise_affine)
@@ -221,7 +221,7 @@ class LTXTransformerBlock(nn.Module):
bias=attention_bias,
out_bias=attention_out_bias,
qk_norm=qk_norm,
processor=LTXAttentionProcessor2_0(),
processor=LTXVideoAttentionProcessor2_0(),
)
self.ff = FeedForward(dim, activation_fn=activation_fn)
@@ -327,7 +327,7 @@ class LTXVideoTransformer3DModel(ModelMixin, ConfigMixin, FromOriginalModelMixin
self.caption_projection = PixArtAlphaTextProjection(in_features=caption_channels, hidden_size=inner_dim)
self.rope = LTXRotaryPosEmbed(
self.rope = LTXVideoRotaryPosEmbed(
dim=inner_dim,
base_num_frames=20,
base_height=2048,
@@ -339,7 +339,7 @@ class LTXVideoTransformer3DModel(ModelMixin, ConfigMixin, FromOriginalModelMixin
self.transformer_blocks = nn.ModuleList(
[
LTXTransformerBlock(
LTXVideoTransformerBlock(
dim=inner_dim,
num_attention_heads=num_attention_heads,
attention_head_dim=attention_head_dim,
@@ -39,7 +39,7 @@ EXAMPLE_DOC_STRING = """
>>> from diffusers import HunyuanVideoPipeline, HunyuanVideoTransformer3DModel
>>> from diffusers.utils import export_to_video
>>> model_id = "tencent/HunyuanVideo"
>>> model_id = "hunyuanvideo-community/HunyuanVideo"
>>> transformer = HunyuanVideoTransformer3DModel.from_pretrained(
... model_id, subfolder="transformer", torch_dtype=torch.bfloat16
... )
@@ -193,15 +193,15 @@ class KandinskyCombinedPipeline(DiffusionPipeline):
def enable_xformers_memory_efficient_attention(self, attention_op: Optional[Callable] = None):
self.decoder_pipe.enable_xformers_memory_efficient_attention(attention_op)
def enable_sequential_cpu_offload(self, gpu_id=0):
def enable_sequential_cpu_offload(self, gpu_id: Optional[int] = None, device: Union[torch.device, str] = "cuda"):
r"""
Offloads all models (`unet`, `text_encoder`, `vae`, and `safety checker` state dicts) to CPU using 🤗
Accelerate, significantly reducing memory usage. Models are moved to a `torch.device('meta')` and loaded on a
GPU only when their specific submodule's `forward` method is called. Offloading happens on a submodule basis.
Memory savings are higher than using `enable_model_cpu_offload`, but performance is lower.
"""
self.prior_pipe.enable_sequential_cpu_offload(gpu_id=gpu_id)
self.decoder_pipe.enable_sequential_cpu_offload(gpu_id=gpu_id)
self.prior_pipe.enable_sequential_cpu_offload(gpu_id=gpu_id, device=device)
self.decoder_pipe.enable_sequential_cpu_offload(gpu_id=gpu_id, device=device)
def progress_bar(self, iterable=None, total=None):
self.prior_pipe.progress_bar(iterable=iterable, total=total)
@@ -411,7 +411,7 @@ class KandinskyImg2ImgCombinedPipeline(DiffusionPipeline):
def enable_xformers_memory_efficient_attention(self, attention_op: Optional[Callable] = None):
self.decoder_pipe.enable_xformers_memory_efficient_attention(attention_op)
def enable_sequential_cpu_offload(self, gpu_id=0):
def enable_sequential_cpu_offload(self, gpu_id: Optional[int] = None, device: Union[torch.device, str] = "cuda"):
r"""
Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, unet,
text_encoder, vae and safety checker have their state dicts saved to CPU and then are moved to a
@@ -419,8 +419,8 @@ class KandinskyImg2ImgCombinedPipeline(DiffusionPipeline):
Note that offloading happens on a submodule basis. Memory savings are higher than with
`enable_model_cpu_offload`, but performance is lower.
"""
self.prior_pipe.enable_sequential_cpu_offload(gpu_id=gpu_id)
self.decoder_pipe.enable_sequential_cpu_offload(gpu_id=gpu_id)
self.prior_pipe.enable_sequential_cpu_offload(gpu_id=gpu_id, device=device)
self.decoder_pipe.enable_sequential_cpu_offload(gpu_id=gpu_id, device=device)
def progress_bar(self, iterable=None, total=None):
self.prior_pipe.progress_bar(iterable=iterable, total=total)
@@ -652,7 +652,7 @@ class KandinskyInpaintCombinedPipeline(DiffusionPipeline):
def enable_xformers_memory_efficient_attention(self, attention_op: Optional[Callable] = None):
self.decoder_pipe.enable_xformers_memory_efficient_attention(attention_op)
def enable_sequential_cpu_offload(self, gpu_id=0):
def enable_sequential_cpu_offload(self, gpu_id: Optional[int] = None, device: Union[torch.device, str] = "cuda"):
r"""
Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, unet,
text_encoder, vae and safety checker have their state dicts saved to CPU and then are moved to a
@@ -660,8 +660,8 @@ class KandinskyInpaintCombinedPipeline(DiffusionPipeline):
Note that offloading happens on a submodule basis. Memory savings are higher than with
`enable_model_cpu_offload`, but performance is lower.
"""
self.prior_pipe.enable_sequential_cpu_offload(gpu_id=gpu_id)
self.decoder_pipe.enable_sequential_cpu_offload(gpu_id=gpu_id)
self.prior_pipe.enable_sequential_cpu_offload(gpu_id=gpu_id, device=device)
self.decoder_pipe.enable_sequential_cpu_offload(gpu_id=gpu_id, device=device)
def progress_bar(self, iterable=None, total=None):
self.prior_pipe.progress_bar(iterable=iterable, total=total)
+25 -1
View File
@@ -511,6 +511,8 @@ class LTXPipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoLoraLoaderMixi
prompt_attention_mask: Optional[torch.Tensor] = None,
negative_prompt_embeds: Optional[torch.Tensor] = None,
negative_prompt_attention_mask: Optional[torch.Tensor] = None,
decode_timestep: Union[float, List[float]] = 0.0,
decode_noise_scale: Optional[Union[float, List[float]]] = None,
output_type: Optional[str] = "pil",
return_dict: bool = True,
attention_kwargs: Optional[Dict[str, Any]] = None,
@@ -563,6 +565,10 @@ class LTXPipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoLoraLoaderMixi
provided, negative_prompt_embeds will be generated from `negative_prompt` input argument.
negative_prompt_attention_mask (`torch.FloatTensor`, *optional*):
Pre-generated attention mask for negative text embeddings.
decode_timestep (`float`, defaults to `0.0`):
The timestep at which generated video is decoded.
decode_noise_scale (`float`, defaults to `None`):
The interpolation factor between random noise and denoised latents at the decode timestep.
output_type (`str`, *optional*, defaults to `"pil"`):
The output format of the generate image. Choose between
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
@@ -753,7 +759,25 @@ class LTXPipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoLoraLoaderMixi
latents, self.vae.latents_mean, self.vae.latents_std, self.vae.config.scaling_factor
)
latents = latents.to(prompt_embeds.dtype)
video = self.vae.decode(latents, return_dict=False)[0]
if not self.vae.config.timestep_conditioning:
timestep = None
else:
noise = torch.randn(latents.shape, generator=generator, device=device, dtype=latents.dtype)
if not isinstance(decode_timestep, list):
decode_timestep = [decode_timestep] * batch_size
if decode_noise_scale is None:
decode_noise_scale = decode_timestep
elif not isinstance(decode_noise_scale, list):
decode_noise_scale = [decode_noise_scale] * batch_size
timestep = torch.tensor(decode_timestep, device=device, dtype=latents.dtype)
decode_noise_scale = torch.tensor(decode_noise_scale, device=device, dtype=latents.dtype)[
:, None, None, None, None
]
latents = (1 - decode_noise_scale) * latents + decode_noise_scale * noise
video = self.vae.decode(latents, timestep, return_dict=False)[0]
video = self.video_processor.postprocess_video(video, output_type=output_type)
# Offload all models
@@ -571,6 +571,8 @@ class LTXImageToVideoPipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoLo
prompt_attention_mask: Optional[torch.Tensor] = None,
negative_prompt_embeds: Optional[torch.Tensor] = None,
negative_prompt_attention_mask: Optional[torch.Tensor] = None,
decode_timestep: Union[float, List[float]] = 0.0,
decode_noise_scale: Optional[Union[float, List[float]]] = None,
output_type: Optional[str] = "pil",
return_dict: bool = True,
attention_kwargs: Optional[Dict[str, Any]] = None,
@@ -625,6 +627,10 @@ class LTXImageToVideoPipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoLo
provided, negative_prompt_embeds will be generated from `negative_prompt` input argument.
negative_prompt_attention_mask (`torch.FloatTensor`, *optional*):
Pre-generated attention mask for negative text embeddings.
decode_timestep (`float`, defaults to `0.0`):
The timestep at which generated video is decoded.
decode_noise_scale (`float`, defaults to `None`):
The interpolation factor between random noise and denoised latents at the decode timestep.
output_type (`str`, *optional*, defaults to `"pil"`):
The output format of the generate image. Choose between
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
@@ -849,7 +855,25 @@ class LTXImageToVideoPipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoLo
latents, self.vae.latents_mean, self.vae.latents_std, self.vae.config.scaling_factor
)
latents = latents.to(prompt_embeds.dtype)
video = self.vae.decode(latents, return_dict=False)[0]
if not self.vae.config.timestep_conditioning:
timestep = None
else:
noise = torch.randn(latents.shape, generator=generator, device=device, dtype=latents.dtype)
if not isinstance(decode_timestep, list):
decode_timestep = [decode_timestep] * batch_size
if decode_noise_scale is None:
decode_noise_scale = decode_timestep
elif not isinstance(decode_noise_scale, list):
decode_noise_scale = [decode_noise_scale] * batch_size
timestep = torch.tensor(decode_timestep, device=device, dtype=latents.dtype)
decode_noise_scale = torch.tensor(decode_noise_scale, device=device, dtype=latents.dtype)[
:, None, None, None, None
]
latents = (1 - decode_noise_scale) * latents + decode_noise_scale * noise
video = self.vae.decode(latents, timestep, return_dict=False)[0]
video = self.video_processor.postprocess_video(video, output_type=output_type)
# Offload all models
@@ -59,13 +59,13 @@ EXAMPLE_DOC_STRING = """
>>> from diffusers import SanaPAGPipeline
>>> pipe = SanaPAGPipeline.from_pretrained(
... "Efficient-Large-Model/Sana_1600M_1024px_diffusers",
... "Efficient-Large-Model/Sana_1600M_1024px_BF16_diffusers",
... pag_applied_layers=["transformer_blocks.8"],
... torch_dtype=torch.float32,
... )
>>> pipe.to("cuda")
>>> pipe.text_encoder.to(torch.bfloat16)
>>> pipe.transformer = pipe.transformer.to(torch.float16)
>>> pipe.transformer = pipe.transformer.to(torch.bfloat16)
>>> image = pipe(prompt='a cyberpunk cat with a neon sign that says "Sana"')[0]
>>> image[0].save("output.png")
@@ -62,11 +62,11 @@ EXAMPLE_DOC_STRING = """
>>> from diffusers import SanaPipeline
>>> pipe = SanaPipeline.from_pretrained(
... "Efficient-Large-Model/Sana_1600M_1024px_diffusers", torch_dtype=torch.float32
... "Efficient-Large-Model/Sana_1600M_1024px_BF16_diffusers", torch_dtype=torch.float32
... )
>>> pipe.to("cuda")
>>> pipe.text_encoder.to(torch.bfloat16)
>>> pipe.transformer = pipe.transformer.to(torch.float16)
>>> pipe.transformer = pipe.transformer.to(torch.bfloat16)
>>> image = pipe(prompt='a cyberpunk cat with a neon sign that says "Sana"')[0]
>>> image[0].save("output.png")
@@ -23,7 +23,7 @@ from typing import TYPE_CHECKING, Any, Dict, List, Union
from packaging import version
from ...utils import get_module_from_name, is_torch_available, is_torchao_available, logging
from ...utils import get_module_from_name, is_torch_available, is_torch_version, is_torchao_available, logging
from ..base import DiffusersQuantizer
@@ -35,21 +35,28 @@ if is_torch_available():
import torch
import torch.nn as nn
SUPPORTED_TORCH_DTYPES_FOR_QUANTIZATION = (
# At the moment, only int8 is supported for integer quantization dtypes.
# In Torch 2.6, int1-int7 will be introduced, so this can be visited in the future
# to support more quantization methods, such as intx_weight_only.
torch.int8,
torch.float8_e4m3fn,
torch.float8_e5m2,
torch.uint1,
torch.uint2,
torch.uint3,
torch.uint4,
torch.uint5,
torch.uint6,
torch.uint7,
)
if is_torch_version(">=", "2.5"):
SUPPORTED_TORCH_DTYPES_FOR_QUANTIZATION = (
# At the moment, only int8 is supported for integer quantization dtypes.
# In Torch 2.6, int1-int7 will be introduced, so this can be visited in the future
# to support more quantization methods, such as intx_weight_only.
torch.int8,
torch.float8_e4m3fn,
torch.float8_e5m2,
torch.uint1,
torch.uint2,
torch.uint3,
torch.uint4,
torch.uint5,
torch.uint6,
torch.uint7,
)
else:
SUPPORTED_TORCH_DTYPES_FOR_QUANTIZATION = (
torch.int8,
torch.float8_e4m3fn,
torch.float8_e5m2,
)
if is_torchao_available():
from torchao.quantization import quantize_
@@ -93,6 +100,11 @@ class TorchAoHfQuantizer(DiffusersQuantizer):
raise ImportError(
"Loading a TorchAO quantized model requires the torchao library. Please install with `pip install torchao`"
)
torchao_version = version.parse(importlib.metadata.version("torch"))
if torchao_version < version.parse("0.7.0"):
raise RuntimeError(
f"The minimum required version of `torchao` is 0.7.0, but the current version is {torchao_version}. Please upgrade with `pip install -U torchao`."
)
self.offload = False
@@ -120,7 +132,7 @@ class TorchAoHfQuantizer(DiffusersQuantizer):
def update_torch_dtype(self, torch_dtype):
quant_type = self.quantization_config.quant_type
if quant_type.startswith("int"):
if quant_type.startswith("int") or quant_type.startswith("uint"):
if torch_dtype is not None and torch_dtype != torch.bfloat16:
logger.warning(
f"You are trying to set torch_dtype to {torch_dtype} for int4/int8/uintx quantization, but "
+2 -2
View File
@@ -490,11 +490,11 @@ def require_gguf_version_greater_or_equal(gguf_version):
return decorator
def require_torchao_version_greater(torchao_version):
def require_torchao_version_greater_or_equal(torchao_version):
def decorator(test_case):
correct_torchao_version = is_torchao_available() and version.parse(
version.parse(importlib.metadata.version("torchao")).base_version
) > version.parse(torchao_version)
) >= version.parse(torchao_version)
return unittest.skipUnless(
correct_torchao_version, f"Test requires torchao with version greater than {torchao_version}."
)(test_case)
+1 -41
View File
@@ -15,8 +15,6 @@
import sys
import unittest
import numpy as np
import pytest
import torch
from transformers import AutoTokenizer, T5EncoderModel
@@ -29,16 +27,13 @@ from diffusers import (
)
from diffusers.utils.testing_utils import (
floats_tensor,
is_torch_version,
require_peft_backend,
skip_mps,
torch_device,
)
sys.path.append(".")
from utils import PeftLoraLoaderMixinTests, check_if_lora_correctly_set # noqa: E402
from utils import PeftLoraLoaderMixinTests # noqa: E402
@require_peft_backend
@@ -123,41 +118,6 @@ class CogVideoXLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
return noise, input_ids, pipeline_inputs
@skip_mps
@pytest.mark.xfail(
condition=torch.device(torch_device).type == "cpu" and is_torch_version(">=", "2.5"),
reason="Test currently fails on CPU and PyTorch 2.5.1 but not on PyTorch 2.4.1.",
strict=True,
)
def test_lora_fuse_nan(self):
for scheduler_cls in self.scheduler_classes:
components, text_lora_config, denoiser_lora_config = self.get_dummy_components(scheduler_cls)
pipe = self.pipeline_class(**components)
pipe = pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)
_, _, inputs = self.get_dummy_inputs(with_generator=False)
pipe.transformer.add_adapter(denoiser_lora_config, "adapter-1")
self.assertTrue(check_if_lora_correctly_set(pipe.transformer), "Lora not correctly set in denoiser")
# corrupt one LoRA weight with `inf` values
with torch.no_grad():
pipe.transformer.transformer_blocks[0].attn1.to_q.lora_A["adapter-1"].weight += float("inf")
# with `safe_fusing=True` we should see an Error
with self.assertRaises(ValueError):
pipe.fuse_lora(components=self.pipeline_class._lora_loadable_modules, safe_fusing=True)
# without we should not see an error, but every image will be black
pipe.fuse_lora(components=self.pipeline_class._lora_loadable_modules, safe_fusing=False)
out = pipe(
"test", num_inference_steps=2, max_sequence_length=inputs["max_sequence_length"], output_type="np"
)[0]
self.assertTrue(np.isnan(out).all())
def test_simple_inference_with_text_lora_denoiser_fused_multi(self):
super().test_simple_inference_with_text_lora_denoiser_fused_multi(expected_atol=9e-3)
+20 -73
View File
@@ -36,7 +36,6 @@ from diffusers.utils.testing_utils import (
numpy_cosine_similarity_distance,
require_big_gpu_with_torch_cuda,
require_peft_backend,
require_peft_version_greater,
require_torch_gpu,
slow,
torch_device,
@@ -331,7 +330,8 @@ class FluxControlLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
}
with CaptureLogger(logger) as cap_logger:
pipe.load_lora_weights(lora_state_dict, "adapter-1")
self.assertTrue(check_if_lora_correctly_set(pipe.transformer), "Lora not correctly set in denoiser")
self.assertTrue(check_if_lora_correctly_set(pipe.transformer), "Lora not correctly set in denoiser")
lora_out = pipe(**inputs, generator=torch.manual_seed(0))[0]
@@ -340,85 +340,32 @@ class FluxControlLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
self.assertTrue(pipe.transformer.config.in_channels == 2 * in_features)
self.assertTrue(cap_logger.out.startswith("Expanding the nn.Linear input/output features for module"))
@require_peft_version_greater("0.13.2")
def test_lora_B_bias(self):
components, _, denoiser_lora_config = self.get_dummy_components(FlowMatchEulerDiscreteScheduler)
# Testing opposite direction where the LoRA params are zero-padded.
components, _, _ = self.get_dummy_components(FlowMatchEulerDiscreteScheduler)
pipe = self.pipeline_class(**components)
pipe = pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)
# keep track of the bias values of the base layers to perform checks later.
bias_values = {}
for name, module in pipe.transformer.named_modules():
if any(k in name for k in ["to_q", "to_k", "to_v", "to_out.0"]):
if module.bias is not None:
bias_values[name] = module.bias.data.clone()
_, _, inputs = self.get_dummy_inputs(with_generator=False)
logger = logging.get_logger("diffusers.loaders.lora_pipeline")
logger.setLevel(logging.INFO)
original_output = pipe(**inputs, generator=torch.manual_seed(0))[0]
denoiser_lora_config.lora_bias = False
pipe.transformer.add_adapter(denoiser_lora_config, "adapter-1")
lora_bias_false_output = pipe(**inputs, generator=torch.manual_seed(0))[0]
pipe.delete_adapters("adapter-1")
denoiser_lora_config.lora_bias = True
pipe.transformer.add_adapter(denoiser_lora_config, "adapter-1")
lora_bias_true_output = pipe(**inputs, generator=torch.manual_seed(0))[0]
self.assertFalse(np.allclose(original_output, lora_bias_false_output, atol=1e-3, rtol=1e-3))
self.assertFalse(np.allclose(original_output, lora_bias_true_output, atol=1e-3, rtol=1e-3))
self.assertFalse(np.allclose(lora_bias_false_output, lora_bias_true_output, atol=1e-3, rtol=1e-3))
# for now this is flux control lora specific but can be generalized later and added to ./utils.py
def test_correct_lora_configs_with_different_ranks(self):
components, _, denoiser_lora_config = self.get_dummy_components(FlowMatchEulerDiscreteScheduler)
pipe = self.pipeline_class(**components)
pipe = pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)
_, _, inputs = self.get_dummy_inputs(with_generator=False)
original_output = pipe(**inputs, generator=torch.manual_seed(0))[0]
pipe.transformer.add_adapter(denoiser_lora_config, "adapter-1")
lora_output_same_rank = pipe(**inputs, generator=torch.manual_seed(0))[0]
pipe.transformer.delete_adapters("adapter-1")
# change the rank_pattern
updated_rank = denoiser_lora_config.r * 2
denoiser_lora_config.rank_pattern = {"single_transformer_blocks.0.attn.to_k": updated_rank}
pipe.transformer.add_adapter(denoiser_lora_config, "adapter-1")
assert pipe.transformer.peft_config["adapter-1"].rank_pattern == {
"single_transformer_blocks.0.attn.to_k": updated_rank
dummy_lora_A = torch.nn.Linear(1, rank, bias=False)
dummy_lora_B = torch.nn.Linear(rank, out_features, bias=False)
lora_state_dict = {
"transformer.x_embedder.lora_A.weight": dummy_lora_A.weight,
"transformer.x_embedder.lora_B.weight": dummy_lora_B.weight,
}
with CaptureLogger(logger) as cap_logger:
pipe.load_lora_weights(lora_state_dict, "adapter-1")
lora_output_diff_rank = pipe(**inputs, generator=torch.manual_seed(0))[0]
self.assertTrue(check_if_lora_correctly_set(pipe.transformer), "Lora not correctly set in denoiser")
self.assertTrue(not np.allclose(original_output, lora_output_same_rank, atol=1e-3, rtol=1e-3))
self.assertTrue(not np.allclose(lora_output_diff_rank, lora_output_same_rank, atol=1e-3, rtol=1e-3))
pipe.transformer.delete_adapters("adapter-1")
lora_out = pipe(**inputs, generator=torch.manual_seed(0))[0]
# similarly change the alpha_pattern
updated_alpha = denoiser_lora_config.lora_alpha * 2
denoiser_lora_config.alpha_pattern = {"single_transformer_blocks.0.attn.to_k": updated_alpha}
pipe.transformer.add_adapter(denoiser_lora_config, "adapter-1")
assert pipe.transformer.peft_config["adapter-1"].alpha_pattern == {
"single_transformer_blocks.0.attn.to_k": updated_alpha
}
self.assertFalse(np.allclose(original_out, lora_out, rtol=1e-4, atol=1e-4))
self.assertTrue(pipe.transformer.x_embedder.weight.data.shape[1] == 2 * in_features)
self.assertTrue(pipe.transformer.config.in_channels == 2 * in_features)
self.assertTrue("The following LoRA modules were zero padded to match the state dict of" in cap_logger.out)
lora_output_diff_alpha = pipe(**inputs, generator=torch.manual_seed(0))[0]
self.assertTrue(not np.allclose(original_output, lora_output_diff_alpha, atol=1e-3, rtol=1e-3))
self.assertTrue(not np.allclose(lora_output_diff_alpha, lora_output_same_rank, atol=1e-3, rtol=1e-3))
def test_lora_expanding_shape_with_normal_lora(self):
# This test checks if it works when a lora with expanded shapes (like control loras) but
# another lora with correct shapes is loaded. The opposite direction isn't supported and is
# tested with it.
def test_normal_lora_with_expanded_lora_raises_error(self):
# Test the following situation. Load a regular LoRA (such as the ones trained on Flux.1-Dev). And then
# load shape expanded LoRA (such as Control LoRA).
components, _, _ = self.get_dummy_components(FlowMatchEulerDiscreteScheduler)
# Change the transformer config to mimic a real use case.
+1 -45
View File
@@ -15,8 +15,6 @@
import sys
import unittest
import numpy as np
import pytest
import torch
from transformers import CLIPTextModel, CLIPTokenizer, LlamaModel, LlamaTokenizerFast
@@ -28,16 +26,14 @@ from diffusers import (
)
from diffusers.utils.testing_utils import (
floats_tensor,
is_torch_version,
require_peft_backend,
skip_mps,
torch_device,
)
sys.path.append(".")
from utils import PeftLoraLoaderMixinTests, check_if_lora_correctly_set # noqa: E402
from utils import PeftLoraLoaderMixinTests # noqa: E402
@require_peft_backend
@@ -144,46 +140,6 @@ class HunyuanVideoLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
return noise, input_ids, pipeline_inputs
@pytest.mark.xfail(
condition=torch.device(torch_device).type == "cpu" and is_torch_version(">=", "2.5"),
reason="Test currently fails on CPU and PyTorch 2.5.1 but not on PyTorch 2.4.1.",
strict=True,
)
def test_lora_fuse_nan(self):
for scheduler_cls in self.scheduler_classes:
components, text_lora_config, denoiser_lora_config = self.get_dummy_components(scheduler_cls)
pipe = self.pipeline_class(**components)
pipe = pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)
_, _, inputs = self.get_dummy_inputs(with_generator=False)
pipe.transformer.add_adapter(denoiser_lora_config, "adapter-1")
self.assertTrue(check_if_lora_correctly_set(pipe.transformer), "Lora not correctly set in denoiser")
# corrupt one LoRA weight with `inf` values
with torch.no_grad():
pipe.transformer.transformer_blocks[0].attn.to_q.lora_A["adapter-1"].weight += float("inf")
# with `safe_fusing=True` we should see an Error
with self.assertRaises(ValueError):
pipe.fuse_lora(components=self.pipeline_class._lora_loadable_modules, safe_fusing=True)
# without we should not see an error, but every image will be black
pipe.fuse_lora(components=self.pipeline_class._lora_loadable_modules, safe_fusing=False)
out = pipe(
prompt=inputs["prompt"],
height=inputs["height"],
width=inputs["width"],
num_frames=inputs["num_frames"],
num_inference_steps=inputs["num_inference_steps"],
max_sequence_length=inputs["max_sequence_length"],
output_type="np",
)[0]
self.assertTrue(np.isnan(out).all())
def test_simple_inference_with_text_lora_denoiser_fused_multi(self):
super().test_simple_inference_with_text_lora_denoiser_fused_multi(expected_atol=9e-3)
+12 -46
View File
@@ -15,8 +15,6 @@
import sys
import unittest
import numpy as np
import pytest
import torch
from transformers import AutoTokenizer, T5EncoderModel
@@ -26,18 +24,12 @@ from diffusers import (
LTXPipeline,
LTXVideoTransformer3DModel,
)
from diffusers.utils.testing_utils import (
floats_tensor,
is_torch_version,
require_peft_backend,
skip_mps,
torch_device,
)
from diffusers.utils.testing_utils import floats_tensor, require_peft_backend
sys.path.append(".")
from utils import PeftLoraLoaderMixinTests, check_if_lora_correctly_set # noqa: E402
from utils import PeftLoraLoaderMixinTests # noqa: E402
@require_peft_backend
@@ -60,10 +52,19 @@ class LTXVideoLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
}
transformer_cls = LTXVideoTransformer3DModel
vae_kwargs = {
"in_channels": 3,
"out_channels": 3,
"latent_channels": 8,
"block_out_channels": (8, 8, 8, 8),
"spatio_temporal_scaling": (True, True, False, False),
"decoder_block_out_channels": (8, 8, 8, 8),
"layers_per_block": (1, 1, 1, 1, 1),
"decoder_layers_per_block": (1, 1, 1, 1, 1),
"spatio_temporal_scaling": (True, True, False, False),
"decoder_spatio_temporal_scaling": (True, True, False, False),
"decoder_inject_noise": (False, False, False, False, False),
"upsample_residual": (False, False, False, False),
"upsample_factor": (1, 1, 1, 1),
"timestep_conditioning": False,
"patch_size": 1,
"patch_size_t": 1,
"encoder_causal": True,
@@ -107,41 +108,6 @@ class LTXVideoLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
return noise, input_ids, pipeline_inputs
@skip_mps
@pytest.mark.xfail(
condition=torch.device(torch_device).type == "cpu" and is_torch_version(">=", "2.5"),
reason="Test currently fails on CPU and PyTorch 2.5.1 but not on PyTorch 2.4.1.",
strict=True,
)
def test_lora_fuse_nan(self):
for scheduler_cls in self.scheduler_classes:
components, text_lora_config, denoiser_lora_config = self.get_dummy_components(scheduler_cls)
pipe = self.pipeline_class(**components)
pipe = pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)
_, _, inputs = self.get_dummy_inputs(with_generator=False)
pipe.transformer.add_adapter(denoiser_lora_config, "adapter-1")
self.assertTrue(check_if_lora_correctly_set(pipe.transformer), "Lora not correctly set in denoiser")
# corrupt one LoRA weight with `inf` values
with torch.no_grad():
pipe.transformer.transformer_blocks[0].attn1.to_q.lora_A["adapter-1"].weight += float("inf")
# with `safe_fusing=True` we should see an Error
with self.assertRaises(ValueError):
pipe.fuse_lora(components=self.pipeline_class._lora_loadable_modules, safe_fusing=True)
# without we should not see an error, but every image will be black
pipe.fuse_lora(components=self.pipeline_class._lora_loadable_modules, safe_fusing=False)
out = pipe(
"test", num_inference_steps=2, max_sequence_length=inputs["max_sequence_length"], output_type="np"
)[0]
self.assertTrue(np.isnan(out).all())
def test_simple_inference_with_text_lora_denoiser_fused_multi(self):
super().test_simple_inference_with_text_lora_denoiser_fused_multi(expected_atol=9e-3)
+1 -39
View File
@@ -15,24 +15,20 @@
import sys
import unittest
import numpy as np
import pytest
import torch
from transformers import AutoTokenizer, T5EncoderModel
from diffusers import AutoencoderKLMochi, FlowMatchEulerDiscreteScheduler, MochiPipeline, MochiTransformer3DModel
from diffusers.utils.testing_utils import (
floats_tensor,
is_torch_version,
require_peft_backend,
skip_mps,
torch_device,
)
sys.path.append(".")
from utils import PeftLoraLoaderMixinTests, check_if_lora_correctly_set # noqa: E402
from utils import PeftLoraLoaderMixinTests # noqa: E402
@require_peft_backend
@@ -103,40 +99,6 @@ class MochiLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
return noise, input_ids, pipeline_inputs
@pytest.mark.xfail(
condition=torch.device(torch_device).type == "cpu" and is_torch_version(">=", "2.5"),
reason="Test currently fails on CPU and PyTorch 2.5.1 but not on PyTorch 2.4.1.",
strict=True,
)
def test_lora_fuse_nan(self):
for scheduler_cls in self.scheduler_classes:
components, text_lora_config, denoiser_lora_config = self.get_dummy_components(scheduler_cls)
pipe = self.pipeline_class(**components)
pipe = pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)
_, _, inputs = self.get_dummy_inputs(with_generator=False)
pipe.transformer.add_adapter(denoiser_lora_config, "adapter-1")
self.assertTrue(check_if_lora_correctly_set(pipe.transformer), "Lora not correctly set in denoiser")
# corrupt one LoRA weight with `inf` values
with torch.no_grad():
pipe.transformer.transformer_blocks[0].attn1.to_q.lora_A["adapter-1"].weight += float("inf")
# with `safe_fusing=True` we should see an Error
with self.assertRaises(ValueError):
pipe.fuse_lora(components=self.pipeline_class._lora_loadable_modules, safe_fusing=True)
# without we should not see an error, but every image will be black
pipe.fuse_lora(components=self.pipeline_class._lora_loadable_modules, safe_fusing=False)
out = pipe(
"test", num_inference_steps=2, max_sequence_length=inputs["max_sequence_length"], output_type="np"
)[0]
self.assertTrue(np.isnan(out).all())
def test_simple_inference_with_text_lora_denoiser_fused_multi(self):
super().test_simple_inference_with_text_lora_denoiser_fused_multi(expected_atol=9e-3)
+10 -10
View File
@@ -16,7 +16,7 @@ import sys
import unittest
import torch
from transformers import Gemma2ForCausalLM, GemmaTokenizer
from transformers import Gemma2Model, GemmaTokenizer
from diffusers import AutoencoderDC, FlowMatchEulerDiscreteScheduler, SanaPipeline, SanaTransformer2DModel
from diffusers.utils.testing_utils import floats_tensor, require_peft_backend
@@ -73,7 +73,7 @@ class SanaLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
}
vae_cls = AutoencoderDC
tokenizer_cls, tokenizer_id = GemmaTokenizer, "hf-internal-testing/dummy-gemma"
text_encoder_cls, text_encoder_id = Gemma2ForCausalLM, "hf-internal-testing/dummy-gemma-for-diffusers"
text_encoder_cls, text_encoder_id = Gemma2Model, "hf-internal-testing/dummy-gemma-for-diffusers"
@property
def output_shape(self):
@@ -105,34 +105,34 @@ class SanaLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
return noise, input_ids, pipeline_inputs
@unittest.skip("Not supported in Sana.")
@unittest.skip("Not supported in SANA.")
def test_modify_padding_mode(self):
pass
@unittest.skip("Not supported in Mochi.")
@unittest.skip("Not supported in SANA.")
def test_simple_inference_with_text_denoiser_block_scale(self):
pass
@unittest.skip("Not supported in Mochi.")
@unittest.skip("Not supported in SANA.")
def test_simple_inference_with_text_denoiser_block_scale_for_all_dict_options(self):
pass
@unittest.skip("Text encoder LoRA is not supported in Mochi.")
@unittest.skip("Text encoder LoRA is not supported in SANA.")
def test_simple_inference_with_partial_text_lora(self):
pass
@unittest.skip("Text encoder LoRA is not supported in Mochi.")
@unittest.skip("Text encoder LoRA is not supported in SANA.")
def test_simple_inference_with_text_lora(self):
pass
@unittest.skip("Text encoder LoRA is not supported in Mochi.")
@unittest.skip("Text encoder LoRA is not supported in SANA.")
def test_simple_inference_with_text_lora_and_scale(self):
pass
@unittest.skip("Text encoder LoRA is not supported in Mochi.")
@unittest.skip("Text encoder LoRA is not supported in SANA.")
def test_simple_inference_with_text_lora_fused(self):
pass
@unittest.skip("Text encoder LoRA is not supported in Mochi.")
@unittest.skip("Text encoder LoRA is not supported in SANA.")
def test_simple_inference_with_text_lora_save_load(self):
pass
+112 -2
View File
@@ -1528,7 +1528,7 @@ class PeftLoraLoaderMixinTests:
@pytest.mark.xfail(
condition=torch.device(torch_device).type == "cpu" and is_torch_version(">=", "2.5"),
reason="Test currently fails on CPU and PyTorch 2.5.1 but not on PyTorch 2.4.1.",
strict=True,
strict=False,
)
def test_lora_fuse_nan(self):
for scheduler_cls in self.scheduler_classes:
@@ -1568,7 +1568,7 @@ class PeftLoraLoaderMixinTests:
# without we should not see an error, but every image will be black
pipe.fuse_lora(components=self.pipeline_class._lora_loadable_modules, safe_fusing=False)
out = pipe("test", num_inference_steps=2, output_type="np")[0]
out = pipe(**inputs)[0]
self.assertTrue(np.isnan(out).all())
@@ -1988,3 +1988,113 @@ class PeftLoraLoaderMixinTests:
np.allclose(output_lora_scale_wo_kwargs, output_lora_from_pretrained, atol=1e-3, rtol=1e-3),
"Loading from saved checkpoints should give same results as set_adapters().",
)
@require_peft_version_greater("0.13.2")
def test_lora_B_bias(self):
# Currently, this test is only relevant for Flux Control LoRA as we are not
# aware of any other LoRA checkpoint that has its `lora_B` biases trained.
components, _, denoiser_lora_config = self.get_dummy_components(self.scheduler_classes[0])
pipe = self.pipeline_class(**components)
pipe = pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)
# keep track of the bias values of the base layers to perform checks later.
bias_values = {}
denoiser = pipe.unet if self.unet_kwargs is not None else pipe.transformer
for name, module in denoiser.named_modules():
if any(k in name for k in ["to_q", "to_k", "to_v", "to_out.0"]):
if module.bias is not None:
bias_values[name] = module.bias.data.clone()
_, _, inputs = self.get_dummy_inputs(with_generator=False)
logger = logging.get_logger("diffusers.loaders.lora_pipeline")
logger.setLevel(logging.INFO)
original_output = pipe(**inputs, generator=torch.manual_seed(0))[0]
denoiser_lora_config.lora_bias = False
if self.unet_kwargs is not None:
pipe.unet.add_adapter(denoiser_lora_config, "adapter-1")
else:
pipe.transformer.add_adapter(denoiser_lora_config, "adapter-1")
lora_bias_false_output = pipe(**inputs, generator=torch.manual_seed(0))[0]
pipe.delete_adapters("adapter-1")
denoiser_lora_config.lora_bias = True
if self.unet_kwargs is not None:
pipe.unet.add_adapter(denoiser_lora_config, "adapter-1")
else:
pipe.transformer.add_adapter(denoiser_lora_config, "adapter-1")
lora_bias_true_output = pipe(**inputs, generator=torch.manual_seed(0))[0]
self.assertFalse(np.allclose(original_output, lora_bias_false_output, atol=1e-3, rtol=1e-3))
self.assertFalse(np.allclose(original_output, lora_bias_true_output, atol=1e-3, rtol=1e-3))
self.assertFalse(np.allclose(lora_bias_false_output, lora_bias_true_output, atol=1e-3, rtol=1e-3))
def test_correct_lora_configs_with_different_ranks(self):
components, _, denoiser_lora_config = self.get_dummy_components(self.scheduler_classes[0])
pipe = self.pipeline_class(**components)
pipe = pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)
_, _, inputs = self.get_dummy_inputs(with_generator=False)
original_output = pipe(**inputs, generator=torch.manual_seed(0))[0]
if self.unet_kwargs is not None:
pipe.unet.add_adapter(denoiser_lora_config, "adapter-1")
else:
pipe.transformer.add_adapter(denoiser_lora_config, "adapter-1")
lora_output_same_rank = pipe(**inputs, generator=torch.manual_seed(0))[0]
if self.unet_kwargs is not None:
pipe.unet.delete_adapters("adapter-1")
else:
pipe.transformer.delete_adapters("adapter-1")
denoiser = pipe.unet if self.unet_kwargs is not None else pipe.transformer
for name, _ in denoiser.named_modules():
if "to_k" in name and "attn" in name and "lora" not in name:
module_name_to_rank_update = name.replace(".base_layer.", ".")
break
# change the rank_pattern
updated_rank = denoiser_lora_config.r * 2
denoiser_lora_config.rank_pattern = {module_name_to_rank_update: updated_rank}
if self.unet_kwargs is not None:
pipe.unet.add_adapter(denoiser_lora_config, "adapter-1")
updated_rank_pattern = pipe.unet.peft_config["adapter-1"].rank_pattern
else:
pipe.transformer.add_adapter(denoiser_lora_config, "adapter-1")
updated_rank_pattern = pipe.transformer.peft_config["adapter-1"].rank_pattern
self.assertTrue(updated_rank_pattern == {module_name_to_rank_update: updated_rank})
lora_output_diff_rank = pipe(**inputs, generator=torch.manual_seed(0))[0]
self.assertTrue(not np.allclose(original_output, lora_output_same_rank, atol=1e-3, rtol=1e-3))
self.assertTrue(not np.allclose(lora_output_diff_rank, lora_output_same_rank, atol=1e-3, rtol=1e-3))
if self.unet_kwargs is not None:
pipe.unet.delete_adapters("adapter-1")
else:
pipe.transformer.delete_adapters("adapter-1")
# similarly change the alpha_pattern
updated_alpha = denoiser_lora_config.lora_alpha * 2
denoiser_lora_config.alpha_pattern = {module_name_to_rank_update: updated_alpha}
if self.unet_kwargs is not None:
pipe.unet.add_adapter(denoiser_lora_config, "adapter-1")
self.assertTrue(
pipe.unet.peft_config["adapter-1"].alpha_pattern == {module_name_to_rank_update: updated_alpha}
)
else:
pipe.transformer.add_adapter(denoiser_lora_config, "adapter-1")
self.assertTrue(
pipe.transformer.peft_config["adapter-1"].alpha_pattern == {module_name_to_rank_update: updated_alpha}
)
lora_output_diff_alpha = pipe(**inputs, generator=torch.manual_seed(0))[0]
self.assertTrue(not np.allclose(original_output, lora_output_diff_alpha, atol=1e-3, rtol=1e-3))
self.assertTrue(not np.allclose(lora_output_diff_alpha, lora_output_same_rank, atol=1e-3, rtol=1e-3))
@@ -0,0 +1,169 @@
# coding=utf-8
# Copyright 2024 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import torch
from diffusers import AutoencoderKLLTXVideo
from diffusers.utils.testing_utils import (
enable_full_determinism,
floats_tensor,
torch_device,
)
from ..test_modeling_common import ModelTesterMixin, UNetTesterMixin
enable_full_determinism()
class AutoencoderKLLTXVideo090Tests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase):
model_class = AutoencoderKLLTXVideo
main_input_name = "sample"
base_precision = 1e-2
def get_autoencoder_kl_ltx_video_config(self):
return {
"in_channels": 3,
"out_channels": 3,
"latent_channels": 8,
"block_out_channels": (8, 8, 8, 8),
"decoder_block_out_channels": (8, 8, 8, 8),
"layers_per_block": (1, 1, 1, 1, 1),
"decoder_layers_per_block": (1, 1, 1, 1, 1),
"spatio_temporal_scaling": (True, True, False, False),
"decoder_spatio_temporal_scaling": (True, True, False, False),
"decoder_inject_noise": (False, False, False, False, False),
"upsample_residual": (False, False, False, False),
"upsample_factor": (1, 1, 1, 1),
"timestep_conditioning": False,
"patch_size": 1,
"patch_size_t": 1,
"encoder_causal": True,
"decoder_causal": False,
}
@property
def dummy_input(self):
batch_size = 2
num_frames = 9
num_channels = 3
sizes = (16, 16)
image = floats_tensor((batch_size, num_channels, num_frames) + sizes).to(torch_device)
return {"sample": image}
@property
def input_shape(self):
return (3, 9, 16, 16)
@property
def output_shape(self):
return (3, 9, 16, 16)
def prepare_init_args_and_inputs_for_common(self):
init_dict = self.get_autoencoder_kl_ltx_video_config()
inputs_dict = self.dummy_input
return init_dict, inputs_dict
def test_gradient_checkpointing_is_applied(self):
expected_set = {
"LTXVideoEncoder3d",
"LTXVideoDecoder3d",
"LTXVideoDownBlock3D",
"LTXVideoMidBlock3d",
"LTXVideoUpBlock3d",
}
super().test_gradient_checkpointing_is_applied(expected_set=expected_set)
@unittest.skip("Unsupported test.")
def test_outputs_equivalence(self):
pass
@unittest.skip("AutoencoderKLLTXVideo does not support `norm_num_groups` because it does not use GroupNorm.")
def test_forward_with_norm_groups(self):
pass
class AutoencoderKLLTXVideo091Tests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase):
model_class = AutoencoderKLLTXVideo
main_input_name = "sample"
base_precision = 1e-2
def get_autoencoder_kl_ltx_video_config(self):
return {
"in_channels": 3,
"out_channels": 3,
"latent_channels": 8,
"block_out_channels": (8, 8, 8, 8),
"decoder_block_out_channels": (16, 32, 64),
"layers_per_block": (1, 1, 1, 1),
"decoder_layers_per_block": (1, 1, 1, 1),
"spatio_temporal_scaling": (True, True, True, False),
"decoder_spatio_temporal_scaling": (True, True, True),
"decoder_inject_noise": (True, True, True, False),
"upsample_residual": (True, True, True),
"upsample_factor": (2, 2, 2),
"timestep_conditioning": True,
"patch_size": 1,
"patch_size_t": 1,
"encoder_causal": True,
"decoder_causal": False,
}
@property
def dummy_input(self):
batch_size = 2
num_frames = 9
num_channels = 3
sizes = (16, 16)
image = floats_tensor((batch_size, num_channels, num_frames) + sizes).to(torch_device)
timestep = torch.tensor([0.05] * batch_size, device=torch_device)
return {"sample": image, "temb": timestep}
@property
def input_shape(self):
return (3, 9, 16, 16)
@property
def output_shape(self):
return (3, 9, 16, 16)
def prepare_init_args_and_inputs_for_common(self):
init_dict = self.get_autoencoder_kl_ltx_video_config()
inputs_dict = self.dummy_input
return init_dict, inputs_dict
def test_gradient_checkpointing_is_applied(self):
expected_set = {
"LTXVideoEncoder3d",
"LTXVideoDecoder3d",
"LTXVideoDownBlock3D",
"LTXVideoMidBlock3d",
"LTXVideoUpBlock3d",
}
super().test_gradient_checkpointing_is_applied(expected_set=expected_set)
@unittest.skip("Unsupported test.")
def test_outputs_equivalence(self):
pass
@unittest.skip("AutoencoderKLLTXVideo does not support `norm_num_groups` because it does not use GroupNorm.")
def test_forward_with_norm_groups(self):
pass
+11
View File
@@ -2,10 +2,12 @@ import tempfile
import unittest
import numpy as np
import pytest
import torch
from diffusers import DiffusionPipeline
from diffusers.models.attention_processor import Attention, AttnAddedKVProcessor
from diffusers.utils.testing_utils import torch_device
class AttnAddedKVProcessorTests(unittest.TestCase):
@@ -79,6 +81,15 @@ class AttnAddedKVProcessorTests(unittest.TestCase):
class DeprecatedAttentionBlockTests(unittest.TestCase):
@pytest.fixture(scope="session")
def is_dist_enabled(pytestconfig):
return pytestconfig.getoption("dist") == "loadfile"
@pytest.mark.xfail(
condition=torch.device(torch_device).type == "cuda" and is_dist_enabled,
reason="Test currently fails on our GPU CI because of `loadfile`. Note that it only fails when the tests are distributed from `pytest ... tests/models`. If the tests are run individually, even with `loadfile` it won't fail.",
strict=True,
)
def test_conversion_when_using_device_map(self):
pipe = DiffusionPipeline.from_pretrained(
"hf-internal-testing/tiny-stable-diffusion-torch", safety_checker=None
+78 -10
View File
@@ -22,12 +22,14 @@ import traceback
import unittest
import unittest.mock as mock
import uuid
from typing import Dict, List, Tuple
from collections import defaultdict
from typing import Dict, List, Optional, Tuple, Union
import numpy as np
import requests_mock
import torch
from accelerate.utils import compute_module_sizes
import torch.nn as nn
from accelerate.utils.modeling import _get_proper_dtype, dtype_byte_size
from huggingface_hub import ModelCard, delete_repo, snapshot_download
from huggingface_hub.utils import is_jinja_available
from parameterized import parameterized
@@ -113,6 +115,72 @@ def _test_from_save_pretrained_dynamo(in_queue, out_queue, timeout):
out_queue.join()
def named_persistent_module_tensors(
module: nn.Module,
recurse: bool = False,
):
"""
A helper function that gathers all the tensors (parameters + persistent buffers) of a given module.
Args:
module (`torch.nn.Module`):
The module we want the tensors on.
recurse (`bool`, *optional`, defaults to `False`):
Whether or not to go look in every submodule or just return the direct parameters and buffers.
"""
yield from module.named_parameters(recurse=recurse)
for named_buffer in module.named_buffers(recurse=recurse):
name, _ = named_buffer
# Get parent by splitting on dots and traversing the model
parent = module
if "." in name:
parent_name = name.rsplit(".", 1)[0]
for part in parent_name.split("."):
parent = getattr(parent, part)
name = name.split(".")[-1]
if name not in parent._non_persistent_buffers_set:
yield named_buffer
def compute_module_persistent_sizes(
model: nn.Module,
dtype: Optional[Union[str, torch.device]] = None,
special_dtypes: Optional[Dict[str, Union[str, torch.device]]] = None,
):
"""
Compute the size of each submodule of a given model (parameters + persistent buffers).
"""
if dtype is not None:
dtype = _get_proper_dtype(dtype)
dtype_size = dtype_byte_size(dtype)
if special_dtypes is not None:
special_dtypes = {key: _get_proper_dtype(dtyp) for key, dtyp in special_dtypes.items()}
special_dtypes_size = {key: dtype_byte_size(dtyp) for key, dtyp in special_dtypes.items()}
module_sizes = defaultdict(int)
module_list = []
module_list = named_persistent_module_tensors(model, recurse=True)
for name, tensor in module_list:
if special_dtypes is not None and name in special_dtypes:
size = tensor.numel() * special_dtypes_size[name]
elif dtype is None:
size = tensor.numel() * dtype_byte_size(tensor.dtype)
elif str(tensor.dtype).startswith(("torch.uint", "torch.int", "torch.bool")):
# According to the code in set_module_tensor_to_device, these types won't be converted
# so use their original size here
size = tensor.numel() * dtype_byte_size(tensor.dtype)
else:
size = tensor.numel() * min(dtype_size, dtype_byte_size(tensor.dtype))
name_parts = name.split(".")
for idx in range(len(name_parts) + 1):
module_sizes[".".join(name_parts[:idx])] += size
return module_sizes
class ModelUtilsTest(unittest.TestCase):
def tearDown(self):
super().tearDown()
@@ -1012,7 +1080,7 @@ class ModelTesterMixin:
torch.manual_seed(0)
base_output = model(**inputs_dict)
model_size = compute_module_sizes(model)[""]
model_size = compute_module_persistent_sizes(model)[""]
# We test several splits of sizes to make sure it works.
max_gpu_sizes = [int(p * model_size) for p in self.model_split_percents[1:]]
with tempfile.TemporaryDirectory() as tmp_dir:
@@ -1042,7 +1110,7 @@ class ModelTesterMixin:
torch.manual_seed(0)
base_output = model(**inputs_dict)
model_size = compute_module_sizes(model)[""]
model_size = compute_module_persistent_sizes(model)[""]
with tempfile.TemporaryDirectory() as tmp_dir:
model.cpu().save_pretrained(tmp_dir, safe_serialization=False)
@@ -1076,7 +1144,7 @@ class ModelTesterMixin:
torch.manual_seed(0)
base_output = model(**inputs_dict)
model_size = compute_module_sizes(model)[""]
model_size = compute_module_persistent_sizes(model)[""]
with tempfile.TemporaryDirectory() as tmp_dir:
model.cpu().save_pretrained(tmp_dir)
@@ -1104,7 +1172,7 @@ class ModelTesterMixin:
torch.manual_seed(0)
base_output = model(**inputs_dict)
model_size = compute_module_sizes(model)[""]
model_size = compute_module_persistent_sizes(model)[""]
# We test several splits of sizes to make sure it works.
max_gpu_sizes = [int(p * model_size) for p in self.model_split_percents[1:]]
with tempfile.TemporaryDirectory() as tmp_dir:
@@ -1132,7 +1200,7 @@ class ModelTesterMixin:
base_output = model(**inputs_dict)
model_size = compute_module_sizes(model)[""]
model_size = compute_module_persistent_sizes(model)[""]
max_shard_size = int((model_size * 0.75) / (2**10)) # Convert to KB as these test models are small.
with tempfile.TemporaryDirectory() as tmp_dir:
model.cpu().save_pretrained(tmp_dir, max_shard_size=f"{max_shard_size}KB")
@@ -1164,7 +1232,7 @@ class ModelTesterMixin:
base_output = model(**inputs_dict)
model_size = compute_module_sizes(model)[""]
model_size = compute_module_persistent_sizes(model)[""]
max_shard_size = int((model_size * 0.75) / (2**10)) # Convert to KB as these test models are small.
variant = "fp16"
with tempfile.TemporaryDirectory() as tmp_dir:
@@ -1204,7 +1272,7 @@ class ModelTesterMixin:
torch.manual_seed(0)
base_output = model(**inputs_dict)
model_size = compute_module_sizes(model)[""]
model_size = compute_module_persistent_sizes(model)[""]
max_shard_size = int((model_size * 0.75) / (2**10)) # Convert to KB as these test models are small.
with tempfile.TemporaryDirectory() as tmp_dir:
model.cpu().save_pretrained(tmp_dir, max_shard_size=f"{max_shard_size}KB")
@@ -1233,7 +1301,7 @@ class ModelTesterMixin:
config, _ = self.prepare_init_args_and_inputs_for_common()
model = self.model_class(**config).eval()
model_size = compute_module_sizes(model)[""]
model_size = compute_module_persistent_sizes(model)[""]
max_shard_size = int((model_size * 0.75) / (2**10)) # Convert to KB as these test models are small.
variant = "fp16"
with tempfile.TemporaryDirectory() as tmp_dir:
@@ -30,6 +30,8 @@ class MochiTransformerTests(ModelTesterMixin, unittest.TestCase):
model_class = MochiTransformer3DModel
main_input_name = "hidden_states"
uses_custom_attn_processor = True
# Overriding it because of the transformer size.
model_split_percents = [0.7, 0.6, 0.6]
@property
def dummy_input(self):
@@ -14,6 +14,7 @@
import unittest
import pytest
import torch
from diffusers import SanaTransformer2DModel
@@ -80,3 +81,27 @@ class SanaTransformerTests(ModelTesterMixin, unittest.TestCase):
def test_gradient_checkpointing_is_applied(self):
expected_set = {"SanaTransformer2DModel"}
super().test_gradient_checkpointing_is_applied(expected_set=expected_set)
@pytest.mark.xfail(
condition=torch.device(torch_device).type == "cuda",
reason="Test currently fails.",
strict=True,
)
def test_cpu_offload(self):
return super().test_cpu_offload()
@pytest.mark.xfail(
condition=torch.device(torch_device).type == "cuda",
reason="Test currently fails.",
strict=True,
)
def test_disk_offload_with_safetensors(self):
return super().test_disk_offload_with_safetensors()
@pytest.mark.xfail(
condition=torch.device(torch_device).type == "cuda",
reason="Test currently fails.",
strict=True,
)
def test_disk_offload_without_safetensors(self):
return super().test_disk_offload_without_safetensors()
+10 -1
View File
@@ -63,10 +63,19 @@ class LTXPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
torch.manual_seed(0)
vae = AutoencoderKLLTXVideo(
in_channels=3,
out_channels=3,
latent_channels=8,
block_out_channels=(8, 8, 8, 8),
spatio_temporal_scaling=(True, True, False, False),
decoder_block_out_channels=(8, 8, 8, 8),
layers_per_block=(1, 1, 1, 1, 1),
decoder_layers_per_block=(1, 1, 1, 1, 1),
spatio_temporal_scaling=(True, True, False, False),
decoder_spatio_temporal_scaling=(True, True, False, False),
decoder_inject_noise=(False, False, False, False, False),
upsample_residual=(False, False, False, False),
upsample_factor=(1, 1, 1, 1),
timestep_conditioning=False,
patch_size=1,
patch_size_t=1,
encoder_causal=True,
+10 -1
View File
@@ -68,10 +68,19 @@ class LTXImageToVideoPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
torch.manual_seed(0)
vae = AutoencoderKLLTXVideo(
in_channels=3,
out_channels=3,
latent_channels=8,
block_out_channels=(8, 8, 8, 8),
spatio_temporal_scaling=(True, True, False, False),
decoder_block_out_channels=(8, 8, 8, 8),
layers_per_block=(1, 1, 1, 1, 1),
decoder_layers_per_block=(1, 1, 1, 1, 1),
spatio_temporal_scaling=(True, True, False, False),
decoder_spatio_temporal_scaling=(True, True, False, False),
decoder_inject_noise=(False, False, False, False, False),
upsample_residual=(False, False, False, False),
upsample_factor=(1, 1, 1, 1),
timestep_conditioning=False,
patch_size=1,
patch_size_t=1,
encoder_causal=True,
+3 -3
View File
@@ -18,7 +18,7 @@ import unittest
import numpy as np
import torch
from transformers import Gemma2Config, Gemma2ForCausalLM, GemmaTokenizer
from transformers import Gemma2Config, Gemma2Model, GemmaTokenizer
from diffusers import AutoencoderDC, FlowMatchEulerDiscreteScheduler, SanaPipeline, SanaTransformer2DModel
from diffusers.utils.testing_utils import (
@@ -101,7 +101,7 @@ class SanaPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
torch.manual_seed(0)
text_encoder_config = Gemma2Config(
head_dim=16,
hidden_size=32,
hidden_size=8,
initializer_range=0.02,
intermediate_size=64,
max_position_embeddings=8192,
@@ -112,7 +112,7 @@ class SanaPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
vocab_size=8,
attn_implementation="eager",
)
text_encoder = Gemma2ForCausalLM(text_encoder_config)
text_encoder = Gemma2Model(text_encoder_config)
tokenizer = GemmaTokenizer.from_pretrained("hf-internal-testing/dummy-gemma")
components = {
+336 -184
View File
@@ -36,7 +36,7 @@ from diffusers.utils.testing_utils import (
nightly,
require_torch,
require_torch_gpu,
require_torchao_version_greater,
require_torchao_version_greater_or_equal,
slow,
torch_device,
)
@@ -74,13 +74,13 @@ if is_torch_available():
if is_torchao_available():
from torchao.dtypes import AffineQuantizedTensor
from torchao.dtypes.affine_quantized_tensor import TensorCoreTiledLayoutType
from torchao.quantization.linear_activation_quantized_tensor import LinearActivationQuantizedTensor
from torchao.utils import get_model_size_in_bytes
@require_torch
@require_torch_gpu
@require_torchao_version_greater("0.6.0")
@require_torchao_version_greater_or_equal("0.7.0")
class TorchAoConfigTest(unittest.TestCase):
def test_to_dict(self):
"""
@@ -125,25 +125,28 @@ class TorchAoConfigTest(unittest.TestCase):
# Slices for these tests have been obtained on our aws-g6e-xlarge-plus runners
@require_torch
@require_torch_gpu
@require_torchao_version_greater("0.6.0")
@require_torchao_version_greater_or_equal("0.7.0")
class TorchAoTest(unittest.TestCase):
def tearDown(self):
gc.collect()
torch.cuda.empty_cache()
def get_dummy_components(self, quantization_config: TorchAoConfig):
model_id = "hf-internal-testing/tiny-flux-pipe"
def get_dummy_components(
self, quantization_config: TorchAoConfig, model_id: str = "hf-internal-testing/tiny-flux-pipe"
):
transformer = FluxTransformer2DModel.from_pretrained(
model_id,
subfolder="transformer",
quantization_config=quantization_config,
torch_dtype=torch.bfloat16,
)
text_encoder = CLIPTextModel.from_pretrained(model_id, subfolder="text_encoder")
text_encoder_2 = T5EncoderModel.from_pretrained(model_id, subfolder="text_encoder_2")
text_encoder = CLIPTextModel.from_pretrained(model_id, subfolder="text_encoder", torch_dtype=torch.bfloat16)
text_encoder_2 = T5EncoderModel.from_pretrained(
model_id, subfolder="text_encoder_2", torch_dtype=torch.bfloat16
)
tokenizer = CLIPTokenizer.from_pretrained(model_id, subfolder="tokenizer")
tokenizer_2 = AutoTokenizer.from_pretrained(model_id, subfolder="tokenizer_2")
vae = AutoencoderKL.from_pretrained(model_id, subfolder="vae")
vae = AutoencoderKL.from_pretrained(model_id, subfolder="vae", torch_dtype=torch.bfloat16)
scheduler = FlowMatchEulerDiscreteScheduler()
return {
@@ -209,10 +212,10 @@ class TorchAoTest(unittest.TestCase):
"timestep": timestep,
}
def _test_quant_type(self, quantization_config: TorchAoConfig, expected_slice: List[float]):
components = self.get_dummy_components(quantization_config)
def _test_quant_type(self, quantization_config: TorchAoConfig, expected_slice: List[float], model_id: str):
components = self.get_dummy_components(quantization_config, model_id)
pipe = FluxPipeline(**components)
pipe.to(device=torch_device, dtype=torch.bfloat16)
pipe.to(device=torch_device)
inputs = self.get_dummy_inputs(torch_device)
output = pipe(**inputs)[0]
@@ -221,44 +224,45 @@ class TorchAoTest(unittest.TestCase):
self.assertTrue(np.allclose(output_slice, expected_slice, atol=1e-3, rtol=1e-3))
def test_quantization(self):
# fmt: off
QUANTIZATION_TYPES_TO_TEST = [
("int4wo", np.array([0.4648, 0.5234, 0.5547, 0.4219, 0.4414, 0.6445, 0.4336, 0.4531, 0.5625])),
("int4dq", np.array([0.4688, 0.5195, 0.5547, 0.418, 0.4414, 0.6406, 0.4336, 0.4531, 0.5625])),
("int8wo", np.array([0.4648, 0.5195, 0.5547, 0.4199, 0.4414, 0.6445, 0.4316, 0.4531, 0.5625])),
("int8dq", np.array([0.4648, 0.5195, 0.5547, 0.4199, 0.4414, 0.6445, 0.4316, 0.4531, 0.5625])),
("uint4wo", np.array([0.4609, 0.5234, 0.5508, 0.4199, 0.4336, 0.6406, 0.4316, 0.4531, 0.5625])),
("uint7wo", np.array([0.4648, 0.5195, 0.5547, 0.4219, 0.4414, 0.6445, 0.4316, 0.4531, 0.5625])),
]
for model_id in ["hf-internal-testing/tiny-flux-pipe", "hf-internal-testing/tiny-flux-sharded"]:
# fmt: off
QUANTIZATION_TYPES_TO_TEST = [
("int4wo", np.array([0.4648, 0.5234, 0.5547, 0.4219, 0.4414, 0.6445, 0.4336, 0.4531, 0.5625])),
("int4dq", np.array([0.4688, 0.5195, 0.5547, 0.418, 0.4414, 0.6406, 0.4336, 0.4531, 0.5625])),
("int8wo", np.array([0.4648, 0.5195, 0.5547, 0.4199, 0.4414, 0.6445, 0.4316, 0.4531, 0.5625])),
("int8dq", np.array([0.4648, 0.5195, 0.5547, 0.4199, 0.4414, 0.6445, 0.4316, 0.4531, 0.5625])),
("uint4wo", np.array([0.4609, 0.5234, 0.5508, 0.4199, 0.4336, 0.6406, 0.4316, 0.4531, 0.5625])),
("uint7wo", np.array([0.4648, 0.5195, 0.5547, 0.4219, 0.4414, 0.6445, 0.4316, 0.4531, 0.5625])),
]
if TorchAoConfig._is_cuda_capability_atleast_8_9():
QUANTIZATION_TYPES_TO_TEST.extend([
("float8wo_e5m2", np.array([0.4590, 0.5273, 0.5547, 0.4219, 0.4375, 0.6406, 0.4316, 0.4512, 0.5625])),
("float8wo_e4m3", np.array([0.4648, 0.5234, 0.5547, 0.4219, 0.4414, 0.6406, 0.4316, 0.4531, 0.5625])),
# =====
# The following lead to an internal torch error:
# RuntimeError: mat2 shape (32x4 must be divisible by 16
# Skip these for now; TODO(aryan): investigate later
# ("float8dq_e4m3", np.array([0, 0, 0, 0, 0, 0, 0, 0, 0])),
# ("float8dq_e4m3_tensor", np.array([0, 0, 0, 0, 0, 0, 0, 0, 0])),
# =====
# Cutlass fails to initialize for below
# ("float8dq_e4m3_row", np.array([0, 0, 0, 0, 0, 0, 0, 0, 0])),
# =====
("fp4", np.array([0.4668, 0.5195, 0.5547, 0.4199, 0.4434, 0.6445, 0.4316, 0.4531, 0.5625])),
("fp6", np.array([0.4668, 0.5195, 0.5547, 0.4199, 0.4434, 0.6445, 0.4316, 0.4531, 0.5625])),
])
# fmt: on
if TorchAoConfig._is_cuda_capability_atleast_8_9():
QUANTIZATION_TYPES_TO_TEST.extend([
("float8wo_e5m2", np.array([0.4590, 0.5273, 0.5547, 0.4219, 0.4375, 0.6406, 0.4316, 0.4512, 0.5625])),
("float8wo_e4m3", np.array([0.4648, 0.5234, 0.5547, 0.4219, 0.4414, 0.6406, 0.4316, 0.4531, 0.5625])),
# =====
# The following lead to an internal torch error:
# RuntimeError: mat2 shape (32x4 must be divisible by 16
# Skip these for now; TODO(aryan): investigate later
# ("float8dq_e4m3", np.array([0, 0, 0, 0, 0, 0, 0, 0, 0])),
# ("float8dq_e4m3_tensor", np.array([0, 0, 0, 0, 0, 0, 0, 0, 0])),
# =====
# Cutlass fails to initialize for below
# ("float8dq_e4m3_row", np.array([0, 0, 0, 0, 0, 0, 0, 0, 0])),
# =====
("fp4", np.array([0.4668, 0.5195, 0.5547, 0.4199, 0.4434, 0.6445, 0.4316, 0.4531, 0.5625])),
("fp6", np.array([0.4668, 0.5195, 0.5547, 0.4199, 0.4434, 0.6445, 0.4316, 0.4531, 0.5625])),
])
# fmt: on
for quantization_name, expected_slice in QUANTIZATION_TYPES_TO_TEST:
quant_kwargs = {}
if quantization_name in ["uint4wo", "uint7wo"]:
# The dummy flux model that we use has smaller dimensions. This imposes some restrictions on group_size here
quant_kwargs.update({"group_size": 16})
quantization_config = TorchAoConfig(
quant_type=quantization_name, modules_to_not_convert=["x_embedder"], **quant_kwargs
)
self._test_quant_type(quantization_config, expected_slice)
for quantization_name, expected_slice in QUANTIZATION_TYPES_TO_TEST:
quant_kwargs = {}
if quantization_name in ["uint4wo", "uint7wo"]:
# The dummy flux model that we use has smaller dimensions. This imposes some restrictions on group_size here
quant_kwargs.update({"group_size": 16})
quantization_config = TorchAoConfig(
quant_type=quantization_name, modules_to_not_convert=["x_embedder"], **quant_kwargs
)
self._test_quant_type(quantization_config, expected_slice, model_id)
def test_int4wo_quant_bfloat16_conversion(self):
"""
@@ -276,15 +280,16 @@ class TorchAoTest(unittest.TestCase):
self.assertTrue(isinstance(weight, AffineQuantizedTensor))
self.assertEqual(weight.quant_min, 0)
self.assertEqual(weight.quant_max, 15)
self.assertTrue(isinstance(weight.layout_type, TensorCoreTiledLayoutType))
def test_device_map(self):
# Note: We were not checking if the weight tensor's were AffineQuantizedTensor's before. If we did
# it would have errored out. Now, we do. So, device_map basically never worked with or without
# sharded checkpoints. This will need to be supported in the future (TODO(aryan))
"""
Test if the quantized model int4 weight-only is working properly with "auto" and custom device maps.
The custom device map performs cpu/disk offloading as well. Also verifies that the device map is
correctly set (in the `hf_device_map` attribute of the model).
"""
custom_device_map_dict = {
"time_text_embed": torch_device,
"context_embedder": torch_device,
@@ -296,51 +301,73 @@ class TorchAoTest(unittest.TestCase):
}
device_maps = ["auto", custom_device_map_dict]
inputs = self.get_dummy_tensor_inputs(torch_device)
expected_slice = np.array([0.3457, -0.0366, 0.0105, -0.2275, -0.4941, 0.4395, -0.166, -0.6641, 0.4375])
# inputs = self.get_dummy_tensor_inputs(torch_device)
# expected_slice = np.array([0.3457, -0.0366, 0.0105, -0.2275, -0.4941, 0.4395, -0.166, -0.6641, 0.4375])
for device_map in device_maps:
device_map_to_compare = {"": 0} if device_map == "auto" else device_map
# device_map_to_compare = {"": 0} if device_map == "auto" else device_map
# Test non-sharded model
with tempfile.TemporaryDirectory() as offload_folder:
quantization_config = TorchAoConfig("int4_weight_only", group_size=64)
quantized_model = FluxTransformer2DModel.from_pretrained(
"hf-internal-testing/tiny-flux-pipe",
subfolder="transformer",
quantization_config=quantization_config,
device_map=device_map,
torch_dtype=torch.bfloat16,
offload_folder=offload_folder,
)
# Test non-sharded model - should work
with self.assertRaises(NotImplementedError):
with tempfile.TemporaryDirectory() as offload_folder:
quantization_config = TorchAoConfig("int4_weight_only", group_size=64)
_ = FluxTransformer2DModel.from_pretrained(
"hf-internal-testing/tiny-flux-pipe",
subfolder="transformer",
quantization_config=quantization_config,
device_map=device_map,
torch_dtype=torch.bfloat16,
offload_folder=offload_folder,
)
self.assertTrue(quantized_model.hf_device_map == device_map_to_compare)
# weight = quantized_model.transformer_blocks[0].ff.net[2].weight
# self.assertTrue(quantized_model.hf_device_map == device_map_to_compare)
# self.assertTrue(isinstance(weight, AffineQuantizedTensor))
output = quantized_model(**inputs)[0]
output_slice = output.flatten()[-9:].detach().float().cpu().numpy()
self.assertTrue(np.allclose(output_slice, expected_slice, atol=1e-3, rtol=1e-3))
# output = quantized_model(**inputs)[0]
# output_slice = output.flatten()[-9:].detach().float().cpu().numpy()
# self.assertTrue(np.allclose(output_slice, expected_slice, atol=1e-3, rtol=1e-3))
# Test sharded model
with tempfile.TemporaryDirectory() as offload_folder:
quantization_config = TorchAoConfig("int4_weight_only", group_size=64)
quantized_model = FluxTransformer2DModel.from_pretrained(
"hf-internal-testing/tiny-flux-sharded",
subfolder="transformer",
quantization_config=quantization_config,
device_map=device_map,
torch_dtype=torch.bfloat16,
offload_folder=offload_folder,
)
# Test sharded model - should not work
with self.assertRaises(NotImplementedError):
with tempfile.TemporaryDirectory() as offload_folder:
quantization_config = TorchAoConfig("int4_weight_only", group_size=64)
_ = FluxTransformer2DModel.from_pretrained(
"hf-internal-testing/tiny-flux-sharded",
subfolder="transformer",
quantization_config=quantization_config,
device_map=device_map,
torch_dtype=torch.bfloat16,
offload_folder=offload_folder,
)
self.assertTrue(quantized_model.hf_device_map == device_map_to_compare)
# weight = quantized_model.transformer_blocks[0].ff.net[2].weight
# self.assertTrue(quantized_model.hf_device_map == device_map_to_compare)
# self.assertTrue(isinstance(weight, AffineQuantizedTensor))
output = quantized_model(**inputs)[0]
output_slice = output.flatten()[-9:].detach().float().cpu().numpy()
# output = quantized_model(**inputs)[0]
# output_slice = output.flatten()[-9:].detach().float().cpu().numpy()
self.assertTrue(np.allclose(output_slice, expected_slice, atol=1e-3, rtol=1e-3))
# self.assertTrue(np.allclose(output_slice, expected_slice, atol=1e-3, rtol=1e-3))
def test_modules_to_not_convert(self):
quantization_config = TorchAoConfig("int8_weight_only", modules_to_not_convert=["transformer_blocks.0"])
quantized_model_with_not_convert = FluxTransformer2DModel.from_pretrained(
"hf-internal-testing/tiny-flux-pipe",
subfolder="transformer",
quantization_config=quantization_config,
torch_dtype=torch.bfloat16,
)
unquantized_layer = quantized_model_with_not_convert.transformer_blocks[0].ff.net[2]
self.assertTrue(isinstance(unquantized_layer, torch.nn.Linear))
self.assertFalse(isinstance(unquantized_layer.weight, AffineQuantizedTensor))
self.assertEqual(unquantized_layer.weight.dtype, torch.bfloat16)
quantized_layer = quantized_model_with_not_convert.proj_out
self.assertTrue(isinstance(quantized_layer.weight, AffineQuantizedTensor))
quantization_config = TorchAoConfig("int8_weight_only")
quantized_model = FluxTransformer2DModel.from_pretrained(
"hf-internal-testing/tiny-flux-pipe",
subfolder="transformer",
@@ -348,14 +375,10 @@ class TorchAoTest(unittest.TestCase):
torch_dtype=torch.bfloat16,
)
unquantized_layer = quantized_model.transformer_blocks[0].ff.net[2]
self.assertTrue(isinstance(unquantized_layer, torch.nn.Linear))
self.assertFalse(isinstance(unquantized_layer.weight, AffineQuantizedTensor))
self.assertEqual(unquantized_layer.weight.dtype, torch.bfloat16)
size_quantized_with_not_convert = get_model_size_in_bytes(quantized_model_with_not_convert)
size_quantized = get_model_size_in_bytes(quantized_model)
quantized_layer = quantized_model.proj_out
self.assertTrue(isinstance(quantized_layer.weight, AffineQuantizedTensor))
self.assertEqual(quantized_layer.weight.layout_tensor.data.dtype, torch.int8)
self.assertTrue(size_quantized < size_quantized_with_not_convert)
def test_training(self):
quantization_config = TorchAoConfig("int8_weight_only")
@@ -391,83 +414,82 @@ class TorchAoTest(unittest.TestCase):
@nightly
def test_torch_compile(self):
r"""Test that verifies if torch.compile works with torchao quantization."""
quantization_config = TorchAoConfig("int8_weight_only")
components = self.get_dummy_components(quantization_config)
pipe = FluxPipeline(**components)
pipe.to(device=torch_device, dtype=torch.bfloat16)
for model_id in ["hf-internal-testing/tiny-flux-pipe", "hf-internal-testing/tiny-flux-sharded"]:
quantization_config = TorchAoConfig("int8_weight_only")
components = self.get_dummy_components(quantization_config, model_id=model_id)
pipe = FluxPipeline(**components)
pipe.to(device=torch_device)
inputs = self.get_dummy_inputs(torch_device)
normal_output = pipe(**inputs)[0].flatten()[-32:]
inputs = self.get_dummy_inputs(torch_device)
normal_output = pipe(**inputs)[0].flatten()[-32:]
pipe.transformer = torch.compile(pipe.transformer, mode="max-autotune", fullgraph=True, dynamic=False)
inputs = self.get_dummy_inputs(torch_device)
compile_output = pipe(**inputs)[0].flatten()[-32:]
pipe.transformer = torch.compile(pipe.transformer, mode="max-autotune", fullgraph=True, dynamic=False)
inputs = self.get_dummy_inputs(torch_device)
compile_output = pipe(**inputs)[0].flatten()[-32:]
# Note: Seems to require higher tolerance
self.assertTrue(np.allclose(normal_output, compile_output, atol=1e-2, rtol=1e-3))
@staticmethod
def _get_memory_footprint(module):
quantized_param_memory = 0.0
unquantized_param_memory = 0.0
for param in module.parameters():
if param.__class__.__name__ == "AffineQuantizedTensor":
data, scale, zero_point = param.layout_tensor.get_plain()
quantized_param_memory += data.numel() + data.element_size()
quantized_param_memory += scale.numel() + scale.element_size()
quantized_param_memory += zero_point.numel() + zero_point.element_size()
else:
unquantized_param_memory += param.data.numel() * param.data.element_size()
total_memory = quantized_param_memory + unquantized_param_memory
return total_memory, quantized_param_memory, unquantized_param_memory
# Note: Seems to require higher tolerance
self.assertTrue(np.allclose(normal_output, compile_output, atol=1e-2, rtol=1e-3))
def test_memory_footprint(self):
r"""
A simple test to check if the model conversion has been done correctly by checking on the
memory footprint of the converted model and the class type of the linear layers of the converted models
"""
transformer_int4wo = self.get_dummy_components(TorchAoConfig("int4wo"))["transformer"]
transformer_int4wo_gs32 = self.get_dummy_components(TorchAoConfig("int4wo", group_size=32))["transformer"]
transformer_int8wo = self.get_dummy_components(TorchAoConfig("int8wo"))["transformer"]
transformer_bf16 = self.get_dummy_components(None)["transformer"]
for model_id in ["hf-internal-testing/tiny-flux-pipe", "hf-internal-testing/tiny-flux-sharded"]:
transformer_int4wo = self.get_dummy_components(TorchAoConfig("int4wo"), model_id=model_id)["transformer"]
transformer_int4wo_gs32 = self.get_dummy_components(
TorchAoConfig("int4wo", group_size=32), model_id=model_id
)["transformer"]
transformer_int8wo = self.get_dummy_components(TorchAoConfig("int8wo"), model_id=model_id)["transformer"]
transformer_bf16 = self.get_dummy_components(None, model_id=model_id)["transformer"]
total_int4wo, quantized_int4wo, unquantized_int4wo = self._get_memory_footprint(transformer_int4wo)
total_int4wo_gs32, quantized_int4wo_gs32, unquantized_int4wo_gs32 = self._get_memory_footprint(
transformer_int4wo_gs32
)
total_int8wo, quantized_int8wo, unquantized_int8wo = self._get_memory_footprint(transformer_int8wo)
total_bf16, quantized_bf16, unquantized_bf16 = self._get_memory_footprint(transformer_bf16)
# Will not quantized all the layers by default due to the model weights shapes not being divisible by group_size=64
for block in transformer_int4wo.transformer_blocks:
self.assertTrue(isinstance(block.ff.net[2].weight, AffineQuantizedTensor))
self.assertTrue(isinstance(block.ff_context.net[2].weight, AffineQuantizedTensor))
self.assertTrue(quantized_bf16 == 0 and total_bf16 == unquantized_bf16)
# int4wo_gs32 has smaller group size, so more groups -> more scales and zero points
self.assertTrue(total_int8wo < total_bf16 < total_int4wo_gs32)
# int4 with default group size quantized very few linear layers compared to a smaller group size of 32
self.assertTrue(quantized_int4wo < quantized_int4wo_gs32 and unquantized_int4wo > unquantized_int4wo_gs32)
# int8 quantizes more layers compare to int4 with default group size
self.assertTrue(quantized_int8wo < quantized_int4wo)
# Will quantize all the linear layers except x_embedder
for name, module in transformer_int4wo_gs32.named_modules():
if isinstance(module, nn.Linear) and name not in ["x_embedder"]:
self.assertTrue(isinstance(module.weight, AffineQuantizedTensor))
# Will quantize all the linear layers
for module in transformer_int8wo.modules():
if isinstance(module, nn.Linear):
self.assertTrue(isinstance(module.weight, AffineQuantizedTensor))
total_int4wo = get_model_size_in_bytes(transformer_int4wo)
total_int4wo_gs32 = get_model_size_in_bytes(transformer_int4wo_gs32)
total_int8wo = get_model_size_in_bytes(transformer_int8wo)
total_bf16 = get_model_size_in_bytes(transformer_bf16)
# TODO: refactor to align with other quantization tests
# Latter has smaller group size, so more groups -> more scales and zero points
self.assertTrue(total_int4wo < total_int4wo_gs32)
# int8 quantizes more layers compare to int4 with default group size
self.assertTrue(total_int8wo < total_int4wo)
# int4wo does not quantize too many layers because of default group size, but for the layers it does
# there is additional overhead of scales and zero points
self.assertTrue(total_bf16 < total_int4wo)
def test_wrong_config(self):
with self.assertRaises(ValueError):
self.get_dummy_components(TorchAoConfig("int42"))
# This class is not to be run as a test by itself. See the tests that follow this class
# Slices for these tests have been obtained on our aws-g6e-xlarge-plus runners
@require_torch
@require_torch_gpu
@require_torchao_version_greater("0.6.0")
@require_torchao_version_greater_or_equal("0.7.0")
class TorchAoSerializationTest(unittest.TestCase):
model_name = "hf-internal-testing/tiny-flux-pipe"
quant_method, quant_method_kwargs = None, None
device = "cuda"
def tearDown(self):
gc.collect()
torch.cuda.empty_cache()
def get_dummy_model(self, device=None):
quantization_config = TorchAoConfig(self.quant_method, **self.quant_method_kwargs)
def get_dummy_model(self, quant_method, quant_method_kwargs, device=None):
quantization_config = TorchAoConfig(quant_method, **quant_method_kwargs)
quantized_model = FluxTransformer2DModel.from_pretrained(
self.model_name,
subfolder="transformer",
@@ -503,21 +525,23 @@ class TorchAoSerializationTest(unittest.TestCase):
"timestep": timestep,
}
def test_original_model_expected_slice(self):
quantized_model = self.get_dummy_model(torch_device)
def _test_original_model_expected_slice(self, quant_method, quant_method_kwargs, expected_slice):
quantized_model = self.get_dummy_model(quant_method, quant_method_kwargs, torch_device)
inputs = self.get_dummy_tensor_inputs(torch_device)
output = quantized_model(**inputs)[0]
output_slice = output.flatten()[-9:].detach().float().cpu().numpy()
self.assertTrue(np.allclose(output_slice, self.expected_slice, atol=1e-3, rtol=1e-3))
weight = quantized_model.transformer_blocks[0].ff.net[2].weight
self.assertTrue(isinstance(weight, (AffineQuantizedTensor, LinearActivationQuantizedTensor)))
self.assertTrue(np.allclose(output_slice, expected_slice, atol=1e-3, rtol=1e-3))
def check_serialization_expected_slice(self, expected_slice):
quantized_model = self.get_dummy_model(self.device)
def _check_serialization_expected_slice(self, quant_method, quant_method_kwargs, expected_slice, device):
quantized_model = self.get_dummy_model(quant_method, quant_method_kwargs, device)
with tempfile.TemporaryDirectory() as tmp_dir:
quantized_model.save_pretrained(tmp_dir, safe_serialization=False)
loaded_quantized_model = FluxTransformer2DModel.from_pretrained(
tmp_dir, torch_dtype=torch.bfloat16, device_map=torch_device, use_safetensors=False
)
tmp_dir, torch_dtype=torch.bfloat16, use_safetensors=False
).to(device=torch_device)
inputs = self.get_dummy_tensor_inputs(torch_device)
output = loaded_quantized_model(**inputs)[0]
@@ -530,42 +554,39 @@ class TorchAoSerializationTest(unittest.TestCase):
)
self.assertTrue(np.allclose(output_slice, expected_slice, atol=1e-3, rtol=1e-3))
def test_serialization_expected_slice(self):
self.check_serialization_expected_slice(self.serialized_expected_slice)
def test_int_a8w8_cuda(self):
quant_method, quant_method_kwargs = "int8_dynamic_activation_int8_weight", {}
expected_slice = np.array([0.3633, -0.1357, -0.0188, -0.249, -0.4688, 0.5078, -0.1289, -0.6914, 0.4551])
device = "cuda"
self._test_original_model_expected_slice(quant_method, quant_method_kwargs, expected_slice)
self._check_serialization_expected_slice(quant_method, quant_method_kwargs, expected_slice, device)
def test_int_a16w8_cuda(self):
quant_method, quant_method_kwargs = "int8_weight_only", {}
expected_slice = np.array([0.3613, -0.127, -0.0223, -0.2539, -0.459, 0.4961, -0.1357, -0.6992, 0.4551])
device = "cuda"
self._test_original_model_expected_slice(quant_method, quant_method_kwargs, expected_slice)
self._check_serialization_expected_slice(quant_method, quant_method_kwargs, expected_slice, device)
class TorchAoSerializationINTA8W8Test(TorchAoSerializationTest):
quant_method, quant_method_kwargs = "int8_dynamic_activation_int8_weight", {}
expected_slice = np.array([0.3633, -0.1357, -0.0188, -0.249, -0.4688, 0.5078, -0.1289, -0.6914, 0.4551])
serialized_expected_slice = expected_slice
device = "cuda"
def test_int_a8w8_cpu(self):
quant_method, quant_method_kwargs = "int8_dynamic_activation_int8_weight", {}
expected_slice = np.array([0.3633, -0.1357, -0.0188, -0.249, -0.4688, 0.5078, -0.1289, -0.6914, 0.4551])
device = "cpu"
self._test_original_model_expected_slice(quant_method, quant_method_kwargs, expected_slice)
self._check_serialization_expected_slice(quant_method, quant_method_kwargs, expected_slice, device)
class TorchAoSerializationINTA16W8Test(TorchAoSerializationTest):
quant_method, quant_method_kwargs = "int8_weight_only", {}
expected_slice = np.array([0.3613, -0.127, -0.0223, -0.2539, -0.459, 0.4961, -0.1357, -0.6992, 0.4551])
serialized_expected_slice = expected_slice
device = "cuda"
class TorchAoSerializationINTA8W8CPUTest(TorchAoSerializationTest):
quant_method, quant_method_kwargs = "int8_dynamic_activation_int8_weight", {}
expected_slice = np.array([0.3633, -0.1357, -0.0188, -0.249, -0.4688, 0.5078, -0.1289, -0.6914, 0.4551])
serialized_expected_slice = expected_slice
device = "cpu"
class TorchAoSerializationINTA16W8CPUTest(TorchAoSerializationTest):
quant_method, quant_method_kwargs = "int8_weight_only", {}
expected_slice = np.array([0.3613, -0.127, -0.0223, -0.2539, -0.459, 0.4961, -0.1357, -0.6992, 0.4551])
serialized_expected_slice = expected_slice
device = "cpu"
def test_int_a16w8_cpu(self):
quant_method, quant_method_kwargs = "int8_weight_only", {}
expected_slice = np.array([0.3613, -0.127, -0.0223, -0.2539, -0.459, 0.4961, -0.1357, -0.6992, 0.4551])
device = "cpu"
self._test_original_model_expected_slice(quant_method, quant_method_kwargs, expected_slice)
self._check_serialization_expected_slice(quant_method, quant_method_kwargs, expected_slice, device)
# Slices for these tests have been obtained on our aws-g6e-xlarge-plus runners
@require_torch
@require_torch_gpu
@require_torchao_version_greater("0.6.0")
@require_torchao_version_greater_or_equal("0.7.0")
@slow
@nightly
class SlowTorchAoTests(unittest.TestCase):
@@ -574,18 +595,25 @@ class SlowTorchAoTests(unittest.TestCase):
torch.cuda.empty_cache()
def get_dummy_components(self, quantization_config: TorchAoConfig):
# This is just for convenience, so that we can modify it at one place for custom environments and locally testing
cache_dir = None
model_id = "black-forest-labs/FLUX.1-dev"
transformer = FluxTransformer2DModel.from_pretrained(
model_id,
subfolder="transformer",
quantization_config=quantization_config,
torch_dtype=torch.bfloat16,
cache_dir=cache_dir,
)
text_encoder = CLIPTextModel.from_pretrained(model_id, subfolder="text_encoder")
text_encoder_2 = T5EncoderModel.from_pretrained(model_id, subfolder="text_encoder_2")
tokenizer = CLIPTokenizer.from_pretrained(model_id, subfolder="tokenizer")
tokenizer_2 = AutoTokenizer.from_pretrained(model_id, subfolder="tokenizer_2")
vae = AutoencoderKL.from_pretrained(model_id, subfolder="vae")
text_encoder = CLIPTextModel.from_pretrained(
model_id, subfolder="text_encoder", torch_dtype=torch.bfloat16, cache_dir=cache_dir
)
text_encoder_2 = T5EncoderModel.from_pretrained(
model_id, subfolder="text_encoder_2", torch_dtype=torch.bfloat16, cache_dir=cache_dir
)
tokenizer = CLIPTokenizer.from_pretrained(model_id, subfolder="tokenizer", cache_dir=cache_dir)
tokenizer_2 = AutoTokenizer.from_pretrained(model_id, subfolder="tokenizer_2", cache_dir=cache_dir)
vae = AutoencoderKL.from_pretrained(model_id, subfolder="vae", torch_dtype=torch.bfloat16, cache_dir=cache_dir)
scheduler = FlowMatchEulerDiscreteScheduler()
return {
@@ -617,13 +645,15 @@ class SlowTorchAoTests(unittest.TestCase):
def _test_quant_type(self, quantization_config, expected_slice):
components = self.get_dummy_components(quantization_config)
pipe = FluxPipeline(**components).to(dtype=torch.bfloat16)
pipe = FluxPipeline(**components)
pipe.enable_model_cpu_offload()
weight = pipe.transformer.transformer_blocks[0].ff.net[2].weight
self.assertTrue(isinstance(weight, (AffineQuantizedTensor, LinearActivationQuantizedTensor)))
inputs = self.get_dummy_inputs(torch_device)
output = pipe(**inputs)[0].flatten()
output_slice = np.concatenate((output[:16], output[-16:]))
self.assertTrue(np.allclose(output_slice, expected_slice, atol=1e-3, rtol=1e-3))
def test_quantization(self):
@@ -636,7 +666,7 @@ class SlowTorchAoTests(unittest.TestCase):
if TorchAoConfig._is_cuda_capability_atleast_8_9():
QUANTIZATION_TYPES_TO_TEST.extend([
("float8wo_e4m3", np.array([0.0546, 0.0722, 0.1328, 0.0468, 0.0585, 0.1367, 0.0605, 0.0703, 0.1328, 0.0625, 0.0703, 0.1445, 0.0585, 0.0703, 0.1406, 0.0605, 0.3496, 0.7109, 0.4843, 0.4042, 0.7226, 0.5000, 0.4160, 0.7031, 0.4824, 0.3886, 0.6757, 0.4667, 0.3710, 0.6679, 0.4902, 0.4238])),
("fp5_e3m1", np.array([0.0527, 0.0742, 0.1289, 0.0449, 0.0625, 0.1308, 0.0585, 0.0742, 0.1269, 0.0585, 0.0722, 0.1328, 0.0566, 0.0742, 0.1347, 0.0585, 0.3691, 0.7578, 0.5429, 0.4355, 0.7695, 0.5546, 0.4414, 0.7578, 0.5468, 0.4179, 0.7265, 0.5273, 0.3945, 0.6992, 0.5234, 0.4316])),
("fp5_e3m1", np.array([0.0527, 0.0762, 0.1309, 0.0449, 0.0645, 0.1328, 0.0566, 0.0723, 0.125, 0.0566, 0.0703, 0.1328, 0.0566, 0.0742, 0.1348, 0.0566, 0.3633, 0.7617, 0.5273, 0.4277, 0.7891, 0.5469, 0.4375, 0.8008, 0.5586, 0.4336, 0.7383, 0.5156, 0.3906, 0.6992, 0.5156, 0.4375])),
])
# fmt: on
@@ -646,3 +676,125 @@ class SlowTorchAoTests(unittest.TestCase):
gc.collect()
torch.cuda.empty_cache()
torch.cuda.synchronize()
def test_serialization_int8wo(self):
quantization_config = TorchAoConfig("int8wo")
components = self.get_dummy_components(quantization_config)
pipe = FluxPipeline(**components)
pipe.enable_model_cpu_offload()
weight = pipe.transformer.x_embedder.weight
self.assertTrue(isinstance(weight, AffineQuantizedTensor))
inputs = self.get_dummy_inputs(torch_device)
output = pipe(**inputs)[0].flatten()[:128]
with tempfile.TemporaryDirectory() as tmp_dir:
pipe.transformer.save_pretrained(tmp_dir, safe_serialization=False)
pipe.remove_all_hooks()
del pipe.transformer
gc.collect()
torch.cuda.empty_cache()
torch.cuda.synchronize()
transformer = FluxTransformer2DModel.from_pretrained(
tmp_dir, torch_dtype=torch.bfloat16, use_safetensors=False
)
pipe.transformer = transformer
pipe.enable_model_cpu_offload()
weight = transformer.x_embedder.weight
self.assertTrue(isinstance(weight, AffineQuantizedTensor))
loaded_output = pipe(**inputs)[0].flatten()[:128]
# Seems to require higher tolerance depending on which machine it is being run.
# A difference of 0.06 in normalized pixel space (-1 to 1), corresponds to a difference of
# 0.06 / 2 * 255 = 7.65 in pixel space (0 to 255). On our CI runners, the difference is about 0.04,
# on DGX it is 0.06, and on audace it is 0.037. So, we are using a tolerance of 0.06 here.
self.assertTrue(np.allclose(output, loaded_output, atol=0.06))
def test_memory_footprint_int4wo(self):
# The original checkpoints are in bf16 and about 24 GB
expected_memory_in_gb = 6.0
quantization_config = TorchAoConfig("int4wo")
cache_dir = None
transformer = FluxTransformer2DModel.from_pretrained(
"black-forest-labs/FLUX.1-dev",
subfolder="transformer",
quantization_config=quantization_config,
torch_dtype=torch.bfloat16,
cache_dir=cache_dir,
)
int4wo_memory_in_gb = get_model_size_in_bytes(transformer) / 1024**3
self.assertTrue(int4wo_memory_in_gb < expected_memory_in_gb)
def test_memory_footprint_int8wo(self):
# The original checkpoints are in bf16 and about 24 GB
expected_memory_in_gb = 12.0
quantization_config = TorchAoConfig("int8wo")
cache_dir = None
transformer = FluxTransformer2DModel.from_pretrained(
"black-forest-labs/FLUX.1-dev",
subfolder="transformer",
quantization_config=quantization_config,
torch_dtype=torch.bfloat16,
cache_dir=cache_dir,
)
int8wo_memory_in_gb = get_model_size_in_bytes(transformer) / 1024**3
self.assertTrue(int8wo_memory_in_gb < expected_memory_in_gb)
@require_torch
@require_torch_gpu
@require_torchao_version_greater_or_equal("0.7.0")
@slow
@nightly
class SlowTorchAoPreserializedModelTests(unittest.TestCase):
def tearDown(self):
gc.collect()
torch.cuda.empty_cache()
def get_dummy_inputs(self, device: torch.device, seed: int = 0):
if str(device).startswith("mps"):
generator = torch.manual_seed(seed)
else:
generator = torch.Generator().manual_seed(seed)
inputs = {
"prompt": "an astronaut riding a horse in space",
"height": 512,
"width": 512,
"num_inference_steps": 20,
"output_type": "np",
"generator": generator,
}
return inputs
def test_transformer_int8wo(self):
# fmt: off
expected_slice = np.array([0.0566, 0.0781, 0.1426, 0.0488, 0.0684, 0.1504, 0.0625, 0.0781, 0.1445, 0.0625, 0.0781, 0.1562, 0.0547, 0.0723, 0.1484, 0.0566, 0.5703, 0.8867, 0.7266, 0.5742, 0.875, 0.7148, 0.5586, 0.875, 0.7148, 0.5547, 0.8633, 0.7109, 0.5469, 0.8398, 0.6992, 0.5703])
# fmt: on
# This is just for convenience, so that we can modify it at one place for custom environments and locally testing
cache_dir = None
transformer = FluxTransformer2DModel.from_pretrained(
"hf-internal-testing/FLUX.1-Dev-TorchAO-int8wo-transformer",
torch_dtype=torch.bfloat16,
use_safetensors=False,
cache_dir=cache_dir,
)
pipe = FluxPipeline.from_pretrained(
"black-forest-labs/FLUX.1-dev", transformer=transformer, torch_dtype=torch.bfloat16, cache_dir=cache_dir
)
pipe.enable_model_cpu_offload()
# Verify that all linear layer weights are quantized
for name, module in pipe.transformer.named_modules():
if isinstance(module, nn.Linear):
self.assertTrue(isinstance(module.weight, AffineQuantizedTensor))
# Verify outputs match expected slice
inputs = self.get_dummy_inputs(torch_device)
output = pipe(**inputs)[0].flatten()
output_slice = np.concatenate((output[:16], output[-16:]))
self.assertTrue(np.allclose(output_slice, expected_slice, atol=1e-3, rtol=1e-3))