Compare commits

...

58 Commits

Author SHA1 Message Date
Sayak Paul 6ca6fbd614 Merge branch 'main' into add-caching-note 2024-06-24 22:27:18 +05:30
Sayak Paul 3e3d102f20 Update docs/source/en/tutorials/fast_diffusion.md
Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com>
2024-06-24 22:27:13 +05:30
Tolga Cangöz 468ae09ed8 Errata - Trim trailing white space in the whole repo (#8575)
* Trim all the trailing white space in the whole repo

* Remove unnecessary empty places

* make style && make quality

* Trim trailing white space

* trim

---------

Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
2024-06-24 18:39:15 +05:30
Dong 3fca52022f 🎨 fix xl playground device (#8550)
* 🎨 fix xl playground device

* 🎨 run `make fix-copies`

* 🎨 run `make fix-copies`

* edit xl_controlnet_img2img file

* edit playground img2img test slow

* Update tests/pipelines/stable_diffusion_xl/test_stable_diffusion_xl_img2img.py

---------

Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
2024-06-24 16:49:55 +05:30
Tolga Cangöz c375903db5 Errata - Fix typos & improve contributing page (#8572)
* Fix typos & improve contributing page

* `make style && make quality`

* fix typos

* Fix typo

---------

Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
2024-06-24 14:13:03 +05:30
Vinh H. Pham b9d52fca1d [train_lcm_distill_lora_sdxl.py] Fix the LR schedulers when num_train_epochs is passed in a distributed training env (#8446)
fix num_train_epochs

Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
2024-06-24 14:09:28 +05:30
drhead 2ada094bff Add extra performance features for EMAModel, torch._foreach operations and better support for non-blocking CPU offloading (#7685)
* Add support for _foreach operations and non-blocking to EMAModel

* default foreach to false

* add non-blocking EMA offloading to SD1.5 T2I example script

* fix whitespace

* move foreach to cli argument

* linting

* Update README.md re: EMA weight training

* correct args.foreach_ema

* add tests for foreach ema

* code quality

* add foreach to from_pretrained

* default foreach false

* fix linting

---------

Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
Co-authored-by: drhead <a@a.a>
2024-06-24 14:03:47 +05:30
sayakpaul 1b4c4d4614 formatting 2024-06-24 11:45:34 +05:30
sayakpaul 28ef949cf6 add note on caching in fast diffusion 2024-06-24 11:19:32 +05:30
Haofan Wang f1f542bdd4 Update pipeline_stable_diffusion_3_controlnet.py (#8660)
Co-authored-by: YiYi Xu <yixu310@gmail,com>
2024-06-23 15:27:59 +05:30
Sayak Paul a9c403c001 [LoRA] refactor lora conversion utility. (#8295)
* refactor lora conversion utility.

* remove error raises.

* add onetrainer support too.
2024-06-22 08:29:12 +05:30
Álvaro Somoza e7b9a0762b [SD3 LoRA] Fix list index out of range (#8584)
* fix

* add check

* key present is checked before

* test case draft

* aply suggestions

* changed testing repo, back to old class

* forgot docstring

---------

Co-authored-by: YiYi Xu <yixu310@gmail.com>
Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
2024-06-21 21:17:34 +05:30
Sayak Paul 8eb17315c8 [LoRA] get rid of the legacy lora remnants and make our codebase lighter (#8623)
* get rid of the legacy lora remnants and make our codebase lighter

* fix depcrecated lora argument

* fix

* empty commit to trigger ci

* remove print

* empty
2024-06-21 16:36:05 +05:30
YiYi Xu c71c19c5e6 a few fix for shard checkpoints (#8656)
fix

Co-authored-by: yiyixuxu <yixu310@gmail,com>
2024-06-21 12:50:58 +05:30
Steaunk adc31940a9 Fix Typo in StableDiffusion3 (#8642)
* fix typo in __call__ of pipeline_stable_diffusion_3.py

* fix typo in __call__ of pipeline_stable_diffusion_3_img2img.py

---------

Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
2024-06-21 08:45:48 +05:30
satani99 963ee05d16 Update train_dreambooth_lora_sd3.py (#8600)
* Update train_dreambooth_lora_sd3.py

* Update train_dreambooth_lora_sd3.py

* Update train_dreambooth_sd3.py

---------

Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
2024-06-20 17:42:24 +05:30
Sayak Paul 668e34c6e0 [LoRA SD3] add support for lora fusion in sd3 (#8616)
* add support for lora fusion in sd3

* add test to ensure fused lora and effective lora produce same outpouts
2024-06-20 14:25:51 +05:30
Sayak Paul 25d7bb3ea6 [Flax tests] reduce tolerance for a flax test (#8640)
reduce tolerance for a flax test
2024-06-20 00:48:08 +04:00
YiYi Xu 394b8fb996 fix from_single_file for checkpoints with t5 (#8631)
fix single file
2024-06-19 08:23:35 -10:00
Sayak Paul a1d55e14ba Change the default weighting_scheme in the SD3 scripts (#8639)
* change to logit_normal as the weighting scheme

* sensible default mote
2024-06-19 13:05:26 +01:00
王奇勋 e5564d45bf Support SD3 ControlNet and Multi-ControlNet. (#8566)
* sd3 controlnet



---------

Co-authored-by: haofanwang <haofanwang.ai@gmail.com>
2024-06-18 14:59:22 -10:00
Nan 2921a20194 [SD3] Fix mis-matched shape when num_images_per_prompt > 1 using without T5 (text_encoder_3=None) (#8558)
* fix shape mismatch when num_images_per_prompt > 1 and text_encoder_3=None

* style

* fix copies

---------

Co-authored-by: YiYi Xu <yixu310@gmail.com>
Co-authored-by: yiyixuxu <yixu310@gmail,com>
2024-06-18 12:41:18 -10:00
Carolinabanana 3376252d71 Fix gradient checkpointing issue for Stable Diffusion 3 (#8542)
Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
Co-authored-by: YiYi Xu <yixu310@gmail.com>
2024-06-18 11:36:23 -10:00
Yongsen Mao 16170c69ae add sd1.5 compatibility to controlnet-xs and fix unused_parameters error during training (#8606)
* add sd1.5 compatibility to controlnet-xs

* set use_linear_projection by base_block

* refine code style
2024-06-18 11:35:34 -10:00
kkj15dk 4408047ac5 self.upsample = Upsample1D (#8580)
Making self.upsample actually be Upsample1D
2024-06-18 11:34:07 -10:00
Vasco Ramos 34fab8b511 [SD3 Docs] Corrected title about loading model with T5 "without" -> "with" (#8602)
[SD3 Docs] Corrected title about loading model with T5

Corrected the documentation title to "Loading the single file checkpoint with T5" Previously, it incorrectly stated "Loading the single file checkpoint without T5" which contradicted the code snippet showing how to load the SD3 checkpoint with the T5 model
2024-06-18 11:33:43 -10:00
Gæros 298ce67999 [LoRA] text encoder: read the ranks for all the attn modules (#8324)
* [LoRA] text encoder: read the ranks for all the attn modules

 * In addition to out_proj, read the ranks of adapters for q_proj, k_proj, and  v_proj

 * Allow missing adapters (UNet already supports this)

* ruff format loaders.lora

* [LoRA] add tests for partial text encoders LoRAs

* [LoRA] update test_simple_inference_with_partial_text_lora to be deterministic

* [LoRA] comment justifying test_simple_inference_with_partial_text_lora

* style

---------

Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
2024-06-18 21:10:50 +01:00
Andrew Hong d2e7a19fd5 Remove underlines between badges (#8484)
Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
2024-06-18 10:40:12 -07:00
Sayak Paul cd3082008e [Core] Add shift_factor to SD3 tiny autoencoder (#8618)
* shift factor argument to tiny

* remove shift factor rejigging from the sd3 docs
2024-06-18 18:28:02 +01:00
Álvaro Somoza f3209b5b55 [SD3 Inference] T5 Token limit (#8506)
* max_sequence_length for the T5

* updated img2img

* apply suggestions

---------

Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
Co-authored-by: YiYi Xu <yixu310@gmail.com>
2024-06-18 06:46:38 -10:00
Marc Sun 96399c3ec6 Fix sharding when no device_map is passed (#8531)
* Fix sharding when no device_map is passed

* style

* add tests

* align

* add docstring

* format

---------

Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
2024-06-18 05:47:23 -10:00
MaoXianXin 10d3220abe A backslash is missing from the run command (#8471) 2024-06-18 16:44:34 +01:00
Dhruv Nair f69511ecc6 [Single File Loading] Handle unexpected keys in CLIP models when accelerate isn't installed. (#8462)
* update

* update

* update

* update

* update

---------

Co-authored-by: YiYi Xu <yixu310@gmail.com>
Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
2024-06-18 16:39:30 +01:00
Álvaro Somoza d2b10b1f4f [SD3] TAESD3 docs (#8607)
* tased3 docs

* apply suggestion

---------

Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
2024-06-18 15:56:38 +01:00
Sayak Paul 23a2cd3337 [LoRA] training fix the position of param casting when loading them (#8460)
fix the position of param casting when loading them
2024-06-18 14:57:34 +01:00
Sayak Paul 4edde134f6 [SD3 training] refactor the density and weighting utilities. (#8591)
refactor the density and weighting utilities.
2024-06-18 14:44:38 +01:00
Bagheera 074a7cc3c5 SD3: update default training timestep / loss weighting distribution to logit_normal (#8592)
Co-authored-by: bghira <bghira@users.github.com>
Co-authored-by: Kashif Rasul <kashif.rasul@gmail.com>
2024-06-18 14:15:19 +01:00
Álvaro Somoza 6bfd13f07a [SD3 Training] T5 token limit (#8564)
* initial commit

* default back to 77

* better text

* text correction

---------

Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
2024-06-17 16:32:56 -04:00
AmosDinh eeb70033a6 Syntax error in readme example "pipe" -> "pipeline" (#8601)
Update controlnet.md

Syntax error pipe -> pipeline
2024-06-17 11:02:07 -07:00
Dhruv Nair c4a4750cb3 Temporarily pin Numpy in the CI (#8603)
temp pin numpy
2024-06-17 19:32:38 +05:30
YiYi Xu a6375d4101 Image processor latent (#8513)
* fix

* up

---------

Co-authored-by: yiyixuxu <yixu310@gmail,com>
2024-06-16 22:34:55 -10:00
spacepxl 8e1b7a084a Fix the deletion of SD3 text encoders for Dreambooth/LoRA training if the text encoders are not being trained (#8536)
* Update train_dreambooth_sd3.py to fix TE garbage collection

* Update train_dreambooth_lora_sd3.py to fix TE garbage collection

---------

Co-authored-by: Kashif Rasul <kashif.rasul@gmail.com>
Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
2024-06-16 20:52:33 +01:00
Rafie Walker 6946facf69 Implement SD3 loss weighting (#8528)
* Add lognorm and cosmap weighting

* Implement mode sampling

* Update examples/dreambooth/train_dreambooth_lora_sd3.py

* Update examples/dreambooth/train_dreambooth_lora_sd3.py

* Update examples/dreambooth/train_dreambooth_sd3.py

* Update examples/dreambooth/train_dreambooth_sd3.py

* Update examples/dreambooth/train_dreambooth_sd3.py

* Update examples/dreambooth/train_dreambooth_lora_sd3.py

* Update examples/dreambooth/train_dreambooth_sd3.py

* Update examples/dreambooth/train_dreambooth_sd3.py

* Update examples/dreambooth/train_dreambooth_lora_sd3.py

* keep timestamp sampling fully on cpu

---------

Co-authored-by: Kashif Rasul <kashif.rasul@gmail.com>
Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
2024-06-16 20:15:50 +01:00
Sayak Paul 130dd936bb pin accelerate to 0.31.0 (#8563)
* pin accelerate to 0.31.0

* update dep table

* empty
2024-06-16 08:37:00 -10:00
Jonathan Rahn a899e42fc7 add sentencepiece to requirements.txt for SD3 dreambooth (#8538)
* add `sentencepiece` requirement for SD3

add `sentencepiece` requirement

* Empty-Commit

---------

Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
2024-06-14 22:48:36 +01:00
Sayak Paul f96e4a16ad pin transformers to the latest (#8522)
thanks!
2024-06-13 07:39:24 -10:00
Tolga Cangöz 9c6e9684a2 Refactor StableDiffusion3Img2ImgPipeline to remove redundant code (#8533) 2024-06-13 07:36:46 -10:00
Sayak Paul 2e4841ef1e post release 0.29.0 (#8492)
post release
2024-06-13 06:14:20 -10:00
Haofan Wang 8bea943714 Update requirements_sd3.txt (#8521) 2024-06-13 17:02:17 +01:00
YiYi Xu 614d0c64e9 remove the deprecated prepare_mask_and_masked_image function (#8512)
remove prepare mask fn

Co-authored-by: yiyixuxu <yixu310@gmail,com>
Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
2024-06-13 14:59:21 +01:00
Dhruv Nair b1a2c0d577 Expand Single File support in SD3 Pipeline (#8517)
* update

* update
2024-06-13 18:29:19 +05:30
Lucain 06ee907b73 Fix PATH_IN_REPO on new release in mirror_community_pipeline.yaml (#8519)
Fix PATH_IN_REPO in mirror workflow
2024-06-13 10:25:24 +02:00
ちくわぶ 896fb6d8d7 Fix duplicate variable assignments in SD3's JointAttnProcessor (#8516)
* Fix duplicate variable assignments.

* Fix duplicate variable assignments.
2024-06-12 21:52:35 -10:00
Beinsezii 7f51f286a5 Add Hunyuan AutoPipe mapping (#8505) 2024-06-12 16:11:55 -10:00
kkj15dk 829f6defa4 Fix spelling in scheduling_flow_match_euler_discrete.py (#8497)
Update scheduling_flow_match_euler_discrete.py

Spelling:
Foward -> Forward

Co-authored-by: YiYi Xu <yixu310@gmail.com>
2024-06-12 12:37:47 -10:00
Beinsezii 24bdf4b215 Add SD3 AutoPipeline mappings (#8489) 2024-06-12 12:31:36 -10:00
Radamés Ajna 95e0c3757d Fix small typo (#8498) 2024-06-12 15:30:58 -07:00
Sayak Paul 6cf0be5d3d fix warning log for Transformer SD3 (#8496)
fix warning log
2024-06-12 12:25:18 -10:00
217 changed files with 3711 additions and 2218 deletions
@@ -54,7 +54,7 @@ jobs:
else else
# e.g. refs/tags/v0.28.1 -> v0.28.1 # e.g. refs/tags/v0.28.1 -> v0.28.1
echo "CHECKOUT_REF=${{ github.ref }}" >> $GITHUB_ENV echo "CHECKOUT_REF=${{ github.ref }}" >> $GITHUB_ENV
echo "PATH_IN_REPO=${${{ github.ref }}#refs/tags/}" >> $GITHUB_ENV echo "PATH_IN_REPO=$(echo ${{ github.ref }} | sed 's/^refs\/tags\///')" >> $GITHUB_ENV
fi fi
- name: Print env vars - name: Print env vars
run: | run: |
-1
View File
@@ -33,4 +33,3 @@ jobs:
run: | run: |
python -m venv /opt/venv && export PATH="/opt/venv/bin:$PATH" python -m venv /opt/venv && export PATH="/opt/venv/bin:$PATH"
pytest tests/others/test_dependencies.py pytest tests/others/test_dependencies.py
+3 -2
View File
@@ -245,7 +245,7 @@ The official training examples are maintained by the Diffusers' core maintainers
This is because of the same reasons put forward in [6. Contribute a community pipeline](#6-contribute-a-community-pipeline) for official pipelines vs. community pipelines: It is not feasible for the core maintainers to maintain all possible training methods for diffusion models. This is because of the same reasons put forward in [6. Contribute a community pipeline](#6-contribute-a-community-pipeline) for official pipelines vs. community pipelines: It is not feasible for the core maintainers to maintain all possible training methods for diffusion models.
If the Diffusers core maintainers and the community consider a certain training paradigm to be too experimental or not popular enough, the corresponding training code should be put in the `research_projects` folder and maintained by the author. If the Diffusers core maintainers and the community consider a certain training paradigm to be too experimental or not popular enough, the corresponding training code should be put in the `research_projects` folder and maintained by the author.
Both official training and research examples consist of a directory that contains one or more training scripts, a requirements.txt file, and a README.md file. In order for the user to make use of the Both official training and research examples consist of a directory that contains one or more training scripts, a `requirements.txt` file, and a `README.md` file. In order for the user to make use of the
training examples, it is required to clone the repository: training examples, it is required to clone the repository:
```bash ```bash
@@ -255,7 +255,8 @@ git clone https://github.com/huggingface/diffusers
as well as to install all additional dependencies required for training: as well as to install all additional dependencies required for training:
```bash ```bash
pip install -r /examples/<your-example-folder>/requirements.txt cd diffusers
pip install -r examples/<your-example-folder>/requirements.txt
``` ```
Therefore when adding an example, the `requirements.txt` file shall define all pip dependencies required for your training example so that once all those are installed, the user can run the example's training script. See, for example, the [DreamBooth `requirements.txt` file](https://github.com/huggingface/diffusers/blob/main/examples/dreambooth/requirements.txt). Therefore when adding an example, the `requirements.txt` file shall define all pip dependencies required for your training example so that once all those are installed, the user can run the example's training script. See, for example, the [DreamBooth `requirements.txt` file](https://github.com/huggingface/diffusers/blob/main/examples/dreambooth/requirements.txt).
+5 -15
View File
@@ -20,21 +20,11 @@ limitations under the License.
<br> <br>
<p> <p>
<p align="center"> <p align="center">
<a href="https://github.com/huggingface/diffusers/blob/main/LICENSE"> <a href="https://github.com/huggingface/diffusers/blob/main/LICENSE"><img alt="GitHub" src="https://img.shields.io/github/license/huggingface/datasets.svg?color=blue"></a>
<img alt="GitHub" src="https://img.shields.io/github/license/huggingface/datasets.svg?color=blue"> <a href="https://github.com/huggingface/diffusers/releases"><img alt="GitHub release" src="https://img.shields.io/github/release/huggingface/diffusers.svg"></a>
</a> <a href="https://pepy.tech/project/diffusers"><img alt="GitHub release" src="https://static.pepy.tech/badge/diffusers/month"></a>
<a href="https://github.com/huggingface/diffusers/releases"> <a href="CODE_OF_CONDUCT.md"><img alt="Contributor Covenant" src="https://img.shields.io/badge/Contributor%20Covenant-2.1-4baaaa.svg"></a>
<img alt="GitHub release" src="https://img.shields.io/github/release/huggingface/diffusers.svg"> <a href="https://twitter.com/diffuserslib"><img alt="X account" src="https://img.shields.io/twitter/url/https/twitter.com/diffuserslib.svg?style=social&label=Follow%20%40diffuserslib"></a>
</a>
<a href="https://pepy.tech/project/diffusers">
<img alt="GitHub release" src="https://static.pepy.tech/badge/diffusers/month">
</a>
<a href="CODE_OF_CONDUCT.md">
<img alt="Contributor Covenant" src="https://img.shields.io/badge/Contributor%20Covenant-2.1-4baaaa.svg">
</a>
<a href="https://twitter.com/diffuserslib">
<img alt="X account" src="https://img.shields.io/twitter/url/https/twitter.com/diffuserslib.svg?style=social&label=Follow%20%40diffuserslib">
</a>
</p> </p>
🤗 Diffusers is the go-to library for state-of-the-art pretrained diffusion models for generating images, audio, and even 3D structures of molecules. Whether you're looking for a simple inference solution or training your own diffusion models, 🤗 Diffusers is a modular toolbox that supports both. Our library is designed with a focus on [usability over performance](https://huggingface.co/docs/diffusers/conceptual/philosophy#usability-over-performance), [simple over easy](https://huggingface.co/docs/diffusers/conceptual/philosophy#simple-over-easy), and [customizability over abstractions](https://huggingface.co/docs/diffusers/conceptual/philosophy#tweakable-contributorfriendly-over-abstraction). 🤗 Diffusers is the go-to library for state-of-the-art pretrained diffusion models for generating images, audio, and even 3D structures of molecules. Whether you're looking for a simple inference solution or training your own diffusion models, 🤗 Diffusers is a modular toolbox that supports both. Our library is designed with a focus on [usability over performance](https://huggingface.co/docs/diffusers/conceptual/philosophy#usability-over-performance), [simple over easy](https://huggingface.co/docs/diffusers/conceptual/philosophy#simple-over-easy), and [customizability over abstractions](https://huggingface.co/docs/diffusers/conceptual/philosophy#tweakable-contributorfriendly-over-abstraction).
+1 -1
View File
@@ -42,7 +42,7 @@ RUN python3.10 -m pip install --no-cache-dir --upgrade pip uv==0.1.11 && \
huggingface-hub \ huggingface-hub \
Jinja2 \ Jinja2 \
librosa \ librosa \
numpy \ numpy==1.26.4 \
scipy \ scipy \
tensorboard \ tensorboard \
transformers \ transformers \
+1 -1
View File
@@ -40,7 +40,7 @@ RUN python3 -m pip install --no-cache-dir --upgrade pip uv==0.1.11 && \
huggingface-hub \ huggingface-hub \
Jinja2 \ Jinja2 \
librosa \ librosa \
numpy \ numpy==1.26.4 \
scipy \ scipy \
tensorboard \ tensorboard \
transformers transformers
+1 -1
View File
@@ -42,7 +42,7 @@ RUN python3 -m pip install --no-cache-dir --upgrade pip uv==0.1.11 && \
huggingface-hub \ huggingface-hub \
Jinja2 \ Jinja2 \
librosa \ librosa \
numpy \ numpy==1.26.4 \
scipy \ scipy \
tensorboard \ tensorboard \
transformers transformers
+1 -1
View File
@@ -40,7 +40,7 @@ RUN python3 -m pip install --no-cache-dir --upgrade pip uv==0.1.11 && \
huggingface-hub \ huggingface-hub \
Jinja2 \ Jinja2 \
librosa \ librosa \
numpy \ numpy==1.26.4 \
scipy \ scipy \
tensorboard \ tensorboard \
transformers transformers
+1 -1
View File
@@ -40,7 +40,7 @@ RUN python3.10 -m pip install --no-cache-dir --upgrade pip uv==0.1.11 && \
huggingface-hub \ huggingface-hub \
Jinja2 \ Jinja2 \
librosa \ librosa \
numpy \ numpy==1.26.4 \
scipy \ scipy \
tensorboard \ tensorboard \
transformers transformers
@@ -39,7 +39,7 @@ RUN python3.10 -m pip install --no-cache-dir --upgrade pip uv==0.1.11 && \
huggingface-hub \ huggingface-hub \
Jinja2 \ Jinja2 \
librosa \ librosa \
numpy \ numpy==1.26.4 \
scipy \ scipy \
tensorboard \ tensorboard \
transformers transformers
+1 -1
View File
@@ -40,7 +40,7 @@ RUN python3.10 -m pip install --no-cache-dir --upgrade pip uv==0.1.11 && \
huggingface-hub \ huggingface-hub \
Jinja2 \ Jinja2 \
librosa \ librosa \
numpy \ numpy==1.26.4 \
scipy \ scipy \
tensorboard \ tensorboard \
transformers matplotlib transformers matplotlib
+1 -1
View File
@@ -39,7 +39,7 @@ RUN python3.10 -m pip install --no-cache-dir --upgrade pip uv==0.1.11 && \
huggingface-hub \ huggingface-hub \
Jinja2 \ Jinja2 \
librosa \ librosa \
numpy \ numpy==1.26.4 \
scipy \ scipy \
tensorboard \ tensorboard \
transformers \ transformers \
@@ -39,7 +39,7 @@ RUN python3.10 -m pip install --no-cache-dir --upgrade pip uv==0.1.11 && \
huggingface-hub \ huggingface-hub \
Jinja2 \ Jinja2 \
librosa \ librosa \
numpy \ numpy==1.26.4 \
scipy \ scipy \
tensorboard \ tensorboard \
transformers \ transformers \
+4
View File
@@ -253,6 +253,8 @@
title: PriorTransformer title: PriorTransformer
- local: api/models/controlnet - local: api/models/controlnet
title: ControlNetModel title: ControlNetModel
- local: api/models/controlnet_sd3
title: SD3ControlNetModel
title: Models title: Models
- isExpanded: false - isExpanded: false
sections: sections:
@@ -276,6 +278,8 @@
title: Consistency Models title: Consistency Models
- local: api/pipelines/controlnet - local: api/pipelines/controlnet
title: ControlNet title: ControlNet
- local: api/pipelines/controlnet_sd3
title: ControlNet with Stable Diffusion 3
- local: api/pipelines/controlnet_sdxl - local: api/pipelines/controlnet_sdxl
title: ControlNet with Stable Diffusion XL title: ControlNet with Stable Diffusion XL
- local: api/pipelines/controlnetxs - local: api/pipelines/controlnetxs
-6
View File
@@ -41,12 +41,6 @@ An attention processor is a class for applying different types of attention mech
## FusedAttnProcessor2_0 ## FusedAttnProcessor2_0
[[autodoc]] models.attention_processor.FusedAttnProcessor2_0 [[autodoc]] models.attention_processor.FusedAttnProcessor2_0
## LoRAAttnAddedKVProcessor
[[autodoc]] models.attention_processor.LoRAAttnAddedKVProcessor
## LoRAXFormersAttnProcessor
[[autodoc]] models.attention_processor.LoRAXFormersAttnProcessor
## SlicedAttnProcessor ## SlicedAttnProcessor
[[autodoc]] models.attention_processor.SlicedAttnProcessor [[autodoc]] models.attention_processor.SlicedAttnProcessor
@@ -35,6 +35,7 @@ The [`~loaders.FromSingleFileMixin.from_single_file`] method allows you to load:
- [`StableDiffusionXLInstructPix2PixPipeline`] - [`StableDiffusionXLInstructPix2PixPipeline`]
- [`StableDiffusionXLControlNetPipeline`] - [`StableDiffusionXLControlNetPipeline`]
- [`StableDiffusionXLKDiffusionPipeline`] - [`StableDiffusionXLKDiffusionPipeline`]
- [`StableDiffusion3Pipeline`]
- [`LatentConsistencyModelPipeline`] - [`LatentConsistencyModelPipeline`]
- [`LatentConsistencyModelImg2ImgPipeline`] - [`LatentConsistencyModelImg2ImgPipeline`]
- [`StableDiffusionControlNetXSPipeline`] - [`StableDiffusionControlNetXSPipeline`]
@@ -49,6 +50,7 @@ The [`~loaders.FromSingleFileMixin.from_single_file`] method allows you to load:
- [`StableCascadeUNet`] - [`StableCascadeUNet`]
- [`AutoencoderKL`] - [`AutoencoderKL`]
- [`ControlNetModel`] - [`ControlNetModel`]
- [`SD3Transformer2DModel`]
## FromSingleFileMixin ## FromSingleFileMixin
@@ -0,0 +1,42 @@
<!--Copyright 2024 The HuggingFace Team and The InstantX Team. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
specific language governing permissions and limitations under the License.
-->
# SD3ControlNetModel
SD3ControlNetModel is an implementation of ControlNet for Stable Diffusion 3.
The ControlNet model was introduced in [Adding Conditional Control to Text-to-Image Diffusion Models](https://huggingface.co/papers/2302.05543) by Lvmin Zhang, Anyi Rao, Maneesh Agrawala. It provides a greater degree of control over text-to-image generation by conditioning the model on additional inputs such as edge maps, depth maps, segmentation maps, and keypoints for pose detection.
The abstract from the paper is:
*We present ControlNet, a neural network architecture to add spatial conditioning controls to large, pretrained text-to-image diffusion models. ControlNet locks the production-ready large diffusion models, and reuses their deep and robust encoding layers pretrained with billions of images as a strong backbone to learn a diverse set of conditional controls. The neural architecture is connected with "zero convolutions" (zero-initialized convolution layers) that progressively grow the parameters from zero and ensure that no harmful noise could affect the finetuning. We test various conditioning controls, eg, edges, depth, segmentation, human pose, etc, with Stable Diffusion, using single or multiple conditions, with or without prompts. We show that the training of ControlNets is robust with small (<50k) and large (>1m) datasets. Extensive results show that ControlNet may facilitate wider applications to control image diffusion models.*
## Loading from the original format
By default the [`SD3ControlNetModel`] should be loaded with [`~ModelMixin.from_pretrained`].
```py
from diffusers import StableDiffusion3ControlNetPipeline
from diffusers.models import SD3ControlNetModel, SD3MultiControlNetModel
controlnet = SD3ControlNetModel.from_pretrained("InstantX/SD3-Controlnet-Canny")
pipe = StableDiffusion3ControlNetPipeline.from_pretrained("stabilityai/stable-diffusion-3-medium-diffusers", controlnet=controlnet)
```
## SD3ControlNetModel
[[autodoc]] SD3ControlNetModel
## SD3ControlNetOutput
[[autodoc]] models.controlnet_sd3.SD3ControlNetOutput
@@ -78,7 +78,6 @@ output = pipe(
) )
frames = output.frames[0] frames = output.frames[0]
export_to_gif(frames, "animation.gif") export_to_gif(frames, "animation.gif")
``` ```
Here are some sample outputs: Here are some sample outputs:
@@ -303,7 +302,6 @@ output = pipe(
) )
frames = output.frames[0] frames = output.frames[0]
export_to_gif(frames, "animation.gif") export_to_gif(frames, "animation.gif")
``` ```
<table> <table>
@@ -378,7 +376,6 @@ output = pipe(
) )
frames = output.frames[0] frames = output.frames[0]
export_to_gif(frames, "animation.gif") export_to_gif(frames, "animation.gif")
``` ```
<table> <table>
@@ -0,0 +1,39 @@
<!--Copyright 2023 The HuggingFace Team and The InstantX Team. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
specific language governing permissions and limitations under the License.
-->
# ControlNet with Stable Diffusion 3
StableDiffusion3ControlNetPipeline is an implementation of ControlNet for Stable Diffusion 3.
ControlNet was introduced in [Adding Conditional Control to Text-to-Image Diffusion Models](https://huggingface.co/papers/2302.05543) by Lvmin Zhang, Anyi Rao, and Maneesh Agrawala.
With a ControlNet model, you can provide an additional control image to condition and control Stable Diffusion generation. For example, if you provide a depth map, the ControlNet model generates an image that'll preserve the spatial information from the depth map. It is a more flexible and accurate way to control the image generation process.
The abstract from the paper is:
*We present ControlNet, a neural network architecture to add spatial conditioning controls to large, pretrained text-to-image diffusion models. ControlNet locks the production-ready large diffusion models, and reuses their deep and robust encoding layers pretrained with billions of images as a strong backbone to learn a diverse set of conditional controls. The neural architecture is connected with "zero convolutions" (zero-initialized convolution layers) that progressively grow the parameters from zero and ensure that no harmful noise could affect the finetuning. We test various conditioning controls, eg, edges, depth, segmentation, human pose, etc, with Stable Diffusion, using single or multiple conditions, with or without prompts. We show that the training of ControlNets is robust with small (<50k) and large (>1m) datasets. Extensive results show that ControlNet may facilitate wider applications to control image diffusion models.*
This code is implemented by [The InstantX Team](https://huggingface.co/InstantX). You can find pre-trained checkpoints for SD3-ControlNet on [The InstantX Team](https://huggingface.co/InstantX) Hub profile.
<Tip>
Make sure to check out the Schedulers [guide](../../using-diffusers/schedulers) to learn how to explore the tradeoff between scheduler speed and quality, and see the [reuse components across pipelines](../../using-diffusers/loading#reuse-components-across-pipelines) section to learn how to efficiently load the same components into multiple pipelines.
</Tip>
## StableDiffusion3ControlNetPipeline
[[autodoc]] StableDiffusion3ControlNetPipeline
- all
- __call__
## StableDiffusion3PipelineOutput
[[autodoc]] pipelines.stable_diffusion_3.pipeline_output.StableDiffusion3PipelineOutput
@@ -186,7 +186,7 @@ pipe.transformer = torch.compile(pipe.transformer, mode="max-autotune", fullgrap
pipe.vae.decode = torch.compile(pipe.vae.decode, mode="max-autotune", fullgraph=True) pipe.vae.decode = torch.compile(pipe.vae.decode, mode="max-autotune", fullgraph=True)
# Warm Up # Warm Up
prompt = "a photo of a cat holding a sign that says hello world", prompt = "a photo of a cat holding a sign that says hello world"
for _ in range(3): for _ in range(3):
_ = pipe(prompt=prompt, generator=torch.manual_seed(1)) _ = pipe(prompt=prompt, generator=torch.manual_seed(1))
@@ -197,6 +197,27 @@ image.save("sd3_hello_world.png")
Check out the full script [here](https://gist.github.com/sayakpaul/508d89d7aad4f454900813da5d42ca97). Check out the full script [here](https://gist.github.com/sayakpaul/508d89d7aad4f454900813da5d42ca97).
## Tiny AutoEncoder for Stable Diffusion 3
Tiny AutoEncoder for Stable Diffusion (TAESD3) is a tiny distilled version of Stable Diffusion 3's VAE by [Ollin Boer Bohan](https://github.com/madebyollin/taesd) that can decode [`StableDiffusion3Pipeline`] latents almost instantly.
To use with Stable Diffusion 3:
```python
import torch
from diffusers import StableDiffusion3Pipeline, AutoencoderTiny
pipe = StableDiffusion3Pipeline.from_pretrained(
"stabilityai/stable-diffusion-3-medium-diffusers", torch_dtype=torch.float16
)
pipe.vae = AutoencoderTiny.from_pretrained("madebyollin/taesd3", torch_dtype=torch.float16)
pipe = pipe.to("cuda")
prompt = "slice of delicious New York-style berry cheesecake"
image = pipe(prompt, num_inference_steps=25).images[0]
image.save("cheesecake.png")
```
## Loading the original checkpoints via `from_single_file` ## Loading the original checkpoints via `from_single_file`
The `SD3Transformer2DModel` and `StableDiffusion3Pipeline` classes support loading the original checkpoints via the `from_single_file` method. This method allows you to load the original checkpoint files that were used to train the models. The `SD3Transformer2DModel` and `StableDiffusion3Pipeline` classes support loading the original checkpoints via the `from_single_file` method. This method allows you to load the original checkpoint files that were used to train the models.
@@ -211,17 +232,38 @@ model = SD3Transformer2DModel.from_single_file("https://huggingface.co/stability
## Loading the single checkpoint for the `StableDiffusion3Pipeline` ## Loading the single checkpoint for the `StableDiffusion3Pipeline`
```python ### Loading the single file checkpoint without T5
from diffusers import StableDiffusion3Pipeline
from transformers import T5EncoderModel
text_encoder_3 = T5EncoderModel.from_pretrained("stabilityai/stable-diffusion-3-medium-diffusers", subfolder="text_encoder_3", torch_dtype=torch.float16) ```python
pipe = StableDiffusion3Pipeline.from_single_file("https://huggingface.co/stabilityai/stable-diffusion-3-medium/blob/main/sd3_medium_incl_clips.safetensors", torch_dtype=torch.float16, text_encoder_3=text_encoder_3) import torch
from diffusers import StableDiffusion3Pipeline
pipe = StableDiffusion3Pipeline.from_single_file(
"https://huggingface.co/stabilityai/stable-diffusion-3-medium/blob/main/sd3_medium_incl_clips.safetensors",
torch_dtype=torch.float16,
text_encoder_3=None
)
pipe.enable_model_cpu_offload()
image = pipe("a picture of a cat holding a sign that says hello world").images[0]
image.save('sd3-single-file.png')
``` ```
<Tip> ### Loading the single file checkpoint with T5
`from_single_file` support for the `fp8` version of the checkpoints is coming soon. Watch this space.
</Tip> ```python
import torch
from diffusers import StableDiffusion3Pipeline
pipe = StableDiffusion3Pipeline.from_single_file(
"https://huggingface.co/stabilityai/stable-diffusion-3-medium/blob/main/sd3_medium_incl_clips_t5xxlfp8.safetensors",
torch_dtype=torch.float16,
)
pipe.enable_model_cpu_offload()
image = pipe("a picture of a cat holding a sign that says hello world").images[0]
image.save('sd3-single-file-t5-fp8.png')
```
## StableDiffusion3Pipeline ## StableDiffusion3Pipeline
+10 -10
View File
@@ -22,14 +22,13 @@ We enormously value feedback from the community, so please do not be afraid to s
## Overview ## Overview
You can contribute in many ways ranging from answering questions on issues to adding new diffusion models to You can contribute in many ways ranging from answering questions on issues and discussions to adding new diffusion models to the core library.
the core library.
In the following, we give an overview of different ways to contribute, ranked by difficulty in ascending order. All of them are valuable to the community. In the following, we give an overview of different ways to contribute, ranked by difficulty in ascending order. All of them are valuable to the community.
* 1. Asking and answering questions on [the Diffusers discussion forum](https://discuss.huggingface.co/c/discussion-related-to-httpsgithubcomhuggingfacediffusers) or on [Discord](https://discord.gg/G7tWnz98XR). * 1. Asking and answering questions on [the Diffusers discussion forum](https://discuss.huggingface.co/c/discussion-related-to-httpsgithubcomhuggingfacediffusers) or on [Discord](https://discord.gg/G7tWnz98XR).
* 2. Opening new issues on [the GitHub Issues tab](https://github.com/huggingface/diffusers/issues/new/choose). * 2. Opening new issues on [the GitHub Issues tab](https://github.com/huggingface/diffusers/issues/new/choose) or new discussions on [the GitHub Discussions tab](https://github.com/huggingface/diffusers/discussions/new/choose).
* 3. Answering issues on [the GitHub Issues tab](https://github.com/huggingface/diffusers/issues). * 3. Answering issues on [the GitHub Issues tab](https://github.com/huggingface/diffusers/issues) or discussions on [the GitHub Discussions tab](https://github.com/huggingface/diffusers/discussions).
* 4. Fix a simple issue, marked by the "Good first issue" label, see [here](https://github.com/huggingface/diffusers/issues?q=is%3Aopen+is%3Aissue+label%3A%22good+first+issue%22). * 4. Fix a simple issue, marked by the "Good first issue" label, see [here](https://github.com/huggingface/diffusers/issues?q=is%3Aopen+is%3Aissue+label%3A%22good+first+issue%22).
* 5. Contribute to the [documentation](https://github.com/huggingface/diffusers/tree/main/docs/source). * 5. Contribute to the [documentation](https://github.com/huggingface/diffusers/tree/main/docs/source).
* 6. Contribute a [Community Pipeline](https://github.com/huggingface/diffusers/issues?q=is%3Aopen+is%3Aissue+label%3Acommunity-examples). * 6. Contribute a [Community Pipeline](https://github.com/huggingface/diffusers/issues?q=is%3Aopen+is%3Aissue+label%3Acommunity-examples).
@@ -63,7 +62,7 @@ In the same spirit, you are of immense help to the community by answering such q
**Please** keep in mind that the more effort you put into asking or answering a question, the higher **Please** keep in mind that the more effort you put into asking or answering a question, the higher
the quality of the publicly documented knowledge. In the same way, well-posed and well-answered questions create a high-quality knowledge database accessible to everybody, while badly posed questions or answers reduce the overall quality of the public knowledge database. the quality of the publicly documented knowledge. In the same way, well-posed and well-answered questions create a high-quality knowledge database accessible to everybody, while badly posed questions or answers reduce the overall quality of the public knowledge database.
In short, a high quality question or answer is *precise*, *concise*, *relevant*, *easy-to-understand*, *accessible*, and *well-formated/well-posed*. For more information, please have a look through the [How to write a good issue](#how-to-write-a-good-issue) section. In short, a high quality question or answer is *precise*, *concise*, *relevant*, *easy-to-understand*, *accessible*, and *well-formatted/well-posed*. For more information, please have a look through the [How to write a good issue](#how-to-write-a-good-issue) section.
**NOTE about channels**: **NOTE about channels**:
[*The forum*](https://discuss.huggingface.co/c/discussion-related-to-httpsgithubcomhuggingfacediffusers/63) is much better indexed by search engines, such as Google. Posts are ranked by popularity rather than chronologically. Hence, it's easier to look up questions and answers that we posted some time ago. [*The forum*](https://discuss.huggingface.co/c/discussion-related-to-httpsgithubcomhuggingfacediffusers/63) is much better indexed by search engines, such as Google. Posts are ranked by popularity rather than chronologically. Hence, it's easier to look up questions and answers that we posted some time ago.
@@ -99,7 +98,7 @@ This means in more detail:
- Format your code. - Format your code.
- Do not include any external libraries except for Diffusers depending on them. - Do not include any external libraries except for Diffusers depending on them.
- **Always** provide all necessary information about your environment; for this, you can run: `diffusers-cli env` in your shell and copy-paste the displayed information to the issue. - **Always** provide all necessary information about your environment; for this, you can run: `diffusers-cli env` in your shell and copy-paste the displayed information to the issue.
- Explain the issue. If the reader doesn't know what the issue is and why it is an issue, she cannot solve it. - Explain the issue. If the reader doesn't know what the issue is and why it is an issue, (s)he cannot solve it.
- **Always** make sure the reader can reproduce your issue with as little effort as possible. If your code snippet cannot be run because of missing libraries or undefined variables, the reader cannot help you. Make sure your reproducible code snippet is as minimal as possible and can be copy-pasted into a simple Python shell. - **Always** make sure the reader can reproduce your issue with as little effort as possible. If your code snippet cannot be run because of missing libraries or undefined variables, the reader cannot help you. Make sure your reproducible code snippet is as minimal as possible and can be copy-pasted into a simple Python shell.
- If in order to reproduce your issue a model and/or dataset is required, make sure the reader has access to that model or dataset. You can always upload your model or dataset to the [Hub](https://huggingface.co) to make it easily downloadable. Try to keep your model and dataset as small as possible, to make the reproduction of your issue as effortless as possible. - If in order to reproduce your issue a model and/or dataset is required, make sure the reader has access to that model or dataset. You can always upload your model or dataset to the [Hub](https://huggingface.co) to make it easily downloadable. Try to keep your model and dataset as small as possible, to make the reproduction of your issue as effortless as possible.
@@ -288,7 +287,7 @@ The official training examples are maintained by the Diffusers' core maintainers
This is because of the same reasons put forward in [6. Contribute a community pipeline](#6-contribute-a-community-pipeline) for official pipelines vs. community pipelines: It is not feasible for the core maintainers to maintain all possible training methods for diffusion models. This is because of the same reasons put forward in [6. Contribute a community pipeline](#6-contribute-a-community-pipeline) for official pipelines vs. community pipelines: It is not feasible for the core maintainers to maintain all possible training methods for diffusion models.
If the Diffusers core maintainers and the community consider a certain training paradigm to be too experimental or not popular enough, the corresponding training code should be put in the `research_projects` folder and maintained by the author. If the Diffusers core maintainers and the community consider a certain training paradigm to be too experimental or not popular enough, the corresponding training code should be put in the `research_projects` folder and maintained by the author.
Both official training and research examples consist of a directory that contains one or more training scripts, a requirements.txt file, and a README.md file. In order for the user to make use of the Both official training and research examples consist of a directory that contains one or more training scripts, a `requirements.txt` file, and a `README.md` file. In order for the user to make use of the
training examples, it is required to clone the repository: training examples, it is required to clone the repository:
```bash ```bash
@@ -298,7 +297,8 @@ git clone https://github.com/huggingface/diffusers
as well as to install all additional dependencies required for training: as well as to install all additional dependencies required for training:
```bash ```bash
pip install -r /examples/<your-example-folder>/requirements.txt cd diffusers
pip install -r examples/<your-example-folder>/requirements.txt
``` ```
Therefore when adding an example, the `requirements.txt` file shall define all pip dependencies required for your training example so that once all those are installed, the user can run the example's training script. See, for example, the [DreamBooth `requirements.txt` file](https://github.com/huggingface/diffusers/blob/main/examples/dreambooth/requirements.txt). Therefore when adding an example, the `requirements.txt` file shall define all pip dependencies required for your training example so that once all those are installed, the user can run the example's training script. See, for example, the [DreamBooth `requirements.txt` file](https://github.com/huggingface/diffusers/blob/main/examples/dreambooth/requirements.txt).
@@ -316,7 +316,7 @@ Once an example script works, please make sure to add a comprehensive `README.md
- A link to some training results (logs, models, etc.) that show what the user can expect as shown [here](https://api.wandb.ai/report/patrickvonplaten/xm6cd5q5). - A link to some training results (logs, models, etc.) that show what the user can expect as shown [here](https://api.wandb.ai/report/patrickvonplaten/xm6cd5q5).
- If you are adding a non-official/research training example, **please don't forget** to add a sentence that you are maintaining this training example which includes your git handle as shown [here](https://github.com/huggingface/diffusers/tree/main/examples/research_projects/intel_opts#diffusers-examples-with-intel-optimizations). - If you are adding a non-official/research training example, **please don't forget** to add a sentence that you are maintaining this training example which includes your git handle as shown [here](https://github.com/huggingface/diffusers/tree/main/examples/research_projects/intel_opts#diffusers-examples-with-intel-optimizations).
If you are contributing to the official training examples, please also make sure to add a test to [examples/test_examples.py](https://github.com/huggingface/diffusers/blob/main/examples/test_examples.py). This is not necessary for non-official training examples. If you are contributing to the official training examples, please also make sure to add a test to its folder such as [examples/dreambooth/test_dreambooth.py](https://github.com/huggingface/diffusers/blob/main/examples/dreambooth/test_dreambooth.py). This is not necessary for non-official training examples.
### 8. Fixing a "Good second issue" ### 8. Fixing a "Good second issue"
@@ -418,7 +418,7 @@ You will need basic `git` proficiency to be able to contribute to
manual. Type `git --help` in a shell and enjoy. If you prefer books, [Pro manual. Type `git --help` in a shell and enjoy. If you prefer books, [Pro
Git](https://git-scm.com/book/en/v2) is a very good reference. Git](https://git-scm.com/book/en/v2) is a very good reference.
Follow these steps to start contributing ([supported Python versions](https://github.com/huggingface/diffusers/blob/main/setup.py#L244)): Follow these steps to start contributing ([supported Python versions](https://github.com/huggingface/diffusers/blob/83bc6c94eaeb6f7704a2a428931cf2d9ad973ae9/setup.py#L270)):
1. Fork the [repository](https://github.com/huggingface/diffusers) by 1. Fork the [repository](https://github.com/huggingface/diffusers) by
clicking on the 'Fork' button on the repository's page. This creates a copy of the code clicking on the 'Fork' button on the repository's page. This creates a copy of the code
+1 -1
View File
@@ -349,7 +349,7 @@ control_image = load_image("./conditioning_image_1.png")
prompt = "pale golden rod circle with old lace background" prompt = "pale golden rod circle with old lace background"
generator = torch.manual_seed(0) generator = torch.manual_seed(0)
image = pipe(prompt, num_inference_steps=20, generator=generator, image=control_image).images[0] image = pipeline(prompt, num_inference_steps=20, generator=generator, image=control_image).images[0]
image.save("./output.png") image.save("./output.png")
``` ```
+1 -1
View File
@@ -181,7 +181,7 @@ accelerate launch --mixed_precision="fp16" train_text_to_image.py \
--max_train_steps=15000 \ --max_train_steps=15000 \
--learning_rate=1e-05 \ --learning_rate=1e-05 \
--max_grad_norm=1 \ --max_grad_norm=1 \
--enable_xformers_memory_efficient_attention --enable_xformers_memory_efficient_attention \
--lr_scheduler="constant" --lr_warmup_steps=0 \ --lr_scheduler="constant" --lr_warmup_steps=0 \
--output_dir="sd-naruto-model" \ --output_dir="sd-naruto-model" \
--push_to_hub --push_to_hub
+6 -6
View File
@@ -34,13 +34,10 @@ Install [PyTorch nightly](https://pytorch.org/) to benefit from the latest and f
pip3 install --pre torch --index-url https://download.pytorch.org/whl/nightly/cu121 pip3 install --pre torch --index-url https://download.pytorch.org/whl/nightly/cu121
``` ```
<Tip> > [!TIP]
> The results reported below are from a 80GB 400W A100 with its clock rate set to the maximum.
> If you're interested in the full benchmarking code, take a look at [huggingface/diffusion-fast](https://github.com/huggingface/diffusion-fast).
The results reported below are from a 80GB 400W A100 with its clock rate set to the maximum. <br>
If you're interested in the full benchmarking code, take a look at [huggingface/diffusion-fast](https://github.com/huggingface/diffusion-fast).
</Tip>
## Baseline ## Baseline
@@ -170,6 +167,9 @@ Using SDPA attention and compiling both the UNet and VAE cuts the latency from 3
<img src="https://huggingface.co/datasets/sayakpaul/sample-datasets/resolve/main/progressive-acceleration-sdxl/SDXL%2C_Batch_Size%3A_1%2C_Steps%3A_30_3.png" width=500> <img src="https://huggingface.co/datasets/sayakpaul/sample-datasets/resolve/main/progressive-acceleration-sdxl/SDXL%2C_Batch_Size%3A_1%2C_Steps%3A_30_3.png" width=500>
</div> </div>
> [!TIP]
> From PyTorch 2.3.1, you can control the caching behavior of `torch.compile()`. This is particularly beneficial for compilation modes like `"max-autotune"` which performs a grid-search over several compilation flags to find the optimal configuration. Learn more in the [Compile Time Caching in torch.compile](https://pytorch.org/tutorials/recipes/torch_compile_caching_tutorial.html) tutorial.
### Prevent graph breaks ### Prevent graph breaks
Specifying `fullgraph=True` ensures there are no graph breaks in the underlying model to take full advantage of `torch.compile` without any performance degradation. For the UNet and VAE, this means changing how you access the return variables. Specifying `fullgraph=True` ensures there are no graph breaks in the underlying model to take full advantage of `torch.compile` without any performance degradation. For the UNet and VAE, this means changing how you access the return variables.
@@ -472,7 +472,6 @@ my_local_config_path = snapshot_download(
local_dir_use_symlinks=False, local_dir_use_symlinks=False,
) )
print("My local config: ", my_local_config_path) print("My local config: ", my_local_config_path)
``` ```
Then you can pass the local paths to the `pretrained_model_link_or_path` and `config` parameters. Then you can pass the local paths to the `pretrained_model_link_or_path` and `config` parameters.
@@ -436,7 +436,7 @@ lora_path = "lora-library/B-LoRA-pen_sketch"
state_dict = lora_lora_unet_blocks(content_B_lora_path,alpha=1,target_blocks=["unet.up_blocks.0.attentions.0"]) state_dict = lora_lora_unet_blocks(content_B_lora_path,alpha=1,target_blocks=["unet.up_blocks.0.attentions.0"])
# Load traine dlora layers into the unet # Load trained lora layers into the unet
pipeline.load_lora_into_unet(state_dict, None, pipeline.unet) pipeline.load_lora_into_unet(state_dict, None, pipeline.unet)
#generate #generate
@@ -71,7 +71,7 @@ from diffusers.utils.import_utils import is_xformers_available
# Will error if the minimal version of diffusers is not installed. Remove at your own risks. # Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.29.0.dev0") check_min_version("0.30.0.dev0")
logger = get_logger(__name__) logger = get_logger(__name__)
@@ -326,7 +326,7 @@ def parse_args(input_args=None):
type=str, type=str,
default="TOK", default="TOK",
help="identifier specifying the instance(or instances) as used in instance_prompt, validation prompt, " help="identifier specifying the instance(or instances) as used in instance_prompt, validation prompt, "
"captions - e.g. TOK. To use multiple identifiers, please specify them in a comma seperated string - e.g. " "captions - e.g. TOK. To use multiple identifiers, please specify them in a comma separated string - e.g. "
"'TOK,TOK2,TOK3' etc.", "'TOK,TOK2,TOK3' etc.",
) )
@@ -559,7 +559,7 @@ def parse_args(input_args=None):
"--prodigy_beta3", "--prodigy_beta3",
type=float, type=float,
default=None, default=None,
help="coefficients for computing the Prodidy stepsize using running averages. If set to None, " help="coefficients for computing the Prodigy stepsize using running averages. If set to None, "
"uses the value of square root of beta2. Ignored if optimizer is adamW", "uses the value of square root of beta2. Ignored if optimizer is adamW",
) )
parser.add_argument("--prodigy_decouple", type=bool, default=True, help="Use AdamW style decoupled weight decay") parser.add_argument("--prodigy_decouple", type=bool, default=True, help="Use AdamW style decoupled weight decay")
@@ -736,7 +736,7 @@ class TokenEmbeddingsHandler:
# random initialization of new tokens # random initialization of new tokens
std_token_embedding = text_encoder.text_model.embeddings.token_embedding.weight.data.std() std_token_embedding = text_encoder.text_model.embeddings.token_embedding.weight.data.std()
print(f"{idx} text encodedr's std_token_embedding: {std_token_embedding}") print(f"{idx} text encoder's std_token_embedding: {std_token_embedding}")
text_encoder.text_model.embeddings.token_embedding.weight.data[self.train_ids] = ( text_encoder.text_model.embeddings.token_embedding.weight.data[self.train_ids] = (
torch.randn(len(self.train_ids), text_encoder.text_model.config.hidden_size) torch.randn(len(self.train_ids), text_encoder.text_model.config.hidden_size)
@@ -948,7 +948,7 @@ class DreamBoothDataset(Dataset):
else: else:
example["instance_prompt"] = self.instance_prompt example["instance_prompt"] = self.instance_prompt
else: # costum prompts were provided, but length does not match size of image dataset else: # custom prompts were provided, but length does not match size of image dataset
example["instance_prompt"] = self.instance_prompt example["instance_prompt"] = self.instance_prompt
if self.class_data_root: if self.class_data_root:
@@ -1967,7 +1967,7 @@ def main(args):
} }
) )
# Conver to WebUI format # Convert to WebUI format
lora_state_dict = load_file(f"{args.output_dir}/pytorch_lora_weights.safetensors") lora_state_dict = load_file(f"{args.output_dir}/pytorch_lora_weights.safetensors")
peft_state_dict = convert_all_state_dict_to_peft(lora_state_dict) peft_state_dict = convert_all_state_dict_to_peft(lora_state_dict)
kohya_state_dict = convert_state_dict_to_kohya(peft_state_dict) kohya_state_dict = convert_state_dict_to_kohya(peft_state_dict)
@@ -78,7 +78,7 @@ from diffusers.utils.torch_utils import is_compiled_module
# Will error if the minimal version of diffusers is not installed. Remove at your own risks. # Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.29.0.dev0") check_min_version("0.30.0.dev0")
logger = get_logger(__name__) logger = get_logger(__name__)
@@ -348,7 +348,7 @@ def parse_args(input_args=None):
type=str, type=str,
default="TOK", default="TOK",
help="identifier specifying the instance(or instances) as used in instance_prompt, validation prompt, " help="identifier specifying the instance(or instances) as used in instance_prompt, validation prompt, "
"captions - e.g. TOK. To use multiple identifiers, please specify them in a comma seperated string - e.g. " "captions - e.g. TOK. To use multiple identifiers, please specify them in a comma separated string - e.g. "
"'TOK,TOK2,TOK3' etc.", "'TOK,TOK2,TOK3' etc.",
) )
@@ -591,7 +591,7 @@ def parse_args(input_args=None):
"--prodigy_beta3", "--prodigy_beta3",
type=float, type=float,
default=None, default=None,
help="coefficients for computing the Prodidy stepsize using running averages. If set to None, " help="coefficients for computing the Prodigy stepsize using running averages. If set to None, "
"uses the value of square root of beta2. Ignored if optimizer is adamW", "uses the value of square root of beta2. Ignored if optimizer is adamW",
) )
parser.add_argument("--prodigy_decouple", type=bool, default=True, help="Use AdamW style decoupled weight decay") parser.add_argument("--prodigy_decouple", type=bool, default=True, help="Use AdamW style decoupled weight decay")
@@ -824,7 +824,7 @@ class TokenEmbeddingsHandler:
# random initialization of new tokens # random initialization of new tokens
std_token_embedding = text_encoder.text_model.embeddings.token_embedding.weight.data.std() std_token_embedding = text_encoder.text_model.embeddings.token_embedding.weight.data.std()
print(f"{idx} text encodedr's std_token_embedding: {std_token_embedding}") print(f"{idx} text encoder's std_token_embedding: {std_token_embedding}")
text_encoder.text_model.embeddings.token_embedding.weight.data[self.train_ids] = ( text_encoder.text_model.embeddings.token_embedding.weight.data[self.train_ids] = (
torch.randn(len(self.train_ids), text_encoder.text_model.config.hidden_size) torch.randn(len(self.train_ids), text_encoder.text_model.config.hidden_size)
@@ -1097,7 +1097,7 @@ class DreamBoothDataset(Dataset):
else: else:
example["instance_prompt"] = self.instance_prompt example["instance_prompt"] = self.instance_prompt
else: # costum prompts were provided, but length does not match size of image dataset else: # custom prompts were provided, but length does not match size of image dataset
example["instance_prompt"] = self.instance_prompt example["instance_prompt"] = self.instance_prompt
if self.class_data_root: if self.class_data_root:
@@ -1794,7 +1794,7 @@ def main(args):
if args.with_prior_preservation: if args.with_prior_preservation:
prompt_embeds = torch.cat([prompt_embeds, class_prompt_hidden_states], dim=0) prompt_embeds = torch.cat([prompt_embeds, class_prompt_hidden_states], dim=0)
unet_add_text_embeds = torch.cat([unet_add_text_embeds, class_pooled_prompt_embeds], dim=0) unet_add_text_embeds = torch.cat([unet_add_text_embeds, class_pooled_prompt_embeds], dim=0)
# if we're optmizing the text encoder (both if instance prompt is used for all images or custom prompts) we need to tokenize and encode the # if we're optimizing the text encoder (both if instance prompt is used for all images or custom prompts) we need to tokenize and encode the
# batch prompts on all training steps # batch prompts on all training steps
else: else:
tokens_one = tokenize_prompt(tokenizer_one, args.instance_prompt, add_special_tokens) tokens_one = tokenize_prompt(tokenizer_one, args.instance_prompt, add_special_tokens)
@@ -2411,7 +2411,7 @@ def main(args):
} }
) )
# Conver to WebUI format # Convert to WebUI format
lora_state_dict = load_file(f"{args.output_dir}/pytorch_lora_weights.safetensors") lora_state_dict = load_file(f"{args.output_dir}/pytorch_lora_weights.safetensors")
peft_state_dict = convert_all_state_dict_to_peft(lora_state_dict) peft_state_dict = convert_all_state_dict_to_peft(lora_state_dict)
kohya_state_dict = convert_state_dict_to_kohya(peft_state_dict) kohya_state_dict = convert_state_dict_to_kohya(peft_state_dict)
+1 -1
View File
@@ -3595,7 +3595,7 @@ This pipeline provides drag-and-drop image editing using stochastic differential
![SDE Drag Image](https://github.com/huggingface/diffusers/assets/75928535/bd54f52f-f002-4951-9934-b2a4592771a5) ![SDE Drag Image](https://github.com/huggingface/diffusers/assets/75928535/bd54f52f-f002-4951-9934-b2a4592771a5)
See [paper](https://arxiv.org/abs/2311.01410), [paper page](https://ml-gsai.github.io/SDE-Drag-demo/), [original repo](https://github.com/ML-GSAI/SDE-Drag) for more infomation. See [paper](https://arxiv.org/abs/2311.01410), [paper page](https://ml-gsai.github.io/SDE-Drag-demo/), [original repo](https://github.com/ML-GSAI/SDE-Drag) for more information.
```py ```py
import PIL import PIL
+2 -12
View File
@@ -24,12 +24,7 @@ from diffusers import DiffusionPipeline, StableDiffusionXLPipeline
from diffusers.image_processor import PipelineImageInput, VaeImageProcessor from diffusers.image_processor import PipelineImageInput, VaeImageProcessor
from diffusers.loaders import FromSingleFileMixin, IPAdapterMixin, LoraLoaderMixin, TextualInversionLoaderMixin from diffusers.loaders import FromSingleFileMixin, IPAdapterMixin, LoraLoaderMixin, TextualInversionLoaderMixin
from diffusers.models import AutoencoderKL, ImageProjection, UNet2DConditionModel from diffusers.models import AutoencoderKL, ImageProjection, UNet2DConditionModel
from diffusers.models.attention_processor import ( from diffusers.models.attention_processor import AttnProcessor2_0, XFormersAttnProcessor
AttnProcessor2_0,
LoRAAttnProcessor2_0,
LoRAXFormersAttnProcessor,
XFormersAttnProcessor,
)
from diffusers.pipelines.pipeline_utils import StableDiffusionMixin from diffusers.pipelines.pipeline_utils import StableDiffusionMixin
from diffusers.pipelines.stable_diffusion_xl.pipeline_output import StableDiffusionXLPipelineOutput from diffusers.pipelines.stable_diffusion_xl.pipeline_output import StableDiffusionXLPipelineOutput
from diffusers.schedulers import KarrasDiffusionSchedulers from diffusers.schedulers import KarrasDiffusionSchedulers
@@ -1292,12 +1287,7 @@ class SDXLLongPromptWeightingPipeline(
self.vae.to(dtype=torch.float32) self.vae.to(dtype=torch.float32)
use_torch_2_0_or_xformers = isinstance( use_torch_2_0_or_xformers = isinstance(
self.vae.decoder.mid_block.attentions[0].processor, self.vae.decoder.mid_block.attentions[0].processor,
( (AttnProcessor2_0, XFormersAttnProcessor),
AttnProcessor2_0,
XFormersAttnProcessor,
LoRAXFormersAttnProcessor,
LoRAAttnProcessor2_0,
),
) )
# if xformers or torch_2_0 is used attention block does not need # if xformers or torch_2_0 is used attention block does not need
# to be in float32 which can save lots of memory # to be in float32 which can save lots of memory
@@ -43,7 +43,7 @@ from diffusers.utils import BaseOutput, check_min_version
# Will error if the minimal version of diffusers is not installed. Remove at your own risks. # Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.29.0.dev0") check_min_version("0.30.0.dev0")
class MarigoldDepthOutput(BaseOutput): class MarigoldDepthOutput(BaseOutput):
+4 -14
View File
@@ -16,12 +16,7 @@ from diffusers.loaders import (
TextualInversionLoaderMixin, TextualInversionLoaderMixin,
) )
from diffusers.models import AutoencoderKL, UNet2DConditionModel from diffusers.models import AutoencoderKL, UNet2DConditionModel
from diffusers.models.attention_processor import ( from diffusers.models.attention_processor import AttnProcessor2_0, XFormersAttnProcessor
AttnProcessor2_0,
LoRAAttnProcessor2_0,
LoRAXFormersAttnProcessor,
XFormersAttnProcessor,
)
from diffusers.models.lora import adjust_lora_scale_text_encoder from diffusers.models.lora import adjust_lora_scale_text_encoder
from diffusers.pipelines.pipeline_utils import DiffusionPipeline, StableDiffusionMixin from diffusers.pipelines.pipeline_utils import DiffusionPipeline, StableDiffusionMixin
from diffusers.schedulers import KarrasDiffusionSchedulers from diffusers.schedulers import KarrasDiffusionSchedulers
@@ -612,12 +607,7 @@ class DemoFusionSDXLPipeline(
self.vae.to(dtype=torch.float32) self.vae.to(dtype=torch.float32)
use_torch_2_0_or_xformers = isinstance( use_torch_2_0_or_xformers = isinstance(
self.vae.decoder.mid_block.attentions[0].processor, self.vae.decoder.mid_block.attentions[0].processor,
( (AttnProcessor2_0, XFormersAttnProcessor),
AttnProcessor2_0,
XFormersAttnProcessor,
LoRAXFormersAttnProcessor,
LoRAAttnProcessor2_0,
),
) )
# if xformers or torch_2_0 is used attention block does not need # if xformers or torch_2_0 is used attention block does not need
# to be in float32 which can save lots of memory # to be in float32 which can save lots of memory
@@ -805,10 +795,10 @@ class DemoFusionSDXLPipeline(
Control the strength of dilated sampling. For specific impacts, please refer to Appendix C Control the strength of dilated sampling. For specific impacts, please refer to Appendix C
in the DemoFusion paper. in the DemoFusion paper.
cosine_scale_3 (`float`, defaults to 1): cosine_scale_3 (`float`, defaults to 1):
Control the strength of the gaussion filter. For specific impacts, please refer to Appendix C Control the strength of the gaussian filter. For specific impacts, please refer to Appendix C
in the DemoFusion paper. in the DemoFusion paper.
sigma (`float`, defaults to 1): sigma (`float`, defaults to 1):
The standerd value of the gaussian filter. The standard value of the gaussian filter.
show_image (`bool`, defaults to False): show_image (`bool`, defaults to False):
Determine whether to show intermediate results during generation. Determine whether to show intermediate results during generation.
@@ -46,8 +46,6 @@ from diffusers.models.attention_processor import (
Attention, Attention,
AttnProcessor2_0, AttnProcessor2_0,
FusedAttnProcessor2_0, FusedAttnProcessor2_0,
LoRAAttnProcessor2_0,
LoRAXFormersAttnProcessor,
XFormersAttnProcessor, XFormersAttnProcessor,
) )
from diffusers.models.lora import adjust_lora_scale_text_encoder from diffusers.models.lora import adjust_lora_scale_text_encoder
@@ -1153,8 +1151,6 @@ class StyleAlignedSDXLPipeline(
( (
AttnProcessor2_0, AttnProcessor2_0,
XFormersAttnProcessor, XFormersAttnProcessor,
LoRAXFormersAttnProcessor,
LoRAAttnProcessor2_0,
FusedAttnProcessor2_0, FusedAttnProcessor2_0,
), ),
) )
@@ -25,12 +25,7 @@ from transformers import CLIPTextModel, CLIPTextModelWithProjection, CLIPTokeniz
from diffusers.image_processor import PipelineImageInput, VaeImageProcessor from diffusers.image_processor import PipelineImageInput, VaeImageProcessor
from diffusers.loaders import FromSingleFileMixin, StableDiffusionXLLoraLoaderMixin, TextualInversionLoaderMixin from diffusers.loaders import FromSingleFileMixin, StableDiffusionXLLoraLoaderMixin, TextualInversionLoaderMixin
from diffusers.models import AutoencoderKL, ControlNetModel, MultiAdapter, T2IAdapter, UNet2DConditionModel from diffusers.models import AutoencoderKL, ControlNetModel, MultiAdapter, T2IAdapter, UNet2DConditionModel
from diffusers.models.attention_processor import ( from diffusers.models.attention_processor import AttnProcessor2_0, XFormersAttnProcessor
AttnProcessor2_0,
LoRAAttnProcessor2_0,
LoRAXFormersAttnProcessor,
XFormersAttnProcessor,
)
from diffusers.models.lora import adjust_lora_scale_text_encoder from diffusers.models.lora import adjust_lora_scale_text_encoder
from diffusers.pipelines.controlnet.multicontrolnet import MultiControlNetModel from diffusers.pipelines.controlnet.multicontrolnet import MultiControlNetModel
from diffusers.pipelines.pipeline_utils import DiffusionPipeline, StableDiffusionMixin from diffusers.pipelines.pipeline_utils import DiffusionPipeline, StableDiffusionMixin
@@ -797,12 +792,7 @@ class StableDiffusionXLControlNetAdapterPipeline(
self.vae.to(dtype=torch.float32) self.vae.to(dtype=torch.float32)
use_torch_2_0_or_xformers = isinstance( use_torch_2_0_or_xformers = isinstance(
self.vae.decoder.mid_block.attentions[0].processor, self.vae.decoder.mid_block.attentions[0].processor,
( (AttnProcessor2_0, XFormersAttnProcessor),
AttnProcessor2_0,
XFormersAttnProcessor,
LoRAXFormersAttnProcessor,
LoRAAttnProcessor2_0,
),
) )
# if xformers or torch_2_0 is used attention block does not need # if xformers or torch_2_0 is used attention block does not need
# to be in float32 which can save lots of memory # to be in float32 which can save lots of memory
@@ -44,12 +44,7 @@ from diffusers.models import (
T2IAdapter, T2IAdapter,
UNet2DConditionModel, UNet2DConditionModel,
) )
from diffusers.models.attention_processor import ( from diffusers.models.attention_processor import AttnProcessor2_0, XFormersAttnProcessor
AttnProcessor2_0,
LoRAAttnProcessor2_0,
LoRAXFormersAttnProcessor,
XFormersAttnProcessor,
)
from diffusers.models.lora import adjust_lora_scale_text_encoder from diffusers.models.lora import adjust_lora_scale_text_encoder
from diffusers.pipelines.controlnet.multicontrolnet import MultiControlNetModel from diffusers.pipelines.controlnet.multicontrolnet import MultiControlNetModel
from diffusers.pipelines.pipeline_utils import StableDiffusionMixin from diffusers.pipelines.pipeline_utils import StableDiffusionMixin
@@ -1135,12 +1130,7 @@ class StableDiffusionXLControlNetAdapterInpaintPipeline(
self.vae.to(dtype=torch.float32) self.vae.to(dtype=torch.float32)
use_torch_2_0_or_xformers = isinstance( use_torch_2_0_or_xformers = isinstance(
self.vae.decoder.mid_block.attentions[0].processor, self.vae.decoder.mid_block.attentions[0].processor,
( (AttnProcessor2_0, XFormersAttnProcessor),
AttnProcessor2_0,
XFormersAttnProcessor,
LoRAXFormersAttnProcessor,
LoRAAttnProcessor2_0,
),
) )
# if xformers or torch_2_0 is used attention block does not need # if xformers or torch_2_0 is used attention block does not need
# to be in float32 which can save lots of memory # to be in float32 which can save lots of memory
@@ -37,8 +37,6 @@ from diffusers.loaders import (
from diffusers.models import AutoencoderKL, ImageProjection, UNet2DConditionModel from diffusers.models import AutoencoderKL, ImageProjection, UNet2DConditionModel
from diffusers.models.attention_processor import ( from diffusers.models.attention_processor import (
AttnProcessor2_0, AttnProcessor2_0,
LoRAAttnProcessor2_0,
LoRAXFormersAttnProcessor,
XFormersAttnProcessor, XFormersAttnProcessor,
) )
from diffusers.models.lora import adjust_lora_scale_text_encoder from diffusers.models.lora import adjust_lora_scale_text_encoder
@@ -854,8 +852,6 @@ class StableDiffusionXLDifferentialImg2ImgPipeline(
( (
AttnProcessor2_0, AttnProcessor2_0,
XFormersAttnProcessor, XFormersAttnProcessor,
LoRAXFormersAttnProcessor,
LoRAAttnProcessor2_0,
), ),
) )
# if xformers or torch_2_0 is used attention block does not need # if xformers or torch_2_0 is used attention block does not need
@@ -34,8 +34,6 @@ from diffusers.loaders import (
from diffusers.models import AutoencoderKL, UNet2DConditionModel from diffusers.models import AutoencoderKL, UNet2DConditionModel
from diffusers.models.attention_processor import ( from diffusers.models.attention_processor import (
AttnProcessor2_0, AttnProcessor2_0,
LoRAAttnProcessor2_0,
LoRAXFormersAttnProcessor,
XFormersAttnProcessor, XFormersAttnProcessor,
) )
from diffusers.models.lora import adjust_lora_scale_text_encoder from diffusers.models.lora import adjust_lora_scale_text_encoder
@@ -662,8 +660,6 @@ class StableDiffusionXLPipelineIpex(
( (
AttnProcessor2_0, AttnProcessor2_0,
XFormersAttnProcessor, XFormersAttnProcessor,
LoRAXFormersAttnProcessor,
LoRAAttnProcessor2_0,
), ),
) )
# if xformers or torch_2_0 is used attention block does not need # if xformers or torch_2_0 is used attention block does not need
@@ -73,7 +73,7 @@ if is_wandb_available():
import wandb import wandb
# Will error if the minimal version of diffusers is not installed. Remove at your own risks. # Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.29.0.dev0") check_min_version("0.30.0.dev0")
logger = get_logger(__name__) logger = get_logger(__name__)
@@ -66,7 +66,7 @@ if is_wandb_available():
import wandb import wandb
# Will error if the minimal version of diffusers is not installed. Remove at your own risks. # Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.29.0.dev0") check_min_version("0.30.0.dev0")
logger = get_logger(__name__) logger = get_logger(__name__)
@@ -1111,11 +1111,16 @@ def main(args):
# 15. LR Scheduler creation # 15. LR Scheduler creation
# Scheduler and math around the number of training steps. # Scheduler and math around the number of training steps.
overrode_max_train_steps = False # Check the PR https://github.com/huggingface/diffusers/pull/8312 for detailed explanation.
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) num_warmup_steps_for_scheduler = args.lr_warmup_steps * accelerator.num_processes
if args.max_train_steps is None: if args.max_train_steps is None:
args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch len_train_dataloader_after_sharding = math.ceil(len(train_dataloader) / accelerator.num_processes)
overrode_max_train_steps = True num_update_steps_per_epoch = math.ceil(len_train_dataloader_after_sharding / args.gradient_accumulation_steps)
num_training_steps_for_scheduler = (
args.num_train_epochs * num_update_steps_per_epoch * accelerator.num_processes
)
else:
num_training_steps_for_scheduler = args.max_train_steps * accelerator.num_processes
if args.scale_lr: if args.scale_lr:
args.learning_rate = ( args.learning_rate = (
@@ -1130,8 +1135,8 @@ def main(args):
lr_scheduler = get_scheduler( lr_scheduler = get_scheduler(
args.lr_scheduler, args.lr_scheduler,
optimizer=optimizer, optimizer=optimizer,
num_warmup_steps=args.lr_warmup_steps * accelerator.num_processes, num_warmup_steps=num_warmup_steps_for_scheduler,
num_training_steps=args.max_train_steps * accelerator.num_processes, num_training_steps=num_training_steps_for_scheduler,
) )
# 16. Prepare for training # 16. Prepare for training
@@ -1142,8 +1147,14 @@ def main(args):
# We need to recalculate our total training steps as the size of the training dataloader may have changed. # We need to recalculate our total training steps as the size of the training dataloader may have changed.
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
if overrode_max_train_steps: if args.max_train_steps is None:
args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
if num_training_steps_for_scheduler != args.max_train_steps * accelerator.num_processes:
logger.warning(
f"The length of the 'train_dataloader' after 'accelerator.prepare' ({len(train_dataloader)}) does not match "
f"the expected length ({len_train_dataloader_after_sharding}) when the learning rate scheduler was created. "
f"This inconsistency may result in the learning rate scheduler not functioning properly."
)
# Afterwards we recalculate our number of training epochs # Afterwards we recalculate our number of training epochs
args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch) args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
@@ -79,7 +79,7 @@ if is_wandb_available():
import wandb import wandb
# Will error if the minimal version of diffusers is not installed. Remove at your own risks. # Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.29.0.dev0") check_min_version("0.30.0.dev0")
logger = get_logger(__name__) logger = get_logger(__name__)
@@ -72,7 +72,7 @@ if is_wandb_available():
import wandb import wandb
# Will error if the minimal version of diffusers is not installed. Remove at your own risks. # Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.29.0.dev0") check_min_version("0.30.0.dev0")
logger = get_logger(__name__) logger = get_logger(__name__)
@@ -78,7 +78,7 @@ if is_wandb_available():
import wandb import wandb
# Will error if the minimal version of diffusers is not installed. Remove at your own risks. # Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.29.0.dev0") check_min_version("0.30.0.dev0")
logger = get_logger(__name__) logger = get_logger(__name__)
+1 -1
View File
@@ -60,7 +60,7 @@ if is_wandb_available():
import wandb import wandb
# Will error if the minimal version of diffusers is not installed. Remove at your own risks. # Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.29.0.dev0") check_min_version("0.30.0.dev0")
logger = get_logger(__name__) logger = get_logger(__name__)
+1 -1
View File
@@ -60,7 +60,7 @@ if is_wandb_available():
import wandb import wandb
# Will error if the minimal version of diffusers is not installed. Remove at your own risks. # Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.29.0.dev0") check_min_version("0.30.0.dev0")
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
+1 -1
View File
@@ -61,7 +61,7 @@ if is_wandb_available():
import wandb import wandb
# Will error if the minimal version of diffusers is not installed. Remove at your own risks. # Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.29.0.dev0") check_min_version("0.30.0.dev0")
logger = get_logger(__name__) logger = get_logger(__name__)
if is_torch_npu_available(): if is_torch_npu_available():
@@ -63,7 +63,7 @@ from diffusers.utils.import_utils import is_xformers_available
# Will error if the minimal version of diffusers is not installed. Remove at your own risks. # Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.29.0.dev0") check_min_version("0.30.0.dev0")
logger = get_logger(__name__) logger = get_logger(__name__)
+11
View File
@@ -11,6 +11,8 @@ The `train_dreambooth_sd3.py` script shows how to implement the training procedu
huggingface-cli login huggingface-cli login
``` ```
This will also allow us to push the trained model parameters to the Hugging Face Hub platform.
## Running locally with PyTorch ## Running locally with PyTorch
### Installing the dependencies ### Installing the dependencies
@@ -106,6 +108,9 @@ To better track our training experiments, we're using the following flags in the
* `report_to="wandb` will ensure the training runs are tracked on Weights and Biases. To use it, be sure to install `wandb` with `pip install wandb`. * `report_to="wandb` will ensure the training runs are tracked on Weights and Biases. To use it, be sure to install `wandb` with `pip install wandb`.
* `validation_prompt` and `validation_epochs` to allow the script to do a few validation inference runs. This allows us to qualitatively check if the training is progressing as expected. * `validation_prompt` and `validation_epochs` to allow the script to do a few validation inference runs. This allows us to qualitatively check if the training is progressing as expected.
> [!NOTE]
> If you want to train using long prompts with the T5 text encoder, you can use `--max_sequence_length` to set the token limit. The default is 77, but it can be increased to as high as 512. Note that this will use more resources and may slow down the training in some cases.
> [!TIP] > [!TIP]
> You can pass `--use_8bit_adam` to reduce the memory requirements of training. Make sure to install `bitsandbytes` if you want to do so. > You can pass `--use_8bit_adam` to reduce the memory requirements of training. Make sure to install `bitsandbytes` if you want to do so.
@@ -113,6 +118,8 @@ To better track our training experiments, we're using the following flags in the
[LoRA](https://huggingface.co/docs/peft/conceptual_guides/adapter#low-rank-adaptation-lora) is a popular parameter-efficient fine-tuning technique that allows you to achieve full-finetuning like performance but with a fraction of learnable parameters. [LoRA](https://huggingface.co/docs/peft/conceptual_guides/adapter#low-rank-adaptation-lora) is a popular parameter-efficient fine-tuning technique that allows you to achieve full-finetuning like performance but with a fraction of learnable parameters.
Note also that we use PEFT library as backend for LoRA training, make sure to have `peft>=0.6.0` installed in your environment.
To perform DreamBooth with LoRA, run: To perform DreamBooth with LoRA, run:
```bash ```bash
@@ -139,3 +146,7 @@ accelerate launch train_dreambooth_lora_sd3.py \
--seed="0" \ --seed="0" \
--push_to_hub --push_to_hub
``` ```
## Other notes
We default to the "logit_normal" weighting scheme for the loss following the SD3 paper. Thanks to @bghira for helping us discover that for other weighting schemes supported from the training script, training may incur numerical instabilities.
+2 -1
View File
@@ -4,4 +4,5 @@ transformers>=4.41.2
ftfy ftfy
tensorboard tensorboard
Jinja2 Jinja2
peft== 0.11.1 peft==0.11.1
sentencepiece
+1 -1
View File
@@ -63,7 +63,7 @@ if is_wandb_available():
import wandb import wandb
# Will error if the minimal version of diffusers is not installed. Remove at your own risks. # Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.29.0.dev0") check_min_version("0.30.0.dev0")
logger = get_logger(__name__) logger = get_logger(__name__)
+1 -1
View File
@@ -35,7 +35,7 @@ from diffusers.utils import check_min_version
# Will error if the minimal version of diffusers is not installed. Remove at your own risks. # Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.29.0.dev0") check_min_version("0.30.0.dev0")
# Cache compiled models across invocations of this script. # Cache compiled models across invocations of this script.
cc.initialize_cache(os.path.expanduser("~/.cache/jax/compilation_cache")) cc.initialize_cache(os.path.expanduser("~/.cache/jax/compilation_cache"))
+1 -1
View File
@@ -70,7 +70,7 @@ if is_wandb_available():
import wandb import wandb
# Will error if the minimal version of diffusers is not installed. Remove at your own risks. # Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.29.0.dev0") check_min_version("0.30.0.dev0")
logger = get_logger(__name__) logger = get_logger(__name__)
@@ -53,7 +53,11 @@ from diffusers import (
StableDiffusion3Pipeline, StableDiffusion3Pipeline,
) )
from diffusers.optimization import get_scheduler from diffusers.optimization import get_scheduler
from diffusers.training_utils import cast_training_params from diffusers.training_utils import (
cast_training_params,
compute_density_for_timestep_sampling,
compute_loss_weighting_for_sd3,
)
from diffusers.utils import ( from diffusers.utils import (
check_min_version, check_min_version,
convert_unet_state_dict_to_peft, convert_unet_state_dict_to_peft,
@@ -67,7 +71,7 @@ if is_wandb_available():
import wandb import wandb
# Will error if the minimal version of diffusers is not installed. Remove at your own risks. # Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.28.0.dev0") check_min_version("0.30.0.dev0")
logger = get_logger(__name__) logger = get_logger(__name__)
@@ -298,6 +302,12 @@ def parse_args(input_args=None):
default=None, default=None,
help="The prompt to specify images in the same class as provided instance images.", help="The prompt to specify images in the same class as provided instance images.",
) )
parser.add_argument(
"--max_sequence_length",
type=int,
default=77,
help="Maximum sequence length to use with with the T5 text encoder",
)
parser.add_argument( parser.add_argument(
"--validation_prompt", "--validation_prompt",
type=str, type=str,
@@ -467,11 +477,23 @@ def parse_args(input_args=None):
), ),
) )
parser.add_argument( parser.add_argument(
"--weighting_scheme", type=str, default="sigma_sqrt", choices=["sigma_sqrt", "logit_normal", "mode"] "--weighting_scheme",
type=str,
default="logit_normal",
choices=["sigma_sqrt", "logit_normal", "mode", "cosmap"],
)
parser.add_argument(
"--logit_mean", type=float, default=0.0, help="mean to use when using the `'logit_normal'` weighting scheme."
)
parser.add_argument(
"--logit_std", type=float, default=1.0, help="std to use when using the `'logit_normal'` weighting scheme."
)
parser.add_argument(
"--mode_scale",
type=float,
default=1.29,
help="Scale of mode weighting scheme. Only effective when using the `'mode'` as the `weighting_scheme`.",
) )
parser.add_argument("--logit_mean", type=float, default=0.0)
parser.add_argument("--logit_std", type=float, default=1.0)
parser.add_argument("--mode_scale", type=float, default=1.29)
parser.add_argument( parser.add_argument(
"--optimizer", "--optimizer",
type=str, type=str,
@@ -495,7 +517,7 @@ def parse_args(input_args=None):
"--prodigy_beta3", "--prodigy_beta3",
type=float, type=float,
default=None, default=None,
help="coefficients for computing the Prodidy stepsize using running averages. If set to None, " help="coefficients for computing the Prodigy stepsize using running averages. If set to None, "
"uses the value of square root of beta2. Ignored if optimizer is adamW", "uses the value of square root of beta2. Ignored if optimizer is adamW",
) )
parser.add_argument("--prodigy_decouple", type=bool, default=True, help="Use AdamW style decoupled weight decay") parser.add_argument("--prodigy_decouple", type=bool, default=True, help="Use AdamW style decoupled weight decay")
@@ -766,7 +788,7 @@ class DreamBoothDataset(Dataset):
else: else:
example["instance_prompt"] = self.instance_prompt example["instance_prompt"] = self.instance_prompt
else: # costum prompts were provided, but length does not match size of image dataset else: # custom prompts were provided, but length does not match size of image dataset
example["instance_prompt"] = self.instance_prompt example["instance_prompt"] = self.instance_prompt
if self.class_data_root: if self.class_data_root:
@@ -830,6 +852,7 @@ def tokenize_prompt(tokenizer, prompt):
def _encode_prompt_with_t5( def _encode_prompt_with_t5(
text_encoder, text_encoder,
tokenizer, tokenizer,
max_sequence_length,
prompt=None, prompt=None,
num_images_per_prompt=1, num_images_per_prompt=1,
device=None, device=None,
@@ -840,7 +863,7 @@ def _encode_prompt_with_t5(
text_inputs = tokenizer( text_inputs = tokenizer(
prompt, prompt,
padding="max_length", padding="max_length",
max_length=77, max_length=max_sequence_length,
truncation=True, truncation=True,
add_special_tokens=True, add_special_tokens=True,
return_tensors="pt", return_tensors="pt",
@@ -897,6 +920,7 @@ def encode_prompt(
text_encoders, text_encoders,
tokenizers, tokenizers,
prompt: str, prompt: str,
max_sequence_length,
device=None, device=None,
num_images_per_prompt: int = 1, num_images_per_prompt: int = 1,
): ):
@@ -924,6 +948,7 @@ def encode_prompt(
t5_prompt_embed = _encode_prompt_with_t5( t5_prompt_embed = _encode_prompt_with_t5(
text_encoders[-1], text_encoders[-1],
tokenizers[-1], tokenizers[-1],
max_sequence_length,
prompt=prompt, prompt=prompt,
num_images_per_prompt=num_images_per_prompt, num_images_per_prompt=num_images_per_prompt,
device=device if device is not None else text_encoders[-1].device, device=device if device is not None else text_encoders[-1].device,
@@ -1297,7 +1322,9 @@ def main(args):
def compute_text_embeddings(prompt, text_encoders, tokenizers): def compute_text_embeddings(prompt, text_encoders, tokenizers):
with torch.no_grad(): with torch.no_grad():
prompt_embeds, pooled_prompt_embeds = encode_prompt(text_encoders, tokenizers, prompt) prompt_embeds, pooled_prompt_embeds = encode_prompt(
text_encoders, tokenizers, prompt, args.max_sequence_length
)
prompt_embeds = prompt_embeds.to(accelerator.device) prompt_embeds = prompt_embeds.to(accelerator.device)
pooled_prompt_embeds = pooled_prompt_embeds.to(accelerator.device) pooled_prompt_embeds = pooled_prompt_embeds.to(accelerator.device)
return prompt_embeds, pooled_prompt_embeds return prompt_embeds, pooled_prompt_embeds
@@ -1316,6 +1343,8 @@ def main(args):
# Clear the memory here # Clear the memory here
if not train_dataset.custom_instance_prompts: if not train_dataset.custom_instance_prompts:
del tokenizers, text_encoders del tokenizers, text_encoders
# Explicitly delete the objects as well, otherwise only the lists are deleted and the original references remain, preventing garbage collection
del text_encoder_one, text_encoder_two, text_encoder_three
gc.collect() gc.collect()
if torch.cuda.is_available(): if torch.cuda.is_available():
torch.cuda.empty_cache() torch.cuda.empty_cache()
@@ -1330,7 +1359,7 @@ def main(args):
if args.with_prior_preservation: if args.with_prior_preservation:
prompt_embeds = torch.cat([prompt_embeds, class_prompt_hidden_states], dim=0) prompt_embeds = torch.cat([prompt_embeds, class_prompt_hidden_states], dim=0)
pooled_prompt_embeds = torch.cat([pooled_prompt_embeds, class_pooled_prompt_embeds], dim=0) pooled_prompt_embeds = torch.cat([pooled_prompt_embeds, class_pooled_prompt_embeds], dim=0)
# if we're optmizing the text encoder (both if instance prompt is used for all images or custom prompts) we need to tokenize and encode the # if we're optimizing the text encoder (both if instance prompt is used for all images or custom prompts) we need to tokenize and encode the
# batch prompts on all training steps # batch prompts on all training steps
else: else:
tokens_one = tokenize_prompt(tokenizer_one, args.instance_prompt) tokens_one = tokenize_prompt(tokenizer_one, args.instance_prompt)
@@ -1462,7 +1491,15 @@ def main(args):
bsz = model_input.shape[0] bsz = model_input.shape[0]
# Sample a random timestep for each image # Sample a random timestep for each image
indices = torch.randint(0, noise_scheduler_copy.config.num_train_timesteps, (bsz,)) # for weighting schemes where we sample timesteps non-uniformly
u = compute_density_for_timestep_sampling(
weighting_scheme=args.weighting_scheme,
batch_size=bsz,
logit_mean=args.logit_mean,
logit_std=args.logit_std,
mode_scale=args.mode_scale,
)
indices = (u * noise_scheduler_copy.config.num_train_timesteps).long()
timesteps = noise_scheduler_copy.timesteps[indices].to(device=model_input.device) timesteps = noise_scheduler_copy.timesteps[indices].to(device=model_input.device)
# Add noise according to flow matching. # Add noise according to flow matching.
@@ -1482,20 +1519,11 @@ def main(args):
# Preconditioning of the model outputs. # Preconditioning of the model outputs.
model_pred = model_pred * (-sigmas) + noisy_model_input model_pred = model_pred * (-sigmas) + noisy_model_input
# TODO (kashif, sayakpaul): weighting sceme needs to be experimented with :) # these weighting schemes use a uniform timestep sampling
if args.weighting_scheme == "sigma_sqrt": # and instead post-weight the loss
weighting = (sigmas**-2.0).float() weighting = compute_loss_weighting_for_sd3(weighting_scheme=args.weighting_scheme, sigmas=sigmas)
elif args.weighting_scheme == "logit_normal":
# See 3.1 in the SD3 paper ($rf/lognorm(0.00,1.00)$).
u = torch.normal(mean=args.logit_mean, std=args.logit_std, size=(bsz,), device=accelerator.device)
weighting = torch.nn.functional.sigmoid(u)
elif args.weighting_scheme == "mode":
# See sec 3.1 in the SD3 paper (20).
u = torch.rand(size=(bsz,), device=accelerator.device)
weighting = 1 - u - args.mode_scale * (torch.cos(math.pi * u / 2) ** 2 - 1 + u)
# simplified flow matching aka 0-rectified flow matching loss # flow matching loss
# target = model_input - noise
target = model_input target = model_input
if args.with_prior_preservation: if args.with_prior_preservation:
@@ -78,7 +78,7 @@ if is_wandb_available():
import wandb import wandb
# Will error if the minimal version of diffusers is not installed. Remove at your own risks. # Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.29.0.dev0") check_min_version("0.30.0.dev0")
logger = get_logger(__name__) logger = get_logger(__name__)
@@ -562,7 +562,7 @@ def parse_args(input_args=None):
"--prodigy_beta3", "--prodigy_beta3",
type=float, type=float,
default=None, default=None,
help="coefficients for computing the Prodidy stepsize using running averages. If set to None, " help="coefficients for computing the Prodigy stepsize using running averages. If set to None, "
"uses the value of square root of beta2. Ignored if optimizer is adamW", "uses the value of square root of beta2. Ignored if optimizer is adamW",
) )
parser.add_argument("--prodigy_decouple", type=bool, default=True, help="Use AdamW style decoupled weight decay") parser.add_argument("--prodigy_decouple", type=bool, default=True, help="Use AdamW style decoupled weight decay")
@@ -861,7 +861,7 @@ class DreamBoothDataset(Dataset):
else: else:
example["instance_prompt"] = self.instance_prompt example["instance_prompt"] = self.instance_prompt
else: # costum prompts were provided, but length does not match size of image dataset else: # custom prompts were provided, but length does not match size of image dataset
example["instance_prompt"] = self.instance_prompt example["instance_prompt"] = self.instance_prompt
if self.class_data_root: if self.class_data_root:
@@ -1289,8 +1289,8 @@ def main(args):
models = [unet_] models = [unet_]
if args.train_text_encoder: if args.train_text_encoder:
models.extend([text_encoder_one_, text_encoder_two_]) models.extend([text_encoder_one_, text_encoder_two_])
# only upcast trainable parameters (LoRA) into fp32 # only upcast trainable parameters (LoRA) into fp32
cast_training_params(models) cast_training_params(models)
accelerator.register_save_state_pre_hook(save_model_hook) accelerator.register_save_state_pre_hook(save_model_hook)
accelerator.register_load_state_pre_hook(load_model_hook) accelerator.register_load_state_pre_hook(load_model_hook)
@@ -1488,7 +1488,7 @@ def main(args):
if args.with_prior_preservation: if args.with_prior_preservation:
prompt_embeds = torch.cat([prompt_embeds, class_prompt_hidden_states], dim=0) prompt_embeds = torch.cat([prompt_embeds, class_prompt_hidden_states], dim=0)
unet_add_text_embeds = torch.cat([unet_add_text_embeds, class_pooled_prompt_embeds], dim=0) unet_add_text_embeds = torch.cat([unet_add_text_embeds, class_pooled_prompt_embeds], dim=0)
# if we're optmizing the text encoder (both if instance prompt is used for all images or custom prompts) we need to tokenize and encode the # if we're optimizing the text encoder (both if instance prompt is used for all images or custom prompts) we need to tokenize and encode the
# batch prompts on all training steps # batch prompts on all training steps
else: else:
tokens_one = tokenize_prompt(tokenizer_one, args.instance_prompt) tokens_one = tokenize_prompt(tokenizer_one, args.instance_prompt)
+49 -25
View File
@@ -51,6 +51,7 @@ from diffusers import (
StableDiffusion3Pipeline, StableDiffusion3Pipeline,
) )
from diffusers.optimization import get_scheduler from diffusers.optimization import get_scheduler
from diffusers.training_utils import compute_density_for_timestep_sampling, compute_loss_weighting_for_sd3
from diffusers.utils import ( from diffusers.utils import (
check_min_version, check_min_version,
is_wandb_available, is_wandb_available,
@@ -63,7 +64,7 @@ if is_wandb_available():
import wandb import wandb
# Will error if the minimal version of diffusers is not installed. Remove at your own risks. # Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.28.0.dev0") check_min_version("0.30.0.dev0")
logger = get_logger(__name__) logger = get_logger(__name__)
@@ -297,6 +298,12 @@ def parse_args(input_args=None):
default=None, default=None,
help="The prompt to specify images in the same class as provided instance images.", help="The prompt to specify images in the same class as provided instance images.",
) )
parser.add_argument(
"--max_sequence_length",
type=int,
default=77,
help="Maximum sequence length to use with with the T5 text encoder",
)
parser.add_argument( parser.add_argument(
"--validation_prompt", "--validation_prompt",
type=str, type=str,
@@ -465,11 +472,23 @@ def parse_args(input_args=None):
), ),
) )
parser.add_argument( parser.add_argument(
"--weighting_scheme", type=str, default="sigma_sqrt", choices=["sigma_sqrt", "logit_normal", "mode"] "--weighting_scheme",
type=str,
default="logit_normal",
choices=["sigma_sqrt", "logit_normal", "mode", "cosmap"],
)
parser.add_argument(
"--logit_mean", type=float, default=0.0, help="mean to use when using the `'logit_normal'` weighting scheme."
)
parser.add_argument(
"--logit_std", type=float, default=1.0, help="std to use when using the `'logit_normal'` weighting scheme."
)
parser.add_argument(
"--mode_scale",
type=float,
default=1.29,
help="Scale of mode weighting scheme. Only effective when using the `'mode'` as the `weighting_scheme`.",
) )
parser.add_argument("--logit_mean", type=float, default=0.0)
parser.add_argument("--logit_std", type=float, default=1.0)
parser.add_argument("--mode_scale", type=float, default=1.29)
parser.add_argument( parser.add_argument(
"--optimizer", "--optimizer",
type=str, type=str,
@@ -493,7 +512,7 @@ def parse_args(input_args=None):
"--prodigy_beta3", "--prodigy_beta3",
type=float, type=float,
default=None, default=None,
help="coefficients for computing the Prodidy stepsize using running averages. If set to None, " help="coefficients for computing the Prodigy stepsize using running averages. If set to None, "
"uses the value of square root of beta2. Ignored if optimizer is adamW", "uses the value of square root of beta2. Ignored if optimizer is adamW",
) )
parser.add_argument("--prodigy_decouple", type=bool, default=True, help="Use AdamW style decoupled weight decay") parser.add_argument("--prodigy_decouple", type=bool, default=True, help="Use AdamW style decoupled weight decay")
@@ -764,7 +783,7 @@ class DreamBoothDataset(Dataset):
else: else:
example["instance_prompt"] = self.instance_prompt example["instance_prompt"] = self.instance_prompt
else: # costum prompts were provided, but length does not match size of image dataset else: # custom prompts were provided, but length does not match size of image dataset
example["instance_prompt"] = self.instance_prompt example["instance_prompt"] = self.instance_prompt
if self.class_data_root: if self.class_data_root:
@@ -828,6 +847,7 @@ def tokenize_prompt(tokenizer, prompt):
def _encode_prompt_with_t5( def _encode_prompt_with_t5(
text_encoder, text_encoder,
tokenizer, tokenizer,
max_sequence_length,
prompt=None, prompt=None,
num_images_per_prompt=1, num_images_per_prompt=1,
device=None, device=None,
@@ -838,7 +858,7 @@ def _encode_prompt_with_t5(
text_inputs = tokenizer( text_inputs = tokenizer(
prompt, prompt,
padding="max_length", padding="max_length",
max_length=77, max_length=max_sequence_length,
truncation=True, truncation=True,
add_special_tokens=True, add_special_tokens=True,
return_tensors="pt", return_tensors="pt",
@@ -895,6 +915,7 @@ def encode_prompt(
text_encoders, text_encoders,
tokenizers, tokenizers,
prompt: str, prompt: str,
max_sequence_length,
device=None, device=None,
num_images_per_prompt: int = 1, num_images_per_prompt: int = 1,
): ):
@@ -922,6 +943,7 @@ def encode_prompt(
t5_prompt_embed = _encode_prompt_with_t5( t5_prompt_embed = _encode_prompt_with_t5(
text_encoders[-1], text_encoders[-1],
tokenizers[-1], tokenizers[-1],
max_sequence_length,
prompt=prompt, prompt=prompt,
num_images_per_prompt=num_images_per_prompt, num_images_per_prompt=num_images_per_prompt,
device=device if device is not None else text_encoders[-1].device, device=device if device is not None else text_encoders[-1].device,
@@ -1324,7 +1346,9 @@ def main(args):
def compute_text_embeddings(prompt, text_encoders, tokenizers): def compute_text_embeddings(prompt, text_encoders, tokenizers):
with torch.no_grad(): with torch.no_grad():
prompt_embeds, pooled_prompt_embeds = encode_prompt(text_encoders, tokenizers, prompt) prompt_embeds, pooled_prompt_embeds = encode_prompt(
text_encoders, tokenizers, prompt, args.max_sequence_length
)
prompt_embeds = prompt_embeds.to(accelerator.device) prompt_embeds = prompt_embeds.to(accelerator.device)
pooled_prompt_embeds = pooled_prompt_embeds.to(accelerator.device) pooled_prompt_embeds = pooled_prompt_embeds.to(accelerator.device)
return prompt_embeds, pooled_prompt_embeds return prompt_embeds, pooled_prompt_embeds
@@ -1347,6 +1371,8 @@ def main(args):
# Clear the memory here # Clear the memory here
if not args.train_text_encoder and not train_dataset.custom_instance_prompts: if not args.train_text_encoder and not train_dataset.custom_instance_prompts:
del tokenizers, text_encoders del tokenizers, text_encoders
# Explicitly delete the objects as well, otherwise only the lists are deleted and the original references remain, preventing garbage collection
del text_encoder_one, text_encoder_two, text_encoder_three
gc.collect() gc.collect()
if torch.cuda.is_available(): if torch.cuda.is_available():
torch.cuda.empty_cache() torch.cuda.empty_cache()
@@ -1362,7 +1388,7 @@ def main(args):
if args.with_prior_preservation: if args.with_prior_preservation:
prompt_embeds = torch.cat([prompt_embeds, class_prompt_hidden_states], dim=0) prompt_embeds = torch.cat([prompt_embeds, class_prompt_hidden_states], dim=0)
pooled_prompt_embeds = torch.cat([pooled_prompt_embeds, class_pooled_prompt_embeds], dim=0) pooled_prompt_embeds = torch.cat([pooled_prompt_embeds, class_pooled_prompt_embeds], dim=0)
# if we're optmizing the text encoder (both if instance prompt is used for all images or custom prompts) we need to tokenize and encode the # if we're optimizing the text encoder (both if instance prompt is used for all images or custom prompts) we need to tokenize and encode the
# batch prompts on all training steps # batch prompts on all training steps
else: else:
tokens_one = tokenize_prompt(tokenizer_one, args.instance_prompt) tokens_one = tokenize_prompt(tokenizer_one, args.instance_prompt)
@@ -1526,7 +1552,15 @@ def main(args):
bsz = model_input.shape[0] bsz = model_input.shape[0]
# Sample a random timestep for each image # Sample a random timestep for each image
indices = torch.randint(0, noise_scheduler_copy.config.num_train_timesteps, (bsz,)) # for weighting schemes where we sample timesteps non-uniformly
u = compute_density_for_timestep_sampling(
weighting_scheme=args.weighting_scheme,
batch_size=bsz,
logit_mean=args.logit_mean,
logit_std=args.logit_std,
mode_scale=args.mode_scale,
)
indices = (u * noise_scheduler_copy.config.num_train_timesteps).long()
timesteps = noise_scheduler_copy.timesteps[indices].to(device=model_input.device) timesteps = noise_scheduler_copy.timesteps[indices].to(device=model_input.device)
# Add noise according to flow matching. # Add noise according to flow matching.
@@ -1560,21 +1594,11 @@ def main(args):
# Follow: Section 5 of https://arxiv.org/abs/2206.00364. # Follow: Section 5 of https://arxiv.org/abs/2206.00364.
# Preconditioning of the model outputs. # Preconditioning of the model outputs.
model_pred = model_pred * (-sigmas) + noisy_model_input model_pred = model_pred * (-sigmas) + noisy_model_input
# these weighting schemes use a uniform timestep sampling
# and instead post-weight the loss
weighting = compute_loss_weighting_for_sd3(weighting_scheme=args.weighting_scheme, sigmas=sigmas)
# TODO (kashif, sayakpaul): weighting sceme needs to be experimented with :) # flow matching loss
if args.weighting_scheme == "sigma_sqrt":
weighting = (sigmas**-2.0).float()
elif args.weighting_scheme == "logit_normal":
# See 3.1 in the SD3 paper ($rf/lognorm(0.00,1.00)$).
u = torch.normal(mean=args.logit_mean, std=args.logit_std, size=(bsz,), device=accelerator.device)
weighting = torch.nn.functional.sigmoid(u)
elif args.weighting_scheme == "mode":
# See sec 3.1 in the SD3 paper (20).
u = torch.rand(size=(bsz,), device=accelerator.device)
weighting = 1 - u - args.mode_scale * (torch.cos(math.pi * u / 2) ** 2 - 1 + u)
# simplified flow matching aka 0-rectified flow matching loss
# target = model_input - noise
target = model_input target = model_input
if args.with_prior_preservation: if args.with_prior_preservation:
@@ -57,7 +57,7 @@ if is_wandb_available():
import wandb import wandb
# Will error if the minimal version of diffusers is not installed. Remove at your own risks. # Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.29.0.dev0") check_min_version("0.30.0.dev0")
logger = get_logger(__name__, log_level="INFO") logger = get_logger(__name__, log_level="INFO")
@@ -60,7 +60,7 @@ if is_wandb_available():
import wandb import wandb
# Will error if the minimal version of diffusers is not installed. Remove at your own risks. # Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.29.0.dev0") check_min_version("0.30.0.dev0")
logger = get_logger(__name__, log_level="INFO") logger = get_logger(__name__, log_level="INFO")
@@ -52,7 +52,7 @@ if is_wandb_available():
# Will error if the minimal version of diffusers is not installed. Remove at your own risks. # Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.29.0.dev0") check_min_version("0.30.0.dev0")
logger = get_logger(__name__, log_level="INFO") logger = get_logger(__name__, log_level="INFO")
@@ -46,7 +46,7 @@ from diffusers.utils import check_min_version, is_wandb_available
# Will error if the minimal version of diffusers is not installed. Remove at your own risks. # Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.29.0.dev0") check_min_version("0.30.0.dev0")
logger = get_logger(__name__, log_level="INFO") logger = get_logger(__name__, log_level="INFO")
@@ -46,7 +46,7 @@ from diffusers.utils import check_min_version, is_wandb_available
# Will error if the minimal version of diffusers is not installed. Remove at your own risks. # Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.29.0.dev0") check_min_version("0.30.0.dev0")
logger = get_logger(__name__, log_level="INFO") logger = get_logger(__name__, log_level="INFO")
@@ -51,7 +51,7 @@ if is_wandb_available():
# Will error if the minimal version of diffusers is not installed. Remove at your own risks. # Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.29.0.dev0") check_min_version("0.30.0.dev0")
logger = get_logger(__name__, log_level="INFO") logger = get_logger(__name__, log_level="INFO")
@@ -42,7 +42,6 @@ input_image_path = "/path/to/input_image"
input_image = Image.open(input_image_path) input_image = Image.open(input_image_path)
edited_images = pipe_lora(num_images_per_prompt=1, prompt=args.edit_prompt, image=input_image, num_inference_steps=1000).images edited_images = pipe_lora(num_images_per_prompt=1, prompt=args.edit_prompt, image=input_image, num_inference_steps=1000).images
edited_images[0].show() edited_images[0].show()
``` ```
## Results ## Results
@@ -29,7 +29,6 @@ export MALLOC_CONF="oversize_threshold:1,background_thread:true,metadata_thp:aut
numactl --membind <node N> -C <cpu list> python python inference_bf16.py numactl --membind <node N> -C <cpu list> python python inference_bf16.py
# Launch with DPMSolverMultistepScheduler # Launch with DPMSolverMultistepScheduler
numactl --membind <node N> -C <cpu list> python python inference_bf16.py --dpm numactl --membind <node N> -C <cpu list> python python inference_bf16.py --dpm
``` ```
## Accelerating the inference for Stable Diffusion using INT8 ## Accelerating the inference for Stable Diffusion using INT8
@@ -561,7 +561,7 @@ def parse_args(input_args=None):
"--prodigy_beta3", "--prodigy_beta3",
type=float, type=float,
default=None, default=None,
help="coefficients for computing the Prodidy stepsize using running averages. If set to None, " help="coefficients for computing the Prodigy stepsize using running averages. If set to None, "
"uses the value of square root of beta2. Ignored if optimizer is adamW", "uses the value of square root of beta2. Ignored if optimizer is adamW",
) )
parser.add_argument("--prodigy_decouple", type=bool, default=True, help="Use AdamW style decoupled weight decay") parser.add_argument("--prodigy_decouple", type=bool, default=True, help="Use AdamW style decoupled weight decay")
@@ -880,7 +880,7 @@ class DreamBoothDataset(Dataset):
else: else:
example["instance_prompt"] = self.instance_prompt example["instance_prompt"] = self.instance_prompt
else: # costum prompts were provided, but length does not match size of image dataset else: # custom prompts were provided, but length does not match size of image dataset
example["instance_prompt"] = self.instance_prompt example["instance_prompt"] = self.instance_prompt
if self.class_data_root: if self.class_data_root:
@@ -1363,8 +1363,8 @@ def main(args):
models = [unet_] models = [unet_]
if args.train_text_encoder: if args.train_text_encoder:
models.extend([text_encoder_one_, text_encoder_two_]) models.extend([text_encoder_one_, text_encoder_two_])
# only upcast trainable parameters (LoRA) into fp32 # only upcast trainable parameters (LoRA) into fp32
cast_training_params(models) cast_training_params(models)
accelerator.register_save_state_pre_hook(save_model_hook) accelerator.register_save_state_pre_hook(save_model_hook)
accelerator.register_load_state_pre_hook(load_model_hook) accelerator.register_load_state_pre_hook(load_model_hook)
@@ -1561,7 +1561,7 @@ def main(args):
if args.with_prior_preservation: if args.with_prior_preservation:
prompt_embeds = torch.cat([prompt_embeds, class_prompt_hidden_states], dim=0) prompt_embeds = torch.cat([prompt_embeds, class_prompt_hidden_states], dim=0)
unet_add_text_embeds = torch.cat([unet_add_text_embeds, class_pooled_prompt_embeds], dim=0) unet_add_text_embeds = torch.cat([unet_add_text_embeds, class_pooled_prompt_embeds], dim=0)
# if we're optmizing the text encoder (both if instance prompt is used for all images or custom prompts) we need to tokenize and encode the # if we're optimizing the text encoder (both if instance prompt is used for all images or custom prompts) we need to tokenize and encode the
# batch prompts on all training steps # batch prompts on all training steps
else: else:
tokens_one = tokenize_prompt(tokenizer_one, args.instance_prompt) tokens_one = tokenize_prompt(tokenizer_one, args.instance_prompt)
@@ -60,7 +60,7 @@ if is_wandb_available():
import wandb import wandb
# Will error if the minimal version of diffusers is not installed. Remove at your own risks. # Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.29.0.dev0") check_min_version("0.30.0.dev0")
logger = get_logger(__name__) logger = get_logger(__name__)
+8
View File
@@ -170,11 +170,19 @@ For our small Narutos dataset, the effects of Min-SNR weighting strategy might n
Also, note that in this example, we either predict `epsilon` (i.e., the noise) or the `v_prediction`. For both of these cases, the formulation of the Min-SNR weighting strategy that we have used holds. Also, note that in this example, we either predict `epsilon` (i.e., the noise) or the `v_prediction`. For both of these cases, the formulation of the Min-SNR weighting strategy that we have used holds.
#### Training with EMA weights
Through the `EMAModel` class, we support a convenient method of tracking an exponential moving average of model parameters. This helps to smooth out noise in model parameter updates and generally improves model performance. If enabled with the `--use_ema` argument, the final model checkpoint that is saved at the end of training will use the EMA weights.
EMA weights require an additional full-precision copy of the model parameters to be stored in memory, but otherwise have very little performance overhead. `--foreach_ema` can be used to further reduce the overhead. If you are short on VRAM and still want to use EMA weights, you can store them in CPU RAM by using the `--offload_ema` argument. This will keep the EMA weights in pinned CPU memory during the training step. Then, once every model parameter update, it will transfer the EMA weights back to the GPU which can then update the parameters on the GPU, before sending them back to the CPU. Both of these transfers are set up as non-blocking, so CUDA devices should be able to overlap this transfer with other computations. With sufficient bandwidth between the host and device and a sufficiently long gap between model parameter updates, storing EMA weights in CPU RAM should have no additional performance overhead, as long as no other calls force synchronization.
#### Training with DREAM #### Training with DREAM
We support training epsilon (noise) prediction models using the [DREAM (Diffusion Rectification and Estimation-Adaptive Models) strategy](https://arxiv.org/abs/2312.00210). DREAM claims to increase model fidelity for the performance cost of an extra grad-less unet `forward` step in the training loop. You can turn on DREAM training by using the `--dream_training` argument. The `--dream_detail_preservation` argument controls the detail preservation variable p and is the default of 1 from the paper. We support training epsilon (noise) prediction models using the [DREAM (Diffusion Rectification and Estimation-Adaptive Models) strategy](https://arxiv.org/abs/2312.00210). DREAM claims to increase model fidelity for the performance cost of an extra grad-less unet `forward` step in the training loop. You can turn on DREAM training by using the `--dream_training` argument. The `--dream_detail_preservation` argument controls the detail preservation variable p and is the default of 1 from the paper.
## Training with LoRA ## Training with LoRA
Low-Rank Adaption of Large Language Models was first introduced by Microsoft in [LoRA: Low-Rank Adaptation of Large Language Models](https://arxiv.org/abs/2106.09685) by *Edward J. Hu, Yelong Shen, Phillip Wallis, Zeyuan Allen-Zhu, Yuanzhi Li, Shean Wang, Lu Wang, Weizhu Chen*. Low-Rank Adaption of Large Language Models was first introduced by Microsoft in [LoRA: Low-Rank Adaptation of Large Language Models](https://arxiv.org/abs/2106.09685) by *Edward J. Hu, Yelong Shen, Phillip Wallis, Zeyuan Allen-Zhu, Yuanzhi Li, Shean Wang, Lu Wang, Weizhu Chen*.
-1
View File
@@ -239,7 +239,6 @@ accelerate launch --config_file $ACCELERATE_CONFIG_FILE train_text_to_image_lor
--seed=1234 \ --seed=1234 \
--output_dir="sd-naruto-model-lora-sdxl" \ --output_dir="sd-naruto-model-lora-sdxl" \
--validation_prompt="cute dragon creature" --validation_prompt="cute dragon creature"
``` ```
+24 -5
View File
@@ -57,7 +57,7 @@ if is_wandb_available():
# Will error if the minimal version of diffusers is not installed. Remove at your own risks. # Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.29.0.dev0") check_min_version("0.30.0.dev0")
logger = get_logger(__name__, log_level="INFO") logger = get_logger(__name__, log_level="INFO")
@@ -387,6 +387,8 @@ def parse_args():
), ),
) )
parser.add_argument("--use_ema", action="store_true", help="Whether to use EMA model.") parser.add_argument("--use_ema", action="store_true", help="Whether to use EMA model.")
parser.add_argument("--offload_ema", action="store_true", help="Offload EMA model to CPU during training step.")
parser.add_argument("--foreach_ema", action="store_true", help="Use faster foreach implementation of EMAModel.")
parser.add_argument( parser.add_argument(
"--non_ema_revision", "--non_ema_revision",
type=str, type=str,
@@ -624,7 +626,12 @@ def main():
ema_unet = UNet2DConditionModel.from_pretrained( ema_unet = UNet2DConditionModel.from_pretrained(
args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision, variant=args.variant args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision, variant=args.variant
) )
ema_unet = EMAModel(ema_unet.parameters(), model_cls=UNet2DConditionModel, model_config=ema_unet.config) ema_unet = EMAModel(
ema_unet.parameters(),
model_cls=UNet2DConditionModel,
model_config=ema_unet.config,
foreach=args.foreach_ema,
)
if args.enable_xformers_memory_efficient_attention: if args.enable_xformers_memory_efficient_attention:
if is_xformers_available(): if is_xformers_available():
@@ -655,9 +662,14 @@ def main():
def load_model_hook(models, input_dir): def load_model_hook(models, input_dir):
if args.use_ema: if args.use_ema:
load_model = EMAModel.from_pretrained(os.path.join(input_dir, "unet_ema"), UNet2DConditionModel) load_model = EMAModel.from_pretrained(
os.path.join(input_dir, "unet_ema"), UNet2DConditionModel, foreach=args.foreach_ema
)
ema_unet.load_state_dict(load_model.state_dict()) ema_unet.load_state_dict(load_model.state_dict())
ema_unet.to(accelerator.device) if args.offload_ema:
ema_unet.pin_memory()
else:
ema_unet.to(accelerator.device)
del load_model del load_model
for _ in range(len(models)): for _ in range(len(models)):
@@ -833,7 +845,10 @@ def main():
) )
if args.use_ema: if args.use_ema:
ema_unet.to(accelerator.device) if args.offload_ema:
ema_unet.pin_memory()
else:
ema_unet.to(accelerator.device)
# For mixed precision training we cast all non-trainable weights (vae, non-lora text_encoder and non-lora unet) to half-precision # For mixed precision training we cast all non-trainable weights (vae, non-lora text_encoder and non-lora unet) to half-precision
# as these weights are only used for inference, keeping weights in full precision is not required. # as these weights are only used for inference, keeping weights in full precision is not required.
@@ -1011,7 +1026,11 @@ def main():
# Checks if the accelerator has performed an optimization step behind the scenes # Checks if the accelerator has performed an optimization step behind the scenes
if accelerator.sync_gradients: if accelerator.sync_gradients:
if args.use_ema: if args.use_ema:
if args.offload_ema:
ema_unet.to(device="cuda", non_blocking=True)
ema_unet.step(unet.parameters()) ema_unet.step(unet.parameters())
if args.offload_ema:
ema_unet.to(device="cpu", non_blocking=True)
progress_bar.update(1) progress_bar.update(1)
global_step += 1 global_step += 1
accelerator.log({"train_loss": train_loss}, step=global_step) accelerator.log({"train_loss": train_loss}, step=global_step)
@@ -49,7 +49,7 @@ from diffusers.utils import check_min_version
# Will error if the minimal version of diffusers is not installed. Remove at your own risks. # Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.29.0.dev0") check_min_version("0.30.0.dev0")
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -56,7 +56,7 @@ if is_wandb_available():
import wandb import wandb
# Will error if the minimal version of diffusers is not installed. Remove at your own risks. # Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.29.0.dev0") check_min_version("0.30.0.dev0")
logger = get_logger(__name__, log_level="INFO") logger = get_logger(__name__, log_level="INFO")
@@ -68,7 +68,7 @@ if is_wandb_available():
import wandb import wandb
# Will error if the minimal version of diffusers is not installed. Remove at your own risks. # Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.29.0.dev0") check_min_version("0.30.0.dev0")
logger = get_logger(__name__) logger = get_logger(__name__)
if is_torch_npu_available(): if is_torch_npu_available():
@@ -55,7 +55,7 @@ from diffusers.utils.torch_utils import is_compiled_module
# Will error if the minimal version of diffusers is not installed. Remove at your own risks. # Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.29.0.dev0") check_min_version("0.30.0.dev0")
logger = get_logger(__name__) logger = get_logger(__name__)
if is_torch_npu_available(): if is_torch_npu_available():
@@ -81,7 +81,7 @@ else:
# Will error if the minimal version of diffusers is not installed. Remove at your own risks. # Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.29.0.dev0") check_min_version("0.30.0.dev0")
logger = get_logger(__name__) logger = get_logger(__name__)
@@ -56,7 +56,7 @@ else:
# ------------------------------------------------------------------------------ # ------------------------------------------------------------------------------
# Will error if the minimal version of diffusers is not installed. Remove at your own risks. # Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.29.0.dev0") check_min_version("0.30.0.dev0")
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -76,7 +76,7 @@ else:
# Will error if the minimal version of diffusers is not installed. Remove at your own risks. # Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.29.0.dev0") check_min_version("0.30.0.dev0")
logger = get_logger(__name__) logger = get_logger(__name__)
@@ -29,7 +29,7 @@ from diffusers.utils.import_utils import is_xformers_available
# Will error if the minimal version of diffusers is not installed. Remove at your own risks. # Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.29.0.dev0") check_min_version("0.30.0.dev0")
logger = get_logger(__name__, log_level="INFO") logger = get_logger(__name__, log_level="INFO")
+1 -1
View File
@@ -50,7 +50,7 @@ if is_wandb_available():
import wandb import wandb
# Will error if the minimal version of diffusers is not installed. Remove at your own risks. # Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.29.0.dev0") check_min_version("0.30.0.dev0")
logger = get_logger(__name__, log_level="INFO") logger = get_logger(__name__, log_level="INFO")
@@ -50,7 +50,7 @@ if is_wandb_available():
# Will error if the minimal version of diffusers is not installed. Remove at your own risks. # Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.29.0.dev0") check_min_version("0.30.0.dev0")
logger = get_logger(__name__, log_level="INFO") logger = get_logger(__name__, log_level="INFO")
@@ -51,7 +51,7 @@ if is_wandb_available():
# Will error if the minimal version of diffusers is not installed. Remove at your own risks. # Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.29.0.dev0") check_min_version("0.30.0.dev0")
logger = get_logger(__name__, log_level="INFO") logger = get_logger(__name__, log_level="INFO")
+3 -3
View File
@@ -95,7 +95,7 @@ from setuptools import Command, find_packages, setup
# 2. once modified, run: `make deps_table_update` to update src/diffusers/dependency_versions_table.py # 2. once modified, run: `make deps_table_update` to update src/diffusers/dependency_versions_table.py
_deps = [ _deps = [
"Pillow", # keep the PIL.Image.Resampling deprecation away "Pillow", # keep the PIL.Image.Resampling deprecation away
"accelerate>=0.29.3", "accelerate>=0.31.0",
"compel==0.1.8", "compel==0.1.8",
"datasets", "datasets",
"filelock", "filelock",
@@ -132,7 +132,7 @@ _deps = [
"tensorboard", "tensorboard",
"torch>=1.4", "torch>=1.4",
"torchvision", "torchvision",
"transformers>=4.25.1", "transformers>=4.41.2",
"urllib3<=2.0.0", "urllib3<=2.0.0",
"black", "black",
] ]
@@ -254,7 +254,7 @@ version_range_max = max(sys.version_info[1], 10) + 1
setup( setup(
name="diffusers", name="diffusers",
version="0.29.0.dev0", # expected format is one of x.y.z.dev0, or x.y.z.rc1 or x.y.z (no to dashes, yes to dots) version="0.30.0.dev0", # expected format is one of x.y.z.dev0, or x.y.z.rc1 or x.y.z (no to dashes, yes to dots)
description="State-of-the-art diffusion in PyTorch and JAX.", description="State-of-the-art diffusion in PyTorch and JAX.",
long_description=open("README.md", "r", encoding="utf-8").read(), long_description=open("README.md", "r", encoding="utf-8").read(),
long_description_content_type="text/markdown", long_description_content_type="text/markdown",
+7 -1
View File
@@ -1,4 +1,4 @@
__version__ = "0.29.0.dev0" __version__ = "0.30.0.dev0"
from typing import TYPE_CHECKING from typing import TYPE_CHECKING
@@ -91,6 +91,8 @@ else:
"MultiAdapter", "MultiAdapter",
"PixArtTransformer2DModel", "PixArtTransformer2DModel",
"PriorTransformer", "PriorTransformer",
"SD3ControlNetModel",
"SD3MultiControlNetModel",
"SD3Transformer2DModel", "SD3Transformer2DModel",
"StableCascadeUNet", "StableCascadeUNet",
"T2IAdapter", "T2IAdapter",
@@ -278,6 +280,7 @@ else:
"StableCascadeCombinedPipeline", "StableCascadeCombinedPipeline",
"StableCascadeDecoderPipeline", "StableCascadeDecoderPipeline",
"StableCascadePriorPipeline", "StableCascadePriorPipeline",
"StableDiffusion3ControlNetPipeline",
"StableDiffusion3Img2ImgPipeline", "StableDiffusion3Img2ImgPipeline",
"StableDiffusion3Pipeline", "StableDiffusion3Pipeline",
"StableDiffusionAdapterPipeline", "StableDiffusionAdapterPipeline",
@@ -501,6 +504,8 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
MultiAdapter, MultiAdapter,
PixArtTransformer2DModel, PixArtTransformer2DModel,
PriorTransformer, PriorTransformer,
SD3ControlNetModel,
SD3MultiControlNetModel,
SD3Transformer2DModel, SD3Transformer2DModel,
T2IAdapter, T2IAdapter,
T5FilmDecoder, T5FilmDecoder,
@@ -666,6 +671,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
StableCascadeCombinedPipeline, StableCascadeCombinedPipeline,
StableCascadeDecoderPipeline, StableCascadeDecoderPipeline,
StableCascadePriorPipeline, StableCascadePriorPipeline,
StableDiffusion3ControlNetPipeline,
StableDiffusion3Img2ImgPipeline, StableDiffusion3Img2ImgPipeline,
StableDiffusion3Pipeline, StableDiffusion3Pipeline,
StableDiffusionAdapterPipeline, StableDiffusionAdapterPipeline,
+1 -1
View File
@@ -716,7 +716,7 @@ class LegacyConfigMixin(ConfigMixin):
@classmethod @classmethod
def from_config(cls, config: Union[FrozenDict, Dict[str, Any]] = None, return_unused_kwargs=False, **kwargs): def from_config(cls, config: Union[FrozenDict, Dict[str, Any]] = None, return_unused_kwargs=False, **kwargs):
# To prevent depedency import problem. # To prevent dependency import problem.
from .models.model_loading_utils import _fetch_remapped_cls_from_config from .models.model_loading_utils import _fetch_remapped_cls_from_config
# resolve remapping # resolve remapping
+2 -2
View File
@@ -3,7 +3,7 @@
# 2. run `make deps_table_update` # 2. run `make deps_table_update`
deps = { deps = {
"Pillow": "Pillow", "Pillow": "Pillow",
"accelerate": "accelerate>=0.29.3", "accelerate": "accelerate>=0.31.0",
"compel": "compel==0.1.8", "compel": "compel==0.1.8",
"datasets": "datasets", "datasets": "datasets",
"filelock": "filelock", "filelock": "filelock",
@@ -40,7 +40,7 @@ deps = {
"tensorboard": "tensorboard", "tensorboard": "tensorboard",
"torch": "torch>=1.4", "torch": "torch>=1.4",
"torchvision": "torchvision", "torchvision": "torchvision",
"transformers": "transformers>=4.25.1", "transformers": "transformers>=4.41.2",
"urllib3": "urllib3<=2.0.0", "urllib3": "urllib3<=2.0.0",
"black": "black", "black": "black",
} }
+1 -2
View File
@@ -569,7 +569,7 @@ class VaeImageProcessor(ConfigMixin):
channel = image.shape[1] channel = image.shape[1]
# don't need any preprocess if the image is latents # don't need any preprocess if the image is latents
if channel == 4: if channel == self.vae_latent_channels:
return image return image
height, width = self.get_default_height_width(image, height, width) height, width = self.get_default_height_width(image, height, width)
@@ -585,7 +585,6 @@ class VaeImageProcessor(ConfigMixin):
FutureWarning, FutureWarning,
) )
do_normalize = False do_normalize = False
if do_normalize: if do_normalize:
image = self.normalize(image) image = self.normalize(image)
+94 -12
View File
@@ -30,6 +30,7 @@ from ..utils import (
_get_model_file, _get_model_file,
convert_state_dict_to_diffusers, convert_state_dict_to_diffusers,
convert_state_dict_to_peft, convert_state_dict_to_peft,
convert_unet_state_dict_to_peft,
delete_adapter_layers, delete_adapter_layers,
get_adapter_name, get_adapter_name,
get_peft_kwargs, get_peft_kwargs,
@@ -42,7 +43,7 @@ from ..utils import (
set_adapter_layers, set_adapter_layers,
set_weights_and_activate_adapters, set_weights_and_activate_adapters,
) )
from .lora_conversion_utils import _convert_kohya_lora_to_diffusers, _maybe_map_sgm_blocks_to_diffusers from .lora_conversion_utils import _convert_non_diffusers_lora_to_diffusers, _maybe_map_sgm_blocks_to_diffusers
if is_transformers_available(): if is_transformers_available():
@@ -287,7 +288,7 @@ class LoraLoaderMixin:
if unet_config is not None: if unet_config is not None:
# use unet config to remap block numbers # use unet config to remap block numbers
state_dict = _maybe_map_sgm_blocks_to_diffusers(state_dict, unet_config) state_dict = _maybe_map_sgm_blocks_to_diffusers(state_dict, unet_config)
state_dict, network_alphas = _convert_kohya_lora_to_diffusers(state_dict) state_dict, network_alphas = _convert_non_diffusers_lora_to_diffusers(state_dict)
return state_dict, network_alphas return state_dict, network_alphas
@@ -462,17 +463,18 @@ class LoraLoaderMixin:
text_encoder_lora_state_dict = convert_state_dict_to_peft(text_encoder_lora_state_dict) text_encoder_lora_state_dict = convert_state_dict_to_peft(text_encoder_lora_state_dict)
for name, _ in text_encoder_attn_modules(text_encoder): for name, _ in text_encoder_attn_modules(text_encoder):
rank_key = f"{name}.out_proj.lora_B.weight" for module in ("out_proj", "q_proj", "k_proj", "v_proj"):
rank[rank_key] = text_encoder_lora_state_dict[rank_key].shape[1] rank_key = f"{name}.{module}.lora_B.weight"
if rank_key not in text_encoder_lora_state_dict:
continue
rank[rank_key] = text_encoder_lora_state_dict[rank_key].shape[1]
patch_mlp = any(".mlp." in key for key in text_encoder_lora_state_dict.keys()) for name, _ in text_encoder_mlp_modules(text_encoder):
if patch_mlp: for module in ("fc1", "fc2"):
for name, _ in text_encoder_mlp_modules(text_encoder): rank_key = f"{name}.{module}.lora_B.weight"
rank_key_fc1 = f"{name}.fc1.lora_B.weight" if rank_key not in text_encoder_lora_state_dict:
rank_key_fc2 = f"{name}.fc2.lora_B.weight" continue
rank[rank_key] = text_encoder_lora_state_dict[rank_key].shape[1]
rank[rank_key_fc1] = text_encoder_lora_state_dict[rank_key_fc1].shape[1]
rank[rank_key_fc2] = text_encoder_lora_state_dict[rank_key_fc2].shape[1]
if network_alphas is not None: if network_alphas is not None:
alpha_keys = [ alpha_keys = [
@@ -1542,6 +1544,11 @@ class SD3LoraLoaderMixin:
} }
if len(state_dict.keys()) > 0: if len(state_dict.keys()) > 0:
# check with first key if is not in peft format
first_key = next(iter(state_dict.keys()))
if "lora_A" not in first_key:
state_dict = convert_unet_state_dict_to_peft(state_dict)
if adapter_name in getattr(transformer, "peft_config", {}): if adapter_name in getattr(transformer, "peft_config", {}):
raise ValueError( raise ValueError(
f"Adapter name {adapter_name} already in use in the transformer - please select a new adapter name." f"Adapter name {adapter_name} already in use in the transformer - please select a new adapter name."
@@ -1727,3 +1734,78 @@ class SD3LoraLoaderMixin:
remove_hook_from_module(component, recurse=is_sequential_cpu_offload) remove_hook_from_module(component, recurse=is_sequential_cpu_offload)
return (is_model_cpu_offload, is_sequential_cpu_offload) return (is_model_cpu_offload, is_sequential_cpu_offload)
def fuse_lora(
self,
fuse_transformer: bool = True,
lora_scale: float = 1.0,
safe_fusing: bool = False,
adapter_names: Optional[List[str]] = None,
):
r"""
Fuses the LoRA parameters into the original parameters of the corresponding blocks.
<Tip warning={true}>
This is an experimental API.
</Tip>
Args:
fuse_transformer (`bool`, defaults to `True`): Whether to fuse the transformer LoRA parameters.
lora_scale (`float`, defaults to 1.0):
Controls how much to influence the outputs with the LoRA parameters.
safe_fusing (`bool`, defaults to `False`):
Whether to check fused weights for NaN values before fusing and if values are NaN not fusing them.
adapter_names (`List[str]`, *optional*):
Adapter names to be used for fusing. If nothing is passed, all active adapters will be fused.
Example:
```py
from diffusers import DiffusionPipeline
import torch
pipeline = DiffusionPipeline.from_pretrained(
"stabilityai/stable-diffusion-3-medium-diffusers", torch_dtype=torch.float16
).to("cuda")
pipeline.load_lora_weights(
"nerijs/pixel-art-medium-128-v0.1",
weight_name="pixel-art-medium-128-v0.1.safetensors",
adapter_name="pixel",
)
pipeline.fuse_lora(lora_scale=0.7)
```
"""
if fuse_transformer:
self.num_fused_loras += 1
if fuse_transformer:
transformer = (
getattr(self, self.transformer_name) if not hasattr(self, "transformer") else self.transformer
)
transformer.fuse_lora(lora_scale, safe_fusing=safe_fusing, adapter_names=adapter_names)
def unfuse_lora(self, unfuse_transformer: bool = True):
r"""
Reverses the effect of
[`pipe.fuse_lora()`](https://huggingface.co/docs/diffusers/main/en/api/loaders#diffusers.loaders.LoraLoaderMixin.fuse_lora).
<Tip warning={true}>
This is an experimental API.
</Tip>
Args:
unfuse_transformer (`bool`, defaults to `True`): Whether to unfuse the transformer LoRA parameters.
"""
from peft.tuners.tuners_utils import BaseTunerLayer
transformer = getattr(self, self.transformer_name) if not hasattr(self, "transformer") else self.transformer
if unfuse_transformer:
for module in transformer.modules():
if isinstance(module, BaseTunerLayer):
module.unmerge()
self.num_fused_loras -= 1
+144 -109
View File
@@ -123,134 +123,76 @@ def _maybe_map_sgm_blocks_to_diffusers(state_dict, unet_config, delimiter="_", b
return new_state_dict return new_state_dict
def _convert_kohya_lora_to_diffusers(state_dict, unet_name="unet", text_encoder_name="text_encoder"): def _convert_non_diffusers_lora_to_diffusers(state_dict, unet_name="unet", text_encoder_name="text_encoder"):
"""
Converts a non-Diffusers LoRA state dict to a Diffusers compatible state dict.
Args:
state_dict (`dict`): The state dict to convert.
unet_name (`str`, optional): The name of the U-Net module in the Diffusers model. Defaults to "unet".
text_encoder_name (`str`, optional): The name of the text encoder module in the Diffusers model. Defaults to
"text_encoder".
Returns:
`tuple`: A tuple containing the converted state dict and a dictionary of alphas.
"""
unet_state_dict = {} unet_state_dict = {}
te_state_dict = {} te_state_dict = {}
te2_state_dict = {} te2_state_dict = {}
network_alphas = {} network_alphas = {}
is_unet_dora_lora = any("dora_scale" in k and "lora_unet_" in k for k in state_dict)
is_te_dora_lora = any("dora_scale" in k and ("lora_te_" in k or "lora_te1_" in k) for k in state_dict)
is_te2_dora_lora = any("dora_scale" in k and "lora_te2_" in k for k in state_dict)
if is_unet_dora_lora or is_te_dora_lora or is_te2_dora_lora: # Check for DoRA-enabled LoRAs.
if any(
"dora_scale" in k and ("lora_unet_" in k or "lora_te_" in k or "lora_te1_" in k or "lora_te2_" in k)
for k in state_dict
):
if is_peft_version("<", "0.9.0"): if is_peft_version("<", "0.9.0"):
raise ValueError( raise ValueError(
"You need `peft` 0.9.0 at least to use DoRA-enabled LoRAs. Please upgrade your installation of `peft`." "You need `peft` 0.9.0 at least to use DoRA-enabled LoRAs. Please upgrade your installation of `peft`."
) )
# every down weight has a corresponding up weight and potentially an alpha weight # Iterate over all LoRA weights.
lora_keys = [k for k in state_dict.keys() if k.endswith("lora_down.weight")] all_lora_keys = list(state_dict.keys())
for key in lora_keys: for key in all_lora_keys:
if not key.endswith("lora_down.weight"):
continue
# Extract LoRA name.
lora_name = key.split(".")[0] lora_name = key.split(".")[0]
# Find corresponding up weight and alpha.
lora_name_up = lora_name + ".lora_up.weight" lora_name_up = lora_name + ".lora_up.weight"
lora_name_alpha = lora_name + ".alpha" lora_name_alpha = lora_name + ".alpha"
# Handle U-Net LoRAs.
if lora_name.startswith("lora_unet_"): if lora_name.startswith("lora_unet_"):
diffusers_name = key.replace("lora_unet_", "").replace("_", ".") diffusers_name = _convert_unet_lora_key(key)
if "input.blocks" in diffusers_name: # Store down and up weights.
diffusers_name = diffusers_name.replace("input.blocks", "down_blocks") unet_state_dict[diffusers_name] = state_dict.pop(key)
else: unet_state_dict[diffusers_name.replace(".down.", ".up.")] = state_dict.pop(lora_name_up)
diffusers_name = diffusers_name.replace("down.blocks", "down_blocks")
if "middle.block" in diffusers_name: # Store DoRA scale if present.
diffusers_name = diffusers_name.replace("middle.block", "mid_block") if "dora_scale" in state_dict:
else:
diffusers_name = diffusers_name.replace("mid.block", "mid_block")
if "output.blocks" in diffusers_name:
diffusers_name = diffusers_name.replace("output.blocks", "up_blocks")
else:
diffusers_name = diffusers_name.replace("up.blocks", "up_blocks")
diffusers_name = diffusers_name.replace("transformer.blocks", "transformer_blocks")
diffusers_name = diffusers_name.replace("to.q.lora", "to_q_lora")
diffusers_name = diffusers_name.replace("to.k.lora", "to_k_lora")
diffusers_name = diffusers_name.replace("to.v.lora", "to_v_lora")
diffusers_name = diffusers_name.replace("to.out.0.lora", "to_out_lora")
diffusers_name = diffusers_name.replace("proj.in", "proj_in")
diffusers_name = diffusers_name.replace("proj.out", "proj_out")
diffusers_name = diffusers_name.replace("emb.layers", "time_emb_proj")
# SDXL specificity.
if "emb" in diffusers_name and "time.emb.proj" not in diffusers_name:
pattern = r"\.\d+(?=\D*$)"
diffusers_name = re.sub(pattern, "", diffusers_name, count=1)
if ".in." in diffusers_name:
diffusers_name = diffusers_name.replace("in.layers.2", "conv1")
if ".out." in diffusers_name:
diffusers_name = diffusers_name.replace("out.layers.3", "conv2")
if "downsamplers" in diffusers_name or "upsamplers" in diffusers_name:
diffusers_name = diffusers_name.replace("op", "conv")
if "skip" in diffusers_name:
diffusers_name = diffusers_name.replace("skip.connection", "conv_shortcut")
# LyCORIS specificity.
if "time.emb.proj" in diffusers_name:
diffusers_name = diffusers_name.replace("time.emb.proj", "time_emb_proj")
if "conv.shortcut" in diffusers_name:
diffusers_name = diffusers_name.replace("conv.shortcut", "conv_shortcut")
# General coverage.
if "transformer_blocks" in diffusers_name:
if "attn1" in diffusers_name or "attn2" in diffusers_name:
diffusers_name = diffusers_name.replace("attn1", "attn1.processor")
diffusers_name = diffusers_name.replace("attn2", "attn2.processor")
unet_state_dict[diffusers_name] = state_dict.pop(key)
unet_state_dict[diffusers_name.replace(".down.", ".up.")] = state_dict.pop(lora_name_up)
elif "ff" in diffusers_name:
unet_state_dict[diffusers_name] = state_dict.pop(key)
unet_state_dict[diffusers_name.replace(".down.", ".up.")] = state_dict.pop(lora_name_up)
elif any(key in diffusers_name for key in ("proj_in", "proj_out")):
unet_state_dict[diffusers_name] = state_dict.pop(key)
unet_state_dict[diffusers_name.replace(".down.", ".up.")] = state_dict.pop(lora_name_up)
else:
unet_state_dict[diffusers_name] = state_dict.pop(key)
unet_state_dict[diffusers_name.replace(".down.", ".up.")] = state_dict.pop(lora_name_up)
if is_unet_dora_lora:
dora_scale_key_to_replace = "_lora.down." if "_lora.down." in diffusers_name else ".lora.down." dora_scale_key_to_replace = "_lora.down." if "_lora.down." in diffusers_name else ".lora.down."
unet_state_dict[ unet_state_dict[
diffusers_name.replace(dora_scale_key_to_replace, ".lora_magnitude_vector.") diffusers_name.replace(dora_scale_key_to_replace, ".lora_magnitude_vector.")
] = state_dict.pop(key.replace("lora_down.weight", "dora_scale")) ] = state_dict.pop(key.replace("lora_down.weight", "dora_scale"))
# Handle text encoder LoRAs.
elif lora_name.startswith(("lora_te_", "lora_te1_", "lora_te2_")): elif lora_name.startswith(("lora_te_", "lora_te1_", "lora_te2_")):
diffusers_name = _convert_text_encoder_lora_key(key, lora_name)
# Store down and up weights for te or te2.
if lora_name.startswith(("lora_te_", "lora_te1_")): if lora_name.startswith(("lora_te_", "lora_te1_")):
key_to_replace = "lora_te_" if lora_name.startswith("lora_te_") else "lora_te1_" te_state_dict[diffusers_name] = state_dict.pop(key)
te_state_dict[diffusers_name.replace(".down.", ".up.")] = state_dict.pop(lora_name_up)
else: else:
key_to_replace = "lora_te2_"
diffusers_name = key.replace(key_to_replace, "").replace("_", ".")
diffusers_name = diffusers_name.replace("text.model", "text_model")
diffusers_name = diffusers_name.replace("self.attn", "self_attn")
diffusers_name = diffusers_name.replace("q.proj.lora", "to_q_lora")
diffusers_name = diffusers_name.replace("k.proj.lora", "to_k_lora")
diffusers_name = diffusers_name.replace("v.proj.lora", "to_v_lora")
diffusers_name = diffusers_name.replace("out.proj.lora", "to_out_lora")
diffusers_name = diffusers_name.replace("text.projection", "text_projection")
if "self_attn" in diffusers_name:
if lora_name.startswith(("lora_te_", "lora_te1_")):
te_state_dict[diffusers_name] = state_dict.pop(key)
te_state_dict[diffusers_name.replace(".down.", ".up.")] = state_dict.pop(lora_name_up)
else:
te2_state_dict[diffusers_name] = state_dict.pop(key)
te2_state_dict[diffusers_name.replace(".down.", ".up.")] = state_dict.pop(lora_name_up)
elif "mlp" in diffusers_name:
# Be aware that this is the new diffusers convention and the rest of the code might
# not utilize it yet.
diffusers_name = diffusers_name.replace(".lora.", ".lora_linear_layer.")
if lora_name.startswith(("lora_te_", "lora_te1_")):
te_state_dict[diffusers_name] = state_dict.pop(key)
te_state_dict[diffusers_name.replace(".down.", ".up.")] = state_dict.pop(lora_name_up)
else:
te2_state_dict[diffusers_name] = state_dict.pop(key)
te2_state_dict[diffusers_name.replace(".down.", ".up.")] = state_dict.pop(lora_name_up)
# OneTrainer specificity
elif "text_projection" in diffusers_name and lora_name.startswith("lora_te2_"):
te2_state_dict[diffusers_name] = state_dict.pop(key) te2_state_dict[diffusers_name] = state_dict.pop(key)
te2_state_dict[diffusers_name.replace(".down.", ".up.")] = state_dict.pop(lora_name_up) te2_state_dict[diffusers_name.replace(".down.", ".up.")] = state_dict.pop(lora_name_up)
if (is_te_dora_lora or is_te2_dora_lora) and lora_name.startswith(("lora_te_", "lora_te1_", "lora_te2_")): # Store DoRA scale if present.
if "dora_scale" in state_dict:
dora_scale_key_to_replace_te = ( dora_scale_key_to_replace_te = (
"_lora.down." if "_lora.down." in diffusers_name else ".lora_linear_layer." "_lora.down." if "_lora.down." in diffusers_name else ".lora_linear_layer."
) )
@@ -263,22 +205,18 @@ def _convert_kohya_lora_to_diffusers(state_dict, unet_name="unet", text_encoder_
diffusers_name.replace(dora_scale_key_to_replace_te, ".lora_magnitude_vector.") diffusers_name.replace(dora_scale_key_to_replace_te, ".lora_magnitude_vector.")
] = state_dict.pop(key.replace("lora_down.weight", "dora_scale")) ] = state_dict.pop(key.replace("lora_down.weight", "dora_scale"))
# Rename the alphas so that they can be mapped appropriately. # Store alpha if present.
if lora_name_alpha in state_dict: if lora_name_alpha in state_dict:
alpha = state_dict.pop(lora_name_alpha).item() alpha = state_dict.pop(lora_name_alpha).item()
if lora_name_alpha.startswith("lora_unet_"): network_alphas.update(_get_alpha_name(lora_name_alpha, diffusers_name, alpha))
prefix = "unet."
elif lora_name_alpha.startswith(("lora_te_", "lora_te1_")):
prefix = "text_encoder."
else:
prefix = "text_encoder_2."
new_name = prefix + diffusers_name.split(".lora.")[0] + ".alpha"
network_alphas.update({new_name: alpha})
# Check if any keys remain.
if len(state_dict) > 0: if len(state_dict) > 0:
raise ValueError(f"The following keys have not been correctly renamed: \n\n {', '.join(state_dict.keys())}") raise ValueError(f"The following keys have not been correctly renamed: \n\n {', '.join(state_dict.keys())}")
logger.info("Kohya-style checkpoint detected.") logger.info("Kohya-style checkpoint detected.")
# Construct final state dict.
unet_state_dict = {f"{unet_name}.{module_name}": params for module_name, params in unet_state_dict.items()} unet_state_dict = {f"{unet_name}.{module_name}": params for module_name, params in unet_state_dict.items()}
te_state_dict = {f"{text_encoder_name}.{module_name}": params for module_name, params in te_state_dict.items()} te_state_dict = {f"{text_encoder_name}.{module_name}": params for module_name, params in te_state_dict.items()}
te2_state_dict = ( te2_state_dict = (
@@ -291,3 +229,100 @@ def _convert_kohya_lora_to_diffusers(state_dict, unet_name="unet", text_encoder_
new_state_dict = {**unet_state_dict, **te_state_dict} new_state_dict = {**unet_state_dict, **te_state_dict}
return new_state_dict, network_alphas return new_state_dict, network_alphas
def _convert_unet_lora_key(key):
"""
Converts a U-Net LoRA key to a Diffusers compatible key.
"""
diffusers_name = key.replace("lora_unet_", "").replace("_", ".")
# Replace common U-Net naming patterns.
diffusers_name = diffusers_name.replace("input.blocks", "down_blocks")
diffusers_name = diffusers_name.replace("down.blocks", "down_blocks")
diffusers_name = diffusers_name.replace("middle.block", "mid_block")
diffusers_name = diffusers_name.replace("mid.block", "mid_block")
diffusers_name = diffusers_name.replace("output.blocks", "up_blocks")
diffusers_name = diffusers_name.replace("up.blocks", "up_blocks")
diffusers_name = diffusers_name.replace("transformer.blocks", "transformer_blocks")
diffusers_name = diffusers_name.replace("to.q.lora", "to_q_lora")
diffusers_name = diffusers_name.replace("to.k.lora", "to_k_lora")
diffusers_name = diffusers_name.replace("to.v.lora", "to_v_lora")
diffusers_name = diffusers_name.replace("to.out.0.lora", "to_out_lora")
diffusers_name = diffusers_name.replace("proj.in", "proj_in")
diffusers_name = diffusers_name.replace("proj.out", "proj_out")
diffusers_name = diffusers_name.replace("emb.layers", "time_emb_proj")
# SDXL specific conversions.
if "emb" in diffusers_name and "time.emb.proj" not in diffusers_name:
pattern = r"\.\d+(?=\D*$)"
diffusers_name = re.sub(pattern, "", diffusers_name, count=1)
if ".in." in diffusers_name:
diffusers_name = diffusers_name.replace("in.layers.2", "conv1")
if ".out." in diffusers_name:
diffusers_name = diffusers_name.replace("out.layers.3", "conv2")
if "downsamplers" in diffusers_name or "upsamplers" in diffusers_name:
diffusers_name = diffusers_name.replace("op", "conv")
if "skip" in diffusers_name:
diffusers_name = diffusers_name.replace("skip.connection", "conv_shortcut")
# LyCORIS specific conversions.
if "time.emb.proj" in diffusers_name:
diffusers_name = diffusers_name.replace("time.emb.proj", "time_emb_proj")
if "conv.shortcut" in diffusers_name:
diffusers_name = diffusers_name.replace("conv.shortcut", "conv_shortcut")
# General conversions.
if "transformer_blocks" in diffusers_name:
if "attn1" in diffusers_name or "attn2" in diffusers_name:
diffusers_name = diffusers_name.replace("attn1", "attn1.processor")
diffusers_name = diffusers_name.replace("attn2", "attn2.processor")
elif "ff" in diffusers_name:
pass
elif any(key in diffusers_name for key in ("proj_in", "proj_out")):
pass
else:
pass
return diffusers_name
def _convert_text_encoder_lora_key(key, lora_name):
"""
Converts a text encoder LoRA key to a Diffusers compatible key.
"""
if lora_name.startswith(("lora_te_", "lora_te1_")):
key_to_replace = "lora_te_" if lora_name.startswith("lora_te_") else "lora_te1_"
else:
key_to_replace = "lora_te2_"
diffusers_name = key.replace(key_to_replace, "").replace("_", ".")
diffusers_name = diffusers_name.replace("text.model", "text_model")
diffusers_name = diffusers_name.replace("self.attn", "self_attn")
diffusers_name = diffusers_name.replace("q.proj.lora", "to_q_lora")
diffusers_name = diffusers_name.replace("k.proj.lora", "to_k_lora")
diffusers_name = diffusers_name.replace("v.proj.lora", "to_v_lora")
diffusers_name = diffusers_name.replace("out.proj.lora", "to_out_lora")
diffusers_name = diffusers_name.replace("text.projection", "text_projection")
if "self_attn" in diffusers_name or "text_projection" in diffusers_name:
pass
elif "mlp" in diffusers_name:
# Be aware that this is the new diffusers convention and the rest of the code might
# not utilize it yet.
diffusers_name = diffusers_name.replace(".lora.", ".lora_linear_layer.")
return diffusers_name
def _get_alpha_name(lora_name_alpha, diffusers_name, alpha):
"""
Gets the correct alpha name for the Diffusers model.
"""
if lora_name_alpha.startswith("lora_unet_"):
prefix = "unet."
elif lora_name_alpha.startswith(("lora_te_", "lora_te1_")):
prefix = "text_encoder."
else:
prefix = "text_encoder_2."
new_name = prefix + diffusers_name.split(".lora.")[0] + ".alpha"
return {new_name: alpha}
+12
View File
@@ -28,9 +28,11 @@ from .single_file_utils import (
_legacy_load_safety_checker, _legacy_load_safety_checker,
_legacy_load_scheduler, _legacy_load_scheduler,
create_diffusers_clip_model_from_ldm, create_diffusers_clip_model_from_ldm,
create_diffusers_t5_model_from_checkpoint,
fetch_diffusers_config, fetch_diffusers_config,
fetch_original_config, fetch_original_config,
is_clip_model_in_single_file, is_clip_model_in_single_file,
is_t5_in_single_file,
load_single_file_checkpoint, load_single_file_checkpoint,
) )
@@ -118,6 +120,16 @@ def load_single_file_sub_model(
is_legacy_loading=is_legacy_loading, is_legacy_loading=is_legacy_loading,
) )
elif is_transformers_model and is_t5_in_single_file(checkpoint):
loaded_sub_model = create_diffusers_t5_model_from_checkpoint(
class_obj,
checkpoint=checkpoint,
config=cached_model_config_path,
subfolder=name,
torch_dtype=torch_dtype,
local_files_only=local_files_only,
)
elif is_tokenizer and is_legacy_loading: elif is_tokenizer and is_legacy_loading:
loaded_sub_model = _legacy_load_clip_tokenizer( loaded_sub_model = _legacy_load_clip_tokenizer(
class_obj, checkpoint=checkpoint, config=cached_model_config_path, local_files_only=local_files_only class_obj, checkpoint=checkpoint, config=cached_model_config_path, local_files_only=local_files_only
+10 -8
View File
@@ -276,16 +276,18 @@ class FromOriginalModelMixin:
if is_accelerate_available(): if is_accelerate_available():
unexpected_keys = load_model_dict_into_meta(model, diffusers_format_checkpoint, dtype=torch_dtype) unexpected_keys = load_model_dict_into_meta(model, diffusers_format_checkpoint, dtype=torch_dtype)
if model._keys_to_ignore_on_load_unexpected is not None:
for pat in model._keys_to_ignore_on_load_unexpected:
unexpected_keys = [k for k in unexpected_keys if re.search(pat, k) is None]
if len(unexpected_keys) > 0:
logger.warning(
f"Some weights of the model checkpoint were not used when initializing {cls.__name__}: \n {[', '.join(unexpected_keys)]}"
)
else: else:
model.load_state_dict(diffusers_format_checkpoint) _, unexpected_keys = model.load_state_dict(diffusers_format_checkpoint, strict=False)
if model._keys_to_ignore_on_load_unexpected is not None:
for pat in model._keys_to_ignore_on_load_unexpected:
unexpected_keys = [k for k in unexpected_keys if re.search(pat, k) is None]
if len(unexpected_keys) > 0:
logger.warning(
f"Some weights of the model checkpoint were not used when initializing {cls.__name__}: \n {[', '.join(unexpected_keys)]}"
)
if torch_dtype is not None: if torch_dtype is not None:
model.to(torch_dtype) model.to(torch_dtype)
+34 -24
View File
@@ -252,7 +252,6 @@ LDM_CONTROLNET_KEY = "control_model."
LDM_CLIP_PREFIX_TO_REMOVE = [ LDM_CLIP_PREFIX_TO_REMOVE = [
"cond_stage_model.transformer.", "cond_stage_model.transformer.",
"conditioner.embedders.0.transformer.", "conditioner.embedders.0.transformer.",
"text_encoders.clip_l.transformer.",
] ]
OPEN_CLIP_PREFIX = "conditioner.embedders.0.model." OPEN_CLIP_PREFIX = "conditioner.embedders.0.model."
LDM_OPEN_CLIP_TEXT_PROJECTION_DIM = 1024 LDM_OPEN_CLIP_TEXT_PROJECTION_DIM = 1024
@@ -399,11 +398,14 @@ def is_open_clip_sdxl_model(checkpoint):
def is_open_clip_sd3_model(checkpoint): def is_open_clip_sd3_model(checkpoint):
is_open_clip_sdxl_refiner_model(checkpoint) if CHECKPOINT_KEY_NAMES["open_clip_sd3"] in checkpoint:
return True
return False
def is_open_clip_sdxl_refiner_model(checkpoint): def is_open_clip_sdxl_refiner_model(checkpoint):
if CHECKPOINT_KEY_NAMES["open_clip_sd3"] in checkpoint: if CHECKPOINT_KEY_NAMES["open_clip_sdxl_refiner"] in checkpoint:
return True return True
return False return False
@@ -1233,11 +1235,14 @@ def convert_ldm_vae_checkpoint(checkpoint, config):
return new_checkpoint return new_checkpoint
def convert_ldm_clip_checkpoint(checkpoint): def convert_ldm_clip_checkpoint(checkpoint, remove_prefix=None):
keys = list(checkpoint.keys()) keys = list(checkpoint.keys())
text_model_dict = {} text_model_dict = {}
remove_prefixes = LDM_CLIP_PREFIX_TO_REMOVE remove_prefixes = []
remove_prefixes.extend(LDM_CLIP_PREFIX_TO_REMOVE)
if remove_prefix:
remove_prefixes.append(remove_prefix)
for key in keys: for key in keys:
for prefix in remove_prefixes: for prefix in remove_prefixes:
@@ -1263,8 +1268,6 @@ def convert_open_clip_checkpoint(
else: else:
text_proj_dim = LDM_OPEN_CLIP_TEXT_PROJECTION_DIM text_proj_dim = LDM_OPEN_CLIP_TEXT_PROJECTION_DIM
text_model_dict["text_model.embeddings.position_ids"] = text_model.text_model.embeddings.get_buffer("position_ids")
keys = list(checkpoint.keys()) keys = list(checkpoint.keys())
keys_to_ignore = SD_2_TEXT_ENCODER_KEYS_TO_IGNORE keys_to_ignore = SD_2_TEXT_ENCODER_KEYS_TO_IGNORE
@@ -1313,9 +1316,6 @@ def convert_open_clip_checkpoint(
else: else:
text_model_dict[diffusers_key] = checkpoint.get(key) text_model_dict[diffusers_key] = checkpoint.get(key)
if not (hasattr(text_model, "embeddings") and hasattr(text_model.embeddings.position_ids)):
text_model_dict.pop("text_model.embeddings.position_ids", None)
return text_model_dict return text_model_dict
@@ -1376,6 +1376,13 @@ def create_diffusers_clip_model_from_ldm(
): ):
diffusers_format_checkpoint = convert_ldm_clip_checkpoint(checkpoint) diffusers_format_checkpoint = convert_ldm_clip_checkpoint(checkpoint)
elif (
is_clip_sd3_model(checkpoint)
and checkpoint[CHECKPOINT_KEY_NAMES["clip_sd3"]].shape[-1] == position_embedding_dim
):
diffusers_format_checkpoint = convert_ldm_clip_checkpoint(checkpoint, "text_encoders.clip_l.transformer.")
diffusers_format_checkpoint["text_projection.weight"] = torch.eye(position_embedding_dim)
elif is_open_clip_model(checkpoint): elif is_open_clip_model(checkpoint):
prefix = "cond_stage_model.model." prefix = "cond_stage_model.model."
diffusers_format_checkpoint = convert_open_clip_checkpoint(model, checkpoint, prefix=prefix) diffusers_format_checkpoint = convert_open_clip_checkpoint(model, checkpoint, prefix=prefix)
@@ -1391,26 +1398,28 @@ def create_diffusers_clip_model_from_ldm(
prefix = "conditioner.embedders.0.model." prefix = "conditioner.embedders.0.model."
diffusers_format_checkpoint = convert_open_clip_checkpoint(model, checkpoint, prefix=prefix) diffusers_format_checkpoint = convert_open_clip_checkpoint(model, checkpoint, prefix=prefix)
elif is_open_clip_sd3_model(checkpoint): elif (
prefix = "text_encoders.clip_g.transformer." is_open_clip_sd3_model(checkpoint)
diffusers_format_checkpoint = convert_open_clip_checkpoint(model, checkpoint, prefix=prefix) and checkpoint[CHECKPOINT_KEY_NAMES["open_clip_sd3"]].shape[-1] == position_embedding_dim
):
diffusers_format_checkpoint = convert_ldm_clip_checkpoint(checkpoint, "text_encoders.clip_g.transformer.")
else: else:
raise ValueError("The provided checkpoint does not seem to contain a valid CLIP model.") raise ValueError("The provided checkpoint does not seem to contain a valid CLIP model.")
if is_accelerate_available(): if is_accelerate_available():
unexpected_keys = load_model_dict_into_meta(model, diffusers_format_checkpoint, dtype=torch_dtype) unexpected_keys = load_model_dict_into_meta(model, diffusers_format_checkpoint, dtype=torch_dtype)
if model._keys_to_ignore_on_load_unexpected is not None:
for pat in model._keys_to_ignore_on_load_unexpected:
unexpected_keys = [k for k in unexpected_keys if re.search(pat, k) is None]
if len(unexpected_keys) > 0:
logger.warning(
f"Some weights of the model checkpoint were not used when initializing {cls.__name__}: \n {[', '.join(unexpected_keys)]}"
)
else: else:
model.load_state_dict(diffusers_format_checkpoint) _, unexpected_keys = model.load_state_dict(diffusers_format_checkpoint, strict=False)
if model._keys_to_ignore_on_load_unexpected is not None:
for pat in model._keys_to_ignore_on_load_unexpected:
unexpected_keys = [k for k in unexpected_keys if re.search(pat, k) is None]
if len(unexpected_keys) > 0:
logger.warning(
f"Some weights of the model checkpoint were not used when initializing {cls.__name__}: \n {[', '.join(unexpected_keys)]}"
)
if torch_dtype is not None: if torch_dtype is not None:
model.to(torch_dtype) model.to(torch_dtype)
@@ -1755,7 +1764,7 @@ def convert_sd3_t5_checkpoint_to_diffusers(checkpoint):
keys = list(checkpoint.keys()) keys = list(checkpoint.keys())
text_model_dict = {} text_model_dict = {}
remove_prefixes = ["text_encoders.t5xxl.transformer.encoder."] remove_prefixes = ["text_encoders.t5xxl.transformer."]
for key in keys: for key in keys:
for prefix in remove_prefixes: for prefix in remove_prefixes:
@@ -1799,3 +1808,4 @@ def create_diffusers_t5_model_from_checkpoint(
else: else:
model.load_state_dict(diffusers_format_checkpoint) model.load_state_dict(diffusers_format_checkpoint)
return model
+2
View File
@@ -33,6 +33,7 @@ if is_torch_available():
_import_structure["autoencoders.consistency_decoder_vae"] = ["ConsistencyDecoderVAE"] _import_structure["autoencoders.consistency_decoder_vae"] = ["ConsistencyDecoderVAE"]
_import_structure["autoencoders.vq_model"] = ["VQModel"] _import_structure["autoencoders.vq_model"] = ["VQModel"]
_import_structure["controlnet"] = ["ControlNetModel"] _import_structure["controlnet"] = ["ControlNetModel"]
_import_structure["controlnet_sd3"] = ["SD3ControlNetModel", "SD3MultiControlNetModel"]
_import_structure["controlnet_xs"] = ["ControlNetXSAdapter", "UNetControlNetXSModel"] _import_structure["controlnet_xs"] = ["ControlNetXSAdapter", "UNetControlNetXSModel"]
_import_structure["embeddings"] = ["ImageProjection"] _import_structure["embeddings"] = ["ImageProjection"]
_import_structure["modeling_utils"] = ["ModelMixin"] _import_structure["modeling_utils"] = ["ModelMixin"]
@@ -74,6 +75,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
VQModel, VQModel,
) )
from .controlnet import ControlNetModel from .controlnet import ControlNetModel
from .controlnet_sd3 import SD3ControlNetModel, SD3MultiControlNetModel
from .controlnet_xs import ControlNetXSAdapter, UNetControlNetXSModel from .controlnet_xs import ControlNetXSAdapter, UNetControlNetXSModel
from .embeddings import ImageProjection from .embeddings import ImageProjection
from .modeling_utils import ModelMixin from .modeling_utils import ModelMixin
+5 -387
View File
@@ -13,7 +13,6 @@
# limitations under the License. # limitations under the License.
import inspect import inspect
import math import math
from importlib import import_module
from typing import Callable, List, Optional, Union from typing import Callable, List, Optional, Union
import torch import torch
@@ -24,7 +23,6 @@ from ..image_processor import IPAdapterMaskProcessor
from ..utils import deprecate, logging from ..utils import deprecate, logging
from ..utils.import_utils import is_torch_npu_available, is_xformers_available from ..utils.import_utils import is_torch_npu_available, is_xformers_available
from ..utils.torch_utils import maybe_allow_in_graph from ..utils.torch_utils import maybe_allow_in_graph
from .lora import LoRALinearLayer
logger = logging.get_logger(__name__) # pylint: disable=invalid-name logger = logging.get_logger(__name__) # pylint: disable=invalid-name
@@ -259,10 +257,6 @@ class Attention(nn.Module):
The attention operation to use. Defaults to `None` which uses the default attention operation from The attention operation to use. Defaults to `None` which uses the default attention operation from
`xformers`. `xformers`.
""" """
is_lora = hasattr(self, "processor") and isinstance(
self.processor,
LORA_ATTENTION_PROCESSORS,
)
is_custom_diffusion = hasattr(self, "processor") and isinstance( is_custom_diffusion = hasattr(self, "processor") and isinstance(
self.processor, self.processor,
(CustomDiffusionAttnProcessor, CustomDiffusionXFormersAttnProcessor, CustomDiffusionAttnProcessor2_0), (CustomDiffusionAttnProcessor, CustomDiffusionXFormersAttnProcessor, CustomDiffusionAttnProcessor2_0),
@@ -274,14 +268,13 @@ class Attention(nn.Module):
AttnAddedKVProcessor2_0, AttnAddedKVProcessor2_0,
SlicedAttnAddedKVProcessor, SlicedAttnAddedKVProcessor,
XFormersAttnAddedKVProcessor, XFormersAttnAddedKVProcessor,
LoRAAttnAddedKVProcessor,
), ),
) )
if use_memory_efficient_attention_xformers: if use_memory_efficient_attention_xformers:
if is_added_kv_processor and (is_lora or is_custom_diffusion): if is_added_kv_processor and is_custom_diffusion:
raise NotImplementedError( raise NotImplementedError(
f"Memory efficient attention is currently not supported for LoRA or custom diffusion for attention processor type {self.processor}" f"Memory efficient attention is currently not supported for custom diffusion for attention processor type {self.processor}"
) )
if not is_xformers_available(): if not is_xformers_available():
raise ModuleNotFoundError( raise ModuleNotFoundError(
@@ -307,18 +300,7 @@ class Attention(nn.Module):
except Exception as e: except Exception as e:
raise e raise e
if is_lora: if is_custom_diffusion:
# TODO (sayakpaul): should we throw a warning if someone wants to use the xformers
# variant when using PT 2.0 now that we have LoRAAttnProcessor2_0?
processor = LoRAXFormersAttnProcessor(
hidden_size=self.processor.hidden_size,
cross_attention_dim=self.processor.cross_attention_dim,
rank=self.processor.rank,
attention_op=attention_op,
)
processor.load_state_dict(self.processor.state_dict())
processor.to(self.processor.to_q_lora.up.weight.device)
elif is_custom_diffusion:
processor = CustomDiffusionXFormersAttnProcessor( processor = CustomDiffusionXFormersAttnProcessor(
train_kv=self.processor.train_kv, train_kv=self.processor.train_kv,
train_q_out=self.processor.train_q_out, train_q_out=self.processor.train_q_out,
@@ -341,18 +323,7 @@ class Attention(nn.Module):
else: else:
processor = XFormersAttnProcessor(attention_op=attention_op) processor = XFormersAttnProcessor(attention_op=attention_op)
else: else:
if is_lora: if is_custom_diffusion:
attn_processor_class = (
LoRAAttnProcessor2_0 if hasattr(F, "scaled_dot_product_attention") else LoRAAttnProcessor
)
processor = attn_processor_class(
hidden_size=self.processor.hidden_size,
cross_attention_dim=self.processor.cross_attention_dim,
rank=self.processor.rank,
)
processor.load_state_dict(self.processor.state_dict())
processor.to(self.processor.to_q_lora.up.weight.device)
elif is_custom_diffusion:
attn_processor_class = ( attn_processor_class = (
CustomDiffusionAttnProcessor2_0 CustomDiffusionAttnProcessor2_0
if hasattr(F, "scaled_dot_product_attention") if hasattr(F, "scaled_dot_product_attention")
@@ -442,82 +413,6 @@ class Attention(nn.Module):
if not return_deprecated_lora: if not return_deprecated_lora:
return self.processor return self.processor
# TODO(Sayak, Patrick). The rest of the function is needed to ensure backwards compatible
# serialization format for LoRA Attention Processors. It should be deleted once the integration
# with PEFT is completed.
is_lora_activated = {
name: module.lora_layer is not None
for name, module in self.named_modules()
if hasattr(module, "lora_layer")
}
# 1. if no layer has a LoRA activated we can return the processor as usual
if not any(is_lora_activated.values()):
return self.processor
# If doesn't apply LoRA do `add_k_proj` or `add_v_proj`
is_lora_activated.pop("add_k_proj", None)
is_lora_activated.pop("add_v_proj", None)
# 2. else it is not possible that only some layers have LoRA activated
if not all(is_lora_activated.values()):
raise ValueError(
f"Make sure that either all layers or no layers have LoRA activated, but have {is_lora_activated}"
)
# 3. And we need to merge the current LoRA layers into the corresponding LoRA attention processor
non_lora_processor_cls_name = self.processor.__class__.__name__
lora_processor_cls = getattr(import_module(__name__), "LoRA" + non_lora_processor_cls_name)
hidden_size = self.inner_dim
# now create a LoRA attention processor from the LoRA layers
if lora_processor_cls in [LoRAAttnProcessor, LoRAAttnProcessor2_0, LoRAXFormersAttnProcessor]:
kwargs = {
"cross_attention_dim": self.cross_attention_dim,
"rank": self.to_q.lora_layer.rank,
"network_alpha": self.to_q.lora_layer.network_alpha,
"q_rank": self.to_q.lora_layer.rank,
"q_hidden_size": self.to_q.lora_layer.out_features,
"k_rank": self.to_k.lora_layer.rank,
"k_hidden_size": self.to_k.lora_layer.out_features,
"v_rank": self.to_v.lora_layer.rank,
"v_hidden_size": self.to_v.lora_layer.out_features,
"out_rank": self.to_out[0].lora_layer.rank,
"out_hidden_size": self.to_out[0].lora_layer.out_features,
}
if hasattr(self.processor, "attention_op"):
kwargs["attention_op"] = self.processor.attention_op
lora_processor = lora_processor_cls(hidden_size, **kwargs)
lora_processor.to_q_lora.load_state_dict(self.to_q.lora_layer.state_dict())
lora_processor.to_k_lora.load_state_dict(self.to_k.lora_layer.state_dict())
lora_processor.to_v_lora.load_state_dict(self.to_v.lora_layer.state_dict())
lora_processor.to_out_lora.load_state_dict(self.to_out[0].lora_layer.state_dict())
elif lora_processor_cls == LoRAAttnAddedKVProcessor:
lora_processor = lora_processor_cls(
hidden_size,
cross_attention_dim=self.add_k_proj.weight.shape[0],
rank=self.to_q.lora_layer.rank,
network_alpha=self.to_q.lora_layer.network_alpha,
)
lora_processor.to_q_lora.load_state_dict(self.to_q.lora_layer.state_dict())
lora_processor.to_k_lora.load_state_dict(self.to_k.lora_layer.state_dict())
lora_processor.to_v_lora.load_state_dict(self.to_v.lora_layer.state_dict())
lora_processor.to_out_lora.load_state_dict(self.to_out[0].lora_layer.state_dict())
# only save if used
if self.add_k_proj.lora_layer is not None:
lora_processor.add_k_proj_lora.load_state_dict(self.add_k_proj.lora_layer.state_dict())
lora_processor.add_v_proj_lora.load_state_dict(self.add_v_proj.lora_layer.state_dict())
else:
lora_processor.add_k_proj_lora = None
lora_processor.add_v_proj_lora = None
else:
raise ValueError(f"{lora_processor_cls} does not exist.")
return lora_processor
def forward( def forward(
self, self,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
@@ -1132,9 +1027,7 @@ class JointAttnProcessor2_0:
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
hidden_states = hidden_states = F.scaled_dot_product_attention( hidden_states = F.scaled_dot_product_attention(query, key, value, dropout_p=0.0, is_causal=False)
query, key, value, dropout_p=0.0, is_causal=False
)
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
hidden_states = hidden_states.to(query.dtype) hidden_states = hidden_states.to(query.dtype)
@@ -1406,7 +1299,6 @@ class XFormersAttnProcessor:
class AttnProcessorNPU: class AttnProcessorNPU:
r""" r"""
Processor for implementing flash attention using torch_npu. Torch_npu supports only fp16 and bf16 data types. If Processor for implementing flash attention using torch_npu. Torch_npu supports only fp16 and bf16 data types. If
fp32 is used, F.scaled_dot_product_attention will be used for computation, but the acceleration effect on NPU is fp32 is used, F.scaled_dot_product_attention will be used for computation, but the acceleration effect on NPU is
@@ -2241,264 +2133,6 @@ class SpatialNorm(nn.Module):
return new_f return new_f
class LoRAAttnProcessor(nn.Module):
def __init__(
self,
hidden_size: int,
cross_attention_dim: Optional[int] = None,
rank: int = 4,
network_alpha: Optional[int] = None,
**kwargs,
):
deprecation_message = "Using LoRAAttnProcessor is deprecated. Please use the PEFT backend for all things LoRA. You can install PEFT by running `pip install peft`."
deprecate("LoRAAttnProcessor", "0.30.0", deprecation_message, standard_warn=False)
super().__init__()
self.hidden_size = hidden_size
self.cross_attention_dim = cross_attention_dim
self.rank = rank
q_rank = kwargs.pop("q_rank", None)
q_hidden_size = kwargs.pop("q_hidden_size", None)
q_rank = q_rank if q_rank is not None else rank
q_hidden_size = q_hidden_size if q_hidden_size is not None else hidden_size
v_rank = kwargs.pop("v_rank", None)
v_hidden_size = kwargs.pop("v_hidden_size", None)
v_rank = v_rank if v_rank is not None else rank
v_hidden_size = v_hidden_size if v_hidden_size is not None else hidden_size
out_rank = kwargs.pop("out_rank", None)
out_hidden_size = kwargs.pop("out_hidden_size", None)
out_rank = out_rank if out_rank is not None else rank
out_hidden_size = out_hidden_size if out_hidden_size is not None else hidden_size
self.to_q_lora = LoRALinearLayer(q_hidden_size, q_hidden_size, q_rank, network_alpha)
self.to_k_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank, network_alpha)
self.to_v_lora = LoRALinearLayer(cross_attention_dim or v_hidden_size, v_hidden_size, v_rank, network_alpha)
self.to_out_lora = LoRALinearLayer(out_hidden_size, out_hidden_size, out_rank, network_alpha)
def __call__(self, attn: Attention, hidden_states: torch.Tensor, **kwargs) -> torch.Tensor:
self_cls_name = self.__class__.__name__
deprecate(
self_cls_name,
"0.26.0",
(
f"Make sure use {self_cls_name[4:]} instead by setting"
"LoRA layers to `self.{to_q,to_k,to_v,to_out[0]}.lora_layer` respectively. This will be done automatically when using"
" `LoraLoaderMixin.load_lora_weights`"
),
)
attn.to_q.lora_layer = self.to_q_lora.to(hidden_states.device)
attn.to_k.lora_layer = self.to_k_lora.to(hidden_states.device)
attn.to_v.lora_layer = self.to_v_lora.to(hidden_states.device)
attn.to_out[0].lora_layer = self.to_out_lora.to(hidden_states.device)
attn._modules.pop("processor")
attn.processor = AttnProcessor()
return attn.processor(attn, hidden_states, **kwargs)
class LoRAAttnProcessor2_0(nn.Module):
def __init__(
self,
hidden_size: int,
cross_attention_dim: Optional[int] = None,
rank: int = 4,
network_alpha: Optional[int] = None,
**kwargs,
):
deprecation_message = "Using LoRAAttnProcessor is deprecated. Please use the PEFT backend for all things LoRA. You can install PEFT by running `pip install peft`."
deprecate("LoRAAttnProcessor2_0", "0.30.0", deprecation_message, standard_warn=False)
super().__init__()
if not hasattr(F, "scaled_dot_product_attention"):
raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
self.hidden_size = hidden_size
self.cross_attention_dim = cross_attention_dim
self.rank = rank
q_rank = kwargs.pop("q_rank", None)
q_hidden_size = kwargs.pop("q_hidden_size", None)
q_rank = q_rank if q_rank is not None else rank
q_hidden_size = q_hidden_size if q_hidden_size is not None else hidden_size
v_rank = kwargs.pop("v_rank", None)
v_hidden_size = kwargs.pop("v_hidden_size", None)
v_rank = v_rank if v_rank is not None else rank
v_hidden_size = v_hidden_size if v_hidden_size is not None else hidden_size
out_rank = kwargs.pop("out_rank", None)
out_hidden_size = kwargs.pop("out_hidden_size", None)
out_rank = out_rank if out_rank is not None else rank
out_hidden_size = out_hidden_size if out_hidden_size is not None else hidden_size
self.to_q_lora = LoRALinearLayer(q_hidden_size, q_hidden_size, q_rank, network_alpha)
self.to_k_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank, network_alpha)
self.to_v_lora = LoRALinearLayer(cross_attention_dim or v_hidden_size, v_hidden_size, v_rank, network_alpha)
self.to_out_lora = LoRALinearLayer(out_hidden_size, out_hidden_size, out_rank, network_alpha)
def __call__(self, attn: Attention, hidden_states: torch.Tensor, **kwargs) -> torch.Tensor:
self_cls_name = self.__class__.__name__
deprecate(
self_cls_name,
"0.26.0",
(
f"Make sure use {self_cls_name[4:]} instead by setting"
"LoRA layers to `self.{to_q,to_k,to_v,to_out[0]}.lora_layer` respectively. This will be done automatically when using"
" `LoraLoaderMixin.load_lora_weights`"
),
)
attn.to_q.lora_layer = self.to_q_lora.to(hidden_states.device)
attn.to_k.lora_layer = self.to_k_lora.to(hidden_states.device)
attn.to_v.lora_layer = self.to_v_lora.to(hidden_states.device)
attn.to_out[0].lora_layer = self.to_out_lora.to(hidden_states.device)
attn._modules.pop("processor")
attn.processor = AttnProcessor2_0()
return attn.processor(attn, hidden_states, **kwargs)
class LoRAXFormersAttnProcessor(nn.Module):
r"""
Processor for implementing the LoRA attention mechanism with memory efficient attention using xFormers.
Args:
hidden_size (`int`, *optional*):
The hidden size of the attention layer.
cross_attention_dim (`int`, *optional*):
The number of channels in the `encoder_hidden_states`.
rank (`int`, defaults to 4):
The dimension of the LoRA update matrices.
attention_op (`Callable`, *optional*, defaults to `None`):
The base
[operator](https://facebookresearch.github.io/xformers/components/ops.html#xformers.ops.AttentionOpBase) to
use as the attention operator. It is recommended to set to `None`, and allow xFormers to choose the best
operator.
network_alpha (`int`, *optional*):
Equivalent to `alpha` but it's usage is specific to Kohya (A1111) style LoRAs.
kwargs (`dict`):
Additional keyword arguments to pass to the `LoRALinearLayer` layers.
"""
def __init__(
self,
hidden_size: int,
cross_attention_dim: int,
rank: int = 4,
attention_op: Optional[Callable] = None,
network_alpha: Optional[int] = None,
**kwargs,
):
super().__init__()
self.hidden_size = hidden_size
self.cross_attention_dim = cross_attention_dim
self.rank = rank
self.attention_op = attention_op
q_rank = kwargs.pop("q_rank", None)
q_hidden_size = kwargs.pop("q_hidden_size", None)
q_rank = q_rank if q_rank is not None else rank
q_hidden_size = q_hidden_size if q_hidden_size is not None else hidden_size
v_rank = kwargs.pop("v_rank", None)
v_hidden_size = kwargs.pop("v_hidden_size", None)
v_rank = v_rank if v_rank is not None else rank
v_hidden_size = v_hidden_size if v_hidden_size is not None else hidden_size
out_rank = kwargs.pop("out_rank", None)
out_hidden_size = kwargs.pop("out_hidden_size", None)
out_rank = out_rank if out_rank is not None else rank
out_hidden_size = out_hidden_size if out_hidden_size is not None else hidden_size
self.to_q_lora = LoRALinearLayer(q_hidden_size, q_hidden_size, q_rank, network_alpha)
self.to_k_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank, network_alpha)
self.to_v_lora = LoRALinearLayer(cross_attention_dim or v_hidden_size, v_hidden_size, v_rank, network_alpha)
self.to_out_lora = LoRALinearLayer(out_hidden_size, out_hidden_size, out_rank, network_alpha)
def __call__(self, attn: Attention, hidden_states: torch.Tensor, **kwargs) -> torch.Tensor:
self_cls_name = self.__class__.__name__
deprecate(
self_cls_name,
"0.26.0",
(
f"Make sure use {self_cls_name[4:]} instead by setting"
"LoRA layers to `self.{to_q,to_k,to_v,add_k_proj,add_v_proj,to_out[0]}.lora_layer` respectively. This will be done automatically when using"
" `LoraLoaderMixin.load_lora_weights`"
),
)
attn.to_q.lora_layer = self.to_q_lora.to(hidden_states.device)
attn.to_k.lora_layer = self.to_k_lora.to(hidden_states.device)
attn.to_v.lora_layer = self.to_v_lora.to(hidden_states.device)
attn.to_out[0].lora_layer = self.to_out_lora.to(hidden_states.device)
attn._modules.pop("processor")
attn.processor = XFormersAttnProcessor()
return attn.processor(attn, hidden_states, **kwargs)
class LoRAAttnAddedKVProcessor(nn.Module):
r"""
Processor for implementing the LoRA attention mechanism with extra learnable key and value matrices for the text
encoder.
Args:
hidden_size (`int`, *optional*):
The hidden size of the attention layer.
cross_attention_dim (`int`, *optional*, defaults to `None`):
The number of channels in the `encoder_hidden_states`.
rank (`int`, defaults to 4):
The dimension of the LoRA update matrices.
network_alpha (`int`, *optional*):
Equivalent to `alpha` but it's usage is specific to Kohya (A1111) style LoRAs.
kwargs (`dict`):
Additional keyword arguments to pass to the `LoRALinearLayer` layers.
"""
def __init__(
self,
hidden_size: int,
cross_attention_dim: Optional[int] = None,
rank: int = 4,
network_alpha: Optional[int] = None,
):
super().__init__()
self.hidden_size = hidden_size
self.cross_attention_dim = cross_attention_dim
self.rank = rank
self.to_q_lora = LoRALinearLayer(hidden_size, hidden_size, rank, network_alpha)
self.add_k_proj_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank, network_alpha)
self.add_v_proj_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank, network_alpha)
self.to_k_lora = LoRALinearLayer(hidden_size, hidden_size, rank, network_alpha)
self.to_v_lora = LoRALinearLayer(hidden_size, hidden_size, rank, network_alpha)
self.to_out_lora = LoRALinearLayer(hidden_size, hidden_size, rank, network_alpha)
def __call__(self, attn: Attention, hidden_states: torch.Tensor, **kwargs) -> torch.Tensor:
self_cls_name = self.__class__.__name__
deprecate(
self_cls_name,
"0.26.0",
(
f"Make sure use {self_cls_name[4:]} instead by setting"
"LoRA layers to `self.{to_q,to_k,to_v,add_k_proj,add_v_proj,to_out[0]}.lora_layer` respectively. This will be done automatically when using"
" `LoraLoaderMixin.load_lora_weights`"
),
)
attn.to_q.lora_layer = self.to_q_lora.to(hidden_states.device)
attn.to_k.lora_layer = self.to_k_lora.to(hidden_states.device)
attn.to_v.lora_layer = self.to_v_lora.to(hidden_states.device)
attn.to_out[0].lora_layer = self.to_out_lora.to(hidden_states.device)
attn._modules.pop("processor")
attn.processor = AttnAddedKVProcessor()
return attn.processor(attn, hidden_states, **kwargs)
class IPAdapterAttnProcessor(nn.Module): class IPAdapterAttnProcessor(nn.Module):
r""" r"""
Attention processor for Multiple IP-Adapters. Attention processor for Multiple IP-Adapters.
@@ -2927,19 +2561,11 @@ class IPAdapterAttnProcessor2_0(torch.nn.Module):
return hidden_states return hidden_states
LORA_ATTENTION_PROCESSORS = (
LoRAAttnProcessor,
LoRAAttnProcessor2_0,
LoRAXFormersAttnProcessor,
LoRAAttnAddedKVProcessor,
)
ADDED_KV_ATTENTION_PROCESSORS = ( ADDED_KV_ATTENTION_PROCESSORS = (
AttnAddedKVProcessor, AttnAddedKVProcessor,
SlicedAttnAddedKVProcessor, SlicedAttnAddedKVProcessor,
AttnAddedKVProcessor2_0, AttnAddedKVProcessor2_0,
XFormersAttnAddedKVProcessor, XFormersAttnAddedKVProcessor,
LoRAAttnAddedKVProcessor,
) )
CROSS_ATTENTION_PROCESSORS = ( CROSS_ATTENTION_PROCESSORS = (
@@ -2947,9 +2573,6 @@ CROSS_ATTENTION_PROCESSORS = (
AttnProcessor2_0, AttnProcessor2_0,
XFormersAttnProcessor, XFormersAttnProcessor,
SlicedAttnProcessor, SlicedAttnProcessor,
LoRAAttnProcessor,
LoRAAttnProcessor2_0,
LoRAXFormersAttnProcessor,
IPAdapterAttnProcessor, IPAdapterAttnProcessor,
IPAdapterAttnProcessor2_0, IPAdapterAttnProcessor2_0,
) )
@@ -2967,9 +2590,4 @@ AttentionProcessor = Union[
CustomDiffusionAttnProcessor, CustomDiffusionAttnProcessor,
CustomDiffusionXFormersAttnProcessor, CustomDiffusionXFormersAttnProcessor,
CustomDiffusionAttnProcessor2_0, CustomDiffusionAttnProcessor2_0,
# deprecated
LoRAAttnProcessor,
LoRAAttnProcessor2_0,
LoRAXFormersAttnProcessor,
LoRAAttnAddedKVProcessor,
] ]
@@ -175,7 +175,7 @@ class AutoencoderKL(ModelMixin, ConfigMixin, FromOriginalModelMixin):
def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]): def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
if hasattr(module, "get_processor"): if hasattr(module, "get_processor"):
processors[f"{name}.processor"] = module.get_processor(return_deprecated_lora=True) processors[f"{name}.processor"] = module.get_processor()
for sub_name, child in module.named_children(): for sub_name, child in module.named_children():
fn_recursive_add_processors(f"{name}.{sub_name}", child, processors) fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
@@ -253,7 +253,7 @@ class AutoencoderKLTemporalDecoder(ModelMixin, ConfigMixin):
def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]): def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
if hasattr(module, "get_processor"): if hasattr(module, "get_processor"):
processors[f"{name}.processor"] = module.get_processor(return_deprecated_lora=True) processors[f"{name}.processor"] = module.get_processor()
for sub_name, child in module.named_children(): for sub_name, child in module.named_children():
fn_recursive_add_processors(f"{name}.{sub_name}", child, processors) fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
@@ -111,6 +111,7 @@ class AutoencoderTiny(ModelMixin, ConfigMixin):
latent_shift: float = 0.5, latent_shift: float = 0.5,
force_upcast: bool = False, force_upcast: bool = False,
scaling_factor: float = 1.0, scaling_factor: float = 1.0,
shift_factor: float = 0.0,
): ):
super().__init__() super().__init__()
@@ -211,7 +211,7 @@ class ConsistencyDecoderVAE(ModelMixin, ConfigMixin):
def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]): def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
if hasattr(module, "get_processor"): if hasattr(module, "get_processor"):
processors[f"{name}.processor"] = module.get_processor(return_deprecated_lora=True) processors[f"{name}.processor"] = module.get_processor()
for sub_name, child in module.named_children(): for sub_name, child in module.named_children():
fn_recursive_add_processors(f"{name}.{sub_name}", child, processors) fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
+2 -2
View File
@@ -54,7 +54,7 @@ class ControlNetOutput(BaseOutput):
be of shape `(batch_size, channel * resolution, height //resolution, width // resolution)`. Output can be be of shape `(batch_size, channel * resolution, height //resolution, width // resolution)`. Output can be
used to condition the original UNet's downsampling activations. used to condition the original UNet's downsampling activations.
mid_down_block_re_sample (`torch.Tensor`): mid_down_block_re_sample (`torch.Tensor`):
The activation of the midde block (the lowest sample resolution). Each tensor should be of shape The activation of the middle block (the lowest sample resolution). Each tensor should be of shape
`(batch_size, channel * lowest_resolution, height // lowest_resolution, width // lowest_resolution)`. `(batch_size, channel * lowest_resolution, height // lowest_resolution, width // lowest_resolution)`.
Output can be used to condition the original UNet's middle block activation. Output can be used to condition the original UNet's middle block activation.
""" """
@@ -530,7 +530,7 @@ class ControlNetModel(ModelMixin, ConfigMixin, FromOriginalModelMixin):
def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]): def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
if hasattr(module, "get_processor"): if hasattr(module, "get_processor"):
processors[f"{name}.processor"] = module.get_processor(return_deprecated_lora=True) processors[f"{name}.processor"] = module.get_processor()
for sub_name, child in module.named_children(): for sub_name, child in module.named_children():
fn_recursive_add_processors(f"{name}.{sub_name}", child, processors) fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
+418
View File
@@ -0,0 +1,418 @@
# Copyright 2024 Stability AI, The HuggingFace Team and The InstantX Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from dataclasses import dataclass
from typing import Any, Dict, List, Optional, Tuple, Union
import torch
import torch.nn as nn
from ..configuration_utils import ConfigMixin, register_to_config
from ..loaders import FromOriginalModelMixin, PeftAdapterMixin
from ..models.attention import JointTransformerBlock
from ..models.attention_processor import Attention, AttentionProcessor
from ..models.modeling_utils import ModelMixin
from ..utils import USE_PEFT_BACKEND, is_torch_version, logging, scale_lora_layers, unscale_lora_layers
from .controlnet import BaseOutput, zero_module
from .embeddings import CombinedTimestepTextProjEmbeddings, PatchEmbed
from .transformers.transformer_2d import Transformer2DModelOutput
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
@dataclass
class SD3ControlNetOutput(BaseOutput):
controlnet_block_samples: Tuple[torch.Tensor]
class SD3ControlNetModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin):
_supports_gradient_checkpointing = True
@register_to_config
def __init__(
self,
sample_size: int = 128,
patch_size: int = 2,
in_channels: int = 16,
num_layers: int = 18,
attention_head_dim: int = 64,
num_attention_heads: int = 18,
joint_attention_dim: int = 4096,
caption_projection_dim: int = 1152,
pooled_projection_dim: int = 2048,
out_channels: int = 16,
pos_embed_max_size: int = 96,
):
super().__init__()
default_out_channels = in_channels
self.out_channels = out_channels if out_channels is not None else default_out_channels
self.inner_dim = num_attention_heads * attention_head_dim
self.pos_embed = PatchEmbed(
height=sample_size,
width=sample_size,
patch_size=patch_size,
in_channels=in_channels,
embed_dim=self.inner_dim,
pos_embed_max_size=pos_embed_max_size,
)
self.time_text_embed = CombinedTimestepTextProjEmbeddings(
embedding_dim=self.inner_dim, pooled_projection_dim=pooled_projection_dim
)
self.context_embedder = nn.Linear(joint_attention_dim, caption_projection_dim)
# `attention_head_dim` is doubled to account for the mixing.
# It needs to crafted when we get the actual checkpoints.
self.transformer_blocks = nn.ModuleList(
[
JointTransformerBlock(
dim=self.inner_dim,
num_attention_heads=num_attention_heads,
attention_head_dim=self.inner_dim,
context_pre_only=False,
)
for i in range(num_layers)
]
)
# controlnet_blocks
self.controlnet_blocks = nn.ModuleList([])
for _ in range(len(self.transformer_blocks)):
controlnet_block = nn.Linear(self.inner_dim, self.inner_dim)
controlnet_block = zero_module(controlnet_block)
self.controlnet_blocks.append(controlnet_block)
pos_embed_input = PatchEmbed(
height=sample_size,
width=sample_size,
patch_size=patch_size,
in_channels=in_channels,
embed_dim=self.inner_dim,
pos_embed_type=None,
)
self.pos_embed_input = zero_module(pos_embed_input)
self.gradient_checkpointing = False
# Copied from diffusers.models.unets.unet_3d_condition.UNet3DConditionModel.enable_forward_chunking
def enable_forward_chunking(self, chunk_size: Optional[int] = None, dim: int = 0) -> None:
"""
Sets the attention processor to use [feed forward
chunking](https://huggingface.co/blog/reformer#2-chunked-feed-forward-layers).
Parameters:
chunk_size (`int`, *optional*):
The chunk size of the feed-forward layers. If not specified, will run feed-forward layer individually
over each tensor of dim=`dim`.
dim (`int`, *optional*, defaults to `0`):
The dimension over which the feed-forward computation should be chunked. Choose between dim=0 (batch)
or dim=1 (sequence length).
"""
if dim not in [0, 1]:
raise ValueError(f"Make sure to set `dim` to either 0 or 1, not {dim}")
# By default chunk size is 1
chunk_size = chunk_size or 1
def fn_recursive_feed_forward(module: torch.nn.Module, chunk_size: int, dim: int):
if hasattr(module, "set_chunk_feed_forward"):
module.set_chunk_feed_forward(chunk_size=chunk_size, dim=dim)
for child in module.children():
fn_recursive_feed_forward(child, chunk_size, dim)
for module in self.children():
fn_recursive_feed_forward(module, chunk_size, dim)
@property
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors
def attn_processors(self) -> Dict[str, AttentionProcessor]:
r"""
Returns:
`dict` of attention processors: A dictionary containing all attention processors used in the model with
indexed by its weight name.
"""
# set recursively
processors = {}
def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
if hasattr(module, "get_processor"):
processors[f"{name}.processor"] = module.get_processor()
for sub_name, child in module.named_children():
fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
return processors
for name, module in self.named_children():
fn_recursive_add_processors(name, module, processors)
return processors
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor
def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
r"""
Sets the attention processor to use to compute attention.
Parameters:
processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
The instantiated processor class or a dictionary of processor classes that will be set as the processor
for **all** `Attention` layers.
If `processor` is a dict, the key needs to define the path to the corresponding cross attention
processor. This is strongly recommended when setting trainable attention processors.
"""
count = len(self.attn_processors.keys())
if isinstance(processor, dict) and len(processor) != count:
raise ValueError(
f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
)
def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
if hasattr(module, "set_processor"):
if not isinstance(processor, dict):
module.set_processor(processor)
else:
module.set_processor(processor.pop(f"{name}.processor"))
for sub_name, child in module.named_children():
fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
for name, module in self.named_children():
fn_recursive_attn_processor(name, module, processor)
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.fuse_qkv_projections
def fuse_qkv_projections(self):
"""
Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query, key, value)
are fused. For cross-attention modules, key and value projection matrices are fused.
<Tip warning={true}>
This API is 🧪 experimental.
</Tip>
"""
self.original_attn_processors = None
for _, attn_processor in self.attn_processors.items():
if "Added" in str(attn_processor.__class__.__name__):
raise ValueError("`fuse_qkv_projections()` is not supported for models having added KV projections.")
self.original_attn_processors = self.attn_processors
for module in self.modules():
if isinstance(module, Attention):
module.fuse_projections(fuse=True)
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.unfuse_qkv_projections
def unfuse_qkv_projections(self):
"""Disables the fused QKV projection if enabled.
<Tip warning={true}>
This API is 🧪 experimental.
</Tip>
"""
if self.original_attn_processors is not None:
self.set_attn_processor(self.original_attn_processors)
def _set_gradient_checkpointing(self, module, value=False):
if hasattr(module, "gradient_checkpointing"):
module.gradient_checkpointing = value
@classmethod
def from_transformer(cls, transformer, num_layers=None, load_weights_from_transformer=True):
config = transformer.config
config["num_layers"] = num_layers or config.num_layers
controlnet = cls(**config)
if load_weights_from_transformer:
controlnet.pos_embed.load_state_dict(transformer.pos_embed.state_dict(), strict=False)
controlnet.time_text_embed.load_state_dict(transformer.time_text_embed.state_dict(), strict=False)
controlnet.context_embedder.load_state_dict(transformer.context_embedder.state_dict(), strict=False)
controlnet.transformer_blocks.load_state_dict(transformer.transformer_blocks.state_dict())
controlnet.pos_embed_input = zero_module(controlnet.pos_embed_input)
return controlnet
def forward(
self,
hidden_states: torch.FloatTensor,
controlnet_cond: torch.Tensor,
conditioning_scale: float = 1.0,
encoder_hidden_states: torch.FloatTensor = None,
pooled_projections: torch.FloatTensor = None,
timestep: torch.LongTensor = None,
joint_attention_kwargs: Optional[Dict[str, Any]] = None,
return_dict: bool = True,
) -> Union[torch.FloatTensor, Transformer2DModelOutput]:
"""
The [`SD3Transformer2DModel`] forward method.
Args:
hidden_states (`torch.FloatTensor` of shape `(batch size, channel, height, width)`):
Input `hidden_states`.
controlnet_cond (`torch.Tensor`):
The conditional input tensor of shape `(batch_size, sequence_length, hidden_size)`.
conditioning_scale (`float`, defaults to `1.0`):
The scale factor for ControlNet outputs.
encoder_hidden_states (`torch.FloatTensor` of shape `(batch size, sequence_len, embed_dims)`):
Conditional embeddings (embeddings computed from the input conditions such as prompts) to use.
pooled_projections (`torch.FloatTensor` of shape `(batch_size, projection_dim)`): Embeddings projected
from the embeddings of input conditions.
timestep ( `torch.LongTensor`):
Used to indicate denoising step.
joint_attention_kwargs (`dict`, *optional*):
A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
`self.processor` in
[diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`~models.transformer_2d.Transformer2DModelOutput`] instead of a plain
tuple.
Returns:
If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a
`tuple` where the first element is the sample tensor.
"""
if joint_attention_kwargs is not None:
joint_attention_kwargs = joint_attention_kwargs.copy()
lora_scale = joint_attention_kwargs.pop("scale", 1.0)
else:
lora_scale = 1.0
if USE_PEFT_BACKEND:
# weight the lora layers by setting `lora_scale` for each PEFT layer
scale_lora_layers(self, lora_scale)
else:
if joint_attention_kwargs is not None and joint_attention_kwargs.get("scale", None) is not None:
logger.warning(
"Passing `scale` via `joint_attention_kwargs` when not using the PEFT backend is ineffective."
)
height, width = hidden_states.shape[-2:]
hidden_states = self.pos_embed(hidden_states) # takes care of adding positional embeddings too.
temb = self.time_text_embed(timestep, pooled_projections)
encoder_hidden_states = self.context_embedder(encoder_hidden_states)
# add
hidden_states = hidden_states + self.pos_embed_input(controlnet_cond)
block_res_samples = ()
for block in self.transformer_blocks:
if self.training and self.gradient_checkpointing:
def create_custom_forward(module, return_dict=None):
def custom_forward(*inputs):
if return_dict is not None:
return module(*inputs, return_dict=return_dict)
else:
return module(*inputs)
return custom_forward
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(block),
hidden_states,
encoder_hidden_states,
temb,
**ckpt_kwargs,
)
else:
encoder_hidden_states, hidden_states = block(
hidden_states=hidden_states, encoder_hidden_states=encoder_hidden_states, temb=temb
)
block_res_samples = block_res_samples + (hidden_states,)
controlnet_block_res_samples = ()
for block_res_sample, controlnet_block in zip(block_res_samples, self.controlnet_blocks):
block_res_sample = controlnet_block(block_res_sample)
controlnet_block_res_samples = controlnet_block_res_samples + (block_res_sample,)
# 6. scaling
controlnet_block_res_samples = [sample * conditioning_scale for sample in controlnet_block_res_samples]
if USE_PEFT_BACKEND:
# remove `lora_scale` from each PEFT layer
unscale_lora_layers(self, lora_scale)
if not return_dict:
return (controlnet_block_res_samples,)
return SD3ControlNetOutput(controlnet_block_samples=controlnet_block_res_samples)
class SD3MultiControlNetModel(ModelMixin):
r"""
`SD3ControlNetModel` wrapper class for Multi-SD3ControlNet
This module is a wrapper for multiple instances of the `SD3ControlNetModel`. The `forward()` API is designed to be
compatible with `SD3ControlNetModel`.
Args:
controlnets (`List[SD3ControlNetModel]`):
Provides additional conditioning to the unet during the denoising process. You must set multiple
`SD3ControlNetModel` as a list.
"""
def __init__(self, controlnets):
super().__init__()
self.nets = nn.ModuleList(controlnets)
def forward(
self,
hidden_states: torch.FloatTensor,
controlnet_cond: List[torch.tensor],
conditioning_scale: List[float],
pooled_projections: torch.FloatTensor,
encoder_hidden_states: torch.FloatTensor = None,
timestep: torch.LongTensor = None,
joint_attention_kwargs: Optional[Dict[str, Any]] = None,
return_dict: bool = True,
) -> Union[SD3ControlNetOutput, Tuple]:
for i, (image, scale, controlnet) in enumerate(zip(controlnet_cond, conditioning_scale, self.nets)):
block_samples = controlnet(
hidden_states=hidden_states,
timestep=timestep,
encoder_hidden_states=encoder_hidden_states,
pooled_projections=pooled_projections,
controlnet_cond=image,
conditioning_scale=scale,
joint_attention_kwargs=joint_attention_kwargs,
return_dict=return_dict,
)
# merge samples
if i == 0:
control_block_samples = block_samples
else:
control_block_samples = [
control_block_sample + block_sample
for control_block_sample, block_sample in zip(control_block_samples[0], block_samples[0])
]
control_block_samples = (tuple(control_block_samples),)
return control_block_samples
+36 -9
View File
@@ -114,6 +114,7 @@ def get_down_block_adapter(
cross_attention_dim: Optional[int] = 1024, cross_attention_dim: Optional[int] = 1024,
add_downsample: bool = True, add_downsample: bool = True,
upcast_attention: Optional[bool] = False, upcast_attention: Optional[bool] = False,
use_linear_projection: Optional[bool] = True,
): ):
num_layers = 2 # only support sd + sdxl num_layers = 2 # only support sd + sdxl
@@ -152,7 +153,7 @@ def get_down_block_adapter(
in_channels=ctrl_out_channels, in_channels=ctrl_out_channels,
num_layers=transformer_layers_per_block[i], num_layers=transformer_layers_per_block[i],
cross_attention_dim=cross_attention_dim, cross_attention_dim=cross_attention_dim,
use_linear_projection=True, use_linear_projection=use_linear_projection,
upcast_attention=upcast_attention, upcast_attention=upcast_attention,
norm_num_groups=find_largest_factor(ctrl_out_channels, max_factor=max_norm_num_groups), norm_num_groups=find_largest_factor(ctrl_out_channels, max_factor=max_norm_num_groups),
) )
@@ -200,6 +201,7 @@ def get_mid_block_adapter(
num_attention_heads: Optional[int] = 1, num_attention_heads: Optional[int] = 1,
cross_attention_dim: Optional[int] = 1024, cross_attention_dim: Optional[int] = 1024,
upcast_attention: bool = False, upcast_attention: bool = False,
use_linear_projection: bool = True,
): ):
# Before the midblock application, information is concatted from base to control. # Before the midblock application, information is concatted from base to control.
# Concat doesn't require change in number of channels # Concat doesn't require change in number of channels
@@ -214,7 +216,7 @@ def get_mid_block_adapter(
resnet_groups=find_largest_factor(gcd(ctrl_channels, ctrl_channels + base_channels), max_norm_num_groups), resnet_groups=find_largest_factor(gcd(ctrl_channels, ctrl_channels + base_channels), max_norm_num_groups),
cross_attention_dim=cross_attention_dim, cross_attention_dim=cross_attention_dim,
num_attention_heads=num_attention_heads, num_attention_heads=num_attention_heads,
use_linear_projection=True, use_linear_projection=use_linear_projection,
upcast_attention=upcast_attention, upcast_attention=upcast_attention,
) )
@@ -308,6 +310,7 @@ class ControlNetXSAdapter(ModelMixin, ConfigMixin):
transformer_layers_per_block: Union[int, Tuple[int]] = 1, transformer_layers_per_block: Union[int, Tuple[int]] = 1,
upcast_attention: bool = True, upcast_attention: bool = True,
max_norm_num_groups: int = 32, max_norm_num_groups: int = 32,
use_linear_projection: bool = True,
): ):
super().__init__() super().__init__()
@@ -381,6 +384,7 @@ class ControlNetXSAdapter(ModelMixin, ConfigMixin):
cross_attention_dim=cross_attention_dim[i], cross_attention_dim=cross_attention_dim[i],
add_downsample=not is_final_block, add_downsample=not is_final_block,
upcast_attention=upcast_attention, upcast_attention=upcast_attention,
use_linear_projection=use_linear_projection,
) )
) )
@@ -393,6 +397,7 @@ class ControlNetXSAdapter(ModelMixin, ConfigMixin):
num_attention_heads=num_attention_heads[-1], num_attention_heads=num_attention_heads[-1],
cross_attention_dim=cross_attention_dim[-1], cross_attention_dim=cross_attention_dim[-1],
upcast_attention=upcast_attention, upcast_attention=upcast_attention,
use_linear_projection=use_linear_projection,
) )
# up # up
@@ -489,6 +494,7 @@ class ControlNetXSAdapter(ModelMixin, ConfigMixin):
transformer_layers_per_block=unet.config.transformer_layers_per_block, transformer_layers_per_block=unet.config.transformer_layers_per_block,
upcast_attention=unet.config.upcast_attention, upcast_attention=unet.config.upcast_attention,
max_norm_num_groups=unet.config.norm_num_groups, max_norm_num_groups=unet.config.norm_num_groups,
use_linear_projection=unet.config.use_linear_projection,
) )
# ensure that the ControlNetXSAdapter is the same dtype as the UNet2DConditionModel # ensure that the ControlNetXSAdapter is the same dtype as the UNet2DConditionModel
@@ -538,6 +544,7 @@ class UNetControlNetXSModel(ModelMixin, ConfigMixin):
addition_embed_type: Optional[str] = None, addition_embed_type: Optional[str] = None,
addition_time_embed_dim: Optional[int] = None, addition_time_embed_dim: Optional[int] = None,
upcast_attention: bool = True, upcast_attention: bool = True,
use_linear_projection: bool = True,
time_cond_proj_dim: Optional[int] = None, time_cond_proj_dim: Optional[int] = None,
projection_class_embeddings_input_dim: Optional[int] = None, projection_class_embeddings_input_dim: Optional[int] = None,
# additional controlnet configs # additional controlnet configs
@@ -595,7 +602,12 @@ class UNetControlNetXSModel(ModelMixin, ConfigMixin):
time_embed_dim, time_embed_dim,
cond_proj_dim=time_cond_proj_dim, cond_proj_dim=time_cond_proj_dim,
) )
self.ctrl_time_embedding = TimestepEmbedding(in_channels=time_embed_input_dim, time_embed_dim=time_embed_dim) if ctrl_learn_time_embedding:
self.ctrl_time_embedding = TimestepEmbedding(
in_channels=time_embed_input_dim, time_embed_dim=time_embed_dim
)
else:
self.ctrl_time_embedding = None
if addition_embed_type is None: if addition_embed_type is None:
self.base_add_time_proj = None self.base_add_time_proj = None
@@ -632,6 +644,7 @@ class UNetControlNetXSModel(ModelMixin, ConfigMixin):
cross_attention_dim=cross_attention_dim[i], cross_attention_dim=cross_attention_dim[i],
add_downsample=not is_final_block, add_downsample=not is_final_block,
upcast_attention=upcast_attention, upcast_attention=upcast_attention,
use_linear_projection=use_linear_projection,
) )
) )
@@ -647,6 +660,7 @@ class UNetControlNetXSModel(ModelMixin, ConfigMixin):
ctrl_num_attention_heads=ctrl_num_attention_heads[-1], ctrl_num_attention_heads=ctrl_num_attention_heads[-1],
cross_attention_dim=cross_attention_dim[-1], cross_attention_dim=cross_attention_dim[-1],
upcast_attention=upcast_attention, upcast_attention=upcast_attention,
use_linear_projection=use_linear_projection,
) )
# # Create up blocks # # Create up blocks
@@ -690,6 +704,7 @@ class UNetControlNetXSModel(ModelMixin, ConfigMixin):
add_upsample=not is_final_block, add_upsample=not is_final_block,
upcast_attention=upcast_attention, upcast_attention=upcast_attention,
norm_num_groups=norm_num_groups, norm_num_groups=norm_num_groups,
use_linear_projection=use_linear_projection,
) )
) )
@@ -754,6 +769,7 @@ class UNetControlNetXSModel(ModelMixin, ConfigMixin):
"addition_embed_type", "addition_embed_type",
"addition_time_embed_dim", "addition_time_embed_dim",
"upcast_attention", "upcast_attention",
"use_linear_projection",
"time_cond_proj_dim", "time_cond_proj_dim",
"projection_class_embeddings_input_dim", "projection_class_embeddings_input_dim",
] ]
@@ -864,7 +880,7 @@ class UNetControlNetXSModel(ModelMixin, ConfigMixin):
def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]): def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
if hasattr(module, "get_processor"): if hasattr(module, "get_processor"):
processors[f"{name}.processor"] = module.get_processor(return_deprecated_lora=True) processors[f"{name}.processor"] = module.get_processor()
for sub_name, child in module.named_children(): for sub_name, child in module.named_children():
fn_recursive_add_processors(f"{name}.{sub_name}", child, processors) fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
@@ -1219,6 +1235,7 @@ class ControlNetXSCrossAttnDownBlock2D(nn.Module):
cross_attention_dim: Optional[int] = 1024, cross_attention_dim: Optional[int] = 1024,
add_downsample: bool = True, add_downsample: bool = True,
upcast_attention: Optional[bool] = False, upcast_attention: Optional[bool] = False,
use_linear_projection: Optional[bool] = True,
): ):
super().__init__() super().__init__()
base_resnets = [] base_resnets = []
@@ -1270,7 +1287,7 @@ class ControlNetXSCrossAttnDownBlock2D(nn.Module):
in_channels=base_out_channels, in_channels=base_out_channels,
num_layers=transformer_layers_per_block[i], num_layers=transformer_layers_per_block[i],
cross_attention_dim=cross_attention_dim, cross_attention_dim=cross_attention_dim,
use_linear_projection=True, use_linear_projection=use_linear_projection,
upcast_attention=upcast_attention, upcast_attention=upcast_attention,
norm_num_groups=norm_num_groups, norm_num_groups=norm_num_groups,
) )
@@ -1282,7 +1299,7 @@ class ControlNetXSCrossAttnDownBlock2D(nn.Module):
in_channels=ctrl_out_channels, in_channels=ctrl_out_channels,
num_layers=transformer_layers_per_block[i], num_layers=transformer_layers_per_block[i],
cross_attention_dim=cross_attention_dim, cross_attention_dim=cross_attention_dim,
use_linear_projection=True, use_linear_projection=use_linear_projection,
upcast_attention=upcast_attention, upcast_attention=upcast_attention,
norm_num_groups=find_largest_factor(ctrl_out_channels, max_factor=ctrl_max_norm_num_groups), norm_num_groups=find_largest_factor(ctrl_out_channels, max_factor=ctrl_max_norm_num_groups),
) )
@@ -1342,6 +1359,7 @@ class ControlNetXSCrossAttnDownBlock2D(nn.Module):
ctrl_num_attention_heads = get_first_cross_attention(ctrl_downblock).heads ctrl_num_attention_heads = get_first_cross_attention(ctrl_downblock).heads
cross_attention_dim = get_first_cross_attention(base_downblock).cross_attention_dim cross_attention_dim = get_first_cross_attention(base_downblock).cross_attention_dim
upcast_attention = get_first_cross_attention(base_downblock).upcast_attention upcast_attention = get_first_cross_attention(base_downblock).upcast_attention
use_linear_projection = base_downblock.attentions[0].use_linear_projection
else: else:
has_crossattn = False has_crossattn = False
transformer_layers_per_block = None transformer_layers_per_block = None
@@ -1349,6 +1367,7 @@ class ControlNetXSCrossAttnDownBlock2D(nn.Module):
ctrl_num_attention_heads = None ctrl_num_attention_heads = None
cross_attention_dim = None cross_attention_dim = None
upcast_attention = None upcast_attention = None
use_linear_projection = None
add_downsample = base_downblock.downsamplers is not None add_downsample = base_downblock.downsamplers is not None
# create model # create model
@@ -1367,6 +1386,7 @@ class ControlNetXSCrossAttnDownBlock2D(nn.Module):
cross_attention_dim=cross_attention_dim, cross_attention_dim=cross_attention_dim,
add_downsample=add_downsample, add_downsample=add_downsample,
upcast_attention=upcast_attention, upcast_attention=upcast_attention,
use_linear_projection=use_linear_projection,
) )
# # load weights # # load weights
@@ -1527,6 +1547,7 @@ class ControlNetXSCrossAttnMidBlock2D(nn.Module):
ctrl_num_attention_heads: Optional[int] = 1, ctrl_num_attention_heads: Optional[int] = 1,
cross_attention_dim: Optional[int] = 1024, cross_attention_dim: Optional[int] = 1024,
upcast_attention: bool = False, upcast_attention: bool = False,
use_linear_projection: Optional[bool] = True,
): ):
super().__init__() super().__init__()
@@ -1541,7 +1562,7 @@ class ControlNetXSCrossAttnMidBlock2D(nn.Module):
resnet_groups=norm_num_groups, resnet_groups=norm_num_groups,
cross_attention_dim=cross_attention_dim, cross_attention_dim=cross_attention_dim,
num_attention_heads=base_num_attention_heads, num_attention_heads=base_num_attention_heads,
use_linear_projection=True, use_linear_projection=use_linear_projection,
upcast_attention=upcast_attention, upcast_attention=upcast_attention,
) )
@@ -1556,7 +1577,7 @@ class ControlNetXSCrossAttnMidBlock2D(nn.Module):
), ),
cross_attention_dim=cross_attention_dim, cross_attention_dim=cross_attention_dim,
num_attention_heads=ctrl_num_attention_heads, num_attention_heads=ctrl_num_attention_heads,
use_linear_projection=True, use_linear_projection=use_linear_projection,
upcast_attention=upcast_attention, upcast_attention=upcast_attention,
) )
@@ -1590,6 +1611,7 @@ class ControlNetXSCrossAttnMidBlock2D(nn.Module):
ctrl_num_attention_heads = get_first_cross_attention(ctrl_midblock).heads ctrl_num_attention_heads = get_first_cross_attention(ctrl_midblock).heads
cross_attention_dim = get_first_cross_attention(base_midblock).cross_attention_dim cross_attention_dim = get_first_cross_attention(base_midblock).cross_attention_dim
upcast_attention = get_first_cross_attention(base_midblock).upcast_attention upcast_attention = get_first_cross_attention(base_midblock).upcast_attention
use_linear_projection = base_midblock.attentions[0].use_linear_projection
# create model # create model
model = cls( model = cls(
@@ -1603,6 +1625,7 @@ class ControlNetXSCrossAttnMidBlock2D(nn.Module):
ctrl_num_attention_heads=ctrl_num_attention_heads, ctrl_num_attention_heads=ctrl_num_attention_heads,
cross_attention_dim=cross_attention_dim, cross_attention_dim=cross_attention_dim,
upcast_attention=upcast_attention, upcast_attention=upcast_attention,
use_linear_projection=use_linear_projection,
) )
# load weights # load weights
@@ -1677,6 +1700,7 @@ class ControlNetXSCrossAttnUpBlock2D(nn.Module):
cross_attention_dim: int = 1024, cross_attention_dim: int = 1024,
add_upsample: bool = True, add_upsample: bool = True,
upcast_attention: bool = False, upcast_attention: bool = False,
use_linear_projection: Optional[bool] = True,
): ):
super().__init__() super().__init__()
resnets = [] resnets = []
@@ -1714,7 +1738,7 @@ class ControlNetXSCrossAttnUpBlock2D(nn.Module):
in_channels=out_channels, in_channels=out_channels,
num_layers=transformer_layers_per_block[i], num_layers=transformer_layers_per_block[i],
cross_attention_dim=cross_attention_dim, cross_attention_dim=cross_attention_dim,
use_linear_projection=True, use_linear_projection=use_linear_projection,
upcast_attention=upcast_attention, upcast_attention=upcast_attention,
norm_num_groups=norm_num_groups, norm_num_groups=norm_num_groups,
) )
@@ -1753,12 +1777,14 @@ class ControlNetXSCrossAttnUpBlock2D(nn.Module):
num_attention_heads = get_first_cross_attention(base_upblock).heads num_attention_heads = get_first_cross_attention(base_upblock).heads
cross_attention_dim = get_first_cross_attention(base_upblock).cross_attention_dim cross_attention_dim = get_first_cross_attention(base_upblock).cross_attention_dim
upcast_attention = get_first_cross_attention(base_upblock).upcast_attention upcast_attention = get_first_cross_attention(base_upblock).upcast_attention
use_linear_projection = base_upblock.attentions[0].use_linear_projection
else: else:
has_crossattn = False has_crossattn = False
transformer_layers_per_block = None transformer_layers_per_block = None
num_attention_heads = None num_attention_heads = None
cross_attention_dim = None cross_attention_dim = None
upcast_attention = None upcast_attention = None
use_linear_projection = None
add_upsample = base_upblock.upsamplers is not None add_upsample = base_upblock.upsamplers is not None
# create model # create model
@@ -1776,6 +1802,7 @@ class ControlNetXSCrossAttnUpBlock2D(nn.Module):
cross_attention_dim=cross_attention_dim, cross_attention_dim=cross_attention_dim,
add_upsample=add_upsample, add_upsample=add_upsample,
upcast_attention=upcast_attention, upcast_attention=upcast_attention,
use_linear_projection=use_linear_projection,
) )
# load weights # load weights
+2 -2
View File
@@ -980,7 +980,7 @@ class GLIGENTextBoundingboxProjection(nn.Module):
objs = self.linears(torch.cat([positive_embeddings, xyxy_embedding], dim=-1)) objs = self.linears(torch.cat([positive_embeddings, xyxy_embedding], dim=-1))
# positionet with text and image infomation # positionet with text and image information
else: else:
phrases_masks = phrases_masks.unsqueeze(-1) phrases_masks = phrases_masks.unsqueeze(-1)
image_masks = image_masks.unsqueeze(-1) image_masks = image_masks.unsqueeze(-1)
@@ -1252,7 +1252,7 @@ class MultiIPAdapterImageProjection(nn.Module):
if not isinstance(image_embeds, list): if not isinstance(image_embeds, list):
deprecation_message = ( deprecation_message = (
"You have passed a tensor as `image_embeds`.This is deprecated and will be removed in a future release." "You have passed a tensor as `image_embeds`.This is deprecated and will be removed in a future release."
" Please make sure to update your script to pass `image_embeds` as a list of tensors to supress this warning." " Please make sure to update your script to pass `image_embeds` as a list of tensors to suppress this warning."
) )
deprecate("image_embeds not a list", "1.0.0", deprecation_message, standard_warn=False) deprecate("image_embeds not a list", "1.0.0", deprecation_message, standard_warn=False)
image_embeds = [image_embeds.unsqueeze(1)] image_embeds = [image_embeds.unsqueeze(1)]
+11 -4
View File
@@ -462,7 +462,7 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
device_map (`str` or `Dict[str, Union[int, str, torch.device]]`, *optional*): device_map (`str` or `Dict[str, Union[int, str, torch.device]]`, *optional*):
A map that specifies where each submodule should go. It doesn't need to be defined for each A map that specifies where each submodule should go. It doesn't need to be defined for each
parameter/buffer name; once a given module name is inside, every submodule of it will be sent to the parameter/buffer name; once a given module name is inside, every submodule of it will be sent to the
same device. same device. Defaults to `None`, meaning that the model will be loaded on CPU.
Set `device_map="auto"` to have 🤗 Accelerate automatically compute the most optimized `device_map`. For Set `device_map="auto"` to have 🤗 Accelerate automatically compute the most optimized `device_map`. For
more information about each option see [designing a device more information about each option see [designing a device
@@ -774,7 +774,12 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
else: # else let accelerate handle loading and dispatching. else: # else let accelerate handle loading and dispatching.
# Load weights and dispatch according to the device_map # Load weights and dispatch according to the device_map
# by default the device_map is None and the weights are loaded on the CPU # by default the device_map is None and the weights are loaded on the CPU
force_hook = True
device_map = _determine_device_map(model, device_map, max_memory, torch_dtype) device_map = _determine_device_map(model, device_map, max_memory, torch_dtype)
if device_map is None and is_sharded:
# we load the parameters on the cpu
device_map = {"": "cpu"}
force_hook = False
try: try:
accelerate.load_checkpoint_and_dispatch( accelerate.load_checkpoint_and_dispatch(
model, model,
@@ -784,7 +789,7 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
offload_folder=offload_folder, offload_folder=offload_folder,
offload_state_dict=offload_state_dict, offload_state_dict=offload_state_dict,
dtype=torch_dtype, dtype=torch_dtype,
force_hooks=True, force_hooks=force_hook,
strict=True, strict=True,
) )
except AttributeError as e: except AttributeError as e:
@@ -808,12 +813,14 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
model._temp_convert_self_to_deprecated_attention_blocks() model._temp_convert_self_to_deprecated_attention_blocks()
accelerate.load_checkpoint_and_dispatch( accelerate.load_checkpoint_and_dispatch(
model, model,
model_file, model_file if not is_sharded else sharded_ckpt_cached_folder,
device_map, device_map,
max_memory=max_memory, max_memory=max_memory,
offload_folder=offload_folder, offload_folder=offload_folder,
offload_state_dict=offload_state_dict, offload_state_dict=offload_state_dict,
dtype=torch_dtype, dtype=torch_dtype,
force_hooks=force_hook,
strict=True,
) )
model._undo_temp_convert_self_to_deprecated_attention_blocks() model._undo_temp_convert_self_to_deprecated_attention_blocks()
else: else:
@@ -1162,7 +1169,7 @@ class LegacyModelMixin(ModelMixin):
@classmethod @classmethod
@validate_hf_hub_args @validate_hf_hub_args
def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.PathLike]], **kwargs): def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.PathLike]], **kwargs):
# To prevent depedency import problem. # To prevent dependency import problem.
from .model_loading_utils import _fetch_remapped_cls_from_config from .model_loading_utils import _fetch_remapped_cls_from_config
# Create a copy of the kwargs so that we don't mess with the keyword arguments in the downstream calls. # Create a copy of the kwargs so that we don't mess with the keyword arguments in the downstream calls.
@@ -373,7 +373,7 @@ class HunyuanDiT2DModel(ModelMixin, ConfigMixin):
def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]): def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
if hasattr(module, "get_processor"): if hasattr(module, "get_processor"):
processors[f"{name}.processor"] = module.get_processor(return_deprecated_lora=True) processors[f"{name}.processor"] = module.get_processor()
for sub_name, child in module.named_children(): for sub_name, child in module.named_children():
fn_recursive_add_processors(f"{name}.{sub_name}", child, processors) fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)

Some files were not shown because too many files have changed in this diff Show More