Compare commits

...

36 Commits

Author SHA1 Message Date
DN6 560fb5f4d6 Release: v0.32.2 2025-01-15 18:16:32 +05:30
Dhruv Nair 8ab26ac9bf [Single File] Fix loading Flux Dev finetunes with Comfy Prefix (#10545)
* update

* update

* update

* update

---------

Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
2025-01-15 18:04:52 +05:30
Dhruv Nair 9f305e7ce2 [CI] Update HF Token on Fast GPU Model Tests (#10570)
update
2025-01-15 18:04:18 +05:30
Dhruv Nair 2c25bf5bef [CI] Update HF Token in Fast GPU Tests (#10568)
update
2025-01-15 18:04:05 +05:30
hlky 0e14cacffc Fix batch > 1 in HunyuanVideo (#10548) 2025-01-15 18:02:51 +05:30
hlky 13ea83f0fa Fix HunyuanVideo produces NaN on PyTorch<2.5 (#10482)
Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
2025-01-15 18:01:57 +05:30
Aryan 2b432ac5a8 Fix hunyuan video attention mask dim (#10454)
* fix

* add coauthor

Co-Authored-By: Nerogar <nerogar@arcor.de>

---------

Co-authored-by: Nerogar <nerogar@arcor.de>
2025-01-15 17:58:25 +05:30
Sayak Paul 263b973466 [LoRA] feat: support loading loras into 4bit quantized Flux models. (#10578)
* feat: support loading loras into 4bit quantized models.

* updates

* update

* remove weight check.
2025-01-15 17:56:14 +05:30
Sayak Paul a663a67ea2 [LoRA] clean up load_lora_into_text_encoder() and fuse_lora() copied from (#10495)
* factor out text encoder loading.

* make fix-copies

* remove copied from fuse_lora and unfuse_lora as needed.

* remove unused imports
2025-01-15 17:55:47 +05:30
Aryan 526858c801 [LoRA] Support original format loras for HunyuanVideo (#10376)
* update

* fix make copies

* update

* add relevant markers to the integration test suite.

* add copied.

* fox-copies

* temporarily add print.

* directly place on CUDA as CPU isn't that big on the CIO.

* fixes to fuse_lora, aryan was right.

* fixes

---------

Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
2025-01-15 17:55:40 +05:30
Sayak Paul b3d2dd36d7 [LoRA] fix: lora unloading when using expanded Flux LoRAs. (#10397)
* fix: lora unloading when using expanded Flux LoRAs.

* fix argument name.

Co-authored-by: a-r-r-o-w <contact.aryanvs@gmail.com>

* docs.

---------

Co-authored-by: a-r-r-o-w <contact.aryanvs@gmail.com>
2025-01-15 17:55:19 +05:30
maxs-kan abfa922410 Fix Flux multiple Lora loading bug (#10388)
* check for base_layer key in transformer state dict

* test_lora_expansion_works_for_absent_keys

* check

* Update tests/lora/test_lora_layers_flux.py

Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>

* check

* test_lora_expansion_works_for_absent_keys/test_lora_expansion_works_for_extra_keys

* absent->extra

---------

Co-authored-by: hlky <hlky@hlky.ac>
Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
2025-01-15 17:55:05 +05:30
Sayak Paul 6a7b01f60f [LoRA] feat: support unload_lora_weights() for Flux Control. (#10206)
* feat: support unload_lora_weights() for Flux Control.

* tighten test

* minor

* updates

* meta device fixes.
2025-01-15 17:54:56 +05:30
Aryan e8aacda762 Release: v0.32.1 2024-12-25 11:34:06 +01:00
Aryan 12184f4015 Fix TorchAO related bugs; revert device_map changes (#10371)
* Revert "Add support for sharded models when TorchAO quantization is enabled (#10256)"

This reverts commit 41ba8c0bf6.

* update tests

* udpate

* update

* update

* update device map tests

* apply review suggestions

* update

* make style

* fix

* update docs

* update tests

* update workflow

* update

* improve tests

* allclose tolerance

* Update src/diffusers/models/modeling_utils.py

Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>

* Update tests/quantization/torchao/test_torchao.py

Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>

* improve tests

* fix

* update correct slices

---------

Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
2024-12-25 11:31:21 +01:00
Sayak Paul 6e1d2da194 fix test pypi installation in the release workflow (#10360)
fix
2024-12-25 11:25:34 +01:00
YiYi Xu 11b1151840 make style for https://github.com/huggingface/diffusers/pull/10368 (#10370)
* fix bug for torch.uint1-7 not support in torch<2.6

* up

---------

Co-authored-by: baymax591 <cbai@mail.nwpu.edu.cn>
2024-12-25 11:25:14 +01:00
sayakpaul cd4d0d8ffb Release: v0.32.0 2024-12-23 20:26:28 +05:30
Aryan 4b557132ce [core] LTX Video 0.9.1 (#10330)
* update

* make style

* update

* update

* update

* make style

* single file related changes

* update

* fix

* update single file urls and docs

* update

* fix
2024-12-23 19:51:33 +05:30
Sayak Paul 851dfa30ae [Tests] Fix more tests sayak (#10359)
* fixes to tests

* fixture

* fixes
2024-12-23 19:11:21 +05:30
Sayak Paul ea1ba0ba53 [LoRA] test fix (#10351)
updates
2024-12-23 15:45:45 +05:30
Aryan 9d27df8071 Rename LTX blocks and docs title (#10213)
* rename blocks and docs

* fix docs

---------

Co-authored-by: Dhruv Nair <dhruv.nair@gmail.com>
2024-12-23 15:29:10 +05:30
Aryan 055d95543a Fix failing CogVideoX LoRA fuse test (#10352)
fix
2024-12-23 14:22:09 +05:30
hlky 71cc2013fe Fix FluxIPAdapterTesterMixin (#10354) 2024-12-23 14:20:06 +05:30
Sayak Paul c34fc34563 [Tests] QoL improvements to the LoRA test suite (#10304)
* misc lora test improvements.

* updates

* fixes to tests
2024-12-23 13:59:55 +05:30
Dhruv Nair 5fcee4a447 [Single File] Fix loading (#10349)
update
2024-12-23 13:12:23 +05:30
Sayak Paul 76e2727b5c [SANA LoRA] sana lora training tests and misc. (#10296)
* sana lora training tests and misc.

* remove push to hub

* Update examples/dreambooth/train_dreambooth_lora_sana.py

Co-authored-by: Aryan <aryan@huggingface.co>

---------

Co-authored-by: Aryan <aryan@huggingface.co>
2024-12-23 12:35:13 +05:30
Aryan 02c777c065 [tests] Refactor TorchAO serialization fast tests (#10271)
refactor
2024-12-23 11:04:57 +05:30
Sayak Paul 6a970a45c5 [docs] fix: torchao example. (#10278)
fix: torchao example.
2024-12-23 11:03:50 +05:30
Aryan ffc0eaab6d Bump minimum TorchAO version to 0.7.0 (#10293)
* bump min torchao version to 0.7.0

* update
2024-12-23 11:03:04 +05:30
Thien Tran 3c2e2aa8a9 .from_single_file() - Add missing .shape (#10332)
Add missing `.shape`
2024-12-23 08:57:25 +05:30
Junsong Chen b58868e6f4 [Sana bug] bug fix for 2K model config (#10340)
* fix the Positinoal Embedding bug in 2K model;

* Change the default model to the BF16 one for more stable training and output

* make style

* substract buffer size

* add compute_module_persistent_sizes

---------

Co-authored-by: yiyixuxu <yixu310@gmail.com>
2024-12-23 08:56:25 +05:30
Dhruv Nair da21d590b5 [Single File] Add Single File support for HunYuan video (#10320)
* update

* Update src/diffusers/loaders/single_file_utils.py

Co-authored-by: Aryan <aryan@huggingface.co>

---------

Co-authored-by: Aryan <aryan@huggingface.co>
2024-12-23 08:44:58 +05:30
YiYi Xu 7c2f0afb1c update get_parameter_dtype (#10342)
add:
q
2024-12-23 08:14:13 +05:30
hlky f615f00f58 Fix enable_sequential_cpu_offload in test_kandinsky_combined (#10324)
Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
2024-12-22 15:28:28 -10:00
Aryan 6aaa0518e3 Community hosted weights for diffusers format HunyuanVideo weights (#10344)
update docs and example to use community weights
2024-12-22 15:26:28 -10:00
107 changed files with 2817 additions and 1376 deletions
+2
View File
@@ -359,6 +359,8 @@ jobs:
test_location: "bnb"
- backend: "gguf"
test_location: "gguf"
- backend: "torchao"
test_location: "torchao"
runs-on:
group: aws-g6e-xlarge-plus
container:
+2 -2
View File
@@ -83,7 +83,7 @@ jobs:
python utils/print_env.py
- name: PyTorch CUDA checkpoint tests on Ubuntu
env:
HF_TOKEN: ${{ secrets.HF_TOKEN }}
HF_TOKEN: ${{ secrets.DIFFUSERS_HF_HUB_READ_TOKEN }}
# https://pytorch.org/docs/stable/notes/randomness.html#avoiding-nondeterministic-algorithms
CUBLAS_WORKSPACE_CONFIG: :16:8
run: |
@@ -137,7 +137,7 @@ jobs:
- name: Run PyTorch CUDA tests
env:
HF_TOKEN: ${{ secrets.HF_TOKEN }}
HF_TOKEN: ${{ secrets.DIFFUSERS_HF_HUB_READ_TOKEN }}
# https://pytorch.org/docs/stable/notes/randomness.html#avoiding-nondeterministic-algorithms
CUBLAS_WORKSPACE_CONFIG: :16:8
run: |
+1 -1
View File
@@ -68,7 +68,7 @@ jobs:
- name: Test installing diffusers and importing
run: |
pip install diffusers && pip uninstall diffusers -y
pip install -i https://testpypi.python.org/pypi diffusers
pip install -i https://test.pypi.org/simple/ diffusers
python -c "from diffusers import __version__; print(__version__)"
python -c "from diffusers import DiffusionPipeline; pipe = DiffusionPipeline.from_pretrained('fusing/unet-ldm-dummy-update'); pipe()"
python -c "from diffusers import DiffusionPipeline; pipe = DiffusionPipeline.from_pretrained('hf-internal-testing/tiny-stable-diffusion-pipe', safety_checker=None); pipe('ah suh du')"
+1 -1
View File
@@ -429,7 +429,7 @@
- local: api/pipelines/ledits_pp
title: LEDITS++
- local: api/pipelines/ltx_video
title: LTX
title: LTXVideo
- local: api/pipelines/lumina
title: Lumina-T2X
- local: api/pipelines/marigold
@@ -18,7 +18,7 @@ The model can be loaded with the following code snippet.
```python
from diffusers import AutoencoderKLHunyuanVideo
vae = AutoencoderKLHunyuanVideo.from_pretrained("tencent/HunyuanVideo", torch_dtype=torch.float16)
vae = AutoencoderKLHunyuanVideo.from_pretrained("hunyuanvideo-community/HunyuanVideo", subfolder="vae", torch_dtype=torch.float16)
```
## AutoencoderKLHunyuanVideo
@@ -18,7 +18,7 @@ The model can be loaded with the following code snippet.
```python
from diffusers import AutoencoderKLLTXVideo
vae = AutoencoderKLLTXVideo.from_pretrained("TODO/TODO", subfolder="vae", torch_dtype=torch.float32).to("cuda")
vae = AutoencoderKLLTXVideo.from_pretrained("Lightricks/LTX-Video", subfolder="vae", torch_dtype=torch.float32).to("cuda")
```
## AutoencoderKLLTXVideo
@@ -18,7 +18,7 @@ The model can be loaded with the following code snippet.
```python
from diffusers import HunyuanVideoTransformer3DModel
transformer = HunyuanVideoTransformer3DModel.from_pretrained("tencent/HunyuanVideo", torch_dtype=torch.bfloat16)
transformer = HunyuanVideoTransformer3DModel.from_pretrained("hunyuanvideo-community/HunyuanVideo", subfolder="transformer", torch_dtype=torch.bfloat16)
```
## HunyuanVideoTransformer3DModel
@@ -18,7 +18,7 @@ The model can be loaded with the following code snippet.
```python
from diffusers import LTXVideoTransformer3DModel
transformer = LTXVideoTransformer3DModel.from_pretrained("TODO/TODO", subfolder="transformer", torch_dtype=torch.bfloat16).to("cuda")
transformer = LTXVideoTransformer3DModel.from_pretrained("Lightricks/LTX-Video", subfolder="transformer", torch_dtype=torch.bfloat16).to("cuda")
```
## LTXVideoTransformer3DModel
@@ -22,7 +22,7 @@ The model can be loaded with the following code snippet.
```python
from diffusers import SanaTransformer2DModel
transformer = SanaTransformer2DModel.from_pretrained("Efficient-Large-Model/Sana_1600M_1024px_diffusers", subfolder="transformer", torch_dtype=torch.float16)
transformer = SanaTransformer2DModel.from_pretrained("Efficient-Large-Model/Sana_1600M_1024px_BF16_diffusers", subfolder="transformer", torch_dtype=torch.bfloat16)
```
## SanaTransformer2DModel
+4
View File
@@ -305,6 +305,10 @@ image = control_pipe(
image.save("output.png")
```
## Note about `unload_lora_weights()` when using Flux LoRAs
When unloading the Control LoRA weights, call `pipe.unload_lora_weights(reset_to_overwritten_params=True)` to reset the `pipe.transformer` completely back to its original form. The resultant pipeline can then be used with methods like [`DiffusionPipeline.from_pipe`]. More details about this argument are available in [this PR](https://github.com/huggingface/diffusers/pull/10397).
## Running FP16 inference
Flux can generate high-quality images with FP16 (i.e. to accelerate inference on Turing/Volta GPUs) but produces different outputs compared to FP32/BF16. The issue is that some activations in the text encoders have to be clipped when running in FP16, which affects the overall image. Forcing text encoders to run with FP32 inference thus removes this output difference. See [here](https://github.com/huggingface/diffusers/pull/9097#issuecomment-2272292516) for details.
@@ -29,7 +29,7 @@ Recommendations for inference:
- Transformer should be in `torch.bfloat16`.
- VAE should be in `torch.float16`.
- `num_frames` should be of the form `4 * k + 1`, for example `49` or `129`.
- For smaller resolution images, try lower values of `shift` (between `2.0` to `5.0`) in the [Scheduler](https://huggingface.co/docs/diffusers/main/en/api/schedulers/flow_match_euler_discrete#diffusers.FlowMatchEulerDiscreteScheduler.shift). For larger resolution images, try higher values (between `7.0` and `12.0`). The default value is `7.0` for HunyuanVideo.
- For smaller resolution videos, try lower values of `shift` (between `2.0` to `5.0`) in the [Scheduler](https://huggingface.co/docs/diffusers/main/en/api/schedulers/flow_match_euler_discrete#diffusers.FlowMatchEulerDiscreteScheduler.shift). For larger resolution images, try higher values (between `7.0` and `12.0`). The default value is `7.0` for HunyuanVideo.
- For more information about supported resolutions and other details, please refer to the original repository [here](https://github.com/Tencent/HunyuanVideo/).
## HunyuanVideoPipeline
+40 -2
View File
@@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License. -->
# LTX
# LTX Video
[LTX Video](https://huggingface.co/Lightricks/LTX-Video) is the first DiT-based video generation model capable of generating high-quality videos in real-time. It produces 24 FPS videos at a 768x512 resolution faster than they can be watched. Trained on a large-scale dataset of diverse videos, the model generates high-resolution videos with realistic and varied content. We provide a model for both text-to-video as well as image + text-to-video usecases.
@@ -22,14 +22,24 @@ Make sure to check out the Schedulers [guide](../../using-diffusers/schedulers.m
</Tip>
Available models:
| Model name | Recommended dtype |
|:-------------:|:-----------------:|
| [`LTX Video 0.9.0`](https://huggingface.co/Lightricks/LTX-Video/blob/main/ltx-video-2b-v0.9.safetensors) | `torch.bfloat16` |
| [`LTX Video 0.9.1`](https://huggingface.co/Lightricks/LTX-Video/blob/main/ltx-video-2b-v0.9.1.safetensors) | `torch.bfloat16` |
Note: The recommended dtype is for the transformer component. The VAE and text encoders can be either `torch.float32`, `torch.bfloat16` or `torch.float16` but the recommended dtype is `torch.bfloat16` as used in the original repository.
## Loading Single Files
Loading the original LTX Video checkpoints is also possible with [`~ModelMixin.from_single_file`].
Loading the original LTX Video checkpoints is also possible with [`~ModelMixin.from_single_file`]. We recommend using `from_single_file` for the Lightricks series of models, as they plan to release multiple models in the future in the single file format.
```python
import torch
from diffusers import AutoencoderKLLTXVideo, LTXImageToVideoPipeline, LTXVideoTransformer3DModel
# `single_file_url` could also be https://huggingface.co/Lightricks/LTX-Video/ltx-video-2b-v0.9.1.safetensors
single_file_url = "https://huggingface.co/Lightricks/LTX-Video/ltx-video-2b-v0.9.safetensors"
transformer = LTXVideoTransformer3DModel.from_single_file(
single_file_url, torch_dtype=torch.bfloat16
@@ -99,6 +109,34 @@ export_to_video(video, "output_gguf_ltx.mp4", fps=24)
Make sure to read the [documentation on GGUF](../../quantization/gguf) to learn more about our GGUF support.
<!-- TODO(aryan): Update this when official weights are supported -->
Loading and running inference with [LTX Video 0.9.1](https://huggingface.co/Lightricks/LTX-Video/blob/main/ltx-video-2b-v0.9.1.safetensors) weights.
```python
import torch
from diffusers import LTXPipeline
from diffusers.utils import export_to_video
pipe = LTXPipeline.from_pretrained("a-r-r-o-w/LTX-Video-0.9.1-diffusers", torch_dtype=torch.bfloat16)
pipe.to("cuda")
prompt = "A woman with long brown hair and light skin smiles at another woman with long blonde hair. The woman with brown hair wears a black jacket and has a small, barely noticeable mole on her right cheek. The camera angle is a close-up, focused on the woman with brown hair's face. The lighting is warm and natural, likely from the setting sun, casting a soft glow on the scene. The scene appears to be real-life footage"
negative_prompt = "worst quality, inconsistent motion, blurry, jittery, distorted"
video = pipe(
prompt=prompt,
negative_prompt=negative_prompt,
width=768,
height=512,
num_frames=161,
decode_timestep=0.03,
decode_noise_scale=0.025,
num_inference_steps=50,
).frames[0]
export_to_video(video, "output.mp4", fps=24)
```
Refer to [this section](https://huggingface.co/docs/diffusers/main/en/api/pipelines/cogvideox#memory-optimization) to learn more about optimizing memory consumption.
## LTXPipeline
+1 -1
View File
@@ -32,9 +32,9 @@ Available models:
| Model | Recommended dtype |
|:-----:|:-----------------:|
| [`Efficient-Large-Model/Sana_1600M_1024px_BF16_diffusers`](https://huggingface.co/Efficient-Large-Model/Sana_1600M_1024px_BF16_diffusers) | `torch.bfloat16` |
| [`Efficient-Large-Model/Sana_1600M_1024px_diffusers`](https://huggingface.co/Efficient-Large-Model/Sana_1600M_1024px_diffusers) | `torch.float16` |
| [`Efficient-Large-Model/Sana_1600M_1024px_MultiLing_diffusers`](https://huggingface.co/Efficient-Large-Model/Sana_1600M_1024px_MultiLing_diffusers) | `torch.float16` |
| [`Efficient-Large-Model/Sana_1600M_1024px_BF16_diffusers`](https://huggingface.co/Efficient-Large-Model/Sana_1600M_1024px_BF16_diffusers) | `torch.bfloat16` |
| [`Efficient-Large-Model/Sana_1600M_512px_diffusers`](https://huggingface.co/Efficient-Large-Model/Sana_1600M_512px_diffusers) | `torch.float16` |
| [`Efficient-Large-Model/Sana_1600M_512px_MultiLing_diffusers`](https://huggingface.co/Efficient-Large-Model/Sana_1600M_512px_MultiLing_diffusers) | `torch.float16` |
| [`Efficient-Large-Model/Sana_600M_1024px_diffusers`](https://huggingface.co/Efficient-Large-Model/Sana_600M_1024px_diffusers) | `torch.float16` |
+66 -2
View File
@@ -25,9 +25,10 @@ Quantize a model by passing [`TorchAoConfig`] to [`~ModelMixin.from_pretrained`]
The example below only quantizes the weights to int8.
```python
import torch
from diffusers import FluxPipeline, FluxTransformer2DModel, TorchAoConfig
model_id = "black-forest-labs/Flux.1-Dev"
model_id = "black-forest-labs/FLUX.1-dev"
dtype = torch.bfloat16
quantization_config = TorchAoConfig("int8wo")
@@ -44,8 +45,14 @@ pipe = FluxPipeline.from_pretrained(
)
pipe.to("cuda")
# Without quantization: ~31.447 GB
# With quantization: ~20.40 GB
print(f"Pipeline memory usage: {torch.cuda.max_memory_reserved() / 1024**3:.3f} GB")
prompt = "A cat holding a sign that says hello world"
image = pipe(prompt, num_inference_steps=28, guidance_scale=0.0).images[0]
image = pipe(
prompt, num_inference_steps=50, guidance_scale=4.5, max_sequence_length=512
).images[0]
image.save("output.png")
```
@@ -86,6 +93,63 @@ Some quantization methods are aliases (for example, `int8wo` is the commonly use
Refer to the official torchao documentation for a better understanding of the available quantization methods and the exhaustive list of configuration options available.
## Serializing and Deserializing quantized models
To serialize a quantized model in a given dtype, first load the model with the desired quantization dtype and then save it using the [`~ModelMixin.save_pretrained`] method.
```python
import torch
from diffusers import FluxTransformer2DModel, TorchAoConfig
quantization_config = TorchAoConfig("int8wo")
transformer = FluxTransformer2DModel.from_pretrained(
"black-forest-labs/Flux.1-Dev",
subfolder="transformer",
quantization_config=quantization_config,
torch_dtype=torch.bfloat16,
)
transformer.save_pretrained("/path/to/flux_int8wo", safe_serialization=False)
```
To load a serialized quantized model, use the [`~ModelMixin.from_pretrained`] method.
```python
import torch
from diffusers import FluxPipeline, FluxTransformer2DModel
transformer = FluxTransformer2DModel.from_pretrained("/path/to/flux_int8wo", torch_dtype=torch.bfloat16, use_safetensors=False)
pipe = FluxPipeline.from_pretrained("black-forest-labs/Flux.1-Dev", transformer=transformer, torch_dtype=torch.bfloat16)
pipe.to("cuda")
prompt = "A cat holding a sign that says hello world"
image = pipe(prompt, num_inference_steps=30, guidance_scale=7.0).images[0]
image.save("output.png")
```
Some quantization methods, such as `uint4wo`, cannot be loaded directly and may result in an `UnpicklingError` when trying to load the models, but work as expected when saving them. In order to work around this, one can load the state dict manually into the model. Note, however, that this requires using `weights_only=False` in `torch.load`, so it should be run only if the weights were obtained from a trustable source.
```python
import torch
from accelerate import init_empty_weights
from diffusers import FluxPipeline, FluxTransformer2DModel, TorchAoConfig
# Serialize the model
transformer = FluxTransformer2DModel.from_pretrained(
"black-forest-labs/Flux.1-Dev",
subfolder="transformer",
quantization_config=TorchAoConfig("uint4wo"),
torch_dtype=torch.bfloat16,
)
transformer.save_pretrained("/path/to/flux_uint4wo", safe_serialization=False, max_shard_size="50GB")
# ...
# Load the model
state_dict = torch.load("/path/to/flux_uint4wo/diffusion_pytorch_model.bin", weights_only=False, map_location="cpu")
with init_empty_weights():
transformer = FluxTransformer2DModel.from_config("/path/to/flux_uint4wo/config.json")
transformer.load_state_dict(state_dict, strict=True, assign=True)
```
## Resources
- [TorchAO Quantization API](https://github.com/pytorch/ao/blob/main/torchao/quantization/README.md)
@@ -74,7 +74,7 @@ if is_wandb_available():
import wandb
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.32.0.dev0")
check_min_version("0.32.0")
logger = get_logger(__name__)
@@ -73,7 +73,7 @@ from diffusers.utils.import_utils import is_xformers_available
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.32.0.dev0")
check_min_version("0.32.0")
logger = get_logger(__name__)
@@ -79,7 +79,7 @@ if is_wandb_available():
import wandb
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.32.0.dev0")
check_min_version("0.32.0")
logger = get_logger(__name__)
@@ -61,7 +61,7 @@ if is_wandb_available():
import wandb
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.32.0.dev0")
check_min_version("0.32.0")
logger = get_logger(__name__)
+1 -1
View File
@@ -52,7 +52,7 @@ if is_wandb_available():
import wandb
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.32.0.dev0")
check_min_version("0.32.0")
logger = get_logger(__name__)
@@ -43,8 +43,7 @@ from diffusers.utils import BaseOutput, check_min_version
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.32.0.dev0")
check_min_version("0.32.0")
class MarigoldDepthOutput(BaseOutput):
"""
@@ -73,7 +73,7 @@ if is_wandb_available():
import wandb
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.32.0.dev0")
check_min_version("0.32.0")
logger = get_logger(__name__)
@@ -66,7 +66,7 @@ if is_wandb_available():
import wandb
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.32.0.dev0")
check_min_version("0.32.0")
logger = get_logger(__name__)
@@ -79,7 +79,7 @@ if is_wandb_available():
import wandb
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.32.0.dev0")
check_min_version("0.32.0")
logger = get_logger(__name__)
@@ -72,7 +72,7 @@ if is_wandb_available():
import wandb
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.32.0.dev0")
check_min_version("0.32.0")
logger = get_logger(__name__)
@@ -78,7 +78,7 @@ if is_wandb_available():
import wandb
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.32.0.dev0")
check_min_version("0.32.0")
logger = get_logger(__name__)
+1 -1
View File
@@ -60,7 +60,7 @@ if is_wandb_available():
import wandb
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.32.0.dev0")
check_min_version("0.32.0")
logger = get_logger(__name__)
+1 -1
View File
@@ -60,7 +60,7 @@ if is_wandb_available():
import wandb
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.32.0.dev0")
check_min_version("0.32.0")
logger = logging.getLogger(__name__)
+1 -1
View File
@@ -65,7 +65,7 @@ if is_wandb_available():
import wandb
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.32.0.dev0")
check_min_version("0.32.0")
logger = get_logger(__name__)
if is_torch_npu_available():
+1 -1
View File
@@ -59,7 +59,7 @@ if is_wandb_available():
import wandb
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.32.0.dev0")
check_min_version("0.32.0")
logger = get_logger(__name__)
+1 -1
View File
@@ -61,7 +61,7 @@ if is_wandb_available():
import wandb
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.32.0.dev0")
check_min_version("0.32.0")
logger = get_logger(__name__)
if is_torch_npu_available():
@@ -63,7 +63,7 @@ from diffusers.utils.import_utils import is_xformers_available
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.32.0.dev0")
check_min_version("0.32.0")
logger = get_logger(__name__)
@@ -0,0 +1,206 @@
# coding=utf-8
# Copyright 2024 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
import os
import sys
import tempfile
import safetensors
sys.path.append("..")
from test_examples_utils import ExamplesTestsAccelerate, run_command # noqa: E402
logging.basicConfig(level=logging.DEBUG)
logger = logging.getLogger()
stream_handler = logging.StreamHandler(sys.stdout)
logger.addHandler(stream_handler)
class DreamBoothLoRASANA(ExamplesTestsAccelerate):
instance_data_dir = "docs/source/en/imgs"
pretrained_model_name_or_path = "hf-internal-testing/tiny-sana-pipe"
script_path = "examples/dreambooth/train_dreambooth_lora_sana.py"
transformer_layer_type = "transformer_blocks.0.attn1.to_k"
def test_dreambooth_lora_sana(self):
with tempfile.TemporaryDirectory() as tmpdir:
test_args = f"""
{self.script_path}
--pretrained_model_name_or_path {self.pretrained_model_name_or_path}
--instance_data_dir {self.instance_data_dir}
--resolution 32
--train_batch_size 1
--gradient_accumulation_steps 1
--max_train_steps 2
--learning_rate 5.0e-04
--scale_lr
--lr_scheduler constant
--lr_warmup_steps 0
--output_dir {tmpdir}
--max_sequence_length 16
""".split()
test_args.extend(["--instance_prompt", ""])
run_command(self._launch_args + test_args)
# save_pretrained smoke test
self.assertTrue(os.path.isfile(os.path.join(tmpdir, "pytorch_lora_weights.safetensors")))
# make sure the state_dict has the correct naming in the parameters.
lora_state_dict = safetensors.torch.load_file(os.path.join(tmpdir, "pytorch_lora_weights.safetensors"))
is_lora = all("lora" in k for k in lora_state_dict.keys())
self.assertTrue(is_lora)
# when not training the text encoder, all the parameters in the state dict should start
# with `"transformer"` in their names.
starts_with_transformer = all(key.startswith("transformer") for key in lora_state_dict.keys())
self.assertTrue(starts_with_transformer)
def test_dreambooth_lora_latent_caching(self):
with tempfile.TemporaryDirectory() as tmpdir:
test_args = f"""
{self.script_path}
--pretrained_model_name_or_path {self.pretrained_model_name_or_path}
--instance_data_dir {self.instance_data_dir}
--resolution 32
--train_batch_size 1
--gradient_accumulation_steps 1
--max_train_steps 2
--cache_latents
--learning_rate 5.0e-04
--scale_lr
--lr_scheduler constant
--lr_warmup_steps 0
--output_dir {tmpdir}
--max_sequence_length 16
""".split()
test_args.extend(["--instance_prompt", ""])
run_command(self._launch_args + test_args)
# save_pretrained smoke test
self.assertTrue(os.path.isfile(os.path.join(tmpdir, "pytorch_lora_weights.safetensors")))
# make sure the state_dict has the correct naming in the parameters.
lora_state_dict = safetensors.torch.load_file(os.path.join(tmpdir, "pytorch_lora_weights.safetensors"))
is_lora = all("lora" in k for k in lora_state_dict.keys())
self.assertTrue(is_lora)
# when not training the text encoder, all the parameters in the state dict should start
# with `"transformer"` in their names.
starts_with_transformer = all(key.startswith("transformer") for key in lora_state_dict.keys())
self.assertTrue(starts_with_transformer)
def test_dreambooth_lora_layers(self):
with tempfile.TemporaryDirectory() as tmpdir:
test_args = f"""
{self.script_path}
--pretrained_model_name_or_path {self.pretrained_model_name_or_path}
--instance_data_dir {self.instance_data_dir}
--resolution 32
--train_batch_size 1
--gradient_accumulation_steps 1
--max_train_steps 2
--cache_latents
--learning_rate 5.0e-04
--scale_lr
--lora_layers {self.transformer_layer_type}
--lr_scheduler constant
--lr_warmup_steps 0
--output_dir {tmpdir}
--max_sequence_length 16
""".split()
test_args.extend(["--instance_prompt", ""])
run_command(self._launch_args + test_args)
# save_pretrained smoke test
self.assertTrue(os.path.isfile(os.path.join(tmpdir, "pytorch_lora_weights.safetensors")))
# make sure the state_dict has the correct naming in the parameters.
lora_state_dict = safetensors.torch.load_file(os.path.join(tmpdir, "pytorch_lora_weights.safetensors"))
is_lora = all("lora" in k for k in lora_state_dict.keys())
self.assertTrue(is_lora)
# when not training the text encoder, all the parameters in the state dict should start
# with `"transformer"` in their names. In this test, we only params of
# `self.transformer_layer_type` should be in the state dict.
starts_with_transformer = all(self.transformer_layer_type in key for key in lora_state_dict)
self.assertTrue(starts_with_transformer)
def test_dreambooth_lora_sana_checkpointing_checkpoints_total_limit(self):
with tempfile.TemporaryDirectory() as tmpdir:
test_args = f"""
{self.script_path}
--pretrained_model_name_or_path={self.pretrained_model_name_or_path}
--instance_data_dir={self.instance_data_dir}
--output_dir={tmpdir}
--resolution=32
--train_batch_size=1
--gradient_accumulation_steps=1
--max_train_steps=6
--checkpoints_total_limit=2
--checkpointing_steps=2
--max_sequence_length 16
""".split()
test_args.extend(["--instance_prompt", ""])
run_command(self._launch_args + test_args)
self.assertEqual(
{x for x in os.listdir(tmpdir) if "checkpoint" in x},
{"checkpoint-4", "checkpoint-6"},
)
def test_dreambooth_lora_sana_checkpointing_checkpoints_total_limit_removes_multiple_checkpoints(self):
with tempfile.TemporaryDirectory() as tmpdir:
test_args = f"""
{self.script_path}
--pretrained_model_name_or_path={self.pretrained_model_name_or_path}
--instance_data_dir={self.instance_data_dir}
--output_dir={tmpdir}
--resolution=32
--train_batch_size=1
--gradient_accumulation_steps=1
--max_train_steps=4
--checkpointing_steps=2
--max_sequence_length 166
""".split()
test_args.extend(["--instance_prompt", ""])
run_command(self._launch_args + test_args)
self.assertEqual({x for x in os.listdir(tmpdir) if "checkpoint" in x}, {"checkpoint-2", "checkpoint-4"})
resume_run_args = f"""
{self.script_path}
--pretrained_model_name_or_path={self.pretrained_model_name_or_path}
--instance_data_dir={self.instance_data_dir}
--output_dir={tmpdir}
--resolution=32
--train_batch_size=1
--gradient_accumulation_steps=1
--max_train_steps=8
--checkpointing_steps=2
--resume_from_checkpoint=checkpoint-4
--checkpoints_total_limit=2
--max_sequence_length 16
""".split()
resume_run_args.extend(["--instance_prompt", ""])
run_command(self._launch_args + resume_run_args)
self.assertEqual({x for x in os.listdir(tmpdir) if "checkpoint" in x}, {"checkpoint-6", "checkpoint-8"})
+1 -1
View File
@@ -63,7 +63,7 @@ if is_wandb_available():
import wandb
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.32.0.dev0")
check_min_version("0.32.0")
logger = get_logger(__name__)
+1 -1
View File
@@ -35,7 +35,7 @@ from diffusers.utils import check_min_version
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.32.0.dev0")
check_min_version("0.32.0")
# Cache compiled models across invocations of this script.
cc.initialize_cache(os.path.expanduser("~/.cache/jax/compilation_cache"))
+1 -1
View File
@@ -65,7 +65,7 @@ if is_wandb_available():
import wandb
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.32.0.dev0")
check_min_version("0.32.0")
logger = get_logger(__name__)
+1 -1
View File
@@ -70,7 +70,7 @@ if is_wandb_available():
import wandb
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.32.0.dev0")
check_min_version("0.32.0")
logger = get_logger(__name__)
@@ -72,7 +72,7 @@ if is_wandb_available():
import wandb
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.32.0.dev0")
check_min_version("0.32.0")
logger = get_logger(__name__)
@@ -70,7 +70,7 @@ if is_wandb_available():
import wandb
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.32.0.dev0")
check_min_version("0.32.0")
logger = get_logger(__name__)
@@ -943,7 +943,7 @@ def main(args):
# Load scheduler and models
noise_scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained(
args.pretrained_model_name_or_path, subfolder="scheduler"
args.pretrained_model_name_or_path, subfolder="scheduler", revision=args.revision
)
noise_scheduler_copy = copy.deepcopy(noise_scheduler)
text_encoder = Gemma2Model.from_pretrained(
@@ -964,15 +964,6 @@ def main(args):
vae.requires_grad_(False)
text_encoder.requires_grad_(False)
# Initialize a text encoding pipeline and keep it to CPU for now.
text_encoding_pipeline = SanaPipeline.from_pretrained(
args.pretrained_model_name_or_path,
vae=None,
transformer=None,
text_encoder=text_encoder,
tokenizer=tokenizer,
)
# For mixed precision training we cast all non-trainable weights (vae, text_encoder and transformer) to half-precision
# as these weights are only used for inference, keeping weights in full precision is not required.
weight_dtype = torch.float32
@@ -993,6 +984,15 @@ def main(args):
# because Gemma2 is particularly suited for bfloat16.
text_encoder.to(dtype=torch.bfloat16)
# Initialize a text encoding pipeline and keep it to CPU for now.
text_encoding_pipeline = SanaPipeline.from_pretrained(
args.pretrained_model_name_or_path,
vae=None,
transformer=None,
text_encoder=text_encoder,
tokenizer=tokenizer,
)
if args.gradient_checkpointing:
transformer.enable_gradient_checkpointing()
@@ -1182,6 +1182,7 @@ def main(args):
)
if args.offload:
text_encoding_pipeline = text_encoding_pipeline.to("cpu")
prompt_embeds = prompt_embeds.to(transformer.dtype)
return prompt_embeds, prompt_attention_mask
# If no type of tuning is done on the text_encoder and custom instance prompts are NOT
@@ -1216,7 +1217,7 @@ def main(args):
vae_config_scaling_factor = vae.config.scaling_factor
if args.cache_latents:
latents_cache = []
vae = vae.to("cuda")
vae = vae.to(accelerator.device)
for batch in tqdm(train_dataloader, desc="Caching latents"):
with torch.no_grad():
batch["pixel_values"] = batch["pixel_values"].to(
@@ -72,7 +72,7 @@ if is_wandb_available():
import wandb
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.32.0.dev0")
check_min_version("0.32.0")
logger = get_logger(__name__)
@@ -79,7 +79,7 @@ if is_wandb_available():
import wandb
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.32.0.dev0")
check_min_version("0.32.0")
logger = get_logger(__name__)
+1 -1
View File
@@ -63,7 +63,7 @@ if is_wandb_available():
import wandb
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.32.0.dev0")
check_min_version("0.32.0")
logger = get_logger(__name__)
+1 -1
View File
@@ -54,7 +54,7 @@ if is_wandb_available():
import wandb
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.32.0.dev0")
check_min_version("0.32.0")
logger = get_logger(__name__)
@@ -57,7 +57,7 @@ if is_wandb_available():
import wandb
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.32.0.dev0")
check_min_version("0.32.0")
logger = get_logger(__name__)
@@ -57,7 +57,7 @@ if is_wandb_available():
import wandb
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.32.0.dev0")
check_min_version("0.32.0")
logger = get_logger(__name__, log_level="INFO")
@@ -60,7 +60,7 @@ if is_wandb_available():
import wandb
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.32.0.dev0")
check_min_version("0.32.0")
logger = get_logger(__name__, log_level="INFO")
@@ -52,7 +52,7 @@ if is_wandb_available():
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.32.0.dev0")
check_min_version("0.32.0")
logger = get_logger(__name__, log_level="INFO")
@@ -46,7 +46,7 @@ from diffusers.utils import check_min_version, is_wandb_available
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.32.0.dev0")
check_min_version("0.32.0")
logger = get_logger(__name__, log_level="INFO")
@@ -46,7 +46,7 @@ from diffusers.utils import check_min_version, is_wandb_available
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.32.0.dev0")
check_min_version("0.32.0")
logger = get_logger(__name__, log_level="INFO")
@@ -51,7 +51,7 @@ if is_wandb_available():
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.32.0.dev0")
check_min_version("0.32.0")
logger = get_logger(__name__, log_level="INFO")
@@ -60,7 +60,7 @@ if is_wandb_available():
import wandb
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.32.0.dev0")
check_min_version("0.32.0")
logger = get_logger(__name__)
@@ -57,7 +57,7 @@ if is_wandb_available():
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.32.0.dev0")
check_min_version("0.32.0")
logger = get_logger(__name__, log_level="INFO")
@@ -49,7 +49,7 @@ from diffusers.utils import check_min_version
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.32.0.dev0")
check_min_version("0.32.0")
logger = logging.getLogger(__name__)
@@ -56,7 +56,7 @@ if is_wandb_available():
import wandb
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.32.0.dev0")
check_min_version("0.32.0")
logger = get_logger(__name__, log_level="INFO")
@@ -68,7 +68,7 @@ if is_wandb_available():
import wandb
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.32.0.dev0")
check_min_version("0.32.0")
logger = get_logger(__name__)
if is_torch_npu_available():
@@ -55,7 +55,7 @@ from diffusers.utils.torch_utils import is_compiled_module
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.32.0.dev0")
check_min_version("0.32.0")
logger = get_logger(__name__)
if is_torch_npu_available():
@@ -81,7 +81,7 @@ else:
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.32.0.dev0")
check_min_version("0.32.0")
logger = get_logger(__name__)
@@ -56,7 +56,7 @@ else:
# ------------------------------------------------------------------------------
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.32.0.dev0")
check_min_version("0.32.0")
logger = logging.getLogger(__name__)
@@ -76,7 +76,7 @@ else:
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.32.0.dev0")
check_min_version("0.32.0")
logger = get_logger(__name__)
@@ -29,7 +29,7 @@ from diffusers.utils.import_utils import is_xformers_available
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.32.0.dev0")
check_min_version("0.32.0")
logger = get_logger(__name__, log_level="INFO")
+1 -1
View File
@@ -50,7 +50,7 @@ if is_wandb_available():
import wandb
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.32.0.dev0")
check_min_version("0.32.0")
logger = get_logger(__name__, log_level="INFO")
+99 -11
View File
@@ -1,7 +1,9 @@
import argparse
from pathlib import Path
from typing import Any, Dict
import torch
from accelerate import init_empty_weights
from safetensors.torch import load_file
from transformers import T5EncoderModel, T5Tokenizer
@@ -21,7 +23,9 @@ TRANSFORMER_KEYS_RENAME_DICT = {
"k_norm": "norm_k",
}
TRANSFORMER_SPECIAL_KEYS_REMAP = {}
TRANSFORMER_SPECIAL_KEYS_REMAP = {
"vae": remove_keys_,
}
VAE_KEYS_RENAME_DICT = {
# decoder
@@ -54,10 +58,31 @@ VAE_KEYS_RENAME_DICT = {
"per_channel_statistics.std-of-means": "latents_std",
}
VAE_091_RENAME_DICT = {
# decoder
"up_blocks.0": "mid_block",
"up_blocks.1": "up_blocks.0.upsamplers.0",
"up_blocks.2": "up_blocks.0",
"up_blocks.3": "up_blocks.1.upsamplers.0",
"up_blocks.4": "up_blocks.1",
"up_blocks.5": "up_blocks.2.upsamplers.0",
"up_blocks.6": "up_blocks.2",
"up_blocks.7": "up_blocks.3.upsamplers.0",
"up_blocks.8": "up_blocks.3",
# common
"last_time_embedder": "time_embedder",
"last_scale_shift_table": "scale_shift_table",
}
VAE_SPECIAL_KEYS_REMAP = {
"per_channel_statistics.channel": remove_keys_,
"per_channel_statistics.mean-of-means": remove_keys_,
"per_channel_statistics.mean-of-stds": remove_keys_,
"model.diffusion_model": remove_keys_,
}
VAE_091_SPECIAL_KEYS_REMAP = {
"timestep_scale_multiplier": remove_keys_,
}
@@ -80,13 +105,16 @@ def convert_transformer(
ckpt_path: str,
dtype: torch.dtype,
):
PREFIX_KEY = ""
PREFIX_KEY = "model.diffusion_model."
original_state_dict = get_state_dict(load_file(ckpt_path))
transformer = LTXVideoTransformer3DModel().to(dtype=dtype)
with init_empty_weights():
transformer = LTXVideoTransformer3DModel()
for key in list(original_state_dict.keys()):
new_key = key[len(PREFIX_KEY) :]
new_key = key[:]
if new_key.startswith(PREFIX_KEY):
new_key = key[len(PREFIX_KEY) :]
for replace_key, rename_key in TRANSFORMER_KEYS_RENAME_DICT.items():
new_key = new_key.replace(replace_key, rename_key)
update_state_dict_inplace(original_state_dict, key, new_key)
@@ -97,16 +125,21 @@ def convert_transformer(
continue
handler_fn_inplace(key, original_state_dict)
transformer.load_state_dict(original_state_dict, strict=True)
transformer.load_state_dict(original_state_dict, strict=True, assign=True)
return transformer
def convert_vae(ckpt_path: str, dtype: torch.dtype):
def convert_vae(ckpt_path: str, config, dtype: torch.dtype):
PREFIX_KEY = "vae."
original_state_dict = get_state_dict(load_file(ckpt_path))
vae = AutoencoderKLLTXVideo().to(dtype=dtype)
with init_empty_weights():
vae = AutoencoderKLLTXVideo(**config)
for key in list(original_state_dict.keys()):
new_key = key[:]
if new_key.startswith(PREFIX_KEY):
new_key = key[len(PREFIX_KEY) :]
for replace_key, rename_key in VAE_KEYS_RENAME_DICT.items():
new_key = new_key.replace(replace_key, rename_key)
update_state_dict_inplace(original_state_dict, key, new_key)
@@ -117,10 +150,60 @@ def convert_vae(ckpt_path: str, dtype: torch.dtype):
continue
handler_fn_inplace(key, original_state_dict)
vae.load_state_dict(original_state_dict, strict=True)
vae.load_state_dict(original_state_dict, strict=True, assign=True)
return vae
def get_vae_config(version: str) -> Dict[str, Any]:
if version == "0.9.0":
config = {
"in_channels": 3,
"out_channels": 3,
"latent_channels": 128,
"block_out_channels": (128, 256, 512, 512),
"decoder_block_out_channels": (128, 256, 512, 512),
"layers_per_block": (4, 3, 3, 3, 4),
"decoder_layers_per_block": (4, 3, 3, 3, 4),
"spatio_temporal_scaling": (True, True, True, False),
"decoder_spatio_temporal_scaling": (True, True, True, False),
"decoder_inject_noise": (False, False, False, False, False),
"upsample_residual": (False, False, False, False),
"upsample_factor": (1, 1, 1, 1),
"patch_size": 4,
"patch_size_t": 1,
"resnet_norm_eps": 1e-6,
"scaling_factor": 1.0,
"encoder_causal": True,
"decoder_causal": False,
"timestep_conditioning": False,
}
elif version == "0.9.1":
config = {
"in_channels": 3,
"out_channels": 3,
"latent_channels": 128,
"block_out_channels": (128, 256, 512, 512),
"decoder_block_out_channels": (256, 512, 1024),
"layers_per_block": (4, 3, 3, 3, 4),
"decoder_layers_per_block": (5, 6, 7, 8),
"spatio_temporal_scaling": (True, True, True, False),
"decoder_spatio_temporal_scaling": (True, True, True),
"decoder_inject_noise": (True, True, True, False),
"upsample_residual": (True, True, True),
"upsample_factor": (2, 2, 2),
"timestep_conditioning": True,
"patch_size": 4,
"patch_size_t": 1,
"resnet_norm_eps": 1e-6,
"scaling_factor": 1.0,
"encoder_causal": True,
"decoder_causal": False,
}
VAE_KEYS_RENAME_DICT.update(VAE_091_RENAME_DICT)
VAE_SPECIAL_KEYS_REMAP.update(VAE_091_SPECIAL_KEYS_REMAP)
return config
def get_args():
parser = argparse.ArgumentParser()
parser.add_argument(
@@ -139,6 +222,9 @@ def get_args():
parser.add_argument("--save_pipeline", action="store_true")
parser.add_argument("--output_path", type=str, required=True, help="Path where converted model should be saved")
parser.add_argument("--dtype", default="fp32", help="Torch dtype to save the model in.")
parser.add_argument(
"--version", type=str, default="0.9.0", choices=["0.9.0", "0.9.1"], help="Version of the LTX model"
)
return parser.parse_args()
@@ -161,6 +247,7 @@ if __name__ == "__main__":
transformer = None
dtype = DTYPE_MAPPING[args.dtype]
variant = VARIANT_MAPPING[args.dtype]
output_path = Path(args.output_path)
if args.save_pipeline:
assert args.transformer_ckpt_path is not None and args.vae_ckpt_path is not None
@@ -169,13 +256,14 @@ if __name__ == "__main__":
transformer: LTXVideoTransformer3DModel = convert_transformer(args.transformer_ckpt_path, dtype)
if not args.save_pipeline:
transformer.save_pretrained(
args.output_path, safe_serialization=True, max_shard_size="5GB", variant=variant
output_path / "transformer", safe_serialization=True, max_shard_size="5GB", variant=variant
)
if args.vae_ckpt_path is not None:
vae: AutoencoderKLLTXVideo = convert_vae(args.vae_ckpt_path, dtype)
config = get_vae_config(args.version)
vae: AutoencoderKLLTXVideo = convert_vae(args.vae_ckpt_path, config, dtype)
if not args.save_pipeline:
vae.save_pretrained(args.output_path, safe_serialization=True, max_shard_size="5GB", variant=variant)
vae.save_pretrained(output_path / "vae", safe_serialization=True, max_shard_size="5GB", variant=variant)
if args.save_pipeline:
text_encoder_id = "google/t5-v1_1-xxl"
+6
View File
@@ -88,13 +88,18 @@ def main(args):
# y norm
converted_state_dict["caption_norm.weight"] = state_dict.pop("attention_y_norm.weight")
# scheduler
flow_shift = 3.0
# model config
if args.model_type == "SanaMS_1600M_P1_D20":
layer_num = 20
elif args.model_type == "SanaMS_600M_P1_D28":
layer_num = 28
else:
raise ValueError(f"{args.model_type} is not supported.")
# Positional embedding interpolation scale.
interpolation_scale = {512: None, 1024: None, 2048: 1.0}
for depth in range(layer_num):
# Transformer blocks.
@@ -176,6 +181,7 @@ def main(args):
patch_size=1,
norm_elementwise_affine=False,
norm_eps=1e-6,
interpolation_scale=interpolation_scale[args.image_size],
)
if is_accelerate_available():
+1 -1
View File
@@ -254,7 +254,7 @@ version_range_max = max(sys.version_info[1], 10) + 1
setup(
name="diffusers",
version="0.32.0.dev0", # expected format is one of x.y.z.dev0, or x.y.z.rc1 or x.y.z (no to dashes, yes to dots)
version="0.32.2", # expected format is one of x.y.z.dev0, or x.y.z.rc1 or x.y.z (no to dashes, yes to dots)
description="State-of-the-art diffusion in PyTorch and JAX.",
long_description=open("README.md", "r", encoding="utf-8").read(),
long_description_content_type="text/markdown",
+1 -1
View File
@@ -1,4 +1,4 @@
__version__ = "0.32.0.dev0"
__version__ = "0.32.2"
from typing import TYPE_CHECKING
+156 -21
View File
@@ -28,13 +28,20 @@ from ..models.modeling_utils import ModelMixin, load_state_dict
from ..utils import (
USE_PEFT_BACKEND,
_get_model_file,
convert_state_dict_to_diffusers,
convert_state_dict_to_peft,
delete_adapter_layers,
deprecate,
get_adapter_name,
get_peft_kwargs,
is_accelerate_available,
is_peft_available,
is_peft_version,
is_transformers_available,
is_transformers_version,
logging,
recurse_remove_peft_layers,
scale_lora_layers,
set_adapter_layers,
set_weights_and_activate_adapters,
)
@@ -43,6 +50,8 @@ from ..utils import (
if is_transformers_available():
from transformers import PreTrainedModel
from ..models.lora import text_encoder_attn_modules, text_encoder_mlp_modules
if is_peft_available():
from peft.tuners.tuners_utils import BaseTunerLayer
@@ -297,6 +306,152 @@ def _best_guess_weight_name(
return weight_name
def _load_lora_into_text_encoder(
state_dict,
network_alphas,
text_encoder,
prefix=None,
lora_scale=1.0,
text_encoder_name="text_encoder",
adapter_name=None,
_pipeline=None,
low_cpu_mem_usage=False,
):
if not USE_PEFT_BACKEND:
raise ValueError("PEFT backend is required for this method.")
peft_kwargs = {}
if low_cpu_mem_usage:
if not is_peft_version(">=", "0.13.1"):
raise ValueError(
"`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`."
)
if not is_transformers_version(">", "4.45.2"):
# Note from sayakpaul: It's not in `transformers` stable yet.
# https://github.com/huggingface/transformers/pull/33725/
raise ValueError(
"`low_cpu_mem_usage=True` is not compatible with this `transformers` version. Please update it with `pip install -U transformers`."
)
peft_kwargs["low_cpu_mem_usage"] = low_cpu_mem_usage
from peft import LoraConfig
# If the serialization format is new (introduced in https://github.com/huggingface/diffusers/pull/2918),
# then the `state_dict` keys should have `unet_name` and/or `text_encoder_name` as
# their prefixes.
keys = list(state_dict.keys())
prefix = text_encoder_name if prefix is None else prefix
# Safe prefix to check with.
if any(text_encoder_name in key for key in keys):
# Load the layers corresponding to text encoder and make necessary adjustments.
text_encoder_keys = [k for k in keys if k.startswith(prefix) and k.split(".")[0] == prefix]
text_encoder_lora_state_dict = {
k.replace(f"{prefix}.", ""): v for k, v in state_dict.items() if k in text_encoder_keys
}
if len(text_encoder_lora_state_dict) > 0:
logger.info(f"Loading {prefix}.")
rank = {}
text_encoder_lora_state_dict = convert_state_dict_to_diffusers(text_encoder_lora_state_dict)
# convert state dict
text_encoder_lora_state_dict = convert_state_dict_to_peft(text_encoder_lora_state_dict)
for name, _ in text_encoder_attn_modules(text_encoder):
for module in ("out_proj", "q_proj", "k_proj", "v_proj"):
rank_key = f"{name}.{module}.lora_B.weight"
if rank_key not in text_encoder_lora_state_dict:
continue
rank[rank_key] = text_encoder_lora_state_dict[rank_key].shape[1]
for name, _ in text_encoder_mlp_modules(text_encoder):
for module in ("fc1", "fc2"):
rank_key = f"{name}.{module}.lora_B.weight"
if rank_key not in text_encoder_lora_state_dict:
continue
rank[rank_key] = text_encoder_lora_state_dict[rank_key].shape[1]
if network_alphas is not None:
alpha_keys = [k for k in network_alphas.keys() if k.startswith(prefix) and k.split(".")[0] == prefix]
network_alphas = {k.replace(f"{prefix}.", ""): v for k, v in network_alphas.items() if k in alpha_keys}
lora_config_kwargs = get_peft_kwargs(rank, network_alphas, text_encoder_lora_state_dict, is_unet=False)
if "use_dora" in lora_config_kwargs:
if lora_config_kwargs["use_dora"]:
if is_peft_version("<", "0.9.0"):
raise ValueError(
"You need `peft` 0.9.0 at least to use DoRA-enabled LoRAs. Please upgrade your installation of `peft`."
)
else:
if is_peft_version("<", "0.9.0"):
lora_config_kwargs.pop("use_dora")
if "lora_bias" in lora_config_kwargs:
if lora_config_kwargs["lora_bias"]:
if is_peft_version("<=", "0.13.2"):
raise ValueError(
"You need `peft` 0.14.0 at least to use `bias` in LoRAs. Please upgrade your installation of `peft`."
)
else:
if is_peft_version("<=", "0.13.2"):
lora_config_kwargs.pop("lora_bias")
lora_config = LoraConfig(**lora_config_kwargs)
# adapter_name
if adapter_name is None:
adapter_name = get_adapter_name(text_encoder)
is_model_cpu_offload, is_sequential_cpu_offload = _func_optionally_disable_offloading(_pipeline)
# inject LoRA layers and load the state dict
# in transformers we automatically check whether the adapter name is already in use or not
text_encoder.load_adapter(
adapter_name=adapter_name,
adapter_state_dict=text_encoder_lora_state_dict,
peft_config=lora_config,
**peft_kwargs,
)
# scale LoRA layers with `lora_scale`
scale_lora_layers(text_encoder, weight=lora_scale)
text_encoder.to(device=text_encoder.device, dtype=text_encoder.dtype)
# Offload back.
if is_model_cpu_offload:
_pipeline.enable_model_cpu_offload()
elif is_sequential_cpu_offload:
_pipeline.enable_sequential_cpu_offload()
# Unsafe code />
def _func_optionally_disable_offloading(_pipeline):
is_model_cpu_offload = False
is_sequential_cpu_offload = False
if _pipeline is not None and _pipeline.hf_device_map is None:
for _, component in _pipeline.components.items():
if isinstance(component, nn.Module) and hasattr(component, "_hf_hook"):
if not is_model_cpu_offload:
is_model_cpu_offload = isinstance(component._hf_hook, CpuOffload)
if not is_sequential_cpu_offload:
is_sequential_cpu_offload = (
isinstance(component._hf_hook, AlignDevicesHook)
or hasattr(component._hf_hook, "hooks")
and isinstance(component._hf_hook.hooks[0], AlignDevicesHook)
)
logger.info(
"Accelerate hooks detected. Since you have called `load_lora_weights()`, the previous hooks will be first removed. Then the LoRA parameters will be loaded and the hooks will be applied again."
)
remove_hook_from_module(component, recurse=is_sequential_cpu_offload)
return (is_model_cpu_offload, is_sequential_cpu_offload)
class LoraBaseMixin:
"""Utility class for handling LoRAs."""
@@ -327,27 +482,7 @@ class LoraBaseMixin:
tuple:
A tuple indicating if `is_model_cpu_offload` or `is_sequential_cpu_offload` is True.
"""
is_model_cpu_offload = False
is_sequential_cpu_offload = False
if _pipeline is not None and _pipeline.hf_device_map is None:
for _, component in _pipeline.components.items():
if isinstance(component, nn.Module) and hasattr(component, "_hf_hook"):
if not is_model_cpu_offload:
is_model_cpu_offload = isinstance(component._hf_hook, CpuOffload)
if not is_sequential_cpu_offload:
is_sequential_cpu_offload = (
isinstance(component._hf_hook, AlignDevicesHook)
or hasattr(component._hf_hook, "hooks")
and isinstance(component._hf_hook.hooks[0], AlignDevicesHook)
)
logger.info(
"Accelerate hooks detected. Since you have called `load_lora_weights()`, the previous hooks will be first removed. Then the LoRA parameters will be loaded and the hooks will be applied again."
)
remove_hook_from_module(component, recurse=is_sequential_cpu_offload)
return (is_model_cpu_offload, is_sequential_cpu_offload)
return _func_optionally_disable_offloading(_pipeline=_pipeline)
@classmethod
def _fetch_state_dict(cls, *args, **kwargs):
@@ -973,3 +973,178 @@ def _convert_bfl_flux_control_lora_to_diffusers(original_state_dict):
converted_state_dict[f"transformer.{key}"] = converted_state_dict.pop(key)
return converted_state_dict
def _convert_hunyuan_video_lora_to_diffusers(original_state_dict):
converted_state_dict = {k: original_state_dict.pop(k) for k in list(original_state_dict.keys())}
def remap_norm_scale_shift_(key, state_dict):
weight = state_dict.pop(key)
shift, scale = weight.chunk(2, dim=0)
new_weight = torch.cat([scale, shift], dim=0)
state_dict[key.replace("final_layer.adaLN_modulation.1", "norm_out.linear")] = new_weight
def remap_txt_in_(key, state_dict):
def rename_key(key):
new_key = key.replace("individual_token_refiner.blocks", "token_refiner.refiner_blocks")
new_key = new_key.replace("adaLN_modulation.1", "norm_out.linear")
new_key = new_key.replace("txt_in", "context_embedder")
new_key = new_key.replace("t_embedder.mlp.0", "time_text_embed.timestep_embedder.linear_1")
new_key = new_key.replace("t_embedder.mlp.2", "time_text_embed.timestep_embedder.linear_2")
new_key = new_key.replace("c_embedder", "time_text_embed.text_embedder")
new_key = new_key.replace("mlp", "ff")
return new_key
if "self_attn_qkv" in key:
weight = state_dict.pop(key)
to_q, to_k, to_v = weight.chunk(3, dim=0)
state_dict[rename_key(key.replace("self_attn_qkv", "attn.to_q"))] = to_q
state_dict[rename_key(key.replace("self_attn_qkv", "attn.to_k"))] = to_k
state_dict[rename_key(key.replace("self_attn_qkv", "attn.to_v"))] = to_v
else:
state_dict[rename_key(key)] = state_dict.pop(key)
def remap_img_attn_qkv_(key, state_dict):
weight = state_dict.pop(key)
if "lora_A" in key:
state_dict[key.replace("img_attn_qkv", "attn.to_q")] = weight
state_dict[key.replace("img_attn_qkv", "attn.to_k")] = weight
state_dict[key.replace("img_attn_qkv", "attn.to_v")] = weight
else:
to_q, to_k, to_v = weight.chunk(3, dim=0)
state_dict[key.replace("img_attn_qkv", "attn.to_q")] = to_q
state_dict[key.replace("img_attn_qkv", "attn.to_k")] = to_k
state_dict[key.replace("img_attn_qkv", "attn.to_v")] = to_v
def remap_txt_attn_qkv_(key, state_dict):
weight = state_dict.pop(key)
if "lora_A" in key:
state_dict[key.replace("txt_attn_qkv", "attn.add_q_proj")] = weight
state_dict[key.replace("txt_attn_qkv", "attn.add_k_proj")] = weight
state_dict[key.replace("txt_attn_qkv", "attn.add_v_proj")] = weight
else:
to_q, to_k, to_v = weight.chunk(3, dim=0)
state_dict[key.replace("txt_attn_qkv", "attn.add_q_proj")] = to_q
state_dict[key.replace("txt_attn_qkv", "attn.add_k_proj")] = to_k
state_dict[key.replace("txt_attn_qkv", "attn.add_v_proj")] = to_v
def remap_single_transformer_blocks_(key, state_dict):
hidden_size = 3072
if "linear1.lora_A.weight" in key or "linear1.lora_B.weight" in key:
linear1_weight = state_dict.pop(key)
if "lora_A" in key:
new_key = key.replace("single_blocks", "single_transformer_blocks").removesuffix(
".linear1.lora_A.weight"
)
state_dict[f"{new_key}.attn.to_q.lora_A.weight"] = linear1_weight
state_dict[f"{new_key}.attn.to_k.lora_A.weight"] = linear1_weight
state_dict[f"{new_key}.attn.to_v.lora_A.weight"] = linear1_weight
state_dict[f"{new_key}.proj_mlp.lora_A.weight"] = linear1_weight
else:
split_size = (hidden_size, hidden_size, hidden_size, linear1_weight.size(0) - 3 * hidden_size)
q, k, v, mlp = torch.split(linear1_weight, split_size, dim=0)
new_key = key.replace("single_blocks", "single_transformer_blocks").removesuffix(
".linear1.lora_B.weight"
)
state_dict[f"{new_key}.attn.to_q.lora_B.weight"] = q
state_dict[f"{new_key}.attn.to_k.lora_B.weight"] = k
state_dict[f"{new_key}.attn.to_v.lora_B.weight"] = v
state_dict[f"{new_key}.proj_mlp.lora_B.weight"] = mlp
elif "linear1.lora_A.bias" in key or "linear1.lora_B.bias" in key:
linear1_bias = state_dict.pop(key)
if "lora_A" in key:
new_key = key.replace("single_blocks", "single_transformer_blocks").removesuffix(
".linear1.lora_A.bias"
)
state_dict[f"{new_key}.attn.to_q.lora_A.bias"] = linear1_bias
state_dict[f"{new_key}.attn.to_k.lora_A.bias"] = linear1_bias
state_dict[f"{new_key}.attn.to_v.lora_A.bias"] = linear1_bias
state_dict[f"{new_key}.proj_mlp.lora_A.bias"] = linear1_bias
else:
split_size = (hidden_size, hidden_size, hidden_size, linear1_bias.size(0) - 3 * hidden_size)
q_bias, k_bias, v_bias, mlp_bias = torch.split(linear1_bias, split_size, dim=0)
new_key = key.replace("single_blocks", "single_transformer_blocks").removesuffix(
".linear1.lora_B.bias"
)
state_dict[f"{new_key}.attn.to_q.lora_B.bias"] = q_bias
state_dict[f"{new_key}.attn.to_k.lora_B.bias"] = k_bias
state_dict[f"{new_key}.attn.to_v.lora_B.bias"] = v_bias
state_dict[f"{new_key}.proj_mlp.lora_B.bias"] = mlp_bias
else:
new_key = key.replace("single_blocks", "single_transformer_blocks")
new_key = new_key.replace("linear2", "proj_out")
new_key = new_key.replace("q_norm", "attn.norm_q")
new_key = new_key.replace("k_norm", "attn.norm_k")
state_dict[new_key] = state_dict.pop(key)
TRANSFORMER_KEYS_RENAME_DICT = {
"img_in": "x_embedder",
"time_in.mlp.0": "time_text_embed.timestep_embedder.linear_1",
"time_in.mlp.2": "time_text_embed.timestep_embedder.linear_2",
"guidance_in.mlp.0": "time_text_embed.guidance_embedder.linear_1",
"guidance_in.mlp.2": "time_text_embed.guidance_embedder.linear_2",
"vector_in.in_layer": "time_text_embed.text_embedder.linear_1",
"vector_in.out_layer": "time_text_embed.text_embedder.linear_2",
"double_blocks": "transformer_blocks",
"img_attn_q_norm": "attn.norm_q",
"img_attn_k_norm": "attn.norm_k",
"img_attn_proj": "attn.to_out.0",
"txt_attn_q_norm": "attn.norm_added_q",
"txt_attn_k_norm": "attn.norm_added_k",
"txt_attn_proj": "attn.to_add_out",
"img_mod.linear": "norm1.linear",
"img_norm1": "norm1.norm",
"img_norm2": "norm2",
"img_mlp": "ff",
"txt_mod.linear": "norm1_context.linear",
"txt_norm1": "norm1.norm",
"txt_norm2": "norm2_context",
"txt_mlp": "ff_context",
"self_attn_proj": "attn.to_out.0",
"modulation.linear": "norm.linear",
"pre_norm": "norm.norm",
"final_layer.norm_final": "norm_out.norm",
"final_layer.linear": "proj_out",
"fc1": "net.0.proj",
"fc2": "net.2",
"input_embedder": "proj_in",
}
TRANSFORMER_SPECIAL_KEYS_REMAP = {
"txt_in": remap_txt_in_,
"img_attn_qkv": remap_img_attn_qkv_,
"txt_attn_qkv": remap_txt_attn_qkv_,
"single_blocks": remap_single_transformer_blocks_,
"final_layer.adaLN_modulation.1": remap_norm_scale_shift_,
}
# Some folks attempt to make their state dict compatible with diffusers by adding "transformer." prefix to all keys
# and use their custom code. To make sure both "original" and "attempted diffusers" loras work as expected, we make
# sure that both follow the same initial format by stripping off the "transformer." prefix.
for key in list(converted_state_dict.keys()):
if key.startswith("transformer."):
converted_state_dict[key[len("transformer.") :]] = converted_state_dict.pop(key)
if key.startswith("diffusion_model."):
converted_state_dict[key[len("diffusion_model.") :]] = converted_state_dict.pop(key)
# Rename and remap the state dict keys
for key in list(converted_state_dict.keys()):
new_key = key[:]
for replace_key, rename_key in TRANSFORMER_KEYS_RENAME_DICT.items():
new_key = new_key.replace(replace_key, rename_key)
converted_state_dict[new_key] = converted_state_dict.pop(key)
for key in list(converted_state_dict.keys()):
for special_key, handler_fn_inplace in TRANSFORMER_SPECIAL_KEYS_REMAP.items():
if special_key not in key:
continue
handler_fn_inplace(key, converted_state_dict)
# Add back the "transformer." prefix
for key in list(converted_state_dict.keys()):
converted_state_dict[f"transformer.{key}"] = converted_state_dict.pop(key)
return converted_state_dict
File diff suppressed because it is too large Load Diff
+2 -27
View File
@@ -20,7 +20,6 @@ from typing import Dict, List, Optional, Union
import safetensors
import torch
import torch.nn as nn
from ..utils import (
MIN_PEFT_VERSION,
@@ -30,20 +29,16 @@ from ..utils import (
delete_adapter_layers,
get_adapter_name,
get_peft_kwargs,
is_accelerate_available,
is_peft_available,
is_peft_version,
logging,
set_adapter_layers,
set_weights_and_activate_adapters,
)
from .lora_base import _fetch_state_dict
from .lora_base import _fetch_state_dict, _func_optionally_disable_offloading
from .unet_loader_utils import _maybe_expand_lora_scales
if is_accelerate_available():
from accelerate.hooks import AlignDevicesHook, CpuOffload, remove_hook_from_module
logger = logging.get_logger(__name__)
_SET_ADAPTER_SCALE_FN_MAPPING = {
@@ -140,27 +135,7 @@ class PeftAdapterMixin:
tuple:
A tuple indicating if `is_model_cpu_offload` or `is_sequential_cpu_offload` is True.
"""
is_model_cpu_offload = False
is_sequential_cpu_offload = False
if _pipeline is not None and _pipeline.hf_device_map is None:
for _, component in _pipeline.components.items():
if isinstance(component, nn.Module) and hasattr(component, "_hf_hook"):
if not is_model_cpu_offload:
is_model_cpu_offload = isinstance(component._hf_hook, CpuOffload)
if not is_sequential_cpu_offload:
is_sequential_cpu_offload = (
isinstance(component._hf_hook, AlignDevicesHook)
or hasattr(component._hf_hook, "hooks")
and isinstance(component._hf_hook.hooks[0], AlignDevicesHook)
)
logger.info(
"Accelerate hooks detected. Since you have called `load_lora_weights()`, the previous hooks will be first removed. Then the LoRA parameters will be loaded and the hooks will be applied again."
)
remove_hook_from_module(component, recurse=is_sequential_cpu_offload)
return (is_model_cpu_offload, is_sequential_cpu_offload)
return _func_optionally_disable_offloading(_pipeline=_pipeline)
def load_lora_adapter(self, pretrained_model_name_or_path_or_dict, prefix="transformer", **kwargs):
r"""
+7 -1
View File
@@ -28,6 +28,7 @@ from .single_file_utils import (
convert_autoencoder_dc_checkpoint_to_diffusers,
convert_controlnet_checkpoint,
convert_flux_transformer_checkpoint_to_diffusers,
convert_hunyuan_video_transformer_to_diffusers,
convert_ldm_unet_checkpoint,
convert_ldm_vae_checkpoint,
convert_ltx_transformer_checkpoint_to_diffusers,
@@ -101,6 +102,10 @@ SINGLE_FILE_LOADABLE_CLASSES = {
"checkpoint_mapping_fn": convert_mochi_transformer_checkpoint_to_diffusers,
"default_subfolder": "transformer",
},
"HunyuanVideoTransformer3DModel": {
"checkpoint_mapping_fn": convert_hunyuan_video_transformer_to_diffusers,
"default_subfolder": "transformer",
},
}
@@ -220,6 +225,7 @@ class FromOriginalModelMixin:
local_files_only = kwargs.pop("local_files_only", None)
subfolder = kwargs.pop("subfolder", None)
revision = kwargs.pop("revision", None)
config_revision = kwargs.pop("config_revision", None)
torch_dtype = kwargs.pop("torch_dtype", None)
quantization_config = kwargs.pop("quantization_config", None)
device = kwargs.pop("device", None)
@@ -297,7 +303,7 @@ class FromOriginalModelMixin:
subfolder=subfolder,
local_files_only=local_files_only,
token=token,
revision=revision,
revision=config_revision,
)
expected_kwargs, optional_kwargs = cls._get_signature_keys(cls)
+168 -5
View File
@@ -108,6 +108,7 @@ CHECKPOINT_KEY_NAMES = {
"autoencoder-dc": "decoder.stages.1.op_list.0.main.conv.conv.bias",
"autoencoder-dc-sana": "encoder.project_in.conv.bias",
"mochi-1-preview": ["model.diffusion_model.blocks.0.attn.qkv_x.weight", "blocks.0.attn.qkv_x.weight"],
"hunyuan-video": "txt_in.individual_token_refiner.blocks.0.adaLN_modulation.1.bias",
}
DIFFUSERS_DEFAULT_PIPELINE_PATHS = {
@@ -156,12 +157,14 @@ DIFFUSERS_DEFAULT_PIPELINE_PATHS = {
"flux-fill": {"pretrained_model_name_or_path": "black-forest-labs/FLUX.1-Fill-dev"},
"flux-depth": {"pretrained_model_name_or_path": "black-forest-labs/FLUX.1-Depth-dev"},
"flux-schnell": {"pretrained_model_name_or_path": "black-forest-labs/FLUX.1-schnell"},
"ltx-video": {"pretrained_model_name_or_path": "Lightricks/LTX-Video"},
"ltx-video": {"pretrained_model_name_or_path": "diffusers/LTX-Video-0.9.0"},
"ltx-video-0.9.1": {"pretrained_model_name_or_path": "diffusers/LTX-Video-0.9.1"},
"autoencoder-dc-f128c512": {"pretrained_model_name_or_path": "mit-han-lab/dc-ae-f128c512-mix-1.0-diffusers"},
"autoencoder-dc-f64c128": {"pretrained_model_name_or_path": "mit-han-lab/dc-ae-f64c128-mix-1.0-diffusers"},
"autoencoder-dc-f32c32": {"pretrained_model_name_or_path": "mit-han-lab/dc-ae-f32c32-mix-1.0-diffusers"},
"autoencoder-dc-f32c32-sana": {"pretrained_model_name_or_path": "mit-han-lab/dc-ae-f32c32-sana-1.0-diffusers"},
"mochi-1-preview": {"pretrained_model_name_or_path": "genmo/mochi-1-preview"},
"hunyuan-video": {"pretrained_model_name_or_path": "hunyuanvideo-community/HunyuanVideo"},
}
# Use to configure model sample size when original config is provided
@@ -592,10 +595,14 @@ def infer_diffusers_model_type(checkpoint):
if any(
g in checkpoint for g in ["guidance_in.in_layer.bias", "model.diffusion_model.guidance_in.in_layer.bias"]
):
if checkpoint["img_in.weight"].shape[1] == 384:
model_type = "flux-fill"
if "model.diffusion_model.img_in.weight" in checkpoint:
key = "model.diffusion_model.img_in.weight"
else:
key = "img_in.weight"
elif checkpoint["img_in.weight"].shape[1] == 128:
if checkpoint[key].shape[1] == 384:
model_type = "flux-fill"
elif checkpoint[key].shape[1] == 128:
model_type = "flux-depth"
else:
model_type = "flux-dev"
@@ -603,7 +610,10 @@ def infer_diffusers_model_type(checkpoint):
model_type = "flux-schnell"
elif any(key in checkpoint for key in CHECKPOINT_KEY_NAMES["ltx-video"]):
model_type = "ltx-video"
if "vae.decoder.last_time_embedder.timestep_embedder.linear_1.weight" in checkpoint:
model_type = "ltx-video-0.9.1"
else:
model_type = "ltx-video"
elif CHECKPOINT_KEY_NAMES["autoencoder-dc"] in checkpoint:
encoder_key = "encoder.project_in.conv.conv.bias"
@@ -624,6 +634,9 @@ def infer_diffusers_model_type(checkpoint):
elif any(key in checkpoint for key in CHECKPOINT_KEY_NAMES["mochi-1-preview"]):
model_type = "mochi-1-preview"
elif CHECKPOINT_KEY_NAMES["hunyuan-video"] in checkpoint:
model_type = "hunyuan-video"
else:
model_type = "v1"
@@ -2333,12 +2346,32 @@ def convert_ltx_vae_checkpoint_to_diffusers(checkpoint, **kwargs):
"per_channel_statistics.std-of-means": "latents_std",
}
VAE_091_RENAME_DICT = {
# decoder
"up_blocks.0": "mid_block",
"up_blocks.1": "up_blocks.0.upsamplers.0",
"up_blocks.2": "up_blocks.0",
"up_blocks.3": "up_blocks.1.upsamplers.0",
"up_blocks.4": "up_blocks.1",
"up_blocks.5": "up_blocks.2.upsamplers.0",
"up_blocks.6": "up_blocks.2",
"up_blocks.7": "up_blocks.3.upsamplers.0",
"up_blocks.8": "up_blocks.3",
# common
"last_time_embedder": "time_embedder",
"last_scale_shift_table": "scale_shift_table",
}
VAE_SPECIAL_KEYS_REMAP = {
"per_channel_statistics.channel": remove_keys_,
"per_channel_statistics.mean-of-means": remove_keys_,
"per_channel_statistics.mean-of-stds": remove_keys_,
"timestep_scale_multiplier": remove_keys_,
}
if "vae.decoder.last_time_embedder.timestep_embedder.linear_1.weight" in converted_state_dict:
VAE_KEYS_RENAME_DICT.update(VAE_091_RENAME_DICT)
for key in list(converted_state_dict.keys()):
new_key = key
for replace_key, rename_key in VAE_KEYS_RENAME_DICT.items():
@@ -2522,3 +2555,133 @@ def convert_mochi_transformer_checkpoint_to_diffusers(checkpoint, **kwargs):
new_state_dict["pos_frequencies"] = checkpoint.pop("pos_frequencies")
return new_state_dict
def convert_hunyuan_video_transformer_to_diffusers(checkpoint, **kwargs):
def remap_norm_scale_shift_(key, state_dict):
weight = state_dict.pop(key)
shift, scale = weight.chunk(2, dim=0)
new_weight = torch.cat([scale, shift], dim=0)
state_dict[key.replace("final_layer.adaLN_modulation.1", "norm_out.linear")] = new_weight
def remap_txt_in_(key, state_dict):
def rename_key(key):
new_key = key.replace("individual_token_refiner.blocks", "token_refiner.refiner_blocks")
new_key = new_key.replace("adaLN_modulation.1", "norm_out.linear")
new_key = new_key.replace("txt_in", "context_embedder")
new_key = new_key.replace("t_embedder.mlp.0", "time_text_embed.timestep_embedder.linear_1")
new_key = new_key.replace("t_embedder.mlp.2", "time_text_embed.timestep_embedder.linear_2")
new_key = new_key.replace("c_embedder", "time_text_embed.text_embedder")
new_key = new_key.replace("mlp", "ff")
return new_key
if "self_attn_qkv" in key:
weight = state_dict.pop(key)
to_q, to_k, to_v = weight.chunk(3, dim=0)
state_dict[rename_key(key.replace("self_attn_qkv", "attn.to_q"))] = to_q
state_dict[rename_key(key.replace("self_attn_qkv", "attn.to_k"))] = to_k
state_dict[rename_key(key.replace("self_attn_qkv", "attn.to_v"))] = to_v
else:
state_dict[rename_key(key)] = state_dict.pop(key)
def remap_img_attn_qkv_(key, state_dict):
weight = state_dict.pop(key)
to_q, to_k, to_v = weight.chunk(3, dim=0)
state_dict[key.replace("img_attn_qkv", "attn.to_q")] = to_q
state_dict[key.replace("img_attn_qkv", "attn.to_k")] = to_k
state_dict[key.replace("img_attn_qkv", "attn.to_v")] = to_v
def remap_txt_attn_qkv_(key, state_dict):
weight = state_dict.pop(key)
to_q, to_k, to_v = weight.chunk(3, dim=0)
state_dict[key.replace("txt_attn_qkv", "attn.add_q_proj")] = to_q
state_dict[key.replace("txt_attn_qkv", "attn.add_k_proj")] = to_k
state_dict[key.replace("txt_attn_qkv", "attn.add_v_proj")] = to_v
def remap_single_transformer_blocks_(key, state_dict):
hidden_size = 3072
if "linear1.weight" in key:
linear1_weight = state_dict.pop(key)
split_size = (hidden_size, hidden_size, hidden_size, linear1_weight.size(0) - 3 * hidden_size)
q, k, v, mlp = torch.split(linear1_weight, split_size, dim=0)
new_key = key.replace("single_blocks", "single_transformer_blocks").removesuffix(".linear1.weight")
state_dict[f"{new_key}.attn.to_q.weight"] = q
state_dict[f"{new_key}.attn.to_k.weight"] = k
state_dict[f"{new_key}.attn.to_v.weight"] = v
state_dict[f"{new_key}.proj_mlp.weight"] = mlp
elif "linear1.bias" in key:
linear1_bias = state_dict.pop(key)
split_size = (hidden_size, hidden_size, hidden_size, linear1_bias.size(0) - 3 * hidden_size)
q_bias, k_bias, v_bias, mlp_bias = torch.split(linear1_bias, split_size, dim=0)
new_key = key.replace("single_blocks", "single_transformer_blocks").removesuffix(".linear1.bias")
state_dict[f"{new_key}.attn.to_q.bias"] = q_bias
state_dict[f"{new_key}.attn.to_k.bias"] = k_bias
state_dict[f"{new_key}.attn.to_v.bias"] = v_bias
state_dict[f"{new_key}.proj_mlp.bias"] = mlp_bias
else:
new_key = key.replace("single_blocks", "single_transformer_blocks")
new_key = new_key.replace("linear2", "proj_out")
new_key = new_key.replace("q_norm", "attn.norm_q")
new_key = new_key.replace("k_norm", "attn.norm_k")
state_dict[new_key] = state_dict.pop(key)
TRANSFORMER_KEYS_RENAME_DICT = {
"img_in": "x_embedder",
"time_in.mlp.0": "time_text_embed.timestep_embedder.linear_1",
"time_in.mlp.2": "time_text_embed.timestep_embedder.linear_2",
"guidance_in.mlp.0": "time_text_embed.guidance_embedder.linear_1",
"guidance_in.mlp.2": "time_text_embed.guidance_embedder.linear_2",
"vector_in.in_layer": "time_text_embed.text_embedder.linear_1",
"vector_in.out_layer": "time_text_embed.text_embedder.linear_2",
"double_blocks": "transformer_blocks",
"img_attn_q_norm": "attn.norm_q",
"img_attn_k_norm": "attn.norm_k",
"img_attn_proj": "attn.to_out.0",
"txt_attn_q_norm": "attn.norm_added_q",
"txt_attn_k_norm": "attn.norm_added_k",
"txt_attn_proj": "attn.to_add_out",
"img_mod.linear": "norm1.linear",
"img_norm1": "norm1.norm",
"img_norm2": "norm2",
"img_mlp": "ff",
"txt_mod.linear": "norm1_context.linear",
"txt_norm1": "norm1.norm",
"txt_norm2": "norm2_context",
"txt_mlp": "ff_context",
"self_attn_proj": "attn.to_out.0",
"modulation.linear": "norm.linear",
"pre_norm": "norm.norm",
"final_layer.norm_final": "norm_out.norm",
"final_layer.linear": "proj_out",
"fc1": "net.0.proj",
"fc2": "net.2",
"input_embedder": "proj_in",
}
TRANSFORMER_SPECIAL_KEYS_REMAP = {
"txt_in": remap_txt_in_,
"img_attn_qkv": remap_img_attn_qkv_,
"txt_attn_qkv": remap_txt_attn_qkv_,
"single_blocks": remap_single_transformer_blocks_,
"final_layer.adaLN_modulation.1": remap_norm_scale_shift_,
}
def update_state_dict_(state_dict, old_key, new_key):
state_dict[new_key] = state_dict.pop(old_key)
for key in list(checkpoint.keys()):
new_key = key[:]
for replace_key, rename_key in TRANSFORMER_KEYS_RENAME_DICT.items():
new_key = new_key.replace(replace_key, rename_key)
update_state_dict_(checkpoint, key, new_key)
for key in list(checkpoint.keys()):
for special_key, handler_fn_inplace in TRANSFORMER_SPECIAL_KEYS_REMAP.items():
if special_key not in key:
continue
handler_fn_inplace(key, checkpoint)
return checkpoint
@@ -177,3 +177,5 @@ class FluxTransformer2DLoadersMixin:
self.encoder_hid_proj = MultiIPAdapterImageProjection(image_projection_layers)
self.config.encoder_hid_dim_type = "ip_image_proj"
self.to(dtype=self.dtype, device=self.device)
+2 -25
View File
@@ -21,7 +21,6 @@ import safetensors
import torch
import torch.nn.functional as F
from huggingface_hub.utils import validate_hf_hub_args
from torch import nn
from ..models.embeddings import (
ImageProjection,
@@ -44,13 +43,11 @@ from ..utils import (
is_torch_version,
logging,
)
from .lora_base import _func_optionally_disable_offloading
from .lora_pipeline import LORA_WEIGHT_NAME, LORA_WEIGHT_NAME_SAFE, TEXT_ENCODER_NAME, UNET_NAME
from .utils import AttnProcsLayers
if is_accelerate_available():
from accelerate.hooks import AlignDevicesHook, CpuOffload, remove_hook_from_module
logger = logging.get_logger(__name__)
@@ -400,27 +397,7 @@ class UNet2DConditionLoadersMixin:
tuple:
A tuple indicating if `is_model_cpu_offload` or `is_sequential_cpu_offload` is True.
"""
is_model_cpu_offload = False
is_sequential_cpu_offload = False
if _pipeline is not None and _pipeline.hf_device_map is None:
for _, component in _pipeline.components.items():
if isinstance(component, nn.Module) and hasattr(component, "_hf_hook"):
if not is_model_cpu_offload:
is_model_cpu_offload = isinstance(component._hf_hook, CpuOffload)
if not is_sequential_cpu_offload:
is_sequential_cpu_offload = (
isinstance(component._hf_hook, AlignDevicesHook)
or hasattr(component._hf_hook, "hooks")
and isinstance(component._hf_hook.hooks[0], AlignDevicesHook)
)
logger.info(
"Accelerate hooks detected. Since you have called `load_lora_weights()`, the previous hooks will be first removed. Then the LoRA parameters will be loaded and the hooks will be applied again."
)
remove_hook_from_module(component, recurse=is_sequential_cpu_offload)
return (is_model_cpu_offload, is_sequential_cpu_offload)
return _func_optionally_disable_offloading(_pipeline=_pipeline)
def save_attn_procs(
self,
@@ -22,13 +22,14 @@ from ...configuration_utils import ConfigMixin, register_to_config
from ...loaders import FromOriginalModelMixin
from ...utils.accelerate_utils import apply_forward_hook
from ..activations import get_activation
from ..embeddings import PixArtAlphaCombinedTimestepSizeEmbeddings
from ..modeling_outputs import AutoencoderKLOutput
from ..modeling_utils import ModelMixin
from ..normalization import RMSNorm
from .vae import DecoderOutput, DiagonalGaussianDistribution
class LTXCausalConv3d(nn.Module):
class LTXVideoCausalConv3d(nn.Module):
def __init__(
self,
in_channels: int,
@@ -79,9 +80,9 @@ class LTXCausalConv3d(nn.Module):
return hidden_states
class LTXResnetBlock3d(nn.Module):
class LTXVideoResnetBlock3d(nn.Module):
r"""
A 3D ResNet block used in the LTX model.
A 3D ResNet block used in the LTXVideo model.
Args:
in_channels (`int`):
@@ -109,7 +110,9 @@ class LTXResnetBlock3d(nn.Module):
elementwise_affine: bool = False,
non_linearity: str = "swish",
is_causal: bool = True,
):
inject_noise: bool = False,
timestep_conditioning: bool = False,
) -> None:
super().__init__()
out_channels = out_channels or in_channels
@@ -117,13 +120,13 @@ class LTXResnetBlock3d(nn.Module):
self.nonlinearity = get_activation(non_linearity)
self.norm1 = RMSNorm(in_channels, eps=1e-8, elementwise_affine=elementwise_affine)
self.conv1 = LTXCausalConv3d(
self.conv1 = LTXVideoCausalConv3d(
in_channels=in_channels, out_channels=out_channels, kernel_size=3, is_causal=is_causal
)
self.norm2 = RMSNorm(out_channels, eps=1e-8, elementwise_affine=elementwise_affine)
self.dropout = nn.Dropout(dropout)
self.conv2 = LTXCausalConv3d(
self.conv2 = LTXVideoCausalConv3d(
in_channels=out_channels, out_channels=out_channels, kernel_size=3, is_causal=is_causal
)
@@ -131,22 +134,58 @@ class LTXResnetBlock3d(nn.Module):
self.conv_shortcut = None
if in_channels != out_channels:
self.norm3 = nn.LayerNorm(in_channels, eps=eps, elementwise_affine=True, bias=True)
self.conv_shortcut = LTXCausalConv3d(
self.conv_shortcut = LTXVideoCausalConv3d(
in_channels=in_channels, out_channels=out_channels, kernel_size=1, stride=1, is_causal=is_causal
)
def forward(self, inputs: torch.Tensor) -> torch.Tensor:
self.per_channel_scale1 = None
self.per_channel_scale2 = None
if inject_noise:
self.per_channel_scale1 = nn.Parameter(torch.zeros(in_channels, 1, 1))
self.per_channel_scale2 = nn.Parameter(torch.zeros(in_channels, 1, 1))
self.scale_shift_table = None
if timestep_conditioning:
self.scale_shift_table = nn.Parameter(torch.randn(4, in_channels) / in_channels**0.5)
def forward(
self, inputs: torch.Tensor, temb: Optional[torch.Tensor] = None, generator: Optional[torch.Generator] = None
) -> torch.Tensor:
hidden_states = inputs
hidden_states = self.norm1(hidden_states.movedim(1, -1)).movedim(-1, 1)
if self.scale_shift_table is not None:
temb = temb.unflatten(1, (4, -1)) + self.scale_shift_table[None, ..., None, None, None]
shift_1, scale_1, shift_2, scale_2 = temb.unbind(dim=1)
hidden_states = hidden_states * (1 + scale_1) + shift_1
hidden_states = self.nonlinearity(hidden_states)
hidden_states = self.conv1(hidden_states)
if self.per_channel_scale1 is not None:
spatial_shape = hidden_states.shape[-2:]
spatial_noise = torch.randn(
spatial_shape, generator=generator, device=hidden_states.device, dtype=hidden_states.dtype
)[None]
hidden_states = hidden_states + (spatial_noise * self.per_channel_scale1)[None, :, None, ...]
hidden_states = self.norm2(hidden_states.movedim(1, -1)).movedim(-1, 1)
if self.scale_shift_table is not None:
hidden_states = hidden_states * (1 + scale_2) + shift_2
hidden_states = self.nonlinearity(hidden_states)
hidden_states = self.dropout(hidden_states)
hidden_states = self.conv2(hidden_states)
if self.per_channel_scale2 is not None:
spatial_shape = hidden_states.shape[-2:]
spatial_noise = torch.randn(
spatial_shape, generator=generator, device=hidden_states.device, dtype=hidden_states.dtype
)[None]
hidden_states = hidden_states + (spatial_noise * self.per_channel_scale2)[None, :, None, ...]
if self.norm3 is not None:
inputs = self.norm3(inputs.movedim(1, -1)).movedim(-1, 1)
@@ -157,20 +196,24 @@ class LTXResnetBlock3d(nn.Module):
return hidden_states
class LTXUpsampler3d(nn.Module):
class LTXVideoUpsampler3d(nn.Module):
def __init__(
self,
in_channels: int,
stride: Union[int, Tuple[int, int, int]] = 1,
is_causal: bool = True,
residual: bool = False,
upscale_factor: int = 1,
) -> None:
super().__init__()
self.stride = stride if isinstance(stride, tuple) else (stride, stride, stride)
self.residual = residual
self.upscale_factor = upscale_factor
out_channels = in_channels * stride[0] * stride[1] * stride[2]
out_channels = (in_channels * stride[0] * stride[1] * stride[2]) // upscale_factor
self.conv = LTXCausalConv3d(
self.conv = LTXVideoCausalConv3d(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=3,
@@ -181,6 +224,15 @@ class LTXUpsampler3d(nn.Module):
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
batch_size, num_channels, num_frames, height, width = hidden_states.shape
if self.residual:
residual = hidden_states.reshape(
batch_size, -1, self.stride[0], self.stride[1], self.stride[2], num_frames, height, width
)
residual = residual.permute(0, 1, 5, 2, 6, 3, 7, 4).flatten(6, 7).flatten(4, 5).flatten(2, 3)
repeats = (self.stride[0] * self.stride[1] * self.stride[2]) // self.upscale_factor
residual = residual.repeat(1, repeats, 1, 1, 1)
residual = residual[:, :, self.stride[0] - 1 :]
hidden_states = self.conv(hidden_states)
hidden_states = hidden_states.reshape(
batch_size, -1, self.stride[0], self.stride[1], self.stride[2], num_frames, height, width
@@ -188,12 +240,15 @@ class LTXUpsampler3d(nn.Module):
hidden_states = hidden_states.permute(0, 1, 5, 2, 6, 3, 7, 4).flatten(6, 7).flatten(4, 5).flatten(2, 3)
hidden_states = hidden_states[:, :, self.stride[0] - 1 :]
if self.residual:
hidden_states = hidden_states + residual
return hidden_states
class LTXDownBlock3D(nn.Module):
class LTXVideoDownBlock3D(nn.Module):
r"""
Down block used in the LTX model.
Down block used in the LTXVideo model.
Args:
in_channels (`int`):
@@ -235,7 +290,7 @@ class LTXDownBlock3D(nn.Module):
resnets = []
for _ in range(num_layers):
resnets.append(
LTXResnetBlock3d(
LTXVideoResnetBlock3d(
in_channels=in_channels,
out_channels=in_channels,
dropout=dropout,
@@ -250,7 +305,7 @@ class LTXDownBlock3D(nn.Module):
if spatio_temporal_scale:
self.downsamplers = nn.ModuleList(
[
LTXCausalConv3d(
LTXVideoCausalConv3d(
in_channels=in_channels,
out_channels=in_channels,
kernel_size=3,
@@ -262,7 +317,7 @@ class LTXDownBlock3D(nn.Module):
self.conv_out = None
if in_channels != out_channels:
self.conv_out = LTXResnetBlock3d(
self.conv_out = LTXVideoResnetBlock3d(
in_channels=in_channels,
out_channels=out_channels,
dropout=dropout,
@@ -273,7 +328,12 @@ class LTXDownBlock3D(nn.Module):
self.gradient_checkpointing = False
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
def forward(
self,
hidden_states: torch.Tensor,
temb: Optional[torch.Tensor] = None,
generator: Optional[torch.Generator] = None,
) -> torch.Tensor:
r"""Forward method of the `LTXDownBlock3D` class."""
for i, resnet in enumerate(self.resnets):
@@ -285,24 +345,26 @@ class LTXDownBlock3D(nn.Module):
return create_forward
hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states)
hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(resnet), hidden_states, temb, generator
)
else:
hidden_states = resnet(hidden_states)
hidden_states = resnet(hidden_states, temb, generator)
if self.downsamplers is not None:
for downsampler in self.downsamplers:
hidden_states = downsampler(hidden_states)
if self.conv_out is not None:
hidden_states = self.conv_out(hidden_states)
hidden_states = self.conv_out(hidden_states, temb, generator)
return hidden_states
# Adapted from diffusers.models.autoencoders.autoencoder_kl_cogvideox.CogVideoMidBlock3d
class LTXMidBlock3d(nn.Module):
class LTXVideoMidBlock3d(nn.Module):
r"""
A middle block used in the LTX model.
A middle block used in the LTXVideo model.
Args:
in_channels (`int`):
@@ -329,28 +391,51 @@ class LTXMidBlock3d(nn.Module):
resnet_eps: float = 1e-6,
resnet_act_fn: str = "swish",
is_causal: bool = True,
inject_noise: bool = False,
timestep_conditioning: bool = False,
) -> None:
super().__init__()
self.time_embedder = None
if timestep_conditioning:
self.time_embedder = PixArtAlphaCombinedTimestepSizeEmbeddings(in_channels * 4, 0)
resnets = []
for _ in range(num_layers):
resnets.append(
LTXResnetBlock3d(
LTXVideoResnetBlock3d(
in_channels=in_channels,
out_channels=in_channels,
dropout=dropout,
eps=resnet_eps,
non_linearity=resnet_act_fn,
is_causal=is_causal,
inject_noise=inject_noise,
timestep_conditioning=timestep_conditioning,
)
)
self.resnets = nn.ModuleList(resnets)
self.gradient_checkpointing = False
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
def forward(
self,
hidden_states: torch.Tensor,
temb: Optional[torch.Tensor] = None,
generator: Optional[torch.Generator] = None,
) -> torch.Tensor:
r"""Forward method of the `LTXMidBlock3D` class."""
if self.time_embedder is not None:
temb = self.time_embedder(
timestep=temb.flatten(),
resolution=None,
aspect_ratio=None,
batch_size=hidden_states.size(0),
hidden_dtype=hidden_states.dtype,
)
temb = temb.view(hidden_states.size(0), -1, 1, 1, 1)
for i, resnet in enumerate(self.resnets):
if torch.is_grad_enabled() and self.gradient_checkpointing:
@@ -360,16 +445,18 @@ class LTXMidBlock3d(nn.Module):
return create_forward
hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states)
hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(resnet), hidden_states, temb, generator
)
else:
hidden_states = resnet(hidden_states)
hidden_states = resnet(hidden_states, temb, generator)
return hidden_states
class LTXUpBlock3d(nn.Module):
class LTXVideoUpBlock3d(nn.Module):
r"""
Up block used in the LTX model.
Up block used in the LTXVideo model.
Args:
in_channels (`int`):
@@ -403,45 +490,82 @@ class LTXUpBlock3d(nn.Module):
resnet_act_fn: str = "swish",
spatio_temporal_scale: bool = True,
is_causal: bool = True,
inject_noise: bool = False,
timestep_conditioning: bool = False,
upsample_residual: bool = False,
upscale_factor: int = 1,
):
super().__init__()
out_channels = out_channels or in_channels
self.time_embedder = None
if timestep_conditioning:
self.time_embedder = PixArtAlphaCombinedTimestepSizeEmbeddings(in_channels * 4, 0)
self.conv_in = None
if in_channels != out_channels:
self.conv_in = LTXResnetBlock3d(
self.conv_in = LTXVideoResnetBlock3d(
in_channels=in_channels,
out_channels=out_channels,
dropout=dropout,
eps=resnet_eps,
non_linearity=resnet_act_fn,
is_causal=is_causal,
inject_noise=inject_noise,
timestep_conditioning=timestep_conditioning,
)
self.upsamplers = None
if spatio_temporal_scale:
self.upsamplers = nn.ModuleList([LTXUpsampler3d(out_channels, stride=(2, 2, 2), is_causal=is_causal)])
self.upsamplers = nn.ModuleList(
[
LTXVideoUpsampler3d(
out_channels * upscale_factor,
stride=(2, 2, 2),
is_causal=is_causal,
residual=upsample_residual,
upscale_factor=upscale_factor,
)
]
)
resnets = []
for _ in range(num_layers):
resnets.append(
LTXResnetBlock3d(
LTXVideoResnetBlock3d(
in_channels=out_channels,
out_channels=out_channels,
dropout=dropout,
eps=resnet_eps,
non_linearity=resnet_act_fn,
is_causal=is_causal,
inject_noise=inject_noise,
timestep_conditioning=timestep_conditioning,
)
)
self.resnets = nn.ModuleList(resnets)
self.gradient_checkpointing = False
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
def forward(
self,
hidden_states: torch.Tensor,
temb: Optional[torch.Tensor] = None,
generator: Optional[torch.Generator] = None,
) -> torch.Tensor:
if self.conv_in is not None:
hidden_states = self.conv_in(hidden_states)
hidden_states = self.conv_in(hidden_states, temb, generator)
if self.time_embedder is not None:
temb = self.time_embedder(
timestep=temb.flatten(),
resolution=None,
aspect_ratio=None,
batch_size=hidden_states.size(0),
hidden_dtype=hidden_states.dtype,
)
temb = temb.view(hidden_states.size(0), -1, 1, 1, 1)
if self.upsamplers is not None:
for upsampler in self.upsamplers:
@@ -456,16 +580,18 @@ class LTXUpBlock3d(nn.Module):
return create_forward
hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states)
hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(resnet), hidden_states, temb, generator
)
else:
hidden_states = resnet(hidden_states)
hidden_states = resnet(hidden_states, temb, generator)
return hidden_states
class LTXEncoder3d(nn.Module):
class LTXVideoEncoder3d(nn.Module):
r"""
The `LTXEncoder3D` layer of a variational autoencoder that encodes input video samples to its latent
The `LTXVideoEncoder3d` layer of a variational autoencoder that encodes input video samples to its latent
representation.
Args:
@@ -509,7 +635,7 @@ class LTXEncoder3d(nn.Module):
output_channel = block_out_channels[0]
self.conv_in = LTXCausalConv3d(
self.conv_in = LTXVideoCausalConv3d(
in_channels=self.in_channels,
out_channels=output_channel,
kernel_size=3,
@@ -524,7 +650,7 @@ class LTXEncoder3d(nn.Module):
input_channel = output_channel
output_channel = block_out_channels[i + 1] if i + 1 < num_block_out_channels else block_out_channels[i]
down_block = LTXDownBlock3D(
down_block = LTXVideoDownBlock3D(
in_channels=input_channel,
out_channels=output_channel,
num_layers=layers_per_block[i],
@@ -536,7 +662,7 @@ class LTXEncoder3d(nn.Module):
self.down_blocks.append(down_block)
# mid block
self.mid_block = LTXMidBlock3d(
self.mid_block = LTXVideoMidBlock3d(
in_channels=output_channel,
num_layers=layers_per_block[-1],
resnet_eps=resnet_norm_eps,
@@ -546,14 +672,14 @@ class LTXEncoder3d(nn.Module):
# out
self.norm_out = RMSNorm(out_channels, eps=1e-8, elementwise_affine=False)
self.conv_act = nn.SiLU()
self.conv_out = LTXCausalConv3d(
self.conv_out = LTXVideoCausalConv3d(
in_channels=output_channel, out_channels=out_channels + 1, kernel_size=3, stride=1, is_causal=is_causal
)
self.gradient_checkpointing = False
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
r"""The forward method of the `LTXEncoder3D` class."""
r"""The forward method of the `LTXVideoEncoder3d` class."""
p = self.patch_size
p_t = self.patch_size_t
@@ -599,9 +725,10 @@ class LTXEncoder3d(nn.Module):
return hidden_states
class LTXDecoder3d(nn.Module):
class LTXVideoDecoder3d(nn.Module):
r"""
The `LTXDecoder3d` layer of a variational autoencoder that decodes its latent representation into an output sample.
The `LTXVideoDecoder3d` layer of a variational autoencoder that decodes its latent representation into an output
sample.
Args:
in_channels (`int`, defaults to 128):
@@ -622,6 +749,8 @@ class LTXDecoder3d(nn.Module):
Epsilon value for ResNet normalization layers.
is_causal (`bool`, defaults to `False`):
Whether this layer behaves causally (future frames depend only on past frames) or not.
timestep_conditioning (`bool`, defaults to `False`):
Whether to condition the model on timesteps.
"""
def __init__(
@@ -635,6 +764,10 @@ class LTXDecoder3d(nn.Module):
patch_size_t: int = 1,
resnet_norm_eps: float = 1e-6,
is_causal: bool = False,
inject_noise: Tuple[bool, ...] = (False, False, False, False),
timestep_conditioning: bool = False,
upsample_residual: Tuple[bool, ...] = (False, False, False, False),
upsample_factor: Tuple[bool, ...] = (1, 1, 1, 1),
) -> None:
super().__init__()
@@ -645,30 +778,42 @@ class LTXDecoder3d(nn.Module):
block_out_channels = tuple(reversed(block_out_channels))
spatio_temporal_scaling = tuple(reversed(spatio_temporal_scaling))
layers_per_block = tuple(reversed(layers_per_block))
inject_noise = tuple(reversed(inject_noise))
upsample_residual = tuple(reversed(upsample_residual))
upsample_factor = tuple(reversed(upsample_factor))
output_channel = block_out_channels[0]
self.conv_in = LTXCausalConv3d(
self.conv_in = LTXVideoCausalConv3d(
in_channels=in_channels, out_channels=output_channel, kernel_size=3, stride=1, is_causal=is_causal
)
self.mid_block = LTXMidBlock3d(
in_channels=output_channel, num_layers=layers_per_block[0], resnet_eps=resnet_norm_eps, is_causal=is_causal
self.mid_block = LTXVideoMidBlock3d(
in_channels=output_channel,
num_layers=layers_per_block[0],
resnet_eps=resnet_norm_eps,
is_causal=is_causal,
inject_noise=inject_noise[0],
timestep_conditioning=timestep_conditioning,
)
# up blocks
num_block_out_channels = len(block_out_channels)
self.up_blocks = nn.ModuleList([])
for i in range(num_block_out_channels):
input_channel = output_channel
output_channel = block_out_channels[i]
input_channel = output_channel // upsample_factor[i]
output_channel = block_out_channels[i] // upsample_factor[i]
up_block = LTXUpBlock3d(
up_block = LTXVideoUpBlock3d(
in_channels=input_channel,
out_channels=output_channel,
num_layers=layers_per_block[i + 1],
resnet_eps=resnet_norm_eps,
spatio_temporal_scale=spatio_temporal_scaling[i],
is_causal=is_causal,
inject_noise=inject_noise[i + 1],
timestep_conditioning=timestep_conditioning,
upsample_residual=upsample_residual[i],
upscale_factor=upsample_factor[i],
)
self.up_blocks.append(up_block)
@@ -676,13 +821,20 @@ class LTXDecoder3d(nn.Module):
# out
self.norm_out = RMSNorm(out_channels, eps=1e-8, elementwise_affine=False)
self.conv_act = nn.SiLU()
self.conv_out = LTXCausalConv3d(
self.conv_out = LTXVideoCausalConv3d(
in_channels=output_channel, out_channels=self.out_channels, kernel_size=3, stride=1, is_causal=is_causal
)
# timestep embedding
self.time_embedder = None
self.scale_shift_table = None
if timestep_conditioning:
self.time_embedder = PixArtAlphaCombinedTimestepSizeEmbeddings(output_channel * 2, 0)
self.scale_shift_table = nn.Parameter(torch.randn(2, output_channel) / output_channel**0.5)
self.gradient_checkpointing = False
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
def forward(self, hidden_states: torch.Tensor, temb: Optional[torch.Tensor] = None) -> torch.Tensor:
hidden_states = self.conv_in(hidden_states)
if torch.is_grad_enabled() and self.gradient_checkpointing:
@@ -693,17 +845,33 @@ class LTXDecoder3d(nn.Module):
return create_forward
hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(self.mid_block), hidden_states)
hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(self.mid_block), hidden_states, temb
)
for up_block in self.up_blocks:
hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(up_block), hidden_states)
hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(up_block), hidden_states, temb)
else:
hidden_states = self.mid_block(hidden_states)
hidden_states = self.mid_block(hidden_states, temb)
for up_block in self.up_blocks:
hidden_states = up_block(hidden_states)
hidden_states = up_block(hidden_states, temb)
hidden_states = self.norm_out(hidden_states.movedim(1, -1)).movedim(-1, 1)
if self.time_embedder is not None:
temb = self.time_embedder(
timestep=temb.flatten(),
resolution=None,
aspect_ratio=None,
batch_size=hidden_states.size(0),
hidden_dtype=hidden_states.dtype,
)
temb = temb.view(hidden_states.size(0), -1, 1, 1, 1).unflatten(1, (2, -1))
temb = temb + self.scale_shift_table[None, ..., None, None, None]
shift, scale = temb.unbind(dim=1)
hidden_states = hidden_states * (1 + scale) + shift
hidden_states = self.conv_act(hidden_states)
hidden_states = self.conv_out(hidden_states)
@@ -766,8 +934,15 @@ class AutoencoderKLLTXVideo(ModelMixin, ConfigMixin, FromOriginalModelMixin):
out_channels: int = 3,
latent_channels: int = 128,
block_out_channels: Tuple[int, ...] = (128, 256, 512, 512),
spatio_temporal_scaling: Tuple[bool, ...] = (True, True, True, False),
decoder_block_out_channels: Tuple[int, ...] = (128, 256, 512, 512),
layers_per_block: Tuple[int, ...] = (4, 3, 3, 3, 4),
decoder_layers_per_block: Tuple[int, ...] = (4, 3, 3, 3, 4),
spatio_temporal_scaling: Tuple[bool, ...] = (True, True, True, False),
decoder_spatio_temporal_scaling: Tuple[bool, ...] = (True, True, True, False),
decoder_inject_noise: Tuple[bool, ...] = (False, False, False, False, False),
upsample_residual: Tuple[bool, ...] = (False, False, False, False),
upsample_factor: Tuple[int, ...] = (1, 1, 1, 1),
timestep_conditioning: bool = False,
patch_size: int = 4,
patch_size_t: int = 1,
resnet_norm_eps: float = 1e-6,
@@ -777,7 +952,7 @@ class AutoencoderKLLTXVideo(ModelMixin, ConfigMixin, FromOriginalModelMixin):
) -> None:
super().__init__()
self.encoder = LTXEncoder3d(
self.encoder = LTXVideoEncoder3d(
in_channels=in_channels,
out_channels=latent_channels,
block_out_channels=block_out_channels,
@@ -788,16 +963,20 @@ class AutoencoderKLLTXVideo(ModelMixin, ConfigMixin, FromOriginalModelMixin):
resnet_norm_eps=resnet_norm_eps,
is_causal=encoder_causal,
)
self.decoder = LTXDecoder3d(
self.decoder = LTXVideoDecoder3d(
in_channels=latent_channels,
out_channels=out_channels,
block_out_channels=block_out_channels,
spatio_temporal_scaling=spatio_temporal_scaling,
layers_per_block=layers_per_block,
block_out_channels=decoder_block_out_channels,
spatio_temporal_scaling=decoder_spatio_temporal_scaling,
layers_per_block=decoder_layers_per_block,
patch_size=patch_size,
patch_size_t=patch_size_t,
resnet_norm_eps=resnet_norm_eps,
is_causal=decoder_causal,
timestep_conditioning=timestep_conditioning,
inject_noise=decoder_inject_noise,
upsample_residual=upsample_residual,
upsample_factor=upsample_factor,
)
latents_mean = torch.zeros((latent_channels,), requires_grad=False)
@@ -837,7 +1016,7 @@ class AutoencoderKLLTXVideo(ModelMixin, ConfigMixin, FromOriginalModelMixin):
self.tile_sample_stride_width = 448
def _set_gradient_checkpointing(self, module, value=False):
if isinstance(module, (LTXEncoder3d, LTXDecoder3d)):
if isinstance(module, (LTXVideoEncoder3d, LTXVideoDecoder3d)):
module.gradient_checkpointing = value
def enable_tiling(
@@ -936,13 +1115,15 @@ class AutoencoderKLLTXVideo(ModelMixin, ConfigMixin, FromOriginalModelMixin):
return (posterior,)
return AutoencoderKLOutput(latent_dist=posterior)
def _decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOutput, torch.Tensor]:
def _decode(
self, z: torch.Tensor, temb: Optional[torch.Tensor] = None, return_dict: bool = True
) -> Union[DecoderOutput, torch.Tensor]:
batch_size, num_channels, num_frames, height, width = z.shape
tile_latent_min_height = self.tile_sample_min_height // self.spatial_compression_ratio
tile_latent_min_width = self.tile_sample_stride_width // self.spatial_compression_ratio
if self.use_tiling and (width > tile_latent_min_width or height > tile_latent_min_height):
return self.tiled_decode(z, return_dict=return_dict)
return self.tiled_decode(z, temb, return_dict=return_dict)
if self.use_framewise_decoding:
# TODO(aryan): requires investigation
@@ -952,7 +1133,7 @@ class AutoencoderKLLTXVideo(ModelMixin, ConfigMixin, FromOriginalModelMixin):
"should be possible, please submit a PR to https://github.com/huggingface/diffusers/pulls."
)
else:
dec = self.decoder(z)
dec = self.decoder(z, temb)
if not return_dict:
return (dec,)
@@ -960,7 +1141,9 @@ class AutoencoderKLLTXVideo(ModelMixin, ConfigMixin, FromOriginalModelMixin):
return DecoderOutput(sample=dec)
@apply_forward_hook
def decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOutput, torch.Tensor]:
def decode(
self, z: torch.Tensor, temb: Optional[torch.Tensor] = None, return_dict: bool = True
) -> Union[DecoderOutput, torch.Tensor]:
"""
Decode a batch of images.
@@ -975,10 +1158,15 @@ class AutoencoderKLLTXVideo(ModelMixin, ConfigMixin, FromOriginalModelMixin):
returned.
"""
if self.use_slicing and z.shape[0] > 1:
decoded_slices = [self._decode(z_slice).sample for z_slice in z.split(1)]
if temb is not None:
decoded_slices = [
self._decode(z_slice, t_slice).sample for z_slice, t_slice in (z.split(1), temb.split(1))
]
else:
decoded_slices = [self._decode(z_slice).sample for z_slice in z.split(1)]
decoded = torch.cat(decoded_slices)
else:
decoded = self._decode(z).sample
decoded = self._decode(z, temb).sample
if not return_dict:
return (decoded,)
@@ -1060,7 +1248,9 @@ class AutoencoderKLLTXVideo(ModelMixin, ConfigMixin, FromOriginalModelMixin):
enc = torch.cat(result_rows, dim=3)[:, :, :, :latent_height, :latent_width]
return enc
def tiled_decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOutput, torch.Tensor]:
def tiled_decode(
self, z: torch.Tensor, temb: Optional[torch.Tensor], return_dict: bool = True
) -> Union[DecoderOutput, torch.Tensor]:
r"""
Decode a batch of images using a tiled decoder.
@@ -1101,7 +1291,9 @@ class AutoencoderKLLTXVideo(ModelMixin, ConfigMixin, FromOriginalModelMixin):
"should be possible, please submit a PR to https://github.com/huggingface/diffusers/pulls."
)
else:
time = self.decoder(z[:, :, :, i : i + tile_latent_min_height, j : j + tile_latent_min_width])
time = self.decoder(
z[:, :, :, i : i + tile_latent_min_height, j : j + tile_latent_min_width], temb
)
row.append(time)
rows.append(row)
@@ -1129,6 +1321,7 @@ class AutoencoderKLLTXVideo(ModelMixin, ConfigMixin, FromOriginalModelMixin):
def forward(
self,
sample: torch.Tensor,
temb: Optional[torch.Tensor] = None,
sample_posterior: bool = False,
return_dict: bool = True,
generator: Optional[torch.Generator] = None,
@@ -1139,7 +1332,7 @@ class AutoencoderKLLTXVideo(ModelMixin, ConfigMixin, FromOriginalModelMixin):
z = posterior.sample(generator=generator)
else:
z = posterior.mode()
dec = self.decode(z)
dec = self.decode(z, temb)
if not return_dict:
return (dec,)
return dec
+1 -1
View File
@@ -748,10 +748,10 @@ class CogVideoXPatchEmbed(nn.Module):
pos_embedding = self._get_positional_embeddings(
height, width, pre_time_compression_frames, device=embeds.device
)
pos_embedding = pos_embedding.to(dtype=embeds.dtype)
else:
pos_embedding = self.pos_embedding
pos_embedding = pos_embedding.to(dtype=embeds.dtype)
embeds = embeds + pos_embedding
return embeds
+1 -1
View File
@@ -228,7 +228,7 @@ def load_model_dict_into_meta(
else:
model_name_or_path_str = f"{model_name_or_path} " if model_name_or_path is not None else ""
raise ValueError(
f"Cannot load {model_name_or_path_str} because {param_name} expected shape {empty_state_dict[param_name]}, but got {param.shape}. If you want to instead overwrite randomly initialized weights, please make sure to pass both `low_cpu_mem_usage=False` and `ignore_mismatched_sizes=True`. For more information, see also: https://github.com/huggingface/diffusers/issues/1619#issuecomment-1345604389 as an example."
f"Cannot load {model_name_or_path_str} because {param_name} expected shape {empty_state_dict[param_name].shape}, but got {param.shape}. If you want to instead overwrite randomly initialized weights, please make sure to pass both `low_cpu_mem_usage=False` and `ignore_mismatched_sizes=True`. For more information, see also: https://github.com/huggingface/diffusers/issues/1619#issuecomment-1345604389 as an example."
)
if is_quantized and (
+35 -17
View File
@@ -99,21 +99,39 @@ def get_parameter_device(parameter: torch.nn.Module) -> torch.device:
def get_parameter_dtype(parameter: torch.nn.Module) -> torch.dtype:
try:
return next(parameter.parameters()).dtype
except StopIteration:
try:
return next(parameter.buffers()).dtype
except StopIteration:
# For torch.nn.DataParallel compatibility in PyTorch 1.5
"""
Returns the first found floating dtype in parameters if there is one, otherwise returns the last dtype it found.
"""
last_dtype = None
for param in parameter.parameters():
last_dtype = param.dtype
if param.is_floating_point():
return param.dtype
def find_tensor_attributes(module: torch.nn.Module) -> List[Tuple[str, Tensor]]:
tuples = [(k, v) for k, v in module.__dict__.items() if torch.is_tensor(v)]
return tuples
for buffer in parameter.buffers():
last_dtype = buffer.dtype
if buffer.is_floating_point():
return buffer.dtype
gen = parameter._named_members(get_members_fn=find_tensor_attributes)
first_tuple = next(gen)
return first_tuple[1].dtype
if last_dtype is not None:
# if no floating dtype was found return whatever the first dtype is
return last_dtype
# For nn.DataParallel compatibility in PyTorch > 1.5
def find_tensor_attributes(module: nn.Module) -> List[Tuple[str, Tensor]]:
tuples = [(k, v) for k, v in module.__dict__.items() if torch.is_tensor(v)]
return tuples
gen = parameter._named_members(get_members_fn=find_tensor_attributes)
last_tuple = None
for tuple in gen:
last_tuple = tuple
if tuple[1].is_floating_point():
return tuple[1].dtype
if last_tuple is not None:
# fallback to the last dtype
return last_tuple[1].dtype
class ModelMixin(torch.nn.Module, PushToHubMixin):
@@ -700,10 +718,9 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
hf_quantizer = None
if hf_quantizer is not None:
is_bnb_quantization_method = hf_quantizer.quantization_config.quant_method.value == "bitsandbytes"
if is_bnb_quantization_method and device_map is not None:
if device_map is not None:
raise NotImplementedError(
"Currently, `device_map` is automatically inferred for quantized bitsandbytes models. Support for providing `device_map` as an input will be added in the future."
"Currently, providing `device_map` is not supported for quantized models. Providing `device_map` as an input will be added in the future."
)
hf_quantizer.validate_environment(torch_dtype=torch_dtype, from_flax=from_flax, device_map=device_map)
@@ -802,7 +819,8 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
revision=revision,
subfolder=subfolder or "",
)
if hf_quantizer is not None and is_bnb_quantization_method:
# TODO: https://github.com/huggingface/diffusers/issues/10013
if hf_quantizer is not None:
model_file = _merge_sharded_checkpoints(sharded_ckpt_cached_folder, sharded_metadata)
logger.info("Merged sharded checkpoints as `hf_quantizer` is not None.")
is_sharded = False
@@ -242,6 +242,7 @@ class SanaTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
patch_size: int = 1,
norm_elementwise_affine: bool = False,
norm_eps: float = 1e-6,
interpolation_scale: Optional[int] = None,
) -> None:
super().__init__()
@@ -249,14 +250,14 @@ class SanaTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
inner_dim = num_attention_heads * attention_head_dim
# 1. Patch Embedding
interpolation_scale = interpolation_scale if interpolation_scale is not None else max(sample_size // 64, 1)
self.patch_embed = PatchEmbed(
height=sample_size,
width=sample_size,
patch_size=patch_size,
in_channels=in_channels,
embed_dim=inner_dim,
interpolation_scale=None,
pos_embed_type=None,
interpolation_scale=interpolation_scale,
)
# 2. Additional condition embeddings
@@ -18,6 +18,8 @@ import torch
import torch.nn as nn
import torch.nn.functional as F
from diffusers.loaders import FromOriginalModelMixin
from ...configuration_utils import ConfigMixin, register_to_config
from ...loaders import PeftAdapterMixin
from ...utils import USE_PEFT_BACKEND, is_torch_version, logging, scale_lora_layers, unscale_lora_layers
@@ -500,7 +502,7 @@ class HunyuanVideoTransformerBlock(nn.Module):
return hidden_states, encoder_hidden_states
class HunyuanVideoTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
class HunyuanVideoTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin):
r"""
A Transformer model for video-like data used in [HunyuanVideo](https://huggingface.co/tencent/HunyuanVideo).
@@ -711,14 +713,16 @@ class HunyuanVideoTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
condition_sequence_length = encoder_hidden_states.shape[1]
sequence_length = latent_sequence_length + condition_sequence_length
attention_mask = torch.zeros(
batch_size, sequence_length, sequence_length, device=hidden_states.device, dtype=torch.bool
) # [B, N, N]
batch_size, sequence_length, device=hidden_states.device, dtype=torch.bool
) # [B, N]
effective_condition_sequence_length = encoder_attention_mask.sum(dim=1, dtype=torch.int) # [B,]
effective_sequence_length = latent_sequence_length + effective_condition_sequence_length
for i in range(batch_size):
attention_mask[i, : effective_sequence_length[i], : effective_sequence_length[i]] = True
attention_mask[i, : effective_sequence_length[i]] = True
# [B, 1, 1, N], for broadcasting across attention heads
attention_mask = attention_mask.unsqueeze(1).unsqueeze(1)
# 4. Transformer blocks
if torch.is_grad_enabled() and self.gradient_checkpointing:
@@ -35,7 +35,7 @@ from ..normalization import AdaLayerNormSingle, RMSNorm
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
class LTXAttentionProcessor2_0:
class LTXVideoAttentionProcessor2_0:
r"""
Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0). This is
used in the LTX model. It applies a normalization layer and rotary embedding on the query and key vector.
@@ -44,7 +44,7 @@ class LTXAttentionProcessor2_0:
def __init__(self):
if not hasattr(F, "scaled_dot_product_attention"):
raise ImportError(
"LTXAttentionProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0."
"LTXVideoAttentionProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0."
)
def __call__(
@@ -92,7 +92,7 @@ class LTXAttentionProcessor2_0:
return hidden_states
class LTXRotaryPosEmbed(nn.Module):
class LTXVideoRotaryPosEmbed(nn.Module):
def __init__(
self,
dim: int,
@@ -164,7 +164,7 @@ class LTXRotaryPosEmbed(nn.Module):
@maybe_allow_in_graph
class LTXTransformerBlock(nn.Module):
class LTXVideoTransformerBlock(nn.Module):
r"""
Transformer block used in [LTX](https://huggingface.co/Lightricks/LTX-Video).
@@ -208,7 +208,7 @@ class LTXTransformerBlock(nn.Module):
cross_attention_dim=None,
out_bias=attention_out_bias,
qk_norm=qk_norm,
processor=LTXAttentionProcessor2_0(),
processor=LTXVideoAttentionProcessor2_0(),
)
self.norm2 = RMSNorm(dim, eps=eps, elementwise_affine=elementwise_affine)
@@ -221,7 +221,7 @@ class LTXTransformerBlock(nn.Module):
bias=attention_bias,
out_bias=attention_out_bias,
qk_norm=qk_norm,
processor=LTXAttentionProcessor2_0(),
processor=LTXVideoAttentionProcessor2_0(),
)
self.ff = FeedForward(dim, activation_fn=activation_fn)
@@ -327,7 +327,7 @@ class LTXVideoTransformer3DModel(ModelMixin, ConfigMixin, FromOriginalModelMixin
self.caption_projection = PixArtAlphaTextProjection(in_features=caption_channels, hidden_size=inner_dim)
self.rope = LTXRotaryPosEmbed(
self.rope = LTXVideoRotaryPosEmbed(
dim=inner_dim,
base_num_frames=20,
base_height=2048,
@@ -339,7 +339,7 @@ class LTXVideoTransformer3DModel(ModelMixin, ConfigMixin, FromOriginalModelMixin
self.transformer_blocks = nn.ModuleList(
[
LTXTransformerBlock(
LTXVideoTransformerBlock(
dim=inner_dim,
num_attention_heads=num_attention_heads,
attention_head_dim=attention_head_dim,
@@ -39,7 +39,7 @@ EXAMPLE_DOC_STRING = """
>>> from diffusers import HunyuanVideoPipeline, HunyuanVideoTransformer3DModel
>>> from diffusers.utils import export_to_video
>>> model_id = "tencent/HunyuanVideo"
>>> model_id = "hunyuanvideo-community/HunyuanVideo"
>>> transformer = HunyuanVideoTransformer3DModel.from_pretrained(
... model_id, subfolder="transformer", torch_dtype=torch.bfloat16
... )
@@ -193,15 +193,15 @@ class KandinskyCombinedPipeline(DiffusionPipeline):
def enable_xformers_memory_efficient_attention(self, attention_op: Optional[Callable] = None):
self.decoder_pipe.enable_xformers_memory_efficient_attention(attention_op)
def enable_sequential_cpu_offload(self, gpu_id=0):
def enable_sequential_cpu_offload(self, gpu_id: Optional[int] = None, device: Union[torch.device, str] = "cuda"):
r"""
Offloads all models (`unet`, `text_encoder`, `vae`, and `safety checker` state dicts) to CPU using 🤗
Accelerate, significantly reducing memory usage. Models are moved to a `torch.device('meta')` and loaded on a
GPU only when their specific submodule's `forward` method is called. Offloading happens on a submodule basis.
Memory savings are higher than using `enable_model_cpu_offload`, but performance is lower.
"""
self.prior_pipe.enable_sequential_cpu_offload(gpu_id=gpu_id)
self.decoder_pipe.enable_sequential_cpu_offload(gpu_id=gpu_id)
self.prior_pipe.enable_sequential_cpu_offload(gpu_id=gpu_id, device=device)
self.decoder_pipe.enable_sequential_cpu_offload(gpu_id=gpu_id, device=device)
def progress_bar(self, iterable=None, total=None):
self.prior_pipe.progress_bar(iterable=iterable, total=total)
@@ -411,7 +411,7 @@ class KandinskyImg2ImgCombinedPipeline(DiffusionPipeline):
def enable_xformers_memory_efficient_attention(self, attention_op: Optional[Callable] = None):
self.decoder_pipe.enable_xformers_memory_efficient_attention(attention_op)
def enable_sequential_cpu_offload(self, gpu_id=0):
def enable_sequential_cpu_offload(self, gpu_id: Optional[int] = None, device: Union[torch.device, str] = "cuda"):
r"""
Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, unet,
text_encoder, vae and safety checker have their state dicts saved to CPU and then are moved to a
@@ -419,8 +419,8 @@ class KandinskyImg2ImgCombinedPipeline(DiffusionPipeline):
Note that offloading happens on a submodule basis. Memory savings are higher than with
`enable_model_cpu_offload`, but performance is lower.
"""
self.prior_pipe.enable_sequential_cpu_offload(gpu_id=gpu_id)
self.decoder_pipe.enable_sequential_cpu_offload(gpu_id=gpu_id)
self.prior_pipe.enable_sequential_cpu_offload(gpu_id=gpu_id, device=device)
self.decoder_pipe.enable_sequential_cpu_offload(gpu_id=gpu_id, device=device)
def progress_bar(self, iterable=None, total=None):
self.prior_pipe.progress_bar(iterable=iterable, total=total)
@@ -652,7 +652,7 @@ class KandinskyInpaintCombinedPipeline(DiffusionPipeline):
def enable_xformers_memory_efficient_attention(self, attention_op: Optional[Callable] = None):
self.decoder_pipe.enable_xformers_memory_efficient_attention(attention_op)
def enable_sequential_cpu_offload(self, gpu_id=0):
def enable_sequential_cpu_offload(self, gpu_id: Optional[int] = None, device: Union[torch.device, str] = "cuda"):
r"""
Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, unet,
text_encoder, vae and safety checker have their state dicts saved to CPU and then are moved to a
@@ -660,8 +660,8 @@ class KandinskyInpaintCombinedPipeline(DiffusionPipeline):
Note that offloading happens on a submodule basis. Memory savings are higher than with
`enable_model_cpu_offload`, but performance is lower.
"""
self.prior_pipe.enable_sequential_cpu_offload(gpu_id=gpu_id)
self.decoder_pipe.enable_sequential_cpu_offload(gpu_id=gpu_id)
self.prior_pipe.enable_sequential_cpu_offload(gpu_id=gpu_id, device=device)
self.decoder_pipe.enable_sequential_cpu_offload(gpu_id=gpu_id, device=device)
def progress_bar(self, iterable=None, total=None):
self.prior_pipe.progress_bar(iterable=iterable, total=total)
+25 -1
View File
@@ -511,6 +511,8 @@ class LTXPipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoLoraLoaderMixi
prompt_attention_mask: Optional[torch.Tensor] = None,
negative_prompt_embeds: Optional[torch.Tensor] = None,
negative_prompt_attention_mask: Optional[torch.Tensor] = None,
decode_timestep: Union[float, List[float]] = 0.0,
decode_noise_scale: Optional[Union[float, List[float]]] = None,
output_type: Optional[str] = "pil",
return_dict: bool = True,
attention_kwargs: Optional[Dict[str, Any]] = None,
@@ -563,6 +565,10 @@ class LTXPipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoLoraLoaderMixi
provided, negative_prompt_embeds will be generated from `negative_prompt` input argument.
negative_prompt_attention_mask (`torch.FloatTensor`, *optional*):
Pre-generated attention mask for negative text embeddings.
decode_timestep (`float`, defaults to `0.0`):
The timestep at which generated video is decoded.
decode_noise_scale (`float`, defaults to `None`):
The interpolation factor between random noise and denoised latents at the decode timestep.
output_type (`str`, *optional*, defaults to `"pil"`):
The output format of the generate image. Choose between
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
@@ -753,7 +759,25 @@ class LTXPipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoLoraLoaderMixi
latents, self.vae.latents_mean, self.vae.latents_std, self.vae.config.scaling_factor
)
latents = latents.to(prompt_embeds.dtype)
video = self.vae.decode(latents, return_dict=False)[0]
if not self.vae.config.timestep_conditioning:
timestep = None
else:
noise = torch.randn(latents.shape, generator=generator, device=device, dtype=latents.dtype)
if not isinstance(decode_timestep, list):
decode_timestep = [decode_timestep] * batch_size
if decode_noise_scale is None:
decode_noise_scale = decode_timestep
elif not isinstance(decode_noise_scale, list):
decode_noise_scale = [decode_noise_scale] * batch_size
timestep = torch.tensor(decode_timestep, device=device, dtype=latents.dtype)
decode_noise_scale = torch.tensor(decode_noise_scale, device=device, dtype=latents.dtype)[
:, None, None, None, None
]
latents = (1 - decode_noise_scale) * latents + decode_noise_scale * noise
video = self.vae.decode(latents, timestep, return_dict=False)[0]
video = self.video_processor.postprocess_video(video, output_type=output_type)
# Offload all models
@@ -571,6 +571,8 @@ class LTXImageToVideoPipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoLo
prompt_attention_mask: Optional[torch.Tensor] = None,
negative_prompt_embeds: Optional[torch.Tensor] = None,
negative_prompt_attention_mask: Optional[torch.Tensor] = None,
decode_timestep: Union[float, List[float]] = 0.0,
decode_noise_scale: Optional[Union[float, List[float]]] = None,
output_type: Optional[str] = "pil",
return_dict: bool = True,
attention_kwargs: Optional[Dict[str, Any]] = None,
@@ -625,6 +627,10 @@ class LTXImageToVideoPipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoLo
provided, negative_prompt_embeds will be generated from `negative_prompt` input argument.
negative_prompt_attention_mask (`torch.FloatTensor`, *optional*):
Pre-generated attention mask for negative text embeddings.
decode_timestep (`float`, defaults to `0.0`):
The timestep at which generated video is decoded.
decode_noise_scale (`float`, defaults to `None`):
The interpolation factor between random noise and denoised latents at the decode timestep.
output_type (`str`, *optional*, defaults to `"pil"`):
The output format of the generate image. Choose between
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
@@ -849,7 +855,25 @@ class LTXImageToVideoPipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoLo
latents, self.vae.latents_mean, self.vae.latents_std, self.vae.config.scaling_factor
)
latents = latents.to(prompt_embeds.dtype)
video = self.vae.decode(latents, return_dict=False)[0]
if not self.vae.config.timestep_conditioning:
timestep = None
else:
noise = torch.randn(latents.shape, generator=generator, device=device, dtype=latents.dtype)
if not isinstance(decode_timestep, list):
decode_timestep = [decode_timestep] * batch_size
if decode_noise_scale is None:
decode_noise_scale = decode_timestep
elif not isinstance(decode_noise_scale, list):
decode_noise_scale = [decode_noise_scale] * batch_size
timestep = torch.tensor(decode_timestep, device=device, dtype=latents.dtype)
decode_noise_scale = torch.tensor(decode_noise_scale, device=device, dtype=latents.dtype)[
:, None, None, None, None
]
latents = (1 - decode_noise_scale) * latents + decode_noise_scale * noise
video = self.vae.decode(latents, timestep, return_dict=False)[0]
video = self.video_processor.postprocess_video(video, output_type=output_type)
# Offload all models
@@ -59,13 +59,13 @@ EXAMPLE_DOC_STRING = """
>>> from diffusers import SanaPAGPipeline
>>> pipe = SanaPAGPipeline.from_pretrained(
... "Efficient-Large-Model/Sana_1600M_1024px_diffusers",
... "Efficient-Large-Model/Sana_1600M_1024px_BF16_diffusers",
... pag_applied_layers=["transformer_blocks.8"],
... torch_dtype=torch.float32,
... )
>>> pipe.to("cuda")
>>> pipe.text_encoder.to(torch.bfloat16)
>>> pipe.transformer = pipe.transformer.to(torch.float16)
>>> pipe.transformer = pipe.transformer.to(torch.bfloat16)
>>> image = pipe(prompt='a cyberpunk cat with a neon sign that says "Sana"')[0]
>>> image[0].save("output.png")
@@ -62,11 +62,11 @@ EXAMPLE_DOC_STRING = """
>>> from diffusers import SanaPipeline
>>> pipe = SanaPipeline.from_pretrained(
... "Efficient-Large-Model/Sana_1600M_1024px_diffusers", torch_dtype=torch.float32
... "Efficient-Large-Model/Sana_1600M_1024px_BF16_diffusers", torch_dtype=torch.float32
... )
>>> pipe.to("cuda")
>>> pipe.text_encoder.to(torch.bfloat16)
>>> pipe.transformer = pipe.transformer.to(torch.float16)
>>> pipe.transformer = pipe.transformer.to(torch.bfloat16)
>>> image = pipe(prompt='a cyberpunk cat with a neon sign that says "Sana"')[0]
>>> image[0].save("output.png")
@@ -23,7 +23,7 @@ from typing import TYPE_CHECKING, Any, Dict, List, Union
from packaging import version
from ...utils import get_module_from_name, is_torch_available, is_torchao_available, logging
from ...utils import get_module_from_name, is_torch_available, is_torch_version, is_torchao_available, logging
from ..base import DiffusersQuantizer
@@ -35,21 +35,28 @@ if is_torch_available():
import torch
import torch.nn as nn
SUPPORTED_TORCH_DTYPES_FOR_QUANTIZATION = (
# At the moment, only int8 is supported for integer quantization dtypes.
# In Torch 2.6, int1-int7 will be introduced, so this can be visited in the future
# to support more quantization methods, such as intx_weight_only.
torch.int8,
torch.float8_e4m3fn,
torch.float8_e5m2,
torch.uint1,
torch.uint2,
torch.uint3,
torch.uint4,
torch.uint5,
torch.uint6,
torch.uint7,
)
if is_torch_version(">=", "2.5"):
SUPPORTED_TORCH_DTYPES_FOR_QUANTIZATION = (
# At the moment, only int8 is supported for integer quantization dtypes.
# In Torch 2.6, int1-int7 will be introduced, so this can be visited in the future
# to support more quantization methods, such as intx_weight_only.
torch.int8,
torch.float8_e4m3fn,
torch.float8_e5m2,
torch.uint1,
torch.uint2,
torch.uint3,
torch.uint4,
torch.uint5,
torch.uint6,
torch.uint7,
)
else:
SUPPORTED_TORCH_DTYPES_FOR_QUANTIZATION = (
torch.int8,
torch.float8_e4m3fn,
torch.float8_e5m2,
)
if is_torchao_available():
from torchao.quantization import quantize_
@@ -93,6 +100,11 @@ class TorchAoHfQuantizer(DiffusersQuantizer):
raise ImportError(
"Loading a TorchAO quantized model requires the torchao library. Please install with `pip install torchao`"
)
torchao_version = version.parse(importlib.metadata.version("torch"))
if torchao_version < version.parse("0.7.0"):
raise RuntimeError(
f"The minimum required version of `torchao` is 0.7.0, but the current version is {torchao_version}. Please upgrade with `pip install -U torchao`."
)
self.offload = False
@@ -120,7 +132,7 @@ class TorchAoHfQuantizer(DiffusersQuantizer):
def update_torch_dtype(self, torch_dtype):
quant_type = self.quantization_config.quant_type
if quant_type.startswith("int"):
if quant_type.startswith("int") or quant_type.startswith("uint"):
if torch_dtype is not None and torch_dtype != torch.bfloat16:
logger.warning(
f"You are trying to set torch_dtype to {torch_dtype} for int4/int8/uintx quantization, but "
+1 -1
View File
@@ -100,7 +100,7 @@ from .import_utils import (
is_xformers_available,
requires_backends,
)
from .loading_utils import get_module_from_name, load_image, load_video
from .loading_utils import get_module_from_name, get_submodule_by_name, load_image, load_video
from .logging import get_logger
from .outputs import BaseOutput
from .peft_utils import (
+12
View File
@@ -148,3 +148,15 @@ def get_module_from_name(module, tensor_name: str) -> Tuple[Any, str]:
module = new_module
tensor_name = splits[-1]
return module, tensor_name
def get_submodule_by_name(root_module, module_path: str):
current = root_module
parts = module_path.split(".")
for part in parts:
if part.isdigit():
idx = int(part)
current = current[idx] # e.g., for nn.ModuleList or nn.Sequential
else:
current = getattr(current, part)
return current
+2 -2
View File
@@ -490,11 +490,11 @@ def require_gguf_version_greater_or_equal(gguf_version):
return decorator
def require_torchao_version_greater(torchao_version):
def require_torchao_version_greater_or_equal(torchao_version):
def decorator(test_case):
correct_torchao_version = is_torchao_available() and version.parse(
version.parse(importlib.metadata.version("torchao")).base_version
) > version.parse(torchao_version)
) >= version.parse(torchao_version)
return unittest.skipUnless(
correct_torchao_version, f"Test requires torchao with version greater than {torchao_version}."
)(test_case)
+1 -41
View File
@@ -15,8 +15,6 @@
import sys
import unittest
import numpy as np
import pytest
import torch
from transformers import AutoTokenizer, T5EncoderModel
@@ -29,16 +27,13 @@ from diffusers import (
)
from diffusers.utils.testing_utils import (
floats_tensor,
is_torch_version,
require_peft_backend,
skip_mps,
torch_device,
)
sys.path.append(".")
from utils import PeftLoraLoaderMixinTests, check_if_lora_correctly_set # noqa: E402
from utils import PeftLoraLoaderMixinTests # noqa: E402
@require_peft_backend
@@ -123,41 +118,6 @@ class CogVideoXLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
return noise, input_ids, pipeline_inputs
@skip_mps
@pytest.mark.xfail(
condition=torch.device(torch_device).type == "cpu" and is_torch_version(">=", "2.5"),
reason="Test currently fails on CPU and PyTorch 2.5.1 but not on PyTorch 2.4.1.",
strict=True,
)
def test_lora_fuse_nan(self):
for scheduler_cls in self.scheduler_classes:
components, text_lora_config, denoiser_lora_config = self.get_dummy_components(scheduler_cls)
pipe = self.pipeline_class(**components)
pipe = pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)
_, _, inputs = self.get_dummy_inputs(with_generator=False)
pipe.transformer.add_adapter(denoiser_lora_config, "adapter-1")
self.assertTrue(check_if_lora_correctly_set(pipe.transformer), "Lora not correctly set in denoiser")
# corrupt one LoRA weight with `inf` values
with torch.no_grad():
pipe.transformer.transformer_blocks[0].attn1.to_q.lora_A["adapter-1"].weight += float("inf")
# with `safe_fusing=True` we should see an Error
with self.assertRaises(ValueError):
pipe.fuse_lora(components=self.pipeline_class._lora_loadable_modules, safe_fusing=True)
# without we should not see an error, but every image will be black
pipe.fuse_lora(components=self.pipeline_class._lora_loadable_modules, safe_fusing=False)
out = pipe(
"test", num_inference_steps=2, max_sequence_length=inputs["max_sequence_length"], output_type="np"
)[0]
self.assertTrue(np.isnan(out).all())
def test_simple_inference_with_text_lora_denoiser_fused_multi(self):
super().test_simple_inference_with_text_lora_denoiser_fused_multi(expected_atol=9e-3)
+245 -73
View File
@@ -12,6 +12,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import copy
import gc
import os
import sys
@@ -36,7 +37,6 @@ from diffusers.utils.testing_utils import (
numpy_cosine_similarity_distance,
require_big_gpu_with_torch_cuda,
require_peft_backend,
require_peft_version_greater,
require_torch_gpu,
slow,
torch_device,
@@ -163,6 +163,105 @@ class FluxLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
)
self.assertFalse(np.allclose(images_lora_with_alpha, images_lora, atol=1e-3, rtol=1e-3))
def test_lora_expansion_works_for_absent_keys(self):
components, _, denoiser_lora_config = self.get_dummy_components(FlowMatchEulerDiscreteScheduler)
pipe = self.pipeline_class(**components)
pipe = pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)
_, _, inputs = self.get_dummy_inputs(with_generator=False)
output_no_lora = pipe(**inputs, generator=torch.manual_seed(0)).images
self.assertTrue(output_no_lora.shape == self.output_shape)
# Modify the config to have a layer which won't be present in the second LoRA we will load.
modified_denoiser_lora_config = copy.deepcopy(denoiser_lora_config)
modified_denoiser_lora_config.target_modules.add("x_embedder")
pipe.transformer.add_adapter(modified_denoiser_lora_config)
self.assertTrue(check_if_lora_correctly_set(pipe.transformer), "Lora not correctly set in transformer")
images_lora = pipe(**inputs, generator=torch.manual_seed(0)).images
self.assertFalse(
np.allclose(images_lora, output_no_lora, atol=1e-3, rtol=1e-3),
"LoRA should lead to different results.",
)
with tempfile.TemporaryDirectory() as tmpdirname:
denoiser_state_dict = get_peft_model_state_dict(pipe.transformer)
self.pipeline_class.save_lora_weights(tmpdirname, transformer_lora_layers=denoiser_state_dict)
self.assertTrue(os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors")))
pipe.unload_lora_weights()
pipe.load_lora_weights(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors"), adapter_name="one")
# Modify the state dict to exclude "x_embedder" related LoRA params.
lora_state_dict = safetensors.torch.load_file(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors"))
lora_state_dict_without_xembedder = {k: v for k, v in lora_state_dict.items() if "x_embedder" not in k}
pipe.load_lora_weights(lora_state_dict_without_xembedder, adapter_name="two")
pipe.set_adapters(["one", "two"])
self.assertTrue(check_if_lora_correctly_set(pipe.transformer), "Lora not correctly set in transformer")
images_lora_with_absent_keys = pipe(**inputs, generator=torch.manual_seed(0)).images
self.assertFalse(
np.allclose(images_lora, images_lora_with_absent_keys, atol=1e-3, rtol=1e-3),
"Different LoRAs should lead to different results.",
)
self.assertFalse(
np.allclose(output_no_lora, images_lora_with_absent_keys, atol=1e-3, rtol=1e-3),
"LoRA should lead to different results.",
)
def test_lora_expansion_works_for_extra_keys(self):
components, _, denoiser_lora_config = self.get_dummy_components(FlowMatchEulerDiscreteScheduler)
pipe = self.pipeline_class(**components)
pipe = pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)
_, _, inputs = self.get_dummy_inputs(with_generator=False)
output_no_lora = pipe(**inputs, generator=torch.manual_seed(0)).images
self.assertTrue(output_no_lora.shape == self.output_shape)
# Modify the config to have a layer which won't be present in the first LoRA we will load.
modified_denoiser_lora_config = copy.deepcopy(denoiser_lora_config)
modified_denoiser_lora_config.target_modules.add("x_embedder")
pipe.transformer.add_adapter(modified_denoiser_lora_config)
self.assertTrue(check_if_lora_correctly_set(pipe.transformer), "Lora not correctly set in transformer")
images_lora = pipe(**inputs, generator=torch.manual_seed(0)).images
self.assertFalse(
np.allclose(images_lora, output_no_lora, atol=1e-3, rtol=1e-3),
"LoRA should lead to different results.",
)
with tempfile.TemporaryDirectory() as tmpdirname:
denoiser_state_dict = get_peft_model_state_dict(pipe.transformer)
self.pipeline_class.save_lora_weights(tmpdirname, transformer_lora_layers=denoiser_state_dict)
self.assertTrue(os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors")))
pipe.unload_lora_weights()
# Modify the state dict to exclude "x_embedder" related LoRA params.
lora_state_dict = safetensors.torch.load_file(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors"))
lora_state_dict_without_xembedder = {k: v for k, v in lora_state_dict.items() if "x_embedder" not in k}
pipe.load_lora_weights(lora_state_dict_without_xembedder, adapter_name="one")
# Load state dict with `x_embedder`.
pipe.load_lora_weights(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors"), adapter_name="two")
pipe.set_adapters(["one", "two"])
self.assertTrue(check_if_lora_correctly_set(pipe.transformer), "Lora not correctly set in transformer")
images_lora_with_extra_keys = pipe(**inputs, generator=torch.manual_seed(0)).images
self.assertFalse(
np.allclose(images_lora, images_lora_with_extra_keys, atol=1e-3, rtol=1e-3),
"Different LoRAs should lead to different results.",
)
self.assertFalse(
np.allclose(output_no_lora, images_lora_with_extra_keys, atol=1e-3, rtol=1e-3),
"LoRA should lead to different results.",
)
@unittest.skip("Not supported in Flux.")
def test_simple_inference_with_text_denoiser_block_scale_for_all_dict_options(self):
pass
@@ -331,7 +430,8 @@ class FluxControlLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
}
with CaptureLogger(logger) as cap_logger:
pipe.load_lora_weights(lora_state_dict, "adapter-1")
self.assertTrue(check_if_lora_correctly_set(pipe.transformer), "Lora not correctly set in denoiser")
self.assertTrue(check_if_lora_correctly_set(pipe.transformer), "Lora not correctly set in denoiser")
lora_out = pipe(**inputs, generator=torch.manual_seed(0))[0]
@@ -340,85 +440,32 @@ class FluxControlLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
self.assertTrue(pipe.transformer.config.in_channels == 2 * in_features)
self.assertTrue(cap_logger.out.startswith("Expanding the nn.Linear input/output features for module"))
@require_peft_version_greater("0.13.2")
def test_lora_B_bias(self):
components, _, denoiser_lora_config = self.get_dummy_components(FlowMatchEulerDiscreteScheduler)
# Testing opposite direction where the LoRA params are zero-padded.
components, _, _ = self.get_dummy_components(FlowMatchEulerDiscreteScheduler)
pipe = self.pipeline_class(**components)
pipe = pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)
# keep track of the bias values of the base layers to perform checks later.
bias_values = {}
for name, module in pipe.transformer.named_modules():
if any(k in name for k in ["to_q", "to_k", "to_v", "to_out.0"]):
if module.bias is not None:
bias_values[name] = module.bias.data.clone()
_, _, inputs = self.get_dummy_inputs(with_generator=False)
logger = logging.get_logger("diffusers.loaders.lora_pipeline")
logger.setLevel(logging.INFO)
original_output = pipe(**inputs, generator=torch.manual_seed(0))[0]
denoiser_lora_config.lora_bias = False
pipe.transformer.add_adapter(denoiser_lora_config, "adapter-1")
lora_bias_false_output = pipe(**inputs, generator=torch.manual_seed(0))[0]
pipe.delete_adapters("adapter-1")
denoiser_lora_config.lora_bias = True
pipe.transformer.add_adapter(denoiser_lora_config, "adapter-1")
lora_bias_true_output = pipe(**inputs, generator=torch.manual_seed(0))[0]
self.assertFalse(np.allclose(original_output, lora_bias_false_output, atol=1e-3, rtol=1e-3))
self.assertFalse(np.allclose(original_output, lora_bias_true_output, atol=1e-3, rtol=1e-3))
self.assertFalse(np.allclose(lora_bias_false_output, lora_bias_true_output, atol=1e-3, rtol=1e-3))
# for now this is flux control lora specific but can be generalized later and added to ./utils.py
def test_correct_lora_configs_with_different_ranks(self):
components, _, denoiser_lora_config = self.get_dummy_components(FlowMatchEulerDiscreteScheduler)
pipe = self.pipeline_class(**components)
pipe = pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)
_, _, inputs = self.get_dummy_inputs(with_generator=False)
original_output = pipe(**inputs, generator=torch.manual_seed(0))[0]
pipe.transformer.add_adapter(denoiser_lora_config, "adapter-1")
lora_output_same_rank = pipe(**inputs, generator=torch.manual_seed(0))[0]
pipe.transformer.delete_adapters("adapter-1")
# change the rank_pattern
updated_rank = denoiser_lora_config.r * 2
denoiser_lora_config.rank_pattern = {"single_transformer_blocks.0.attn.to_k": updated_rank}
pipe.transformer.add_adapter(denoiser_lora_config, "adapter-1")
assert pipe.transformer.peft_config["adapter-1"].rank_pattern == {
"single_transformer_blocks.0.attn.to_k": updated_rank
dummy_lora_A = torch.nn.Linear(1, rank, bias=False)
dummy_lora_B = torch.nn.Linear(rank, out_features, bias=False)
lora_state_dict = {
"transformer.x_embedder.lora_A.weight": dummy_lora_A.weight,
"transformer.x_embedder.lora_B.weight": dummy_lora_B.weight,
}
with CaptureLogger(logger) as cap_logger:
pipe.load_lora_weights(lora_state_dict, "adapter-1")
lora_output_diff_rank = pipe(**inputs, generator=torch.manual_seed(0))[0]
self.assertTrue(check_if_lora_correctly_set(pipe.transformer), "Lora not correctly set in denoiser")
self.assertTrue(not np.allclose(original_output, lora_output_same_rank, atol=1e-3, rtol=1e-3))
self.assertTrue(not np.allclose(lora_output_diff_rank, lora_output_same_rank, atol=1e-3, rtol=1e-3))
pipe.transformer.delete_adapters("adapter-1")
lora_out = pipe(**inputs, generator=torch.manual_seed(0))[0]
# similarly change the alpha_pattern
updated_alpha = denoiser_lora_config.lora_alpha * 2
denoiser_lora_config.alpha_pattern = {"single_transformer_blocks.0.attn.to_k": updated_alpha}
pipe.transformer.add_adapter(denoiser_lora_config, "adapter-1")
assert pipe.transformer.peft_config["adapter-1"].alpha_pattern == {
"single_transformer_blocks.0.attn.to_k": updated_alpha
}
self.assertFalse(np.allclose(original_out, lora_out, rtol=1e-4, atol=1e-4))
self.assertTrue(pipe.transformer.x_embedder.weight.data.shape[1] == 2 * in_features)
self.assertTrue(pipe.transformer.config.in_channels == 2 * in_features)
self.assertTrue("The following LoRA modules were zero padded to match the state dict of" in cap_logger.out)
lora_output_diff_alpha = pipe(**inputs, generator=torch.manual_seed(0))[0]
self.assertTrue(not np.allclose(original_output, lora_output_diff_alpha, atol=1e-3, rtol=1e-3))
self.assertTrue(not np.allclose(lora_output_diff_alpha, lora_output_same_rank, atol=1e-3, rtol=1e-3))
def test_lora_expanding_shape_with_normal_lora(self):
# This test checks if it works when a lora with expanded shapes (like control loras) but
# another lora with correct shapes is loaded. The opposite direction isn't supported and is
# tested with it.
def test_normal_lora_with_expanded_lora_raises_error(self):
# Test the following situation. Load a regular LoRA (such as the ones trained on Flux.1-Dev). And then
# load shape expanded LoRA (such as Control LoRA).
components, _, _ = self.get_dummy_components(FlowMatchEulerDiscreteScheduler)
# Change the transformer config to mimic a real use case.
@@ -611,6 +658,131 @@ class FluxControlLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
self.assertTrue(pipe.transformer.x_embedder.weight.data.shape[1] == in_features * 2)
self.assertFalse(np.allclose(original_output, lora_output, atol=1e-3, rtol=1e-3))
def test_lora_unload_with_parameter_expanded_shapes(self):
components, _, _ = self.get_dummy_components(FlowMatchEulerDiscreteScheduler)
logger = logging.get_logger("diffusers.loaders.lora_pipeline")
logger.setLevel(logging.DEBUG)
# Change the transformer config to mimic a real use case.
num_channels_without_control = 4
transformer = FluxTransformer2DModel.from_config(
components["transformer"].config, in_channels=num_channels_without_control
).to(torch_device)
self.assertTrue(
transformer.config.in_channels == num_channels_without_control,
f"Expected {num_channels_without_control} channels in the modified transformer but has {transformer.config.in_channels=}",
)
# This should be initialized with a Flux pipeline variant that doesn't accept `control_image`.
components["transformer"] = transformer
pipe = FluxPipeline(**components)
pipe = pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)
_, _, inputs = self.get_dummy_inputs(with_generator=False)
control_image = inputs.pop("control_image")
original_out = pipe(**inputs, generator=torch.manual_seed(0))[0]
control_pipe = self.pipeline_class(**components)
out_features, in_features = control_pipe.transformer.x_embedder.weight.shape
rank = 4
dummy_lora_A = torch.nn.Linear(2 * in_features, rank, bias=False)
dummy_lora_B = torch.nn.Linear(rank, out_features, bias=False)
lora_state_dict = {
"transformer.x_embedder.lora_A.weight": dummy_lora_A.weight,
"transformer.x_embedder.lora_B.weight": dummy_lora_B.weight,
}
with CaptureLogger(logger) as cap_logger:
control_pipe.load_lora_weights(lora_state_dict, "adapter-1")
self.assertTrue(check_if_lora_correctly_set(pipe.transformer), "Lora not correctly set in denoiser")
inputs["control_image"] = control_image
lora_out = control_pipe(**inputs, generator=torch.manual_seed(0))[0]
self.assertFalse(np.allclose(original_out, lora_out, rtol=1e-4, atol=1e-4))
self.assertTrue(pipe.transformer.x_embedder.weight.data.shape[1] == 2 * in_features)
self.assertTrue(pipe.transformer.config.in_channels == 2 * in_features)
self.assertTrue(cap_logger.out.startswith("Expanding the nn.Linear input/output features for module"))
control_pipe.unload_lora_weights(reset_to_overwritten_params=True)
self.assertTrue(
control_pipe.transformer.config.in_channels == num_channels_without_control,
f"Expected {num_channels_without_control} channels in the modified transformer but has {control_pipe.transformer.config.in_channels=}",
)
loaded_pipe = FluxPipeline.from_pipe(control_pipe)
self.assertTrue(
loaded_pipe.transformer.config.in_channels == num_channels_without_control,
f"Expected {num_channels_without_control} channels in the modified transformer but has {loaded_pipe.transformer.config.in_channels=}",
)
inputs.pop("control_image")
unloaded_lora_out = loaded_pipe(**inputs, generator=torch.manual_seed(0))[0]
self.assertFalse(np.allclose(unloaded_lora_out, lora_out, rtol=1e-4, atol=1e-4))
self.assertTrue(np.allclose(unloaded_lora_out, original_out, atol=1e-4, rtol=1e-4))
self.assertTrue(pipe.transformer.x_embedder.weight.data.shape[1] == in_features)
self.assertTrue(pipe.transformer.config.in_channels == in_features)
def test_lora_unload_with_parameter_expanded_shapes_and_no_reset(self):
components, _, _ = self.get_dummy_components(FlowMatchEulerDiscreteScheduler)
logger = logging.get_logger("diffusers.loaders.lora_pipeline")
logger.setLevel(logging.DEBUG)
# Change the transformer config to mimic a real use case.
num_channels_without_control = 4
transformer = FluxTransformer2DModel.from_config(
components["transformer"].config, in_channels=num_channels_without_control
).to(torch_device)
self.assertTrue(
transformer.config.in_channels == num_channels_without_control,
f"Expected {num_channels_without_control} channels in the modified transformer but has {transformer.config.in_channels=}",
)
# This should be initialized with a Flux pipeline variant that doesn't accept `control_image`.
components["transformer"] = transformer
pipe = FluxPipeline(**components)
pipe = pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)
_, _, inputs = self.get_dummy_inputs(with_generator=False)
control_image = inputs.pop("control_image")
original_out = pipe(**inputs, generator=torch.manual_seed(0))[0]
control_pipe = self.pipeline_class(**components)
out_features, in_features = control_pipe.transformer.x_embedder.weight.shape
rank = 4
dummy_lora_A = torch.nn.Linear(2 * in_features, rank, bias=False)
dummy_lora_B = torch.nn.Linear(rank, out_features, bias=False)
lora_state_dict = {
"transformer.x_embedder.lora_A.weight": dummy_lora_A.weight,
"transformer.x_embedder.lora_B.weight": dummy_lora_B.weight,
}
with CaptureLogger(logger) as cap_logger:
control_pipe.load_lora_weights(lora_state_dict, "adapter-1")
self.assertTrue(check_if_lora_correctly_set(pipe.transformer), "Lora not correctly set in denoiser")
inputs["control_image"] = control_image
lora_out = control_pipe(**inputs, generator=torch.manual_seed(0))[0]
self.assertFalse(np.allclose(original_out, lora_out, rtol=1e-4, atol=1e-4))
self.assertTrue(pipe.transformer.x_embedder.weight.data.shape[1] == 2 * in_features)
self.assertTrue(pipe.transformer.config.in_channels == 2 * in_features)
self.assertTrue(cap_logger.out.startswith("Expanding the nn.Linear input/output features for module"))
control_pipe.unload_lora_weights(reset_to_overwritten_params=False)
self.assertTrue(
control_pipe.transformer.config.in_channels == 2 * num_channels_without_control,
f"Expected {num_channels_without_control} channels in the modified transformer but has {control_pipe.transformer.config.in_channels=}",
)
no_lora_out = control_pipe(**inputs, generator=torch.manual_seed(0))[0]
self.assertFalse(np.allclose(no_lora_out, lora_out, rtol=1e-4, atol=1e-4))
self.assertTrue(pipe.transformer.x_embedder.weight.data.shape[1] == in_features * 2)
self.assertTrue(pipe.transformer.config.in_channels == in_features * 2)
@unittest.skip("Not supported in Flux.")
def test_simple_inference_with_text_denoiser_block_scale_for_all_dict_options(self):
pass
+72 -43
View File
@@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import gc
import sys
import unittest
@@ -28,16 +29,18 @@ from diffusers import (
)
from diffusers.utils.testing_utils import (
floats_tensor,
is_torch_version,
nightly,
numpy_cosine_similarity_distance,
require_big_gpu_with_torch_cuda,
require_peft_backend,
require_torch_gpu,
skip_mps,
torch_device,
)
sys.path.append(".")
from utils import PeftLoraLoaderMixinTests, check_if_lora_correctly_set # noqa: E402
from utils import PeftLoraLoaderMixinTests # noqa: E402
@require_peft_backend
@@ -144,46 +147,6 @@ class HunyuanVideoLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
return noise, input_ids, pipeline_inputs
@pytest.mark.xfail(
condition=torch.device(torch_device).type == "cpu" and is_torch_version(">=", "2.5"),
reason="Test currently fails on CPU and PyTorch 2.5.1 but not on PyTorch 2.4.1.",
strict=True,
)
def test_lora_fuse_nan(self):
for scheduler_cls in self.scheduler_classes:
components, text_lora_config, denoiser_lora_config = self.get_dummy_components(scheduler_cls)
pipe = self.pipeline_class(**components)
pipe = pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)
_, _, inputs = self.get_dummy_inputs(with_generator=False)
pipe.transformer.add_adapter(denoiser_lora_config, "adapter-1")
self.assertTrue(check_if_lora_correctly_set(pipe.transformer), "Lora not correctly set in denoiser")
# corrupt one LoRA weight with `inf` values
with torch.no_grad():
pipe.transformer.transformer_blocks[0].attn.to_q.lora_A["adapter-1"].weight += float("inf")
# with `safe_fusing=True` we should see an Error
with self.assertRaises(ValueError):
pipe.fuse_lora(components=self.pipeline_class._lora_loadable_modules, safe_fusing=True)
# without we should not see an error, but every image will be black
pipe.fuse_lora(components=self.pipeline_class._lora_loadable_modules, safe_fusing=False)
out = pipe(
prompt=inputs["prompt"],
height=inputs["height"],
width=inputs["width"],
num_frames=inputs["num_frames"],
num_inference_steps=inputs["num_inference_steps"],
max_sequence_length=inputs["max_sequence_length"],
output_type="np",
)[0]
self.assertTrue(np.isnan(out).all())
def test_simple_inference_with_text_lora_denoiser_fused_multi(self):
super().test_simple_inference_with_text_lora_denoiser_fused_multi(expected_atol=9e-3)
@@ -226,3 +189,69 @@ class HunyuanVideoLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
@unittest.skip("Text encoder LoRA is not supported in HunyuanVideo.")
def test_simple_inference_with_text_lora_save_load(self):
pass
@nightly
@require_torch_gpu
@require_peft_backend
@require_big_gpu_with_torch_cuda
@pytest.mark.big_gpu_with_torch_cuda
class HunyuanVideoLoRAIntegrationTests(unittest.TestCase):
"""internal note: The integration slices were obtained on DGX.
torch: 2.5.1+cu124 with CUDA 12.5. Need the same setup for the
assertions to pass.
"""
num_inference_steps = 10
seed = 0
def setUp(self):
super().setUp()
gc.collect()
torch.cuda.empty_cache()
model_id = "hunyuanvideo-community/HunyuanVideo"
transformer = HunyuanVideoTransformer3DModel.from_pretrained(
model_id, subfolder="transformer", torch_dtype=torch.bfloat16
)
self.pipeline = HunyuanVideoPipeline.from_pretrained(
model_id, transformer=transformer, torch_dtype=torch.float16
).to("cuda")
def tearDown(self):
super().tearDown()
gc.collect()
torch.cuda.empty_cache()
def test_original_format_cseti(self):
self.pipeline.load_lora_weights(
"Cseti/HunyuanVideo-LoRA-Arcane_Jinx-v1", weight_name="csetiarcane-nfjinx-v1-6000.safetensors"
)
self.pipeline.fuse_lora()
self.pipeline.unload_lora_weights()
self.pipeline.vae.enable_tiling()
prompt = "CSETIARCANE. A cat walks on the grass, realistic"
out = self.pipeline(
prompt=prompt,
height=320,
width=512,
num_frames=9,
num_inference_steps=self.num_inference_steps,
output_type="np",
generator=torch.manual_seed(self.seed),
).frames[0]
out = out.flatten()
out_slice = np.concatenate((out[:8], out[-8:]))
# fmt: off
expected_slice = np.array([0.1013, 0.1924, 0.0078, 0.1021, 0.1929, 0.0078, 0.1023, 0.1919, 0.7402, 0.104, 0.4482, 0.7354, 0.0925, 0.4382, 0.7275, 0.0815])
# fmt: on
max_diff = numpy_cosine_similarity_distance(expected_slice.flatten(), out_slice)
assert max_diff < 1e-3
+12 -46
View File
@@ -15,8 +15,6 @@
import sys
import unittest
import numpy as np
import pytest
import torch
from transformers import AutoTokenizer, T5EncoderModel
@@ -26,18 +24,12 @@ from diffusers import (
LTXPipeline,
LTXVideoTransformer3DModel,
)
from diffusers.utils.testing_utils import (
floats_tensor,
is_torch_version,
require_peft_backend,
skip_mps,
torch_device,
)
from diffusers.utils.testing_utils import floats_tensor, require_peft_backend
sys.path.append(".")
from utils import PeftLoraLoaderMixinTests, check_if_lora_correctly_set # noqa: E402
from utils import PeftLoraLoaderMixinTests # noqa: E402
@require_peft_backend
@@ -60,10 +52,19 @@ class LTXVideoLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
}
transformer_cls = LTXVideoTransformer3DModel
vae_kwargs = {
"in_channels": 3,
"out_channels": 3,
"latent_channels": 8,
"block_out_channels": (8, 8, 8, 8),
"spatio_temporal_scaling": (True, True, False, False),
"decoder_block_out_channels": (8, 8, 8, 8),
"layers_per_block": (1, 1, 1, 1, 1),
"decoder_layers_per_block": (1, 1, 1, 1, 1),
"spatio_temporal_scaling": (True, True, False, False),
"decoder_spatio_temporal_scaling": (True, True, False, False),
"decoder_inject_noise": (False, False, False, False, False),
"upsample_residual": (False, False, False, False),
"upsample_factor": (1, 1, 1, 1),
"timestep_conditioning": False,
"patch_size": 1,
"patch_size_t": 1,
"encoder_causal": True,
@@ -107,41 +108,6 @@ class LTXVideoLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
return noise, input_ids, pipeline_inputs
@skip_mps
@pytest.mark.xfail(
condition=torch.device(torch_device).type == "cpu" and is_torch_version(">=", "2.5"),
reason="Test currently fails on CPU and PyTorch 2.5.1 but not on PyTorch 2.4.1.",
strict=True,
)
def test_lora_fuse_nan(self):
for scheduler_cls in self.scheduler_classes:
components, text_lora_config, denoiser_lora_config = self.get_dummy_components(scheduler_cls)
pipe = self.pipeline_class(**components)
pipe = pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)
_, _, inputs = self.get_dummy_inputs(with_generator=False)
pipe.transformer.add_adapter(denoiser_lora_config, "adapter-1")
self.assertTrue(check_if_lora_correctly_set(pipe.transformer), "Lora not correctly set in denoiser")
# corrupt one LoRA weight with `inf` values
with torch.no_grad():
pipe.transformer.transformer_blocks[0].attn1.to_q.lora_A["adapter-1"].weight += float("inf")
# with `safe_fusing=True` we should see an Error
with self.assertRaises(ValueError):
pipe.fuse_lora(components=self.pipeline_class._lora_loadable_modules, safe_fusing=True)
# without we should not see an error, but every image will be black
pipe.fuse_lora(components=self.pipeline_class._lora_loadable_modules, safe_fusing=False)
out = pipe(
"test", num_inference_steps=2, max_sequence_length=inputs["max_sequence_length"], output_type="np"
)[0]
self.assertTrue(np.isnan(out).all())
def test_simple_inference_with_text_lora_denoiser_fused_multi(self):
super().test_simple_inference_with_text_lora_denoiser_fused_multi(expected_atol=9e-3)
+1 -39
View File
@@ -15,24 +15,20 @@
import sys
import unittest
import numpy as np
import pytest
import torch
from transformers import AutoTokenizer, T5EncoderModel
from diffusers import AutoencoderKLMochi, FlowMatchEulerDiscreteScheduler, MochiPipeline, MochiTransformer3DModel
from diffusers.utils.testing_utils import (
floats_tensor,
is_torch_version,
require_peft_backend,
skip_mps,
torch_device,
)
sys.path.append(".")
from utils import PeftLoraLoaderMixinTests, check_if_lora_correctly_set # noqa: E402
from utils import PeftLoraLoaderMixinTests # noqa: E402
@require_peft_backend
@@ -103,40 +99,6 @@ class MochiLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
return noise, input_ids, pipeline_inputs
@pytest.mark.xfail(
condition=torch.device(torch_device).type == "cpu" and is_torch_version(">=", "2.5"),
reason="Test currently fails on CPU and PyTorch 2.5.1 but not on PyTorch 2.4.1.",
strict=True,
)
def test_lora_fuse_nan(self):
for scheduler_cls in self.scheduler_classes:
components, text_lora_config, denoiser_lora_config = self.get_dummy_components(scheduler_cls)
pipe = self.pipeline_class(**components)
pipe = pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)
_, _, inputs = self.get_dummy_inputs(with_generator=False)
pipe.transformer.add_adapter(denoiser_lora_config, "adapter-1")
self.assertTrue(check_if_lora_correctly_set(pipe.transformer), "Lora not correctly set in denoiser")
# corrupt one LoRA weight with `inf` values
with torch.no_grad():
pipe.transformer.transformer_blocks[0].attn1.to_q.lora_A["adapter-1"].weight += float("inf")
# with `safe_fusing=True` we should see an Error
with self.assertRaises(ValueError):
pipe.fuse_lora(components=self.pipeline_class._lora_loadable_modules, safe_fusing=True)
# without we should not see an error, but every image will be black
pipe.fuse_lora(components=self.pipeline_class._lora_loadable_modules, safe_fusing=False)
out = pipe(
"test", num_inference_steps=2, max_sequence_length=inputs["max_sequence_length"], output_type="np"
)[0]
self.assertTrue(np.isnan(out).all())
def test_simple_inference_with_text_lora_denoiser_fused_multi(self):
super().test_simple_inference_with_text_lora_denoiser_fused_multi(expected_atol=9e-3)
+10 -10
View File
@@ -16,7 +16,7 @@ import sys
import unittest
import torch
from transformers import Gemma2ForCausalLM, GemmaTokenizer
from transformers import Gemma2Model, GemmaTokenizer
from diffusers import AutoencoderDC, FlowMatchEulerDiscreteScheduler, SanaPipeline, SanaTransformer2DModel
from diffusers.utils.testing_utils import floats_tensor, require_peft_backend
@@ -73,7 +73,7 @@ class SanaLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
}
vae_cls = AutoencoderDC
tokenizer_cls, tokenizer_id = GemmaTokenizer, "hf-internal-testing/dummy-gemma"
text_encoder_cls, text_encoder_id = Gemma2ForCausalLM, "hf-internal-testing/dummy-gemma-for-diffusers"
text_encoder_cls, text_encoder_id = Gemma2Model, "hf-internal-testing/dummy-gemma-for-diffusers"
@property
def output_shape(self):
@@ -105,34 +105,34 @@ class SanaLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
return noise, input_ids, pipeline_inputs
@unittest.skip("Not supported in Sana.")
@unittest.skip("Not supported in SANA.")
def test_modify_padding_mode(self):
pass
@unittest.skip("Not supported in Mochi.")
@unittest.skip("Not supported in SANA.")
def test_simple_inference_with_text_denoiser_block_scale(self):
pass
@unittest.skip("Not supported in Mochi.")
@unittest.skip("Not supported in SANA.")
def test_simple_inference_with_text_denoiser_block_scale_for_all_dict_options(self):
pass
@unittest.skip("Text encoder LoRA is not supported in Mochi.")
@unittest.skip("Text encoder LoRA is not supported in SANA.")
def test_simple_inference_with_partial_text_lora(self):
pass
@unittest.skip("Text encoder LoRA is not supported in Mochi.")
@unittest.skip("Text encoder LoRA is not supported in SANA.")
def test_simple_inference_with_text_lora(self):
pass
@unittest.skip("Text encoder LoRA is not supported in Mochi.")
@unittest.skip("Text encoder LoRA is not supported in SANA.")
def test_simple_inference_with_text_lora_and_scale(self):
pass
@unittest.skip("Text encoder LoRA is not supported in Mochi.")
@unittest.skip("Text encoder LoRA is not supported in SANA.")
def test_simple_inference_with_text_lora_fused(self):
pass
@unittest.skip("Text encoder LoRA is not supported in Mochi.")
@unittest.skip("Text encoder LoRA is not supported in SANA.")
def test_simple_inference_with_text_lora_save_load(self):
pass
+112 -2
View File
@@ -1528,7 +1528,7 @@ class PeftLoraLoaderMixinTests:
@pytest.mark.xfail(
condition=torch.device(torch_device).type == "cpu" and is_torch_version(">=", "2.5"),
reason="Test currently fails on CPU and PyTorch 2.5.1 but not on PyTorch 2.4.1.",
strict=True,
strict=False,
)
def test_lora_fuse_nan(self):
for scheduler_cls in self.scheduler_classes:
@@ -1568,7 +1568,7 @@ class PeftLoraLoaderMixinTests:
# without we should not see an error, but every image will be black
pipe.fuse_lora(components=self.pipeline_class._lora_loadable_modules, safe_fusing=False)
out = pipe("test", num_inference_steps=2, output_type="np")[0]
out = pipe(**inputs)[0]
self.assertTrue(np.isnan(out).all())
@@ -1988,3 +1988,113 @@ class PeftLoraLoaderMixinTests:
np.allclose(output_lora_scale_wo_kwargs, output_lora_from_pretrained, atol=1e-3, rtol=1e-3),
"Loading from saved checkpoints should give same results as set_adapters().",
)
@require_peft_version_greater("0.13.2")
def test_lora_B_bias(self):
# Currently, this test is only relevant for Flux Control LoRA as we are not
# aware of any other LoRA checkpoint that has its `lora_B` biases trained.
components, _, denoiser_lora_config = self.get_dummy_components(self.scheduler_classes[0])
pipe = self.pipeline_class(**components)
pipe = pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)
# keep track of the bias values of the base layers to perform checks later.
bias_values = {}
denoiser = pipe.unet if self.unet_kwargs is not None else pipe.transformer
for name, module in denoiser.named_modules():
if any(k in name for k in ["to_q", "to_k", "to_v", "to_out.0"]):
if module.bias is not None:
bias_values[name] = module.bias.data.clone()
_, _, inputs = self.get_dummy_inputs(with_generator=False)
logger = logging.get_logger("diffusers.loaders.lora_pipeline")
logger.setLevel(logging.INFO)
original_output = pipe(**inputs, generator=torch.manual_seed(0))[0]
denoiser_lora_config.lora_bias = False
if self.unet_kwargs is not None:
pipe.unet.add_adapter(denoiser_lora_config, "adapter-1")
else:
pipe.transformer.add_adapter(denoiser_lora_config, "adapter-1")
lora_bias_false_output = pipe(**inputs, generator=torch.manual_seed(0))[0]
pipe.delete_adapters("adapter-1")
denoiser_lora_config.lora_bias = True
if self.unet_kwargs is not None:
pipe.unet.add_adapter(denoiser_lora_config, "adapter-1")
else:
pipe.transformer.add_adapter(denoiser_lora_config, "adapter-1")
lora_bias_true_output = pipe(**inputs, generator=torch.manual_seed(0))[0]
self.assertFalse(np.allclose(original_output, lora_bias_false_output, atol=1e-3, rtol=1e-3))
self.assertFalse(np.allclose(original_output, lora_bias_true_output, atol=1e-3, rtol=1e-3))
self.assertFalse(np.allclose(lora_bias_false_output, lora_bias_true_output, atol=1e-3, rtol=1e-3))
def test_correct_lora_configs_with_different_ranks(self):
components, _, denoiser_lora_config = self.get_dummy_components(self.scheduler_classes[0])
pipe = self.pipeline_class(**components)
pipe = pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)
_, _, inputs = self.get_dummy_inputs(with_generator=False)
original_output = pipe(**inputs, generator=torch.manual_seed(0))[0]
if self.unet_kwargs is not None:
pipe.unet.add_adapter(denoiser_lora_config, "adapter-1")
else:
pipe.transformer.add_adapter(denoiser_lora_config, "adapter-1")
lora_output_same_rank = pipe(**inputs, generator=torch.manual_seed(0))[0]
if self.unet_kwargs is not None:
pipe.unet.delete_adapters("adapter-1")
else:
pipe.transformer.delete_adapters("adapter-1")
denoiser = pipe.unet if self.unet_kwargs is not None else pipe.transformer
for name, _ in denoiser.named_modules():
if "to_k" in name and "attn" in name and "lora" not in name:
module_name_to_rank_update = name.replace(".base_layer.", ".")
break
# change the rank_pattern
updated_rank = denoiser_lora_config.r * 2
denoiser_lora_config.rank_pattern = {module_name_to_rank_update: updated_rank}
if self.unet_kwargs is not None:
pipe.unet.add_adapter(denoiser_lora_config, "adapter-1")
updated_rank_pattern = pipe.unet.peft_config["adapter-1"].rank_pattern
else:
pipe.transformer.add_adapter(denoiser_lora_config, "adapter-1")
updated_rank_pattern = pipe.transformer.peft_config["adapter-1"].rank_pattern
self.assertTrue(updated_rank_pattern == {module_name_to_rank_update: updated_rank})
lora_output_diff_rank = pipe(**inputs, generator=torch.manual_seed(0))[0]
self.assertTrue(not np.allclose(original_output, lora_output_same_rank, atol=1e-3, rtol=1e-3))
self.assertTrue(not np.allclose(lora_output_diff_rank, lora_output_same_rank, atol=1e-3, rtol=1e-3))
if self.unet_kwargs is not None:
pipe.unet.delete_adapters("adapter-1")
else:
pipe.transformer.delete_adapters("adapter-1")
# similarly change the alpha_pattern
updated_alpha = denoiser_lora_config.lora_alpha * 2
denoiser_lora_config.alpha_pattern = {module_name_to_rank_update: updated_alpha}
if self.unet_kwargs is not None:
pipe.unet.add_adapter(denoiser_lora_config, "adapter-1")
self.assertTrue(
pipe.unet.peft_config["adapter-1"].alpha_pattern == {module_name_to_rank_update: updated_alpha}
)
else:
pipe.transformer.add_adapter(denoiser_lora_config, "adapter-1")
self.assertTrue(
pipe.transformer.peft_config["adapter-1"].alpha_pattern == {module_name_to_rank_update: updated_alpha}
)
lora_output_diff_alpha = pipe(**inputs, generator=torch.manual_seed(0))[0]
self.assertTrue(not np.allclose(original_output, lora_output_diff_alpha, atol=1e-3, rtol=1e-3))
self.assertTrue(not np.allclose(lora_output_diff_alpha, lora_output_same_rank, atol=1e-3, rtol=1e-3))
@@ -0,0 +1,169 @@
# coding=utf-8
# Copyright 2024 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import torch
from diffusers import AutoencoderKLLTXVideo
from diffusers.utils.testing_utils import (
enable_full_determinism,
floats_tensor,
torch_device,
)
from ..test_modeling_common import ModelTesterMixin, UNetTesterMixin
enable_full_determinism()
class AutoencoderKLLTXVideo090Tests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase):
model_class = AutoencoderKLLTXVideo
main_input_name = "sample"
base_precision = 1e-2
def get_autoencoder_kl_ltx_video_config(self):
return {
"in_channels": 3,
"out_channels": 3,
"latent_channels": 8,
"block_out_channels": (8, 8, 8, 8),
"decoder_block_out_channels": (8, 8, 8, 8),
"layers_per_block": (1, 1, 1, 1, 1),
"decoder_layers_per_block": (1, 1, 1, 1, 1),
"spatio_temporal_scaling": (True, True, False, False),
"decoder_spatio_temporal_scaling": (True, True, False, False),
"decoder_inject_noise": (False, False, False, False, False),
"upsample_residual": (False, False, False, False),
"upsample_factor": (1, 1, 1, 1),
"timestep_conditioning": False,
"patch_size": 1,
"patch_size_t": 1,
"encoder_causal": True,
"decoder_causal": False,
}
@property
def dummy_input(self):
batch_size = 2
num_frames = 9
num_channels = 3
sizes = (16, 16)
image = floats_tensor((batch_size, num_channels, num_frames) + sizes).to(torch_device)
return {"sample": image}
@property
def input_shape(self):
return (3, 9, 16, 16)
@property
def output_shape(self):
return (3, 9, 16, 16)
def prepare_init_args_and_inputs_for_common(self):
init_dict = self.get_autoencoder_kl_ltx_video_config()
inputs_dict = self.dummy_input
return init_dict, inputs_dict
def test_gradient_checkpointing_is_applied(self):
expected_set = {
"LTXVideoEncoder3d",
"LTXVideoDecoder3d",
"LTXVideoDownBlock3D",
"LTXVideoMidBlock3d",
"LTXVideoUpBlock3d",
}
super().test_gradient_checkpointing_is_applied(expected_set=expected_set)
@unittest.skip("Unsupported test.")
def test_outputs_equivalence(self):
pass
@unittest.skip("AutoencoderKLLTXVideo does not support `norm_num_groups` because it does not use GroupNorm.")
def test_forward_with_norm_groups(self):
pass
class AutoencoderKLLTXVideo091Tests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase):
model_class = AutoencoderKLLTXVideo
main_input_name = "sample"
base_precision = 1e-2
def get_autoencoder_kl_ltx_video_config(self):
return {
"in_channels": 3,
"out_channels": 3,
"latent_channels": 8,
"block_out_channels": (8, 8, 8, 8),
"decoder_block_out_channels": (16, 32, 64),
"layers_per_block": (1, 1, 1, 1),
"decoder_layers_per_block": (1, 1, 1, 1),
"spatio_temporal_scaling": (True, True, True, False),
"decoder_spatio_temporal_scaling": (True, True, True),
"decoder_inject_noise": (True, True, True, False),
"upsample_residual": (True, True, True),
"upsample_factor": (2, 2, 2),
"timestep_conditioning": True,
"patch_size": 1,
"patch_size_t": 1,
"encoder_causal": True,
"decoder_causal": False,
}
@property
def dummy_input(self):
batch_size = 2
num_frames = 9
num_channels = 3
sizes = (16, 16)
image = floats_tensor((batch_size, num_channels, num_frames) + sizes).to(torch_device)
timestep = torch.tensor([0.05] * batch_size, device=torch_device)
return {"sample": image, "temb": timestep}
@property
def input_shape(self):
return (3, 9, 16, 16)
@property
def output_shape(self):
return (3, 9, 16, 16)
def prepare_init_args_and_inputs_for_common(self):
init_dict = self.get_autoencoder_kl_ltx_video_config()
inputs_dict = self.dummy_input
return init_dict, inputs_dict
def test_gradient_checkpointing_is_applied(self):
expected_set = {
"LTXVideoEncoder3d",
"LTXVideoDecoder3d",
"LTXVideoDownBlock3D",
"LTXVideoMidBlock3d",
"LTXVideoUpBlock3d",
}
super().test_gradient_checkpointing_is_applied(expected_set=expected_set)
@unittest.skip("Unsupported test.")
def test_outputs_equivalence(self):
pass
@unittest.skip("AutoencoderKLLTXVideo does not support `norm_num_groups` because it does not use GroupNorm.")
def test_forward_with_norm_groups(self):
pass
+11
View File
@@ -2,10 +2,12 @@ import tempfile
import unittest
import numpy as np
import pytest
import torch
from diffusers import DiffusionPipeline
from diffusers.models.attention_processor import Attention, AttnAddedKVProcessor
from diffusers.utils.testing_utils import torch_device
class AttnAddedKVProcessorTests(unittest.TestCase):
@@ -79,6 +81,15 @@ class AttnAddedKVProcessorTests(unittest.TestCase):
class DeprecatedAttentionBlockTests(unittest.TestCase):
@pytest.fixture(scope="session")
def is_dist_enabled(pytestconfig):
return pytestconfig.getoption("dist") == "loadfile"
@pytest.mark.xfail(
condition=torch.device(torch_device).type == "cuda" and is_dist_enabled,
reason="Test currently fails on our GPU CI because of `loadfile`. Note that it only fails when the tests are distributed from `pytest ... tests/models`. If the tests are run individually, even with `loadfile` it won't fail.",
strict=True,
)
def test_conversion_when_using_device_map(self):
pipe = DiffusionPipeline.from_pretrained(
"hf-internal-testing/tiny-stable-diffusion-torch", safety_checker=None
+78 -10
View File
@@ -22,12 +22,14 @@ import traceback
import unittest
import unittest.mock as mock
import uuid
from typing import Dict, List, Tuple
from collections import defaultdict
from typing import Dict, List, Optional, Tuple, Union
import numpy as np
import requests_mock
import torch
from accelerate.utils import compute_module_sizes
import torch.nn as nn
from accelerate.utils.modeling import _get_proper_dtype, dtype_byte_size
from huggingface_hub import ModelCard, delete_repo, snapshot_download
from huggingface_hub.utils import is_jinja_available
from parameterized import parameterized
@@ -113,6 +115,72 @@ def _test_from_save_pretrained_dynamo(in_queue, out_queue, timeout):
out_queue.join()
def named_persistent_module_tensors(
module: nn.Module,
recurse: bool = False,
):
"""
A helper function that gathers all the tensors (parameters + persistent buffers) of a given module.
Args:
module (`torch.nn.Module`):
The module we want the tensors on.
recurse (`bool`, *optional`, defaults to `False`):
Whether or not to go look in every submodule or just return the direct parameters and buffers.
"""
yield from module.named_parameters(recurse=recurse)
for named_buffer in module.named_buffers(recurse=recurse):
name, _ = named_buffer
# Get parent by splitting on dots and traversing the model
parent = module
if "." in name:
parent_name = name.rsplit(".", 1)[0]
for part in parent_name.split("."):
parent = getattr(parent, part)
name = name.split(".")[-1]
if name not in parent._non_persistent_buffers_set:
yield named_buffer
def compute_module_persistent_sizes(
model: nn.Module,
dtype: Optional[Union[str, torch.device]] = None,
special_dtypes: Optional[Dict[str, Union[str, torch.device]]] = None,
):
"""
Compute the size of each submodule of a given model (parameters + persistent buffers).
"""
if dtype is not None:
dtype = _get_proper_dtype(dtype)
dtype_size = dtype_byte_size(dtype)
if special_dtypes is not None:
special_dtypes = {key: _get_proper_dtype(dtyp) for key, dtyp in special_dtypes.items()}
special_dtypes_size = {key: dtype_byte_size(dtyp) for key, dtyp in special_dtypes.items()}
module_sizes = defaultdict(int)
module_list = []
module_list = named_persistent_module_tensors(model, recurse=True)
for name, tensor in module_list:
if special_dtypes is not None and name in special_dtypes:
size = tensor.numel() * special_dtypes_size[name]
elif dtype is None:
size = tensor.numel() * dtype_byte_size(tensor.dtype)
elif str(tensor.dtype).startswith(("torch.uint", "torch.int", "torch.bool")):
# According to the code in set_module_tensor_to_device, these types won't be converted
# so use their original size here
size = tensor.numel() * dtype_byte_size(tensor.dtype)
else:
size = tensor.numel() * min(dtype_size, dtype_byte_size(tensor.dtype))
name_parts = name.split(".")
for idx in range(len(name_parts) + 1):
module_sizes[".".join(name_parts[:idx])] += size
return module_sizes
class ModelUtilsTest(unittest.TestCase):
def tearDown(self):
super().tearDown()
@@ -1012,7 +1080,7 @@ class ModelTesterMixin:
torch.manual_seed(0)
base_output = model(**inputs_dict)
model_size = compute_module_sizes(model)[""]
model_size = compute_module_persistent_sizes(model)[""]
# We test several splits of sizes to make sure it works.
max_gpu_sizes = [int(p * model_size) for p in self.model_split_percents[1:]]
with tempfile.TemporaryDirectory() as tmp_dir:
@@ -1042,7 +1110,7 @@ class ModelTesterMixin:
torch.manual_seed(0)
base_output = model(**inputs_dict)
model_size = compute_module_sizes(model)[""]
model_size = compute_module_persistent_sizes(model)[""]
with tempfile.TemporaryDirectory() as tmp_dir:
model.cpu().save_pretrained(tmp_dir, safe_serialization=False)
@@ -1076,7 +1144,7 @@ class ModelTesterMixin:
torch.manual_seed(0)
base_output = model(**inputs_dict)
model_size = compute_module_sizes(model)[""]
model_size = compute_module_persistent_sizes(model)[""]
with tempfile.TemporaryDirectory() as tmp_dir:
model.cpu().save_pretrained(tmp_dir)
@@ -1104,7 +1172,7 @@ class ModelTesterMixin:
torch.manual_seed(0)
base_output = model(**inputs_dict)
model_size = compute_module_sizes(model)[""]
model_size = compute_module_persistent_sizes(model)[""]
# We test several splits of sizes to make sure it works.
max_gpu_sizes = [int(p * model_size) for p in self.model_split_percents[1:]]
with tempfile.TemporaryDirectory() as tmp_dir:
@@ -1132,7 +1200,7 @@ class ModelTesterMixin:
base_output = model(**inputs_dict)
model_size = compute_module_sizes(model)[""]
model_size = compute_module_persistent_sizes(model)[""]
max_shard_size = int((model_size * 0.75) / (2**10)) # Convert to KB as these test models are small.
with tempfile.TemporaryDirectory() as tmp_dir:
model.cpu().save_pretrained(tmp_dir, max_shard_size=f"{max_shard_size}KB")
@@ -1164,7 +1232,7 @@ class ModelTesterMixin:
base_output = model(**inputs_dict)
model_size = compute_module_sizes(model)[""]
model_size = compute_module_persistent_sizes(model)[""]
max_shard_size = int((model_size * 0.75) / (2**10)) # Convert to KB as these test models are small.
variant = "fp16"
with tempfile.TemporaryDirectory() as tmp_dir:
@@ -1204,7 +1272,7 @@ class ModelTesterMixin:
torch.manual_seed(0)
base_output = model(**inputs_dict)
model_size = compute_module_sizes(model)[""]
model_size = compute_module_persistent_sizes(model)[""]
max_shard_size = int((model_size * 0.75) / (2**10)) # Convert to KB as these test models are small.
with tempfile.TemporaryDirectory() as tmp_dir:
model.cpu().save_pretrained(tmp_dir, max_shard_size=f"{max_shard_size}KB")
@@ -1233,7 +1301,7 @@ class ModelTesterMixin:
config, _ = self.prepare_init_args_and_inputs_for_common()
model = self.model_class(**config).eval()
model_size = compute_module_sizes(model)[""]
model_size = compute_module_persistent_sizes(model)[""]
max_shard_size = int((model_size * 0.75) / (2**10)) # Convert to KB as these test models are small.
variant = "fp16"
with tempfile.TemporaryDirectory() as tmp_dir:
@@ -30,6 +30,8 @@ class MochiTransformerTests(ModelTesterMixin, unittest.TestCase):
model_class = MochiTransformer3DModel
main_input_name = "hidden_states"
uses_custom_attn_processor = True
# Overriding it because of the transformer size.
model_split_percents = [0.7, 0.6, 0.6]
@property
def dummy_input(self):

Some files were not shown because too many files have changed in this diff Show More