Compare commits

...

48 Commits

Author SHA1 Message Date
Kashif Rasul ccfaf0b75f initial script copied from the dpo trainer 2025-02-11 16:19:26 +01:00
hlky 7fb481f840 Add Self type hint to ModelMixin's from_pretrained (#10742) 2025-02-10 09:17:57 -10:00
Sayak Paul 9f5ad1db41 [LoRA] fix peft state dict parsing (#10532)
* fix peft state dict parsing

* updates
2025-02-10 18:47:20 +05:30
hlky 464374fb87 EDMEulerScheduler accept sigmas, add final_sigmas_type (#10734) 2025-02-07 06:53:52 +00:00
hlky d43ce14e2d Quantized Flux with IP-Adapter (#10728) 2025-02-06 07:02:36 -10:00
Leo Jiang cd0a4a82cf [bugfix] NPU Adaption for Sana (#10724)
* NPU Adaption for Sanna

* NPU Adaption for Sanna

* NPU Adaption for Sanna

* NPU Adaption for Sanna

* NPU Adaption for Sanna

* NPU Adaption for Sanna

* NPU Adaption for Sanna

* NPU Adaption for Sanna

* NPU Adaption for Sanna

* NPU Adaption for Sanna

* NPU Adaption for Sanna

* NPU Adaption for Sanna

* NPU Adaption for Sanna

* NPU Adaption for Sanna

* NPU Adaption for Sanna

* NPU Adaption for Sanna

* NPU Adaption for Sanna

* NPU Adaption for Sanna

* [bugfix]NPU Adaption for Sanna

---------

Co-authored-by: J石页 <jiangshuo9@h-partners.com>
Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
2025-02-06 19:29:58 +05:30
suzukimain 145522cbb7 [Community] Enhanced Model Search (#10417)
* Added `auto_load_textual_inversion` and `auto_load_lora_weights`

* update README.md

* fix

* make quality

* Fix and `make style`
2025-02-05 14:43:53 -10:00
xieofxie 23bc56a02d add provider_options in from_pretrained (#10719)
Co-authored-by: hualxie <hualxie@microsoft.com>
2025-02-05 09:41:41 -10:00
SahilCarterr 5b1dcd1584 [Fix] Type Hint in from_pretrained() to Ensure Correct Type Inference (#10714)
* Update pipeline_utils.py

Added Self in from_pretrained method so  inference will correctly recognize pipeline

* Use typing_extensions

---------

Co-authored-by: hlky <hlky@hlky.ac>
2025-02-04 08:59:31 -10:00
Parag Ekbote dbe0094e86 Notebooks for Community Scripts-6 (#10713)
* Fix Doc Tutorial.

* Add 4 Notebooks and improve their example
scripts.
2025-02-04 10:12:17 -08:00
Nicolas f63d32233f Fix train_text_to_image.py --help (#10711) 2025-02-04 11:26:23 +05:30
Sayak Paul 5e8e6cb44f [bitsandbytes] Simplify bnb int8 dequant (#10401)
* fix dequantization for latest bnb.

* smol fixes.

* fix type annotation

* update peft link

* updates
2025-02-04 11:17:14 +05:30
Parag Ekbote 3e35f56b00 Fix Documentation about Image-to-Image Pipeline (#10704)
Fix Doc Tutorial.
2025-02-03 09:54:00 -08:00
Ikpreet S Babra 537891e693 Fixed grammar in "write_own_pipeline" readme (#10706) 2025-02-03 09:53:30 -08:00
Vedat Baday 9f28f1abba feat(training-utils): support device and dtype params in compute_density_for_timestep_sampling (#10699)
* feat(training-utils): support device and dtype params in compute_density_for_timestep_sampling

* chore: update type hint

* refactor: use union for type hint

---------

Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
2025-02-01 23:04:05 +05:30
Thanh Le 5d2d23986e Fix inconsistent random transform in instruct pix2pix (#10698)
* Update train_instruct_pix2pix.py

Fix inconsistent random transform in instruct_pix2pix

* Update train_instruct_pix2pix_sdxl.py
2025-01-31 08:29:29 -10:00
Max Podkorytov 1ae9b0595f Fix enable memory efficient attention on ROCm (#10564)
* fix enable memory efficient attention on ROCm

while calling CK implementation

* Update attention_processor.py

refactor of picking a set element
2025-01-31 17:15:49 +05:30
SahilCarterr aad69ac2f3 [FIX] check_inputs function in Auraflow Pipeline (#10678)
fix_shape_error
2025-01-29 13:11:54 -10:00
Vedat Baday ea76880bd7 fix(hunyuan-video): typo in height and width input check (#10684) 2025-01-30 04:16:05 +05:30
Teriks 33f936154d support StableDiffusionAdapterPipeline.from_single_file (#10552)
* support StableDiffusionAdapterPipeline.from_single_file

* make style

---------

Co-authored-by: Teriks <Teriks@users.noreply.github.com>
Co-authored-by: hlky <hlky@hlky.ac>
2025-01-29 07:18:47 -10:00
Sayak Paul e6037e8275 [tests] update llamatokenizer in hunyuanvideo tests (#10681)
update llamatokenizer in hunyuanvideo tests
2025-01-29 21:12:57 +05:30
Dimitri Barbot 196aef5a6f Fix pipeline dtype unexpected change when using SDXL reference community pipelines in float16 mode (#10670)
Fix pipeline dtype unexpected change when using SDXL reference community pipelines
2025-01-28 10:46:41 -03:00
Sayak Paul 7b100ce589 [Tests] conditionally check fp8_e4m3_bf16_max_memory < fp8_e4m3_fp32_max_memory (#10669)
* conditionally check if compute capability is met.

* log info.

* fix condition.

* updates

* updates

* updates

* updates
2025-01-28 12:00:14 +05:30
Aryan c4d4ac21e7 Refactor gradient checkpointing (#10611)
* update

* remove unused fn

* apply suggestions based on review

* update + cleanup 🧹

* more cleanup 🧹

* make fix-copies

* update test
2025-01-28 06:51:46 +05:30
Hanch Han f295e2eefc [fix] refer use_framewise_encoding on AutoencoderKLHunyuanVideo._encode (#10600)
* fix: refer to use_framewise_encoding on AutoencoderKLHunyuanVideo._encode

* fix: comment about tile_sample_min_num_frames

---------

Co-authored-by: Aryan <aryan@huggingface.co>
2025-01-28 06:51:27 +05:30
Aryan 658e24e86c [core] Pyramid Attention Broadcast (#9562)
* start pyramid attention broadcast

* add coauthor

Co-Authored-By: Xuanlei Zhao <43881818+oahzxl@users.noreply.github.com>

* update

* make style

* update

* make style

* add docs

* add tests

* update

* Update docs/source/en/api/pipelines/cogvideox.md

Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com>

* Update docs/source/en/api/pipelines/cogvideox.md

Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com>

* Pyramid Attention Broadcast rewrite + introduce hooks (#9826)

* rewrite implementation with hooks

* make style

* update

* merge pyramid-attention-rewrite-2

* make style

* remove changes from latte transformer

* revert docs changes

* better debug message

* add todos for future

* update tests

* make style

* cleanup

* fix

* improve log message; fix latte test

* refactor

* update

* update

* update

* revert changes to tests

* update docs

* update tests

* Apply suggestions from code review

Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com>

* update

* fix flux test

* reorder

* refactor

* make fix-copies

* update docs

* fixes

* more fixes

* make style

* update tests

* update code example

* make fix-copies

* refactor based on reviews

* use maybe_free_model_hooks

* CacheMixin

* make style

* update

* add current_timestep property; update docs

* make fix-copies

* update

* improve tests

* try circular import fix

* apply suggestions from review

* address review comments

* Apply suggestions from code review

* refactor hook implementation

* add test suite for hooks

* PAB Refactor (#10667)

* update

* update

* update

---------

Co-authored-by: DN6 <dhruv.nair@gmail.com>

* update

* fix remove hook behaviour

---------

Co-authored-by: Xuanlei Zhao <43881818+oahzxl@users.noreply.github.com>
Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com>
Co-authored-by: DN6 <dhruv.nair@gmail.com>
2025-01-28 05:09:04 +05:30
Giuseppe Catalano fb42066489 Revert RePaint scheduler 'fix' (#10644)
Co-authored-by: Giuseppe Catalano <giuseppelorenzo.catalano@unito.it>
2025-01-27 11:16:45 -10:00
Teriks e89ab5bc26 SDXL ControlNet Union pipelines, make control_image argument immutible (#10663)
controlnet union XL, make control_image immutible

when this argument is passed a list, __call__
modifies its content, since it is pass by reference
the list passed by the caller gets its content
modified unexpectedly

make a copy at method intro so this does not happen

Co-authored-by: Teriks <Teriks@users.noreply.github.com>
2025-01-27 10:53:30 -10:00
victolee0 8ceec90d76 fix check_inputs func in LuminaText2ImgPipeline (#10651) 2025-01-27 09:47:01 -10:00
hlky 158c5c4d08 Add provider_options to OnnxRuntimeModel (#10661) 2025-01-27 09:46:17 -10:00
hlky 41571773d9 [training] Convert to ImageFolder script (#10664)
* [training] Convert to ImageFolder script

* make
2025-01-27 09:43:51 -10:00
hlky 18f7d1d937 ControlNet Union controlnet_conditioning_scale for multiple control inputs (#10666) 2025-01-27 08:15:25 -10:00
Marlon May f7f36c7d3d Add community pipeline for semantic guidance for FLUX (#10610)
* add community pipeline for semantic guidance for flux

* fix imports in community pipeline for semantic guidance for flux

* Update examples/community/pipeline_flux_semantic_guidance.py

Co-authored-by: hlky <hlky@hlky.ac>

* fix community pipeline for semantic guidance for flux

---------

Co-authored-by: Linoy Tsaban <57615435+linoytsaban@users.noreply.github.com>
Co-authored-by: hlky <hlky@hlky.ac>
2025-01-27 16:19:46 +02:00
Yuqian Hong 4fa24591a3 create a script to train autoencoderkl (#10605)
* create a script to train vae

* update main.py

* update train_autoencoderkl.py

* update train_autoencoderkl.py

* add a check of --pretrained_model_name_or_path and --model_config_name_or_path

* remove the comment, remove diffusers in requiremnets.txt, add validation_image ote

* update autoencoderkl.py

* quality

---------

Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
2025-01-27 16:41:34 +05:30
Jacob Helwig 4f3ec5364e Add sigmoid scheduler in scheduling_ddpm.py docs (#10648)
Sigmoid scheduler in scheduling_ddpm.py docs
2025-01-26 15:37:20 -08:00
Leo Jiang 07860f9916 NPU Adaption for Sanna (#10409)
* NPU Adaption for Sanna


---------

Co-authored-by: J石页 <jiangshuo9@h-partners.com>
Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
2025-01-24 09:08:52 -10:00
Wenhao Sun 87252d80c3 Add pipeline_stable_diffusion_xl_attentive_eraser (#10579)
* add pipeline_stable_diffusion_xl_attentive_eraser

* add pipeline_stable_diffusion_xl_attentive_eraser_make_style

* make style and add example output

* update Docs

Co-authored-by: Other Contributor <a457435687@126.com>

* add Oral

Co-authored-by: Other Contributor <a457435687@126.com>

* update_review

Co-authored-by: Other Contributor <a457435687@126.com>

* update_review_ms

Co-authored-by: Other Contributor <a457435687@126.com>

---------

Co-authored-by: Other Contributor <a457435687@126.com>
2025-01-24 13:52:45 +00:00
Sayak Paul 5897137397 [chore] add a script to extract loras from full fine-tuned models (#10631)
* feat: add a lora extraction script.

* updates
2025-01-24 11:50:36 +05:30
Yaniv Galron a451c0ed14 removing redundant requires_grad = False (#10628)
We already set the unet to requires grad false at line 506

Co-authored-by: Aryan <aryan@huggingface.co>
2025-01-24 03:25:33 +05:30
hlky 37c9697f5b Add IP-Adapter example to Flux docs (#10633)
* Add IP-Adapter example to Flux docs

* Apply suggestions from code review

Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>

---------

Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
2025-01-23 22:15:33 +05:30
Raul Ciotescu 9684c52adf width and height are mixed-up (#10629)
vars mixed-up
2025-01-23 06:40:22 -10:00
Steven Liu 5483162d12 [docs] uv installation (#10622)
* uv

* feedback
2025-01-23 08:34:51 -08:00
Sayak Paul d77c53b6d2 [docs] fix image path in para attention docs (#10632)
fix image path in para attention docs
2025-01-23 08:22:42 -08:00
Sayak Paul 78bc824729 [Tests] modify the test slices for the failing flax test (#10630)
* fixes

* fixes

* fixes

* updates
2025-01-23 12:10:24 +05:30
kahmed10 04d40920a7 add onnxruntime-migraphx as part of check for onnxruntime in import_utils.py (#10624)
add onnxruntime-migraphx to import_utils.py

Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
2025-01-23 07:49:51 +05:30
Dhruv Nair 8d6f6d6b66 [CI] Update HF_TOKEN in all workflows (#10613)
update
2025-01-22 20:03:41 +05:30
Aryan ca60ad8e55 Improve TorchAO error message (#10627)
improve error message
2025-01-22 19:50:02 +05:30
Aryan beacaa5528 [core] Layerwise Upcasting (#10347)
* update

* update

* make style

* remove dynamo disable

* add coauthor

Co-Authored-By: Dhruv Nair <dhruv.nair@gmail.com>

* update

* update

* update

* update mixin

* add some basic tests

* update

* update

* non_blocking

* improvements

* update

* norm.* -> norm

* apply suggestions from review

* add example

* update hook implementation to the latest changes from pyramid attention broadcast

* deinitialize should raise an error

* update doc page

* Apply suggestions from code review

Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com>

* update docs

* update

* refactor

* fix _always_upcast_modules for asym ae and vq_model

* fix lumina embedding forward to not depend on weight dtype

* refactor tests

* add simple lora inference tests

* _always_upcast_modules -> _precision_sensitive_module_patterns

* remove todo comments about review; revert changes to self.dtype in unets because .dtype on ModelMixin should be able to handle fp8 weight case

* check layer dtypes in lora test

* fix UNet1DModelTests::test_layerwise_upcasting_inference

* _precision_sensitive_module_patterns -> _skip_layerwise_casting_patterns based on feedback

* skip test in NCSNppModelTests

* skip tests for AutoencoderTinyTests

* skip tests for AutoencoderOobleckTests

* skip tests for UNet1DModelTests - unsupported pytorch operations

* layerwise_upcasting -> layerwise_casting

* skip tests for UNetRLModelTests; needs next pytorch release for currently unimplemented operation support

* add layerwise fp8 pipeline test

* use xfail

* Apply suggestions from code review

Co-authored-by: Dhruv Nair <dhruv.nair@gmail.com>

* add assertion with fp32 comparison; add tolerance to fp8-fp32 vs fp32-fp32 comparison (required for a few models' test to pass)

* add note about memory consumption on tesla CI runner for failing test

---------

Co-authored-by: Dhruv Nair <dhruv.nair@gmail.com>
Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com>
2025-01-22 19:49:37 +05:30
172 changed files with 9646 additions and 2155 deletions
+3 -3
View File
@@ -265,7 +265,7 @@ jobs:
- name: Run PyTorch CUDA tests
env:
HF_TOKEN: ${{ secrets.HF_TOKEN }}
HF_TOKEN: ${{ secrets.DIFFUSERS_HF_HUB_READ_TOKEN }}
# https://pytorch.org/docs/stable/notes/randomness.html#avoiding-nondeterministic-algorithms
CUBLAS_WORKSPACE_CONFIG: :16:8
run: |
@@ -505,7 +505,7 @@ jobs:
# shell: arch -arch arm64 bash {0}
# env:
# HF_HOME: /System/Volumes/Data/mnt/cache
# HF_TOKEN: ${{ secrets.HF_TOKEN }}
# HF_TOKEN: ${{ secrets.DIFFUSERS_HF_HUB_READ_TOKEN }}
# run: |
# ${CONDA_RUN} python -m pytest -n 1 -s -v --make-reports=tests_torch_mps \
# --report-log=tests_torch_mps.log \
@@ -561,7 +561,7 @@ jobs:
# shell: arch -arch arm64 bash {0}
# env:
# HF_HOME: /System/Volumes/Data/mnt/cache
# HF_TOKEN: ${{ secrets.HF_TOKEN }}
# HF_TOKEN: ${{ secrets.DIFFUSERS_HF_HUB_READ_TOKEN }}
# run: |
# ${CONDA_RUN} python -m pytest -n 1 -s -v --make-reports=tests_torch_mps \
# --report-log=tests_torch_mps.log \
+5 -5
View File
@@ -187,7 +187,7 @@ jobs:
- name: Run Flax TPU tests
env:
HF_TOKEN: ${{ secrets.HF_TOKEN }}
HF_TOKEN: ${{ secrets.DIFFUSERS_HF_HUB_READ_TOKEN }}
run: |
python -m pytest -n 0 \
-s -v -k "Flax" \
@@ -235,7 +235,7 @@ jobs:
- name: Run ONNXRuntime CUDA tests
env:
HF_TOKEN: ${{ secrets.HF_TOKEN }}
HF_TOKEN: ${{ secrets.DIFFUSERS_HF_HUB_READ_TOKEN }}
run: |
python -m pytest -n 1 --max-worker-restart=0 --dist=loadfile \
-s -v -k "Onnx" \
@@ -283,7 +283,7 @@ jobs:
python utils/print_env.py
- name: Run example tests on GPU
env:
HF_TOKEN: ${{ secrets.HF_TOKEN }}
HF_TOKEN: ${{ secrets.DIFFUSERS_HF_HUB_READ_TOKEN }}
RUN_COMPILE: yes
run: |
python -m pytest -n 1 --max-worker-restart=0 --dist=loadfile -s -v -k "compile" --make-reports=tests_torch_compile_cuda tests/
@@ -326,7 +326,7 @@ jobs:
python utils/print_env.py
- name: Run example tests on GPU
env:
HF_TOKEN: ${{ secrets.HF_TOKEN }}
HF_TOKEN: ${{ secrets.DIFFUSERS_HF_HUB_READ_TOKEN }}
run: |
python -m pytest -n 1 --max-worker-restart=0 --dist=loadfile -s -v -k "xformers" --make-reports=tests_torch_xformers_cuda tests/
- name: Failure short reports
@@ -372,7 +372,7 @@ jobs:
- name: Run example tests on GPU
env:
HF_TOKEN: ${{ secrets.HF_TOKEN }}
HF_TOKEN: ${{ secrets.DIFFUSERS_HF_HUB_READ_TOKEN }}
run: |
python -m venv /opt/venv && export PATH="/opt/venv/bin:$PATH"
python -m uv pip install timm
+8 -8
View File
@@ -81,7 +81,7 @@ jobs:
python utils/print_env.py
- name: Slow PyTorch CUDA checkpoint tests on Ubuntu
env:
HF_TOKEN: ${{ secrets.HF_TOKEN }}
HF_TOKEN: ${{ secrets.DIFFUSERS_HF_HUB_READ_TOKEN }}
# https://pytorch.org/docs/stable/notes/randomness.html#avoiding-nondeterministic-algorithms
CUBLAS_WORKSPACE_CONFIG: :16:8
run: |
@@ -135,7 +135,7 @@ jobs:
- name: Run PyTorch CUDA tests
env:
HF_TOKEN: ${{ secrets.HF_TOKEN }}
HF_TOKEN: ${{ secrets.DIFFUSERS_HF_HUB_READ_TOKEN }}
# https://pytorch.org/docs/stable/notes/randomness.html#avoiding-nondeterministic-algorithms
CUBLAS_WORKSPACE_CONFIG: :16:8
run: |
@@ -186,7 +186,7 @@ jobs:
- name: Run PyTorch CUDA tests
env:
HF_TOKEN: ${{ secrets.HF_TOKEN }}
HF_TOKEN: ${{ secrets.DIFFUSERS_HF_HUB_READ_TOKEN }}
# https://pytorch.org/docs/stable/notes/randomness.html#avoiding-nondeterministic-algorithms
CUBLAS_WORKSPACE_CONFIG: :16:8
run: |
@@ -241,7 +241,7 @@ jobs:
- name: Run slow Flax TPU tests
env:
HF_TOKEN: ${{ secrets.HF_TOKEN }}
HF_TOKEN: ${{ secrets.DIFFUSERS_HF_HUB_READ_TOKEN }}
run: |
python -m pytest -n 0 \
-s -v -k "Flax" \
@@ -289,7 +289,7 @@ jobs:
- name: Run slow ONNXRuntime CUDA tests
env:
HF_TOKEN: ${{ secrets.HF_TOKEN }}
HF_TOKEN: ${{ secrets.DIFFUSERS_HF_HUB_READ_TOKEN }}
run: |
python -m pytest -n 1 --max-worker-restart=0 --dist=loadfile \
-s -v -k "Onnx" \
@@ -337,7 +337,7 @@ jobs:
python utils/print_env.py
- name: Run example tests on GPU
env:
HF_TOKEN: ${{ secrets.HF_TOKEN }}
HF_TOKEN: ${{ secrets.DIFFUSERS_HF_HUB_READ_TOKEN }}
RUN_COMPILE: yes
run: |
python -m pytest -n 1 --max-worker-restart=0 --dist=loadfile -s -v -k "compile" --make-reports=tests_torch_compile_cuda tests/
@@ -380,7 +380,7 @@ jobs:
python utils/print_env.py
- name: Run example tests on GPU
env:
HF_TOKEN: ${{ secrets.HF_TOKEN }}
HF_TOKEN: ${{ secrets.DIFFUSERS_HF_HUB_READ_TOKEN }}
run: |
python -m pytest -n 1 --max-worker-restart=0 --dist=loadfile -s -v -k "xformers" --make-reports=tests_torch_xformers_cuda tests/
- name: Failure short reports
@@ -426,7 +426,7 @@ jobs:
- name: Run example tests on GPU
env:
HF_TOKEN: ${{ secrets.HF_TOKEN }}
HF_TOKEN: ${{ secrets.DIFFUSERS_HF_HUB_READ_TOKEN }}
run: |
python -m venv /opt/venv && export PATH="/opt/venv/bin:$PATH"
python -m uv pip install timm
+2
View File
@@ -598,6 +598,8 @@
title: Attention Processor
- local: api/activations
title: Custom activation functions
- local: api/cache
title: Caching methods
- local: api/normalization
title: Custom normalization layers
- local: api/utilities
+49
View File
@@ -0,0 +1,49 @@
<!-- Copyright 2024 The HuggingFace Team. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
specific language governing permissions and limitations under the License. -->
# Caching methods
## Pyramid Attention Broadcast
[Pyramid Attention Broadcast](https://huggingface.co/papers/2408.12588) from Xuanlei Zhao, Xiaolong Jin, Kai Wang, Yang You.
Pyramid Attention Broadcast (PAB) is a method that speeds up inference in diffusion models by systematically skipping attention computations between successive inference steps and reusing cached attention states. The attention states are not very different between successive inference steps. The most prominent difference is in the spatial attention blocks, not as much in the temporal attention blocks, and finally the least in the cross attention blocks. Therefore, many cross attention computation blocks can be skipped, followed by the temporal and spatial attention blocks. By combining other techniques like sequence parallelism and classifier-free guidance parallelism, PAB achieves near real-time video generation.
Enable PAB with [`~PyramidAttentionBroadcastConfig`] on any pipeline. For some benchmarks, refer to [this](https://github.com/huggingface/diffusers/pull/9562) pull request.
```python
import torch
from diffusers import CogVideoXPipeline, PyramidAttentionBroadcastConfig
pipe = CogVideoXPipeline.from_pretrained("THUDM/CogVideoX-5b", torch_dtype=torch.bfloat16)
pipe.to("cuda")
# Increasing the value of `spatial_attention_timestep_skip_range[0]` or decreasing the value of
# `spatial_attention_timestep_skip_range[1]` will decrease the interval in which pyramid attention
# broadcast is active, leader to slower inference speeds. However, large intervals can lead to
# poorer quality of generated videos.
config = PyramidAttentionBroadcastConfig(
spatial_attention_block_skip_range=2,
spatial_attention_timestep_skip_range=(100, 800),
current_timestep_callback=lambda: pipe.current_timestep,
)
pipe.transformer.enable_cache(config)
```
### CacheMixin
[[autodoc]] CacheMixin
### PyramidAttentionBroadcastConfig
[[autodoc]] PyramidAttentionBroadcastConfig
[[autodoc]] apply_pyramid_attention_broadcast
+47
View File
@@ -309,6 +309,53 @@ image.save("output.png")
When unloading the Control LoRA weights, call `pipe.unload_lora_weights(reset_to_overwritten_params=True)` to reset the `pipe.transformer` completely back to its original form. The resultant pipeline can then be used with methods like [`DiffusionPipeline.from_pipe`]. More details about this argument are available in [this PR](https://github.com/huggingface/diffusers/pull/10397).
## IP-Adapter
<Tip>
Check out [IP-Adapter](../../../using-diffusers/ip_adapter) to learn more about how IP-Adapters work.
</Tip>
An IP-Adapter lets you prompt Flux with images, in addition to the text prompt. This is especially useful when describing complex concepts that are difficult to articulate through text alone and you have reference images.
```python
import torch
from diffusers import FluxPipeline
from diffusers.utils import load_image
pipe = FluxPipeline.from_pretrained(
"black-forest-labs/FLUX.1-dev", torch_dtype=torch.bfloat16
).to("cuda")
image = load_image("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/flux_ip_adapter_input.jpg").resize((1024, 1024))
pipe.load_ip_adapter(
"XLabs-AI/flux-ip-adapter",
weight_name="ip_adapter.safetensors",
image_encoder_pretrained_model_name_or_path="openai/clip-vit-large-patch14"
)
pipe.set_ip_adapter_scale(1.0)
image = pipe(
width=1024,
height=1024,
prompt="wearing sunglasses",
negative_prompt="",
true_cfg=4.0,
generator=torch.Generator().manual_seed(4444),
ip_adapter_image=image,
).images[0]
image.save('flux_ip_adapter_output.jpg')
```
<div class="justify-center">
<img src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/flux_ip_adapter_output.jpg"/>
<figcaption class="mt-2 text-sm text-center text-gray-500">IP-Adapter examples with prompt "wearing sunglasses"</figcaption>
</div>
## Running FP16 inference
Flux can generate high-quality images with FP16 (i.e. to accelerate inference on Turing/Volta GPUs) but produces different outputs compared to FP32/BF16. The issue is that some activations in the text encoders have to be clipped when running in FP16, which affects the overall image. Forcing text encoders to run with FP32 inference thus removes this output difference. See [here](https://github.com/huggingface/diffusers/pull/9097#issuecomment-2272292516) for details.
+4
View File
@@ -41,3 +41,7 @@ Utility and helper functions for working with 🤗 Diffusers.
## randn_tensor
[[autodoc]] utils.torch_utils.randn_tensor
## apply_layerwise_casting
[[autodoc]] hooks.layerwise_casting.apply_layerwise_casting
+34 -6
View File
@@ -23,32 +23,60 @@ You should install 🤗 Diffusers in a [virtual environment](https://docs.python
If you're unfamiliar with Python virtual environments, take a look at this [guide](https://packaging.python.org/guides/installing-using-pip-and-virtual-environments/).
A virtual environment makes it easier to manage different projects and avoid compatibility issues between dependencies.
Start by creating a virtual environment in your project directory:
Create a virtual environment with Python or [uv](https://docs.astral.sh/uv/) (refer to [Installation](https://docs.astral.sh/uv/getting-started/installation/) for installation instructions), a fast Rust-based Python package and project manager.
<hfoptions id="install">
<hfoption id="uv">
```bash
python -m venv .env
uv venv my-env
source my-env/bin/activate
```
Activate the virtual environment:
</hfoption>
<hfoption id="Python">
```bash
source .env/bin/activate
python -m venv my-env
source my-env/bin/activate
```
You should also install 🤗 Transformers because 🤗 Diffusers relies on its models:
</hfoption>
</hfoptions>
You should also install 🤗 Transformers because 🤗 Diffusers relies on its models.
<frameworkcontent>
<pt>
Note - PyTorch only supports Python 3.8 - 3.11 on Windows.
PyTorch only supports Python 3.8 - 3.11 on Windows. Install Diffusers with uv.
```bash
uv install diffusers["torch"] transformers
```
You can also install Diffusers with pip.
```bash
pip install diffusers["torch"] transformers
```
</pt>
<jax>
Install Diffusers with uv.
```bash
uv pip install diffusers["flax"] transformers
```
You can also install Diffusers with pip.
```bash
pip install diffusers["flax"] transformers
```
</jax>
</frameworkcontent>
+37
View File
@@ -158,6 +158,43 @@ In order to properly offload models after they're called, it is required to run
</Tip>
## FP8 layerwise weight-casting
PyTorch supports `torch.float8_e4m3fn` and `torch.float8_e5m2` as weight storage dtypes, but they can't be used for computation in many different tensor operations due to unimplemented kernel support. However, you can use these dtypes to store model weights in fp8 precision and upcast them on-the-fly when the layers are used in the forward pass. This is known as layerwise weight-casting.
Typically, inference on most models is done with `torch.float16` or `torch.bfloat16` weight/computation precision. Layerwise weight-casting cuts down the memory footprint of the model weights by approximately half.
```python
import torch
from diffusers import CogVideoXPipeline, CogVideoXTransformer3DModel
from diffusers.utils import export_to_video
model_id = "THUDM/CogVideoX-5b"
# Load the model in bfloat16 and enable layerwise casting
transformer = CogVideoXTransformer3DModel.from_pretrained(model_id, subfolder="transformer", torch_dtype=torch.bfloat16)
transformer.enable_layerwise_casting(storage_dtype=torch.float8_e4m3fn, compute_dtype=torch.bfloat16)
# Load the pipeline
pipe = CogVideoXPipeline.from_pretrained(model_id, transformer=transformer, torch_dtype=torch.bfloat16)
pipe.to("cuda")
prompt = (
"A panda, dressed in a small, red jacket and a tiny hat, sits on a wooden stool in a serene bamboo forest. "
"The panda's fluffy paws strum a miniature acoustic guitar, producing soft, melodic tunes. Nearby, a few other "
"pandas gather, watching curiously and some clapping in rhythm. Sunlight filters through the tall bamboo, "
"casting a gentle glow on the scene. The panda's face is expressive, showing concentration and joy as it plays. "
"The background includes a small, flowing stream and vibrant green foliage, enhancing the peaceful and magical "
"atmosphere of this unique musical performance."
)
video = pipe(prompt=prompt, guidance_scale=6, num_inference_steps=50).frames[0]
export_to_video(video, "output.mp4", fps=8)
```
In the above example, layerwise casting is enabled on the transformer component of the pipeline. By default, certain layers are skipped from the FP8 weight casting because it can lead to significant degradation of generation quality. The normalization and modulation related weight parameters are also skipped by default.
However, you gain more control and flexibility by directly utilizing the [`~hooks.layerwise_casting.apply_layerwise_casting`] function instead of [`~ModelMixin.enable_layerwise_casting`].
## Channels-last memory format
The channels-last memory format is an alternative way of ordering NCHW tensors in memory to preserve dimension ordering. Channels-last tensors are ordered in such a way that the channels become the densest dimension (storing images pixel-per-pixel). Since not all operators currently support the channels-last format, it may result in worst performance but you should still try and see if it works for your model.
+1 -1
View File
@@ -29,7 +29,7 @@ However, it is hard to decide when to reuse the cache to ensure quality generate
This achieves a 2x speedup on FLUX.1-dev and HunyuanVideo inference with very good quality.
<figure>
<img src="https://huggingface.co/datasets/chengzeyi/documentation-images/resolve/main/diffusers/para-attn/ada-cache.png" alt="Cache in Diffusion Transformer" />
<img src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/para-attn/ada-cache.png" alt="Cache in Diffusion Transformer" />
<figcaption>How AdaCache works, First Block Cache is a variant of it</figcaption>
</figure>
+2 -2
View File
@@ -461,12 +461,12 @@ Chain it to an upscaler pipeline to increase the image resolution:
from diffusers import StableDiffusionLatentUpscalePipeline
upscaler = StableDiffusionLatentUpscalePipeline.from_pretrained(
"stabilityai/sd-x2-latent-upscaler", torch_dtype=torch.float16, variant="fp16", use_safetensors=True
"stabilityai/sd-x2-latent-upscaler", torch_dtype=torch.float16, use_safetensors=True
)
upscaler.enable_model_cpu_offload()
upscaler.enable_xformers_memory_efficient_attention()
image_2 = upscaler(prompt, image=image_1, output_type="latent").images[0]
image_2 = upscaler(prompt, image=image_1).images[0]
```
Finally, chain it to a super-resolution pipeline to further enhance the resolution:
@@ -106,7 +106,7 @@ Let's try it out!
## Deconstruct the Stable Diffusion pipeline
Stable Diffusion is a text-to-image *latent diffusion* model. It is called a latent diffusion model because it works with a lower-dimensional representation of the image instead of the actual pixel space, which makes it more memory efficient. The encoder compresses the image into a smaller representation, and a decoder to convert the compressed representation back into an image. For text-to-image models, you'll need a tokenizer and an encoder to generate text embeddings. From the previous example, you already know you need a UNet model and a scheduler.
Stable Diffusion is a text-to-image *latent diffusion* model. It is called a latent diffusion model because it works with a lower-dimensional representation of the image instead of the actual pixel space, which makes it more memory efficient. The encoder compresses the image into a smaller representation, and a decoder converts the compressed representation back into an image. For text-to-image models, you'll need a tokenizer and an encoder to generate text embeddings. From the previous example, you already know you need a UNet model and a scheduler.
As you can see, this is already more complex than the DDPM pipeline which only contains a UNet model. The Stable Diffusion model has three separate pretrained models.
Executable → Regular
+191 -30
View File
@@ -24,8 +24,8 @@ Please also check out our [Community Scripts](https://github.com/huggingface/dif
| Speech to Image | Using automatic-speech-recognition to transcribe text and Stable Diffusion to generate images | [Speech to Image](#speech-to-image) |[Notebook](https://github.com/huggingface/notebooks/blob/main/diffusers/speech_to_image.ipynb) | [Mikail Duzenli](https://github.com/MikailINTech)
| Wild Card Stable Diffusion | Stable Diffusion Pipeline that supports prompts that contain wildcard terms (indicated by surrounding double underscores), with values instantiated randomly from a corresponding txt file or a dictionary of possible values | [Wildcard Stable Diffusion](#wildcard-stable-diffusion) | [Notebook](https://github.com/huggingface/notebooks/blob/main/diffusers/wildcard_stable_diffusion.ipynb) | [Shyam Sudhakaran](https://github.com/shyamsn97) |
| [Composable Stable Diffusion](https://energy-based-model.github.io/Compositional-Visual-Generation-with-Composable-Diffusion-Models/) | Stable Diffusion Pipeline that supports prompts that contain "&#124;" in prompts (as an AND condition) and weights (separated by "&#124;" as well) to positively / negatively weight prompts. | [Composable Stable Diffusion](#composable-stable-diffusion) | - | [Mark Rich](https://github.com/MarkRich) |
| Seed Resizing Stable Diffusion | Stable Diffusion Pipeline that supports resizing an image and retaining the concepts of the 512 by 512 generation. | [Seed Resizing](#seed-resizing) | - | [Mark Rich](https://github.com/MarkRich) |
| Imagic Stable Diffusion | Stable Diffusion Pipeline that enables writing a text prompt to edit an existing image | [Imagic Stable Diffusion](#imagic-stable-diffusion) | - | [Mark Rich](https://github.com/MarkRich) |
| Seed Resizing Stable Diffusion | Stable Diffusion Pipeline that supports resizing an image and retaining the concepts of the 512 by 512 generation. | [Seed Resizing](#seed-resizing) | [Notebook](https://github.com/huggingface/notebooks/blob/main/diffusers/seed_resizing.ipynb) | [Mark Rich](https://github.com/MarkRich) |
| Imagic Stable Diffusion | Stable Diffusion Pipeline that enables writing a text prompt to edit an existing image | [Imagic Stable Diffusion](#imagic-stable-diffusion) | [Notebook](https://github.com/huggingface/notebooks/blob/main/diffusers/imagic_stable_diffusion.ipynb) | [Mark Rich](https://github.com/MarkRich) |
| Multilingual Stable Diffusion | Stable Diffusion Pipeline that supports prompts in 50 different languages. | [Multilingual Stable Diffusion](#multilingual-stable-diffusion-pipeline) | [Notebook](https://github.com/huggingface/notebooks/blob/main/diffusers/multilingual_stable_diffusion.ipynb) | [Juan Carlos Piñeros](https://github.com/juancopi81) |
| GlueGen Stable Diffusion | Stable Diffusion Pipeline that supports prompts in different languages using GlueGen adapter. | [GlueGen Stable Diffusion](#gluegen-stable-diffusion-pipeline) | - | [Phạm Hồng Vinh](https://github.com/rootonchair) |
| Image to Image Inpainting Stable Diffusion | Stable Diffusion Pipeline that enables the overlaying of two images and subsequent inpainting | [Image to Image Inpainting Stable Diffusion](#image-to-image-inpainting-stable-diffusion) | - | [Alex McKinney](https://github.com/vvvm23) |
@@ -37,7 +37,7 @@ Please also check out our [Community Scripts](https://github.com/huggingface/dif
| MagicMix | Diffusion Pipeline for semantic mixing of an image and a text prompt | [MagicMix](#magic-mix) | - | [Partho Das](https://github.com/daspartho) |
| Stable UnCLIP | Diffusion Pipeline for combining prior model (generate clip image embedding from text, UnCLIPPipeline `"kakaobrain/karlo-v1-alpha"`) and decoder pipeline (decode clip image embedding to image, StableDiffusionImageVariationPipeline `"lambdalabs/sd-image-variations-diffusers"` ). | [Stable UnCLIP](#stable-unclip) | [Notebook](https://github.com/huggingface/notebooks/blob/main/diffusers/stable_unclip.ipynb) | [Ray Wang](https://wrong.wang) |
| UnCLIP Text Interpolation Pipeline | Diffusion Pipeline that allows passing two prompts and produces images while interpolating between the text-embeddings of the two prompts | [UnCLIP Text Interpolation Pipeline](#unclip-text-interpolation-pipeline) | [Notebook](https://github.com/huggingface/notebooks/blob/main/diffusers/unclip_text_interpolation.ipynb)| [Naga Sai Abhinay Devarinti](https://github.com/Abhinay1997/) |
| UnCLIP Image Interpolation Pipeline | Diffusion Pipeline that allows passing two images/image_embeddings and produces images while interpolating between their image-embeddings | [UnCLIP Image Interpolation Pipeline](#unclip-image-interpolation-pipeline) | - | [Naga Sai Abhinay Devarinti](https://github.com/Abhinay1997/) |
| UnCLIP Image Interpolation Pipeline | Diffusion Pipeline that allows passing two images/image_embeddings and produces images while interpolating between their image-embeddings | [UnCLIP Image Interpolation Pipeline](#unclip-image-interpolation-pipeline) | [Notebook](https://github.com/huggingface/notebooks/blob/main/diffusers/unclip_image_interpolation.ipynb)| [Naga Sai Abhinay Devarinti](https://github.com/Abhinay1997/) |
| DDIM Noise Comparative Analysis Pipeline | Investigating how the diffusion models learn visual concepts from each noise level (which is a contribution of [P2 weighting (CVPR 2022)](https://arxiv.org/abs/2204.00227)) | [DDIM Noise Comparative Analysis Pipeline](#ddim-noise-comparative-analysis-pipeline) | [Notebook](https://github.com/huggingface/notebooks/blob/main/diffusers/ddim_noise_comparative_analysis.ipynb)| [Aengus (Duc-Anh)](https://github.com/aengusng8) |
| CLIP Guided Img2Img Stable Diffusion Pipeline | Doing CLIP guidance for image to image generation with Stable Diffusion | [CLIP Guided Img2Img Stable Diffusion](#clip-guided-img2img-stable-diffusion) | - | [Nipun Jindal](https://github.com/nipunjindal/) |
| TensorRT Stable Diffusion Text to Image Pipeline | Accelerates the Stable Diffusion Text2Image Pipeline using TensorRT | [TensorRT Stable Diffusion Text to Image Pipeline](#tensorrt-text2image-stable-diffusion-pipeline) | - | [Asfiya Baig](https://github.com/asfiyab-nvidia) |
@@ -57,7 +57,7 @@ Please also check out our [Community Scripts](https://github.com/huggingface/dif
| Latent Consistency Pipeline | Implementation of [Latent Consistency Models: Synthesizing High-Resolution Images with Few-Step Inference](https://arxiv.org/abs/2310.04378) | [Latent Consistency Pipeline](#latent-consistency-pipeline) | - | [Simian Luo](https://github.com/luosiallen) |
| Latent Consistency Img2img Pipeline | Img2img pipeline for Latent Consistency Models | [Latent Consistency Img2Img Pipeline](#latent-consistency-img2img-pipeline) | - | [Logan Zoellner](https://github.com/nagolinc) |
| Latent Consistency Interpolation Pipeline | Interpolate the latent space of Latent Consistency Models with multiple prompts | [Latent Consistency Interpolation Pipeline](#latent-consistency-interpolation-pipeline) | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1pK3NrLWJSiJsBynLns1K1-IDTW9zbPvl?usp=sharing) | [Aryan V S](https://github.com/a-r-r-o-w) |
| SDE Drag Pipeline | The pipeline supports drag editing of images using stochastic differential equations | [SDE Drag Pipeline](#sde-drag-pipeline) | - | [NieShen](https://github.com/NieShenRuc) [Fengqi Zhu](https://github.com/Monohydroxides) |
| SDE Drag Pipeline | The pipeline supports drag editing of images using stochastic differential equations | [SDE Drag Pipeline](#sde-drag-pipeline) | [Notebook](https://github.com/huggingface/notebooks/blob/main/diffusers/sde_drag.ipynb) | [NieShen](https://github.com/NieShenRuc) [Fengqi Zhu](https://github.com/Monohydroxides) |
| Regional Prompting Pipeline | Assign multiple prompts for different regions | [Regional Prompting Pipeline](#regional-prompting-pipeline) | - | [hako-mikan](https://github.com/hako-mikan) |
| LDM3D-sr (LDM3D upscaler) | Upscale low resolution RGB and depth inputs to high resolution | [StableDiffusionUpscaleLDM3D Pipeline](https://github.com/estelleafl/diffusers/tree/ldm3d_upscaler_community/examples/community#stablediffusionupscaleldm3d-pipeline) | - | [Estelle Aflalo](https://github.com/estelleafl) |
| AnimateDiff ControlNet Pipeline | Combines AnimateDiff with precise motion control using ControlNets | [AnimateDiff ControlNet Pipeline](#animatediff-controlnet-pipeline) | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1SKboYeGjEQmQPWoFC0aLYpBlYdHXkvAu?usp=sharing) | [Aryan V S](https://github.com/a-r-r-o-w) and [Edoardo Botta](https://github.com/EdoardoBotta) |
@@ -77,6 +77,7 @@ Please also check out our [Community Scripts](https://github.com/huggingface/dif
PIXART-α Controlnet pipeline | Implementation of the controlnet model for pixart alpha and its diffusers pipeline | [PIXART-α Controlnet pipeline](#pixart-α-controlnet-pipeline) | - | [Raul Ciotescu](https://github.com/raulc0399/) |
| HunyuanDiT Differential Diffusion Pipeline | Applies [Differential Diffusion](https://github.com/exx8/differential-diffusion) to [HunyuanDiT](https://github.com/huggingface/diffusers/pull/8240). | [HunyuanDiT with Differential Diffusion](#hunyuandit-with-differential-diffusion) | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1v44a5fpzyr4Ffr4v2XBQ7BajzG874N4P?usp=sharing) | [Monjoy Choudhury](https://github.com/MnCSSJ4x) |
| [🪆Matryoshka Diffusion Models](https://huggingface.co/papers/2310.15111) | A diffusion process that denoises inputs at multiple resolutions jointly and uses a NestedUNet architecture where features and parameters for small scale inputs are nested within those of the large scales. See [original codebase](https://github.com/apple/ml-mdm). | [🪆Matryoshka Diffusion Models](#matryoshka-diffusion-models) | [![Hugging Face Space](https://img.shields.io/badge/🤗%20Hugging%20Face-Space-yellow)](https://huggingface.co/spaces/pcuenq/mdm) [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/gist/tolgacangoz/1f54875fc7aeaabcf284ebde64820966/matryoshka_hf.ipynb) | [M. Tolga Cangöz](https://github.com/tolgacangoz) |
| Stable Diffusion XL Attentive Eraser Pipeline |[[AAAI2025 Oral] Attentive Eraser](https://github.com/Anonym0u3/AttentiveEraser) is a novel tuning-free method that enhances object removal capabilities in pre-trained diffusion models.|[Stable Diffusion XL Attentive Eraser Pipeline](#stable-diffusion-xl-attentive-eraser-pipeline)|-|[Wenhao Sun](https://github.com/Anonym0u3) and [Benlei Cui](https://github.com/Benny079)|
To load a custom pipeline you just need to pass the `custom_pipeline` argument to `DiffusionPipeline`, as one of the files in `diffusers/examples/community`. Feel free to send a PR with your own pipelines, we will merge them quickly.
@@ -947,10 +948,15 @@ image.save('./imagic/imagic_image_alpha_2.png')
Test seed resizing. Originally generate an image in 512 by 512, then generate image with same seed at 512 by 592 using seed resizing. Finally, generate 512 by 592 using original stable diffusion pipeline.
```python
import os
import torch as th
import numpy as np
from diffusers import DiffusionPipeline
# Ensure the save directory exists or create it
save_dir = './seed_resize/'
os.makedirs(save_dir, exist_ok=True)
has_cuda = th.cuda.is_available()
device = th.device('cpu' if not has_cuda else 'cuda')
@@ -964,7 +970,6 @@ def dummy(images, **kwargs):
pipe.safety_checker = dummy
images = []
th.manual_seed(0)
generator = th.Generator("cuda").manual_seed(0)
@@ -983,15 +988,14 @@ res = pipe(
width=width,
generator=generator)
image = res.images[0]
image.save('./seed_resize/seed_resize_{w}_{h}_image.png'.format(w=width, h=height))
image.save(os.path.join(save_dir, 'seed_resize_{w}_{h}_image.png'.format(w=width, h=height)))
th.manual_seed(0)
generator = th.Generator("cuda").manual_seed(0)
pipe = DiffusionPipeline.from_pretrained(
"CompVis/stable-diffusion-v1-4",
custom_pipeline="/home/mark/open_source/diffusers/examples/community/"
custom_pipeline="seed_resize_stable_diffusion"
).to(device)
width = 512
@@ -1005,11 +1009,11 @@ res = pipe(
width=width,
generator=generator)
image = res.images[0]
image.save('./seed_resize/seed_resize_{w}_{h}_image.png'.format(w=width, h=height))
image.save(os.path.join(save_dir, 'seed_resize_{w}_{h}_image.png'.format(w=width, h=height)))
pipe_compare = DiffusionPipeline.from_pretrained(
"CompVis/stable-diffusion-v1-4",
custom_pipeline="/home/mark/open_source/diffusers/examples/community/"
custom_pipeline="seed_resize_stable_diffusion"
).to(device)
res = pipe_compare(
@@ -1022,7 +1026,7 @@ res = pipe_compare(
)
image = res.images[0]
image.save('./seed_resize/seed_resize_{w}_{h}_image_compare.png'.format(w=width, h=height))
image.save(os.path.join(save_dir, 'seed_resize_{w}_{h}_image_compare.png'.format(w=width, h=height)))
```
### Multilingual Stable Diffusion Pipeline
@@ -1542,6 +1546,8 @@ This Diffusion Pipeline takes two images or an image_embeddings tensor of size 2
import torch
from diffusers import DiffusionPipeline
from PIL import Image
import requests
from io import BytesIO
device = torch.device("cpu" if not torch.cuda.is_available() else "cuda")
dtype = torch.float16 if torch.cuda.is_available() else torch.bfloat16
@@ -1553,13 +1559,25 @@ pipe = DiffusionPipeline.from_pretrained(
)
pipe.to(device)
images = [Image.open('./starry_night.jpg'), Image.open('./flowers.jpg')]
# List of image URLs
image_urls = [
'https://camo.githubusercontent.com/ef13c8059b12947c0d5e8d3ea88900de6bf1cd76bbf61ace3928e824c491290e/68747470733a2f2f68756767696e67666163652e636f2f64617461736574732f4e616761536169416268696e61792f556e434c4950496d616765496e746572706f6c6174696f6e53616d706c65732f7265736f6c76652f6d61696e2f7374617272795f6e696768742e6a7067',
'https://camo.githubusercontent.com/d1947ab7c49ae3f550c28409d5e8b120df48e456559cf4557306c0848337702c/68747470733a2f2f68756767696e67666163652e636f2f64617461736574732f4e616761536169416268696e61792f556e434c4950496d616765496e746572706f6c6174696f6e53616d706c65732f7265736f6c76652f6d61696e2f666c6f776572732e6a7067'
]
# Open images from URLs
images = []
for url in image_urls:
response = requests.get(url)
img = Image.open(BytesIO(response.content))
images.append(img)
# For best results keep the prompts close in length to each other. Of course, feel free to try out with differing lengths.
generator = torch.Generator(device=device).manual_seed(42)
output = pipe(image=images, steps=6, generator=generator)
for i,image in enumerate(output.images):
for i, image in enumerate(output.images):
image.save('starry_to_flowers_%s.jpg' % i)
```
@@ -3908,33 +3926,89 @@ This pipeline provides drag-and-drop image editing using stochastic differential
See [paper](https://arxiv.org/abs/2311.01410), [paper page](https://ml-gsai.github.io/SDE-Drag-demo/), [original repo](https://github.com/ML-GSAI/SDE-Drag) for more information.
```py
import PIL
import torch
from diffusers import DDIMScheduler, DiffusionPipeline
from PIL import Image
import requests
from io import BytesIO
import numpy as np
# Load the pipeline
model_path = "stable-diffusion-v1-5/stable-diffusion-v1-5"
scheduler = DDIMScheduler.from_pretrained(model_path, subfolder="scheduler")
pipe = DiffusionPipeline.from_pretrained(model_path, scheduler=scheduler, custom_pipeline="sde_drag")
pipe.to('cuda')
# To save GPU memory, torch.float16 can be used, but it may compromise image quality.
# If not training LoRA, please avoid using torch.float16
# pipe.to(torch.float16)
# Ensure the model is moved to the GPU
device = "cuda" if torch.cuda.is_available() else "cpu"
pipe.to(device)
# Provide prompt, image, mask image, and the starting and target points for drag editing.
prompt = "prompt of the image"
image = PIL.Image.open('/path/to/image')
mask_image = PIL.Image.open('/path/to/mask_image')
source_points = [[123, 456]]
target_points = [[234, 567]]
# Function to load image from URL
def load_image_from_url(url):
response = requests.get(url)
return Image.open(BytesIO(response.content)).convert("RGB")
# train_lora is optional, and in most cases, using train_lora can better preserve consistency with the original image.
pipe.train_lora(prompt, image)
# Function to prepare mask
def prepare_mask(mask_image):
# Convert to grayscale
mask = mask_image.convert("L")
return mask
output = pipe(prompt, image, mask_image, source_points, target_points)
output_image = PIL.Image.fromarray(output)
# Function to convert numpy array to PIL Image
def array_to_pil(array):
# Ensure the array is in uint8 format
if array.dtype != np.uint8:
if array.max() <= 1.0:
array = (array * 255).astype(np.uint8)
else:
array = array.astype(np.uint8)
# Handle different array shapes
if len(array.shape) == 3:
if array.shape[0] == 3: # If channels first
array = array.transpose(1, 2, 0)
return Image.fromarray(array)
elif len(array.shape) == 4: # If batch dimension
array = array[0]
if array.shape[0] == 3: # If channels first
array = array.transpose(1, 2, 0)
return Image.fromarray(array)
else:
raise ValueError(f"Unexpected array shape: {array.shape}")
# Image and mask URLs
image_url = 'https://raw.githubusercontent.com/CompVis/latent-diffusion/main/data/inpainting_examples/overture-creations-5sI6fQgYIuo.png'
mask_url = 'https://raw.githubusercontent.com/CompVis/latent-diffusion/main/data/inpainting_examples/overture-creations-5sI6fQgYIuo_mask.png'
# Load the images
image = load_image_from_url(image_url)
mask_image = load_image_from_url(mask_url)
# Resize images to a size that's compatible with the model's latent space
image = image.resize((512, 512))
mask_image = mask_image.resize((512, 512))
# Prepare the mask (keep as PIL Image)
mask = prepare_mask(mask_image)
# Provide the prompt and points for drag editing
prompt = "A cute dog"
source_points = [[32, 32]] # Adjusted for 512x512 image
target_points = [[64, 64]] # Adjusted for 512x512 image
# Generate the output image
output_array = pipe(
prompt=prompt,
image=image,
mask_image=mask,
source_points=source_points,
target_points=target_points
)
# Convert output array to PIL Image and save
output_image = array_to_pil(output_array)
output_image.save("./output.png")
print("Output image saved as './output.png'")
```
### Instaflow Pipeline
@@ -4585,8 +4659,8 @@ image = pipe(
```
| ![Gradient](https://github.com/user-attachments/assets/e38ce4d5-1ae6-4df0-ab43-adc1b45716b5) | ![Input](https://github.com/user-attachments/assets/9c95679c-e9d7-4f5a-90d6-560203acd6b3) | ![Output](https://github.com/user-attachments/assets/5313ff64-a0c4-418b-8b55-a38f1a5e7532) |
| ------------------------------------------------------------------------------------------ | --------------------------------------------------------------------------------------- | ---------------------------------------------------------------------------------------- |
| Gradient | Input | Output |
| -------------------------------------------------------------------------------------------- | ----------------------------------------------------------------------------------------- | ------------------------------------------------------------------------------------------ |
| Gradient | Input | Output |
A colab notebook demonstrating all results can be found [here](https://colab.research.google.com/drive/1v44a5fpzyr4Ffr4v2XBQ7BajzG874N4P?usp=sharing). Depth Maps have also been added in the same colab.
@@ -4634,6 +4708,93 @@ make_image_grid(image, rows=1, cols=len(image))
# 50+, 100+, and 250+ num_inference_steps are recommended for nesting levels 0, 1, and 2 respectively.
```
### Stable Diffusion XL Attentive Eraser Pipeline
<img src="https://raw.githubusercontent.com/Anonym0u3/Images/refs/heads/main/fenmian.png" width="600" />
**Stable Diffusion XL Attentive Eraser Pipeline** is an advanced object removal pipeline that leverages SDXL for precise content suppression and seamless region completion. This pipeline uses **self-attention redirection guidance** to modify the models self-attention mechanism, allowing for effective removal and inpainting across various levels of mask precision, including semantic segmentation masks, bounding boxes, and hand-drawn masks. If you are interested in more detailed information and have any questions, please refer to the [paper](https://arxiv.org/abs/2412.12974) and [official implementation](https://github.com/Anonym0u3/AttentiveEraser).
#### Key features
- **Tuning-Free**: No additional training is required, making it easy to integrate and use.
- **Flexible Mask Support**: Works with different types of masks for targeted object removal.
- **High-Quality Results**: Utilizes the inherent generative power of diffusion models for realistic content completion.
#### Usage example
To use the Stable Diffusion XL Attentive Eraser Pipeline, you can initialize it as follows:
```py
import torch
from diffusers import DDIMScheduler, DiffusionPipeline
from diffusers.utils import load_image
import torch.nn.functional as F
from torchvision.transforms.functional import to_tensor, gaussian_blur
dtype = torch.float16
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
scheduler = DDIMScheduler(beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", clip_sample=False, set_alpha_to_one=False)
pipeline = DiffusionPipeline.from_pretrained(
"stabilityai/stable-diffusion-xl-base-1.0",
custom_pipeline="pipeline_stable_diffusion_xl_attentive_eraser",
scheduler=scheduler,
variant="fp16",
use_safetensors=True,
torch_dtype=dtype,
).to(device)
def preprocess_image(image_path, device):
image = to_tensor((load_image(image_path)))
image = image.unsqueeze_(0).float() * 2 - 1 # [0,1] --> [-1,1]
if image.shape[1] != 3:
image = image.expand(-1, 3, -1, -1)
image = F.interpolate(image, (1024, 1024))
image = image.to(dtype).to(device)
return image
def preprocess_mask(mask_path, device):
mask = to_tensor((load_image(mask_path, convert_method=lambda img: img.convert('L'))))
mask = mask.unsqueeze_(0).float() # 0 or 1
mask = F.interpolate(mask, (1024, 1024))
mask = gaussian_blur(mask, kernel_size=(77, 77))
mask[mask < 0.1] = 0
mask[mask >= 0.1] = 1
mask = mask.to(dtype).to(device)
return mask
prompt = "" # Set prompt to null
seed=123
generator = torch.Generator(device=device).manual_seed(seed)
source_image_path = "https://raw.githubusercontent.com/Anonym0u3/Images/refs/heads/main/an1024.png"
mask_path = "https://raw.githubusercontent.com/Anonym0u3/Images/refs/heads/main/an1024_mask.png"
source_image = preprocess_image(source_image_path, device)
mask = preprocess_mask(mask_path, device)
image = pipeline(
prompt=prompt,
image=source_image,
mask_image=mask,
height=1024,
width=1024,
AAS=True, # enable AAS
strength=0.8, # inpainting strength
rm_guidance_scale=9, # removal guidance scale
ss_steps = 9, # similarity suppression steps
ss_scale = 0.3, # similarity suppression scale
AAS_start_step=0, # AAS start step
AAS_start_layer=34, # AAS start layer
AAS_end_layer=70, # AAS end layer
num_inference_steps=50, # number of inference steps # AAS_end_step = int(strength*num_inference_steps)
generator=generator,
guidance_scale=1,
).images[0]
image.save('./removed_img.png')
print("Object removal completed")
```
| Source Image | Mask | Output |
| ---------------------------------------------------------------------------------------------- | ------------------------------------------------------------------------------------------- | --------------------------------------------------------------------------------------------------- |
| ![Source Image](https://raw.githubusercontent.com/Anonym0u3/Images/refs/heads/main/an1024.png) | ![Mask](https://raw.githubusercontent.com/Anonym0u3/Images/refs/heads/main/an1024_mask.png) | ![Output](https://raw.githubusercontent.com/Anonym0u3/Images/refs/heads/main/AE_step40_layer34.png) |
# Perturbed-Attention Guidance
[Project](https://ku-cvlab.github.io/Perturbed-Attention-Guidance/) / [arXiv](https://arxiv.org/abs/2403.17377) / [GitHub](https://github.com/KU-CVLAB/Perturbed-Attention-Guidance)
+5 -74
View File
@@ -80,7 +80,6 @@ from diffusers.utils import (
USE_PEFT_BACKEND,
BaseOutput,
deprecate,
is_torch_version,
is_torch_xla_available,
logging,
replace_example_docstring,
@@ -869,23 +868,7 @@ class CrossAttnDownBlock2D(nn.Module):
for i, (resnet, attn) in enumerate(blocks):
if torch.is_grad_enabled() and self.gradient_checkpointing:
def create_custom_forward(module, return_dict=None):
def custom_forward(*inputs):
if return_dict is not None:
return module(*inputs, return_dict=return_dict)
else:
return module(*inputs)
return custom_forward
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(resnet),
hidden_states,
temb,
**ckpt_kwargs,
)
hidden_states = self._gradient_checkpointing_func(resnet, hidden_states, temb)
hidden_states = attn(
hidden_states,
encoder_hidden_states=encoder_hidden_states,
@@ -1030,17 +1013,6 @@ class UNetMidBlock2DCrossAttn(nn.Module):
hidden_states = self.resnets[0](hidden_states, temb)
for attn, resnet in zip(self.attentions, self.resnets[1:]):
if torch.is_grad_enabled() and self.gradient_checkpointing:
def create_custom_forward(module, return_dict=None):
def custom_forward(*inputs):
if return_dict is not None:
return module(*inputs, return_dict=return_dict)
else:
return module(*inputs)
return custom_forward
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
hidden_states = attn(
hidden_states,
encoder_hidden_states=encoder_hidden_states,
@@ -1049,12 +1021,7 @@ class UNetMidBlock2DCrossAttn(nn.Module):
encoder_attention_mask=encoder_attention_mask,
return_dict=False,
)[0]
hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(resnet),
hidden_states,
temb,
**ckpt_kwargs,
)
hidden_states = self._gradient_checkpointing_func(resnet, hidden_states, temb)
else:
hidden_states = attn(
hidden_states,
@@ -1192,23 +1159,7 @@ class CrossAttnUpBlock2D(nn.Module):
hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
if torch.is_grad_enabled() and self.gradient_checkpointing:
def create_custom_forward(module, return_dict=None):
def custom_forward(*inputs):
if return_dict is not None:
return module(*inputs, return_dict=return_dict)
else:
return module(*inputs)
return custom_forward
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(resnet),
hidden_states,
temb,
**ckpt_kwargs,
)
hidden_states = self._gradient_checkpointing_func(resnet, hidden_states, temb)
hidden_states = attn(
hidden_states,
encoder_hidden_states=encoder_hidden_states,
@@ -1282,10 +1233,6 @@ class MatryoshkaTransformer2DModel(LegacyModelMixin, LegacyConfigMixin):
]
)
def _set_gradient_checkpointing(self, module, value=False):
if hasattr(module, "gradient_checkpointing"):
module.gradient_checkpointing = value
def forward(
self,
hidden_states: torch.Tensor,
@@ -1365,19 +1312,8 @@ class MatryoshkaTransformer2DModel(LegacyModelMixin, LegacyConfigMixin):
# Blocks
for block in self.transformer_blocks:
if torch.is_grad_enabled() and self.gradient_checkpointing:
def create_custom_forward(module, return_dict=None):
def custom_forward(*inputs):
if return_dict is not None:
return module(*inputs, return_dict=return_dict)
else:
return module(*inputs)
return custom_forward
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(block),
hidden_states = self._gradient_checkpointing_func(
block,
hidden_states,
attention_mask,
encoder_hidden_states,
@@ -1385,7 +1321,6 @@ class MatryoshkaTransformer2DModel(LegacyModelMixin, LegacyConfigMixin):
timestep,
cross_attention_kwargs,
class_labels,
**ckpt_kwargs,
)
else:
hidden_states = block(
@@ -2724,10 +2659,6 @@ class MatryoshkaUNet2DConditionModel(
for module in self.children():
fn_recursive_set_attention_slice(module, reversed_slice_size)
def _set_gradient_checkpointing(self, module, value=False):
if hasattr(module, "gradient_checkpointing"):
module.gradient_checkpointing = value
def enable_freeu(self, s1: float, s2: float, b1: float, b2: float):
r"""Enables the FreeU mechanism from https://arxiv.org/abs/2309.11497.
File diff suppressed because it is too large Load Diff
File diff suppressed because it is too large Load Diff
@@ -193,7 +193,8 @@ class StableDiffusionXLControlNetReferencePipeline(StableDiffusionXLControlNetPi
def prepare_ref_latents(self, refimage, batch_size, dtype, device, generator, do_classifier_free_guidance):
refimage = refimage.to(device=device)
if self.vae.dtype == torch.float16 and self.vae.config.force_upcast:
needs_upcasting = self.vae.dtype == torch.float16 and self.vae.config.force_upcast
if needs_upcasting:
self.upcast_vae()
refimage = refimage.to(next(iter(self.vae.post_quant_conv.parameters())).dtype)
if refimage.dtype != self.vae.dtype:
@@ -223,6 +224,11 @@ class StableDiffusionXLControlNetReferencePipeline(StableDiffusionXLControlNetPi
# aligning device to prevent device errors when concating it with the latent model input
ref_image_latents = ref_image_latents.to(device=device, dtype=dtype)
# cast back to fp16 if needed
if needs_upcasting:
self.vae.to(dtype=torch.float16)
return ref_image_latents
def prepare_ref_image(
@@ -139,7 +139,8 @@ def retrieve_timesteps(
class StableDiffusionXLReferencePipeline(StableDiffusionXLPipeline):
def prepare_ref_latents(self, refimage, batch_size, dtype, device, generator, do_classifier_free_guidance):
refimage = refimage.to(device=device)
if self.vae.dtype == torch.float16 and self.vae.config.force_upcast:
needs_upcasting = self.vae.dtype == torch.float16 and self.vae.config.force_upcast
if needs_upcasting:
self.upcast_vae()
refimage = refimage.to(next(iter(self.vae.post_quant_conv.parameters())).dtype)
if refimage.dtype != self.vae.dtype:
@@ -169,6 +170,11 @@ class StableDiffusionXLReferencePipeline(StableDiffusionXLPipeline):
# aligning device to prevent device errors when concating it with the latent model input
ref_image_latents = ref_image_latents.to(device=device, dtype=dtype)
# cast back to fp16 if needed
if needs_upcasting:
self.vae.to(dtype=torch.float16)
return ref_image_latents
def prepare_ref_image(
+26
View File
@@ -742,3 +742,29 @@ accelerate launch train_dreambooth.py \
## Stable Diffusion XL
We support fine-tuning of the UNet shipped in [Stable Diffusion XL](https://huggingface.co/papers/2307.01952) with DreamBooth and LoRA via the `train_dreambooth_lora_sdxl.py` script. Please refer to the docs [here](./README_sdxl.md).
## Dataset
We support 🤗 [Datasets](https://huggingface.co/docs/datasets/index), you can find a dataset on the [Hugging Face Hub](https://huggingface.co/datasets) or use your own.
The quickest way to get started with your custom dataset is 🤗 Datasets' [`ImageFolder`](https://huggingface.co/docs/datasets/image_dataset#imagefolder).
We need to create a file `metadata.jsonl` in the directory with our images:
```
{"file_name": "01.jpg", "prompt": "prompt 01"}
{"file_name": "02.jpg", "prompt": "prompt 02"}
```
If we have a directory with image-text pairs e.g. `01.jpg` and `01.txt` then `convert_to_imagefolder.py` can create `metadata.jsonl`.
```sh
python convert_to_imagefolder.py --path my_dataset/
```
We use `--dataset_name` and `--caption_column` with training scripts.
```
--dataset_name=my_dataset/
--caption_column=prompt
```
@@ -0,0 +1,32 @@
import argparse
import json
import pathlib
parser = argparse.ArgumentParser()
parser.add_argument(
"--path",
type=str,
required=True,
help="Path to folder with image-text pairs.",
)
parser.add_argument("--caption_column", type=str, default="prompt", help="Name of caption column.")
args = parser.parse_args()
path = pathlib.Path(args.path)
if not path.exists():
raise RuntimeError(f"`--path` '{args.path}' does not exist.")
all_files = list(path.glob("*"))
captions = list(path.glob("*.txt"))
images = set(all_files) - set(captions)
images = {image.stem: image for image in images}
caption_image = {caption: images.get(caption.stem) for caption in captions if images.get(caption.stem)}
metadata = path.joinpath("metadata.jsonl")
with metadata.open("w", encoding="utf-8") as f:
for caption, image in caption_image.items():
caption_text = caption.read_text(encoding="utf-8")
json.dump({"file_name": image.name, args.caption_column: caption_text}, f)
f.write("\n")
@@ -63,6 +63,7 @@ from diffusers.utils import (
is_wandb_available,
)
from diffusers.utils.hub_utils import load_or_create_model_card, populate_model_card
from diffusers.utils.import_utils import is_torch_npu_available
from diffusers.utils.torch_utils import is_compiled_module
@@ -74,6 +75,9 @@ check_min_version("0.33.0.dev0")
logger = get_logger(__name__)
if is_torch_npu_available():
torch.npu.config.allow_internal_format = False
def save_model_card(
repo_id: str,
@@ -601,6 +605,7 @@ def parse_args(input_args=None):
)
parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank")
parser.add_argument("--enable_vae_tiling", action="store_true", help="Enabla vae tiling in log validation")
parser.add_argument("--enable_npu_flash_attention", action="store_true", help="Enabla Flash Attention for NPU")
if input_args is not None:
args = parser.parse_args(input_args)
@@ -924,8 +929,7 @@ def main(args):
image.save(image_filename)
del pipeline
if torch.cuda.is_available():
torch.cuda.empty_cache()
free_memory()
# Handle the repository creation
if accelerator.is_main_process:
@@ -988,6 +992,14 @@ def main(args):
# because Gemma2 is particularly suited for bfloat16.
text_encoder.to(dtype=torch.bfloat16)
if args.enable_npu_flash_attention:
if is_torch_npu_available():
logger.info("npu flash attention enabled.")
for block in transformer.transformer_blocks:
block.attn2.set_use_npu_flash_attention(True)
else:
raise ValueError("npu flash attention requires torch_npu extensions and is supported only on npu device ")
# Initialize a text encoding pipeline and keep it to CPU for now.
text_encoding_pipeline = SanaPipeline.from_pretrained(
args.pretrained_model_name_or_path,
@@ -695,7 +695,7 @@ def main():
)
# We need to ensure that the original and the edited images undergo the same
# augmentation transforms.
images = np.concatenate([original_images, edited_images])
images = np.stack([original_images, edited_images])
images = torch.tensor(images)
images = 2 * (images / 255) - 1
return train_transforms(images)
@@ -706,7 +706,7 @@ def main():
# Since the original and edited images were concatenated before
# applying the transformations, we need to separate them and reshape
# them accordingly.
original_images, edited_images = preprocessed_images.chunk(2)
original_images, edited_images = preprocessed_images
original_images = original_images.reshape(-1, 3, args.resolution, args.resolution)
edited_images = edited_images.reshape(-1, 3, args.resolution, args.resolution)
@@ -766,7 +766,7 @@ def main():
)
# We need to ensure that the original and the edited images undergo the same
# augmentation transforms.
images = np.concatenate([original_images, edited_images])
images = np.stack([original_images, edited_images])
images = torch.tensor(images)
images = 2 * (images / 255) - 1
return train_transforms(images)
@@ -906,7 +906,7 @@ def main():
# Since the original and edited images were concatenated before
# applying the transformations, we need to separate them and reshape
# them accordingly.
original_images, edited_images = preprocessed_images.chunk(2)
original_images, edited_images = preprocessed_images
original_images = original_images.reshape(-1, 3, args.resolution, args.resolution)
edited_images = edited_images.reshape(-1, 3, args.resolution, args.resolution)
+2 -22
View File
@@ -82,31 +82,11 @@ pipeline = EasyPipelineForInpainting.from_huggingface(
## Search Civitai and Huggingface
```python
from pipeline_easy import (
search_huggingface,
search_civitai,
)
# Search Lora
Lora = search_civitai(
"Keyword_to_search_Lora",
model_type="LORA",
base_model = "SD 1.5",
download=True,
)
# Load Lora into the pipeline.
pipeline.load_lora_weights(Lora)
pipeline.auto_load_lora_weights("Detail Tweaker")
# Search TextualInversion
TextualInversion = search_civitai(
"EasyNegative",
model_type="TextualInversion",
base_model = "SD 1.5",
download=True
)
# Load TextualInversion into the pipeline.
pipeline.load_textual_inversion(TextualInversion, token="EasyNegative")
pipeline.auto_load_textual_inversion("EasyNegative", token="EasyNegative")
```
### Search Civitai
+482 -110
View File
@@ -1,5 +1,5 @@
# coding=utf-8
# Copyright 2024 suzukimain
# Copyright 2025 suzukimain
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -15,11 +15,13 @@
import os
import re
import types
from collections import OrderedDict
from dataclasses import asdict, dataclass
from typing import Union
from dataclasses import asdict, dataclass, field
from typing import Dict, List, Optional, Union
import requests
import torch
from huggingface_hub import hf_api, hf_hub_download
from huggingface_hub.file_download import http_get
from huggingface_hub.utils import validate_hf_hub_args
@@ -30,6 +32,7 @@ from diffusers.loaders.single_file_utils import (
infer_diffusers_model_type,
load_single_file_checkpoint,
)
from diffusers.pipelines.animatediff import AnimateDiffPipeline, AnimateDiffSDXLPipeline
from diffusers.pipelines.auto_pipeline import (
AutoPipelineForImage2Image,
AutoPipelineForInpainting,
@@ -39,13 +42,18 @@ from diffusers.pipelines.controlnet import (
StableDiffusionControlNetImg2ImgPipeline,
StableDiffusionControlNetInpaintPipeline,
StableDiffusionControlNetPipeline,
StableDiffusionXLControlNetImg2ImgPipeline,
StableDiffusionXLControlNetPipeline,
)
from diffusers.pipelines.flux import FluxImg2ImgPipeline, FluxPipeline
from diffusers.pipelines.pipeline_utils import DiffusionPipeline
from diffusers.pipelines.stable_diffusion import (
StableDiffusionImg2ImgPipeline,
StableDiffusionInpaintPipeline,
StableDiffusionPipeline,
StableDiffusionUpscalePipeline,
)
from diffusers.pipelines.stable_diffusion_3 import StableDiffusion3Img2ImgPipeline, StableDiffusion3Pipeline
from diffusers.pipelines.stable_diffusion_xl import (
StableDiffusionXLImg2ImgPipeline,
StableDiffusionXLInpaintPipeline,
@@ -59,46 +67,133 @@ logger = logging.get_logger(__name__)
SINGLE_FILE_CHECKPOINT_TEXT2IMAGE_PIPELINE_MAPPING = OrderedDict(
[
("xl_base", StableDiffusionXLPipeline),
("xl_refiner", StableDiffusionXLPipeline),
("xl_inpaint", None),
("playground-v2-5", StableDiffusionXLPipeline),
("upscale", None),
("animatediff_rgb", AnimateDiffPipeline),
("animatediff_scribble", AnimateDiffPipeline),
("animatediff_sdxl_beta", AnimateDiffSDXLPipeline),
("animatediff_v1", AnimateDiffPipeline),
("animatediff_v2", AnimateDiffPipeline),
("animatediff_v3", AnimateDiffPipeline),
("autoencoder-dc-f128c512", None),
("autoencoder-dc-f32c32", None),
("autoencoder-dc-f32c32-sana", None),
("autoencoder-dc-f64c128", None),
("controlnet", StableDiffusionControlNetPipeline),
("controlnet_xl", StableDiffusionXLControlNetPipeline),
("controlnet_xl_large", StableDiffusionXLControlNetPipeline),
("controlnet_xl_mid", StableDiffusionXLControlNetPipeline),
("controlnet_xl_small", StableDiffusionXLControlNetPipeline),
("flux-depth", FluxPipeline),
("flux-dev", FluxPipeline),
("flux-fill", FluxPipeline),
("flux-schnell", FluxPipeline),
("hunyuan-video", None),
("inpainting", None),
("inpainting_v2", None),
("controlnet", StableDiffusionControlNetPipeline),
("v2", StableDiffusionPipeline),
("ltx-video", None),
("ltx-video-0.9.1", None),
("mochi-1-preview", None),
("playground-v2-5", StableDiffusionXLPipeline),
("sd3", StableDiffusion3Pipeline),
("sd35_large", StableDiffusion3Pipeline),
("sd35_medium", StableDiffusion3Pipeline),
("stable_cascade_stage_b", None),
("stable_cascade_stage_b_lite", None),
("stable_cascade_stage_c", None),
("stable_cascade_stage_c_lite", None),
("upscale", StableDiffusionUpscalePipeline),
("v1", StableDiffusionPipeline),
("v2", StableDiffusionPipeline),
("xl_base", StableDiffusionXLPipeline),
("xl_inpaint", None),
("xl_refiner", StableDiffusionXLPipeline),
]
)
SINGLE_FILE_CHECKPOINT_IMAGE2IMAGE_PIPELINE_MAPPING = OrderedDict(
[
("xl_base", StableDiffusionXLImg2ImgPipeline),
("xl_refiner", StableDiffusionXLImg2ImgPipeline),
("xl_inpaint", None),
("playground-v2-5", StableDiffusionXLImg2ImgPipeline),
("upscale", None),
("animatediff_rgb", AnimateDiffPipeline),
("animatediff_scribble", AnimateDiffPipeline),
("animatediff_sdxl_beta", AnimateDiffSDXLPipeline),
("animatediff_v1", AnimateDiffPipeline),
("animatediff_v2", AnimateDiffPipeline),
("animatediff_v3", AnimateDiffPipeline),
("autoencoder-dc-f128c512", None),
("autoencoder-dc-f32c32", None),
("autoencoder-dc-f32c32-sana", None),
("autoencoder-dc-f64c128", None),
("controlnet", StableDiffusionControlNetImg2ImgPipeline),
("controlnet_xl", StableDiffusionXLControlNetImg2ImgPipeline),
("controlnet_xl_large", StableDiffusionXLControlNetImg2ImgPipeline),
("controlnet_xl_mid", StableDiffusionXLControlNetImg2ImgPipeline),
("controlnet_xl_small", StableDiffusionXLControlNetImg2ImgPipeline),
("flux-depth", FluxImg2ImgPipeline),
("flux-dev", FluxImg2ImgPipeline),
("flux-fill", FluxImg2ImgPipeline),
("flux-schnell", FluxImg2ImgPipeline),
("hunyuan-video", None),
("inpainting", None),
("inpainting_v2", None),
("controlnet", StableDiffusionControlNetImg2ImgPipeline),
("v2", StableDiffusionImg2ImgPipeline),
("ltx-video", None),
("ltx-video-0.9.1", None),
("mochi-1-preview", None),
("playground-v2-5", StableDiffusionXLImg2ImgPipeline),
("sd3", StableDiffusion3Img2ImgPipeline),
("sd35_large", StableDiffusion3Img2ImgPipeline),
("sd35_medium", StableDiffusion3Img2ImgPipeline),
("stable_cascade_stage_b", None),
("stable_cascade_stage_b_lite", None),
("stable_cascade_stage_c", None),
("stable_cascade_stage_c_lite", None),
("upscale", StableDiffusionUpscalePipeline),
("v1", StableDiffusionImg2ImgPipeline),
("v2", StableDiffusionImg2ImgPipeline),
("xl_base", StableDiffusionXLImg2ImgPipeline),
("xl_inpaint", None),
("xl_refiner", StableDiffusionXLImg2ImgPipeline),
]
)
SINGLE_FILE_CHECKPOINT_INPAINT_PIPELINE_MAPPING = OrderedDict(
[
("xl_base", None),
("xl_refiner", None),
("xl_inpaint", StableDiffusionXLInpaintPipeline),
("playground-v2-5", None),
("upscale", None),
("animatediff_rgb", None),
("animatediff_scribble", None),
("animatediff_sdxl_beta", None),
("animatediff_v1", None),
("animatediff_v2", None),
("animatediff_v3", None),
("autoencoder-dc-f128c512", None),
("autoencoder-dc-f32c32", None),
("autoencoder-dc-f32c32-sana", None),
("autoencoder-dc-f64c128", None),
("controlnet", StableDiffusionControlNetInpaintPipeline),
("controlnet_xl", None),
("controlnet_xl_large", None),
("controlnet_xl_mid", None),
("controlnet_xl_small", None),
("flux-depth", None),
("flux-dev", None),
("flux-fill", None),
("flux-schnell", None),
("hunyuan-video", None),
("inpainting", StableDiffusionInpaintPipeline),
("inpainting_v2", StableDiffusionInpaintPipeline),
("controlnet", StableDiffusionControlNetInpaintPipeline),
("v2", None),
("ltx-video", None),
("ltx-video-0.9.1", None),
("mochi-1-preview", None),
("playground-v2-5", None),
("sd3", None),
("sd35_large", None),
("sd35_medium", None),
("stable_cascade_stage_b", None),
("stable_cascade_stage_b_lite", None),
("stable_cascade_stage_c", None),
("stable_cascade_stage_c_lite", None),
("upscale", StableDiffusionUpscalePipeline),
("v1", None),
("v2", None),
("xl_base", None),
("xl_inpaint", StableDiffusionXLInpaintPipeline),
("xl_refiner", None),
]
)
@@ -116,14 +211,33 @@ CONFIG_FILE_LIST = [
"diffusion_pytorch_model.non_ema.safetensors",
]
DIFFUSERS_CONFIG_DIR = ["safety_checker", "unet", "vae", "text_encoder", "text_encoder_2"]
INPAINT_PIPELINE_KEYS = [
"xl_inpaint",
"inpainting",
"inpainting_v2",
DIFFUSERS_CONFIG_DIR = [
"safety_checker",
"unet",
"vae",
"text_encoder",
"text_encoder_2",
]
TOKENIZER_SHAPE_MAP = {
768: [
"SD 1.4",
"SD 1.5",
"SD 1.5 LCM",
"SDXL 0.9",
"SDXL 1.0",
"SDXL 1.0 LCM",
"SDXL Distilled",
"SDXL Turbo",
"SDXL Lightning",
"PixArt a",
"Playground v2",
"Pony",
],
1024: ["SD 2.0", "SD 2.0 768", "SD 2.1", "SD 2.1 768", "SD 2.1 Unclip"],
}
EXTENSION = [".safetensors", ".ckpt", ".bin"]
CACHE_HOME = os.path.expanduser("~/.cache")
@@ -162,12 +276,28 @@ class ModelStatus:
The name of the model file.
local (`bool`):
Whether the model exists locally
site_url (`str`):
The URL of the site where the model is hosted.
"""
search_word: str = ""
download_url: str = ""
file_name: str = ""
local: bool = False
site_url: str = ""
@dataclass
class ExtraStatus:
r"""
Data class for storing extra status information.
Attributes:
trained_words (`str`):
The words used to trigger the model
"""
trained_words: Union[List[str], None] = None
@dataclass
@@ -191,8 +321,9 @@ class SearchResult:
model_path: str = ""
loading_method: Union[str, None] = None
checkpoint_format: Union[str, None] = None
repo_status: RepoStatus = RepoStatus()
model_status: ModelStatus = ModelStatus()
repo_status: RepoStatus = field(default_factory=RepoStatus)
model_status: ModelStatus = field(default_factory=ModelStatus)
extra_status: ExtraStatus = field(default_factory=ExtraStatus)
@validate_hf_hub_args
@@ -385,6 +516,7 @@ def file_downloader(
proxies = kwargs.pop("proxies", None)
force_download = kwargs.pop("force_download", False)
displayed_filename = kwargs.pop("displayed_filename", None)
# Default mode for file writing and initial file size
mode = "wb"
file_size = 0
@@ -396,7 +528,7 @@ def file_downloader(
if os.path.exists(save_path):
if not force_download:
# If the file exists and force_download is False, skip the download
logger.warning(f"File already exists: {save_path}, skipping download.")
logger.info(f"File already exists: {save_path}, skipping download.")
return None
elif resume:
# If resuming, set mode to append binary and get current file size
@@ -457,10 +589,18 @@ def search_huggingface(search_word: str, **kwargs) -> Union[str, SearchResult, N
gated = kwargs.pop("gated", False)
skip_error = kwargs.pop("skip_error", False)
file_list = []
hf_repo_info = {}
hf_security_info = {}
model_path = ""
repo_id, file_name = "", ""
diffusers_model_exists = False
# Get the type and loading method for the keyword
search_word_status = get_keyword_types(search_word)
if search_word_status["type"]["hf_repo"]:
hf_repo_info = hf_api.model_info(repo_id=search_word, securityStatus=True)
if download:
model_path = DiffusionPipeline.download(
search_word,
@@ -503,13 +643,6 @@ def search_huggingface(search_word: str, **kwargs) -> Union[str, SearchResult, N
)
model_dicts = [asdict(value) for value in list(hf_models)]
file_list = []
hf_repo_info = {}
hf_security_info = {}
model_path = ""
repo_id, file_name = "", ""
diffusers_model_exists = False
# Loop through models to find a suitable candidate
for repo_info in model_dicts:
repo_id = repo_info["id"]
@@ -523,7 +656,10 @@ def search_huggingface(search_word: str, **kwargs) -> Union[str, SearchResult, N
if hf_security_info["scansDone"]:
for info in repo_info["siblings"]:
file_path = info["rfilename"]
if "model_index.json" == file_path and checkpoint_format in ["diffusers", "all"]:
if "model_index.json" == file_path and checkpoint_format in [
"diffusers",
"all",
]:
diffusers_model_exists = True
break
@@ -571,6 +707,10 @@ def search_huggingface(search_word: str, **kwargs) -> Union[str, SearchResult, N
force_download=force_download,
)
# `pathlib.PosixPath` may be returned
if model_path:
model_path = str(model_path)
if file_name:
download_url = f"https://huggingface.co/{repo_id}/blob/main/{file_name}"
else:
@@ -586,10 +726,12 @@ def search_huggingface(search_word: str, **kwargs) -> Union[str, SearchResult, N
repo_status=RepoStatus(repo_id=repo_id, repo_hash=hf_repo_info.sha, version=revision),
model_status=ModelStatus(
search_word=search_word,
site_url=download_url,
download_url=download_url,
file_name=file_name,
local=download,
),
extra_status=ExtraStatus(trained_words=None),
)
else:
@@ -605,6 +747,8 @@ def search_civitai(search_word: str, **kwargs) -> Union[str, SearchResult, None]
The search query string.
model_type (`str`, *optional*, defaults to `Checkpoint`):
The type of model to search for.
sort (`str`, *optional*):
The order in which you wish to sort the results(for example, `Highest Rated`, `Most Downloaded`, `Newest`).
base_model (`str`, *optional*):
The base model to filter by.
download (`bool`, *optional*, defaults to `False`):
@@ -628,6 +772,7 @@ def search_civitai(search_word: str, **kwargs) -> Union[str, SearchResult, None]
# Extract additional parameters from kwargs
model_type = kwargs.pop("model_type", "Checkpoint")
sort = kwargs.pop("sort", None)
download = kwargs.pop("download", False)
base_model = kwargs.pop("base_model", None)
force_download = kwargs.pop("force_download", False)
@@ -642,6 +787,7 @@ def search_civitai(search_word: str, **kwargs) -> Union[str, SearchResult, None]
repo_name = ""
repo_id = ""
version_id = ""
trainedWords = ""
models_list = []
selected_repo = {}
selected_model = {}
@@ -652,12 +798,16 @@ def search_civitai(search_word: str, **kwargs) -> Union[str, SearchResult, None]
params = {
"query": search_word,
"types": model_type,
"sort": "Most Downloaded",
"limit": 20,
}
if base_model is not None:
if not isinstance(base_model, list):
base_model = [base_model]
params["baseModel"] = base_model
if sort is not None:
params["sort"] = sort
headers = {}
if token:
headers["Authorization"] = f"Bearer {token}"
@@ -686,25 +836,30 @@ def search_civitai(search_word: str, **kwargs) -> Union[str, SearchResult, None]
# Sort versions within the selected repo by download count
sorted_versions = sorted(
selected_repo["modelVersions"], key=lambda x: x["stats"]["downloadCount"], reverse=True
selected_repo["modelVersions"],
key=lambda x: x["stats"]["downloadCount"],
reverse=True,
)
for selected_version in sorted_versions:
version_id = selected_version["id"]
trainedWords = selected_version["trainedWords"]
models_list = []
for model_data in selected_version["files"]:
# Check if the file passes security scans and has a valid extension
file_name = model_data["name"]
if (
model_data["pickleScanResult"] == "Success"
and model_data["virusScanResult"] == "Success"
and any(file_name.endswith(ext) for ext in EXTENSION)
and os.path.basename(os.path.dirname(file_name)) not in DIFFUSERS_CONFIG_DIR
):
file_status = {
"filename": file_name,
"download_url": model_data["downloadUrl"],
}
models_list.append(file_status)
# When searching for textual inversion, results other than the values entered for the base model may come up, so check again.
if base_model is None or selected_version["baseModel"] in base_model:
for model_data in selected_version["files"]:
# Check if the file passes security scans and has a valid extension
file_name = model_data["name"]
if (
model_data["pickleScanResult"] == "Success"
and model_data["virusScanResult"] == "Success"
and any(file_name.endswith(ext) for ext in EXTENSION)
and os.path.basename(os.path.dirname(file_name)) not in DIFFUSERS_CONFIG_DIR
):
file_status = {
"filename": file_name,
"download_url": model_data["downloadUrl"],
}
models_list.append(file_status)
if models_list:
# Sort the models list by filename and find the safest model
@@ -764,19 +919,229 @@ def search_civitai(search_word: str, **kwargs) -> Union[str, SearchResult, None]
repo_status=RepoStatus(repo_id=repo_name, repo_hash=repo_id, version=version_id),
model_status=ModelStatus(
search_word=search_word,
site_url=f"https://civitai.com/models/{repo_id}?modelVersionId={version_id}",
download_url=download_url,
file_name=file_name,
local=output_info["type"]["local"],
),
extra_status=ExtraStatus(trained_words=trainedWords or None),
)
def add_methods(pipeline):
r"""
Add methods from `AutoConfig` to the pipeline.
Parameters:
pipeline (`Pipeline`):
The pipeline to which the methods will be added.
"""
for attr_name in dir(AutoConfig):
attr_value = getattr(AutoConfig, attr_name)
if callable(attr_value) and not attr_name.startswith("__"):
setattr(pipeline, attr_name, types.MethodType(attr_value, pipeline))
return pipeline
class AutoConfig:
def auto_load_textual_inversion(
self,
pretrained_model_name_or_path: Union[str, List[str]],
token: Optional[Union[str, List[str]]] = None,
base_model: Optional[Union[str, List[str]]] = None,
tokenizer=None,
text_encoder=None,
**kwargs,
):
r"""
Load Textual Inversion embeddings into the text encoder of [`StableDiffusionPipeline`] (both 🤗 Diffusers and
Automatic1111 formats are supported).
Parameters:
pretrained_model_name_or_path (`str` or `os.PathLike` or `List[str or os.PathLike]` or `Dict` or `List[Dict]`):
Can be either one of the following or a list of them:
- Search keywords for pretrained model (for example `EasyNegative`).
- A string, the *model id* (for example `sd-concepts-library/low-poly-hd-logos-icons`) of a
pretrained model hosted on the Hub.
- A path to a *directory* (for example `./my_text_inversion_directory/`) containing the textual
inversion weights.
- A path to a *file* (for example `./my_text_inversions.pt`) containing textual inversion weights.
- A [torch state
dict](https://pytorch.org/tutorials/beginner/saving_loading_models.html#what-is-a-state-dict).
token (`str` or `List[str]`, *optional*):
Override the token to use for the textual inversion weights. If `pretrained_model_name_or_path` is a
list, then `token` must also be a list of equal length.
text_encoder ([`~transformers.CLIPTextModel`], *optional*):
Frozen text-encoder ([clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14)).
If not specified, function will take self.tokenizer.
tokenizer ([`~transformers.CLIPTokenizer`], *optional*):
A `CLIPTokenizer` to tokenize text. If not specified, function will take self.tokenizer.
weight_name (`str`, *optional*):
Name of a custom weight file. This should be used when:
- The saved textual inversion file is in 🤗 Diffusers format, but was saved under a specific weight
name such as `text_inv.bin`.
- The saved textual inversion file is in the Automatic1111 format.
cache_dir (`Union[str, os.PathLike]`, *optional*):
Path to a directory where a downloaded pretrained model configuration is cached if the standard cache
is not used.
force_download (`bool`, *optional*, defaults to `False`):
Whether or not to force the (re-)download of the model weights and configuration files, overriding the
cached versions if they exist.
proxies (`Dict[str, str]`, *optional*):
A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128',
'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
local_files_only (`bool`, *optional*, defaults to `False`):
Whether to only load local model weights and configuration files or not. If set to `True`, the model
won't be downloaded from the Hub.
token (`str` or *bool*, *optional*):
The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from
`diffusers-cli login` (stored in `~/.huggingface`) is used.
revision (`str`, *optional*, defaults to `"main"`):
The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier
allowed by Git.
subfolder (`str`, *optional*, defaults to `""`):
The subfolder location of a model file within a larger model repository on the Hub or locally.
mirror (`str`, *optional*):
Mirror source to resolve accessibility issues if you're downloading a model in China. We do not
guarantee the timeliness or safety of the source, and you should refer to the mirror site for more
information.
Examples:
```py
>>> from auto_diffusers import EasyPipelineForText2Image
>>> pipeline = EasyPipelineForText2Image.from_huggingface("stable-diffusion-v1-5")
>>> pipeline.auto_load_textual_inversion("EasyNegative", token="EasyNegative")
>>> image = pipeline(prompt).images[0]
```
"""
# 1. Set tokenizer and text encoder
tokenizer = tokenizer or getattr(self, "tokenizer", None)
text_encoder = text_encoder or getattr(self, "text_encoder", None)
# Check if tokenizer and text encoder are provided
if tokenizer is None or text_encoder is None:
raise ValueError("Tokenizer and text encoder must be provided.")
# 2. Normalize inputs
pretrained_model_name_or_paths = (
[pretrained_model_name_or_path]
if not isinstance(pretrained_model_name_or_path, list)
else pretrained_model_name_or_path
)
# 2.1 Normalize tokens
tokens = [token] if not isinstance(token, list) else token
if tokens[0] is None:
tokens = tokens * len(pretrained_model_name_or_paths)
for check_token in tokens:
# Check if token is already in tokenizer vocabulary
if check_token in tokenizer.get_vocab():
raise ValueError(
f"Token {token} already in tokenizer vocabulary. Please choose a different token name or remove {token} and embedding from the tokenizer and text encoder."
)
expected_shape = text_encoder.get_input_embeddings().weight.shape[-1] # Expected shape of tokenizer
for search_word in pretrained_model_name_or_paths:
if isinstance(search_word, str):
# Update kwargs to ensure the model is downloaded and parameters are included
_status = {
"download": True,
"include_params": True,
"skip_error": False,
"model_type": "TextualInversion",
}
# Get tags for the base model of textual inversion compatible with tokenizer.
# If the tokenizer is 768-dimensional, set tags for SD 1.x and SDXL.
# If the tokenizer is 1024-dimensional, set tags for SD 2.x.
if expected_shape in TOKENIZER_SHAPE_MAP:
# Retrieve the appropriate tags from the TOKENIZER_SHAPE_MAP based on the expected shape
tags = TOKENIZER_SHAPE_MAP[expected_shape]
if base_model is not None:
if isinstance(base_model, list):
tags.extend(base_model)
else:
tags.append(base_model)
_status["base_model"] = tags
kwargs.update(_status)
# Search for the model on Civitai and get the model status
textual_inversion_path = search_civitai(search_word, **kwargs)
logger.warning(
f"textual_inversion_path: {search_word} -> {textual_inversion_path.model_status.site_url}"
)
pretrained_model_name_or_paths[
pretrained_model_name_or_paths.index(search_word)
] = textual_inversion_path.model_path
self.load_textual_inversion(
pretrained_model_name_or_paths, token=tokens, tokenizer=tokenizer, text_encoder=text_encoder, **kwargs
)
def auto_load_lora_weights(
self, pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], adapter_name=None, **kwargs
):
r"""
Load LoRA weights specified in `pretrained_model_name_or_path_or_dict` into `self.unet` and
`self.text_encoder`.
All kwargs are forwarded to `self.lora_state_dict`.
See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`] for more details on how the state dict is
loaded.
See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_into_unet`] for more details on how the state dict is
loaded into `self.unet`.
See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_into_text_encoder`] for more details on how the state
dict is loaded into `self.text_encoder`.
Parameters:
pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`):
See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`].
adapter_name (`str`, *optional*):
Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
`default_{i}` where i is the total number of adapters being loaded.
low_cpu_mem_usage (`bool`, *optional*):
Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
weights.
kwargs (`dict`, *optional*):
See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`].
"""
if isinstance(pretrained_model_name_or_path_or_dict, str):
# Update kwargs to ensure the model is downloaded and parameters are included
_status = {
"download": True,
"include_params": True,
"skip_error": False,
"model_type": "LORA",
}
kwargs.update(_status)
# Search for the model on Civitai and get the model status
lora_path = search_civitai(pretrained_model_name_or_path_or_dict, **kwargs)
logger.warning(f"lora_path: {lora_path.model_status.site_url}")
logger.warning(f"trained_words: {lora_path.extra_status.trained_words}")
pretrained_model_name_or_path_or_dict = lora_path.model_path
self.load_lora_weights(pretrained_model_name_or_path_or_dict, adapter_name=adapter_name, **kwargs)
class EasyPipelineForText2Image(AutoPipelineForText2Image):
r"""
[`AutoPipelineForText2Image`] is a generic pipeline class that instantiates a text-to-image pipeline class. The
[`EasyPipelineForText2Image`] is a generic pipeline class that instantiates a text-to-image pipeline class. The
specific underlying pipeline class is automatically selected from either the
[`~AutoPipelineForText2Image.from_pretrained`] or [`~AutoPipelineForText2Image.from_pipe`] methods.
[`~EasyPipelineForText2Image.from_pretrained`], [`~EasyPipelineForText2Image.from_pipe`], [`~EasyPipelineForText2Image.from_huggingface`] or [`~EasyPipelineForText2Image.from_civitai`] methods.
This class cannot be instantiated using `__init__()` (throws an error).
@@ -891,9 +1256,9 @@ class EasyPipelineForText2Image(AutoPipelineForText2Image):
Examples:
```py
>>> from diffusers import AutoPipelineForText2Image
>>> from auto_diffusers import EasyPipelineForText2Image
>>> pipeline = AutoPipelineForText2Image.from_huggingface("stable-diffusion-v1-5")
>>> pipeline = EasyPipelineForText2Image.from_huggingface("stable-diffusion-v1-5")
>>> image = pipeline(prompt).images[0]
```
"""
@@ -907,20 +1272,21 @@ class EasyPipelineForText2Image(AutoPipelineForText2Image):
kwargs.update(_status)
# Search for the model on Hugging Face and get the model status
hf_model_status = search_huggingface(pretrained_model_link_or_path, **kwargs)
logger.warning(f"checkpoint_path: {hf_model_status.model_status.download_url}")
checkpoint_path = hf_model_status.model_path
hf_checkpoint_status = search_huggingface(pretrained_model_link_or_path, **kwargs)
logger.warning(f"checkpoint_path: {hf_checkpoint_status.model_status.download_url}")
checkpoint_path = hf_checkpoint_status.model_path
# Check the format of the model checkpoint
if hf_model_status.checkpoint_format == "single_file":
if hf_checkpoint_status.loading_method == "from_single_file":
# Load the pipeline from a single file checkpoint
return load_pipeline_from_single_file(
pipeline = load_pipeline_from_single_file(
pretrained_model_or_path=checkpoint_path,
pipeline_mapping=SINGLE_FILE_CHECKPOINT_TEXT2IMAGE_PIPELINE_MAPPING,
**kwargs,
)
else:
return cls.from_pretrained(checkpoint_path, **kwargs)
pipeline = cls.from_pretrained(checkpoint_path, **kwargs)
return add_methods(pipeline)
@classmethod
def from_civitai(cls, pretrained_model_link_or_path, **kwargs):
@@ -999,9 +1365,9 @@ class EasyPipelineForText2Image(AutoPipelineForText2Image):
Examples:
```py
>>> from diffusers import AutoPipelineForText2Image
>>> from auto_diffusers import EasyPipelineForText2Image
>>> pipeline = AutoPipelineForText2Image.from_huggingface("stable-diffusion-v1-5")
>>> pipeline = EasyPipelineForText2Image.from_huggingface("stable-diffusion-v1-5")
>>> image = pipeline(prompt).images[0]
```
"""
@@ -1015,24 +1381,25 @@ class EasyPipelineForText2Image(AutoPipelineForText2Image):
kwargs.update(_status)
# Search for the model on Civitai and get the model status
model_status = search_civitai(pretrained_model_link_or_path, **kwargs)
logger.warning(f"checkpoint_path: {model_status.model_status.download_url}")
checkpoint_path = model_status.model_path
checkpoint_status = search_civitai(pretrained_model_link_or_path, **kwargs)
logger.warning(f"checkpoint_path: {checkpoint_status.model_status.site_url}")
checkpoint_path = checkpoint_status.model_path
# Load the pipeline from a single file checkpoint
return load_pipeline_from_single_file(
pipeline = load_pipeline_from_single_file(
pretrained_model_or_path=checkpoint_path,
pipeline_mapping=SINGLE_FILE_CHECKPOINT_TEXT2IMAGE_PIPELINE_MAPPING,
**kwargs,
)
return add_methods(pipeline)
class EasyPipelineForImage2Image(AutoPipelineForImage2Image):
r"""
[`AutoPipelineForImage2Image`] is a generic pipeline class that instantiates an image-to-image pipeline class. The
[`EasyPipelineForImage2Image`] is a generic pipeline class that instantiates an image-to-image pipeline class. The
specific underlying pipeline class is automatically selected from either the
[`~AutoPipelineForImage2Image.from_pretrained`] or [`~AutoPipelineForImage2Image.from_pipe`] methods.
[`~EasyPipelineForImage2Image.from_pretrained`], [`~EasyPipelineForImage2Image.from_pipe`], [`~EasyPipelineForImage2Image.from_huggingface`] or [`~EasyPipelineForImage2Image.from_civitai`] methods.
This class cannot be instantiated using `__init__()` (throws an error).
@@ -1147,10 +1514,10 @@ class EasyPipelineForImage2Image(AutoPipelineForImage2Image):
Examples:
```py
>>> from diffusers import AutoPipelineForText2Image
>>> from auto_diffusers import EasyPipelineForImage2Image
>>> pipeline = AutoPipelineForText2Image.from_huggingface("stable-diffusion-v1-5")
>>> image = pipeline(prompt).images[0]
>>> pipeline = EasyPipelineForImage2Image.from_huggingface("stable-diffusion-v1-5")
>>> image = pipeline(prompt, image).images[0]
```
"""
# Update kwargs to ensure the model is downloaded and parameters are included
@@ -1163,20 +1530,22 @@ class EasyPipelineForImage2Image(AutoPipelineForImage2Image):
kwargs.update(_parmas)
# Search for the model on Hugging Face and get the model status
model_status = search_huggingface(pretrained_model_link_or_path, **kwargs)
logger.warning(f"checkpoint_path: {model_status.model_status.download_url}")
checkpoint_path = model_status.model_path
hf_checkpoint_status = search_huggingface(pretrained_model_link_or_path, **kwargs)
logger.warning(f"checkpoint_path: {hf_checkpoint_status.model_status.download_url}")
checkpoint_path = hf_checkpoint_status.model_path
# Check the format of the model checkpoint
if model_status.checkpoint_format == "single_file":
if hf_checkpoint_status.loading_method == "from_single_file":
# Load the pipeline from a single file checkpoint
return load_pipeline_from_single_file(
pipeline = load_pipeline_from_single_file(
pretrained_model_or_path=checkpoint_path,
pipeline_mapping=SINGLE_FILE_CHECKPOINT_IMAGE2IMAGE_PIPELINE_MAPPING,
**kwargs,
)
else:
return cls.from_pretrained(checkpoint_path, **kwargs)
pipeline = cls.from_pretrained(checkpoint_path, **kwargs)
return add_methods(pipeline)
@classmethod
def from_civitai(cls, pretrained_model_link_or_path, **kwargs):
@@ -1255,10 +1624,10 @@ class EasyPipelineForImage2Image(AutoPipelineForImage2Image):
Examples:
```py
>>> from diffusers import AutoPipelineForText2Image
>>> from auto_diffusers import EasyPipelineForImage2Image
>>> pipeline = AutoPipelineForText2Image.from_huggingface("stable-diffusion-v1-5")
>>> image = pipeline(prompt).images[0]
>>> pipeline = EasyPipelineForImage2Image.from_huggingface("stable-diffusion-v1-5")
>>> image = pipeline(prompt, image).images[0]
```
"""
# Update kwargs to ensure the model is downloaded and parameters are included
@@ -1271,24 +1640,25 @@ class EasyPipelineForImage2Image(AutoPipelineForImage2Image):
kwargs.update(_status)
# Search for the model on Civitai and get the model status
model_status = search_civitai(pretrained_model_link_or_path, **kwargs)
logger.warning(f"checkpoint_path: {model_status.model_status.download_url}")
checkpoint_path = model_status.model_path
checkpoint_status = search_civitai(pretrained_model_link_or_path, **kwargs)
logger.warning(f"checkpoint_path: {checkpoint_status.model_status.site_url}")
checkpoint_path = checkpoint_status.model_path
# Load the pipeline from a single file checkpoint
return load_pipeline_from_single_file(
pipeline = load_pipeline_from_single_file(
pretrained_model_or_path=checkpoint_path,
pipeline_mapping=SINGLE_FILE_CHECKPOINT_IMAGE2IMAGE_PIPELINE_MAPPING,
**kwargs,
)
return add_methods(pipeline)
class EasyPipelineForInpainting(AutoPipelineForInpainting):
r"""
[`AutoPipelineForInpainting`] is a generic pipeline class that instantiates an inpainting pipeline class. The
[`EasyPipelineForInpainting`] is a generic pipeline class that instantiates an inpainting pipeline class. The
specific underlying pipeline class is automatically selected from either the
[`~AutoPipelineForInpainting.from_pretrained`] or [`~AutoPipelineForInpainting.from_pipe`] methods.
[`~EasyPipelineForInpainting.from_pretrained`], [`~EasyPipelineForInpainting.from_pipe`], [`~EasyPipelineForInpainting.from_huggingface`] or [`~EasyPipelineForInpainting.from_civitai`] methods.
This class cannot be instantiated using `__init__()` (throws an error).
@@ -1403,10 +1773,10 @@ class EasyPipelineForInpainting(AutoPipelineForInpainting):
Examples:
```py
>>> from diffusers import AutoPipelineForText2Image
>>> from auto_diffusers import EasyPipelineForInpainting
>>> pipeline = AutoPipelineForText2Image.from_huggingface("stable-diffusion-v1-5")
>>> image = pipeline(prompt).images[0]
>>> pipeline = EasyPipelineForInpainting.from_huggingface("stable-diffusion-2-inpainting")
>>> image = pipeline(prompt, image=init_image, mask_image=mask_image).images[0]
```
"""
# Update kwargs to ensure the model is downloaded and parameters are included
@@ -1419,20 +1789,21 @@ class EasyPipelineForInpainting(AutoPipelineForInpainting):
kwargs.update(_status)
# Search for the model on Hugging Face and get the model status
model_status = search_huggingface(pretrained_model_link_or_path, **kwargs)
logger.warning(f"checkpoint_path: {model_status.model_status.download_url}")
checkpoint_path = model_status.model_path
hf_checkpoint_status = search_huggingface(pretrained_model_link_or_path, **kwargs)
logger.warning(f"checkpoint_path: {hf_checkpoint_status.model_status.download_url}")
checkpoint_path = hf_checkpoint_status.model_path
# Check the format of the model checkpoint
if model_status.checkpoint_format == "single_file":
if hf_checkpoint_status.loading_method == "from_single_file":
# Load the pipeline from a single file checkpoint
return load_pipeline_from_single_file(
pipeline = load_pipeline_from_single_file(
pretrained_model_or_path=checkpoint_path,
pipeline_mapping=SINGLE_FILE_CHECKPOINT_INPAINT_PIPELINE_MAPPING,
**kwargs,
)
else:
return cls.from_pretrained(checkpoint_path, **kwargs)
pipeline = cls.from_pretrained(checkpoint_path, **kwargs)
return add_methods(pipeline)
@classmethod
def from_civitai(cls, pretrained_model_link_or_path, **kwargs):
@@ -1511,10 +1882,10 @@ class EasyPipelineForInpainting(AutoPipelineForInpainting):
Examples:
```py
>>> from diffusers import AutoPipelineForText2Image
>>> from auto_diffusers import EasyPipelineForInpainting
>>> pipeline = AutoPipelineForText2Image.from_huggingface("stable-diffusion-v1-5")
>>> image = pipeline(prompt).images[0]
>>> pipeline = EasyPipelineForInpainting.from_huggingface("stable-diffusion-2-inpainting")
>>> image = pipeline(prompt, image=init_image, mask_image=mask_image).images[0]
```
"""
# Update kwargs to ensure the model is downloaded and parameters are included
@@ -1527,13 +1898,14 @@ class EasyPipelineForInpainting(AutoPipelineForInpainting):
kwargs.update(_status)
# Search for the model on Civitai and get the model status
model_status = search_civitai(pretrained_model_link_or_path, **kwargs)
logger.warning(f"checkpoint_path: {model_status.model_status.download_url}")
checkpoint_path = model_status.model_path
checkpoint_status = search_civitai(pretrained_model_link_or_path, **kwargs)
logger.warning(f"checkpoint_path: {checkpoint_status.model_status.site_url}")
checkpoint_path = checkpoint_status.model_path
# Load the pipeline from a single file checkpoint
return load_pipeline_from_single_file(
pipeline = load_pipeline_from_single_file(
pretrained_model_or_path=checkpoint_path,
pipeline_mapping=SINGLE_FILE_CHECKPOINT_INPAINT_PIPELINE_MAPPING,
**kwargs,
)
return add_methods(pipeline)
@@ -0,0 +1,59 @@
# AutoencoderKL training example
## Installing the dependencies
Before running the scripts, make sure to install the library's training dependencies:
**Important**
To make sure you can successfully run the latest versions of the example scripts, we highly recommend **installing from source** and keeping the install up to date as we update the example scripts frequently and install some example-specific requirements. To do this, execute the following steps in a new virtual environment:
```bash
git clone https://github.com/huggingface/diffusers
cd diffusers
pip install .
```
Then cd in the example folder and run
```bash
pip install -r requirements.txt
```
And initialize an [🤗Accelerate](https://github.com/huggingface/accelerate/) environment with:
```bash
accelerate config
```
## Training on CIFAR10
Please replace the validation image with your own image.
```bash
accelerate launch train_autoencoderkl.py \
--pretrained_model_name_or_path stabilityai/sd-vae-ft-mse \
--dataset_name=cifar10 \
--image_column=img \
--validation_image images/bird.jpg images/car.jpg images/dog.jpg images/frog.jpg \
--num_train_epochs 100 \
--gradient_accumulation_steps 2 \
--learning_rate 4.5e-6 \
--lr_scheduler cosine \
--report_to wandb \
```
## Training on ImageNet
```bash
accelerate launch train_autoencoderkl.py \
--pretrained_model_name_or_path stabilityai/sd-vae-ft-mse \
--num_train_epochs 100 \
--gradient_accumulation_steps 2 \
--learning_rate 4.5e-6 \
--lr_scheduler cosine \
--report_to wandb \
--mixed_precision bf16 \
--train_data_dir /path/to/ImageNet/train \
--validation_image ./image.png \
--decoder_only
```
@@ -0,0 +1,15 @@
accelerate>=0.16.0
bitsandbytes
datasets
huggingface_hub
lpips
numpy
packaging
Pillow
taming_transformers
torch
torchvision
tqdm
transformers
wandb
xformers
File diff suppressed because it is too large Load Diff
@@ -0,0 +1,30 @@
# Diffusion Model Alignment Using GRPO
This directory provides LoRA implementations of Diffusion [GRPO](https://arxiv.org/abs/2402.03300) an RL based alignment method which is a variant of Proximal Policy Optimization (PPO) in the diffusion model setting.
## SDXL training command
```bash
accelerate launch train_diffusion_grpo_sdxl.py \
--pretrained_model_name_or_path=stabilityai/stable-diffusion-xl-base-1.0 \
--pretrained_vae_model_name_or_path=madebyollin/sdxl-vae-fp16-fix \
--output_dir="diffusion-sdxl-dpo" \
--mixed_precision="fp16" \
--dataset_name=kashif/pickascore \
--train_batch_size=8 \
--gradient_accumulation_steps=2 \
--gradient_checkpointing \
--use_8bit_adam \
--rank=8 \
--learning_rate=1e-5 \
--report_to="wandb" \
--lr_scheduler="constant" \
--lr_warmup_steps=0 \
--max_train_steps=2000 \
--checkpointing_steps=500 \
--run_validation --validation_steps=50 \
--seed="0" \
--report_to="wandb" \
--push_to_hub
```
@@ -0,0 +1,8 @@
accelerate>=0.16.0
torchvision
transformers>=4.25.1
ftfy
tensorboard
Jinja2
peft
wandb
File diff suppressed because it is too large Load Diff
@@ -8,7 +8,6 @@ from diffusers.models import PixArtTransformer2DModel
from diffusers.models.attention import BasicTransformerBlock
from diffusers.models.modeling_outputs import Transformer2DModelOutput
from diffusers.models.modeling_utils import ModelMixin
from diffusers.utils.torch_utils import is_torch_version
class PixArtControlNetAdapterBlock(nn.Module):
@@ -151,10 +150,6 @@ class PixArtControlNetTransformerModel(ModelMixin, ConfigMixin):
self.transformer = transformer
self.controlnet = controlnet
def _set_gradient_checkpointing(self, module, value=False):
if hasattr(module, "gradient_checkpointing"):
module.gradient_checkpointing = value
def forward(
self,
hidden_states: torch.Tensor,
@@ -220,18 +215,8 @@ class PixArtControlNetTransformerModel(ModelMixin, ConfigMixin):
print("Gradient checkpointing is not supported for the controlnet transformer model, yet.")
exit(1)
def create_custom_forward(module, return_dict=None):
def custom_forward(*inputs):
if return_dict is not None:
return module(*inputs, return_dict=return_dict)
else:
return module(*inputs)
return custom_forward
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(block),
hidden_states = self._gradient_checkpointing_func(
block,
hidden_states,
attention_mask,
encoder_hidden_states,
@@ -239,7 +224,6 @@ class PixArtControlNetTransformerModel(ModelMixin, ConfigMixin):
timestep,
cross_attention_kwargs,
None,
**ckpt_kwargs,
)
else:
# the control nets are only used for the blocks 1 to self.blocks_num
@@ -365,8 +365,8 @@ def parse_args():
"--dream_training",
action="store_true",
help=(
"Use the DREAM training method, which makes training more efficient and accurate at the ",
"expense of doing an extra forward pass. See: https://arxiv.org/abs/2312.00210",
"Use the DREAM training method, which makes training more efficient and accurate at the "
"expense of doing an extra forward pass. See: https://arxiv.org/abs/2312.00210"
),
)
parser.add_argument(
@@ -515,10 +515,6 @@ def main():
elif accelerator.mixed_precision == "bf16":
weight_dtype = torch.bfloat16
# Freeze the unet parameters before adding adapters
for param in unet.parameters():
param.requires_grad_(False)
unet_lora_config = LoraConfig(
r=args.rank,
lora_alpha=args.rank,
+151
View File
@@ -0,0 +1,151 @@
"""
This script demonstrates how to extract a LoRA checkpoint from a fully finetuned model with the CogVideoX model.
To make it work for other models:
* Change the model class. Here we use `CogVideoXTransformer3DModel`. For Flux, it would be `FluxTransformer2DModel`,
for example. (TODO: more reason to add `AutoModel`).
* Spply path to the base checkpoint via `base_ckpt_path`.
* Supply path to the fully fine-tuned checkpoint via `--finetune_ckpt_path`.
* Change the `--rank` as needed.
Example usage:
```bash
python extract_lora_from_model.py \
--base_ckpt_path=THUDM/CogVideoX-5b \
--finetune_ckpt_path=finetrainers/cakeify-v0 \
--lora_out_path=cakeify_lora.safetensors
```
Script is adapted from
https://github.com/Stability-AI/stability-ComfyUI-nodes/blob/001154622564b17223ce0191803c5fff7b87146c/control_lora_create.py
"""
import argparse
import torch
from safetensors.torch import save_file
from tqdm.auto import tqdm
from diffusers import CogVideoXTransformer3DModel
RANK = 64
CLAMP_QUANTILE = 0.99
# Comes from
# https://github.com/Stability-AI/stability-ComfyUI-nodes/blob/001154622564b17223ce0191803c5fff7b87146c/control_lora_create.py#L9
def extract_lora(diff, rank):
# Important to use CUDA otherwise, very slow!
if torch.cuda.is_available():
diff = diff.to("cuda")
is_conv2d = len(diff.shape) == 4
kernel_size = None if not is_conv2d else diff.size()[2:4]
is_conv2d_3x3 = is_conv2d and kernel_size != (1, 1)
out_dim, in_dim = diff.size()[0:2]
rank = min(rank, in_dim, out_dim)
if is_conv2d:
if is_conv2d_3x3:
diff = diff.flatten(start_dim=1)
else:
diff = diff.squeeze()
U, S, Vh = torch.linalg.svd(diff.float())
U = U[:, :rank]
S = S[:rank]
U = U @ torch.diag(S)
Vh = Vh[:rank, :]
dist = torch.cat([U.flatten(), Vh.flatten()])
hi_val = torch.quantile(dist, CLAMP_QUANTILE)
low_val = -hi_val
U = U.clamp(low_val, hi_val)
Vh = Vh.clamp(low_val, hi_val)
if is_conv2d:
U = U.reshape(out_dim, rank, 1, 1)
Vh = Vh.reshape(rank, in_dim, kernel_size[0], kernel_size[1])
return (U.cpu(), Vh.cpu())
def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument(
"--base_ckpt_path",
default=None,
type=str,
required=True,
help="Base checkpoint path from which the model was finetuned. Can be a model ID on the Hub.",
)
parser.add_argument(
"--base_subfolder",
default="transformer",
type=str,
help="subfolder to load the base checkpoint from if any.",
)
parser.add_argument(
"--finetune_ckpt_path",
default=None,
type=str,
required=True,
help="Fully fine-tuned checkpoint path. Can be a model ID on the Hub.",
)
parser.add_argument(
"--finetune_subfolder",
default=None,
type=str,
help="subfolder to load the fulle finetuned checkpoint from if any.",
)
parser.add_argument("--rank", default=64, type=int)
parser.add_argument("--lora_out_path", default=None, type=str, required=True)
args = parser.parse_args()
if not args.lora_out_path.endswith(".safetensors"):
raise ValueError("`lora_out_path` must end with `.safetensors`.")
return args
@torch.no_grad()
def main(args):
model_finetuned = CogVideoXTransformer3DModel.from_pretrained(
args.finetune_ckpt_path, subfolder=args.finetune_subfolder, torch_dtype=torch.bfloat16
)
state_dict_ft = model_finetuned.state_dict()
# Change the `subfolder` as needed.
base_model = CogVideoXTransformer3DModel.from_pretrained(
args.base_ckpt_path, subfolder=args.base_subfolder, torch_dtype=torch.bfloat16
)
state_dict = base_model.state_dict()
output_dict = {}
for k in tqdm(state_dict, desc="Extracting LoRA..."):
original_param = state_dict[k]
finetuned_param = state_dict_ft[k]
if len(original_param.shape) >= 2:
diff = finetuned_param.float() - original_param.float()
out = extract_lora(diff, RANK)
name = k
if name.endswith(".weight"):
name = name[: -len(".weight")]
down_key = "{}.lora_A.weight".format(name)
up_key = "{}.lora_B.weight".format(name)
output_dict[up_key] = out[0].contiguous().to(finetuned_param.dtype)
output_dict[down_key] = out[1].contiguous().to(finetuned_param.dtype)
prefix = "transformer" if "transformer" in base_model.__class__.__name__.lower() else "unet"
output_dict = {f"{prefix}.{k}": v for k, v in output_dict.items()}
save_file(output_dict, args.lora_out_path)
print(f"LoRA saved and it contains {len(output_dict)} keys.")
if __name__ == "__main__":
args = parse_args()
main(args)
+11
View File
@@ -28,6 +28,7 @@ from .utils import (
_import_structure = {
"configuration_utils": ["ConfigMixin"],
"hooks": [],
"loaders": ["FromOriginalModelMixin"],
"models": [],
"pipelines": [],
@@ -75,6 +76,13 @@ except OptionalDependencyNotAvailable:
_import_structure["utils.dummy_pt_objects"] = [name for name in dir(dummy_pt_objects) if not name.startswith("_")]
else:
_import_structure["hooks"].extend(
[
"HookRegistry",
"PyramidAttentionBroadcastConfig",
"apply_pyramid_attention_broadcast",
]
)
_import_structure["models"].extend(
[
"AllegroTransformer3DModel",
@@ -90,6 +98,7 @@ else:
"AutoencoderKLTemporalDecoder",
"AutoencoderOobleck",
"AutoencoderTiny",
"CacheMixin",
"CogVideoXTransformer3DModel",
"CogView3PlusTransformer2DModel",
"ConsisIDTransformer3DModel",
@@ -588,6 +597,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
except OptionalDependencyNotAvailable:
from .utils.dummy_pt_objects import * # noqa F403
else:
from .hooks import HookRegistry, PyramidAttentionBroadcastConfig, apply_pyramid_attention_broadcast
from .models import (
AllegroTransformer3DModel,
AsymmetricAutoencoderKL,
@@ -602,6 +612,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
AutoencoderKLTemporalDecoder,
AutoencoderOobleck,
AutoencoderTiny,
CacheMixin,
CogVideoXTransformer3DModel,
CogView3PlusTransformer2DModel,
ConsisIDTransformer3DModel,
+7
View File
@@ -0,0 +1,7 @@
from ..utils import is_torch_available
if is_torch_available():
from .hooks import HookRegistry, ModelHook
from .layerwise_casting import apply_layerwise_casting, apply_layerwise_casting_hook
from .pyramid_attention_broadcast import PyramidAttentionBroadcastConfig, apply_pyramid_attention_broadcast
+236
View File
@@ -0,0 +1,236 @@
# Copyright 2024 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import functools
from typing import Any, Dict, Optional, Tuple
import torch
from ..utils.logging import get_logger
logger = get_logger(__name__) # pylint: disable=invalid-name
class ModelHook:
r"""
A hook that contains callbacks to be executed just before and after the forward method of a model.
"""
_is_stateful = False
def __init__(self):
self.fn_ref: "HookFunctionReference" = None
def initialize_hook(self, module: torch.nn.Module) -> torch.nn.Module:
r"""
Hook that is executed when a model is initialized.
Args:
module (`torch.nn.Module`):
The module attached to this hook.
"""
return module
def deinitalize_hook(self, module: torch.nn.Module) -> torch.nn.Module:
r"""
Hook that is executed when a model is deinitalized.
Args:
module (`torch.nn.Module`):
The module attached to this hook.
"""
return module
def pre_forward(self, module: torch.nn.Module, *args, **kwargs) -> Tuple[Tuple[Any], Dict[str, Any]]:
r"""
Hook that is executed just before the forward method of the model.
Args:
module (`torch.nn.Module`):
The module whose forward pass will be executed just after this event.
args (`Tuple[Any]`):
The positional arguments passed to the module.
kwargs (`Dict[Str, Any]`):
The keyword arguments passed to the module.
Returns:
`Tuple[Tuple[Any], Dict[Str, Any]]`:
A tuple with the treated `args` and `kwargs`.
"""
return args, kwargs
def post_forward(self, module: torch.nn.Module, output: Any) -> Any:
r"""
Hook that is executed just after the forward method of the model.
Args:
module (`torch.nn.Module`):
The module whose forward pass been executed just before this event.
output (`Any`):
The output of the module.
Returns:
`Any`: The processed `output`.
"""
return output
def detach_hook(self, module: torch.nn.Module) -> torch.nn.Module:
r"""
Hook that is executed when the hook is detached from a module.
Args:
module (`torch.nn.Module`):
The module detached from this hook.
"""
return module
def reset_state(self, module: torch.nn.Module):
if self._is_stateful:
raise NotImplementedError("This hook is stateful and needs to implement the `reset_state` method.")
return module
class HookFunctionReference:
def __init__(self) -> None:
"""A container class that maintains mutable references to forward pass functions in a hook chain.
Its mutable nature allows the hook system to modify the execution chain dynamically without rebuilding the
entire forward pass structure.
Attributes:
pre_forward: A callable that processes inputs before the main forward pass.
post_forward: A callable that processes outputs after the main forward pass.
forward: The current forward function in the hook chain.
original_forward: The original forward function, stored when a hook provides a custom new_forward.
The class enables hook removal by allowing updates to the forward chain through reference modification rather
than requiring reconstruction of the entire chain. When a hook is removed, only the relevant references need to
be updated, preserving the execution order of the remaining hooks.
"""
self.pre_forward = None
self.post_forward = None
self.forward = None
self.original_forward = None
class HookRegistry:
def __init__(self, module_ref: torch.nn.Module) -> None:
super().__init__()
self.hooks: Dict[str, ModelHook] = {}
self._module_ref = module_ref
self._hook_order = []
self._fn_refs = []
def register_hook(self, hook: ModelHook, name: str) -> None:
if name in self.hooks.keys():
raise ValueError(
f"Hook with name {name} already exists in the registry. Please use a different name or "
f"first remove the existing hook and then add a new one."
)
self._module_ref = hook.initialize_hook(self._module_ref)
def create_new_forward(function_reference: HookFunctionReference):
def new_forward(module, *args, **kwargs):
args, kwargs = function_reference.pre_forward(module, *args, **kwargs)
output = function_reference.forward(*args, **kwargs)
return function_reference.post_forward(module, output)
return new_forward
forward = self._module_ref.forward
fn_ref = HookFunctionReference()
fn_ref.pre_forward = hook.pre_forward
fn_ref.post_forward = hook.post_forward
fn_ref.forward = forward
if hasattr(hook, "new_forward"):
fn_ref.original_forward = forward
fn_ref.forward = functools.update_wrapper(
functools.partial(hook.new_forward, self._module_ref), hook.new_forward
)
rewritten_forward = create_new_forward(fn_ref)
self._module_ref.forward = functools.update_wrapper(
functools.partial(rewritten_forward, self._module_ref), rewritten_forward
)
hook.fn_ref = fn_ref
self.hooks[name] = hook
self._hook_order.append(name)
self._fn_refs.append(fn_ref)
def get_hook(self, name: str) -> Optional[ModelHook]:
return self.hooks.get(name, None)
def remove_hook(self, name: str, recurse: bool = True) -> None:
if name in self.hooks.keys():
num_hooks = len(self._hook_order)
hook = self.hooks[name]
index = self._hook_order.index(name)
fn_ref = self._fn_refs[index]
old_forward = fn_ref.forward
if fn_ref.original_forward is not None:
old_forward = fn_ref.original_forward
if index == num_hooks - 1:
self._module_ref.forward = old_forward
else:
self._fn_refs[index + 1].forward = old_forward
self._module_ref = hook.deinitalize_hook(self._module_ref)
del self.hooks[name]
self._hook_order.pop(index)
self._fn_refs.pop(index)
if recurse:
for module_name, module in self._module_ref.named_modules():
if module_name == "":
continue
if hasattr(module, "_diffusers_hook"):
module._diffusers_hook.remove_hook(name, recurse=False)
def reset_stateful_hooks(self, recurse: bool = True) -> None:
for hook_name in reversed(self._hook_order):
hook = self.hooks[hook_name]
if hook._is_stateful:
hook.reset_state(self._module_ref)
if recurse:
for module_name, module in self._module_ref.named_modules():
if module_name == "":
continue
if hasattr(module, "_diffusers_hook"):
module._diffusers_hook.reset_stateful_hooks(recurse=False)
@classmethod
def check_if_exists_or_initialize(cls, module: torch.nn.Module) -> "HookRegistry":
if not hasattr(module, "_diffusers_hook"):
module._diffusers_hook = cls(module)
return module._diffusers_hook
def __repr__(self) -> str:
registry_repr = ""
for i, hook_name in enumerate(self._hook_order):
if self.hooks[hook_name].__class__.__repr__ is not object.__repr__:
hook_repr = self.hooks[hook_name].__repr__()
else:
hook_repr = self.hooks[hook_name].__class__.__name__
registry_repr += f" ({i}) {hook_name} - {hook_repr}"
if i < len(self._hook_order) - 1:
registry_repr += "\n"
return f"HookRegistry(\n{registry_repr}\n)"
+191
View File
@@ -0,0 +1,191 @@
# Copyright 2024 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import re
from typing import Optional, Tuple, Type, Union
import torch
from ..utils import get_logger
from .hooks import HookRegistry, ModelHook
logger = get_logger(__name__) # pylint: disable=invalid-name
# fmt: off
SUPPORTED_PYTORCH_LAYERS = (
torch.nn.Conv1d, torch.nn.Conv2d, torch.nn.Conv3d,
torch.nn.ConvTranspose1d, torch.nn.ConvTranspose2d, torch.nn.ConvTranspose3d,
torch.nn.Linear,
)
DEFAULT_SKIP_MODULES_PATTERN = ("pos_embed", "patch_embed", "norm", "^proj_in$", "^proj_out$")
# fmt: on
class LayerwiseCastingHook(ModelHook):
r"""
A hook that casts the weights of a module to a high precision dtype for computation, and to a low precision dtype
for storage. This process may lead to quality loss in the output, but can significantly reduce the memory
footprint.
"""
_is_stateful = False
def __init__(self, storage_dtype: torch.dtype, compute_dtype: torch.dtype, non_blocking: bool) -> None:
self.storage_dtype = storage_dtype
self.compute_dtype = compute_dtype
self.non_blocking = non_blocking
def initialize_hook(self, module: torch.nn.Module):
module.to(dtype=self.storage_dtype, non_blocking=self.non_blocking)
return module
def deinitalize_hook(self, module: torch.nn.Module):
raise NotImplementedError(
"LayerwiseCastingHook does not support deinitalization. A model once enabled with layerwise casting will "
"have casted its weights to a lower precision dtype for storage. Casting this back to the original dtype "
"will lead to precision loss, which might have an impact on the model's generation quality. The model should "
"be re-initialized and loaded in the original dtype."
)
def pre_forward(self, module: torch.nn.Module, *args, **kwargs):
module.to(dtype=self.compute_dtype, non_blocking=self.non_blocking)
return args, kwargs
def post_forward(self, module: torch.nn.Module, output):
module.to(dtype=self.storage_dtype, non_blocking=self.non_blocking)
return output
def apply_layerwise_casting(
module: torch.nn.Module,
storage_dtype: torch.dtype,
compute_dtype: torch.dtype,
skip_modules_pattern: Union[str, Tuple[str, ...]] = "auto",
skip_modules_classes: Optional[Tuple[Type[torch.nn.Module], ...]] = None,
non_blocking: bool = False,
) -> None:
r"""
Applies layerwise casting to a given module. The module expected here is a Diffusers ModelMixin but it can be any
nn.Module using diffusers layers or pytorch primitives.
Example:
```python
>>> import torch
>>> from diffusers import CogVideoXTransformer3DModel
>>> transformer = CogVideoXTransformer3DModel.from_pretrained(
... model_id, subfolder="transformer", torch_dtype=torch.bfloat16
... )
>>> apply_layerwise_casting(
... transformer,
... storage_dtype=torch.float8_e4m3fn,
... compute_dtype=torch.bfloat16,
... skip_modules_pattern=["patch_embed", "norm", "proj_out"],
... non_blocking=True,
... )
```
Args:
module (`torch.nn.Module`):
The module whose leaf modules will be cast to a high precision dtype for computation, and to a low
precision dtype for storage.
storage_dtype (`torch.dtype`):
The dtype to cast the module to before/after the forward pass for storage.
compute_dtype (`torch.dtype`):
The dtype to cast the module to during the forward pass for computation.
skip_modules_pattern (`Tuple[str, ...]`, defaults to `"auto"`):
A list of patterns to match the names of the modules to skip during the layerwise casting process. If set
to `"auto"`, the default patterns are used. If set to `None`, no modules are skipped. If set to `None`
alongside `skip_modules_classes` being `None`, the layerwise casting is applied directly to the module
instead of its internal submodules.
skip_modules_classes (`Tuple[Type[torch.nn.Module], ...]`, defaults to `None`):
A list of module classes to skip during the layerwise casting process.
non_blocking (`bool`, defaults to `False`):
If `True`, the weight casting operations are non-blocking.
"""
if skip_modules_pattern == "auto":
skip_modules_pattern = DEFAULT_SKIP_MODULES_PATTERN
if skip_modules_classes is None and skip_modules_pattern is None:
apply_layerwise_casting_hook(module, storage_dtype, compute_dtype, non_blocking)
return
_apply_layerwise_casting(
module,
storage_dtype,
compute_dtype,
skip_modules_pattern,
skip_modules_classes,
non_blocking,
)
def _apply_layerwise_casting(
module: torch.nn.Module,
storage_dtype: torch.dtype,
compute_dtype: torch.dtype,
skip_modules_pattern: Optional[Tuple[str, ...]] = None,
skip_modules_classes: Optional[Tuple[Type[torch.nn.Module], ...]] = None,
non_blocking: bool = False,
_prefix: str = "",
) -> None:
should_skip = (skip_modules_classes is not None and isinstance(module, skip_modules_classes)) or (
skip_modules_pattern is not None and any(re.search(pattern, _prefix) for pattern in skip_modules_pattern)
)
if should_skip:
logger.debug(f'Skipping layerwise casting for layer "{_prefix}"')
return
if isinstance(module, SUPPORTED_PYTORCH_LAYERS):
logger.debug(f'Applying layerwise casting to layer "{_prefix}"')
apply_layerwise_casting_hook(module, storage_dtype, compute_dtype, non_blocking)
return
for name, submodule in module.named_children():
layer_name = f"{_prefix}.{name}" if _prefix else name
_apply_layerwise_casting(
submodule,
storage_dtype,
compute_dtype,
skip_modules_pattern,
skip_modules_classes,
non_blocking,
_prefix=layer_name,
)
def apply_layerwise_casting_hook(
module: torch.nn.Module, storage_dtype: torch.dtype, compute_dtype: torch.dtype, non_blocking: bool
) -> None:
r"""
Applies a `LayerwiseCastingHook` to a given module.
Args:
module (`torch.nn.Module`):
The module to attach the hook to.
storage_dtype (`torch.dtype`):
The dtype to cast the module to before the forward pass.
compute_dtype (`torch.dtype`):
The dtype to cast the module to during the forward pass.
non_blocking (`bool`):
If `True`, the weight casting operations are non-blocking.
"""
registry = HookRegistry.check_if_exists_or_initialize(module)
hook = LayerwiseCastingHook(storage_dtype, compute_dtype, non_blocking)
registry.register_hook(hook, "layerwise_casting")
@@ -0,0 +1,314 @@
# Copyright 2024 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import re
from dataclasses import dataclass
from typing import Any, Callable, Optional, Tuple, Union
import torch
from ..models.attention_processor import Attention, MochiAttention
from ..utils import logging
from .hooks import HookRegistry, ModelHook
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
_ATTENTION_CLASSES = (Attention, MochiAttention)
_SPATIAL_ATTENTION_BLOCK_IDENTIFIERS = ("blocks", "transformer_blocks", "single_transformer_blocks")
_TEMPORAL_ATTENTION_BLOCK_IDENTIFIERS = ("temporal_transformer_blocks",)
_CROSS_ATTENTION_BLOCK_IDENTIFIERS = ("blocks", "transformer_blocks")
@dataclass
class PyramidAttentionBroadcastConfig:
r"""
Configuration for Pyramid Attention Broadcast.
Args:
spatial_attention_block_skip_range (`int`, *optional*, defaults to `None`):
The number of times a specific spatial attention broadcast is skipped before computing the attention states
to re-use. If this is set to the value `N`, the attention computation will be skipped `N - 1` times (i.e.,
old attention states will be re-used) before computing the new attention states again.
temporal_attention_block_skip_range (`int`, *optional*, defaults to `None`):
The number of times a specific temporal attention broadcast is skipped before computing the attention
states to re-use. If this is set to the value `N`, the attention computation will be skipped `N - 1` times
(i.e., old attention states will be re-used) before computing the new attention states again.
cross_attention_block_skip_range (`int`, *optional*, defaults to `None`):
The number of times a specific cross-attention broadcast is skipped before computing the attention states
to re-use. If this is set to the value `N`, the attention computation will be skipped `N - 1` times (i.e.,
old attention states will be re-used) before computing the new attention states again.
spatial_attention_timestep_skip_range (`Tuple[int, int]`, defaults to `(100, 800)`):
The range of timesteps to skip in the spatial attention layer. The attention computations will be
conditionally skipped if the current timestep is within the specified range.
temporal_attention_timestep_skip_range (`Tuple[int, int]`, defaults to `(100, 800)`):
The range of timesteps to skip in the temporal attention layer. The attention computations will be
conditionally skipped if the current timestep is within the specified range.
cross_attention_timestep_skip_range (`Tuple[int, int]`, defaults to `(100, 800)`):
The range of timesteps to skip in the cross-attention layer. The attention computations will be
conditionally skipped if the current timestep is within the specified range.
spatial_attention_block_identifiers (`Tuple[str, ...]`, defaults to `("blocks", "transformer_blocks")`):
The identifiers to match against the layer names to determine if the layer is a spatial attention layer.
temporal_attention_block_identifiers (`Tuple[str, ...]`, defaults to `("temporal_transformer_blocks",)`):
The identifiers to match against the layer names to determine if the layer is a temporal attention layer.
cross_attention_block_identifiers (`Tuple[str, ...]`, defaults to `("blocks", "transformer_blocks")`):
The identifiers to match against the layer names to determine if the layer is a cross-attention layer.
"""
spatial_attention_block_skip_range: Optional[int] = None
temporal_attention_block_skip_range: Optional[int] = None
cross_attention_block_skip_range: Optional[int] = None
spatial_attention_timestep_skip_range: Tuple[int, int] = (100, 800)
temporal_attention_timestep_skip_range: Tuple[int, int] = (100, 800)
cross_attention_timestep_skip_range: Tuple[int, int] = (100, 800)
spatial_attention_block_identifiers: Tuple[str, ...] = _SPATIAL_ATTENTION_BLOCK_IDENTIFIERS
temporal_attention_block_identifiers: Tuple[str, ...] = _TEMPORAL_ATTENTION_BLOCK_IDENTIFIERS
cross_attention_block_identifiers: Tuple[str, ...] = _CROSS_ATTENTION_BLOCK_IDENTIFIERS
current_timestep_callback: Callable[[], int] = None
# TODO(aryan): add PAB for MLP layers (very limited speedup from testing with original codebase
# so not added for now)
def __repr__(self) -> str:
return (
f"PyramidAttentionBroadcastConfig("
f" spatial_attention_block_skip_range={self.spatial_attention_block_skip_range},\n"
f" temporal_attention_block_skip_range={self.temporal_attention_block_skip_range},\n"
f" cross_attention_block_skip_range={self.cross_attention_block_skip_range},\n"
f" spatial_attention_timestep_skip_range={self.spatial_attention_timestep_skip_range},\n"
f" temporal_attention_timestep_skip_range={self.temporal_attention_timestep_skip_range},\n"
f" cross_attention_timestep_skip_range={self.cross_attention_timestep_skip_range},\n"
f" spatial_attention_block_identifiers={self.spatial_attention_block_identifiers},\n"
f" temporal_attention_block_identifiers={self.temporal_attention_block_identifiers},\n"
f" cross_attention_block_identifiers={self.cross_attention_block_identifiers},\n"
f" current_timestep_callback={self.current_timestep_callback}\n"
")"
)
class PyramidAttentionBroadcastState:
r"""
State for Pyramid Attention Broadcast.
Attributes:
iteration (`int`):
The current iteration of the Pyramid Attention Broadcast. It is necessary to ensure that `reset_state` is
called before starting a new inference forward pass for PAB to work correctly.
cache (`Any`):
The cached output from the previous forward pass. This is used to re-use the attention states when the
attention computation is skipped. It is either a tensor or a tuple of tensors, depending on the module.
"""
def __init__(self) -> None:
self.iteration = 0
self.cache = None
def reset(self):
self.iteration = 0
self.cache = None
def __repr__(self):
cache_repr = ""
if self.cache is None:
cache_repr = "None"
else:
cache_repr = f"Tensor(shape={self.cache.shape}, dtype={self.cache.dtype})"
return f"PyramidAttentionBroadcastState(iteration={self.iteration}, cache={cache_repr})"
class PyramidAttentionBroadcastHook(ModelHook):
r"""A hook that applies Pyramid Attention Broadcast to a given module."""
_is_stateful = True
def __init__(
self, timestep_skip_range: Tuple[int, int], block_skip_range: int, current_timestep_callback: Callable[[], int]
) -> None:
super().__init__()
self.timestep_skip_range = timestep_skip_range
self.block_skip_range = block_skip_range
self.current_timestep_callback = current_timestep_callback
def initialize_hook(self, module):
self.state = PyramidAttentionBroadcastState()
return module
def new_forward(self, module: torch.nn.Module, *args, **kwargs) -> Any:
is_within_timestep_range = (
self.timestep_skip_range[0] < self.current_timestep_callback() < self.timestep_skip_range[1]
)
should_compute_attention = (
self.state.cache is None
or self.state.iteration == 0
or not is_within_timestep_range
or self.state.iteration % self.block_skip_range == 0
)
if should_compute_attention:
output = self.fn_ref.original_forward(*args, **kwargs)
else:
output = self.state.cache
self.state.cache = output
self.state.iteration += 1
return output
def reset_state(self, module: torch.nn.Module) -> None:
self.state.reset()
return module
def apply_pyramid_attention_broadcast(
module: torch.nn.Module,
config: PyramidAttentionBroadcastConfig,
):
r"""
Apply [Pyramid Attention Broadcast](https://huggingface.co/papers/2408.12588) to a given pipeline.
PAB is an attention approximation method that leverages the similarity in attention states between timesteps to
reduce the computational cost of attention computation. The key takeaway from the paper is that the attention
similarity in the cross-attention layers between timesteps is high, followed by less similarity in the temporal and
spatial layers. This allows for the skipping of attention computation in the cross-attention layers more frequently
than in the temporal and spatial layers. Applying PAB will, therefore, speedup the inference process.
Args:
module (`torch.nn.Module`):
The module to apply Pyramid Attention Broadcast to.
config (`Optional[PyramidAttentionBroadcastConfig]`, `optional`, defaults to `None`):
The configuration to use for Pyramid Attention Broadcast.
Example:
```python
>>> import torch
>>> from diffusers import CogVideoXPipeline, PyramidAttentionBroadcastConfig, apply_pyramid_attention_broadcast
>>> from diffusers.utils import export_to_video
>>> pipe = CogVideoXPipeline.from_pretrained("THUDM/CogVideoX-5b", torch_dtype=torch.bfloat16)
>>> pipe.to("cuda")
>>> config = PyramidAttentionBroadcastConfig(
... spatial_attention_block_skip_range=2,
... spatial_attention_timestep_skip_range=(100, 800),
... current_timestep_callback=lambda: pipe.current_timestep,
... )
>>> apply_pyramid_attention_broadcast(pipe.transformer, config)
```
"""
if config.current_timestep_callback is None:
raise ValueError(
"The `current_timestep_callback` function must be provided in the configuration to apply Pyramid Attention Broadcast."
)
if (
config.spatial_attention_block_skip_range is None
and config.temporal_attention_block_skip_range is None
and config.cross_attention_block_skip_range is None
):
logger.warning(
"Pyramid Attention Broadcast requires one or more of `spatial_attention_block_skip_range`, `temporal_attention_block_skip_range` "
"or `cross_attention_block_skip_range` parameters to be set to an integer, not `None`. Defaulting to using `spatial_attention_block_skip_range=2`. "
"To avoid this warning, please set one of the above parameters."
)
config.spatial_attention_block_skip_range = 2
for name, submodule in module.named_modules():
if not isinstance(submodule, _ATTENTION_CLASSES):
# PAB has been implemented specific to Diffusers' Attention classes. However, this does not mean that PAB
# cannot be applied to this layer. For custom layers, users can extend this functionality and implement
# their own PAB logic similar to `_apply_pyramid_attention_broadcast_on_attention_class`.
continue
_apply_pyramid_attention_broadcast_on_attention_class(name, submodule, config)
def _apply_pyramid_attention_broadcast_on_attention_class(
name: str, module: Attention, config: PyramidAttentionBroadcastConfig
) -> bool:
is_spatial_self_attention = (
any(re.search(identifier, name) is not None for identifier in config.spatial_attention_block_identifiers)
and config.spatial_attention_block_skip_range is not None
and not getattr(module, "is_cross_attention", False)
)
is_temporal_self_attention = (
any(re.search(identifier, name) is not None for identifier in config.temporal_attention_block_identifiers)
and config.temporal_attention_block_skip_range is not None
and not getattr(module, "is_cross_attention", False)
)
is_cross_attention = (
any(re.search(identifier, name) is not None for identifier in config.cross_attention_block_identifiers)
and config.cross_attention_block_skip_range is not None
and getattr(module, "is_cross_attention", False)
)
block_skip_range, timestep_skip_range, block_type = None, None, None
if is_spatial_self_attention:
block_skip_range = config.spatial_attention_block_skip_range
timestep_skip_range = config.spatial_attention_timestep_skip_range
block_type = "spatial"
elif is_temporal_self_attention:
block_skip_range = config.temporal_attention_block_skip_range
timestep_skip_range = config.temporal_attention_timestep_skip_range
block_type = "temporal"
elif is_cross_attention:
block_skip_range = config.cross_attention_block_skip_range
timestep_skip_range = config.cross_attention_timestep_skip_range
block_type = "cross"
if block_skip_range is None or timestep_skip_range is None:
logger.info(
f'Unable to apply Pyramid Attention Broadcast to the selected layer: "{name}" because it does '
f"not match any of the required criteria for spatial, temporal or cross attention layers. Note, "
f"however, that this layer may still be valid for applying PAB. Please specify the correct "
f"block identifiers in the configuration."
)
return False
logger.debug(f"Enabling Pyramid Attention Broadcast ({block_type}) in layer: {name}")
_apply_pyramid_attention_broadcast_hook(
module, timestep_skip_range, block_skip_range, config.current_timestep_callback
)
return True
def _apply_pyramid_attention_broadcast_hook(
module: Union[Attention, MochiAttention],
timestep_skip_range: Tuple[int, int],
block_skip_range: int,
current_timestep_callback: Callable[[], int],
):
r"""
Apply [Pyramid Attention Broadcast](https://huggingface.co/papers/2408.12588) to a given torch.nn.Module.
Args:
module (`torch.nn.Module`):
The module to apply Pyramid Attention Broadcast to.
timestep_skip_range (`Tuple[int, int]`):
The range of timesteps to skip in the attention layer. The attention computations will be conditionally
skipped if the current timestep is within the specified range.
block_skip_range (`int`):
The number of times a specific attention broadcast is skipped before computing the attention states to
re-use. If this is set to the value `N`, the attention computation will be skipped `N - 1` times (i.e., old
attention states will be re-used) before computing the new attention states again.
current_timestep_callback (`Callable[[], int]`):
A callback function that returns the current inference timestep.
"""
registry = HookRegistry.check_if_exists_or_initialize(module)
hook = PyramidAttentionBroadcastHook(timestep_skip_range, block_skip_range, current_timestep_callback)
registry.register_hook(hook, "pyramid_attention_broadcast")
+83 -1
View File
@@ -519,7 +519,7 @@ def _convert_kohya_flux_lora_to_diffusers(state_dict):
remaining_keys = list(sds_sd.keys())
te_state_dict = {}
if remaining_keys:
if not all(k.startswith("lora_te1") for k in remaining_keys):
if not all(k.startswith("lora_te") for k in remaining_keys):
raise ValueError(f"Incompatible keys detected: \n\n {', '.join(remaining_keys)}")
for key in remaining_keys:
if not key.endswith("lora_down.weight"):
@@ -558,6 +558,88 @@ def _convert_kohya_flux_lora_to_diffusers(state_dict):
new_state_dict = {**ait_sd, **te_state_dict}
return new_state_dict
def _convert_mixture_state_dict_to_diffusers(state_dict):
new_state_dict = {}
def _convert(original_key, diffusers_key, state_dict, new_state_dict):
down_key = f"{original_key}.lora_down.weight"
down_weight = state_dict.pop(down_key)
lora_rank = down_weight.shape[0]
up_weight_key = f"{original_key}.lora_up.weight"
up_weight = state_dict.pop(up_weight_key)
alpha_key = f"{original_key}.alpha"
alpha = state_dict.pop(alpha_key)
# scale weight by alpha and dim
scale = alpha / lora_rank
# calculate scale_down and scale_up
scale_down = scale
scale_up = 1.0
while scale_down * 2 < scale_up:
scale_down *= 2
scale_up /= 2
down_weight = down_weight * scale_down
up_weight = up_weight * scale_up
diffusers_down_key = f"{diffusers_key}.lora_A.weight"
new_state_dict[diffusers_down_key] = down_weight
new_state_dict[diffusers_down_key.replace(".lora_A.", ".lora_B.")] = up_weight
all_unique_keys = {
k.replace(".lora_down.weight", "").replace(".lora_up.weight", "").replace(".alpha", "") for k in state_dict
}
all_unique_keys = sorted(all_unique_keys)
assert all("lora_transformer_" in k for k in all_unique_keys), f"{all_unique_keys=}"
for k in all_unique_keys:
if k.startswith("lora_transformer_single_transformer_blocks_"):
i = int(k.split("lora_transformer_single_transformer_blocks_")[-1].split("_")[0])
diffusers_key = f"single_transformer_blocks.{i}"
elif k.startswith("lora_transformer_transformer_blocks_"):
i = int(k.split("lora_transformer_transformer_blocks_")[-1].split("_")[0])
diffusers_key = f"transformer_blocks.{i}"
else:
raise NotImplementedError
if "attn_" in k:
if "_to_out_0" in k:
diffusers_key += ".attn.to_out.0"
elif "_to_add_out" in k:
diffusers_key += ".attn.to_add_out"
elif any(qkv in k for qkv in ["to_q", "to_k", "to_v"]):
remaining = k.split("attn_")[-1]
diffusers_key += f".attn.{remaining}"
elif any(add_qkv in k for add_qkv in ["add_q_proj", "add_k_proj", "add_v_proj"]):
remaining = k.split("attn_")[-1]
diffusers_key += f".attn.{remaining}"
if diffusers_key == f"transformer_blocks.{i}":
print(k, diffusers_key)
_convert(k, diffusers_key, state_dict, new_state_dict)
if len(state_dict) > 0:
raise ValueError(
f"Expected an empty state dict at this point but its has these keys which couldn't be parsed: {list(state_dict.keys())}."
)
new_state_dict = {f"transformer.{k}": v for k, v in new_state_dict.items()}
return new_state_dict
# This is weird.
# https://huggingface.co/sayakpaul/different-lora-from-civitai/tree/main?show_file_info=sharp_detailed_foot.safetensors
# has both `peft` and non-peft state dict.
has_peft_state_dict = any(k.startswith("transformer.") for k in state_dict)
if has_peft_state_dict:
state_dict = {k: v for k, v in state_dict.items() if k.startswith("transformer.")}
return state_dict
# Another weird one.
has_mixture = any(
k.startswith("lora_transformer_") and ("lora_down" in k or "lora_up" in k or "alpha" in k) for k in state_dict
)
if has_mixture:
return _convert_mixture_state_dict_to_diffusers(state_dict)
return _convert_sd_scripts_to_ai_toolkit(state_dict)
@@ -177,5 +177,3 @@ class FluxTransformer2DLoadersMixin:
self.encoder_hid_proj = MultiIPAdapterImageProjection(image_projection_layers)
self.config.encoder_hid_dim_type = "ip_image_proj"
self.to(dtype=self.dtype, device=self.device)
+2
View File
@@ -39,6 +39,7 @@ if is_torch_available():
_import_structure["autoencoders.autoencoder_tiny"] = ["AutoencoderTiny"]
_import_structure["autoencoders.consistency_decoder_vae"] = ["ConsistencyDecoderVAE"]
_import_structure["autoencoders.vq_model"] = ["VQModel"]
_import_structure["cache_utils"] = ["CacheMixin"]
_import_structure["controlnets.controlnet"] = ["ControlNetModel"]
_import_structure["controlnets.controlnet_flux"] = ["FluxControlNetModel", "FluxMultiControlNetModel"]
_import_structure["controlnets.controlnet_hunyuan"] = [
@@ -109,6 +110,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
ConsistencyDecoderVAE,
VQModel,
)
from .cache_utils import CacheMixin
from .controlnets import (
ControlNetModel,
ControlNetUnionModel,
+11 -5
View File
@@ -405,11 +405,12 @@ class Attention(nn.Module):
else:
try:
# Make sure we can run the memory efficient attention
_ = xformers.ops.memory_efficient_attention(
torch.randn((1, 2, 40), device="cuda"),
torch.randn((1, 2, 40), device="cuda"),
torch.randn((1, 2, 40), device="cuda"),
)
dtype = None
if attention_op is not None:
op_fw, op_bw = attention_op
dtype, *_ = op_fw.SUPPORTED_DTYPES
q = torch.randn((1, 2, 40), device="cuda", dtype=dtype)
_ = xformers.ops.memory_efficient_attention(q, q, q)
except Exception as e:
raise e
@@ -3154,6 +3155,11 @@ class AttnProcessorNPU:
# scaled_dot_product_attention expects attention_mask shape to be
# (batch, heads, source_length, target_length)
attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
attention_mask = attention_mask.repeat(1, 1, hidden_states.shape[1], 1)
if attention_mask.dtype == torch.bool:
attention_mask = torch.logical_not(attention_mask.bool())
else:
attention_mask = attention_mask.bool()
if attn.group_norm is not None:
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
@@ -60,6 +60,8 @@ class AsymmetricAutoencoderKL(ModelMixin, ConfigMixin):
Synthesis with Latent Diffusion Models](https://arxiv.org/abs/2112.10752) paper.
"""
_skip_layerwise_casting_patterns = ["decoder"]
@register_to_config
def __init__(
self,
@@ -138,10 +138,6 @@ class AutoencoderKL(ModelMixin, ConfigMixin, FromOriginalModelMixin, PeftAdapter
self.tile_latent_min_size = int(sample_size / (2 ** (len(self.config.block_out_channels) - 1)))
self.tile_overlap_factor = 0.25
def _set_gradient_checkpointing(self, module, value=False):
if isinstance(module, (Encoder, Decoder)):
module.gradient_checkpointing = value
def enable_tiling(self, use_tiling: bool = True):
r"""
Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to
@@ -507,19 +507,12 @@ class AllegroEncoder3D(nn.Module):
sample = sample + residual
if torch.is_grad_enabled() and self.gradient_checkpointing:
def create_custom_forward(module):
def custom_forward(*inputs):
return module(*inputs)
return custom_forward
# Down blocks
for down_block in self.down_blocks:
sample = torch.utils.checkpoint.checkpoint(create_custom_forward(down_block), sample)
sample = self._gradient_checkpointing_func(down_block, sample)
# Mid block
sample = torch.utils.checkpoint.checkpoint(create_custom_forward(self.mid_block), sample)
sample = self._gradient_checkpointing_func(self.mid_block, sample)
else:
# Down blocks
for down_block in self.down_blocks:
@@ -647,19 +640,12 @@ class AllegroDecoder3D(nn.Module):
upscale_dtype = next(iter(self.up_blocks.parameters())).dtype
if torch.is_grad_enabled() and self.gradient_checkpointing:
def create_custom_forward(module):
def custom_forward(*inputs):
return module(*inputs)
return custom_forward
# Mid block
sample = torch.utils.checkpoint.checkpoint(create_custom_forward(self.mid_block), sample)
sample = self._gradient_checkpointing_func(self.mid_block, sample)
# Up blocks
for up_block in self.up_blocks:
sample = torch.utils.checkpoint.checkpoint(create_custom_forward(up_block), sample)
sample = self._gradient_checkpointing_func(up_block, sample)
else:
# Mid block
@@ -809,10 +795,6 @@ class AutoencoderKLAllegro(ModelMixin, ConfigMixin):
sample_size - self.tile_overlap_w,
)
def _set_gradient_checkpointing(self, module, value=False):
if isinstance(module, (AllegroEncoder3D, AllegroDecoder3D)):
module.gradient_checkpointing = value
def enable_tiling(self) -> None:
r"""
Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to
@@ -421,15 +421,8 @@ class CogVideoXDownBlock3D(nn.Module):
conv_cache_key = f"resnet_{i}"
if torch.is_grad_enabled() and self.gradient_checkpointing:
def create_custom_forward(module):
def create_forward(*inputs):
return module(*inputs)
return create_forward
hidden_states, new_conv_cache[conv_cache_key] = torch.utils.checkpoint.checkpoint(
create_custom_forward(resnet),
hidden_states, new_conv_cache[conv_cache_key] = self._gradient_checkpointing_func(
resnet,
hidden_states,
temb,
zq,
@@ -523,15 +516,8 @@ class CogVideoXMidBlock3D(nn.Module):
conv_cache_key = f"resnet_{i}"
if torch.is_grad_enabled() and self.gradient_checkpointing:
def create_custom_forward(module):
def create_forward(*inputs):
return module(*inputs)
return create_forward
hidden_states, new_conv_cache[conv_cache_key] = torch.utils.checkpoint.checkpoint(
create_custom_forward(resnet), hidden_states, temb, zq, conv_cache.get(conv_cache_key)
hidden_states, new_conv_cache[conv_cache_key] = self._gradient_checkpointing_func(
resnet, hidden_states, temb, zq, conv_cache.get(conv_cache_key)
)
else:
hidden_states, new_conv_cache[conv_cache_key] = resnet(
@@ -637,15 +623,8 @@ class CogVideoXUpBlock3D(nn.Module):
conv_cache_key = f"resnet_{i}"
if torch.is_grad_enabled() and self.gradient_checkpointing:
def create_custom_forward(module):
def create_forward(*inputs):
return module(*inputs)
return create_forward
hidden_states, new_conv_cache[conv_cache_key] = torch.utils.checkpoint.checkpoint(
create_custom_forward(resnet),
hidden_states, new_conv_cache[conv_cache_key] = self._gradient_checkpointing_func(
resnet,
hidden_states,
temb,
zq,
@@ -774,18 +753,11 @@ class CogVideoXEncoder3D(nn.Module):
hidden_states, new_conv_cache["conv_in"] = self.conv_in(sample, conv_cache=conv_cache.get("conv_in"))
if torch.is_grad_enabled() and self.gradient_checkpointing:
def create_custom_forward(module):
def custom_forward(*inputs):
return module(*inputs)
return custom_forward
# 1. Down
for i, down_block in enumerate(self.down_blocks):
conv_cache_key = f"down_block_{i}"
hidden_states, new_conv_cache[conv_cache_key] = torch.utils.checkpoint.checkpoint(
create_custom_forward(down_block),
hidden_states, new_conv_cache[conv_cache_key] = self._gradient_checkpointing_func(
down_block,
hidden_states,
temb,
None,
@@ -793,8 +765,8 @@ class CogVideoXEncoder3D(nn.Module):
)
# 2. Mid
hidden_states, new_conv_cache["mid_block"] = torch.utils.checkpoint.checkpoint(
create_custom_forward(self.mid_block),
hidden_states, new_conv_cache["mid_block"] = self._gradient_checkpointing_func(
self.mid_block,
hidden_states,
temb,
None,
@@ -940,16 +912,9 @@ class CogVideoXDecoder3D(nn.Module):
hidden_states, new_conv_cache["conv_in"] = self.conv_in(sample, conv_cache=conv_cache.get("conv_in"))
if torch.is_grad_enabled() and self.gradient_checkpointing:
def create_custom_forward(module):
def custom_forward(*inputs):
return module(*inputs)
return custom_forward
# 1. Mid
hidden_states, new_conv_cache["mid_block"] = torch.utils.checkpoint.checkpoint(
create_custom_forward(self.mid_block),
hidden_states, new_conv_cache["mid_block"] = self._gradient_checkpointing_func(
self.mid_block,
hidden_states,
temb,
sample,
@@ -959,8 +924,8 @@ class CogVideoXDecoder3D(nn.Module):
# 2. Up
for i, up_block in enumerate(self.up_blocks):
conv_cache_key = f"up_block_{i}"
hidden_states, new_conv_cache[conv_cache_key] = torch.utils.checkpoint.checkpoint(
create_custom_forward(up_block),
hidden_states, new_conv_cache[conv_cache_key] = self._gradient_checkpointing_func(
up_block,
hidden_states,
temb,
sample,
@@ -1122,10 +1087,6 @@ class AutoencoderKLCogVideoX(ModelMixin, ConfigMixin, FromOriginalModelMixin):
self.tile_overlap_factor_height = 1 / 6
self.tile_overlap_factor_width = 1 / 5
def _set_gradient_checkpointing(self, module, value=False):
if isinstance(module, (CogVideoXEncoder3D, CogVideoXDecoder3D)):
module.gradient_checkpointing = value
def enable_tiling(
self,
tile_sample_min_height: Optional[int] = None,
@@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, Dict, Optional, Tuple, Union
from typing import Optional, Tuple, Union
import numpy as np
import torch
@@ -21,7 +21,7 @@ import torch.nn.functional as F
import torch.utils.checkpoint
from ...configuration_utils import ConfigMixin, register_to_config
from ...utils import is_torch_version, logging
from ...utils import logging
from ...utils.accelerate_utils import apply_forward_hook
from ..activations import get_activation
from ..attention_processor import Attention
@@ -252,21 +252,7 @@ class HunyuanVideoMidBlock3D(nn.Module):
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
if torch.is_grad_enabled() and self.gradient_checkpointing:
def create_custom_forward(module, return_dict=None):
def custom_forward(*inputs):
if return_dict is not None:
return module(*inputs, return_dict=return_dict)
else:
return module(*inputs)
return custom_forward
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(self.resnets[0]), hidden_states, **ckpt_kwargs
)
hidden_states = self._gradient_checkpointing_func(self.resnets[0], hidden_states)
for attn, resnet in zip(self.attentions, self.resnets[1:]):
if attn is not None:
@@ -278,9 +264,7 @@ class HunyuanVideoMidBlock3D(nn.Module):
hidden_states = attn(hidden_states, attention_mask=attention_mask)
hidden_states = hidden_states.unflatten(1, (num_frames, height, width)).permute(0, 4, 1, 2, 3)
hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(resnet), hidden_states, **ckpt_kwargs
)
hidden_states = self._gradient_checkpointing_func(resnet, hidden_states)
else:
hidden_states = self.resnets[0](hidden_states)
@@ -350,22 +334,8 @@ class HunyuanVideoDownBlock3D(nn.Module):
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
if torch.is_grad_enabled() and self.gradient_checkpointing:
def create_custom_forward(module, return_dict=None):
def custom_forward(*inputs):
if return_dict is not None:
return module(*inputs, return_dict=return_dict)
else:
return module(*inputs)
return custom_forward
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
for resnet in self.resnets:
hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(resnet), hidden_states, **ckpt_kwargs
)
hidden_states = self._gradient_checkpointing_func(resnet, hidden_states)
else:
for resnet in self.resnets:
hidden_states = resnet(hidden_states)
@@ -426,22 +396,8 @@ class HunyuanVideoUpBlock3D(nn.Module):
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
if torch.is_grad_enabled() and self.gradient_checkpointing:
def create_custom_forward(module, return_dict=None):
def custom_forward(*inputs):
if return_dict is not None:
return module(*inputs, return_dict=return_dict)
else:
return module(*inputs)
return custom_forward
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
for resnet in self.resnets:
hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(resnet), hidden_states, **ckpt_kwargs
)
hidden_states = self._gradient_checkpointing_func(resnet, hidden_states)
else:
for resnet in self.resnets:
@@ -545,26 +501,10 @@ class HunyuanVideoEncoder3D(nn.Module):
hidden_states = self.conv_in(hidden_states)
if torch.is_grad_enabled() and self.gradient_checkpointing:
def create_custom_forward(module, return_dict=None):
def custom_forward(*inputs):
if return_dict is not None:
return module(*inputs, return_dict=return_dict)
else:
return module(*inputs)
return custom_forward
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
for down_block in self.down_blocks:
hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(down_block), hidden_states, **ckpt_kwargs
)
hidden_states = self._gradient_checkpointing_func(down_block, hidden_states)
hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(self.mid_block), hidden_states, **ckpt_kwargs
)
hidden_states = self._gradient_checkpointing_func(self.mid_block, hidden_states)
else:
for down_block in self.down_blocks:
hidden_states = down_block(hidden_states)
@@ -667,26 +607,10 @@ class HunyuanVideoDecoder3D(nn.Module):
hidden_states = self.conv_in(hidden_states)
if torch.is_grad_enabled() and self.gradient_checkpointing:
def create_custom_forward(module, return_dict=None):
def custom_forward(*inputs):
if return_dict is not None:
return module(*inputs, return_dict=return_dict)
else:
return module(*inputs)
return custom_forward
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(self.mid_block), hidden_states, **ckpt_kwargs
)
hidden_states = self._gradient_checkpointing_func(self.mid_block, hidden_states)
for up_block in self.up_blocks:
hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(up_block), hidden_states, **ckpt_kwargs
)
hidden_states = self._gradient_checkpointing_func(up_block, hidden_states)
else:
hidden_states = self.mid_block(hidden_states)
@@ -786,7 +710,7 @@ class AutoencoderKLHunyuanVideo(ModelMixin, ConfigMixin):
self.use_tiling = False
# When decoding temporally long video latents, the memory requirement is very high. By decoding latent frames
# at a fixed frame batch size (based on `self.num_latent_frames_batch_sizes`), the memory requirement can be lowered.
# at a fixed frame batch size (based on `self.tile_sample_min_num_frames`), the memory requirement can be lowered.
self.use_framewise_encoding = True
self.use_framewise_decoding = True
@@ -800,10 +724,6 @@ class AutoencoderKLHunyuanVideo(ModelMixin, ConfigMixin):
self.tile_sample_stride_width = 192
self.tile_sample_stride_num_frames = 12
def _set_gradient_checkpointing(self, module, value=False):
if isinstance(module, (HunyuanVideoEncoder3D, HunyuanVideoDecoder3D)):
module.gradient_checkpointing = value
def enable_tiling(
self,
tile_sample_min_height: Optional[int] = None,
@@ -868,7 +788,7 @@ class AutoencoderKLHunyuanVideo(ModelMixin, ConfigMixin):
def _encode(self, x: torch.Tensor) -> torch.Tensor:
batch_size, num_channels, num_frames, height, width = x.shape
if self.use_framewise_decoding and num_frames > self.tile_sample_min_num_frames:
if self.use_framewise_encoding and num_frames > self.tile_sample_min_num_frames:
return self._temporal_tiled_encode(x)
if self.use_tiling and (width > self.tile_sample_min_width or height > self.tile_sample_min_height):
@@ -338,16 +338,7 @@ class LTXVideoDownBlock3D(nn.Module):
for i, resnet in enumerate(self.resnets):
if torch.is_grad_enabled() and self.gradient_checkpointing:
def create_custom_forward(module):
def create_forward(*inputs):
return module(*inputs)
return create_forward
hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(resnet), hidden_states, temb, generator
)
hidden_states = self._gradient_checkpointing_func(resnet, hidden_states, temb, generator)
else:
hidden_states = resnet(hidden_states, temb, generator)
@@ -438,16 +429,7 @@ class LTXVideoMidBlock3d(nn.Module):
for i, resnet in enumerate(self.resnets):
if torch.is_grad_enabled() and self.gradient_checkpointing:
def create_custom_forward(module):
def create_forward(*inputs):
return module(*inputs)
return create_forward
hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(resnet), hidden_states, temb, generator
)
hidden_states = self._gradient_checkpointing_func(resnet, hidden_states, temb, generator)
else:
hidden_states = resnet(hidden_states, temb, generator)
@@ -573,16 +555,7 @@ class LTXVideoUpBlock3d(nn.Module):
for i, resnet in enumerate(self.resnets):
if torch.is_grad_enabled() and self.gradient_checkpointing:
def create_custom_forward(module):
def create_forward(*inputs):
return module(*inputs)
return create_forward
hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(resnet), hidden_states, temb, generator
)
hidden_states = self._gradient_checkpointing_func(resnet, hidden_states, temb, generator)
else:
hidden_states = resnet(hidden_states, temb, generator)
@@ -697,17 +670,10 @@ class LTXVideoEncoder3d(nn.Module):
hidden_states = self.conv_in(hidden_states)
if torch.is_grad_enabled() and self.gradient_checkpointing:
def create_custom_forward(module):
def create_forward(*inputs):
return module(*inputs)
return create_forward
for down_block in self.down_blocks:
hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(down_block), hidden_states)
hidden_states = self._gradient_checkpointing_func(down_block, hidden_states)
hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(self.mid_block), hidden_states)
hidden_states = self._gradient_checkpointing_func(self.mid_block, hidden_states)
else:
for down_block in self.down_blocks:
hidden_states = down_block(hidden_states)
@@ -838,19 +804,10 @@ class LTXVideoDecoder3d(nn.Module):
hidden_states = self.conv_in(hidden_states)
if torch.is_grad_enabled() and self.gradient_checkpointing:
def create_custom_forward(module):
def create_forward(*inputs):
return module(*inputs)
return create_forward
hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(self.mid_block), hidden_states, temb
)
hidden_states = self._gradient_checkpointing_func(self.mid_block, hidden_states, temb)
for up_block in self.up_blocks:
hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(up_block), hidden_states, temb)
hidden_states = self._gradient_checkpointing_func(up_block, hidden_states, temb)
else:
hidden_states = self.mid_block(hidden_states, temb)
@@ -1017,10 +974,6 @@ class AutoencoderKLLTXVideo(ModelMixin, ConfigMixin, FromOriginalModelMixin):
self.tile_sample_stride_width = 448
self.tile_sample_stride_num_frames = 8
def _set_gradient_checkpointing(self, module, value=False):
if isinstance(module, (LTXVideoEncoder3d, LTXVideoDecoder3d)):
module.gradient_checkpointing = value
def enable_tiling(
self,
tile_sample_min_height: Optional[int] = None,
@@ -207,15 +207,8 @@ class MochiDownBlock3D(nn.Module):
conv_cache_key = f"resnet_{i}"
if torch.is_grad_enabled() and self.gradient_checkpointing:
def create_custom_forward(module):
def create_forward(*inputs):
return module(*inputs)
return create_forward
hidden_states, new_conv_cache[conv_cache_key] = torch.utils.checkpoint.checkpoint(
create_custom_forward(resnet),
hidden_states, new_conv_cache[conv_cache_key] = self._gradient_checkpointing_func(
resnet,
hidden_states,
conv_cache=conv_cache.get(conv_cache_key),
)
@@ -312,15 +305,8 @@ class MochiMidBlock3D(nn.Module):
conv_cache_key = f"resnet_{i}"
if torch.is_grad_enabled() and self.gradient_checkpointing:
def create_custom_forward(module):
def create_forward(*inputs):
return module(*inputs)
return create_forward
hidden_states, new_conv_cache[conv_cache_key] = torch.utils.checkpoint.checkpoint(
create_custom_forward(resnet), hidden_states, conv_cache=conv_cache.get(conv_cache_key)
hidden_states, new_conv_cache[conv_cache_key] = self._gradient_checkpointing_func(
resnet, hidden_states, conv_cache=conv_cache.get(conv_cache_key)
)
else:
hidden_states, new_conv_cache[conv_cache_key] = resnet(
@@ -393,15 +379,8 @@ class MochiUpBlock3D(nn.Module):
conv_cache_key = f"resnet_{i}"
if torch.is_grad_enabled() and self.gradient_checkpointing:
def create_custom_forward(module):
def create_forward(*inputs):
return module(*inputs)
return create_forward
hidden_states, new_conv_cache[conv_cache_key] = torch.utils.checkpoint.checkpoint(
create_custom_forward(resnet),
hidden_states, new_conv_cache[conv_cache_key] = self._gradient_checkpointing_func(
resnet,
hidden_states,
conv_cache=conv_cache.get(conv_cache_key),
)
@@ -531,21 +510,14 @@ class MochiEncoder3D(nn.Module):
hidden_states = hidden_states.permute(0, 4, 1, 2, 3)
if torch.is_grad_enabled() and self.gradient_checkpointing:
def create_custom_forward(module):
def create_forward(*inputs):
return module(*inputs)
return create_forward
hidden_states, new_conv_cache["block_in"] = torch.utils.checkpoint.checkpoint(
create_custom_forward(self.block_in), hidden_states, conv_cache=conv_cache.get("block_in")
hidden_states, new_conv_cache["block_in"] = self._gradient_checkpointing_func(
self.block_in, hidden_states, conv_cache=conv_cache.get("block_in")
)
for i, down_block in enumerate(self.down_blocks):
conv_cache_key = f"down_block_{i}"
hidden_states, new_conv_cache[conv_cache_key] = torch.utils.checkpoint.checkpoint(
create_custom_forward(down_block), hidden_states, conv_cache=conv_cache.get(conv_cache_key)
hidden_states, new_conv_cache[conv_cache_key] = self._gradient_checkpointing_func(
down_block, hidden_states, conv_cache=conv_cache.get(conv_cache_key)
)
else:
hidden_states, new_conv_cache["block_in"] = self.block_in(
@@ -648,21 +620,14 @@ class MochiDecoder3D(nn.Module):
# 1. Mid
if torch.is_grad_enabled() and self.gradient_checkpointing:
def create_custom_forward(module):
def create_forward(*inputs):
return module(*inputs)
return create_forward
hidden_states, new_conv_cache["block_in"] = torch.utils.checkpoint.checkpoint(
create_custom_forward(self.block_in), hidden_states, conv_cache=conv_cache.get("block_in")
hidden_states, new_conv_cache["block_in"] = self._gradient_checkpointing_func(
self.block_in, hidden_states, conv_cache=conv_cache.get("block_in")
)
for i, up_block in enumerate(self.up_blocks):
conv_cache_key = f"up_block_{i}"
hidden_states, new_conv_cache[conv_cache_key] = torch.utils.checkpoint.checkpoint(
create_custom_forward(up_block), hidden_states, conv_cache=conv_cache.get(conv_cache_key)
hidden_states, new_conv_cache[conv_cache_key] = self._gradient_checkpointing_func(
up_block, hidden_states, conv_cache=conv_cache.get(conv_cache_key)
)
else:
hidden_states, new_conv_cache["block_in"] = self.block_in(
@@ -819,10 +784,6 @@ class AutoencoderKLMochi(ModelMixin, ConfigMixin):
self.tile_sample_stride_height = 192
self.tile_sample_stride_width = 192
def _set_gradient_checkpointing(self, module, value=False):
if isinstance(module, (MochiEncoder3D, MochiDecoder3D)):
module.gradient_checkpointing = value
def enable_tiling(
self,
tile_sample_min_height: Optional[int] = None,
@@ -18,7 +18,6 @@ import torch
import torch.nn as nn
from ...configuration_utils import ConfigMixin, register_to_config
from ...utils import is_torch_version
from ...utils.accelerate_utils import apply_forward_hook
from ..attention_processor import CROSS_ATTENTION_PROCESSORS, AttentionProcessor, AttnProcessor
from ..modeling_outputs import AutoencoderKLOutput
@@ -97,47 +96,21 @@ class TemporalDecoder(nn.Module):
upscale_dtype = next(itertools.chain(self.up_blocks.parameters(), self.up_blocks.buffers())).dtype
if torch.is_grad_enabled() and self.gradient_checkpointing:
# middle
sample = self._gradient_checkpointing_func(
self.mid_block,
sample,
image_only_indicator,
)
sample = sample.to(upscale_dtype)
def create_custom_forward(module):
def custom_forward(*inputs):
return module(*inputs)
return custom_forward
if is_torch_version(">=", "1.11.0"):
# middle
sample = torch.utils.checkpoint.checkpoint(
create_custom_forward(self.mid_block),
sample,
image_only_indicator,
use_reentrant=False,
)
sample = sample.to(upscale_dtype)
# up
for up_block in self.up_blocks:
sample = torch.utils.checkpoint.checkpoint(
create_custom_forward(up_block),
sample,
image_only_indicator,
use_reentrant=False,
)
else:
# middle
sample = torch.utils.checkpoint.checkpoint(
create_custom_forward(self.mid_block),
# up
for up_block in self.up_blocks:
sample = self._gradient_checkpointing_func(
up_block,
sample,
image_only_indicator,
)
sample = sample.to(upscale_dtype)
# up
for up_block in self.up_blocks:
sample = torch.utils.checkpoint.checkpoint(
create_custom_forward(up_block),
sample,
image_only_indicator,
)
else:
# middle
sample = self.mid_block(sample, image_only_indicator=image_only_indicator)
@@ -229,10 +202,6 @@ class AutoencoderKLTemporalDecoder(ModelMixin, ConfigMixin):
self.quant_conv = nn.Conv2d(2 * latent_channels, 2 * latent_channels, 1)
def _set_gradient_checkpointing(self, module, value=False):
if isinstance(module, (Encoder, TemporalDecoder)):
module.gradient_checkpointing = value
@property
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors
def attn_processors(self) -> Dict[str, AttentionProcessor]:
@@ -154,10 +154,6 @@ class AutoencoderTiny(ModelMixin, ConfigMixin):
self.register_to_config(block_out_channels=decoder_block_out_channels)
self.register_to_config(force_upcast=False)
def _set_gradient_checkpointing(self, module, value: bool = False) -> None:
if isinstance(module, (EncoderTiny, DecoderTiny)):
module.gradient_checkpointing = value
def scale_latents(self, x: torch.Tensor) -> torch.Tensor:
"""raw latents -> [0, 1]"""
return x.div(2 * self.latent_magnitude).add(self.latent_shift).clamp(0, 1)
+31 -141
View File
@@ -18,7 +18,7 @@ import numpy as np
import torch
import torch.nn as nn
from ...utils import BaseOutput, is_torch_version
from ...utils import BaseOutput
from ...utils.torch_utils import randn_tensor
from ..activations import get_activation
from ..attention_processor import SpatialNorm
@@ -156,28 +156,11 @@ class Encoder(nn.Module):
sample = self.conv_in(sample)
if torch.is_grad_enabled() and self.gradient_checkpointing:
def create_custom_forward(module):
def custom_forward(*inputs):
return module(*inputs)
return custom_forward
# down
if is_torch_version(">=", "1.11.0"):
for down_block in self.down_blocks:
sample = torch.utils.checkpoint.checkpoint(
create_custom_forward(down_block), sample, use_reentrant=False
)
# middle
sample = torch.utils.checkpoint.checkpoint(
create_custom_forward(self.mid_block), sample, use_reentrant=False
)
else:
for down_block in self.down_blocks:
sample = torch.utils.checkpoint.checkpoint(create_custom_forward(down_block), sample)
# middle
sample = torch.utils.checkpoint.checkpoint(create_custom_forward(self.mid_block), sample)
for down_block in self.down_blocks:
sample = self._gradient_checkpointing_func(down_block, sample)
# middle
sample = self._gradient_checkpointing_func(self.mid_block, sample)
else:
# down
@@ -305,41 +288,13 @@ class Decoder(nn.Module):
upscale_dtype = next(iter(self.up_blocks.parameters())).dtype
if torch.is_grad_enabled() and self.gradient_checkpointing:
# middle
sample = self._gradient_checkpointing_func(self.mid_block, sample, latent_embeds)
sample = sample.to(upscale_dtype)
def create_custom_forward(module):
def custom_forward(*inputs):
return module(*inputs)
return custom_forward
if is_torch_version(">=", "1.11.0"):
# middle
sample = torch.utils.checkpoint.checkpoint(
create_custom_forward(self.mid_block),
sample,
latent_embeds,
use_reentrant=False,
)
sample = sample.to(upscale_dtype)
# up
for up_block in self.up_blocks:
sample = torch.utils.checkpoint.checkpoint(
create_custom_forward(up_block),
sample,
latent_embeds,
use_reentrant=False,
)
else:
# middle
sample = torch.utils.checkpoint.checkpoint(
create_custom_forward(self.mid_block), sample, latent_embeds
)
sample = sample.to(upscale_dtype)
# up
for up_block in self.up_blocks:
sample = torch.utils.checkpoint.checkpoint(create_custom_forward(up_block), sample, latent_embeds)
# up
for up_block in self.up_blocks:
sample = self._gradient_checkpointing_func(up_block, sample, latent_embeds)
else:
# middle
sample = self.mid_block(sample, latent_embeds)
@@ -558,72 +513,28 @@ class MaskConditionDecoder(nn.Module):
upscale_dtype = next(iter(self.up_blocks.parameters())).dtype
if torch.is_grad_enabled() and self.gradient_checkpointing:
# middle
sample = self._gradient_checkpointing_func(self.mid_block, sample, latent_embeds)
sample = sample.to(upscale_dtype)
def create_custom_forward(module):
def custom_forward(*inputs):
return module(*inputs)
return custom_forward
if is_torch_version(">=", "1.11.0"):
# middle
sample = torch.utils.checkpoint.checkpoint(
create_custom_forward(self.mid_block),
sample,
latent_embeds,
use_reentrant=False,
# condition encoder
if image is not None and mask is not None:
masked_image = (1 - mask) * image
im_x = self._gradient_checkpointing_func(
self.condition_encoder,
masked_image,
mask,
)
sample = sample.to(upscale_dtype)
# condition encoder
# up
for up_block in self.up_blocks:
if image is not None and mask is not None:
masked_image = (1 - mask) * image
im_x = torch.utils.checkpoint.checkpoint(
create_custom_forward(self.condition_encoder),
masked_image,
mask,
use_reentrant=False,
)
# up
for up_block in self.up_blocks:
if image is not None and mask is not None:
sample_ = im_x[str(tuple(sample.shape))]
mask_ = nn.functional.interpolate(mask, size=sample.shape[-2:], mode="nearest")
sample = sample * mask_ + sample_ * (1 - mask_)
sample = torch.utils.checkpoint.checkpoint(
create_custom_forward(up_block),
sample,
latent_embeds,
use_reentrant=False,
)
if image is not None and mask is not None:
sample = sample * mask + im_x[str(tuple(sample.shape))] * (1 - mask)
else:
# middle
sample = torch.utils.checkpoint.checkpoint(
create_custom_forward(self.mid_block), sample, latent_embeds
)
sample = sample.to(upscale_dtype)
# condition encoder
if image is not None and mask is not None:
masked_image = (1 - mask) * image
im_x = torch.utils.checkpoint.checkpoint(
create_custom_forward(self.condition_encoder),
masked_image,
mask,
)
# up
for up_block in self.up_blocks:
if image is not None and mask is not None:
sample_ = im_x[str(tuple(sample.shape))]
mask_ = nn.functional.interpolate(mask, size=sample.shape[-2:], mode="nearest")
sample = sample * mask_ + sample_ * (1 - mask_)
sample = torch.utils.checkpoint.checkpoint(create_custom_forward(up_block), sample, latent_embeds)
if image is not None and mask is not None:
sample = sample * mask + im_x[str(tuple(sample.shape))] * (1 - mask)
sample_ = im_x[str(tuple(sample.shape))]
mask_ = nn.functional.interpolate(mask, size=sample.shape[-2:], mode="nearest")
sample = sample * mask_ + sample_ * (1 - mask_)
sample = self._gradient_checkpointing_func(up_block, sample, latent_embeds)
if image is not None and mask is not None:
sample = sample * mask + im_x[str(tuple(sample.shape))] * (1 - mask)
else:
# middle
sample = self.mid_block(sample, latent_embeds)
@@ -890,17 +801,7 @@ class EncoderTiny(nn.Module):
def forward(self, x: torch.Tensor) -> torch.Tensor:
r"""The forward method of the `EncoderTiny` class."""
if torch.is_grad_enabled() and self.gradient_checkpointing:
def create_custom_forward(module):
def custom_forward(*inputs):
return module(*inputs)
return custom_forward
if is_torch_version(">=", "1.11.0"):
x = torch.utils.checkpoint.checkpoint(create_custom_forward(self.layers), x, use_reentrant=False)
else:
x = torch.utils.checkpoint.checkpoint(create_custom_forward(self.layers), x)
x = self._gradient_checkpointing_func(self.layers, x)
else:
# scale image from [-1, 1] to [0, 1] to match TAESD convention
@@ -976,18 +877,7 @@ class DecoderTiny(nn.Module):
x = torch.tanh(x / 3) * 3
if torch.is_grad_enabled() and self.gradient_checkpointing:
def create_custom_forward(module):
def custom_forward(*inputs):
return module(*inputs)
return custom_forward
if is_torch_version(">=", "1.11.0"):
x = torch.utils.checkpoint.checkpoint(create_custom_forward(self.layers), x, use_reentrant=False)
else:
x = torch.utils.checkpoint.checkpoint(create_custom_forward(self.layers), x)
x = self._gradient_checkpointing_func(self.layers, x)
else:
x = self.layers(x)
@@ -71,6 +71,8 @@ class VQModel(ModelMixin, ConfigMixin):
Type of normalization layer to use. Can be one of `"group"` or `"spatial"`.
"""
_skip_layerwise_casting_patterns = ["quantize"]
@register_to_config
def __init__(
self,
+89
View File
@@ -0,0 +1,89 @@
# Copyright 2024 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from ..utils.logging import get_logger
logger = get_logger(__name__) # pylint: disable=invalid-name
class CacheMixin:
r"""
A class for enable/disabling caching techniques on diffusion models.
Supported caching techniques:
- [Pyramid Attention Broadcast](https://huggingface.co/papers/2408.12588)
"""
_cache_config = None
@property
def is_cache_enabled(self) -> bool:
return self._cache_config is not None
def enable_cache(self, config) -> None:
r"""
Enable caching techniques on the model.
Args:
config (`Union[PyramidAttentionBroadcastConfig]`):
The configuration for applying the caching technique. Currently supported caching techniques are:
- [`~hooks.PyramidAttentionBroadcastConfig`]
Example:
```python
>>> import torch
>>> from diffusers import CogVideoXPipeline, PyramidAttentionBroadcastConfig
>>> pipe = CogVideoXPipeline.from_pretrained("THUDM/CogVideoX-5b", torch_dtype=torch.bfloat16)
>>> pipe.to("cuda")
>>> config = PyramidAttentionBroadcastConfig(
... spatial_attention_block_skip_range=2,
... spatial_attention_timestep_skip_range=(100, 800),
... current_timestep_callback=lambda: pipe.current_timestep,
... )
>>> pipe.transformer.enable_cache(config)
```
"""
from ..hooks import PyramidAttentionBroadcastConfig, apply_pyramid_attention_broadcast
if isinstance(config, PyramidAttentionBroadcastConfig):
apply_pyramid_attention_broadcast(self, config)
else:
raise ValueError(f"Cache config {type(config)} is not supported.")
self._cache_config = config
def disable_cache(self) -> None:
from ..hooks import HookRegistry, PyramidAttentionBroadcastConfig
if self._cache_config is None:
logger.warning("Caching techniques have not been enabled, so there's nothing to disable.")
return
if isinstance(self._cache_config, PyramidAttentionBroadcastConfig):
registry = HookRegistry.check_if_exists_or_initialize(self)
registry.remove_hook("pyramid_attention_broadcast", recurse=True)
else:
raise ValueError(f"Cache config {type(self._cache_config)} is not supported.")
self._cache_config = None
def _reset_stateful_cache(self, recurse: bool = True) -> None:
from ..hooks import HookRegistry
HookRegistry.check_if_exists_or_initialize(self).reset_stateful_hooks(recurse=recurse)
@@ -31,8 +31,6 @@ from ..attention_processor import (
from ..embeddings import TextImageProjection, TextImageTimeEmbedding, TextTimeEmbedding, TimestepEmbedding, Timesteps
from ..modeling_utils import ModelMixin
from ..unets.unet_2d_blocks import (
CrossAttnDownBlock2D,
DownBlock2D,
UNetMidBlock2D,
UNetMidBlock2DCrossAttn,
get_down_block,
@@ -659,10 +657,6 @@ class ControlNetModel(ModelMixin, ConfigMixin, FromOriginalModelMixin):
for module in self.children():
fn_recursive_set_attention_slice(module, reversed_slice_size)
def _set_gradient_checkpointing(self, module, value: bool = False) -> None:
if isinstance(module, (CrossAttnDownBlock2D, DownBlock2D)):
module.gradient_checkpointing = value
def forward(
self,
sample: torch.Tensor,
@@ -22,7 +22,7 @@ from ...configuration_utils import ConfigMixin, register_to_config
from ...loaders import PeftAdapterMixin
from ...models.attention_processor import AttentionProcessor
from ...models.modeling_utils import ModelMixin
from ...utils import USE_PEFT_BACKEND, BaseOutput, is_torch_version, logging, scale_lora_layers, unscale_lora_layers
from ...utils import USE_PEFT_BACKEND, BaseOutput, logging, scale_lora_layers, unscale_lora_layers
from ..controlnets.controlnet import ControlNetConditioningEmbedding, zero_module
from ..embeddings import CombinedTimestepGuidanceTextProjEmbeddings, CombinedTimestepTextProjEmbeddings, FluxPosEmbed
from ..modeling_outputs import Transformer2DModelOutput
@@ -178,10 +178,6 @@ class FluxControlNetModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
for name, module in self.named_children():
fn_recursive_attn_processor(name, module, processor)
def _set_gradient_checkpointing(self, module, value=False):
if hasattr(module, "gradient_checkpointing"):
module.gradient_checkpointing = value
@classmethod
def from_transformer(
cls,
@@ -330,24 +326,12 @@ class FluxControlNetModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
block_samples = ()
for index_block, block in enumerate(self.transformer_blocks):
if torch.is_grad_enabled() and self.gradient_checkpointing:
def create_custom_forward(module, return_dict=None):
def custom_forward(*inputs):
if return_dict is not None:
return module(*inputs, return_dict=return_dict)
else:
return module(*inputs)
return custom_forward
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
encoder_hidden_states, hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(block),
encoder_hidden_states, hidden_states = self._gradient_checkpointing_func(
block,
hidden_states,
encoder_hidden_states,
temb,
image_rotary_emb,
**ckpt_kwargs,
)
else:
@@ -364,23 +348,11 @@ class FluxControlNetModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
single_block_samples = ()
for index_block, block in enumerate(self.single_transformer_blocks):
if torch.is_grad_enabled() and self.gradient_checkpointing:
def create_custom_forward(module, return_dict=None):
def custom_forward(*inputs):
if return_dict is not None:
return module(*inputs, return_dict=return_dict)
else:
return module(*inputs)
return custom_forward
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(block),
hidden_states = self._gradient_checkpointing_func(
block,
hidden_states,
temb,
image_rotary_emb,
**ckpt_kwargs,
)
else:
@@ -21,7 +21,7 @@ import torch.nn as nn
from ...configuration_utils import ConfigMixin, register_to_config
from ...loaders import FromOriginalModelMixin, PeftAdapterMixin
from ...utils import USE_PEFT_BACKEND, is_torch_version, logging, scale_lora_layers, unscale_lora_layers
from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers
from ..attention import JointTransformerBlock
from ..attention_processor import Attention, AttentionProcessor, FusedJointAttnProcessor2_0
from ..embeddings import CombinedTimestepTextProjEmbeddings, PatchEmbed
@@ -262,10 +262,6 @@ class SD3ControlNetModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginal
if self.original_attn_processors is not None:
self.set_attn_processor(self.original_attn_processors)
def _set_gradient_checkpointing(self, module, value=False):
if hasattr(module, "gradient_checkpointing"):
module.gradient_checkpointing = value
# Notes: This is for SD3.5 8b controlnet, which shares the pos_embed with the transformer
# we should have handled this in conversion script
def _get_pos_embed_from_transformer(self, transformer):
@@ -382,30 +378,16 @@ class SD3ControlNetModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginal
for block in self.transformer_blocks:
if torch.is_grad_enabled() and self.gradient_checkpointing:
def create_custom_forward(module, return_dict=None):
def custom_forward(*inputs):
if return_dict is not None:
return module(*inputs, return_dict=return_dict)
else:
return module(*inputs)
return custom_forward
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
if self.context_embedder is not None:
encoder_hidden_states, hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(block),
encoder_hidden_states, hidden_states = self._gradient_checkpointing_func(
block,
hidden_states,
encoder_hidden_states,
temb,
**ckpt_kwargs,
)
else:
# SD3.5 8b controlnet use single transformer block, which does not use `encoder_hidden_states`
hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(block), hidden_states, temb, **ckpt_kwargs
)
hidden_states = self._gradient_checkpointing_func(block, hidden_states, temb)
else:
if self.context_embedder is not None:
@@ -590,10 +590,6 @@ class SparseControlNetModel(ModelMixin, ConfigMixin, FromOriginalModelMixin):
for module in self.children():
fn_recursive_set_attention_slice(module, reversed_slice_size)
def _set_gradient_checkpointing(self, module, value: bool = False) -> None:
if isinstance(module, (CrossAttnDownBlockMotion, DownBlockMotion, UNetMidBlock2DCrossAttn)):
module.gradient_checkpointing = value
def forward(
self,
sample: torch.Tensor,
@@ -29,8 +29,6 @@ from ..attention_processor import (
from ..embeddings import TextImageTimeEmbedding, TextTimeEmbedding, TimestepEmbedding, Timesteps
from ..modeling_utils import ModelMixin
from ..unets.unet_2d_blocks import (
CrossAttnDownBlock2D,
DownBlock2D,
UNetMidBlock2DCrossAttn,
get_down_block,
)
@@ -599,10 +597,6 @@ class ControlNetUnionModel(ModelMixin, ConfigMixin, FromOriginalModelMixin):
for module in self.children():
fn_recursive_set_attention_slice(module, reversed_slice_size)
def _set_gradient_checkpointing(self, module, value: bool = False) -> None:
if isinstance(module, (CrossAttnDownBlock2D, DownBlock2D)):
module.gradient_checkpointing = value
def forward(
self,
sample: torch.Tensor,
@@ -20,7 +20,7 @@ import torch.utils.checkpoint
from torch import Tensor, nn
from ...configuration_utils import ConfigMixin, register_to_config
from ...utils import BaseOutput, is_torch_version, logging
from ...utils import BaseOutput, logging
from ...utils.torch_utils import apply_freeu
from ..attention_processor import (
ADDED_KV_ATTENTION_PROCESSORS,
@@ -864,10 +864,6 @@ class UNetControlNetXSModel(ModelMixin, ConfigMixin):
for u in self.up_blocks:
u.freeze_base_params()
def _set_gradient_checkpointing(self, module, value=False):
if hasattr(module, "gradient_checkpointing"):
module.gradient_checkpointing = value
@property
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors
def attn_processors(self) -> Dict[str, AttentionProcessor]:
@@ -1450,15 +1446,6 @@ class ControlNetXSCrossAttnDownBlock2D(nn.Module):
base_blocks = list(zip(self.base_resnets, self.base_attentions))
ctrl_blocks = list(zip(self.ctrl_resnets, self.ctrl_attentions))
def create_custom_forward(module, return_dict=None):
def custom_forward(*inputs):
if return_dict is not None:
return module(*inputs, return_dict=return_dict)
else:
return module(*inputs)
return custom_forward
for (b_res, b_attn), (c_res, c_attn), b2c, c2b in zip(
base_blocks, ctrl_blocks, self.base_to_ctrl, self.ctrl_to_base
):
@@ -1468,13 +1455,7 @@ class ControlNetXSCrossAttnDownBlock2D(nn.Module):
# apply base subblock
if torch.is_grad_enabled() and self.gradient_checkpointing:
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
h_base = torch.utils.checkpoint.checkpoint(
create_custom_forward(b_res),
h_base,
temb,
**ckpt_kwargs,
)
h_base = self._gradient_checkpointing_func(b_res, h_base, temb)
else:
h_base = b_res(h_base, temb)
@@ -1491,13 +1472,7 @@ class ControlNetXSCrossAttnDownBlock2D(nn.Module):
# apply ctrl subblock
if apply_control:
if torch.is_grad_enabled() and self.gradient_checkpointing:
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
h_ctrl = torch.utils.checkpoint.checkpoint(
create_custom_forward(c_res),
h_ctrl,
temb,
**ckpt_kwargs,
)
h_ctrl = self._gradient_checkpointing_func(c_res, h_ctrl, temb)
else:
h_ctrl = c_res(h_ctrl, temb)
if c_attn is not None:
@@ -1862,15 +1837,6 @@ class ControlNetXSCrossAttnUpBlock2D(nn.Module):
and getattr(self, "b2", None)
)
def create_custom_forward(module, return_dict=None):
def custom_forward(*inputs):
if return_dict is not None:
return module(*inputs, return_dict=return_dict)
else:
return module(*inputs)
return custom_forward
def maybe_apply_freeu_to_subblock(hidden_states, res_h_base):
# FreeU: Only operate on the first two stages
if is_freeu_enabled:
@@ -1900,13 +1866,7 @@ class ControlNetXSCrossAttnUpBlock2D(nn.Module):
hidden_states = torch.cat([hidden_states, res_h_base], dim=1)
if torch.is_grad_enabled() and self.gradient_checkpointing:
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(resnet),
hidden_states,
temb,
**ckpt_kwargs,
)
hidden_states = self._gradient_checkpointing_func(resnet, hidden_states, temb)
else:
hidden_states = resnet(hidden_states, temb)
+1 -1
View File
@@ -1787,7 +1787,7 @@ class LuminaCombinedTimestepCaptionEmbedding(nn.Module):
def forward(self, timestep, caption_feat, caption_mask):
# timestep embedding:
time_freq = self.time_proj(timestep)
time_embed = self.timestep_embedder(time_freq.to(dtype=self.timestep_embedder.linear_1.weight.dtype))
time_embed = self.timestep_embedder(time_freq.to(dtype=caption_feat.dtype))
# caption condition embedding:
caption_mask_float = caption_mask.float().unsqueeze(-1)
+148 -7
View File
@@ -21,17 +21,20 @@ import json
import os
import re
from collections import OrderedDict
from functools import partial, wraps
from functools import wraps
from pathlib import Path
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
from typing import Any, Callable, Dict, List, Optional, Tuple, Type, Union
import safetensors
import torch
import torch.utils.checkpoint
from huggingface_hub import DDUFEntry, create_repo, split_torch_state_dict_into_shards
from huggingface_hub.utils import validate_hf_hub_args
from torch import Tensor, nn
from typing_extensions import Self
from .. import __version__
from ..hooks import apply_layerwise_casting
from ..quantizers import DiffusersAutoQuantizer, DiffusersQuantizer
from ..quantizers.quantization_config import QuantizationMethod
from ..utils import (
@@ -48,6 +51,7 @@ from ..utils import (
is_accelerate_available,
is_bitsandbytes_available,
is_bitsandbytes_version,
is_peft_available,
is_torch_version,
logging,
)
@@ -102,6 +106,17 @@ def get_parameter_dtype(parameter: torch.nn.Module) -> torch.dtype:
"""
Returns the first found floating dtype in parameters if there is one, otherwise returns the last dtype it found.
"""
# 1. Check if we have attached any dtype modifying hooks (eg. layerwise casting)
if isinstance(parameter, nn.Module):
for name, submodule in parameter.named_modules():
if not hasattr(submodule, "_diffusers_hook"):
continue
registry = submodule._diffusers_hook
hook = registry.get_hook("layerwise_casting")
if hook is not None:
return hook.compute_dtype
# 2. If no dtype modifying hooks are attached, return the dtype of the first floating point parameter/buffer
last_dtype = None
for param in parameter.parameters():
last_dtype = param.dtype
@@ -150,10 +165,13 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
_keys_to_ignore_on_load_unexpected = None
_no_split_modules = None
_keep_in_fp32_modules = None
_skip_layerwise_casting_patterns = None
def __init__(self):
super().__init__()
self._gradient_checkpointing_func = None
def __getattr__(self, name: str) -> Any:
"""The only reason we overwrite `getattr` here is to gracefully deprecate accessing
config attributes directly. See https://github.com/huggingface/diffusers/pull/3129 We need to overwrite
@@ -179,14 +197,35 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
"""
return any(hasattr(m, "gradient_checkpointing") and m.gradient_checkpointing for m in self.modules())
def enable_gradient_checkpointing(self) -> None:
def enable_gradient_checkpointing(self, gradient_checkpointing_func: Optional[Callable] = None) -> None:
"""
Activates gradient checkpointing for the current model (may be referred to as *activation checkpointing* or
*checkpoint activations* in other frameworks).
Args:
gradient_checkpointing_func (`Callable`, *optional*):
The function to use for gradient checkpointing. If `None`, the default PyTorch checkpointing function
is used (`torch.utils.checkpoint.checkpoint`).
"""
if not self._supports_gradient_checkpointing:
raise ValueError(f"{self.__class__.__name__} does not support gradient checkpointing.")
self.apply(partial(self._set_gradient_checkpointing, value=True))
raise ValueError(
f"{self.__class__.__name__} does not support gradient checkpointing. Please make sure to set the boolean attribute "
f"`_supports_gradient_checkpointing` to `True` in the class definition."
)
if gradient_checkpointing_func is None:
def _gradient_checkpointing_func(module, *args):
ckpt_kwargs = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
return torch.utils.checkpoint.checkpoint(
module.__call__,
*args,
**ckpt_kwargs,
)
gradient_checkpointing_func = _gradient_checkpointing_func
self._set_gradient_checkpointing(enable=True, gradient_checkpointing_func=gradient_checkpointing_func)
def disable_gradient_checkpointing(self) -> None:
"""
@@ -194,7 +233,7 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
*checkpoint activations* in other frameworks).
"""
if self._supports_gradient_checkpointing:
self.apply(partial(self._set_gradient_checkpointing, value=False))
self._set_gradient_checkpointing(enable=False)
def set_use_npu_flash_attention(self, valid: bool) -> None:
r"""
@@ -314,6 +353,90 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
"""
self.set_use_memory_efficient_attention_xformers(False)
def enable_layerwise_casting(
self,
storage_dtype: torch.dtype = torch.float8_e4m3fn,
compute_dtype: Optional[torch.dtype] = None,
skip_modules_pattern: Optional[Tuple[str, ...]] = None,
skip_modules_classes: Optional[Tuple[Type[torch.nn.Module], ...]] = None,
non_blocking: bool = False,
) -> None:
r"""
Activates layerwise casting for the current model.
Layerwise casting is a technique that casts the model weights to a lower precision dtype for storage but
upcasts them on-the-fly to a higher precision dtype for computation. This process can significantly reduce the
memory footprint from model weights, but may lead to some quality degradation in the outputs. Most degradations
are negligible, mostly stemming from weight casting in normalization and modulation layers.
By default, most models in diffusers set the `_skip_layerwise_casting_patterns` attribute to ignore patch
embedding, positional embedding and normalization layers. This is because these layers are most likely
precision-critical for quality. If you wish to change this behavior, you can set the
`_skip_layerwise_casting_patterns` attribute to `None`, or call
[`~hooks.layerwise_casting.apply_layerwise_casting`] with custom arguments.
Example:
Using [`~models.ModelMixin.enable_layerwise_casting`]:
```python
>>> from diffusers import CogVideoXTransformer3DModel
>>> transformer = CogVideoXTransformer3DModel.from_pretrained(
... "THUDM/CogVideoX-5b", subfolder="transformer", torch_dtype=torch.bfloat16
... )
>>> # Enable layerwise casting via the model, which ignores certain modules by default
>>> transformer.enable_layerwise_casting(storage_dtype=torch.float8_e4m3fn, compute_dtype=torch.bfloat16)
```
Args:
storage_dtype (`torch.dtype`):
The dtype to which the model should be cast for storage.
compute_dtype (`torch.dtype`):
The dtype to which the model weights should be cast during the forward pass.
skip_modules_pattern (`Tuple[str, ...]`, *optional*):
A list of patterns to match the names of the modules to skip during the layerwise casting process. If
set to `None`, default skip patterns are used to ignore certain internal layers of modules and PEFT
layers.
skip_modules_classes (`Tuple[Type[torch.nn.Module], ...]`, *optional*):
A list of module classes to skip during the layerwise casting process.
non_blocking (`bool`, *optional*, defaults to `False`):
If `True`, the weight casting operations are non-blocking.
"""
user_provided_patterns = True
if skip_modules_pattern is None:
from ..hooks.layerwise_casting import DEFAULT_SKIP_MODULES_PATTERN
skip_modules_pattern = DEFAULT_SKIP_MODULES_PATTERN
user_provided_patterns = False
if self._keep_in_fp32_modules is not None:
skip_modules_pattern += tuple(self._keep_in_fp32_modules)
if self._skip_layerwise_casting_patterns is not None:
skip_modules_pattern += tuple(self._skip_layerwise_casting_patterns)
skip_modules_pattern = tuple(set(skip_modules_pattern))
if is_peft_available() and not user_provided_patterns:
# By default, we want to skip all peft layers because they have a very low memory footprint.
# If users want to apply layerwise casting on peft layers as well, they can utilize the
# `~diffusers.hooks.layerwise_casting.apply_layerwise_casting` function which provides
# them with more flexibility and control.
from peft.tuners.loha.layer import LoHaLayer
from peft.tuners.lokr.layer import LoKrLayer
from peft.tuners.lora.layer import LoraLayer
for layer in (LoHaLayer, LoKrLayer, LoraLayer):
skip_modules_pattern += tuple(layer.adapter_layer_names)
if compute_dtype is None:
logger.info("`compute_dtype` not provided when enabling layerwise casting. Using dtype of the model.")
compute_dtype = self.dtype
apply_layerwise_casting(
self, storage_dtype, compute_dtype, skip_modules_pattern, skip_modules_classes, non_blocking
)
def save_pretrained(
self,
save_directory: Union[str, os.PathLike],
@@ -483,7 +606,7 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
@classmethod
@validate_hf_hub_args
def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.PathLike]], **kwargs):
def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.PathLike]], **kwargs) -> Self:
r"""
Instantiate a pretrained PyTorch model from a pretrained model configuration.
@@ -1354,6 +1477,24 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
mem = mem + mem_bufs
return mem
def _set_gradient_checkpointing(
self, enable: bool = True, gradient_checkpointing_func: Callable = torch.utils.checkpoint.checkpoint
) -> None:
is_gradient_checkpointing_set = False
for name, module in self.named_modules():
if hasattr(module, "gradient_checkpointing"):
logger.debug(f"Setting `gradient_checkpointing={enable}` for '{name}'")
module._gradient_checkpointing_func = gradient_checkpointing_func
module.gradient_checkpointing = enable
is_gradient_checkpointing_set = True
if not is_gradient_checkpointing_set:
raise ValueError(
f"The module {self.__class__.__name__} does not support gradient checkpointing. Please make sure to "
f"use a module that supports gradient checkpointing by creating a boolean attribute `gradient_checkpointing`."
)
def _convert_deprecated_attention_blocks(self, state_dict: OrderedDict) -> None:
deprecated_attention_block_paths = []
@@ -13,7 +13,7 @@
# limitations under the License.
from typing import Any, Dict, Union
from typing import Dict, Union
import torch
import torch.nn as nn
@@ -21,7 +21,7 @@ import torch.nn.functional as F
from ...configuration_utils import ConfigMixin, register_to_config
from ...loaders import FromOriginalModelMixin
from ...utils import is_torch_version, logging
from ...utils import logging
from ...utils.torch_utils import maybe_allow_in_graph
from ..attention_processor import (
Attention,
@@ -276,6 +276,7 @@ class AuraFlowTransformer2DModel(ModelMixin, ConfigMixin, FromOriginalModelMixin
"""
_no_split_modules = ["AuraFlowJointTransformerBlock", "AuraFlowSingleTransformerBlock", "AuraFlowPatchEmbed"]
_skip_layerwise_casting_patterns = ["pos_embed", "norm"]
_supports_gradient_checkpointing = True
@register_to_config
@@ -443,10 +444,6 @@ class AuraFlowTransformer2DModel(ModelMixin, ConfigMixin, FromOriginalModelMixin
if self.original_attn_processors is not None:
self.set_attn_processor(self.original_attn_processors)
def _set_gradient_checkpointing(self, module, value=False):
if hasattr(module, "gradient_checkpointing"):
module.gradient_checkpointing = value
def forward(
self,
hidden_states: torch.FloatTensor,
@@ -468,23 +465,11 @@ class AuraFlowTransformer2DModel(ModelMixin, ConfigMixin, FromOriginalModelMixin
# MMDiT blocks.
for index_block, block in enumerate(self.joint_transformer_blocks):
if torch.is_grad_enabled() and self.gradient_checkpointing:
def create_custom_forward(module, return_dict=None):
def custom_forward(*inputs):
if return_dict is not None:
return module(*inputs, return_dict=return_dict)
else:
return module(*inputs)
return custom_forward
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
encoder_hidden_states, hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(block),
encoder_hidden_states, hidden_states = self._gradient_checkpointing_func(
block,
hidden_states,
encoder_hidden_states,
temb,
**ckpt_kwargs,
)
else:
@@ -499,22 +484,10 @@ class AuraFlowTransformer2DModel(ModelMixin, ConfigMixin, FromOriginalModelMixin
for index_block, block in enumerate(self.single_transformer_blocks):
if torch.is_grad_enabled() and self.gradient_checkpointing:
def create_custom_forward(module, return_dict=None):
def custom_forward(*inputs):
if return_dict is not None:
return module(*inputs, return_dict=return_dict)
else:
return module(*inputs)
return custom_forward
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
combined_hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(block),
combined_hidden_states = self._gradient_checkpointing_func(
block,
combined_hidden_states,
temb,
**ckpt_kwargs,
)
else:
@@ -20,10 +20,11 @@ from torch import nn
from ...configuration_utils import ConfigMixin, register_to_config
from ...loaders import PeftAdapterMixin
from ...utils import USE_PEFT_BACKEND, is_torch_version, logging, scale_lora_layers, unscale_lora_layers
from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers
from ...utils.torch_utils import maybe_allow_in_graph
from ..attention import Attention, FeedForward
from ..attention_processor import AttentionProcessor, CogVideoXAttnProcessor2_0, FusedCogVideoXAttnProcessor2_0
from ..cache_utils import CacheMixin
from ..embeddings import CogVideoXPatchEmbed, TimestepEmbedding, Timesteps
from ..modeling_outputs import Transformer2DModelOutput
from ..modeling_utils import ModelMixin
@@ -156,7 +157,7 @@ class CogVideoXBlock(nn.Module):
return hidden_states, encoder_hidden_states
class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, CacheMixin):
"""
A Transformer model for video-like data in [CogVideoX](https://github.com/THUDM/CogVideo).
@@ -212,6 +213,7 @@ class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
Scaling factor to apply in 3D positional embeddings across temporal dimensions.
"""
_skip_layerwise_casting_patterns = ["patch_embed", "norm"]
_supports_gradient_checkpointing = True
_no_split_modules = ["CogVideoXBlock", "CogVideoXPatchEmbed"]
@@ -329,9 +331,6 @@ class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
self.gradient_checkpointing = False
def _set_gradient_checkpointing(self, module, value=False):
self.gradient_checkpointing = value
@property
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors
def attn_processors(self) -> Dict[str, AttentionProcessor]:
@@ -487,22 +486,13 @@ class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
# 3. Transformer blocks
for i, block in enumerate(self.transformer_blocks):
if torch.is_grad_enabled() and self.gradient_checkpointing:
def create_custom_forward(module):
def custom_forward(*inputs):
return module(*inputs)
return custom_forward
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
hidden_states, encoder_hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(block),
hidden_states, encoder_hidden_states = self._gradient_checkpointing_func(
block,
hidden_states,
encoder_hidden_states,
emb,
image_rotary_emb,
attention_kwargs,
**ckpt_kwargs,
)
else:
hidden_states, encoder_hidden_states = block(
@@ -20,7 +20,7 @@ from torch import nn
from ...configuration_utils import ConfigMixin, register_to_config
from ...loaders import PeftAdapterMixin
from ...utils import USE_PEFT_BACKEND, is_torch_version, logging, scale_lora_layers, unscale_lora_layers
from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers
from ...utils.torch_utils import maybe_allow_in_graph
from ..attention import Attention, FeedForward
from ..attention_processor import AttentionProcessor, CogVideoXAttnProcessor2_0
@@ -595,9 +595,6 @@ class ConsisIDTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
self.gradient_checkpointing = False
def _set_gradient_checkpointing(self, module, value=False):
self.gradient_checkpointing = value
def _init_face_inputs(self):
self.local_facial_extractor = LocalFacialExtractor(
id_dim=self.LFE_id_dim,
@@ -745,22 +742,13 @@ class ConsisIDTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
# 3. Transformer blocks
ca_idx = 0
for i, block in enumerate(self.transformer_blocks):
if self.training and self.gradient_checkpointing:
def create_custom_forward(module):
def custom_forward(*inputs):
return module(*inputs)
return custom_forward
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
hidden_states, encoder_hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(block),
if torch.is_grad_enabled() and self.gradient_checkpointing:
hidden_states, encoder_hidden_states = self._gradient_checkpointing_func(
block,
hidden_states,
encoder_hidden_states,
emb,
image_rotary_emb,
**ckpt_kwargs,
)
else:
hidden_states, encoder_hidden_states = block(
@@ -18,7 +18,7 @@ import torch.nn.functional as F
from torch import nn
from ...configuration_utils import ConfigMixin, register_to_config
from ...utils import is_torch_version, logging
from ...utils import logging
from ..attention import BasicTransformerBlock
from ..embeddings import PatchEmbed
from ..modeling_outputs import Transformer2DModelOutput
@@ -64,6 +64,7 @@ class DiTTransformer2DModel(ModelMixin, ConfigMixin):
A small constant added to the denominator in normalization layers to prevent division by zero.
"""
_skip_layerwise_casting_patterns = ["pos_embed", "norm"]
_supports_gradient_checkpointing = True
@register_to_config
@@ -143,10 +144,6 @@ class DiTTransformer2DModel(ModelMixin, ConfigMixin):
self.inner_dim, self.config.patch_size * self.config.patch_size * self.out_channels
)
def _set_gradient_checkpointing(self, module, value=False):
if hasattr(module, "gradient_checkpointing"):
module.gradient_checkpointing = value
def forward(
self,
hidden_states: torch.Tensor,
@@ -185,19 +182,8 @@ class DiTTransformer2DModel(ModelMixin, ConfigMixin):
# 2. Blocks
for block in self.transformer_blocks:
if torch.is_grad_enabled() and self.gradient_checkpointing:
def create_custom_forward(module, return_dict=None):
def custom_forward(*inputs):
if return_dict is not None:
return module(*inputs, return_dict=return_dict)
else:
return module(*inputs)
return custom_forward
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(block),
hidden_states = self._gradient_checkpointing_func(
block,
hidden_states,
None,
None,
@@ -205,7 +191,6 @@ class DiTTransformer2DModel(ModelMixin, ConfigMixin):
timestep,
cross_attention_kwargs,
class_labels,
**ckpt_kwargs,
)
else:
hidden_states = block(
@@ -244,6 +244,8 @@ class HunyuanDiT2DModel(ModelMixin, ConfigMixin):
Whether or not to use style condition and image meta size. True for version <=1.1, False for version >= 1.2
"""
_skip_layerwise_casting_patterns = ["pos_embed", "norm", "pooler"]
@register_to_config
def __init__(
self,
@@ -11,6 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Optional
import torch
@@ -19,13 +20,14 @@ from torch import nn
from ...configuration_utils import ConfigMixin, register_to_config
from ...models.embeddings import PixArtAlphaTextProjection, get_1d_sincos_pos_embed_from_grid
from ..attention import BasicTransformerBlock
from ..cache_utils import CacheMixin
from ..embeddings import PatchEmbed
from ..modeling_outputs import Transformer2DModelOutput
from ..modeling_utils import ModelMixin
from ..normalization import AdaLayerNormSingle
class LatteTransformer3DModel(ModelMixin, ConfigMixin):
class LatteTransformer3DModel(ModelMixin, ConfigMixin, CacheMixin):
_supports_gradient_checkpointing = True
"""
@@ -65,6 +67,8 @@ class LatteTransformer3DModel(ModelMixin, ConfigMixin):
The number of frames in the video-like data.
"""
_skip_layerwise_casting_patterns = ["pos_embed", "norm"]
@register_to_config
def __init__(
self,
@@ -162,9 +166,6 @@ class LatteTransformer3DModel(ModelMixin, ConfigMixin):
self.gradient_checkpointing = False
def _set_gradient_checkpointing(self, module, value=False):
self.gradient_checkpointing = value
def forward(
self,
hidden_states: torch.Tensor,
@@ -239,7 +240,7 @@ class LatteTransformer3DModel(ModelMixin, ConfigMixin):
zip(self.transformer_blocks, self.temporal_transformer_blocks)
):
if torch.is_grad_enabled() and self.gradient_checkpointing:
hidden_states = torch.utils.checkpoint.checkpoint(
hidden_states = self._gradient_checkpointing_func(
spatial_block,
hidden_states,
None, # attention_mask
@@ -248,7 +249,6 @@ class LatteTransformer3DModel(ModelMixin, ConfigMixin):
timestep_spatial,
None, # cross_attention_kwargs
None, # class_labels
use_reentrant=False,
)
else:
hidden_states = spatial_block(
@@ -272,7 +272,7 @@ class LatteTransformer3DModel(ModelMixin, ConfigMixin):
hidden_states = hidden_states + self.temp_pos_embed
if torch.is_grad_enabled() and self.gradient_checkpointing:
hidden_states = torch.utils.checkpoint.checkpoint(
hidden_states = self._gradient_checkpointing_func(
temp_block,
hidden_states,
None, # attention_mask
@@ -281,7 +281,6 @@ class LatteTransformer3DModel(ModelMixin, ConfigMixin):
timestep_temp,
None, # cross_attention_kwargs
None, # class_labels
use_reentrant=False,
)
else:
hidden_states = temp_block(
@@ -221,6 +221,8 @@ class LuminaNextDiT2DModel(ModelMixin, ConfigMixin):
overall scale of the model's operations.
"""
_skip_layerwise_casting_patterns = ["patch_embedder", "norm", "ffn_norm"]
@register_to_config
def __init__(
self,
@@ -17,7 +17,7 @@ import torch
from torch import nn
from ...configuration_utils import ConfigMixin, register_to_config
from ...utils import is_torch_version, logging
from ...utils import logging
from ..attention import BasicTransformerBlock
from ..attention_processor import Attention, AttentionProcessor, AttnProcessor, FusedAttnProcessor2_0
from ..embeddings import PatchEmbed, PixArtAlphaTextProjection
@@ -79,6 +79,7 @@ class PixArtTransformer2DModel(ModelMixin, ConfigMixin):
_supports_gradient_checkpointing = True
_no_split_modules = ["BasicTransformerBlock", "PatchEmbed"]
_skip_layerwise_casting_patterns = ["pos_embed", "norm", "adaln_single"]
@register_to_config
def __init__(
@@ -183,10 +184,6 @@ class PixArtTransformer2DModel(ModelMixin, ConfigMixin):
in_features=self.config.caption_channels, hidden_size=self.inner_dim
)
def _set_gradient_checkpointing(self, module, value=False):
if hasattr(module, "gradient_checkpointing"):
module.gradient_checkpointing = value
@property
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors
def attn_processors(self) -> Dict[str, AttentionProcessor]:
@@ -387,19 +384,8 @@ class PixArtTransformer2DModel(ModelMixin, ConfigMixin):
# 2. Blocks
for block in self.transformer_blocks:
if torch.is_grad_enabled() and self.gradient_checkpointing:
def create_custom_forward(module, return_dict=None):
def custom_forward(*inputs):
if return_dict is not None:
return module(*inputs, return_dict=return_dict)
else:
return module(*inputs)
return custom_forward
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(block),
hidden_states = self._gradient_checkpointing_func(
block,
hidden_states,
attention_mask,
encoder_hidden_states,
@@ -407,7 +393,6 @@ class PixArtTransformer2DModel(ModelMixin, ConfigMixin):
timestep,
cross_attention_kwargs,
None,
**ckpt_kwargs,
)
else:
hidden_states = block(
@@ -19,7 +19,7 @@ from torch import nn
from ...configuration_utils import ConfigMixin, register_to_config
from ...loaders import PeftAdapterMixin
from ...utils import USE_PEFT_BACKEND, is_torch_version, logging, scale_lora_layers, unscale_lora_layers
from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers
from ..attention_processor import (
Attention,
AttentionProcessor,
@@ -236,6 +236,7 @@ class SanaTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
_supports_gradient_checkpointing = True
_no_split_modules = ["SanaTransformerBlock", "PatchEmbed", "SanaModulatedNorm"]
_skip_layerwise_casting_patterns = ["patch_embed", "norm"]
@register_to_config
def __init__(
@@ -307,10 +308,6 @@ class SanaTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
self.gradient_checkpointing = False
def _set_gradient_checkpointing(self, module, value=False):
if hasattr(module, "gradient_checkpointing"):
module.gradient_checkpointing = value
@property
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors
def attn_processors(self) -> Dict[str, AttentionProcessor]:
@@ -437,21 +434,9 @@ class SanaTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
# 2. Transformer blocks
if torch.is_grad_enabled() and self.gradient_checkpointing:
def create_custom_forward(module, return_dict=None):
def custom_forward(*inputs):
if return_dict is not None:
return module(*inputs, return_dict=return_dict)
else:
return module(*inputs)
return custom_forward
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
for block in self.transformer_blocks:
hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(block),
hidden_states = self._gradient_checkpointing_func(
block,
hidden_states,
attention_mask,
encoder_hidden_states,
@@ -459,7 +444,6 @@ class SanaTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
timestep,
post_patch_height,
post_patch_width,
**ckpt_kwargs,
)
else:
@@ -13,7 +13,7 @@
# limitations under the License.
from typing import Any, Dict, Optional, Union
from typing import Dict, Optional, Union
import numpy as np
import torch
@@ -29,7 +29,7 @@ from ...models.attention_processor import (
)
from ...models.modeling_utils import ModelMixin
from ...models.transformers.transformer_2d import Transformer2DModelOutput
from ...utils import is_torch_version, logging
from ...utils import logging
from ...utils.torch_utils import maybe_allow_in_graph
@@ -211,6 +211,7 @@ class StableAudioDiTModel(ModelMixin, ConfigMixin):
"""
_supports_gradient_checkpointing = True
_skip_layerwise_casting_patterns = ["preprocess_conv", "postprocess_conv", "^proj_in$", "^proj_out$", "norm"]
@register_to_config
def __init__(
@@ -345,10 +346,6 @@ class StableAudioDiTModel(ModelMixin, ConfigMixin):
"""
self.set_attn_processor(StableAudioAttnProcessor2_0())
def _set_gradient_checkpointing(self, module, value=False):
if hasattr(module, "gradient_checkpointing"):
module.gradient_checkpointing = value
def forward(
self,
hidden_states: torch.FloatTensor,
@@ -415,25 +412,13 @@ class StableAudioDiTModel(ModelMixin, ConfigMixin):
for block in self.transformer_blocks:
if torch.is_grad_enabled() and self.gradient_checkpointing:
def create_custom_forward(module, return_dict=None):
def custom_forward(*inputs):
if return_dict is not None:
return module(*inputs, return_dict=return_dict)
else:
return module(*inputs)
return custom_forward
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(block),
hidden_states = self._gradient_checkpointing_func(
block,
hidden_states,
attention_mask,
cross_attention_hidden_states,
encoder_attention_mask,
rotary_embedding,
**ckpt_kwargs,
)
else:
@@ -18,7 +18,7 @@ import torch.nn.functional as F
from torch import nn
from ...configuration_utils import LegacyConfigMixin, register_to_config
from ...utils import deprecate, is_torch_version, logging
from ...utils import deprecate, logging
from ..attention import BasicTransformerBlock
from ..embeddings import ImagePositionalEmbeddings, PatchEmbed, PixArtAlphaTextProjection
from ..modeling_outputs import Transformer2DModelOutput
@@ -66,6 +66,7 @@ class Transformer2DModel(LegacyModelMixin, LegacyConfigMixin):
_supports_gradient_checkpointing = True
_no_split_modules = ["BasicTransformerBlock"]
_skip_layerwise_casting_patterns = ["latent_image_embedding", "norm"]
@register_to_config
def __init__(
@@ -320,10 +321,6 @@ class Transformer2DModel(LegacyModelMixin, LegacyConfigMixin):
in_features=self.caption_channels, hidden_size=self.inner_dim
)
def _set_gradient_checkpointing(self, module, value=False):
if hasattr(module, "gradient_checkpointing"):
module.gradient_checkpointing = value
def forward(
self,
hidden_states: torch.Tensor,
@@ -416,19 +413,8 @@ class Transformer2DModel(LegacyModelMixin, LegacyConfigMixin):
# 2. Blocks
for block in self.transformer_blocks:
if torch.is_grad_enabled() and self.gradient_checkpointing:
def create_custom_forward(module, return_dict=None):
def custom_forward(*inputs):
if return_dict is not None:
return module(*inputs, return_dict=return_dict)
else:
return module(*inputs)
return custom_forward
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(block),
hidden_states = self._gradient_checkpointing_func(
block,
hidden_states,
attention_mask,
encoder_hidden_states,
@@ -436,7 +422,6 @@ class Transformer2DModel(LegacyModelMixin, LegacyConfigMixin):
timestep,
cross_attention_kwargs,
class_labels,
**ckpt_kwargs,
)
else:
hidden_states = block(
@@ -13,17 +13,18 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, Dict, Optional, Tuple
from typing import Optional, Tuple
import torch
import torch.nn as nn
import torch.nn.functional as F
from ...configuration_utils import ConfigMixin, register_to_config
from ...utils import is_torch_version, logging
from ...utils import logging
from ...utils.torch_utils import maybe_allow_in_graph
from ..attention import FeedForward
from ..attention_processor import AllegroAttnProcessor2_0, Attention
from ..cache_utils import CacheMixin
from ..embeddings import PatchEmbed, PixArtAlphaTextProjection
from ..modeling_outputs import Transformer2DModelOutput
from ..modeling_utils import ModelMixin
@@ -172,7 +173,7 @@ class AllegroTransformerBlock(nn.Module):
return hidden_states
class AllegroTransformer3DModel(ModelMixin, ConfigMixin):
class AllegroTransformer3DModel(ModelMixin, ConfigMixin, CacheMixin):
_supports_gradient_checkpointing = True
"""
@@ -222,6 +223,7 @@ class AllegroTransformer3DModel(ModelMixin, ConfigMixin):
"""
_supports_gradient_checkpointing = True
_skip_layerwise_casting_patterns = ["pos_embed", "norm", "adaln_single"]
@register_to_config
def __init__(
@@ -302,9 +304,6 @@ class AllegroTransformer3DModel(ModelMixin, ConfigMixin):
self.gradient_checkpointing = False
def _set_gradient_checkpointing(self, module, value=False):
self.gradient_checkpointing = value
def forward(
self,
hidden_states: torch.Tensor,
@@ -374,23 +373,14 @@ class AllegroTransformer3DModel(ModelMixin, ConfigMixin):
for i, block in enumerate(self.transformer_blocks):
# TODO(aryan): Implement gradient checkpointing
if torch.is_grad_enabled() and self.gradient_checkpointing:
def create_custom_forward(module):
def custom_forward(*inputs):
return module(*inputs)
return custom_forward
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(block),
hidden_states = self._gradient_checkpointing_func(
block,
hidden_states,
encoder_hidden_states,
timestep,
attention_mask,
encoder_attention_mask,
image_rotary_emb,
**ckpt_kwargs,
)
else:
hidden_states = block(
@@ -13,7 +13,7 @@
# limitations under the License.
from typing import Any, Dict, Union
from typing import Dict, Union
import torch
import torch.nn as nn
@@ -27,7 +27,7 @@ from ...models.attention_processor import (
)
from ...models.modeling_utils import ModelMixin
from ...models.normalization import AdaLayerNormContinuous
from ...utils import is_torch_version, logging
from ...utils import logging
from ..embeddings import CogView3CombinedTimestepSizeEmbeddings, CogView3PlusPatchEmbed
from ..modeling_outputs import Transformer2DModelOutput
from ..normalization import CogView3PlusAdaLayerNormZeroTextImage
@@ -166,6 +166,7 @@ class CogView3PlusTransformer2DModel(ModelMixin, ConfigMixin):
"""
_supports_gradient_checkpointing = True
_skip_layerwise_casting_patterns = ["patch_embed", "norm"]
_no_split_modules = ["CogView3PlusTransformerBlock", "CogView3PlusPatchEmbed"]
@register_to_config
@@ -288,10 +289,6 @@ class CogView3PlusTransformer2DModel(ModelMixin, ConfigMixin):
for name, module in self.named_children():
fn_recursive_attn_processor(name, module, processor)
def _set_gradient_checkpointing(self, module, value=False):
if hasattr(module, "gradient_checkpointing"):
module.gradient_checkpointing = value
def forward(
self,
hidden_states: torch.Tensor,
@@ -343,20 +340,11 @@ class CogView3PlusTransformer2DModel(ModelMixin, ConfigMixin):
for index_block, block in enumerate(self.transformer_blocks):
if torch.is_grad_enabled() and self.gradient_checkpointing:
def create_custom_forward(module):
def custom_forward(*inputs):
return module(*inputs)
return custom_forward
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
hidden_states, encoder_hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(block),
hidden_states, encoder_hidden_states = self._gradient_checkpointing_func(
block,
hidden_states,
encoder_hidden_states,
emb,
**ckpt_kwargs,
)
else:
hidden_states, encoder_hidden_states = block(
@@ -32,9 +32,10 @@ from ...models.attention_processor import (
)
from ...models.modeling_utils import ModelMixin
from ...models.normalization import AdaLayerNormContinuous, AdaLayerNormZero, AdaLayerNormZeroSingle
from ...utils import USE_PEFT_BACKEND, is_torch_version, logging, scale_lora_layers, unscale_lora_layers
from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers
from ...utils.import_utils import is_torch_npu_available
from ...utils.torch_utils import maybe_allow_in_graph
from ..cache_utils import CacheMixin
from ..embeddings import CombinedTimestepGuidanceTextProjEmbeddings, CombinedTimestepTextProjEmbeddings, FluxPosEmbed
from ..modeling_outputs import Transformer2DModelOutput
@@ -227,7 +228,7 @@ class FluxTransformerBlock(nn.Module):
class FluxTransformer2DModel(
ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin, FluxTransformer2DLoadersMixin
ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin, FluxTransformer2DLoadersMixin, CacheMixin
):
"""
The Transformer model introduced in Flux.
@@ -262,6 +263,7 @@ class FluxTransformer2DModel(
_supports_gradient_checkpointing = True
_no_split_modules = ["FluxTransformerBlock", "FluxSingleTransformerBlock"]
_skip_layerwise_casting_patterns = ["pos_embed", "norm"]
@register_to_config
def __init__(
@@ -421,10 +423,6 @@ class FluxTransformer2DModel(
if self.original_attn_processors is not None:
self.set_attn_processor(self.original_attn_processors)
def _set_gradient_checkpointing(self, module, value=False):
if hasattr(module, "gradient_checkpointing"):
module.gradient_checkpointing = value
def forward(
self,
hidden_states: torch.Tensor,
@@ -519,24 +517,12 @@ class FluxTransformer2DModel(
for index_block, block in enumerate(self.transformer_blocks):
if torch.is_grad_enabled() and self.gradient_checkpointing:
def create_custom_forward(module, return_dict=None):
def custom_forward(*inputs):
if return_dict is not None:
return module(*inputs, return_dict=return_dict)
else:
return module(*inputs)
return custom_forward
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
encoder_hidden_states, hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(block),
encoder_hidden_states, hidden_states = self._gradient_checkpointing_func(
block,
hidden_states,
encoder_hidden_states,
temb,
image_rotary_emb,
**ckpt_kwargs,
)
else:
@@ -563,23 +549,11 @@ class FluxTransformer2DModel(
for index_block, block in enumerate(self.single_transformer_blocks):
if torch.is_grad_enabled() and self.gradient_checkpointing:
def create_custom_forward(module, return_dict=None):
def custom_forward(*inputs):
if return_dict is not None:
return module(*inputs, return_dict=return_dict)
else:
return module(*inputs)
return custom_forward
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(block),
hidden_states = self._gradient_checkpointing_func(
block,
hidden_states,
temb,
image_rotary_emb,
**ckpt_kwargs,
)
else:
@@ -22,9 +22,10 @@ from diffusers.loaders import FromOriginalModelMixin
from ...configuration_utils import ConfigMixin, register_to_config
from ...loaders import PeftAdapterMixin
from ...utils import USE_PEFT_BACKEND, is_torch_version, logging, scale_lora_layers, unscale_lora_layers
from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers
from ..attention import FeedForward
from ..attention_processor import Attention, AttentionProcessor
from ..cache_utils import CacheMixin
from ..embeddings import (
CombinedTimestepGuidanceTextProjEmbeddings,
CombinedTimestepTextProjEmbeddings,
@@ -502,7 +503,7 @@ class HunyuanVideoTransformerBlock(nn.Module):
return hidden_states, encoder_hidden_states
class HunyuanVideoTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin):
class HunyuanVideoTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin, CacheMixin):
r"""
A Transformer model for video-like data used in [HunyuanVideo](https://huggingface.co/tencent/HunyuanVideo).
@@ -542,6 +543,7 @@ class HunyuanVideoTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin,
"""
_supports_gradient_checkpointing = True
_skip_layerwise_casting_patterns = ["x_embedder", "context_embedder", "norm"]
_no_split_modules = [
"HunyuanVideoTransformerBlock",
"HunyuanVideoSingleTransformerBlock",
@@ -670,10 +672,6 @@ class HunyuanVideoTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin,
for name, module in self.named_children():
fn_recursive_attn_processor(name, module, processor)
def _set_gradient_checkpointing(self, module, value=False):
if hasattr(module, "gradient_checkpointing"):
module.gradient_checkpointing = value
def forward(
self,
hidden_states: torch.Tensor,
@@ -732,38 +730,24 @@ class HunyuanVideoTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin,
# 4. Transformer blocks
if torch.is_grad_enabled() and self.gradient_checkpointing:
def create_custom_forward(module, return_dict=None):
def custom_forward(*inputs):
if return_dict is not None:
return module(*inputs, return_dict=return_dict)
else:
return module(*inputs)
return custom_forward
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
for block in self.transformer_blocks:
hidden_states, encoder_hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(block),
hidden_states, encoder_hidden_states = self._gradient_checkpointing_func(
block,
hidden_states,
encoder_hidden_states,
temb,
attention_mask,
image_rotary_emb,
**ckpt_kwargs,
)
for block in self.single_transformer_blocks:
hidden_states, encoder_hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(block),
hidden_states, encoder_hidden_states = self._gradient_checkpointing_func(
block,
hidden_states,
encoder_hidden_states,
temb,
attention_mask,
image_rotary_emb,
**ckpt_kwargs,
)
else:
@@ -22,7 +22,7 @@ import torch.nn.functional as F
from ...configuration_utils import ConfigMixin, register_to_config
from ...loaders import FromOriginalModelMixin, PeftAdapterMixin
from ...utils import USE_PEFT_BACKEND, is_torch_version, logging, scale_lora_layers, unscale_lora_layers
from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers
from ...utils.torch_utils import maybe_allow_in_graph
from ..attention import FeedForward
from ..attention_processor import Attention
@@ -295,6 +295,7 @@ class LTXVideoTransformer3DModel(ModelMixin, ConfigMixin, FromOriginalModelMixin
"""
_supports_gradient_checkpointing = True
_skip_layerwise_casting_patterns = ["norm"]
@register_to_config
def __init__(
@@ -360,10 +361,6 @@ class LTXVideoTransformer3DModel(ModelMixin, ConfigMixin, FromOriginalModelMixin
self.gradient_checkpointing = False
def _set_gradient_checkpointing(self, module, value=False):
if hasattr(module, "gradient_checkpointing"):
module.gradient_checkpointing = value
def forward(
self,
hidden_states: torch.Tensor,
@@ -416,25 +413,13 @@ class LTXVideoTransformer3DModel(ModelMixin, ConfigMixin, FromOriginalModelMixin
for block in self.transformer_blocks:
if torch.is_grad_enabled() and self.gradient_checkpointing:
def create_custom_forward(module, return_dict=None):
def custom_forward(*inputs):
if return_dict is not None:
return module(*inputs, return_dict=return_dict)
else:
return module(*inputs)
return custom_forward
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(block),
hidden_states = self._gradient_checkpointing_func(
block,
hidden_states,
encoder_hidden_states,
temb,
image_rotary_emb,
encoder_attention_mask,
**ckpt_kwargs,
)
else:
hidden_states = block(
@@ -21,10 +21,11 @@ import torch.nn as nn
from ...configuration_utils import ConfigMixin, register_to_config
from ...loaders import PeftAdapterMixin
from ...loaders.single_file_model import FromOriginalModelMixin
from ...utils import USE_PEFT_BACKEND, is_torch_version, logging, scale_lora_layers, unscale_lora_layers
from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers
from ...utils.torch_utils import maybe_allow_in_graph
from ..attention import FeedForward
from ..attention_processor import MochiAttention, MochiAttnProcessor2_0
from ..cache_utils import CacheMixin
from ..embeddings import MochiCombinedTimestepCaptionEmbedding, PatchEmbed
from ..modeling_outputs import Transformer2DModelOutput
from ..modeling_utils import ModelMixin
@@ -305,7 +306,7 @@ class MochiRoPE(nn.Module):
@maybe_allow_in_graph
class MochiTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin):
class MochiTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin, CacheMixin):
r"""
A Transformer model for video-like data introduced in [Mochi](https://huggingface.co/genmo/mochi-1-preview).
@@ -336,6 +337,7 @@ class MochiTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOri
_supports_gradient_checkpointing = True
_no_split_modules = ["MochiTransformerBlock"]
_skip_layerwise_casting_patterns = ["patch_embed", "norm"]
@register_to_config
def __init__(
@@ -402,10 +404,6 @@ class MochiTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOri
self.gradient_checkpointing = False
def _set_gradient_checkpointing(self, module, value=False):
if hasattr(module, "gradient_checkpointing"):
module.gradient_checkpointing = value
def forward(
self,
hidden_states: torch.Tensor,
@@ -458,22 +456,13 @@ class MochiTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOri
for i, block in enumerate(self.transformer_blocks):
if torch.is_grad_enabled() and self.gradient_checkpointing:
def create_custom_forward(module):
def custom_forward(*inputs):
return module(*inputs)
return custom_forward
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
hidden_states, encoder_hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(block),
hidden_states, encoder_hidden_states = self._gradient_checkpointing_func(
block,
hidden_states,
encoder_hidden_states,
temb,
encoder_attention_mask,
image_rotary_emb,
**ckpt_kwargs,
)
else:
hidden_states, encoder_hidden_states = block(
@@ -28,7 +28,7 @@ from ...models.attention_processor import (
)
from ...models.modeling_utils import ModelMixin
from ...models.normalization import AdaLayerNormContinuous, AdaLayerNormZero
from ...utils import USE_PEFT_BACKEND, is_torch_version, logging, scale_lora_layers, unscale_lora_layers
from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers
from ...utils.torch_utils import maybe_allow_in_graph
from ..embeddings import CombinedTimestepTextProjEmbeddings, PatchEmbed
from ..modeling_outputs import Transformer2DModelOutput
@@ -127,6 +127,7 @@ class SD3Transformer2DModel(
"""
_supports_gradient_checkpointing = True
_skip_layerwise_casting_patterns = ["pos_embed", "norm"]
@register_to_config
def __init__(
@@ -328,10 +329,6 @@ class SD3Transformer2DModel(
if self.original_attn_processors is not None:
self.set_attn_processor(self.original_attn_processors)
def _set_gradient_checkpointing(self, module, value=False):
if hasattr(module, "gradient_checkpointing"):
module.gradient_checkpointing = value
def forward(
self,
hidden_states: torch.FloatTensor,
@@ -403,24 +400,12 @@ class SD3Transformer2DModel(
is_skip = True if skip_layers is not None and index_block in skip_layers else False
if torch.is_grad_enabled() and self.gradient_checkpointing and not is_skip:
def create_custom_forward(module, return_dict=None):
def custom_forward(*inputs):
if return_dict is not None:
return module(*inputs, return_dict=return_dict)
else:
return module(*inputs)
return custom_forward
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
encoder_hidden_states, hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(block),
encoder_hidden_states, hidden_states = self._gradient_checkpointing_func(
block,
hidden_states,
encoder_hidden_states,
temb,
joint_attention_kwargs,
**ckpt_kwargs,
)
elif not is_skip:
encoder_hidden_states, hidden_states = block(
@@ -67,6 +67,8 @@ class TransformerTemporalModel(ModelMixin, ConfigMixin):
The maximum length of the sequence over which to apply positional embeddings.
"""
_skip_layerwise_casting_patterns = ["norm"]
@register_to_config
def __init__(
self,
@@ -341,19 +343,11 @@ class TransformerSpatioTemporalModel(nn.Module):
# 2. Blocks
for block, temporal_block in zip(self.transformer_blocks, self.temporal_transformer_blocks):
if torch.is_grad_enabled() and self.gradient_checkpointing:
hidden_states = torch.utils.checkpoint.checkpoint(
block,
hidden_states,
None,
encoder_hidden_states,
None,
use_reentrant=False,
hidden_states = self._gradient_checkpointing_func(
block, hidden_states, None, encoder_hidden_states, None
)
else:
hidden_states = block(
hidden_states,
encoder_hidden_states=encoder_hidden_states,
)
hidden_states = block(hidden_states, encoder_hidden_states=encoder_hidden_states)
hidden_states_mix = hidden_states
hidden_states_mix = hidden_states_mix + emb
+3 -1
View File
@@ -71,6 +71,8 @@ class UNet1DModel(ModelMixin, ConfigMixin):
Experimental feature for using a UNet without upsampling.
"""
_skip_layerwise_casting_patterns = ["norm"]
@register_to_config
def __init__(
self,
@@ -223,7 +225,7 @@ class UNet1DModel(ModelMixin, ConfigMixin):
timestep_embed = self.time_proj(timesteps)
if self.config.use_timestep_embedding:
timestep_embed = self.time_mlp(timestep_embed)
timestep_embed = self.time_mlp(timestep_embed.to(sample.dtype))
else:
timestep_embed = timestep_embed[..., None]
timestep_embed = timestep_embed.repeat([1, 1, sample.shape[2]]).to(sample.dtype)
+1 -4
View File
@@ -90,6 +90,7 @@ class UNet2DModel(ModelMixin, ConfigMixin):
"""
_supports_gradient_checkpointing = True
_skip_layerwise_casting_patterns = ["norm"]
@register_to_config
def __init__(
@@ -247,10 +248,6 @@ class UNet2DModel(ModelMixin, ConfigMixin):
self.conv_act = nn.SiLU()
self.conv_out = nn.Conv2d(block_out_channels[0], out_channels, kernel_size=3, padding=1)
def _set_gradient_checkpointing(self, module, value=False):
if hasattr(module, "gradient_checkpointing"):
module.gradient_checkpointing = value
def forward(
self,
sample: torch.Tensor,
+19 -243
View File
@@ -18,7 +18,7 @@ import torch
import torch.nn.functional as F
from torch import nn
from ...utils import deprecate, is_torch_version, logging
from ...utils import deprecate, logging
from ...utils.torch_utils import apply_freeu
from ..activations import get_activation
from ..attention_processor import Attention, AttnAddedKVProcessor, AttnAddedKVProcessor2_0
@@ -737,25 +737,9 @@ class UNetMidBlock2D(nn.Module):
hidden_states = self.resnets[0](hidden_states, temb)
for attn, resnet in zip(self.attentions, self.resnets[1:]):
if torch.is_grad_enabled() and self.gradient_checkpointing:
def create_custom_forward(module, return_dict=None):
def custom_forward(*inputs):
if return_dict is not None:
return module(*inputs, return_dict=return_dict)
else:
return module(*inputs)
return custom_forward
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
if attn is not None:
hidden_states = attn(hidden_states, temb=temb)
hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(resnet),
hidden_states,
temb,
**ckpt_kwargs,
)
hidden_states = self._gradient_checkpointing_func(resnet, hidden_states, temb)
else:
if attn is not None:
hidden_states = attn(hidden_states, temb=temb)
@@ -883,17 +867,6 @@ class UNetMidBlock2DCrossAttn(nn.Module):
hidden_states = self.resnets[0](hidden_states, temb)
for attn, resnet in zip(self.attentions, self.resnets[1:]):
if torch.is_grad_enabled() and self.gradient_checkpointing:
def create_custom_forward(module, return_dict=None):
def custom_forward(*inputs):
if return_dict is not None:
return module(*inputs, return_dict=return_dict)
else:
return module(*inputs)
return custom_forward
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
hidden_states = attn(
hidden_states,
encoder_hidden_states=encoder_hidden_states,
@@ -902,12 +875,7 @@ class UNetMidBlock2DCrossAttn(nn.Module):
encoder_attention_mask=encoder_attention_mask,
return_dict=False,
)[0]
hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(resnet),
hidden_states,
temb,
**ckpt_kwargs,
)
hidden_states = self._gradient_checkpointing_func(resnet, hidden_states, temb)
else:
hidden_states = attn(
hidden_states,
@@ -1156,23 +1124,7 @@ class AttnDownBlock2D(nn.Module):
for resnet, attn in zip(self.resnets, self.attentions):
if torch.is_grad_enabled() and self.gradient_checkpointing:
def create_custom_forward(module, return_dict=None):
def custom_forward(*inputs):
if return_dict is not None:
return module(*inputs, return_dict=return_dict)
else:
return module(*inputs)
return custom_forward
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(resnet),
hidden_states,
temb,
**ckpt_kwargs,
)
hidden_states = self._gradient_checkpointing_func(resnet, hidden_states, temb)
hidden_states = attn(hidden_states, **cross_attention_kwargs)
output_states = output_states + (hidden_states,)
else:
@@ -1304,23 +1256,7 @@ class CrossAttnDownBlock2D(nn.Module):
for i, (resnet, attn) in enumerate(blocks):
if torch.is_grad_enabled() and self.gradient_checkpointing:
def create_custom_forward(module, return_dict=None):
def custom_forward(*inputs):
if return_dict is not None:
return module(*inputs, return_dict=return_dict)
else:
return module(*inputs)
return custom_forward
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(resnet),
hidden_states,
temb,
**ckpt_kwargs,
)
hidden_states = self._gradient_checkpointing_func(resnet, hidden_states, temb)
hidden_states = attn(
hidden_states,
encoder_hidden_states=encoder_hidden_states,
@@ -1418,21 +1354,7 @@ class DownBlock2D(nn.Module):
for resnet in self.resnets:
if torch.is_grad_enabled() and self.gradient_checkpointing:
def create_custom_forward(module):
def custom_forward(*inputs):
return module(*inputs)
return custom_forward
if is_torch_version(">=", "1.11.0"):
hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(resnet), hidden_states, temb, use_reentrant=False
)
else:
hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(resnet), hidden_states, temb
)
hidden_states = self._gradient_checkpointing_func(resnet, hidden_states, temb)
else:
hidden_states = resnet(hidden_states, temb)
@@ -1906,21 +1828,7 @@ class ResnetDownsampleBlock2D(nn.Module):
for resnet in self.resnets:
if torch.is_grad_enabled() and self.gradient_checkpointing:
def create_custom_forward(module):
def custom_forward(*inputs):
return module(*inputs)
return custom_forward
if is_torch_version(">=", "1.11.0"):
hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(resnet), hidden_states, temb, use_reentrant=False
)
else:
hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(resnet), hidden_states, temb
)
hidden_states = self._gradient_checkpointing_func(resnet, hidden_states, temb)
else:
hidden_states = resnet(hidden_states, temb)
@@ -2058,17 +1966,7 @@ class SimpleCrossAttnDownBlock2D(nn.Module):
for resnet, attn in zip(self.resnets, self.attentions):
if torch.is_grad_enabled() and self.gradient_checkpointing:
def create_custom_forward(module, return_dict=None):
def custom_forward(*inputs):
if return_dict is not None:
return module(*inputs, return_dict=return_dict)
else:
return module(*inputs)
return custom_forward
hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb)
hidden_states = self._gradient_checkpointing_func(resnet, hidden_states, temb)
hidden_states = attn(
hidden_states,
encoder_hidden_states=encoder_hidden_states,
@@ -2153,21 +2051,7 @@ class KDownBlock2D(nn.Module):
for resnet in self.resnets:
if torch.is_grad_enabled() and self.gradient_checkpointing:
def create_custom_forward(module):
def custom_forward(*inputs):
return module(*inputs)
return custom_forward
if is_torch_version(">=", "1.11.0"):
hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(resnet), hidden_states, temb, use_reentrant=False
)
else:
hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(resnet), hidden_states, temb
)
hidden_states = self._gradient_checkpointing_func(resnet, hidden_states, temb)
else:
hidden_states = resnet(hidden_states, temb)
@@ -2262,22 +2146,10 @@ class KCrossAttnDownBlock2D(nn.Module):
for resnet, attn in zip(self.resnets, self.attentions):
if torch.is_grad_enabled() and self.gradient_checkpointing:
def create_custom_forward(module, return_dict=None):
def custom_forward(*inputs):
if return_dict is not None:
return module(*inputs, return_dict=return_dict)
else:
return module(*inputs)
return custom_forward
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(resnet),
hidden_states = self._gradient_checkpointing_func(
resnet,
hidden_states,
temb,
**ckpt_kwargs,
)
hidden_states = attn(
hidden_states,
@@ -2423,23 +2295,7 @@ class AttnUpBlock2D(nn.Module):
hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
if torch.is_grad_enabled() and self.gradient_checkpointing:
def create_custom_forward(module, return_dict=None):
def custom_forward(*inputs):
if return_dict is not None:
return module(*inputs, return_dict=return_dict)
else:
return module(*inputs)
return custom_forward
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(resnet),
hidden_states,
temb,
**ckpt_kwargs,
)
hidden_states = self._gradient_checkpointing_func(resnet, hidden_states, temb)
hidden_states = attn(hidden_states)
else:
hidden_states = resnet(hidden_states, temb)
@@ -2588,23 +2444,7 @@ class CrossAttnUpBlock2D(nn.Module):
hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
if torch.is_grad_enabled() and self.gradient_checkpointing:
def create_custom_forward(module, return_dict=None):
def custom_forward(*inputs):
if return_dict is not None:
return module(*inputs, return_dict=return_dict)
else:
return module(*inputs)
return custom_forward
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(resnet),
hidden_states,
temb,
**ckpt_kwargs,
)
hidden_states = self._gradient_checkpointing_func(resnet, hidden_states, temb)
hidden_states = attn(
hidden_states,
encoder_hidden_states=encoder_hidden_states,
@@ -2721,21 +2561,7 @@ class UpBlock2D(nn.Module):
hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
if torch.is_grad_enabled() and self.gradient_checkpointing:
def create_custom_forward(module):
def custom_forward(*inputs):
return module(*inputs)
return custom_forward
if is_torch_version(">=", "1.11.0"):
hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(resnet), hidden_states, temb, use_reentrant=False
)
else:
hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(resnet), hidden_states, temb
)
hidden_states = self._gradient_checkpointing_func(resnet, hidden_states, temb)
else:
hidden_states = resnet(hidden_states, temb)
@@ -3251,21 +3077,7 @@ class ResnetUpsampleBlock2D(nn.Module):
hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
if torch.is_grad_enabled() and self.gradient_checkpointing:
def create_custom_forward(module):
def custom_forward(*inputs):
return module(*inputs)
return custom_forward
if is_torch_version(">=", "1.11.0"):
hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(resnet), hidden_states, temb, use_reentrant=False
)
else:
hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(resnet), hidden_states, temb
)
hidden_states = self._gradient_checkpointing_func(resnet, hidden_states, temb)
else:
hidden_states = resnet(hidden_states, temb)
@@ -3409,17 +3221,7 @@ class SimpleCrossAttnUpBlock2D(nn.Module):
hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
if torch.is_grad_enabled() and self.gradient_checkpointing:
def create_custom_forward(module, return_dict=None):
def custom_forward(*inputs):
if return_dict is not None:
return module(*inputs, return_dict=return_dict)
else:
return module(*inputs)
return custom_forward
hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb)
hidden_states = self._gradient_checkpointing_func(resnet, hidden_states, temb)
hidden_states = attn(
hidden_states,
encoder_hidden_states=encoder_hidden_states,
@@ -3512,21 +3314,7 @@ class KUpBlock2D(nn.Module):
for resnet in self.resnets:
if torch.is_grad_enabled() and self.gradient_checkpointing:
def create_custom_forward(module):
def custom_forward(*inputs):
return module(*inputs)
return custom_forward
if is_torch_version(">=", "1.11.0"):
hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(resnet), hidden_states, temb, use_reentrant=False
)
else:
hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(resnet), hidden_states, temb
)
hidden_states = self._gradient_checkpointing_func(resnet, hidden_states, temb)
else:
hidden_states = resnet(hidden_states, temb)
@@ -3640,22 +3428,10 @@ class KCrossAttnUpBlock2D(nn.Module):
for resnet, attn in zip(self.resnets, self.attentions):
if torch.is_grad_enabled() and self.gradient_checkpointing:
def create_custom_forward(module, return_dict=None):
def custom_forward(*inputs):
if return_dict is not None:
return module(*inputs, return_dict=return_dict)
else:
return module(*inputs)
return custom_forward
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(resnet),
hidden_states = self._gradient_checkpointing_func(
resnet,
hidden_states,
temb,
**ckpt_kwargs,
)
hidden_states = attn(
hidden_states,
@@ -166,6 +166,7 @@ class UNet2DConditionModel(
_supports_gradient_checkpointing = True
_no_split_modules = ["BasicTransformerBlock", "ResnetBlock2D", "CrossAttnUpBlock2D"]
_skip_layerwise_casting_patterns = ["norm"]
@register_to_config
def __init__(
@@ -833,10 +834,6 @@ class UNet2DConditionModel(
for module in self.children():
fn_recursive_set_attention_slice(module, reversed_slice_size)
def _set_gradient_checkpointing(self, module, value=False):
if hasattr(module, "gradient_checkpointing"):
module.gradient_checkpointing = value
def enable_freeu(self, s1: float, s2: float, b1: float, b2: float):
r"""Enables the FreeU mechanism from https://arxiv.org/abs/2309.11497.
+14 -127
View File
@@ -17,7 +17,7 @@ from typing import Any, Dict, Optional, Tuple, Union
import torch
from torch import nn
from ...utils import deprecate, is_torch_version, logging
from ...utils import deprecate, logging
from ...utils.torch_utils import apply_freeu
from ..attention import Attention
from ..resnet import (
@@ -1078,31 +1078,14 @@ class UNetMidBlockSpatioTemporal(nn.Module):
)
for attn, resnet in zip(self.attentions, self.resnets[1:]):
if torch.is_grad_enabled() and self.gradient_checkpointing: # TODO
def create_custom_forward(module, return_dict=None):
def custom_forward(*inputs):
if return_dict is not None:
return module(*inputs, return_dict=return_dict)
else:
return module(*inputs)
return custom_forward
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
if torch.is_grad_enabled() and self.gradient_checkpointing:
hidden_states = attn(
hidden_states,
encoder_hidden_states=encoder_hidden_states,
image_only_indicator=image_only_indicator,
return_dict=False,
)[0]
hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(resnet),
hidden_states,
temb,
image_only_indicator,
**ckpt_kwargs,
)
hidden_states = self._gradient_checkpointing_func(resnet, hidden_states, temb, image_only_indicator)
else:
hidden_states = attn(
hidden_states,
@@ -1110,11 +1093,7 @@ class UNetMidBlockSpatioTemporal(nn.Module):
image_only_indicator=image_only_indicator,
return_dict=False,
)[0]
hidden_states = resnet(
hidden_states,
temb,
image_only_indicator=image_only_indicator,
)
hidden_states = resnet(hidden_states, temb, image_only_indicator=image_only_indicator)
return hidden_states
@@ -1169,34 +1148,9 @@ class DownBlockSpatioTemporal(nn.Module):
output_states = ()
for resnet in self.resnets:
if torch.is_grad_enabled() and self.gradient_checkpointing:
def create_custom_forward(module):
def custom_forward(*inputs):
return module(*inputs)
return custom_forward
if is_torch_version(">=", "1.11.0"):
hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(resnet),
hidden_states,
temb,
image_only_indicator,
use_reentrant=False,
)
else:
hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(resnet),
hidden_states,
temb,
image_only_indicator,
)
hidden_states = self._gradient_checkpointing_func(resnet, hidden_states, temb, image_only_indicator)
else:
hidden_states = resnet(
hidden_states,
temb,
image_only_indicator=image_only_indicator,
)
hidden_states = resnet(hidden_states, temb, image_only_indicator=image_only_indicator)
output_states = output_states + (hidden_states,)
@@ -1281,25 +1235,8 @@ class CrossAttnDownBlockSpatioTemporal(nn.Module):
blocks = list(zip(self.resnets, self.attentions))
for resnet, attn in blocks:
if torch.is_grad_enabled() and self.gradient_checkpointing: # TODO
def create_custom_forward(module, return_dict=None):
def custom_forward(*inputs):
if return_dict is not None:
return module(*inputs, return_dict=return_dict)
else:
return module(*inputs)
return custom_forward
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(resnet),
hidden_states,
temb,
image_only_indicator,
**ckpt_kwargs,
)
if torch.is_grad_enabled() and self.gradient_checkpointing:
hidden_states = self._gradient_checkpointing_func(resnet, hidden_states, temb, image_only_indicator)
hidden_states = attn(
hidden_states,
@@ -1308,11 +1245,7 @@ class CrossAttnDownBlockSpatioTemporal(nn.Module):
return_dict=False,
)[0]
else:
hidden_states = resnet(
hidden_states,
temb,
image_only_indicator=image_only_indicator,
)
hidden_states = resnet(hidden_states, temb, image_only_indicator=image_only_indicator)
hidden_states = attn(
hidden_states,
encoder_hidden_states=encoder_hidden_states,
@@ -1385,34 +1318,9 @@ class UpBlockSpatioTemporal(nn.Module):
hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
if torch.is_grad_enabled() and self.gradient_checkpointing:
def create_custom_forward(module):
def custom_forward(*inputs):
return module(*inputs)
return custom_forward
if is_torch_version(">=", "1.11.0"):
hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(resnet),
hidden_states,
temb,
image_only_indicator,
use_reentrant=False,
)
else:
hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(resnet),
hidden_states,
temb,
image_only_indicator,
)
hidden_states = self._gradient_checkpointing_func(resnet, hidden_states, temb, image_only_indicator)
else:
hidden_states = resnet(
hidden_states,
temb,
image_only_indicator=image_only_indicator,
)
hidden_states = resnet(hidden_states, temb, image_only_indicator=image_only_indicator)
if self.upsamplers is not None:
for upsampler in self.upsamplers:
@@ -1495,25 +1403,8 @@ class CrossAttnUpBlockSpatioTemporal(nn.Module):
hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
if torch.is_grad_enabled() and self.gradient_checkpointing: # TODO
def create_custom_forward(module, return_dict=None):
def custom_forward(*inputs):
if return_dict is not None:
return module(*inputs, return_dict=return_dict)
else:
return module(*inputs)
return custom_forward
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(resnet),
hidden_states,
temb,
image_only_indicator,
**ckpt_kwargs,
)
if torch.is_grad_enabled() and self.gradient_checkpointing:
hidden_states = self._gradient_checkpointing_func(resnet, hidden_states, temb, image_only_indicator)
hidden_states = attn(
hidden_states,
encoder_hidden_states=encoder_hidden_states,
@@ -1521,11 +1412,7 @@ class CrossAttnUpBlockSpatioTemporal(nn.Module):
return_dict=False,
)[0]
else:
hidden_states = resnet(
hidden_states,
temb,
image_only_indicator=image_only_indicator,
)
hidden_states = resnet(hidden_states, temb, image_only_indicator=image_only_indicator)
hidden_states = attn(
hidden_states,
encoder_hidden_states=encoder_hidden_states,
@@ -37,11 +37,7 @@ from ..embeddings import TimestepEmbedding, Timesteps
from ..modeling_utils import ModelMixin
from ..transformers.transformer_temporal import TransformerTemporalModel
from .unet_3d_blocks import (
CrossAttnDownBlock3D,
CrossAttnUpBlock3D,
DownBlock3D,
UNetMidBlock3DCrossAttn,
UpBlock3D,
get_down_block,
get_up_block,
)
@@ -97,6 +93,7 @@ class UNet3DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
"""
_supports_gradient_checkpointing = False
_skip_layerwise_casting_patterns = ["norm", "time_embedding"]
@register_to_config
def __init__(
@@ -471,10 +468,6 @@ class UNet3DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
self.set_attn_processor(processor)
def _set_gradient_checkpointing(self, module, value: bool = False) -> None:
if isinstance(module, (CrossAttnDownBlock3D, DownBlock3D, CrossAttnUpBlock3D, UpBlock3D)):
module.gradient_checkpointing = value
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.enable_freeu
def enable_freeu(self, s1, s2, b1, b2):
r"""Enables the FreeU mechanism from https://arxiv.org/abs/2309.11497.
@@ -35,11 +35,7 @@ from ..embeddings import TimestepEmbedding, Timesteps
from ..modeling_utils import ModelMixin
from ..transformers.transformer_temporal import TransformerTemporalModel
from .unet_3d_blocks import (
CrossAttnDownBlock3D,
CrossAttnUpBlock3D,
DownBlock3D,
UNetMidBlock3DCrossAttn,
UpBlock3D,
get_down_block,
get_up_block,
)
@@ -436,11 +432,6 @@ class I2VGenXLUNet(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin):
self.set_attn_processor(processor)
# Copied from diffusers.models.unets.unet_3d_condition.UNet3DConditionModel._set_gradient_checkpointing
def _set_gradient_checkpointing(self, module, value: bool = False) -> None:
if isinstance(module, (CrossAttnDownBlock3D, DownBlock3D, CrossAttnUpBlock3D, UpBlock3D)):
module.gradient_checkpointing = value
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.enable_freeu
def enable_freeu(self, s1, s2, b1, b2):
r"""Enables the FreeU mechanism from https://arxiv.org/abs/2309.11497.
@@ -205,10 +205,6 @@ class Kandinsky3UNet(ModelMixin, ConfigMixin):
"""
self.set_attn_processor(AttnProcessor())
def _set_gradient_checkpointing(self, module, value=False):
if hasattr(module, "gradient_checkpointing"):
module.gradient_checkpointing = value
def forward(self, sample, timestep, encoder_hidden_states=None, encoder_attention_mask=None, return_dict=True):
if encoder_attention_mask is not None:
encoder_attention_mask = (1 - encoder_attention_mask.to(sample.dtype)) * -10000.0
+12 -110
View File
@@ -22,7 +22,7 @@ import torch.utils.checkpoint
from ...configuration_utils import ConfigMixin, FrozenDict, register_to_config
from ...loaders import FromOriginalModelMixin, PeftAdapterMixin, UNet2DConditionLoadersMixin
from ...utils import BaseOutput, deprecate, is_torch_version, logging
from ...utils import BaseOutput, deprecate, logging
from ...utils.torch_utils import apply_freeu
from ..attention import BasicTransformerBlock
from ..attention_processor import (
@@ -324,25 +324,7 @@ class DownBlockMotion(nn.Module):
blocks = zip(self.resnets, self.motion_modules)
for resnet, motion_module in blocks:
if torch.is_grad_enabled() and self.gradient_checkpointing:
def create_custom_forward(module):
def custom_forward(*inputs):
return module(*inputs)
return custom_forward
if is_torch_version(">=", "1.11.0"):
hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(resnet),
hidden_states,
temb,
use_reentrant=False,
)
else:
hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(resnet), hidden_states, temb
)
hidden_states = self._gradient_checkpointing_func(resnet, hidden_states, temb)
else:
hidden_states = resnet(input_tensor=hidden_states, temb=temb)
@@ -514,23 +496,7 @@ class CrossAttnDownBlockMotion(nn.Module):
blocks = list(zip(self.resnets, self.attentions, self.motion_modules))
for i, (resnet, attn, motion_module) in enumerate(blocks):
if torch.is_grad_enabled() and self.gradient_checkpointing:
def create_custom_forward(module, return_dict=None):
def custom_forward(*inputs):
if return_dict is not None:
return module(*inputs, return_dict=return_dict)
else:
return module(*inputs)
return custom_forward
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(resnet),
hidden_states,
temb,
**ckpt_kwargs,
)
hidden_states = self._gradient_checkpointing_func(resnet, hidden_states, temb)
else:
hidden_states = resnet(input_tensor=hidden_states, temb=temb)
@@ -543,10 +509,7 @@ class CrossAttnDownBlockMotion(nn.Module):
return_dict=False,
)[0]
hidden_states = motion_module(
hidden_states,
num_frames=num_frames,
)
hidden_states = motion_module(hidden_states, num_frames=num_frames)
# apply additional residuals to the output of the last pair of resnet and attention blocks
if i == len(blocks) - 1 and additional_residuals is not None:
@@ -733,23 +696,7 @@ class CrossAttnUpBlockMotion(nn.Module):
hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
if torch.is_grad_enabled() and self.gradient_checkpointing:
def create_custom_forward(module, return_dict=None):
def custom_forward(*inputs):
if return_dict is not None:
return module(*inputs, return_dict=return_dict)
else:
return module(*inputs)
return custom_forward
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(resnet),
hidden_states,
temb,
**ckpt_kwargs,
)
hidden_states = self._gradient_checkpointing_func(resnet, hidden_states, temb)
else:
hidden_states = resnet(input_tensor=hidden_states, temb=temb)
@@ -762,10 +709,7 @@ class CrossAttnUpBlockMotion(nn.Module):
return_dict=False,
)[0]
hidden_states = motion_module(
hidden_states,
num_frames=num_frames,
)
hidden_states = motion_module(hidden_states, num_frames=num_frames)
if self.upsamplers is not None:
for upsampler in self.upsamplers:
@@ -896,24 +840,7 @@ class UpBlockMotion(nn.Module):
hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
if torch.is_grad_enabled() and self.gradient_checkpointing:
def create_custom_forward(module):
def custom_forward(*inputs):
return module(*inputs)
return custom_forward
if is_torch_version(">=", "1.11.0"):
hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(resnet),
hidden_states,
temb,
use_reentrant=False,
)
else:
hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(resnet), hidden_states, temb
)
hidden_states = self._gradient_checkpointing_func(resnet, hidden_states, temb)
else:
hidden_states = resnet(input_tensor=hidden_states, temb=temb)
@@ -1080,34 +1007,12 @@ class UNetMidBlockCrossAttnMotion(nn.Module):
)[0]
if torch.is_grad_enabled() and self.gradient_checkpointing:
def create_custom_forward(module, return_dict=None):
def custom_forward(*inputs):
if return_dict is not None:
return module(*inputs, return_dict=return_dict)
else:
return module(*inputs)
return custom_forward
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(motion_module),
hidden_states,
temb,
**ckpt_kwargs,
)
hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(resnet),
hidden_states,
temb,
**ckpt_kwargs,
hidden_states = self._gradient_checkpointing_func(
motion_module, hidden_states, None, None, None, num_frames, None
)
hidden_states = self._gradient_checkpointing_func(resnet, hidden_states, temb)
else:
hidden_states = motion_module(
hidden_states,
num_frames=num_frames,
)
hidden_states = motion_module(hidden_states, None, None, None, num_frames, None)
hidden_states = resnet(input_tensor=hidden_states, temb=temb)
return hidden_states
@@ -1301,6 +1206,7 @@ class UNetMotionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin, Peft
"""
_supports_gradient_checkpointing = True
_skip_layerwise_casting_patterns = ["norm"]
@register_to_config
def __init__(
@@ -1965,10 +1871,6 @@ class UNetMotionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin, Peft
self.set_attn_processor(processor)
def _set_gradient_checkpointing(self, module, value: bool = False) -> None:
if isinstance(module, (CrossAttnDownBlockMotion, DownBlockMotion, CrossAttnUpBlockMotion, UpBlockMotion)):
module.gradient_checkpointing = value
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.enable_freeu
def enable_freeu(self, s1: float, s2: float, b1: float, b2: float) -> None:
r"""Enables the FreeU mechanism from https://arxiv.org/abs/2309.11497.
@@ -320,10 +320,6 @@ class UNetSpatioTemporalConditionModel(ModelMixin, ConfigMixin, UNet2DConditionL
self.set_attn_processor(processor)
def _set_gradient_checkpointing(self, module, value=False):
if hasattr(module, "gradient_checkpointing"):
module.gradient_checkpointing = value
# Copied from diffusers.models.unets.unet_3d_condition.UNet3DConditionModel.enable_forward_chunking
def enable_forward_chunking(self, chunk_size: Optional[int] = None, dim: int = 0) -> None:
"""
@@ -387,9 +387,6 @@ class StableCascadeUNet(ModelMixin, ConfigMixin, FromOriginalModelMixin):
self.gradient_checkpointing = False
def _set_gradient_checkpointing(self, value=False):
self.gradient_checkpointing = value
def _init_weights(self, m):
if isinstance(m, (nn.Conv2d, nn.Linear)):
torch.nn.init.xavier_uniform_(m.weight)
@@ -456,29 +453,18 @@ class StableCascadeUNet(ModelMixin, ConfigMixin, FromOriginalModelMixin):
block_group = zip(self.down_blocks, self.down_downscalers, self.down_repeat_mappers)
if torch.is_grad_enabled() and self.gradient_checkpointing:
def create_custom_forward(module):
def custom_forward(*inputs):
return module(*inputs)
return custom_forward
for down_block, downscaler, repmap in block_group:
x = downscaler(x)
for i in range(len(repmap) + 1):
for block in down_block:
if isinstance(block, SDCascadeResBlock):
x = torch.utils.checkpoint.checkpoint(create_custom_forward(block), x, use_reentrant=False)
x = self._gradient_checkpointing_func(block, x)
elif isinstance(block, SDCascadeAttnBlock):
x = torch.utils.checkpoint.checkpoint(
create_custom_forward(block), x, clip, use_reentrant=False
)
x = self._gradient_checkpointing_func(block, x, clip)
elif isinstance(block, SDCascadeTimestepBlock):
x = torch.utils.checkpoint.checkpoint(
create_custom_forward(block), x, r_embed, use_reentrant=False
)
x = self._gradient_checkpointing_func(block, x, r_embed)
else:
x = torch.utils.checkpoint.checkpoint(create_custom_forward(block), use_reentrant=False)
x = self._gradient_checkpointing_func(block)
if i < len(repmap):
x = repmap[i](x)
level_outputs.insert(0, x)
@@ -505,13 +491,6 @@ class StableCascadeUNet(ModelMixin, ConfigMixin, FromOriginalModelMixin):
block_group = zip(self.up_blocks, self.up_upscalers, self.up_repeat_mappers)
if torch.is_grad_enabled() and self.gradient_checkpointing:
def create_custom_forward(module):
def custom_forward(*inputs):
return module(*inputs)
return custom_forward
for i, (up_block, upscaler, repmap) in enumerate(block_group):
for j in range(len(repmap) + 1):
for k, block in enumerate(up_block):
@@ -523,19 +502,13 @@ class StableCascadeUNet(ModelMixin, ConfigMixin, FromOriginalModelMixin):
x.float(), skip.shape[-2:], mode="bilinear", align_corners=True
)
x = x.to(orig_type)
x = torch.utils.checkpoint.checkpoint(
create_custom_forward(block), x, skip, use_reentrant=False
)
x = self._gradient_checkpointing_func(block, x, skip)
elif isinstance(block, SDCascadeAttnBlock):
x = torch.utils.checkpoint.checkpoint(
create_custom_forward(block), x, clip, use_reentrant=False
)
x = self._gradient_checkpointing_func(block, x, clip)
elif isinstance(block, SDCascadeTimestepBlock):
x = torch.utils.checkpoint.checkpoint(
create_custom_forward(block), x, r_embed, use_reentrant=False
)
x = self._gradient_checkpointing_func(block, x, r_embed)
else:
x = torch.utils.checkpoint.checkpoint(create_custom_forward(block), x, use_reentrant=False)
x = self._gradient_checkpointing_func(block, x)
if j < len(repmap):
x = repmap[j](x)
x = upscaler(x)
-3
View File
@@ -148,9 +148,6 @@ class UVit2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
self.gradient_checkpointing = False
def _set_gradient_checkpointing(self, module, value: bool = False) -> None:
pass
def forward(self, input_ids, encoder_hidden_states, pooled_text_emb, micro_conds, cross_attention_kwargs=None):
encoder_hidden_states = self.encoder_proj(encoder_hidden_states)
encoder_hidden_states = self.encoder_proj_layer_norm(encoder_hidden_states)
@@ -683,6 +683,10 @@ class AllegroPipeline(DiffusionPipeline):
def num_timesteps(self):
return self._num_timesteps
@property
def current_timestep(self):
return self._current_timestep
@property
def interrupt(self):
return self._interrupt
@@ -815,6 +819,7 @@ class AllegroPipeline(DiffusionPipeline):
negative_prompt_attention_mask,
)
self._guidance_scale = guidance_scale
self._current_timestep = None
self._interrupt = False
# 2. Default height and width to transformer
@@ -892,6 +897,7 @@ class AllegroPipeline(DiffusionPipeline):
if self.interrupt:
continue
self._current_timestep = t
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
@@ -933,6 +939,8 @@ class AllegroPipeline(DiffusionPipeline):
if XLA_AVAILABLE:
xm.mark_step()
self._current_timestep = None
if not output_type == "latent":
latents = latents.to(self.vae.dtype)
video = self.decode_latents(latents)
@@ -38,7 +38,7 @@ from ...models.resnet import Downsample2D, ResnetBlock2D, Upsample2D
from ...models.transformers.transformer_2d import Transformer2DModel
from ...models.unets.unet_2d_blocks import DownBlock2D, UpBlock2D
from ...models.unets.unet_2d_condition import UNet2DConditionOutput
from ...utils import BaseOutput, is_torch_version, logging
from ...utils import BaseOutput, logging
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
@@ -673,11 +673,6 @@ class AudioLDM2UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoad
for module in self.children():
fn_recursive_set_attention_slice(module, reversed_slice_size)
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel._set_gradient_checkpointing
def _set_gradient_checkpointing(self, module, value=False):
if hasattr(module, "gradient_checkpointing"):
module.gradient_checkpointing = value
def forward(
self,
sample: torch.Tensor,
@@ -1114,23 +1109,7 @@ class CrossAttnDownBlock2D(nn.Module):
for i in range(num_layers):
if torch.is_grad_enabled() and self.gradient_checkpointing:
def create_custom_forward(module, return_dict=None):
def custom_forward(*inputs):
if return_dict is not None:
return module(*inputs, return_dict=return_dict)
else:
return module(*inputs)
return custom_forward
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(self.resnets[i]),
hidden_states,
temb,
**ckpt_kwargs,
)
hidden_states = self._gradient_checkpointing_func(self.resnets[i], hidden_states, temb)
for idx, cross_attention_dim in enumerate(self.cross_attention_dim):
if cross_attention_dim is not None and idx <= 1:
forward_encoder_hidden_states = encoder_hidden_states
@@ -1141,8 +1120,8 @@ class CrossAttnDownBlock2D(nn.Module):
else:
forward_encoder_hidden_states = None
forward_encoder_attention_mask = None
hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(self.attentions[i * num_attention_per_layer + idx], return_dict=False),
hidden_states = self._gradient_checkpointing_func(
self.attentions[i * num_attention_per_layer + idx],
hidden_states,
forward_encoder_hidden_states,
None, # timestep
@@ -1150,7 +1129,6 @@ class CrossAttnDownBlock2D(nn.Module):
cross_attention_kwargs,
attention_mask,
forward_encoder_attention_mask,
**ckpt_kwargs,
)[0]
else:
hidden_states = self.resnets[i](hidden_states, temb)
@@ -1292,17 +1270,6 @@ class UNetMidBlock2DCrossAttn(nn.Module):
for i in range(len(self.resnets[1:])):
if torch.is_grad_enabled() and self.gradient_checkpointing:
def create_custom_forward(module, return_dict=None):
def custom_forward(*inputs):
if return_dict is not None:
return module(*inputs, return_dict=return_dict)
else:
return module(*inputs)
return custom_forward
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
for idx, cross_attention_dim in enumerate(self.cross_attention_dim):
if cross_attention_dim is not None and idx <= 1:
forward_encoder_hidden_states = encoder_hidden_states
@@ -1313,8 +1280,8 @@ class UNetMidBlock2DCrossAttn(nn.Module):
else:
forward_encoder_hidden_states = None
forward_encoder_attention_mask = None
hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(self.attentions[i * num_attention_per_layer + idx], return_dict=False),
hidden_states = self._gradient_checkpointing_func(
self.attentions[i * num_attention_per_layer + idx],
hidden_states,
forward_encoder_hidden_states,
None, # timestep
@@ -1322,14 +1289,8 @@ class UNetMidBlock2DCrossAttn(nn.Module):
cross_attention_kwargs,
attention_mask,
forward_encoder_attention_mask,
**ckpt_kwargs,
)[0]
hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(self.resnets[i + 1]),
hidden_states,
temb,
**ckpt_kwargs,
)
hidden_states = self._gradient_checkpointing_func(self.resnets[i + 1], hidden_states, temb)
else:
for idx, cross_attention_dim in enumerate(self.cross_attention_dim):
if cross_attention_dim is not None and idx <= 1:
@@ -1466,23 +1427,7 @@ class CrossAttnUpBlock2D(nn.Module):
hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
if torch.is_grad_enabled() and self.gradient_checkpointing:
def create_custom_forward(module, return_dict=None):
def custom_forward(*inputs):
if return_dict is not None:
return module(*inputs, return_dict=return_dict)
else:
return module(*inputs)
return custom_forward
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(self.resnets[i]),
hidden_states,
temb,
**ckpt_kwargs,
)
hidden_states = self._gradient_checkpointing_func(self.resnets[i], hidden_states, temb)
for idx, cross_attention_dim in enumerate(self.cross_attention_dim):
if cross_attention_dim is not None and idx <= 1:
forward_encoder_hidden_states = encoder_hidden_states
@@ -1493,8 +1438,8 @@ class CrossAttnUpBlock2D(nn.Module):
else:
forward_encoder_hidden_states = None
forward_encoder_attention_mask = None
hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(self.attentions[i * num_attention_per_layer + idx], return_dict=False),
hidden_states = self._gradient_checkpointing_func(
self.attentions[i * num_attention_per_layer + idx],
hidden_states,
forward_encoder_hidden_states,
None, # timestep
@@ -1502,7 +1447,6 @@ class CrossAttnUpBlock2D(nn.Module):
cross_attention_kwargs,
attention_mask,
forward_encoder_attention_mask,
**ckpt_kwargs,
)[0]
else:
hidden_states = self.resnets[i](hidden_states, temb)
@@ -160,8 +160,10 @@ class AuraFlowPipeline(DiffusionPipeline):
prompt_attention_mask=None,
negative_prompt_attention_mask=None,
):
if height % 8 != 0 or width % 8 != 0:
raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
if height % (self.vae_scale_factor * 2) != 0 or width % (self.vae_scale_factor * 2) != 0:
raise ValueError(
f"`height` and `width` have to be divisible by {self.vae_scale_factor * 2} but are {height} and {width}."
)
if prompt is not None and prompt_embeds is not None:
raise ValueError(
@@ -174,19 +174,16 @@ class Blip2QFormerEncoder(nn.Module):
)
use_cache = False
def create_custom_forward(module):
def custom_forward(*inputs):
return module(*inputs, past_key_value, output_attentions, query_length)
return custom_forward
layer_outputs = torch.utils.checkpoint.checkpoint(
create_custom_forward(layer_module),
layer_outputs = self._gradient_checkpointing_func(
layer_module,
hidden_states,
attention_mask,
layer_head_mask,
encoder_hidden_states,
encoder_attention_mask,
past_key_value,
output_attentions,
query_length,
)
else:
layer_outputs = layer_module(
@@ -494,6 +494,10 @@ class CogVideoXPipeline(DiffusionPipeline, CogVideoXLoraLoaderMixin):
def attention_kwargs(self):
return self._attention_kwargs
@property
def current_timestep(self):
return self._current_timestep
@property
def interrupt(self):
return self._interrupt
@@ -627,6 +631,7 @@ class CogVideoXPipeline(DiffusionPipeline, CogVideoXLoraLoaderMixin):
)
self._guidance_scale = guidance_scale
self._attention_kwargs = attention_kwargs
self._current_timestep = None
self._interrupt = False
# 2. Default call parameters
@@ -705,6 +710,7 @@ class CogVideoXPipeline(DiffusionPipeline, CogVideoXLoraLoaderMixin):
if self.interrupt:
continue
self._current_timestep = t
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
@@ -763,6 +769,8 @@ class CogVideoXPipeline(DiffusionPipeline, CogVideoXLoraLoaderMixin):
if XLA_AVAILABLE:
xm.mark_step()
self._current_timestep = None
if not output_type == "latent":
# Discard any padding frames that were added for CogVideoX 1.5
latents = latents[:, additional_frames:]

Some files were not shown because too many files have changed in this diff Show More