Compare commits

...

45 Commits

Author SHA1 Message Date
Dhruv Nair a12d8d90e2 Update src/diffusers/models/unets/unet_motion_model.py
Co-authored-by: YiYi Xu <yixu310@gmail.com>
2024-07-02 10:20:42 +05:30
YiYi Xu 5a2909734d Merge branch 'main' into animatediff-warning 2024-07-01 07:35:02 -10:00
Dhruv Nair 0368483b61 Remove legacy single file model loading mixins (#8754)
update
2024-07-01 07:20:19 -10:00
YiYi Xu ddb9d8548c [doc] add a tip about using SDXL refiner with hunyuan-dit and pixart (#8735)
* up

* Apply suggestions from code review

Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com>

---------

Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com>
2024-07-01 06:30:09 -10:00
Dhruv Nair 5ce8e040aa update 2024-07-01 12:32:56 +00:00
Lucain 49979753e1 Always raise from previous error (#8751) 2024-07-01 14:22:30 +05:30
XCL a3904d7e34 [Tencent Hunyuan Team] Add HunyuanDiT-v1.2 Support (#8747)
* add v1.2 support

---------

Co-authored-by: xingchaoliu <xingchaoliu@tencent.com>
Co-authored-by: yiyixuxu <yixu310@gmail.com>
2024-06-30 21:33:38 -10:00
WenheLI 7bfc1ee1b2 fix the LR schedulers for dreambooth_lora (#8510)
* update training

* update

---------

Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
Co-authored-by: Linoy Tsaban <57615435+linoytsaban@users.noreply.github.com>
2024-07-01 08:14:57 +05:30
Bhavay Malhotra 71c046102b [train_controlnet_sdxl.py] Fix the LR schedulers when num_train_epochs is passed in a distributed training env (#8476)
* Create diffusers.yml

* num_train_epochs

---------

Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
2024-07-01 07:21:40 +05:30
Sayak Paul 83b112a145 shift cache in benchmarking. (#8740)
* shift cache.

* comment
2024-07-01 07:14:05 +05:30
Shauray Singh 8690e8b9d6 add PAG support for SD architecture (#8725)
* add pag to sd pipelines
2024-06-29 09:26:11 -10:00
Sayak Paul 7db8c3ec40 Benchmarking workflow fix (#8389)
* fix

* fixes

* add back the deadsnakes

* better messaging

* disable IP adapter tests for the moment.

* style

* up

* empty
2024-06-29 09:06:32 +05:30
Álvaro Somoza 9b7acc7cf2 [Community pipeline] SD3 Differential Diffusion Img2Img Pipeline (#8679)
* new pipeline
2024-06-28 17:12:39 -10:00
Luo Chaofan a216b0bb7f fix: ValueError when using FromOriginalModelMixin in subclasses #8440 (#8454)
* fix: ValueError when using FromOriginalModelMixin in subclasses #8440

(cherry picked from commit 9285997843)

* Update src/diffusers/loaders/single_file_model.py

Co-authored-by: Dhruv Nair <dhruv.nair@gmail.com>

* Update single_file_model.py

* Update single_file_model.py

---------

Co-authored-by: Dhruv Nair <dhruv.nair@gmail.com>
Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
2024-06-28 17:15:46 +05:30
Dhruv Nair 150142c537 [Tests] Fix precision related issues in slow pipeline tests (#8720)
update
2024-06-28 08:13:46 +05:30
Linoy Tsaban 35f45ecd71 [Advanced dreambooth lora] adjustments to align with canonical script (#8406)
* minor changes

* minor changes

* minor changes

* minor changes

* minor changes

* minor changes

* minor changes

* fix

* fix

* aligning with blora script

* aligning with blora script

* aligning with blora script

* aligning with blora script

* aligning with blora script

* remove prints

* style

* default val

* license

* move save_model_card to outside push_to_hub

* Update train_dreambooth_lora_sdxl_advanced.py

---------

Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
2024-06-27 13:27:37 +05:30
Sayak Paul d5dd8df3b4 [Chore] perform better deprecation for vqmodeloutput (#8719)
perform better deprecation for vqmodeloutput
2024-06-27 12:16:37 +05:30
Mathis Koroglu 3e0d128da7 Motion Model / Adapter versatility (#8301)
* Motion Model / Adapter versatility

- allow to use a different number of layers per block
- allow to use a different number of transformer per layers per block
- allow a different number of motion attention head per block
- use dropout argument in get_down/up_block in 3d blocks

* Motion Model added arguments renamed & refactoring

* Add test for asymmetric UNetMotionModel
2024-06-27 11:11:29 +05:30
vincedovy a536e775fb Fix json WindowsPath crash (#8662)
* Add check for WindowsPath in to_json_string

On Windows, os.path.join returns a WindowsPath. to_json_string does not convert this from a WindowsPath to a string. Added check for WindowsPath to to_json_saveable.

* Remove extraneous convert to string in test_check_path_types (tests/others/test_config.py)

* Fix style issues in tests/others/test_config.py

* Add unit test to test_config.py to verify that PosixPath and WindowsPath (depending on system) both work when converted to JSON

* Remove distinction between PosixPath and WindowsPath in ConfigMixIn.to_json_string(). Conditional now tests for Path, and uses Path.as_posix() to convert to string.

---------

Co-authored-by: Vincent Dovydaitis <vincedovy@gmail.com>
Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
2024-06-27 10:30:55 +05:30
Álvaro Somoza 3b01d72a64 Modify FlowMatch Scale Noise (#8678)
* initial fix

* apply suggestion

* delete step_index line
2024-06-27 00:36:33 -04:00
Sayak Paul e2a4a46e99 [Release notification] add some info when there is an error. (#8718)
add some info when there is an error.
2024-06-27 09:49:15 +05:30
Sayak Paul eda560d34c modify PR and issue templates (#8687)
* modify PR and issue templates

* add single file poc.
2024-06-27 09:01:47 +05:30
Sayak Paul adbb04864d [LoRA] fix conversion utility so that lora dora loads correctly (#8688)
fix conversion utility so that lora dora loads correctly
2024-06-27 08:58:32 +05:30
Dhruv Nair effe4b9784 Update xformers SD3 test (#8712)
update
2024-06-26 10:24:27 -10:00
Sayak Paul 5b51ad0052 [LoRA] fix vanilla fine-tuned lora loading. (#8691)
fix vanilla fine-tuned lora loading.
2024-06-26 07:38:57 -10:00
Sayak Paul 10b4e354b6 [Chore] remove deprecation from transformer2d regarding the output class. (#8698)
* remove deprecation from transformer2d regarding the output class.

* up

* deprecate more
2024-06-26 07:35:36 -10:00
Donald.Lee ea6938aea5 Fix: unet save_attn_procs at UNet2DconditionLoadersMixin (#8699)
* fix: unet save_attn_procs at custom diffusion

* style: recover unchanaged parts(max line length 119) / mod: add condition

* style: recover unchanaged parts(max line length 119)

---------

Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
2024-06-26 22:30:49 +05:30
Sayak Paul 8ef0d9deff [Observability] add reporting mechanism when mirroring community pipelines. (#8676)
* add reporting mechanism when mirroring community pipelines.

* remove unneeded argument

* get the actual PATH_IN_REPO

* don't need tag
2024-06-26 22:11:33 +05:30
XCL fa2abfdb03 [Tencent Hunyuan Team] Add Hunyuan-DiT ControlNet Inference (#8694)
* add controlnet support

---------

Co-authored-by: xingchaoliu <xingchaoliu@tencent.com>
Co-authored-by: yiyixuxu <yixu310@gmail,com>
2024-06-26 00:43:03 -10:00
YiYi Xu 1d3ef67b09 [doc] add more about from_pipe API for PAG doc (#8701)
* add more about from_pipe API

* Update docs/source/en/using-diffusers/pag.md

* Update docs/source/en/using-diffusers/pag.md

---------

Co-authored-by: yiyixuxu <yixu310@gmail,com>
2024-06-25 22:26:12 -10:00
Dhruv Nair 0f0b531827 Add decorator for compile tests (#8703)
* update

* update

---------

Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
2024-06-26 11:26:47 +05:30
Sayak Paul e8284281c1 add docs on model sharding (#8658)
* add docs on model sharding

* add entry to _toctree.

* Apply suggestions from code review

Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com>

* simplify wording

* add a note on transformer library handling

* move device placement section

* Update docs/source/en/training/distributed_inference.md

Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com>

---------

Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com>
2024-06-26 07:35:11 +05:30
YiYi Xu 715a7da1b2 add sd3 conversion script (#8702)
add conversion script
2024-06-25 14:24:58 -10:00
Álvaro Somoza 14d224d4e6 [Docs] SD3 T5 Token limit doc (#8654)
* doc for max_sequence_length

* better position and changed note to tip

* apply suggestions

---------

Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
2024-06-25 14:41:27 -04:00
YiYi Xu 540399f540 add PAG support (#7944)
* first draft


---------

Co-authored-by: yiyixuxu <yixu310@gmail,com>
Co-authored-by: Junhwa Song <ethan9867@gmail.com>
Co-authored-by: Ahn Donghoon (안동훈 / suno) <suno.vivid@gmail.com>
Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com>
2024-06-25 08:40:02 -10:00
Sayak Paul f088027e93 [Marigold tests] add is_flaky decorator to some Marigold tests (#8696)
okay
2024-06-25 06:27:28 -10:00
Linoy Tsaban c6e08ecd46 [Sd3 Dreambooth LoRA] Add text encoder training for the clip encoders (#8630)
* add clip text-encoder training

* no dora

* text encoder traing fixes

* text encoder traing fixes

* text encoder training fixes

* text encoder training fixes

* text encoder training fixes

* text encoder training fixes

* add text_encoder layers to save_lora

* style

* fix imports

* style

* fix text encoder

* review changes

* review changes

* review changes

* minor change

* add lora tag

* style

* add readme notes

* add tests for clip encoders

* style

* typo

* fixes

* style

* Update tests/lora/test_lora_layers_sd3.py

Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>

* Update examples/dreambooth/README_sd3.md

Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>

* minor readme change

---------

Co-authored-by: YiYi Xu <yixu310@gmail.com>
Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
2024-06-25 18:00:19 +05:30
Sayak Paul 4ad7a1f5fd [Chore] create a utility for calculating the expected number of shards. (#8692)
create a utility for calculating the expected number of shards.
2024-06-25 17:05:39 +05:30
Hammond Liu 1f81fbe274 Fix redundant pipe init in sd3 lora (#8680)
Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
2024-06-25 07:31:20 +05:30
Tolga Cangöz 589931ca79 Errata - Update class method convention to use cls (#8574)
* Class methods are supposed to use `cls` conventionally

* `make style && make quality`

* An Empty commit

---------

Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
2024-06-24 10:35:45 -07:00
Steven Liu 675be88f00 [docs] Add note for float8 (#8685)
add note
2024-06-24 10:13:34 -07:00
Steven Liu df4ad6f4ac [docs] Fix Pillow import (#8684)
fix import error
2024-06-24 10:13:15 -07:00
Sayak Paul bc90c28bc9 [Docs] add note on caching in fast diffusion (#8675)
* add note on caching in fast diffusion

* formatting

* Update docs/source/en/tutorials/fast_diffusion.md

Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com>

---------

Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com>
2024-06-24 10:10:45 -07:00
Tolga Cangöz f040c27d4c Errata - Fix typos and improve style (#8571)
* Fix typos

* Fix typos & up style

* chore: Update numbers

---------

Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
2024-06-24 10:07:22 -07:00
Tolga Cangöz 138fac703a Discourage using deprecated revision parameter (#8573)
* Discourage using `revision`

* `make style && make quality`

* Refactor code to use 'variant' instead of 'revision'

* `revision="bf16"` -> `variant="bf16"`

---------

Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
2024-06-24 10:06:49 -07:00
145 changed files with 15490 additions and 1544 deletions
+8 -4
View File
@@ -63,11 +63,12 @@ body:
Please tag a maximum of 2 people.
Questions on DiffusionPipeline (Saving, Loading, From pretrained, ...):
Questions on DiffusionPipeline (Saving, Loading, From pretrained, ...): @sayakpaul @DN6
Questions on pipelines:
- Stable Diffusion @yiyixuxu @DN6 @sayakpaul
- Stable Diffusion @yiyixuxu @asomoza
- Stable Diffusion XL @yiyixuxu @sayakpaul @DN6
- Stable Diffusion 3: @yiyixuxu @sayakpaul @DN6 @asomoza
- Kandinsky @yiyixuxu
- ControlNet @sayakpaul @yiyixuxu @DN6
- T2I Adapter @sayakpaul @yiyixuxu @DN6
@@ -75,11 +76,14 @@ body:
- Text-to-Video / Video-to-Video @DN6 @sayakpaul
- Wuerstchen @DN6
- Other: @yiyixuxu @DN6
- Improving generation quality: @asomoza
Questions on models:
- UNet @DN6 @yiyixuxu @sayakpaul
- VAE @sayakpaul @DN6 @yiyixuxu
- Transformers/Attention @DN6 @yiyixuxu @sayakpaul @DN6
- Transformers/Attention @DN6 @yiyixuxu @sayakpaul
Questions on single file checkpoints: @DN6
Questions on Schedulers: @yiyixuxu
@@ -99,7 +103,7 @@ body:
Questions on JAX- and MPS-related things: @pcuenca
Questions on audio pipelines: @DN6
Questions on audio pipelines: @sanchit-gandhi
+2 -2
View File
@@ -39,7 +39,7 @@ members/contributors who may be interested in your PR.
Core library:
- Schedulers: @yiyixuxu
- Pipelines: @sayakpaul @yiyixuxu @DN6
- Pipelines and pipeline callbacks: @yiyixuxu and @asomoza
- Training examples: @sayakpaul
- Docs: @stevhliu and @sayakpaul
- JAX and MPS: @pcuenca
@@ -48,7 +48,7 @@ Core library:
Integrations:
- deepspeed: HF Trainer/Accelerate: @pacman100
- deepspeed: HF Trainer/Accelerate: @SunMarc
HF projects:
+14 -2
View File
@@ -13,13 +13,15 @@ env:
jobs:
torch_pipelines_cuda_benchmark_tests:
env:
SLACK_WEBHOOK_URL: ${{ secrets.SLACK_WEBHOOK_URL_BENCHMARK }}
name: Torch Core Pipelines CUDA Benchmarking Tests
strategy:
fail-fast: false
max-parallel: 1
runs-on: [single-gpu, nvidia-gpu, a10, ci]
container:
image: diffusers/diffusers-pytorch-cuda
image: diffusers/diffusers-pytorch-compile-cuda
options: --shm-size "16gb" --ipc host -v /mnt/hf_cache:/mnt/cache/ --gpus 0
steps:
- name: Checkout diffusers
@@ -50,4 +52,14 @@ jobs:
uses: actions/upload-artifact@v2
with:
name: benchmark_test_reports
path: benchmarks/benchmark_outputs
path: benchmarks/benchmark_outputs
- name: Report success status
if: ${{ success() }}
run: |
pip install requests && python utils/notify_benchmarking_status.py --status=success
- name: Report failure status
if: ${{ failure() }}
run: |
pip install requests && python utils/notify_benchmarking_status.py --status=failure
@@ -22,6 +22,9 @@ on:
jobs:
mirror_community_pipeline:
env:
SLACK_WEBHOOK_URL: ${{ secrets.SLACK_WEBHOOK_URL_COMMUNITY_MIRROR }}
runs-on: ubuntu-latest
steps:
# Checkout to correct ref
@@ -86,4 +89,14 @@ jobs:
run: huggingface-cli upload diffusers/community-pipelines-mirror ./examples/community ${PATH_IN_REPO} --repo-type dataset
env:
PATH_IN_REPO: ${{ env.PATH_IN_REPO }}
HF_TOKEN: ${{ secrets.HF_TOKEN_MIRROR_COMMUNITY_PIPELINES }}
HF_TOKEN: ${{ secrets.HF_TOKEN_MIRROR_COMMUNITY_PIPELINES }}
- name: Report success status
if: ${{ success() }}
run: |
pip install requests && python utils/notify_community_pipelines_mirror.py --status=success
- name: Report failure status
if: ${{ failure() }}
run: |
pip install requests && python utils/notify_community_pipelines_mirror.py --status=failure
+1
View File
@@ -330,6 +330,7 @@ jobs:
- name: Run example tests on GPU
env:
HF_TOKEN: ${{ secrets.HF_TOKEN }}
RUN_COMPILE: yes
run: |
python -m pytest -n 1 --max-worker-restart=0 --dist=loadfile -s -v -k "compile" --make-reports=tests_torch_compile_cuda tests/
- name: Failure short reports
+6 -6
View File
@@ -63,14 +63,14 @@ Let's walk through more detailed design decisions for each class.
Pipelines are designed to be easy to use (therefore do not follow [*Simple over easy*](#simple-over-easy) 100%), are not feature complete, and should loosely be seen as examples of how to use [models](#models) and [schedulers](#schedulers) for inference.
The following design principles are followed:
- Pipelines follow the single-file policy. All pipelines can be found in individual directories under src/diffusers/pipelines. One pipeline folder corresponds to one diffusion paper/project/release. Multiple pipeline files can be gathered in one pipeline folder, as its done for [`src/diffusers/pipelines/stable-diffusion`](https://github.com/huggingface/diffusers/tree/main/src/diffusers/pipelines/stable_diffusion). If pipelines share similar functionality, one can make use of the [#Copied from mechanism](https://github.com/huggingface/diffusers/blob/125d783076e5bd9785beb05367a2d2566843a271/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py#L251).
- Pipelines follow the single-file policy. All pipelines can be found in individual directories under src/diffusers/pipelines. One pipeline folder corresponds to one diffusion paper/project/release. Multiple pipeline files can be gathered in one pipeline folder, as its done for [`src/diffusers/pipelines/stable-diffusion`](https://github.com/huggingface/diffusers/tree/main/src/diffusers/pipelines/stable_diffusion). If pipelines share similar functionality, one can make use of the [# Copied from mechanism](https://github.com/huggingface/diffusers/blob/125d783076e5bd9785beb05367a2d2566843a271/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py#L251).
- Pipelines all inherit from [`DiffusionPipeline`].
- Every pipeline consists of different model and scheduler components, that are documented in the [`model_index.json` file](https://huggingface.co/runwayml/stable-diffusion-v1-5/blob/main/model_index.json), are accessible under the same name as attributes of the pipeline and can be shared between pipelines with [`DiffusionPipeline.components`](https://huggingface.co/docs/diffusers/main/en/api/diffusion_pipeline#diffusers.DiffusionPipeline.components) function.
- Every pipeline should be loadable via the [`DiffusionPipeline.from_pretrained`](https://huggingface.co/docs/diffusers/main/en/api/diffusion_pipeline#diffusers.DiffusionPipeline.from_pretrained) function.
- Pipelines should be used **only** for inference.
- Pipelines should be very readable, self-explanatory, and easy to tweak.
- Pipelines should be designed to build on top of each other and be easy to integrate into higher-level APIs.
- Pipelines are **not** intended to be feature-complete user interfaces. For future complete user interfaces one should rather have a look at [InvokeAI](https://github.com/invoke-ai/InvokeAI), [Diffuzers](https://github.com/abhishekkrthakur/diffuzers), and [lama-cleaner](https://github.com/Sanster/lama-cleaner).
- Pipelines are **not** intended to be feature-complete user interfaces. For feature-complete user interfaces one should rather have a look at [InvokeAI](https://github.com/invoke-ai/InvokeAI), [Diffuzers](https://github.com/abhishekkrthakur/diffuzers), and [lama-cleaner](https://github.com/Sanster/lama-cleaner).
- Every pipeline should have one and only one way to run it via a `__call__` method. The naming of the `__call__` arguments should be shared across all pipelines.
- Pipelines should be named after the task they are intended to solve.
- In almost all cases, novel diffusion pipelines shall be implemented in a new pipeline folder/file.
@@ -81,7 +81,7 @@ Models are designed as configurable toolboxes that are natural extensions of [Py
The following design principles are followed:
- Models correspond to **a type of model architecture**. *E.g.* the [`UNet2DConditionModel`] class is used for all UNet variations that expect 2D image inputs and are conditioned on some context.
- All models can be found in [`src/diffusers/models`](https://github.com/huggingface/diffusers/tree/main/src/diffusers/models) and every model architecture shall be defined in its file, e.g. [`unet_2d_condition.py`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/unet_2d_condition.py), [`transformer_2d.py`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/transformer_2d.py), etc...
- All models can be found in [`src/diffusers/models`](https://github.com/huggingface/diffusers/tree/main/src/diffusers/models) and every model architecture shall be defined in its file, e.g. [`unets/unet_2d_condition.py`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/unets/unet_2d_condition.py), [`transformers/transformer_2d.py`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/transformers/transformer_2d.py), etc...
- Models **do not** follow the single-file policy and should make use of smaller model building blocks, such as [`attention.py`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention.py), [`resnet.py`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/resnet.py), [`embeddings.py`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/embeddings.py), etc... **Note**: This is in stark contrast to Transformers' modeling files and shows that models do not really follow the single-file policy.
- Models intend to expose complexity, just like PyTorch's `Module` class, and give clear error messages.
- Models all inherit from `ModelMixin` and `ConfigMixin`.
@@ -90,7 +90,7 @@ The following design principles are followed:
- To integrate new model checkpoints whose general architecture can be classified as an architecture that already exists in Diffusers, the existing model architecture shall be adapted to make it work with the new checkpoint. One should only create a new file if the model architecture is fundamentally different.
- Models should be designed to be easily extendable to future changes. This can be achieved by limiting public function arguments, configuration arguments, and "foreseeing" future changes, *e.g.* it is usually better to add `string` "...type" arguments that can easily be extended to new future types instead of boolean `is_..._type` arguments. Only the minimum amount of changes shall be made to existing architectures to make a new model checkpoint work.
- The model design is a difficult trade-off between keeping code readable and concise and supporting many model checkpoints. For most parts of the modeling code, classes shall be adapted for new model checkpoints, while there are some exceptions where it is preferred to add new classes to make sure the code is kept concise and
readable long-term, such as [UNet blocks](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/unet_2d_blocks.py) and [Attention processors](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
readable long-term, such as [UNet blocks](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/unets/unet_2d_blocks.py) and [Attention processors](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
### Schedulers
@@ -100,11 +100,11 @@ The following design principles are followed:
- All schedulers are found in [`src/diffusers/schedulers`](https://github.com/huggingface/diffusers/tree/main/src/diffusers/schedulers).
- Schedulers are **not** allowed to import from large utils files and shall be kept very self-contained.
- One scheduler Python file corresponds to one scheduler algorithm (as might be defined in a paper).
- If schedulers share similar functionalities, we can make use of the `#Copied from` mechanism.
- If schedulers share similar functionalities, we can make use of the `# Copied from` mechanism.
- Schedulers all inherit from `SchedulerMixin` and `ConfigMixin`.
- Schedulers can be easily swapped out with the [`ConfigMixin.from_config`](https://huggingface.co/docs/diffusers/main/en/api/configuration#diffusers.ConfigMixin.from_config) method as explained in detail [here](./docs/source/en/using-diffusers/schedulers.md).
- Every scheduler has to have a `set_num_inference_steps`, and a `step` function. `set_num_inference_steps(...)` has to be called before every denoising process, *i.e.* before `step(...)` is called.
- Every scheduler exposes the timesteps to be "looped over" via a `timesteps` attribute, which is an array of timesteps the model will be called upon.
- The `step(...)` function takes a predicted model output and the "current" sample (x_t) and returns the "previous", slightly more denoised sample (x_t-1).
- Given the complexity of diffusion schedulers, the `step` function does not expose all the complexity and can be a bit of a "black box".
- In almost all cases, novel schedulers shall be implemented in a new scheduling file.
- In almost all cases, novel schedulers shall be implemented in a new scheduling file.
+2 -2
View File
@@ -67,7 +67,7 @@ Please refer to the [How to use Stable Diffusion in Apple Silicon](https://huggi
## Quickstart
Generating outputs is super easy with 🤗 Diffusers. To generate an image from text, use the `from_pretrained` method to load any pretrained diffusion model (browse the [Hub](https://huggingface.co/models?library=diffusers&sort=downloads) for 25.000+ checkpoints):
Generating outputs is super easy with 🤗 Diffusers. To generate an image from text, use the `from_pretrained` method to load any pretrained diffusion model (browse the [Hub](https://huggingface.co/models?library=diffusers&sort=downloads) for 27.000+ checkpoints):
```python
from diffusers import DiffusionPipeline
@@ -209,7 +209,7 @@ Also, say 👋 in our public Discord channel <a href="https://discord.gg/G7tWnz9
- https://github.com/deep-floyd/IF
- https://github.com/bentoml/BentoML
- https://github.com/bmaltais/kohya_ss
- +11.000 other amazing GitHub repositories 💪
- +12.000 other amazing GitHub repositories 💪
Thank you for using us ❤️.
+5 -1
View File
@@ -40,7 +40,7 @@ def main():
print(f"****** Running file: {file} ******")
# Run with canonical settings.
if file != "benchmark_text_to_image.py":
if file != "benchmark_text_to_image.py" and file != "benchmark_ip_adapters.py":
command = f"python {file}"
run_command(command.split())
@@ -49,6 +49,10 @@ def main():
# Run variants.
for file in python_files:
# See: https://github.com/pytorch/pytorch/issues/129637
if file == "benchmark_ip_adapters.py":
continue
if file == "benchmark_text_to_image.py":
for ckpt in ALL_T2I_CKPTS:
command = f"python {file} --ckpt {ckpt}"
@@ -16,23 +16,24 @@ RUN apt install -y bash \
ca-certificates \
libsndfile1-dev \
libgl1 \
python3.10 \
python3.9 \
python3.9-dev \
python3-pip \
python3.10-venv && \
python3.9-venv && \
rm -rf /var/lib/apt/lists
# make sure to use venv
RUN python3.10 -m venv /opt/venv
RUN python3.9 -m venv /opt/venv
ENV PATH="/opt/venv/bin:$PATH"
# pre-install the heavy dependencies (these can later be overridden by the deps from setup.py)
RUN python3.10 -m pip install --no-cache-dir --upgrade pip uv==0.1.11 && \
python3.10 -m uv pip install --no-cache-dir \
RUN python3.9 -m pip install --no-cache-dir --upgrade pip uv==0.1.11 && \
python3.9 -m uv pip install --no-cache-dir \
torch \
torchvision \
torchaudio \
invisible_watermark && \
python3.10 -m pip install --no-cache-dir \
python3.9 -m pip install --no-cache-dir \
accelerate \
datasets \
hf-doc-builder \
+10
View File
@@ -21,6 +21,8 @@
title: Load LoRAs for inference
- local: tutorials/fast_diffusion
title: Accelerate inference of text-to-image diffusion models
- local: tutorials/inference_with_big_models
title: Working with big models
title: Tutorials
- sections:
- local: using-diffusers/loading
@@ -81,6 +83,8 @@
title: Kandinsky
- local: using-diffusers/ip_adapter
title: IP-Adapter
- local: using-diffusers/pag
title: PAG
- local: using-diffusers/controlnet
title: ControlNet
- local: using-diffusers/t2i_adapter
@@ -253,6 +257,8 @@
title: PriorTransformer
- local: api/models/controlnet
title: ControlNetModel
- local: api/models/controlnet_hunyuandit
title: HunyuanDiT2DControlNetModel
- local: api/models/controlnet_sd3
title: SD3ControlNetModel
title: Models
@@ -278,6 +284,8 @@
title: Consistency Models
- local: api/pipelines/controlnet
title: ControlNet
- local: api/pipelines/controlnet_hunyuandit
title: ControlNet with Hunyuan-DiT
- local: api/pipelines/controlnet_sd3
title: ControlNet with Stable Diffusion 3
- local: api/pipelines/controlnet_sdxl
@@ -322,6 +330,8 @@
title: MultiDiffusion
- local: api/pipelines/musicldm
title: MusicLDM
- local: api/pipelines/pag
title: PAG
- local: api/pipelines/paint_by_example
title: Paint by Example
- local: api/pipelines/pia
+1 -1
View File
@@ -21,7 +21,7 @@ The abstract from the paper is:
## Loading from the original format
By default the [`AutoencoderKL`] should be loaded with [`~ModelMixin.from_pretrained`], but it can also be loaded
from the original format using [`FromOriginalVAEMixin.from_single_file`] as follows:
from the original format using [`FromOriginalModelMixin.from_single_file`] as follows:
```py
from diffusers import AutoencoderKL
+1 -1
View File
@@ -21,7 +21,7 @@ The abstract from the paper is:
## Loading from the original format
By default the [`ControlNetModel`] should be loaded with [`~ModelMixin.from_pretrained`], but it can also be loaded
from the original format using [`FromOriginalControlnetMixin.from_single_file`] as follows:
from the original format using [`FromOriginalModelMixin.from_single_file`] as follows:
```py
from diffusers import StableDiffusionControlNetPipeline, ControlNetModel
@@ -0,0 +1,37 @@
<!--Copyright 2024 The HuggingFace Team and Tencent Hunyuan Team. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
specific language governing permissions and limitations under the License.
-->
# HunyuanDiT2DControlNetModel
HunyuanDiT2DControlNetModel is an implementation of ControlNet for [Hunyuan-DiT](https://arxiv.org/abs/2405.08748).
ControlNet was introduced in [Adding Conditional Control to Text-to-Image Diffusion Models](https://huggingface.co/papers/2302.05543) by Lvmin Zhang, Anyi Rao, and Maneesh Agrawala.
With a ControlNet model, you can provide an additional control image to condition and control Hunyuan-DiT generation. For example, if you provide a depth map, the ControlNet model generates an image that'll preserve the spatial information from the depth map. It is a more flexible and accurate way to control the image generation process.
The abstract from the paper is:
*We present ControlNet, a neural network architecture to add spatial conditioning controls to large, pretrained text-to-image diffusion models. ControlNet locks the production-ready large diffusion models, and reuses their deep and robust encoding layers pretrained with billions of images as a strong backbone to learn a diverse set of conditional controls. The neural architecture is connected with "zero convolutions" (zero-initialized convolution layers) that progressively grow the parameters from zero and ensure that no harmful noise could affect the finetuning. We test various conditioning controls, eg, edges, depth, segmentation, human pose, etc, with Stable Diffusion, using single or multiple conditions, with or without prompts. We show that the training of ControlNets is robust with small (<50k) and large (>1m) datasets. Extensive results show that ControlNet may facilitate wider applications to control image diffusion models.*
This code is implemented by Tencent Hunyuan Team. You can find pre-trained checkpoints for Hunyuan-DiT ControlNets on [Tencent Hunyuan](https://huggingface.co/Tencent-Hunyuan).
## Example For Loading HunyuanDiT2DControlNetModel
```py
from diffusers import HunyuanDiT2DControlNetModel
import torch
controlnet = HunyuanDiT2DControlNetModel.from_pretrained("Tencent-Hunyuan/HunyuanDiT-v1.1-ControlNet-Diffusers-Pose", torch_dtype=torch.float16)
```
## HunyuanDiT2DControlNetModel
[[autodoc]] HunyuanDiT2DControlNetModel
+1 -1
View File
@@ -38,4 +38,4 @@ It is assumed one of the input classes is the masked latent pixel. The predicted
## Transformer2DModelOutput
[[autodoc]] models.transformers.transformer_2d.Transformer2DModelOutput
[[autodoc]] models.modeling_outputs.Transformer2DModelOutput
@@ -0,0 +1,36 @@
<!--Copyright 2024 The HuggingFace Team and Tencent Hunyuan Team. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
specific language governing permissions and limitations under the License.
-->
# ControlNet with Hunyuan-DiT
HunyuanDiTControlNetPipeline is an implementation of ControlNet for [Hunyuan-DiT](https://arxiv.org/abs/2405.08748).
ControlNet was introduced in [Adding Conditional Control to Text-to-Image Diffusion Models](https://huggingface.co/papers/2302.05543) by Lvmin Zhang, Anyi Rao, and Maneesh Agrawala.
With a ControlNet model, you can provide an additional control image to condition and control Hunyuan-DiT generation. For example, if you provide a depth map, the ControlNet model generates an image that'll preserve the spatial information from the depth map. It is a more flexible and accurate way to control the image generation process.
The abstract from the paper is:
*We present ControlNet, a neural network architecture to add spatial conditioning controls to large, pretrained text-to-image diffusion models. ControlNet locks the production-ready large diffusion models, and reuses their deep and robust encoding layers pretrained with billions of images as a strong backbone to learn a diverse set of conditional controls. The neural architecture is connected with "zero convolutions" (zero-initialized convolution layers) that progressively grow the parameters from zero and ensure that no harmful noise could affect the finetuning. We test various conditioning controls, eg, edges, depth, segmentation, human pose, etc, with Stable Diffusion, using single or multiple conditions, with or without prompts. We show that the training of ControlNets is robust with small (<50k) and large (>1m) datasets. Extensive results show that ControlNet may facilitate wider applications to control image diffusion models.*
This code is implemented by Tencent Hunyuan Team. You can find pre-trained checkpoints for Hunyuan-DiT ControlNets on [Tencent Hunyuan](https://huggingface.co/Tencent-Hunyuan).
<Tip>
Make sure to check out the Schedulers [guide](../../using-diffusers/schedulers) to learn how to explore the tradeoff between scheduler speed and quality, and see the [reuse components across pipelines](../../using-diffusers/loading#reuse-components-across-pipelines) section to learn how to efficiently load the same components into multiple pipelines.
</Tip>
## HunyuanDiTControlNetPipeline
[[autodoc]] HunyuanDiTControlNetPipeline
- all
- __call__
+7 -1
View File
@@ -1,4 +1,4 @@
<!--Copyright 2024 The HuggingFace Team. All rights reserved.
<!--Copyright 2024 The HuggingFace Team and Tencent Hunyuan Team. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at
@@ -34,6 +34,12 @@ Make sure to check out the Schedulers [guide](../../using-diffusers/schedulers.m
</Tip>
<Tip>
You can further improve generation quality by passing the generated image from [`HungyuanDiTPipeline`] to the [SDXL refiner](../../using-diffusers/sdxl#base-to-refiner-model) model.
</Tip>
## Optimization
You can optimize the pipeline's runtime and memory consumption with torch.compile and feed-forward chunking. To learn about other optimization methods, check out the [Speed up inference](../../optimization/fp16) and [Reduce memory usage](../../optimization/memory) guides.
+1 -1
View File
@@ -11,7 +11,7 @@ specific language governing permissions and limitations under the License.
Kandinsky 3 is created by [Vladimir Arkhipkin](https://github.com/oriBetelgeuse),[Anastasia Maltseva](https://github.com/NastyaMittseva),[Igor Pavlov](https://github.com/boomb0om),[Andrei Filatov](https://github.com/anvilarth),[Arseniy Shakhmatov](https://github.com/cene555),[Andrey Kuznetsov](https://github.com/kuznetsoffandrey),[Denis Dimitrov](https://github.com/denndimitrov), [Zein Shaheen](https://github.com/zeinsh)
The description from it's Github page:
The description from it's GitHub page:
*Kandinsky 3.0 is an open-source text-to-image diffusion model built upon the Kandinsky2-x model family. In comparison to its predecessors, enhancements have been made to the text understanding and visual quality of the model, achieved by increasing the size of the text encoder and Diffusion U-Net models, respectively.*
+46
View File
@@ -0,0 +1,46 @@
<!--Copyright 2024 The HuggingFace Team. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
specific language governing permissions and limitations under the License.
-->
# Perturbed-Attention Guidance
[Perturbed-Attention Guidance (PAG)](https://ku-cvlab.github.io/Perturbed-Attention-Guidance/) is a new diffusion sampling guidance that improves sample quality across both unconditional and conditional settings, achieving this without requiring further training or the integration of external modules.
PAG was introduced in [Self-Rectifying Diffusion Sampling with Perturbed-Attention Guidance](https://huggingface.co/papers/2403.17377) by Donghoon Ahn, Hyoungwon Cho, Jaewon Min, Wooseok Jang, Jungwoo Kim, SeonHwa Kim, Hyun Hee Park, Kyong Hwan Jin and Seungryong Kim.
The abstract from the paper is:
*Recent studies have demonstrated that diffusion models are capable of generating high-quality samples, but their quality heavily depends on sampling guidance techniques, such as classifier guidance (CG) and classifier-free guidance (CFG). These techniques are often not applicable in unconditional generation or in various downstream tasks such as image restoration. In this paper, we propose a novel sampling guidance, called Perturbed-Attention Guidance (PAG), which improves diffusion sample quality across both unconditional and conditional settings, achieving this without requiring additional training or the integration of external modules. PAG is designed to progressively enhance the structure of samples throughout the denoising process. It involves generating intermediate samples with degraded structure by substituting selected self-attention maps in diffusion U-Net with an identity matrix, by considering the self-attention mechanisms' ability to capture structural information, and guiding the denoising process away from these degraded samples. In both ADM and Stable Diffusion, PAG surprisingly improves sample quality in conditional and even unconditional scenarios. Moreover, PAG significantly improves the baseline performance in various downstream tasks where existing guidances such as CG or CFG cannot be fully utilized, including ControlNet with empty prompts and image restoration such as inpainting and deblurring.*
## StableDiffusionPAGPipeline
[[autodoc]] StableDiffusionPAGPipeline
- all
- __call__
## StableDiffusionXLPAGPipeline
[[autodoc]] StableDiffusionXLPAGPipeline
- all
- __call__
## StableDiffusionXLPAGImg2ImgPipeline
[[autodoc]] StableDiffusionXLPAGImg2ImgPipeline
- all
- __call__
## StableDiffusionXLPAGInpaintPipeline
[[autodoc]] StableDiffusionXLPAGInpaintPipeline
- all
- __call__
## StableDiffusionXLControlNetPAGPipeline
[[autodoc]] StableDiffusionXLControlNetPAGPipeline
- all
- __call__
@@ -37,6 +37,12 @@ Make sure to check out the Schedulers [guide](../../using-diffusers/schedulers)
</Tip>
<Tip>
You can further improve generation quality by passing the generated image from [`PixArtSigmaPipeline`] to the [SDXL refiner](../../using-diffusers/sdxl#base-to-refiner-model) model.
</Tip>
## Inference with under 8GB GPU VRAM
Run the [`PixArtSigmaPipeline`] with under 8GB GPU VRAM by loading the text encoder in 8-bit precision. Let's walk through a full-fledged example.
@@ -48,7 +48,7 @@ from diffusers import DiffusionPipeline, DPMSolverMultistepScheduler
import torch
repo_id = "stabilityai/stable-diffusion-2-base"
pipe = DiffusionPipeline.from_pretrained(repo_id, torch_dtype=torch.float16, revision="fp16")
pipe = DiffusionPipeline.from_pretrained(repo_id, torch_dtype=torch.float16, variant="fp16")
pipe.scheduler = DPMSolverMultistepScheduler.from_config(pipe.scheduler.config)
pipe = pipe.to("cuda")
@@ -72,7 +72,7 @@ init_image = load_image(img_url).resize((512, 512))
mask_image = load_image(mask_url).resize((512, 512))
repo_id = "stabilityai/stable-diffusion-2-inpainting"
pipe = DiffusionPipeline.from_pretrained(repo_id, torch_dtype=torch.float16, revision="fp16")
pipe = DiffusionPipeline.from_pretrained(repo_id, torch_dtype=torch.float16, variant="fp16")
pipe.scheduler = DPMSolverMultistepScheduler.from_config(pipe.scheduler.config)
pipe = pipe.to("cuda")
@@ -35,7 +35,6 @@ The SD3 pipeline uses three text encoders to generate an image. Model offloading
</Tip>
```python
import torch
from diffusers import StableDiffusion3Pipeline
@@ -197,6 +196,47 @@ image.save("sd3_hello_world.png")
Check out the full script [here](https://gist.github.com/sayakpaul/508d89d7aad4f454900813da5d42ca97).
## Using Long Prompts with the T5 Text Encoder
By default, the T5 Text Encoder prompt uses a maximum sequence length of `256`. This can be adjusted by setting the `max_sequence_length` to accept fewer or more tokens. Keep in mind that longer sequences require additional resources and result in longer generation times, such as during batch inference.
```python
prompt = "A whimsical and creative image depicting a hybrid creature that is a mix of a waffle and a hippopotamus, basking in a river of melted butter amidst a breakfast-themed landscape. It features the distinctive, bulky body shape of a hippo. However, instead of the usual grey skin, the creatures body resembles a golden-brown, crispy waffle fresh off the griddle. The skin is textured with the familiar grid pattern of a waffle, each square filled with a glistening sheen of syrup. The environment combines the natural habitat of a hippo with elements of a breakfast table setting, a river of warm, melted butter, with oversized utensils or plates peeking out from the lush, pancake-like foliage in the background, a towering pepper mill standing in for a tree. As the sun rises in this fantastical world, it casts a warm, buttery glow over the scene. The creature, content in its butter river, lets out a yawn. Nearby, a flock of birds take flight"
image = pipe(
prompt=prompt,
negative_prompt="",
num_inference_steps=28,
guidance_scale=4.5,
max_sequence_length=512,
).images[0]
```
### Sending a different prompt to the T5 Text Encoder
You can send a different prompt to the CLIP Text Encoders and the T5 Text Encoder to prevent the prompt from being truncated by the CLIP Text Encoders and to improve generation.
<Tip>
The prompt with the CLIP Text Encoders is still truncated to the 77 token limit.
</Tip>
```python
prompt = "A whimsical and creative image depicting a hybrid creature that is a mix of a waffle and a hippopotamus, basking in a river of melted butter amidst a breakfast-themed landscape. A river of warm, melted butter, pancake-like foliage in the background, a towering pepper mill standing in for a tree."
prompt_3 = "A whimsical and creative image depicting a hybrid creature that is a mix of a waffle and a hippopotamus, basking in a river of melted butter amidst a breakfast-themed landscape. It features the distinctive, bulky body shape of a hippo. However, instead of the usual grey skin, the creatures body resembles a golden-brown, crispy waffle fresh off the griddle. The skin is textured with the familiar grid pattern of a waffle, each square filled with a glistening sheen of syrup. The environment combines the natural habitat of a hippo with elements of a breakfast table setting, a river of warm, melted butter, with oversized utensils or plates peeking out from the lush, pancake-like foliage in the background, a towering pepper mill standing in for a tree. As the sun rises in this fantastical world, it casts a warm, buttery glow over the scene. The creature, content in its butter river, lets out a yawn. Nearby, a flock of birds take flight"
image = pipe(
prompt=prompt,
prompt_3=prompt_3,
negative_prompt="",
num_inference_steps=28,
guidance_scale=4.5,
max_sequence_length=512,
).images[0]
```
## Tiny AutoEncoder for Stable Diffusion 3
Tiny AutoEncoder for Stable Diffusion (TAESD3) is a tiny distilled version of Stable Diffusion 3's VAE by [Ollin Boer Bohan](https://github.com/madebyollin/taesd) that can decode [`StableDiffusion3Pipeline`] latents almost instantly.
@@ -251,6 +291,9 @@ image.save('sd3-single-file.png')
### Loading the single file checkpoint with T5
> [!TIP]
> The following example loads a checkpoint stored in a 8-bit floating point format which requires PyTorch 2.3 or later.
```python
import torch
from diffusers import StableDiffusion3Pipeline
+6 -6
View File
@@ -63,7 +63,7 @@ Let's walk through more in-detail design decisions for each class.
Pipelines are designed to be easy to use (therefore do not follow [*Simple over easy*](#simple-over-easy) 100%), are not feature complete, and should loosely be seen as examples of how to use [models](#models) and [schedulers](#schedulers) for inference.
The following design principles are followed:
- Pipelines follow the single-file policy. All pipelines can be found in individual directories under src/diffusers/pipelines. One pipeline folder corresponds to one diffusion paper/project/release. Multiple pipeline files can be gathered in one pipeline folder, as its done for [`src/diffusers/pipelines/stable-diffusion`](https://github.com/huggingface/diffusers/tree/main/src/diffusers/pipelines/stable_diffusion). If pipelines share similar functionality, one can make use of the [#Copied from mechanism](https://github.com/huggingface/diffusers/blob/125d783076e5bd9785beb05367a2d2566843a271/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py#L251).
- Pipelines follow the single-file policy. All pipelines can be found in individual directories under src/diffusers/pipelines. One pipeline folder corresponds to one diffusion paper/project/release. Multiple pipeline files can be gathered in one pipeline folder, as its done for [`src/diffusers/pipelines/stable-diffusion`](https://github.com/huggingface/diffusers/tree/main/src/diffusers/pipelines/stable_diffusion). If pipelines share similar functionality, one can make use of the [# Copied from mechanism](https://github.com/huggingface/diffusers/blob/125d783076e5bd9785beb05367a2d2566843a271/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py#L251).
- Pipelines all inherit from [`DiffusionPipeline`].
- Every pipeline consists of different model and scheduler components, that are documented in the [`model_index.json` file](https://huggingface.co/runwayml/stable-diffusion-v1-5/blob/main/model_index.json), are accessible under the same name as attributes of the pipeline and can be shared between pipelines with [`DiffusionPipeline.components`](https://huggingface.co/docs/diffusers/main/en/api/diffusion_pipeline#diffusers.DiffusionPipeline.components) function.
- Every pipeline should be loadable via the [`DiffusionPipeline.from_pretrained`](https://huggingface.co/docs/diffusers/main/en/api/diffusion_pipeline#diffusers.DiffusionPipeline.from_pretrained) function.
@@ -81,7 +81,7 @@ Models are designed as configurable toolboxes that are natural extensions of [Py
The following design principles are followed:
- Models correspond to **a type of model architecture**. *E.g.* the [`UNet2DConditionModel`] class is used for all UNet variations that expect 2D image inputs and are conditioned on some context.
- All models can be found in [`src/diffusers/models`](https://github.com/huggingface/diffusers/tree/main/src/diffusers/models) and every model architecture shall be defined in its file, e.g. [`unet_2d_condition.py`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/unet_2d_condition.py), [`transformer_2d.py`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/transformer_2d.py), etc...
- All models can be found in [`src/diffusers/models`](https://github.com/huggingface/diffusers/tree/main/src/diffusers/models) and every model architecture shall be defined in its file, e.g. [`unets/unet_2d_condition.py`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/unets/unet_2d_condition.py), [`transformers/transformer_2d.py`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/transformers/transformer_2d.py), etc...
- Models **do not** follow the single-file policy and should make use of smaller model building blocks, such as [`attention.py`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention.py), [`resnet.py`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/resnet.py), [`embeddings.py`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/embeddings.py), etc... **Note**: This is in stark contrast to Transformers' modeling files and shows that models do not really follow the single-file policy.
- Models intend to expose complexity, just like PyTorch's `Module` class, and give clear error messages.
- Models all inherit from `ModelMixin` and `ConfigMixin`.
@@ -90,7 +90,7 @@ The following design principles are followed:
- To integrate new model checkpoints whose general architecture can be classified as an architecture that already exists in Diffusers, the existing model architecture shall be adapted to make it work with the new checkpoint. One should only create a new file if the model architecture is fundamentally different.
- Models should be designed to be easily extendable to future changes. This can be achieved by limiting public function arguments, configuration arguments, and "foreseeing" future changes, *e.g.* it is usually better to add `string` "...type" arguments that can easily be extended to new future types instead of boolean `is_..._type` arguments. Only the minimum amount of changes shall be made to existing architectures to make a new model checkpoint work.
- The model design is a difficult trade-off between keeping code readable and concise and supporting many model checkpoints. For most parts of the modeling code, classes shall be adapted for new model checkpoints, while there are some exceptions where it is preferred to add new classes to make sure the code is kept concise and
readable long-term, such as [UNet blocks](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/unet_2d_blocks.py) and [Attention processors](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
readable long-term, such as [UNet blocks](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/unets/unet_2d_blocks.py) and [Attention processors](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
### Schedulers
@@ -100,11 +100,11 @@ The following design principles are followed:
- All schedulers are found in [`src/diffusers/schedulers`](https://github.com/huggingface/diffusers/tree/main/src/diffusers/schedulers).
- Schedulers are **not** allowed to import from large utils files and shall be kept very self-contained.
- One scheduler Python file corresponds to one scheduler algorithm (as might be defined in a paper).
- If schedulers share similar functionalities, we can make use of the `#Copied from` mechanism.
- If schedulers share similar functionalities, we can make use of the `# Copied from` mechanism.
- Schedulers all inherit from `SchedulerMixin` and `ConfigMixin`.
- Schedulers can be easily swapped out with the [`ConfigMixin.from_config`](https://huggingface.co/docs/diffusers/main/en/api/configuration#diffusers.ConfigMixin.from_config) method as explained in detail [here](../using-diffusers/schedulers.md).
- Schedulers can be easily swapped out with the [`ConfigMixin.from_config`](https://huggingface.co/docs/diffusers/main/en/api/configuration#diffusers.ConfigMixin.from_config) method as explained in detail [here](../using-diffusers/schedulers).
- Every scheduler has to have a `set_num_inference_steps`, and a `step` function. `set_num_inference_steps(...)` has to be called before every denoising process, *i.e.* before `step(...)` is called.
- Every scheduler exposes the timesteps to be "looped over" via a `timesteps` attribute, which is an array of timesteps the model will be called upon.
- The `step(...)` function takes a predicted model output and the "current" sample (x_t) and returns the "previous", slightly more denoised sample (x_t-1).
- Given the complexity of diffusion schedulers, the `step` function does not expose all the complexity and can be a bit of a "black box".
- In almost all cases, novel schedulers shall be implemented in a new scheduling file.
- In almost all cases, novel schedulers shall be implemented in a new scheduling file.
@@ -52,76 +52,6 @@ To learn more, take a look at the [Distributed Inference with 🤗 Accelerate](h
</Tip>
### Device placement
> [!WARNING]
> This feature is experimental and its APIs might change in the future.
With Accelerate, you can use the `device_map` to determine how to distribute the models of a pipeline across multiple devices. This is useful in situations where you have more than one GPU.
For example, if you have two 8GB GPUs, then using [`~DiffusionPipeline.enable_model_cpu_offload`] may not work so well because:
* it only works on a single GPU
* a single model might not fit on a single GPU ([`~DiffusionPipeline.enable_sequential_cpu_offload`] might work but it will be extremely slow and it is also limited to a single GPU)
To make use of both GPUs, you can use the "balanced" device placement strategy which splits the models across all available GPUs.
> [!WARNING]
> Only the "balanced" strategy is supported at the moment, and we plan to support additional mapping strategies in the future.
```diff
from diffusers import DiffusionPipeline
import torch
pipeline = DiffusionPipeline.from_pretrained(
- "runwayml/stable-diffusion-v1-5", torch_dtype=torch.float16, use_safetensors=True,
+ "runwayml/stable-diffusion-v1-5", torch_dtype=torch.float16, use_safetensors=True, device_map="balanced"
)
image = pipeline("a dog").images[0]
image
```
You can also pass a dictionary to enforce the maximum GPU memory that can be used on each device:
```diff
from diffusers import DiffusionPipeline
import torch
max_memory = {0:"1GB", 1:"1GB"}
pipeline = DiffusionPipeline.from_pretrained(
"runwayml/stable-diffusion-v1-5",
torch_dtype=torch.float16,
use_safetensors=True,
device_map="balanced",
+ max_memory=max_memory
)
image = pipeline("a dog").images[0]
image
```
If a device is not present in `max_memory`, then it will be completely ignored and will not participate in the device placement.
By default, Diffusers uses the maximum memory of all devices. If the models don't fit on the GPUs, they are offloaded to the CPU. If the CPU doesn't have enough memory, then you might see an error. In that case, you could defer to using [`~DiffusionPipeline.enable_sequential_cpu_offload`] and [`~DiffusionPipeline.enable_model_cpu_offload`].
Call [`~DiffusionPipeline.reset_device_map`] to reset the `device_map` of a pipeline. This is also necessary if you want to use methods like `to()`, [`~DiffusionPipeline.enable_sequential_cpu_offload`], and [`~DiffusionPipeline.enable_model_cpu_offload`] on a pipeline that was device-mapped.
```py
pipeline.reset_device_map()
```
Once a pipeline has been device-mapped, you can also access its device map via `hf_device_map`:
```py
print(pipeline.hf_device_map)
```
An example device map would look like so:
```bash
{'unet': 1, 'vae': 1, 'safety_checker': 0, 'text_encoder': 0}
```
## PyTorch Distributed
PyTorch supports [`DistributedDataParallel`](https://pytorch.org/docs/stable/generated/torch.nn.parallel.DistributedDataParallel.html) which enables data parallelism.
@@ -176,3 +106,6 @@ Once you've completed the inference script, use the `--nproc_per_node` argument
```bash
torchrun run_distributed.py --nproc_per_node=2
```
> [!TIP]
> You can use `device_map` within a [`DiffusionPipeline`] to distribute its model-level components on multiple devices. Refer to the [Device placement](../tutorials/inference_with_big_models#device-placement) guide to learn more.
+6 -6
View File
@@ -34,13 +34,10 @@ Install [PyTorch nightly](https://pytorch.org/) to benefit from the latest and f
pip3 install --pre torch --index-url https://download.pytorch.org/whl/nightly/cu121
```
<Tip>
> [!TIP]
> The results reported below are from a 80GB 400W A100 with its clock rate set to the maximum.
> If you're interested in the full benchmarking code, take a look at [huggingface/diffusion-fast](https://github.com/huggingface/diffusion-fast).
The results reported below are from a 80GB 400W A100 with its clock rate set to the maximum. <br>
If you're interested in the full benchmarking code, take a look at [huggingface/diffusion-fast](https://github.com/huggingface/diffusion-fast).
</Tip>
## Baseline
@@ -170,6 +167,9 @@ Using SDPA attention and compiling both the UNet and VAE cuts the latency from 3
<img src="https://huggingface.co/datasets/sayakpaul/sample-datasets/resolve/main/progressive-acceleration-sdxl/SDXL%2C_Batch_Size%3A_1%2C_Steps%3A_30_3.png" width=500>
</div>
> [!TIP]
> From PyTorch 2.3.1, you can control the caching behavior of `torch.compile()`. This is particularly beneficial for compilation modes like `"max-autotune"` which performs a grid-search over several compilation flags to find the optimal configuration. Learn more in the [Compile Time Caching in torch.compile](https://pytorch.org/tutorials/recipes/torch_compile_caching_tutorial.html) tutorial.
### Prevent graph breaks
Specifying `fullgraph=True` ensures there are no graph breaks in the underlying model to take full advantage of `torch.compile` without any performance degradation. For the UNet and VAE, this means changing how you access the return variables.
@@ -0,0 +1,139 @@
<!--Copyright 2024 The HuggingFace Team. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
specific language governing permissions and limitations under the License.
-->
# Working with big models
A modern diffusion model, like [Stable Diffusion XL (SDXL)](../using-diffusers/sdxl), is not just a single model, but a collection of multiple models. SDXL has four different model-level components:
* A variational autoencoder (VAE)
* Two text encoders
* A UNet for denoising
Usually, the text encoders and the denoiser are much larger compared to the VAE.
As models get bigger and better, its possible your model is so big that even a single copy wont fit in memory. But that doesnt mean it cant be loaded. If you have more than one GPU, there is more memory available to store your model. In this case, its better to split your model checkpoint into several smaller *checkpoint shards*.
When a text encoder checkpoint has multiple shards, like [T5-xxl for SD3](https://huggingface.co/stabilityai/stable-diffusion-3-medium-diffusers/tree/main/text_encoder_3), it is automatically handled by the [Transformers](https://huggingface.co/docs/transformers/index) library as it is a required dependency of Diffusers when using the [`StableDiffusion3Pipeline`]. More specifically, Transformers will automatically handle the loading of multiple shards within the requested model class and get it ready so that inference can be performed.
The denoiser checkpoint can also have multiple shards and supports inference thanks to the [Accelerate](https://huggingface.co/docs/accelerate/index) library.
> [!TIP]
> Refer to the [Handling big models for inference](https://huggingface.co/docs/accelerate/main/en/concept_guides/big_model_inference) guide for general guidance when working with big models that are hard to fit into memory.
For example, let's save a sharded checkpoint for the [SDXL UNet](https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0/tree/main/unet):
```python
from diffusers import UNet2DConditionModel
unet = UNet2DConditionModel.from_pretrained(
"stabilityai/stable-diffusion-xl-base-1.0", subfolder="unet"
)
unet.save_pretrained("sdxl-unet-sharded", max_shard_size="5GB")
```
The size of the fp32 variant of the SDXL UNet checkpoint is ~10.4GB. Set the `max_shard_size` parameter to 5GB to create 3 shards. After saving, you can load them in [`StableDiffusionXLPipeline`]:
```python
from diffusers import UNet2DConditionModel, StableDiffusionXLPipeline
import torch
unet = UNet2DConditionModel.from_pretrained(
"sayakpaul/sdxl-unet-sharded", torch_dtype=torch.float16
)
pipeline = StableDiffusionXLPipeline.from_pretrained(
"stabilityai/stable-diffusion-xl-base-1.0", unet=unet, torch_dtype=torch.float16
).to("cuda")
image = pipeline("a cute dog running on the grass", num_inference_steps=30).images[0]
image.save("dog.png")
```
If placing all the model-level components on the GPU at once is not feasible, use [`~DiffusionPipeline.enable_model_cpu_offload`] to help you:
```diff
- pipeline.to("cuda")
+ pipeline.enable_model_cpu_offload()
```
In general, we recommend sharding when a checkpoint is more than 5GB (in fp32).
## Device placement
On distributed setups, you can run inference across multiple GPUs with Accelerate.
> [!WARNING]
> This feature is experimental and its APIs might change in the future.
With Accelerate, you can use the `device_map` to determine how to distribute the models of a pipeline across multiple devices. This is useful in situations where you have more than one GPU.
For example, if you have two 8GB GPUs, then using [`~DiffusionPipeline.enable_model_cpu_offload`] may not work so well because:
* it only works on a single GPU
* a single model might not fit on a single GPU ([`~DiffusionPipeline.enable_sequential_cpu_offload`] might work but it will be extremely slow and it is also limited to a single GPU)
To make use of both GPUs, you can use the "balanced" device placement strategy which splits the models across all available GPUs.
> [!WARNING]
> Only the "balanced" strategy is supported at the moment, and we plan to support additional mapping strategies in the future.
```diff
from diffusers import DiffusionPipeline
import torch
pipeline = DiffusionPipeline.from_pretrained(
- "runwayml/stable-diffusion-v1-5", torch_dtype=torch.float16, use_safetensors=True,
+ "runwayml/stable-diffusion-v1-5", torch_dtype=torch.float16, use_safetensors=True, device_map="balanced"
)
image = pipeline("a dog").images[0]
image
```
You can also pass a dictionary to enforce the maximum GPU memory that can be used on each device:
```diff
from diffusers import DiffusionPipeline
import torch
max_memory = {0:"1GB", 1:"1GB"}
pipeline = DiffusionPipeline.from_pretrained(
"runwayml/stable-diffusion-v1-5",
torch_dtype=torch.float16,
use_safetensors=True,
device_map="balanced",
+ max_memory=max_memory
)
image = pipeline("a dog").images[0]
image
```
If a device is not present in `max_memory`, then it will be completely ignored and will not participate in the device placement.
By default, Diffusers uses the maximum memory of all devices. If the models don't fit on the GPUs, they are offloaded to the CPU. If the CPU doesn't have enough memory, then you might see an error. In that case, you could defer to using [`~DiffusionPipeline.enable_sequential_cpu_offload`] and [`~DiffusionPipeline.enable_model_cpu_offload`].
Call [`~DiffusionPipeline.reset_device_map`] to reset the `device_map` of a pipeline. This is also necessary if you want to use methods like `to()`, [`~DiffusionPipeline.enable_sequential_cpu_offload`], and [`~DiffusionPipeline.enable_model_cpu_offload`] on a pipeline that was device-mapped.
```py
pipeline.reset_device_map()
```
Once a pipeline has been device-mapped, you can also access its device map via `hf_device_map`:
```py
print(pipeline.hf_device_map)
```
An example device map would look like so:
```bash
{'unet': 1, 'vae': 1, 'safety_checker': 0, 'text_encoder': 0}
```
+351
View File
@@ -0,0 +1,351 @@
<!--Copyright 2024 The HuggingFace Team. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
specific language governing permissions and limitations under the License.
-->
# Perturbed-Attention Guidance
[Perturbed-Attention Guidance (PAG)](https://ku-cvlab.github.io/Perturbed-Attention-Guidance/) is a new diffusion sampling guidance that improves sample quality across both unconditional and conditional settings, achieving this without requiring further training or the integration of external modules. PAG is designed to progressively enhance the structure of synthesized samples throughout the denoising process by considering the self-attention mechanisms' ability to capture structural information. It involves generating intermediate samples with degraded structure by substituting selected self-attention maps in diffusion U-Net with an identity matrix, and guiding the denoising process away from these degraded samples.
This guide will show you how to use PAG for various tasks and use cases.
## General tasks
You can apply PAG to the [`StableDiffusionXLPipeline`] for tasks such as text-to-image, image-to-image, and inpainting. To enable PAG for a specific task, load the pipeline using the [AutoPipeline](../api/pipelines/auto_pipeline) API with the `enable_pag=True` flag and the `pag_applied_layers` argument.
> [!TIP]
> 🤗 Diffusers currently only supports using PAG with selected SDXL pipelines, but feel free to open a [feature request](https://github.com/huggingface/diffusers/issues/new/choose) if you want to add PAG support to a new pipeline!
<hfoptions id="tasks">
<hfoption id="Text-to-image">
```py
from diffusers import AutoPipelineForText2Image
from diffusers.utils import load_image
import torch
pipeline = AutoPipelineForText2Image.from_pretrained(
"stabilityai/stable-diffusion-xl-base-1.0",
enable_pag=True,
pag_applied_layers=["mid"],
torch_dtype=torch.float16
)
pipeline.enable_model_cpu_offload()
```
> [!TIP]
> The `pag_applied_layers` argument allows you to specify which layers PAG is applied to. Additionally, you can use `set_pag_applied_layers` method to update these layers after the pipeline has been created. Check out the [pag_applied_layers](#pag_applied_layers) section to learn more about applying PAG to other layers.
If you already have a pipeline created and loaded, you can enable PAG on it using the `from_pipe` API with the `enable_pag` flag. Internally, a PAG pipeline is created based on the pipeline and task you specified. In the example below, since we used `AutoPipelineForText2Image` and passed a `StableDiffusionXLPipeline`, a `StableDiffusionXLPAGPipeline` is created accordingly. Note that this does not require additional memory, and you will have both `StableDiffusionXLPipeline` and `StableDiffusionXLPAGPipeline` loaded and ready to use. You can read more about the `from_pipe` API and how to reuse pipelines in diffuser[here](https://huggingface.co/docs/diffusers/using-diffusers/loading#reuse-a-pipeline)
```py
pipeline_sdxl = AutoPipelineForText2Image.from_pretrained("stabilityai/stable-diffusion-xl-base-1.0, torch_dtype=torch.float16")
pipeline = AutoPipelineForText2Image.from_pipe(pipeline_sdxl, enable_pag=True)
```
To generate an image, you will also need to pass a `pag_scale`. When `pag_scale` increases, images gain more semantically coherent structures and exhibit fewer artifacts. However overly large guidance scale can lead to smoother textures and slight saturation in the images, similarly to CFG. `pag_scale=3.0` is used in the official demo and works well in most of the use cases, but feel free to experiment and select the appropriate value according to your needs! PAG is disabled when `pag_scale=0`.
```py
prompt = "an insect robot preparing a delicious meal, anime style"
for pag_scale in [0.0, 3.0]:
generator = torch.Generator(device="cpu").manual_seed(0)
images = pipeline(
prompt=prompt,
num_inference_steps=25,
guidance_scale=7.0,
generator=generator,
pag_scale=pag_scale,
).images
```
<div class="flex flex-row gap-4">
<div class="flex-1">
<img class="rounded-xl" src="https://huggingface.co/datasets/YiYiXu/testing-images/resolve/main/pag_0.0_cfg_7.0_mid.png"/>
<figcaption class="mt-2 text-center text-sm text-gray-500">generated image without PAG</figcaption>
</div>
<div class="flex-1">
<img class="rounded-xl" src="https://huggingface.co/datasets/YiYiXu/testing-images/resolve/main/pag_3.0_cfg_7.0_mid.png"/>
<figcaption class="mt-2 text-center text-sm text-gray-500">generated image with PAG</figcaption>
</div>
</div>
</hfoption>
<hfoption id="Image-to-image">
You can use PAG with image-to-image pipelines.
```py
from diffusers import AutoPipelineForImage2Image
from diffusers.utils import load_image
import torch
pipeline = AutoPipelineForImage2Image.from_pretrained(
"stabilityai/stable-diffusion-xl-base-1.0",
enable_pag=True,
pag_applied_layers=["mid"],
torch_dtype=torch.float16
)
pipeline.enable_model_cpu_offload()
```
If you already have a image-to-image pipeline and would like enable PAG on it, you can run this
```py
pipeline_t2i = AutoPipelineForImage2Image.from_pretrained("stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16)
pipeline = AutoPipelineForImage2Image.from_pipe(pipeline_t2i, enable_pag=True)
```
It is also very easy to directly switch from a text-to-image pipeline to PAG enabled image-to-image pipeline
```py
pipeline_pag = AutoPipelineForText2Image.from_pretrained("stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16)
pipeline = AutoPipelineForImage2Image.from_pipe(pipeline_t2i, enable_pag=True)
```
If you have a PAG enabled text-to-image pipeline, you can directly switch to a image-to-image pipeline with PAG still enabled
```py
pipeline_pag = AutoPipelineForText2Image.from_pretrained("stabilityai/stable-diffusion-xl-base-1.0", enable_pag=True, torch_dtype=torch.float16)
pipeline = AutoPipelineForImage2Image.from_pipe(pipeline_t2i)
```
Now let's generate an image!
```py
pag_scales = 4.0
guidance_scales = 7.0
url = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/sdxl-text2img.png"
init_image = load_image(url)
prompt = "a dog catching a frisbee in the jungle"
generator = torch.Generator(device="cpu").manual_seed(0)
image = pipeline(
prompt,
image=init_image,
strength=0.8,
guidance_scale=guidance_scale,
pag_scale=pag_scale,
generator=generator).images[0]
```
</hfoption>
<hfoption id="Inpainting">
```py
from diffusers import AutoPipelineForInpainting
from diffusers.utils import load_image
import torch
pipeline = AutoPipelineForInpainting.from_pretrained(
"stabilityai/stable-diffusion-xl-base-1.0",
enable_pag=True,
torch_dtype=torch.float16
)
pipeline.enable_model_cpu_offload()
```
You can enable PAG on an exisiting inpainting pipeline like this
```py
pipeline_inpaint = AutoPipelineForInpaiting.from_pretrained("stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16)
pipeline = AutoPipelineForInpaiting.from_pipe(pipeline_inpaint, enable_pag=True)
```
This still works when your pipeline has a different task:
```py
pipeline_t2i = AutoPipelineForText2Image.from_pretrained("stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16)
pipeline = AutoPipelineForInpaiting.from_pipe(pipeline_t2i, enable_pag=True)
```
Let's generate an image!
```py
img_url = "https://raw.githubusercontent.com/CompVis/latent-diffusion/main/data/inpainting_examples/overture-creations-5sI6fQgYIuo.png"
mask_url = "https://raw.githubusercontent.com/CompVis/latent-diffusion/main/data/inpainting_examples/overture-creations-5sI6fQgYIuo_mask.png"
init_image = load_image(img_url).convert("RGB")
mask_image = load_image(mask_url).convert("RGB")
prompt = "A majestic tiger sitting on a bench"
pag_scales = 3.0
guidance_scales = 7.5
generator = torch.Generator(device="cpu").manual_seed(1)
images = pipeline(
prompt=prompt,
image=init_image,
mask_image=mask_image,
strength=0.8,
num_inference_steps=50,
guidance_scale=guidance_scale,
generator=generator,
pag_scale=pag_scale,
).images
images[0]
```
</hfoption>
</hfoptions>
## PAG with ControlNet
To use PAG with ControlNet, first create a `controlnet`. Then, pass the `controlnet` and other PAG arguments to the `from_pretrained` method of the AutoPipeline for the specified task.
```py
from diffusers import AutoPipelineForText2Image, ControlNetModel
import torch
controlnet = ControlNetModel.from_pretrained(
"diffusers/controlnet-canny-sdxl-1.0", torch_dtype=torch.float16
)
pipeline = AutoPipelineForText2Image.from_pretrained(
"stabilityai/stable-diffusion-xl-base-1.0",
controlnet=controlnet,
enable_pag=True,
pag_applied_layers="mid",
torch_dtype=torch.float16
)
pipeline.enable_model_cpu_offload()
```
<Tip>
If you already have a controlnet pipeline and want to enable PAG, you can use the `from_pipe` API: `AutoPipelineForText2Image.from_pipe(pipeline_controlnet, enable_pag=True)`
</Tip>
You can use the pipeline in the same way you normally use ControlNet pipelines, with the added option to specify a `pag_scale` parameter. Note that PAG works well for unconditional generation. In this example, we will generate an image without a prompt.
```py
from diffusers.utils import load_image
canny_image = load_image(
"https://huggingface.co/datasets/YiYiXu/testing-images/resolve/main/pag_control_input.png"
)
for pag_scale in [0.0, 3.0]:
generator = torch.Generator(device="cpu").manual_seed(1)
images = pipeline(
prompt="",
controlnet_conditioning_scale=controlnet_conditioning_scale,
image=canny_image,
num_inference_steps=50,
guidance_scale=0,
generator=generator,
pag_scale=pag_scale,
).images
images[0]
```
<div class="flex flex-row gap-4">
<div class="flex-1">
<img class="rounded-xl" src="https://huggingface.co/datasets/YiYiXu/testing-images/resolve/main/pag_0.0_controlnet.png"/>
<figcaption class="mt-2 text-center text-sm text-gray-500">generated image without PAG</figcaption>
</div>
<div class="flex-1">
<img class="rounded-xl" src="https://huggingface.co/datasets/YiYiXu/testing-images/resolve/main/pag_3.0_controlnet.png"/>
<figcaption class="mt-2 text-center text-sm text-gray-500">generated image with PAG</figcaption>
</div>
</div>
## PAG with IP-Adapter
[IP-Adapter](https://hf.co/papers/2308.06721) is a popular model that can be plugged into diffusion models to enable image prompting without any changes to the underlying model. You can enable PAG on a pipeline with IP-Adapter loaded.
```py
from diffusers import AutoPipelineForText2Image
from diffusers.utils import load_image
from transformers import CLIPVisionModelWithProjection
import torch
image_encoder = CLIPVisionModelWithProjection.from_pretrained(
"h94/IP-Adapter",
subfolder="models/image_encoder",
torch_dtype=torch.float16
)
pipeline = AutoPipelineForText2Image.from_pretrained(
"stabilityai/stable-diffusion-xl-base-1.0",
image_encoder=image_encoder,
enable_pag=True,
torch_dtype=torch.float16
).to("cuda")
pipeline.load_ip_adapter("h94/IP-Adapter", subfolder="sdxl_models", weight_name="ip-adapter-plus_sdxl_vit-h.bin")
pag_scales = 5.0
ip_adapter_scales = 0.8
image = load_image("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/ip_adapter_diner.png")
pipeline.set_ip_adapter_scale(ip_adapter_scale)
generator = torch.Generator(device="cpu").manual_seed(0)
images = pipeline(
prompt="a polar bear sitting in a chair drinking a milkshake",
ip_adapter_image=image,
negative_prompt="deformed, ugly, wrong proportion, low res, bad anatomy, worst quality, low quality",
num_inference_steps=25,
guidance_scale=3.0,
generator=generator,
pag_scale=pag_scale,
).images
images[0]
```
PAG reduces artifacts and improves the overall compposition.
<div class="flex flex-row gap-4">
<div class="flex-1">
<img class="rounded-xl" src="https://huggingface.co/datasets/YiYiXu/testing-images/resolve/main/pag_0.0_ipa_0.8.png"/>
<figcaption class="mt-2 text-center text-sm text-gray-500">generated image without PAG</figcaption>
</div>
<div class="flex-1">
<img class="rounded-xl" src="https://huggingface.co/datasets/YiYiXu/testing-images/resolve/main/pag_5.0_ipa_0.8.png"/>
<figcaption class="mt-2 text-center text-sm text-gray-500">generated image with PAG</figcaption>
</div>
</div>
## Configure parameters
### pag_applied_layers
The `pag_applied_layers` argument allows you to specify which layers PAG is applied to. By default, it applies only to the mid blocks. Changing this setting will significantly impact the output. You can use the `set_pag_applied_layers` method to adjust the PAG layers after the pipeline is created, helping you find the optimal layers for your model.
As an example, here is the images generated with `pag_layers = ["down.block_2"]` and `pag_layers = ["down.block_2", "up.block_1.attentions_0"]`
```py
prompt = "an insect robot preparing a delicious meal, anime style"
pipeline.set_pag_applied_layers(pag_layers)
generator = torch.Generator(device="cpu").manual_seed(0)
images = pipeline(
prompt=prompt,
num_inference_steps=25,
guidance_scale=guidance_scale,
generator=generator,
pag_scale=pag_scale,
).images
images[0]
```
<div class="flex flex-row gap-4">
<div class="flex-1">
<img class="rounded-xl" src="https://huggingface.co/datasets/YiYiXu/testing-images/resolve/main/pag_3.0_cfg_7.0_down2_up1a0.png"/>
<figcaption class="mt-2 text-center text-sm text-gray-500">down.block_2 + up.block1.attentions_0</figcaption>
</div>
<div class="flex-1">
<img class="rounded-xl" src="https://huggingface.co/datasets/YiYiXu/testing-images/resolve/main/pag_3.0_cfg_7.0_down2.png"/>
<figcaption class="mt-2 text-center text-sm text-gray-500">down.block_2</figcaption>
</div>
</div>
+1 -1
View File
@@ -186,7 +186,7 @@ scheduler, scheduler_state = FlaxDPMSolverMultistepScheduler.from_pretrained(
pipeline, params = FlaxStableDiffusionPipeline.from_pretrained(
"runwayml/stable-diffusion-v1-5",
scheduler=scheduler,
revision="bf16",
variant="bf16",
dtype=jax.numpy.bfloat16,
)
params["scheduler"] = scheduler_state
+6
View File
@@ -285,6 +285,12 @@ refiner = DiffusionPipeline.from_pretrained(
).to("cuda")
```
<Tip warning={true}>
You can use SDXL refiner with a different base model. For example, you can use the [Hunyuan-DiT](../../api/pipelines/hunyuandit) or [PixArt-Sigma](../../api/pipelines/pixart_sigma) pipelines to generate images with better prompt adherence. Once you have generated an image, you can pass it to the SDXL refiner model to enhance final generation quality.
</Tip>
Generate an image from the base model, and set the model output to **latent** space:
```py
@@ -63,7 +63,7 @@ Flax is a functional framework, so models are stateless and parameters are store
dtype = jnp.bfloat16
pipeline, params = FlaxStableDiffusionPipeline.from_pretrained(
"CompVis/stable-diffusion-v1-4",
revision="bf16",
variant="bf16",
dtype=dtype,
)
```
+2 -2
View File
@@ -57,7 +57,7 @@ Diffusers에서는 이러한 철학을 파이프라인과 스케줄러에 모두
파이프라인은 사용하기 쉽도록 설계되었으며 (따라서 [*쉬움보다는 간단함을*](#쉬움보다는-간단함을)을 100% 따르지는 않음), feature-complete하지 않으며, 추론을 위한 [모델](#모델)과 [스케줄러](#스케줄러)를 사용하는 방법의 예시로 간주될 수 있습니다.
다음과 같은 설계 원칙을 따릅니다:
- 파이프라인은 단일 파일 정책을 따릅니다. 모든 파이프라인은 src/diffusers/pipelines의 개별 디렉토리에 있습니다. 하나의 파이프라인 폴더는 하나의 diffusion 논문/프로젝트/릴리스에 해당합니다. 여러 파이프라인 파일은 하나의 파이프라인 폴더에 모을 수 있습니다. 예를 들어 [`src/diffusers/pipelines/stable-diffusion`](https://github.com/huggingface/diffusers/tree/main/src/diffusers/pipelines/stable_diffusion)에서 그렇게 하고 있습니다. 파이프라인이 유사한 기능을 공유하는 경우, [#Copied from mechanism](https://github.com/huggingface/diffusers/blob/125d783076e5bd9785beb05367a2d2566843a271/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py#L251)을 사용할 수 있습니다.
- 파이프라인은 단일 파일 정책을 따릅니다. 모든 파이프라인은 src/diffusers/pipelines의 개별 디렉토리에 있습니다. 하나의 파이프라인 폴더는 하나의 diffusion 논문/프로젝트/릴리스에 해당합니다. 여러 파이프라인 파일은 하나의 파이프라인 폴더에 모을 수 있습니다. 예를 들어 [`src/diffusers/pipelines/stable-diffusion`](https://github.com/huggingface/diffusers/tree/main/src/diffusers/pipelines/stable_diffusion)에서 그렇게 하고 있습니다. 파이프라인이 유사한 기능을 공유하는 경우, [# Copied from mechanism](https://github.com/huggingface/diffusers/blob/125d783076e5bd9785beb05367a2d2566843a271/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py#L251)을 사용할 수 있습니다.
- 파이프라인은 모두 [`DiffusionPipeline`]을 상속합니다.
- 각 파이프라인은 서로 다른 모델 및 스케줄러 구성 요소로 구성되어 있으며, 이는 [`model_index.json` 파일](https://huggingface.co/runwayml/stable-diffusion-v1-5/blob/main/model_index.json)에 문서화되어 있으며, 파이프라인의 속성 이름과 동일한 이름으로 액세스할 수 있으며, [`DiffusionPipeline.components`](https://huggingface.co/docs/diffusers/main/en/api/diffusion_pipeline#diffusers.DiffusionPipeline.components) 함수를 통해 파이프라인 간에 공유할 수 있습니다.
- 각 파이프라인은 [`DiffusionPipeline.from_pretrained`](https://huggingface.co/docs/diffusers/main/en/api/diffusion_pipeline#diffusers.DiffusionPipeline.from_pretrained) 함수를 통해 로드할 수 있어야 합니다.
@@ -93,7 +93,7 @@ Diffusers에서는 이러한 철학을 파이프라인과 스케줄러에 모두
- 모든 스케줄러는 [`src/diffusers/schedulers`](https://github.com/huggingface/diffusers/tree/main/src/diffusers/schedulers)에서 찾을 수 있습니다.
- 스케줄러는 큰 유틸리티 파일에서 가져오지 **않아야** 하며, 자체 포함성을 유지해야 합니다.
- 하나의 스케줄러 Python 파일은 하나의 스케줄러 알고리즘(논문에서 정의된 것과 같은)에 해당합니다.
- 스케줄러가 유사한 기능을 공유하는 경우, `#Copied from` 메커니즘을 사용할 수 있습니다.
- 스케줄러가 유사한 기능을 공유하는 경우, `# Copied from` 메커니즘을 사용할 수 있습니다.
- 모든 스케줄러는 `SchedulerMixin``ConfigMixin`을 상속합니다.
- [`ConfigMixin.from_config`](https://huggingface.co/docs/diffusers/main/en/api/configuration#diffusers.ConfigMixin.from_config) 메서드를 사용하여 스케줄러를 쉽게 교체할 수 있습니다. 자세한 내용은 [여기](../using-diffusers/schedulers.md)에서 설명합니다.
- 모든 스케줄러는 `set_num_inference_steps``step` 함수를 가져야 합니다. `set_num_inference_steps(...)`는 각 노이즈 제거 과정(즉, `step(...)`이 호출되기 전) 이전에 호출되어야 합니다.
+1 -1
View File
@@ -58,7 +58,7 @@ outputs = pipeline(
)
```
더 많은 정보를 얻기 위해, Optimum Habana의 [문서](https://huggingface.co/docs/optimum/habana/usage_guides/stable_diffusion)와 공식 Github 저장소에 제공된 [예시](https://github.com/huggingface/optimum-habana/tree/main/examples/stable-diffusion)를 확인하세요.
더 많은 정보를 얻기 위해, Optimum Habana의 [문서](https://huggingface.co/docs/optimum/habana/usage_guides/stable_diffusion)와 공식 GitHub 저장소에 제공된 [예시](https://github.com/huggingface/optimum-habana/tree/main/examples/stable-diffusion)를 확인하세요.
## 벤치마크
+1 -1
View File
@@ -296,7 +296,7 @@ scheduler, scheduler_state = FlaxDPMSolverMultistepScheduler.from_pretrained(
pipeline, params = FlaxStableDiffusionPipeline.from_pretrained(
model_id,
scheduler=scheduler,
revision="bf16",
variant="bf16",
dtype=jax.numpy.bfloat16,
)
params["scheduler"] = scheduler_state
@@ -83,7 +83,7 @@ Flax는 함수형 프레임워크이므로 모델은 무상태(stateless)형이
```python
pipeline, params = FlaxStableDiffusionPipeline.from_pretrained(
"CompVis/stable-diffusion-v1-4",
revision="bf16",
variant="bf16",
dtype=dtype,
)
```
@@ -1524,17 +1524,22 @@ def main(args):
torch.cuda.empty_cache()
# Scheduler and math around the number of training steps.
overrode_max_train_steps = False
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
# Check the PR https://github.com/huggingface/diffusers/pull/8312 for detailed explanation.
num_warmup_steps_for_scheduler = args.lr_warmup_steps * accelerator.num_processes
if args.max_train_steps is None:
args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
overrode_max_train_steps = True
len_train_dataloader_after_sharding = math.ceil(len(train_dataloader) / accelerator.num_processes)
num_update_steps_per_epoch = math.ceil(len_train_dataloader_after_sharding / args.gradient_accumulation_steps)
num_training_steps_for_scheduler = (
args.num_train_epochs * num_update_steps_per_epoch * accelerator.num_processes
)
else:
num_training_steps_for_scheduler = args.max_train_steps * accelerator.num_processes
lr_scheduler = get_scheduler(
args.lr_scheduler,
optimizer=optimizer,
num_warmup_steps=args.lr_warmup_steps * accelerator.num_processes,
num_training_steps=args.max_train_steps * accelerator.num_processes,
num_warmup_steps=num_warmup_steps_for_scheduler,
num_training_steps=num_training_steps_for_scheduler,
num_cycles=args.lr_num_cycles,
power=args.lr_power,
)
@@ -1551,8 +1556,14 @@ def main(args):
# We need to recalculate our total training steps as the size of the training dataloader may have changed.
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
if overrode_max_train_steps:
if args.max_train_steps is None:
args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
if num_training_steps_for_scheduler != args.max_train_steps * accelerator.num_processes:
logger.warning(
f"The length of the 'train_dataloader' after 'accelerator.prepare' ({len(train_dataloader)}) does not match "
f"the expected length ({len_train_dataloader_after_sharding}) when the learning rate scheduler was created. "
f"This inconsistency may result in the learning rate scheduler not functioning properly."
)
# Afterwards we recalculate our number of training epochs
args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
@@ -31,8 +31,6 @@ from typing import List, Optional
import numpy as np
import torch
import torch.nn.functional as F
# imports of the TokenEmbeddingsHandler class
import torch.utils.checkpoint
import transformers
from accelerate import Accelerator
@@ -77,6 +75,9 @@ from diffusers.utils.import_utils import is_xformers_available
from diffusers.utils.torch_utils import is_compiled_module
if is_wandb_available():
import wandb
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.30.0.dev0")
@@ -101,12 +102,12 @@ def save_model_card(
repo_id: str,
use_dora: bool,
images=None,
base_model=str,
base_model: str = None,
train_text_encoder=False,
train_text_encoder_ti=False,
token_abstraction_dict=None,
instance_prompt=str,
validation_prompt=str,
instance_prompt: str = None,
validation_prompt: str = None,
repo_folder=None,
vae_path=None,
):
@@ -135,6 +136,14 @@ def save_model_card(
diffusers_imports_pivotal = ""
diffusers_example_pivotal = ""
webui_example_pivotal = ""
license = ""
if "playground" in base_model:
license = """\n
## License
Please adhere to the licensing terms as described [here](https://huggingface.co/playgroundai/playground-v2.5-1024px-aesthetic/blob/main/LICENSE.md).
"""
if train_text_encoder_ti:
trigger_str = (
"To trigger image generation of trained concept(or concepts) replace each concept identifier "
@@ -223,11 +232,75 @@ Pivotal tuning was enabled: {train_text_encoder_ti}.
Special VAE used for training: {vae_path}.
{license}
"""
with open(os.path.join(repo_folder, "README.md"), "w") as f:
f.write(yaml + model_card)
def log_validation(
pipeline,
args,
accelerator,
pipeline_args,
epoch,
is_final_validation=False,
):
logger.info(
f"Running validation... \n Generating {args.num_validation_images} images with prompt:"
f" {args.validation_prompt}."
)
# We train on the simplified learning objective. If we were previously predicting a variance, we need the scheduler to ignore it
scheduler_args = {}
if not args.do_edm_style_training:
if "variance_type" in pipeline.scheduler.config:
variance_type = pipeline.scheduler.config.variance_type
if variance_type in ["learned", "learned_range"]:
variance_type = "fixed_small"
scheduler_args["variance_type"] = variance_type
pipeline.scheduler = DPMSolverMultistepScheduler.from_config(pipeline.scheduler.config, **scheduler_args)
pipeline = pipeline.to(accelerator.device)
pipeline.set_progress_bar_config(disable=True)
# run inference
generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed else None
# Currently the context determination is a bit hand-wavy. We can improve it in the future if there's a better
# way to condition it. Reference: https://github.com/huggingface/diffusers/pull/7126#issuecomment-1968523051
if torch.backends.mps.is_available() or "playground" in args.pretrained_model_name_or_path:
autocast_ctx = nullcontext()
else:
autocast_ctx = torch.autocast(accelerator.device.type)
with autocast_ctx:
images = [pipeline(**pipeline_args, generator=generator).images[0] for _ in range(args.num_validation_images)]
for tracker in accelerator.trackers:
phase_name = "test" if is_final_validation else "validation"
if tracker.name == "tensorboard":
np_images = np.stack([np.asarray(img) for img in images])
tracker.writer.add_images(phase_name, np_images, epoch, dataformats="NHWC")
if tracker.name == "wandb":
tracker.log(
{
phase_name: [
wandb.Image(image, caption=f"{i}: {args.validation_prompt}") for i, image in enumerate(images)
]
}
)
del pipeline
if torch.cuda.is_available():
torch.cuda.empty_cache()
return images
def import_model_class_from_model_name_or_path(
pretrained_model_name_or_path: str, revision: str, subfolder: str = "text_encoder"
):
@@ -390,6 +463,7 @@ def parse_args(input_args=None):
)
parser.add_argument(
"--do_edm_style_training",
default=False,
action="store_true",
help="Flag to conduct training using the EDM formulation as introduced in https://arxiv.org/abs/2206.00364.",
)
@@ -571,7 +645,7 @@ def parse_args(input_args=None):
parser.add_argument(
"--optimizer",
type=str,
default="adamW",
default="AdamW",
help=('The optimizer type to use. Choose between ["AdamW", "prodigy"]'),
)
@@ -906,11 +980,6 @@ class DreamBoothDataset(Dataset):
instance_data_root,
instance_prompt,
class_prompt,
dataset_name,
dataset_config_name,
cache_dir,
image_column,
caption_column,
train_text_encoder_ti,
class_data_root=None,
class_num=None,
@@ -929,7 +998,7 @@ class DreamBoothDataset(Dataset):
self.train_text_encoder_ti = train_text_encoder_ti
# if --dataset_name is provided or a metadata jsonl file is provided in the local --instance_data directory,
# we load the training data using load_dataset
if dataset_name is not None:
if args.dataset_name is not None:
try:
from datasets import load_dataset
except ImportError:
@@ -942,25 +1011,26 @@ class DreamBoothDataset(Dataset):
# See more about loading custom images at
# https://huggingface.co/docs/datasets/v2.0.0/en/dataset_script
dataset = load_dataset(
dataset_name,
dataset_config_name,
cache_dir=cache_dir,
args.dataset_name,
args.dataset_config_name,
cache_dir=args.cache_dir,
)
# Preprocessing the datasets.
column_names = dataset["train"].column_names
# 6. Get the column names for input/target.
if image_column is None:
if args.image_column is None:
image_column = column_names[0]
logger.info(f"image column defaulting to {image_column}")
else:
image_column = args.image_column
if image_column not in column_names:
raise ValueError(
f"`--image_column` value '{image_column}' not found in dataset columns. Dataset columns are: {', '.join(column_names)}"
f"`--image_column` value '{args.image_column}' not found in dataset columns. Dataset columns are: {', '.join(column_names)}"
)
instance_images = dataset["train"][image_column]
if caption_column is None:
if args.caption_column is None:
logger.info(
"No caption column provided, defaulting to instance_prompt for all images. If your dataset "
"contains captions/prompts for the images, make sure to specify the "
@@ -968,11 +1038,11 @@ class DreamBoothDataset(Dataset):
)
self.custom_instance_prompts = None
else:
if caption_column not in column_names:
if args.caption_column not in column_names:
raise ValueError(
f"`--caption_column` value '{caption_column}' not found in dataset columns. Dataset columns are: {', '.join(column_names)}"
f"`--caption_column` value '{args.caption_column}' not found in dataset columns. Dataset columns are: {', '.join(column_names)}"
)
custom_instance_prompts = dataset["train"][caption_column]
custom_instance_prompts = dataset["train"][args.caption_column]
# create final list of captions according to --repeats
self.custom_instance_prompts = []
for caption in custom_instance_prompts:
@@ -1178,13 +1248,12 @@ def encode_prompt(text_encoders, tokenizers, prompt, text_input_ids_list=None):
text_input_ids = text_input_ids_list[i]
prompt_embeds = text_encoder(
text_input_ids.to(text_encoder.device),
output_hidden_states=True,
text_input_ids.to(text_encoder.device), output_hidden_states=True, return_dict=False
)
# We are only ALWAYS interested in the pooled output of the final text encoder
pooled_prompt_embeds = prompt_embeds[0]
prompt_embeds = prompt_embeds.hidden_states[-2]
prompt_embeds = prompt_embeds[-1][-2]
bs_embed, seq_len, _ = prompt_embeds.shape
prompt_embeds = prompt_embeds.view(bs_embed, seq_len, -1)
prompt_embeds_list.append(prompt_embeds)
@@ -1200,9 +1269,16 @@ def main(args):
"You cannot use both --report_to=wandb and --hub_token due to a security risk of exposing your token."
" Please use `huggingface-cli login` to authenticate with the Hub."
)
if args.do_edm_style_training and args.snr_gamma is not None:
raise ValueError("Min-SNR formulation is not supported when conducting EDM-style training.")
if torch.backends.mps.is_available() and args.mixed_precision == "bf16":
# due to pytorch#99272, MPS does not yet support bfloat16.
raise ValueError(
"Mixed precision training with bfloat16 is not supported on MPS. Please use fp16 (recommended) or fp32 instead."
)
logging_dir = Path(args.output_dir, args.logging_dir)
accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir, logging_dir=logging_dir)
@@ -1215,10 +1291,13 @@ def main(args):
kwargs_handlers=[kwargs],
)
# Disable AMP for MPS.
if torch.backends.mps.is_available():
accelerator.native_amp = False
if args.report_to == "wandb":
if not is_wandb_available():
raise ImportError("Make sure to install wandb if you want to use it for logging during training.")
import wandb
# Make one log on every process with the configuration for debugging.
logging.basicConfig(
@@ -1246,7 +1325,8 @@ def main(args):
cur_class_images = len(list(class_images_dir.iterdir()))
if cur_class_images < args.num_class_images:
torch_dtype = torch.float16 if accelerator.device.type == "cuda" else torch.float32
has_supported_fp16_accelerator = torch.cuda.is_available() or torch.backends.mps.is_available()
torch_dtype = torch.float16 if has_supported_fp16_accelerator else torch.float32
if args.prior_generation_precision == "fp32":
torch_dtype = torch.float32
elif args.prior_generation_precision == "fp16":
@@ -1404,6 +1484,12 @@ def main(args):
elif accelerator.mixed_precision == "bf16":
weight_dtype = torch.bfloat16
if torch.backends.mps.is_available() and weight_dtype == torch.bfloat16:
# due to pytorch#99272, MPS does not yet support bfloat16.
raise ValueError(
"Mixed precision training with bfloat16 is not supported on MPS. Please use fp16 (recommended) or fp32 instead."
)
# Move unet, vae and text_encoder to device and cast to weight_dtype
unet.to(accelerator.device, dtype=weight_dtype)
@@ -1508,15 +1594,13 @@ def main(args):
if isinstance(model, type(unwrap_model(unet))):
unet_lora_layers_to_save = convert_state_dict_to_diffusers(get_peft_model_state_dict(model))
elif isinstance(model, type(unwrap_model(text_encoder_one))):
if args.train_text_encoder:
text_encoder_one_lora_layers_to_save = convert_state_dict_to_diffusers(
get_peft_model_state_dict(model)
)
text_encoder_one_lora_layers_to_save = convert_state_dict_to_diffusers(
get_peft_model_state_dict(model)
)
elif isinstance(model, type(unwrap_model(text_encoder_two))):
if args.train_text_encoder:
text_encoder_two_lora_layers_to_save = convert_state_dict_to_diffusers(
get_peft_model_state_dict(model)
)
text_encoder_two_lora_layers_to_save = convert_state_dict_to_diffusers(
get_peft_model_state_dict(model)
)
else:
raise ValueError(f"unexpected save model: {model.__class__}")
@@ -1564,6 +1648,7 @@ def main(args):
)
if args.train_text_encoder:
# Do we need to call `scale_lora_layers()` here?
_set_state_dict_into_text_encoder(lora_state_dict, prefix="text_encoder.", text_encoder=text_encoder_one_)
_set_state_dict_into_text_encoder(
@@ -1578,14 +1663,14 @@ def main(args):
if args.train_text_encoder:
models.extend([text_encoder_one_, text_encoder_two_])
# only upcast trainable parameters (LoRA) into fp32
cast_training_params(models)
cast_training_params(models)
accelerator.register_save_state_pre_hook(save_model_hook)
accelerator.register_load_state_pre_hook(load_model_hook)
# Enable TF32 for faster training on Ampere GPUs,
# cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices
if args.allow_tf32:
if args.allow_tf32 and torch.cuda.is_available():
torch.backends.cuda.matmul.allow_tf32 = True
if args.scale_lr:
@@ -1711,12 +1796,7 @@ def main(args):
instance_data_root=args.instance_data_dir,
instance_prompt=args.instance_prompt,
class_prompt=args.class_prompt,
dataset_name=args.dataset_name,
dataset_config_name=args.dataset_config_name,
cache_dir=args.cache_dir,
image_column=args.image_column,
train_text_encoder_ti=args.train_text_encoder_ti,
caption_column=args.caption_column,
class_data_root=args.class_data_dir if args.with_prior_preservation else None,
token_abstraction_dict=token_abstraction_dict if args.train_text_encoder_ti else None,
class_num=args.num_class_images,
@@ -1740,8 +1820,6 @@ def main(args):
def compute_time_ids(crops_coords_top_left, original_size=None):
# Adapted from pipeline.StableDiffusionXLPipeline._get_add_time_ids
if original_size is None:
original_size = (args.resolution, args.resolution)
target_size = (args.resolution, args.resolution)
add_time_ids = list(original_size + crops_coords_top_left + target_size)
add_time_ids = torch.tensor([add_time_ids])
@@ -1778,7 +1856,8 @@ def main(args):
if freeze_text_encoder and not train_dataset.custom_instance_prompts:
del tokenizers, text_encoders
gc.collect()
torch.cuda.empty_cache()
if torch.cuda.is_available():
torch.cuda.empty_cache()
# If custom instance prompts are NOT provided (i.e. the instance prompt is used for all images),
# pack the statically computed variables appropriately here. This is so that we don't
@@ -1820,17 +1899,22 @@ def main(args):
torch.cuda.empty_cache()
# Scheduler and math around the number of training steps.
overrode_max_train_steps = False
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
# Check the PR https://github.com/huggingface/diffusers/pull/8312 for detailed explanation.
num_warmup_steps_for_scheduler = args.lr_warmup_steps * accelerator.num_processes
if args.max_train_steps is None:
args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
overrode_max_train_steps = True
len_train_dataloader_after_sharding = math.ceil(len(train_dataloader) / accelerator.num_processes)
num_update_steps_per_epoch = math.ceil(len_train_dataloader_after_sharding / args.gradient_accumulation_steps)
num_training_steps_for_scheduler = (
args.num_train_epochs * num_update_steps_per_epoch * accelerator.num_processes
)
else:
num_training_steps_for_scheduler = args.max_train_steps * accelerator.num_processes
lr_scheduler = get_scheduler(
args.lr_scheduler,
optimizer=optimizer,
num_warmup_steps=args.lr_warmup_steps * accelerator.num_processes,
num_training_steps=args.max_train_steps * accelerator.num_processes,
num_warmup_steps=num_warmup_steps_for_scheduler,
num_training_steps=num_training_steps_for_scheduler,
num_cycles=args.lr_num_cycles,
power=args.lr_power,
)
@@ -1847,8 +1931,14 @@ def main(args):
# We need to recalculate our total training steps as the size of the training dataloader may have changed.
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
if overrode_max_train_steps:
if args.max_train_steps is None:
args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
if num_training_steps_for_scheduler != args.max_train_steps * accelerator.num_processes:
logger.warning(
f"The length of the 'train_dataloader' after 'accelerator.prepare' ({len(train_dataloader)}) does not match "
f"the expected length ({len_train_dataloader_after_sharding}) when the learning rate scheduler was created. "
f"This inconsistency may result in the learning rate scheduler not functioning properly."
)
# Afterwards we recalculate our number of training epochs
args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
@@ -1946,8 +2036,8 @@ def main(args):
text_encoder_two.train()
# set top parameter requires_grad = True for gradient checkpointing works
if args.train_text_encoder:
text_encoder_one.text_model.embeddings.requires_grad_(True)
text_encoder_two.text_model.embeddings.requires_grad_(True)
accelerator.unwrap_model(text_encoder_one).text_model.embeddings.requires_grad_(True)
accelerator.unwrap_model(text_encoder_two).text_model.embeddings.requires_grad_(True)
for step, batch in enumerate(train_dataloader):
if pivoted:
@@ -2040,7 +2130,6 @@ def main(args):
if freeze_text_encoder:
unet_added_conditions = {
"time_ids": add_time_ids,
# "time_ids": add_time_ids.repeat(elems_to_repeat_time_ids, 1),
"text_embeds": unet_add_text_embeds.repeat(elems_to_repeat_text_embeds, 1),
}
prompt_embeds_input = prompt_embeds.repeat(elems_to_repeat_text_embeds, 1, 1)
@@ -2220,10 +2309,6 @@ def main(args):
if accelerator.is_main_process:
if args.validation_prompt is not None and epoch % args.validation_epochs == 0:
logger.info(
f"Running validation... \n Generating {args.num_validation_images} images with prompt:"
f" {args.validation_prompt}."
)
# create pipeline
if freeze_text_encoder:
text_encoder_one = text_encoder_cls_one.from_pretrained(
@@ -2250,70 +2335,29 @@ def main(args):
variant=args.variant,
torch_dtype=weight_dtype,
)
# We train on the simplified learning objective. If we were previously predicting a variance, we need the scheduler to ignore it
scheduler_args = {}
if not args.do_edm_style_training:
if "variance_type" in pipeline.scheduler.config:
variance_type = pipeline.scheduler.config.variance_type
if variance_type in ["learned", "learned_range"]:
variance_type = "fixed_small"
scheduler_args["variance_type"] = variance_type
pipeline.scheduler = DPMSolverMultistepScheduler.from_config(
pipeline.scheduler.config, **scheduler_args
)
pipeline = pipeline.to(accelerator.device)
pipeline.set_progress_bar_config(disable=True)
# run inference
generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed else None
pipeline_args = {"prompt": args.validation_prompt}
if torch.backends.mps.is_available() or "playground" in args.pretrained_model_name_or_path:
autocast_ctx = nullcontext()
else:
autocast_ctx = torch.autocast(accelerator.device.type)
with autocast_ctx:
images = [
pipeline(**pipeline_args, generator=generator).images[0]
for _ in range(args.num_validation_images)
]
for tracker in accelerator.trackers:
if tracker.name == "tensorboard":
np_images = np.stack([np.asarray(img) for img in images])
tracker.writer.add_images("validation", np_images, epoch, dataformats="NHWC")
if tracker.name == "wandb":
tracker.log(
{
"validation": [
wandb.Image(image, caption=f"{i}: {args.validation_prompt}")
for i, image in enumerate(images)
]
}
)
del pipeline
torch.cuda.empty_cache()
images = log_validation(
pipeline,
args,
accelerator,
pipeline_args,
epoch,
)
# Save the lora layers
accelerator.wait_for_everyone()
if accelerator.is_main_process:
unet = accelerator.unwrap_model(unet)
unet = unwrap_model(unet)
unet = unet.to(torch.float32)
unet_lora_layers = convert_state_dict_to_diffusers(get_peft_model_state_dict(unet))
if args.train_text_encoder:
text_encoder_one = accelerator.unwrap_model(text_encoder_one)
text_encoder_one = unwrap_model(text_encoder_one)
text_encoder_lora_layers = convert_state_dict_to_diffusers(
get_peft_model_state_dict(text_encoder_one.to(torch.float32))
)
text_encoder_two = accelerator.unwrap_model(text_encoder_two)
text_encoder_two = unwrap_model(text_encoder_two)
text_encoder_2_lora_layers = convert_state_dict_to_diffusers(
get_peft_model_state_dict(text_encoder_two.to(torch.float32))
)
@@ -2332,84 +2376,38 @@ def main(args):
embeddings_path = f"{args.output_dir}/{args.output_dir}_emb.safetensors"
embedding_handler.save_embeddings(embeddings_path)
# Final inference
# Load previous pipeline
vae = AutoencoderKL.from_pretrained(
vae_path,
subfolder="vae" if args.pretrained_vae_model_name_or_path is None else None,
revision=args.revision,
variant=args.variant,
torch_dtype=weight_dtype,
)
pipeline = StableDiffusionXLPipeline.from_pretrained(
args.pretrained_model_name_or_path,
vae=vae,
revision=args.revision,
variant=args.variant,
torch_dtype=weight_dtype,
)
# load attention processors
pipeline.load_lora_weights(args.output_dir)
# run inference
images = []
if args.validation_prompt and args.num_validation_images > 0:
# Final inference
# Load previous pipeline
vae = AutoencoderKL.from_pretrained(
vae_path,
subfolder="vae" if args.pretrained_vae_model_name_or_path is None else None,
revision=args.revision,
variant=args.variant,
torch_dtype=weight_dtype,
pipeline_args = {"prompt": args.validation_prompt, "num_inference_steps": 25}
images = log_validation(
pipeline,
args,
accelerator,
pipeline_args,
epoch,
is_final_validation=True,
)
pipeline = StableDiffusionXLPipeline.from_pretrained(
args.pretrained_model_name_or_path,
vae=vae,
revision=args.revision,
variant=args.variant,
torch_dtype=weight_dtype,
)
# We train on the simplified learning objective. If we were previously predicting a variance, we need the scheduler to ignore it
scheduler_args = {}
if not args.do_edm_style_training:
if "variance_type" in pipeline.scheduler.config:
variance_type = pipeline.scheduler.config.variance_type
if variance_type in ["learned", "learned_range"]:
variance_type = "fixed_small"
scheduler_args["variance_type"] = variance_type
pipeline.scheduler = DPMSolverMultistepScheduler.from_config(
pipeline.scheduler.config, **scheduler_args
)
# load attention processors
pipeline.load_lora_weights(args.output_dir)
# load new tokens
if args.train_text_encoder_ti:
state_dict = load_file(embeddings_path)
all_new_tokens = []
for key, value in token_abstraction_dict.items():
all_new_tokens.extend(value)
pipeline.load_textual_inversion(
state_dict["clip_l"],
token=all_new_tokens,
text_encoder=pipeline.text_encoder,
tokenizer=pipeline.tokenizer,
)
pipeline.load_textual_inversion(
state_dict["clip_g"],
token=all_new_tokens,
text_encoder=pipeline.text_encoder_2,
tokenizer=pipeline.tokenizer_2,
)
# run inference
pipeline = pipeline.to(accelerator.device)
generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed else None
images = [
pipeline(args.validation_prompt, num_inference_steps=25, generator=generator).images[0]
for _ in range(args.num_validation_images)
]
for tracker in accelerator.trackers:
if tracker.name == "tensorboard":
np_images = np.stack([np.asarray(img) for img in images])
tracker.writer.add_images("test", np_images, epoch, dataformats="NHWC")
if tracker.name == "wandb":
tracker.log(
{
"test": [
wandb.Image(image, caption=f"{i}: {args.validation_prompt}")
for i, image in enumerate(images)
]
}
)
# Convert to WebUI format
lora_state_dict = load_file(f"{args.output_dir}/pytorch_lora_weights.safetensors")
@@ -2430,6 +2428,7 @@ def main(args):
repo_folder=args.output_dir,
vae_path=args.pretrained_vae_model_name_or_path,
)
if args.push_to_hub:
upload_folder(
repo_id=repo_id,
File diff suppressed because it is too large Load Diff
@@ -2,7 +2,7 @@
# A SDXL pipeline can take unlimited weighted prompt
#
# Author: Andrew Zhu
# Github: https://github.com/xhinker
# GitHub: https://github.com/xhinker
# Medium: https://medium.com/@xhinker
## -----------------------------------------------------------
@@ -2165,7 +2165,7 @@ class SDXLLongPromptWeightingPipeline(
@classmethod
def save_lora_weights(
self,
cls,
save_directory: Union[str, os.PathLike],
unet_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None,
text_encoder_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None,
@@ -2188,7 +2188,7 @@ class SDXLLongPromptWeightingPipeline(
state_dict.update(pack_weights(text_encoder_lora_layers, "text_encoder"))
state_dict.update(pack_weights(text_encoder_2_lora_layers, "text_encoder_2"))
self.write_lora_layers(
cls.write_lora_layers(
state_dict=state_dict,
save_directory=save_directory,
is_main_process=is_main_process,
@@ -1339,7 +1339,7 @@ class DemoFusionSDXLPipeline(
@classmethod
def save_lora_weights(
self,
cls,
save_directory: Union[str, os.PathLike],
unet_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None,
text_encoder_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None,
@@ -1368,7 +1368,7 @@ class DemoFusionSDXLPipeline(
state_dict.update(pack_weights(text_encoder_lora_layers, "text_encoder"))
state_dict.update(pack_weights(text_encoder_2_lora_layers, "text_encoder_2"))
self.write_lora_layers(
cls.write_lora_layers(
state_dict=state_dict,
save_directory=save_directory,
is_main_process=is_main_process,
@@ -0,0 +1,981 @@
# Copyright 2024 Stability AI and The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import inspect
from typing import Callable, Dict, List, Optional, Union
import torch
from transformers import (
CLIPTextModelWithProjection,
CLIPTokenizer,
T5EncoderModel,
T5TokenizerFast,
)
from diffusers.image_processor import PipelineImageInput, VaeImageProcessor
from diffusers.models.autoencoders import AutoencoderKL
from diffusers.models.transformers import SD3Transformer2DModel
from diffusers.pipelines.pipeline_utils import DiffusionPipeline
from diffusers.pipelines.stable_diffusion_3.pipeline_output import StableDiffusion3PipelineOutput
from diffusers.schedulers import FlowMatchEulerDiscreteScheduler
from diffusers.utils import (
is_torch_xla_available,
logging,
replace_example_docstring,
)
from diffusers.utils.torch_utils import randn_tensor
if is_torch_xla_available():
import torch_xla.core.xla_model as xm
XLA_AVAILABLE = True
else:
XLA_AVAILABLE = False
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
EXAMPLE_DOC_STRING = """
Examples:
```py
>>> import torch
>>> from diffusers import AutoPipelineForImage2Image
>>> from diffusers.utils import load_image
>>> device = "cuda"
>>> model_id_or_path = "stabilityai/stable-diffusion-3-medium-diffusers"
>>> pipe = AutoPipelineForImage2Image.from_pretrained(model_id_or_path, torch_dtype=torch.float16)
>>> pipe = pipe.to(device)
>>> url = "https://raw.githubusercontent.com/CompVis/stable-diffusion/main/assets/stable-samples/img2img/sketch-mountains-input.jpg"
>>> init_image = load_image(url).resize((512, 512))
>>> prompt = "cat wizard, gandalf, lord of the rings, detailed, fantasy, cute, adorable, Pixar, Disney, 8k"
>>> images = pipe(prompt=prompt, image=init_image, strength=0.95, guidance_scale=7.5).images[0]
```
"""
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents
def retrieve_latents(
encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample"
):
if hasattr(encoder_output, "latent_dist") and sample_mode == "sample":
return encoder_output.latent_dist.sample(generator)
elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax":
return encoder_output.latent_dist.mode()
elif hasattr(encoder_output, "latents"):
return encoder_output.latents
else:
raise AttributeError("Could not access latents of provided encoder_output")
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
def retrieve_timesteps(
scheduler,
num_inference_steps: Optional[int] = None,
device: Optional[Union[str, torch.device]] = None,
timesteps: Optional[List[int]] = None,
sigmas: Optional[List[float]] = None,
**kwargs,
):
"""
Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
Args:
scheduler (`SchedulerMixin`):
The scheduler to get timesteps from.
num_inference_steps (`int`):
The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
must be `None`.
device (`str` or `torch.device`, *optional*):
The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
timesteps (`List[int]`, *optional*):
Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
`num_inference_steps` and `sigmas` must be `None`.
sigmas (`List[float]`, *optional*):
Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
`num_inference_steps` and `timesteps` must be `None`.
Returns:
`Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
second element is the number of inference steps.
"""
if timesteps is not None and sigmas is not None:
raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
if timesteps is not None:
accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
if not accepts_timesteps:
raise ValueError(
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
f" timestep schedules. Please check whether you are using the correct scheduler."
)
scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
timesteps = scheduler.timesteps
num_inference_steps = len(timesteps)
elif sigmas is not None:
accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
if not accept_sigmas:
raise ValueError(
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
f" sigmas schedules. Please check whether you are using the correct scheduler."
)
scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
timesteps = scheduler.timesteps
num_inference_steps = len(timesteps)
else:
scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
timesteps = scheduler.timesteps
return timesteps, num_inference_steps
class StableDiffusion3DifferentialImg2ImgPipeline(DiffusionPipeline):
r"""
Args:
transformer ([`SD3Transformer2DModel`]):
Conditional Transformer (MMDiT) architecture to denoise the encoded image latents.
scheduler ([`FlowMatchEulerDiscreteScheduler`]):
A scheduler to be used in combination with `transformer` to denoise the encoded image latents.
vae ([`AutoencoderKL`]):
Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
text_encoder ([`CLIPTextModelWithProjection`]):
[CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModelWithProjection),
specifically the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant,
with an additional added projection layer that is initialized with a diagonal matrix with the `hidden_size`
as its dimension.
text_encoder_2 ([`CLIPTextModelWithProjection`]):
[CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModelWithProjection),
specifically the
[laion/CLIP-ViT-bigG-14-laion2B-39B-b160k](https://huggingface.co/laion/CLIP-ViT-bigG-14-laion2B-39B-b160k)
variant.
text_encoder_3 ([`T5EncoderModel`]):
Frozen text-encoder. Stable Diffusion 3 uses
[T5](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5EncoderModel), specifically the
[t5-v1_1-xxl](https://huggingface.co/google/t5-v1_1-xxl) variant.
tokenizer (`CLIPTokenizer`):
Tokenizer of class
[CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer).
tokenizer_2 (`CLIPTokenizer`):
Second Tokenizer of class
[CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer).
tokenizer_3 (`T5TokenizerFast`):
Tokenizer of class
[T5Tokenizer](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5Tokenizer).
"""
model_cpu_offload_seq = "text_encoder->text_encoder_2->text_encoder_3->transformer->vae"
_optional_components = []
_callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds", "negative_pooled_prompt_embeds"]
def __init__(
self,
transformer: SD3Transformer2DModel,
scheduler: FlowMatchEulerDiscreteScheduler,
vae: AutoencoderKL,
text_encoder: CLIPTextModelWithProjection,
tokenizer: CLIPTokenizer,
text_encoder_2: CLIPTextModelWithProjection,
tokenizer_2: CLIPTokenizer,
text_encoder_3: T5EncoderModel,
tokenizer_3: T5TokenizerFast,
):
super().__init__()
self.register_modules(
vae=vae,
text_encoder=text_encoder,
text_encoder_2=text_encoder_2,
text_encoder_3=text_encoder_3,
tokenizer=tokenizer,
tokenizer_2=tokenizer_2,
tokenizer_3=tokenizer_3,
transformer=transformer,
scheduler=scheduler,
)
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
self.image_processor = VaeImageProcessor(
vae_scale_factor=self.vae_scale_factor, vae_latent_channels=self.vae.config.latent_channels
)
self.mask_processor = VaeImageProcessor(
vae_scale_factor=self.vae_scale_factor, do_normalize=False, do_convert_grayscale=True
)
self.tokenizer_max_length = self.tokenizer.model_max_length
self.default_sample_size = self.transformer.config.sample_size
# Copied from diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3.StableDiffusion3Pipeline._get_t5_prompt_embeds
def _get_t5_prompt_embeds(
self,
prompt: Union[str, List[str]] = None,
num_images_per_prompt: int = 1,
max_sequence_length: int = 256,
device: Optional[torch.device] = None,
dtype: Optional[torch.dtype] = None,
):
device = device or self._execution_device
dtype = dtype or self.text_encoder.dtype
prompt = [prompt] if isinstance(prompt, str) else prompt
batch_size = len(prompt)
if self.text_encoder_3 is None:
return torch.zeros(
(
batch_size * num_images_per_prompt,
self.tokenizer_max_length,
self.transformer.config.joint_attention_dim,
),
device=device,
dtype=dtype,
)
text_inputs = self.tokenizer_3(
prompt,
padding="max_length",
max_length=max_sequence_length,
truncation=True,
add_special_tokens=True,
return_tensors="pt",
)
text_input_ids = text_inputs.input_ids
untruncated_ids = self.tokenizer_3(prompt, padding="longest", return_tensors="pt").input_ids
if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
removed_text = self.tokenizer_3.batch_decode(untruncated_ids[:, self.tokenizer_max_length - 1 : -1])
logger.warning(
"The following part of your input was truncated because `max_sequence_length` is set to "
f" {max_sequence_length} tokens: {removed_text}"
)
prompt_embeds = self.text_encoder_3(text_input_ids.to(device))[0]
dtype = self.text_encoder_3.dtype
prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
_, seq_len, _ = prompt_embeds.shape
# duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
return prompt_embeds
# Copied from diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3.StableDiffusion3Pipeline._get_clip_prompt_embeds
def _get_clip_prompt_embeds(
self,
prompt: Union[str, List[str]],
num_images_per_prompt: int = 1,
device: Optional[torch.device] = None,
clip_skip: Optional[int] = None,
clip_model_index: int = 0,
):
device = device or self._execution_device
clip_tokenizers = [self.tokenizer, self.tokenizer_2]
clip_text_encoders = [self.text_encoder, self.text_encoder_2]
tokenizer = clip_tokenizers[clip_model_index]
text_encoder = clip_text_encoders[clip_model_index]
prompt = [prompt] if isinstance(prompt, str) else prompt
batch_size = len(prompt)
text_inputs = tokenizer(
prompt,
padding="max_length",
max_length=self.tokenizer_max_length,
truncation=True,
return_tensors="pt",
)
text_input_ids = text_inputs.input_ids
untruncated_ids = tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
removed_text = tokenizer.batch_decode(untruncated_ids[:, self.tokenizer_max_length - 1 : -1])
logger.warning(
"The following part of your input was truncated because CLIP can only handle sequences up to"
f" {self.tokenizer_max_length} tokens: {removed_text}"
)
prompt_embeds = text_encoder(text_input_ids.to(device), output_hidden_states=True)
pooled_prompt_embeds = prompt_embeds[0]
if clip_skip is None:
prompt_embeds = prompt_embeds.hidden_states[-2]
else:
prompt_embeds = prompt_embeds.hidden_states[-(clip_skip + 2)]
prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device)
_, seq_len, _ = prompt_embeds.shape
# duplicate text embeddings for each generation per prompt, using mps friendly method
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
pooled_prompt_embeds = pooled_prompt_embeds.repeat(1, num_images_per_prompt, 1)
pooled_prompt_embeds = pooled_prompt_embeds.view(batch_size * num_images_per_prompt, -1)
return prompt_embeds, pooled_prompt_embeds
# Copied from diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3.StableDiffusion3Pipeline.encode_prompt
def encode_prompt(
self,
prompt: Union[str, List[str]],
prompt_2: Union[str, List[str]],
prompt_3: Union[str, List[str]],
device: Optional[torch.device] = None,
num_images_per_prompt: int = 1,
do_classifier_free_guidance: bool = True,
negative_prompt: Optional[Union[str, List[str]]] = None,
negative_prompt_2: Optional[Union[str, List[str]]] = None,
negative_prompt_3: Optional[Union[str, List[str]]] = None,
prompt_embeds: Optional[torch.FloatTensor] = None,
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
clip_skip: Optional[int] = None,
max_sequence_length: int = 256,
):
r"""
Args:
prompt (`str` or `List[str]`, *optional*):
prompt to be encoded
prompt_2 (`str` or `List[str]`, *optional*):
The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is
used in all text-encoders
prompt_3 (`str` or `List[str]`, *optional*):
The prompt or prompts to be sent to the `tokenizer_3` and `text_encoder_3`. If not defined, `prompt` is
used in all text-encoders
device: (`torch.device`):
torch device
num_images_per_prompt (`int`):
number of images that should be generated per prompt
do_classifier_free_guidance (`bool`):
whether to use classifier free guidance or not
negative_prompt (`str` or `List[str]`, *optional*):
The prompt or prompts not to guide the image generation. If not defined, one has to pass
`negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
less than `1`).
negative_prompt_2 (`str` or `List[str]`, *optional*):
The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and
`text_encoder_2`. If not defined, `negative_prompt` is used in all the text-encoders.
negative_prompt_2 (`str` or `List[str]`, *optional*):
The prompt or prompts not to guide the image generation to be sent to `tokenizer_3` and
`text_encoder_3`. If not defined, `negative_prompt` is used in both text-encoders
prompt_embeds (`torch.FloatTensor`, *optional*):
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
provided, text embeddings will be generated from `prompt` input argument.
negative_prompt_embeds (`torch.FloatTensor`, *optional*):
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
argument.
pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.
If not provided, pooled text embeddings will be generated from `prompt` input argument.
negative_pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt`
input argument.
clip_skip (`int`, *optional*):
Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that
the output of the pre-final layer will be used for computing the prompt embeddings.
"""
device = device or self._execution_device
prompt = [prompt] if isinstance(prompt, str) else prompt
if prompt is not None:
batch_size = len(prompt)
else:
batch_size = prompt_embeds.shape[0]
if prompt_embeds is None:
prompt_2 = prompt_2 or prompt
prompt_2 = [prompt_2] if isinstance(prompt_2, str) else prompt_2
prompt_3 = prompt_3 or prompt
prompt_3 = [prompt_3] if isinstance(prompt_3, str) else prompt_3
prompt_embed, pooled_prompt_embed = self._get_clip_prompt_embeds(
prompt=prompt,
device=device,
num_images_per_prompt=num_images_per_prompt,
clip_skip=clip_skip,
clip_model_index=0,
)
prompt_2_embed, pooled_prompt_2_embed = self._get_clip_prompt_embeds(
prompt=prompt_2,
device=device,
num_images_per_prompt=num_images_per_prompt,
clip_skip=clip_skip,
clip_model_index=1,
)
clip_prompt_embeds = torch.cat([prompt_embed, prompt_2_embed], dim=-1)
t5_prompt_embed = self._get_t5_prompt_embeds(
prompt=prompt_3,
num_images_per_prompt=num_images_per_prompt,
max_sequence_length=max_sequence_length,
device=device,
)
clip_prompt_embeds = torch.nn.functional.pad(
clip_prompt_embeds, (0, t5_prompt_embed.shape[-1] - clip_prompt_embeds.shape[-1])
)
prompt_embeds = torch.cat([clip_prompt_embeds, t5_prompt_embed], dim=-2)
pooled_prompt_embeds = torch.cat([pooled_prompt_embed, pooled_prompt_2_embed], dim=-1)
if do_classifier_free_guidance and negative_prompt_embeds is None:
negative_prompt = negative_prompt or ""
negative_prompt_2 = negative_prompt_2 or negative_prompt
negative_prompt_3 = negative_prompt_3 or negative_prompt
# normalize str to list
negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt
negative_prompt_2 = (
batch_size * [negative_prompt_2] if isinstance(negative_prompt_2, str) else negative_prompt_2
)
negative_prompt_3 = (
batch_size * [negative_prompt_3] if isinstance(negative_prompt_3, str) else negative_prompt_3
)
if prompt is not None and type(prompt) is not type(negative_prompt):
raise TypeError(
f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
f" {type(prompt)}."
)
elif batch_size != len(negative_prompt):
raise ValueError(
f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
" the batch size of `prompt`."
)
negative_prompt_embed, negative_pooled_prompt_embed = self._get_clip_prompt_embeds(
negative_prompt,
device=device,
num_images_per_prompt=num_images_per_prompt,
clip_skip=None,
clip_model_index=0,
)
negative_prompt_2_embed, negative_pooled_prompt_2_embed = self._get_clip_prompt_embeds(
negative_prompt_2,
device=device,
num_images_per_prompt=num_images_per_prompt,
clip_skip=None,
clip_model_index=1,
)
negative_clip_prompt_embeds = torch.cat([negative_prompt_embed, negative_prompt_2_embed], dim=-1)
t5_negative_prompt_embed = self._get_t5_prompt_embeds(
prompt=negative_prompt_3,
num_images_per_prompt=num_images_per_prompt,
max_sequence_length=max_sequence_length,
device=device,
)
negative_clip_prompt_embeds = torch.nn.functional.pad(
negative_clip_prompt_embeds,
(0, t5_negative_prompt_embed.shape[-1] - negative_clip_prompt_embeds.shape[-1]),
)
negative_prompt_embeds = torch.cat([negative_clip_prompt_embeds, t5_negative_prompt_embed], dim=-2)
negative_pooled_prompt_embeds = torch.cat(
[negative_pooled_prompt_embed, negative_pooled_prompt_2_embed], dim=-1
)
return prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds
def check_inputs(
self,
prompt,
prompt_2,
prompt_3,
strength,
negative_prompt=None,
negative_prompt_2=None,
negative_prompt_3=None,
prompt_embeds=None,
negative_prompt_embeds=None,
pooled_prompt_embeds=None,
negative_pooled_prompt_embeds=None,
callback_on_step_end_tensor_inputs=None,
max_sequence_length=None,
):
if strength < 0 or strength > 1:
raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}")
if callback_on_step_end_tensor_inputs is not None and not all(
k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
):
raise ValueError(
f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
)
if prompt is not None and prompt_embeds is not None:
raise ValueError(
f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
" only forward one of the two."
)
elif prompt_2 is not None and prompt_embeds is not None:
raise ValueError(
f"Cannot forward both `prompt_2`: {prompt_2} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
" only forward one of the two."
)
elif prompt_3 is not None and prompt_embeds is not None:
raise ValueError(
f"Cannot forward both `prompt_3`: {prompt_2} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
" only forward one of the two."
)
elif prompt is None and prompt_embeds is None:
raise ValueError(
"Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
)
elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
elif prompt_2 is not None and (not isinstance(prompt_2, str) and not isinstance(prompt_2, list)):
raise ValueError(f"`prompt_2` has to be of type `str` or `list` but is {type(prompt_2)}")
elif prompt_3 is not None and (not isinstance(prompt_3, str) and not isinstance(prompt_3, list)):
raise ValueError(f"`prompt_3` has to be of type `str` or `list` but is {type(prompt_3)}")
if negative_prompt is not None and negative_prompt_embeds is not None:
raise ValueError(
f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
)
elif negative_prompt_2 is not None and negative_prompt_embeds is not None:
raise ValueError(
f"Cannot forward both `negative_prompt_2`: {negative_prompt_2} and `negative_prompt_embeds`:"
f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
)
elif negative_prompt_3 is not None and negative_prompt_embeds is not None:
raise ValueError(
f"Cannot forward both `negative_prompt_3`: {negative_prompt_3} and `negative_prompt_embeds`:"
f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
)
if prompt_embeds is not None and negative_prompt_embeds is not None:
if prompt_embeds.shape != negative_prompt_embeds.shape:
raise ValueError(
"`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
f" {negative_prompt_embeds.shape}."
)
if prompt_embeds is not None and pooled_prompt_embeds is None:
raise ValueError(
"If `prompt_embeds` are provided, `pooled_prompt_embeds` also have to be passed. Make sure to generate `pooled_prompt_embeds` from the same text encoder that was used to generate `prompt_embeds`."
)
if negative_prompt_embeds is not None and negative_pooled_prompt_embeds is None:
raise ValueError(
"If `negative_prompt_embeds` are provided, `negative_pooled_prompt_embeds` also have to be passed. Make sure to generate `negative_pooled_prompt_embeds` from the same text encoder that was used to generate `negative_prompt_embeds`."
)
if max_sequence_length is not None and max_sequence_length > 512:
raise ValueError(f"`max_sequence_length` cannot be greater than 512 but is {max_sequence_length}")
def get_timesteps(self, num_inference_steps, strength, device):
# get the original timestep using init_timestep
init_timestep = min(num_inference_steps * strength, num_inference_steps)
t_start = int(max(num_inference_steps - init_timestep, 0))
timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :]
return timesteps, num_inference_steps - t_start
def prepare_latents(
self, batch_size, num_channels_latents, height, width, image, timestep, dtype, device, generator=None
):
shape = (
batch_size,
num_channels_latents,
int(height) // self.vae_scale_factor,
int(width) // self.vae_scale_factor,
)
image = image.to(device=device, dtype=dtype)
if isinstance(generator, list) and len(generator) != batch_size:
raise ValueError(
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
)
elif isinstance(generator, list):
init_latents = [
retrieve_latents(self.vae.encode(image[i : i + 1]), generator=generator[i]) for i in range(batch_size)
]
init_latents = torch.cat(init_latents, dim=0)
else:
init_latents = retrieve_latents(self.vae.encode(image), generator=generator)
init_latents = (init_latents - self.vae.config.shift_factor) * self.vae.config.scaling_factor
if batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] == 0:
# expand init_latents for batch_size
additional_image_per_prompt = batch_size // init_latents.shape[0]
init_latents = torch.cat([init_latents] * additional_image_per_prompt, dim=0)
elif batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] != 0:
raise ValueError(
f"Cannot duplicate `image` of batch size {init_latents.shape[0]} to {batch_size} text prompts."
)
else:
init_latents = torch.cat([init_latents], dim=0)
shape = init_latents.shape
noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
init_latents = self.scheduler.scale_noise(init_latents, timestep, noise)
latents = init_latents.to(device=device, dtype=dtype)
return latents
@property
def guidance_scale(self):
return self._guidance_scale
@property
def clip_skip(self):
return self._clip_skip
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
# corresponds to doing no classifier free guidance.
@property
def do_classifier_free_guidance(self):
return self._guidance_scale > 1
@property
def num_timesteps(self):
return self._num_timesteps
@property
def interrupt(self):
return self._interrupt
@torch.no_grad()
@replace_example_docstring(EXAMPLE_DOC_STRING)
def __call__(
self,
prompt: Union[str, List[str]] = None,
prompt_2: Optional[Union[str, List[str]]] = None,
prompt_3: Optional[Union[str, List[str]]] = None,
height: Optional[int] = None,
width: Optional[int] = None,
image: PipelineImageInput = None,
strength: float = 0.6,
num_inference_steps: int = 50,
timesteps: List[int] = None,
guidance_scale: float = 7.0,
negative_prompt: Optional[Union[str, List[str]]] = None,
negative_prompt_2: Optional[Union[str, List[str]]] = None,
negative_prompt_3: Optional[Union[str, List[str]]] = None,
num_images_per_prompt: Optional[int] = 1,
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
latents: Optional[torch.FloatTensor] = None,
prompt_embeds: Optional[torch.FloatTensor] = None,
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
output_type: Optional[str] = "pil",
return_dict: bool = True,
clip_skip: Optional[int] = None,
callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
max_sequence_length: int = 256,
map: PipelineImageInput = None,
):
r"""
Function invoked when calling the pipeline for generation.
Args:
prompt (`str` or `List[str]`, *optional*):
The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
instead.
prompt_2 (`str` or `List[str]`, *optional*):
The prompt or prompts to be sent to `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is
will be used instead
prompt_3 (`str` or `List[str]`, *optional*):
The prompt or prompts to be sent to `tokenizer_3` and `text_encoder_3`. If not defined, `prompt` is
will be used instead
height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
The height in pixels of the generated image. This is set to 1024 by default for the best results.
width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
The width in pixels of the generated image. This is set to 1024 by default for the best results.
num_inference_steps (`int`, *optional*, defaults to 50):
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
expense of slower inference.
timesteps (`List[int]`, *optional*):
Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument
in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is
passed will be used. Must be in descending order.
guidance_scale (`float`, *optional*, defaults to 5.0):
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
`guidance_scale` is defined as `w` of equation 2. of [Imagen
Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
usually at the expense of lower image quality.
negative_prompt (`str` or `List[str]`, *optional*):
The prompt or prompts not to guide the image generation. If not defined, one has to pass
`negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
less than `1`).
negative_prompt_2 (`str` or `List[str]`, *optional*):
The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and
`text_encoder_2`. If not defined, `negative_prompt` is used instead
negative_prompt_3 (`str` or `List[str]`, *optional*):
The prompt or prompts not to guide the image generation to be sent to `tokenizer_3` and
`text_encoder_3`. If not defined, `negative_prompt` is used instead
num_images_per_prompt (`int`, *optional*, defaults to 1):
The number of images to generate per prompt.
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
to make generation deterministic.
latents (`torch.FloatTensor`, *optional*):
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
tensor will ge generated by sampling using the supplied random `generator`.
prompt_embeds (`torch.FloatTensor`, *optional*):
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
provided, text embeddings will be generated from `prompt` input argument.
negative_prompt_embeds (`torch.FloatTensor`, *optional*):
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
argument.
pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.
If not provided, pooled text embeddings will be generated from `prompt` input argument.
negative_pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt`
input argument.
output_type (`str`, *optional*, defaults to `"pil"`):
The output format of the generate image. Choose between
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] instead
of a plain tuple.
callback_on_step_end (`Callable`, *optional*):
A function that calls at the end of each denoising steps during the inference. The function is called
with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
`callback_on_step_end_tensor_inputs`.
callback_on_step_end_tensor_inputs (`List`, *optional*):
The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
`._callback_tensor_inputs` attribute of your pipeline class.
max_sequence_length (`int` defaults to 256): Maximum sequence length to use with the `prompt`.
Examples:
Returns:
[`~pipelines.stable_diffusion_3.StableDiffusion3PipelineOutput`] or `tuple`:
[`~pipelines.stable_diffusion_3.StableDiffusion3PipelineOutput`] if `return_dict` is True, otherwise a
`tuple`. When returning a tuple, the first element is a list with the generated images.
"""
# 0. Default height and width
height = height or self.default_sample_size * self.vae_scale_factor
width = width or self.default_sample_size * self.vae_scale_factor
# 1. Check inputs. Raise error if not correct
self.check_inputs(
prompt,
prompt_2,
prompt_3,
strength,
negative_prompt=negative_prompt,
negative_prompt_2=negative_prompt_2,
negative_prompt_3=negative_prompt_3,
prompt_embeds=prompt_embeds,
negative_prompt_embeds=negative_prompt_embeds,
pooled_prompt_embeds=pooled_prompt_embeds,
negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs,
max_sequence_length=max_sequence_length,
)
self._guidance_scale = guidance_scale
self._clip_skip = clip_skip
self._interrupt = False
# 2. Define call parameters
if prompt is not None and isinstance(prompt, str):
batch_size = 1
elif prompt is not None and isinstance(prompt, list):
batch_size = len(prompt)
else:
batch_size = prompt_embeds.shape[0]
device = self._execution_device
(
prompt_embeds,
negative_prompt_embeds,
pooled_prompt_embeds,
negative_pooled_prompt_embeds,
) = self.encode_prompt(
prompt=prompt,
prompt_2=prompt_2,
prompt_3=prompt_3,
negative_prompt=negative_prompt,
negative_prompt_2=negative_prompt_2,
negative_prompt_3=negative_prompt_3,
do_classifier_free_guidance=self.do_classifier_free_guidance,
prompt_embeds=prompt_embeds,
negative_prompt_embeds=negative_prompt_embeds,
pooled_prompt_embeds=pooled_prompt_embeds,
negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
device=device,
clip_skip=self.clip_skip,
num_images_per_prompt=num_images_per_prompt,
max_sequence_length=max_sequence_length,
)
if self.do_classifier_free_guidance:
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
pooled_prompt_embeds = torch.cat([negative_pooled_prompt_embeds, pooled_prompt_embeds], dim=0)
# 3. Preprocess image
init_image = self.image_processor.preprocess(image, height=height, width=width).to(dtype=torch.float32)
map = self.mask_processor.preprocess(
map, height=height // self.vae_scale_factor, width=width // self.vae_scale_factor
).to(device)
# 4. Prepare timesteps
timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps)
# begin diff diff change
total_time_steps = num_inference_steps
# end diff diff change
timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, strength, device)
latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt)
# 5. Prepare latent variables
num_channels_latents = self.transformer.config.in_channels
if latents is None:
latents = self.prepare_latents(
batch_size * num_images_per_prompt,
num_channels_latents,
height,
width,
init_image,
latent_timestep,
prompt_embeds.dtype,
device,
generator,
)
# 6. Denoising loop
num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
self._num_timesteps = len(timesteps)
# preparations for diff diff
original_with_noise = self.prepare_latents(
batch_size * num_images_per_prompt,
num_channels_latents,
height,
width,
init_image,
timesteps,
prompt_embeds.dtype,
device,
generator,
)
thresholds = torch.arange(total_time_steps, dtype=map.dtype) / total_time_steps
thresholds = thresholds.unsqueeze(1).unsqueeze(1).to(device)
masks = map.squeeze() > thresholds
# end diff diff preparations
with self.progress_bar(total=num_inference_steps) as progress_bar:
for i, t in enumerate(timesteps):
if self.interrupt:
continue
# diff diff
if i == 0:
latents = original_with_noise[:1]
else:
mask = masks[i].unsqueeze(0).to(latents.dtype)
mask = mask.unsqueeze(1) # fit shape
latents = original_with_noise[i] * mask + latents * (1 - mask)
# end diff diff
# expand the latents if we are doing classifier free guidance
latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
timestep = t.expand(latent_model_input.shape[0])
noise_pred = self.transformer(
hidden_states=latent_model_input,
timestep=timestep,
encoder_hidden_states=prompt_embeds,
pooled_projections=pooled_prompt_embeds,
return_dict=False,
)[0]
# perform guidance
if self.do_classifier_free_guidance:
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond)
# compute the previous noisy sample x_t -> x_t-1
latents_dtype = latents.dtype
latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
if latents.dtype != latents_dtype:
if torch.backends.mps.is_available():
# some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272
latents = latents.to(latents_dtype)
if callback_on_step_end is not None:
callback_kwargs = {}
for k in callback_on_step_end_tensor_inputs:
callback_kwargs[k] = locals()[k]
callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
latents = callback_outputs.pop("latents", latents)
prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
negative_pooled_prompt_embeds = callback_outputs.pop(
"negative_pooled_prompt_embeds", negative_pooled_prompt_embeds
)
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
progress_bar.update()
if XLA_AVAILABLE:
xm.mark_step()
if output_type == "latent":
image = latents
else:
latents = (latents / self.vae.config.scaling_factor) + self.vae.config.shift_factor
image = self.vae.decode(latents, return_dict=False)[0]
image = self.image_processor.postprocess(image, output_type=output_type)
# Offload all models
self.maybe_free_model_hooks()
if not return_dict:
return (image,)
return StableDiffusion3PipelineOutput(images=image)
+1 -1
View File
@@ -282,7 +282,7 @@ class StableDiffusionTiledUpscalePipeline(StableDiffusionUpscalePipeline):
def main():
# Run a demo
model_id = "stabilityai/stable-diffusion-x4-upscaler"
pipe = StableDiffusionTiledUpscalePipeline.from_pretrained(model_id, revision="fp16", torch_dtype=torch.float16)
pipe = StableDiffusionTiledUpscalePipeline.from_pretrained(model_id, variant="fp16", torch_dtype=torch.float16)
pipe = pipe.to("cuda")
image = Image.open("../../docs/source/imgs/diffusers_library.jpg")
+18 -7
View File
@@ -1088,17 +1088,22 @@ def main(args):
)
# Scheduler and math around the number of training steps.
overrode_max_train_steps = False
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
# Check the PR https://github.com/huggingface/diffusers/pull/8312 for detailed explanation.
num_warmup_steps_for_scheduler = args.lr_warmup_steps * accelerator.num_processes
if args.max_train_steps is None:
args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
overrode_max_train_steps = True
len_train_dataloader_after_sharding = math.ceil(len(train_dataloader) / accelerator.num_processes)
num_update_steps_per_epoch = math.ceil(len_train_dataloader_after_sharding / args.gradient_accumulation_steps)
num_training_steps_for_scheduler = (
args.num_train_epochs * num_update_steps_per_epoch * accelerator.num_processes
)
else:
num_training_steps_for_scheduler = args.max_train_steps * accelerator.num_processes
lr_scheduler = get_scheduler(
args.lr_scheduler,
optimizer=optimizer,
num_warmup_steps=args.lr_warmup_steps * accelerator.num_processes,
num_training_steps=args.max_train_steps * accelerator.num_processes,
num_warmup_steps=num_warmup_steps_for_scheduler,
num_training_steps=num_training_steps_for_scheduler,
num_cycles=args.lr_num_cycles,
power=args.lr_power,
)
@@ -1110,8 +1115,14 @@ def main(args):
# We need to recalculate our total training steps as the size of the training dataloader may have changed.
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
if overrode_max_train_steps:
if args.max_train_steps is None:
args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
if num_training_steps_for_scheduler != args.max_train_steps * accelerator.num_processes:
logger.warning(
f"The length of the 'train_dataloader' after 'accelerator.prepare' ({len(train_dataloader)}) does not match "
f"the expected length ({len_train_dataloader_after_sharding}) when the learning rate scheduler was created. "
f"This inconsistency may result in the learning rate scheduler not functioning properly."
)
# Afterwards we recalculate our number of training epochs
args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
+34
View File
@@ -147,6 +147,40 @@ accelerate launch train_dreambooth_lora_sd3.py \
--push_to_hub
```
### Text Encoder Training
Alongside the transformer, LoRA fine-tuning of the CLIP text encoders is now also supported.
To do so, just specify `--train_text_encoder` while launching training. Please keep the following points in mind:
> [!NOTE]
> SD3 has three text encoders (CLIP L/14, OpenCLIP bigG/14, and T5-v1.1-XXL).
By enabling `--train_text_encoder`, LoRA fine-tuning of both **CLIP encoders** is performed. At the moment, T5 fine-tuning is not supported and weights remain frozen when text encoder training is enabled.
To perform DreamBooth LoRA with text-encoder training, run:
```bash
export MODEL_NAME="stabilityai/stable-diffusion-3-medium-diffusers"
export OUTPUT_DIR="trained-sd3-lora"
accelerate launch train_dreambooth_lora_sd3.py \
--pretrained_model_name_or_path=$MODEL_NAME \
--output_dir=$OUTPUT_DIR \
--dataset_name="Norod78/Yarn-art-style" \
--instance_prompt="a photo of TOK yarn art dog" \
--resolution=1024 \
--train_batch_size=1 \
--train_text_encoder\
--gradient_accumulation_steps=1 \
--optimizer="prodigy"\
--learning_rate=1.0 \
--text_encoder_lr=1.0 \
--report_to="wandb" \
--lr_scheduler="constant" \
--lr_warmup_steps=0 \
--max_train_steps=1500 \
--rank=32 \
--seed="0" \
--push_to_hub
```
## Other notes
We default to the "logit_normal" weighting scheme for the loss following the SD3 paper. Thanks to @bghira for helping us discover that for other weighting schemes supported from the training script, training may incur numerical instabilities.
+1 -1
View File
@@ -261,7 +261,7 @@ The authors found that by using DoRA, both the learning capacity and training st
**Usage**
1. To use DoRA you need to upgrade the installation of `peft`:
```bash
pip install-U peft
pip install -U peft
```
2. Enable DoRA training by adding this flag
```bash
+209 -59
View File
@@ -54,6 +54,7 @@ from diffusers import (
)
from diffusers.optimization import get_scheduler
from diffusers.training_utils import (
_set_state_dict_into_text_encoder,
cast_training_params,
compute_density_for_timestep_sampling,
compute_loss_weighting_for_sd3,
@@ -80,6 +81,7 @@ def save_model_card(
repo_id: str,
images=None,
base_model: str = None,
train_text_encoder=False,
instance_prompt=None,
validation_prompt=None,
repo_folder=None,
@@ -103,6 +105,8 @@ These are {repo_id} DreamBooth weights for {base_model}.
The weights were trained using [DreamBooth](https://dreambooth.github.io/).
LoRA for the text encoder was enabled: {train_text_encoder}.
## Trigger words
You should use {instance_prompt} to trigger the image generation.
@@ -113,7 +117,7 @@ You should use {instance_prompt} to trigger the image generation.
## License
Please adhere to the licensing terms as described `[here](https://huggingface.co/stabilityai/stable-diffusion-3-medium/blob/main/LICENSE)`.
Please adhere to the licensing terms as described [here](https://huggingface.co/stabilityai/stable-diffusion-3-medium/blob/main/LICENSE).
"""
model_card = load_or_create_model_card(
repo_id_or_path=repo_id,
@@ -128,6 +132,7 @@ Please adhere to the licensing terms as described `[here](https://huggingface.co
"text-to-image",
"diffusers-training",
"diffusers",
"lora",
"sd3",
"sd3-diffusers",
"template:sd-lora",
@@ -381,6 +386,12 @@ def parse_args(input_args=None):
action="store_true",
help="whether to randomly flip images horizontally",
)
parser.add_argument(
"--train_text_encoder",
action="store_true",
help="Whether to train the text encoder (clip text encoders only). If set, the text encoder should be float32 precision.",
)
parser.add_argument(
"--train_batch_size", type=int, default=4, help="Batch size (per device) for the training dataloader."
)
@@ -856,19 +867,25 @@ def _encode_prompt_with_t5(
prompt=None,
num_images_per_prompt=1,
device=None,
text_input_ids=None,
):
prompt = [prompt] if isinstance(prompt, str) else prompt
batch_size = len(prompt)
text_inputs = tokenizer(
prompt,
padding="max_length",
max_length=max_sequence_length,
truncation=True,
add_special_tokens=True,
return_tensors="pt",
)
text_input_ids = text_inputs.input_ids
if tokenizer is not None:
text_inputs = tokenizer(
prompt,
padding="max_length",
max_length=max_sequence_length,
truncation=True,
add_special_tokens=True,
return_tensors="pt",
)
text_input_ids = text_inputs.input_ids
else:
if text_input_ids is None:
raise ValueError("text_input_ids must be provided when the tokenizer is not specified")
prompt_embeds = text_encoder(text_input_ids.to(device))[0]
dtype = text_encoder.dtype
@@ -888,20 +905,26 @@ def _encode_prompt_with_clip(
tokenizer,
prompt: str,
device=None,
text_input_ids=None,
num_images_per_prompt: int = 1,
):
prompt = [prompt] if isinstance(prompt, str) else prompt
batch_size = len(prompt)
text_inputs = tokenizer(
prompt,
padding="max_length",
max_length=77,
truncation=True,
return_tensors="pt",
)
if tokenizer is not None:
text_inputs = tokenizer(
prompt,
padding="max_length",
max_length=77,
truncation=True,
return_tensors="pt",
)
text_input_ids = text_inputs.input_ids
else:
if text_input_ids is None:
raise ValueError("text_input_ids must be provided when the tokenizer is not specified")
text_input_ids = text_inputs.input_ids
prompt_embeds = text_encoder(text_input_ids.to(device), output_hidden_states=True)
pooled_prompt_embeds = prompt_embeds[0]
@@ -923,6 +946,7 @@ def encode_prompt(
max_sequence_length,
device=None,
num_images_per_prompt: int = 1,
text_input_ids_list=None,
):
prompt = [prompt] if isinstance(prompt, str) else prompt
@@ -931,13 +955,14 @@ def encode_prompt(
clip_prompt_embeds_list = []
clip_pooled_prompt_embeds_list = []
for tokenizer, text_encoder in zip(clip_tokenizers, clip_text_encoders):
for i, (tokenizer, text_encoder) in enumerate(zip(clip_tokenizers, clip_text_encoders)):
prompt_embeds, pooled_prompt_embeds = _encode_prompt_with_clip(
text_encoder=text_encoder,
tokenizer=tokenizer,
prompt=prompt,
device=device if device is not None else text_encoder.device,
num_images_per_prompt=num_images_per_prompt,
text_input_ids=text_input_ids_list[i],
)
clip_prompt_embeds_list.append(prompt_embeds)
clip_pooled_prompt_embeds_list.append(pooled_prompt_embeds)
@@ -951,6 +976,7 @@ def encode_prompt(
max_sequence_length,
prompt=prompt,
num_images_per_prompt=num_images_per_prompt,
text_input_ids=text_input_ids_list[:-1],
device=device if device is not None else text_encoders[-1].device,
)
@@ -1145,6 +1171,9 @@ def main(args):
if args.gradient_checkpointing:
transformer.enable_gradient_checkpointing()
if args.train_text_encoder:
text_encoder_one.gradient_checkpointing_enable()
text_encoder_two.gradient_checkpointing_enable()
# now we will add new LoRA weights to the attention layers
transformer_lora_config = LoraConfig(
@@ -1155,6 +1184,16 @@ def main(args):
)
transformer.add_adapter(transformer_lora_config)
if args.train_text_encoder:
text_lora_config = LoraConfig(
r=args.rank,
lora_alpha=args.rank,
init_lora_weights="gaussian",
target_modules=["q_proj", "k_proj", "v_proj", "out_proj"],
)
text_encoder_one.add_adapter(text_lora_config)
text_encoder_two.add_adapter(text_lora_config)
def unwrap_model(model):
model = accelerator.unwrap_model(model)
model = model._orig_mod if is_compiled_module(model) else model
@@ -1164,10 +1203,16 @@ def main(args):
def save_model_hook(models, weights, output_dir):
if accelerator.is_main_process:
transformer_lora_layers_to_save = None
text_encoder_one_lora_layers_to_save = None
text_encoder_two_lora_layers_to_save = None
for model in models:
if isinstance(model, type(unwrap_model(transformer))):
transformer_lora_layers_to_save = get_peft_model_state_dict(model)
elif isinstance(model, type(unwrap_model(text_encoder_one))):
text_encoder_one_lora_layers_to_save = get_peft_model_state_dict(model)
elif isinstance(model, type(unwrap_model(text_encoder_two))):
text_encoder_two_lora_layers_to_save = get_peft_model_state_dict(model)
else:
raise ValueError(f"unexpected save model: {model.__class__}")
@@ -1175,17 +1220,26 @@ def main(args):
weights.pop()
StableDiffusion3Pipeline.save_lora_weights(
output_dir, transformer_lora_layers=transformer_lora_layers_to_save
output_dir,
transformer_lora_layers=transformer_lora_layers_to_save,
text_encoder_lora_layers=text_encoder_one_lora_layers_to_save,
text_encoder_2_lora_layers=text_encoder_two_lora_layers_to_save,
)
def load_model_hook(models, input_dir):
transformer_ = None
text_encoder_one_ = None
text_encoder_two_ = None
while len(models) > 0:
model = models.pop()
if isinstance(model, type(unwrap_model(transformer))):
transformer_ = model
elif isinstance(model, type(unwrap_model(text_encoder_one))):
text_encoder_one_ = model
elif isinstance(model, type(unwrap_model(text_encoder_two))):
text_encoder_two_ = model
else:
raise ValueError(f"unexpected save model: {model.__class__}")
@@ -1204,12 +1258,21 @@ def main(args):
f"Loading adapter weights from state_dict led to unexpected keys not found in the model: "
f" {unexpected_keys}. "
)
if args.train_text_encoder:
# Do we need to call `scale_lora_layers()` here?
_set_state_dict_into_text_encoder(lora_state_dict, prefix="text_encoder.", text_encoder=text_encoder_one_)
_set_state_dict_into_text_encoder(
lora_state_dict, prefix="text_encoder_2.", text_encoder=text_encoder_two_
)
# Make sure the trainable params are in float32. This is again needed since the base models
# are in `weight_dtype`. More details:
# https://github.com/huggingface/diffusers/pull/6514#discussion_r1449796804
if args.mixed_precision == "fp16":
models = [transformer_]
if args.train_text_encoder:
models.extend([text_encoder_one_, text_encoder_two_])
# only upcast trainable parameters (LoRA) into fp32
cast_training_params(models)
@@ -1229,14 +1292,37 @@ def main(args):
# Make sure the trainable params are in float32.
if args.mixed_precision == "fp16":
models = [transformer]
if args.train_text_encoder:
models.extend([text_encoder_one, text_encoder_two])
# only upcast trainable parameters (LoRA) into fp32
cast_training_params(models, dtype=torch.float32)
transformer_lora_parameters = list(filter(lambda p: p.requires_grad, transformer.parameters()))
if args.train_text_encoder:
text_lora_parameters_one = list(filter(lambda p: p.requires_grad, text_encoder_one.parameters()))
text_lora_parameters_two = list(filter(lambda p: p.requires_grad, text_encoder_two.parameters()))
# Optimization parameters
transformer_parameters_with_lr = {"params": transformer_lora_parameters, "lr": args.learning_rate}
params_to_optimize = [transformer_parameters_with_lr]
if args.train_text_encoder:
# different learning rate for text encoder and unet
text_lora_parameters_one_with_lr = {
"params": text_lora_parameters_one,
"weight_decay": args.adam_weight_decay_text_encoder,
"lr": args.text_encoder_lr if args.text_encoder_lr else args.learning_rate,
}
text_lora_parameters_two_with_lr = {
"params": text_lora_parameters_two,
"weight_decay": args.adam_weight_decay_text_encoder,
"lr": args.text_encoder_lr if args.text_encoder_lr else args.learning_rate,
}
params_to_optimize = [
transformer_parameters_with_lr,
text_lora_parameters_one_with_lr,
text_lora_parameters_two_with_lr,
]
else:
params_to_optimize = [transformer_parameters_with_lr]
# Optimizer creation
if not (args.optimizer.lower() == "prodigy" or args.optimizer.lower() == "adamw"):
@@ -1317,31 +1403,33 @@ def main(args):
num_workers=args.dataloader_num_workers,
)
tokenizers = [tokenizer_one, tokenizer_two, tokenizer_three]
text_encoders = [text_encoder_one, text_encoder_two, text_encoder_three]
if not args.train_text_encoder:
tokenizers = [tokenizer_one, tokenizer_two, tokenizer_three]
text_encoders = [text_encoder_one, text_encoder_two, text_encoder_three]
def compute_text_embeddings(prompt, text_encoders, tokenizers):
with torch.no_grad():
prompt_embeds, pooled_prompt_embeds = encode_prompt(
text_encoders, tokenizers, prompt, args.max_sequence_length
)
prompt_embeds = prompt_embeds.to(accelerator.device)
pooled_prompt_embeds = pooled_prompt_embeds.to(accelerator.device)
return prompt_embeds, pooled_prompt_embeds
def compute_text_embeddings(prompt, text_encoders, tokenizers):
with torch.no_grad():
prompt_embeds, pooled_prompt_embeds = encode_prompt(
text_encoders, tokenizers, prompt, args.max_sequence_length
)
prompt_embeds = prompt_embeds.to(accelerator.device)
pooled_prompt_embeds = pooled_prompt_embeds.to(accelerator.device)
return prompt_embeds, pooled_prompt_embeds
if not train_dataset.custom_instance_prompts:
if not args.train_text_encoder and not train_dataset.custom_instance_prompts:
instance_prompt_hidden_states, instance_pooled_prompt_embeds = compute_text_embeddings(
args.instance_prompt, text_encoders, tokenizers
)
# Handle class prompt for prior-preservation.
if args.with_prior_preservation:
class_prompt_hidden_states, class_pooled_prompt_embeds = compute_text_embeddings(
args.class_prompt, text_encoders, tokenizers
)
if not args.train_text_encoder:
class_prompt_hidden_states, class_pooled_prompt_embeds = compute_text_embeddings(
args.class_prompt, text_encoders, tokenizers
)
# Clear the memory here
if not train_dataset.custom_instance_prompts:
if not args.train_text_encoder and train_dataset.custom_instance_prompts:
del tokenizers, text_encoders
# Explicitly delete the objects as well, otherwise only the lists are deleted and the original references remain, preventing garbage collection
del text_encoder_one, text_encoder_two, text_encoder_three
@@ -1354,12 +1442,13 @@ def main(args):
# have to pass them to the dataloader.
if not train_dataset.custom_instance_prompts:
prompt_embeds = instance_prompt_hidden_states
pooled_prompt_embeds = instance_pooled_prompt_embeds
if args.with_prior_preservation:
prompt_embeds = torch.cat([prompt_embeds, class_prompt_hidden_states], dim=0)
pooled_prompt_embeds = torch.cat([pooled_prompt_embeds, class_pooled_prompt_embeds], dim=0)
# if we're optimizing the text encoder (both if instance prompt is used for all images or custom prompts) we need to tokenize and encode the
if not args.train_text_encoder:
prompt_embeds = instance_prompt_hidden_states
pooled_prompt_embeds = instance_pooled_prompt_embeds
if args.with_prior_preservation:
prompt_embeds = torch.cat([prompt_embeds, class_prompt_hidden_states], dim=0)
pooled_prompt_embeds = torch.cat([pooled_prompt_embeds, class_pooled_prompt_embeds], dim=0)
# if we're optimizing the text encoder (both if instance prompt is used for all images or custom prompts) we need to tokenize and encode the
# batch prompts on all training steps
else:
tokens_one = tokenize_prompt(tokenizer_one, args.instance_prompt)
@@ -1390,9 +1479,22 @@ def main(args):
)
# Prepare everything with our `accelerator`.
transformer, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
transformer, optimizer, train_dataloader, lr_scheduler
)
# Prepare everything with our `accelerator`.
if args.train_text_encoder:
(
transformer,
text_encoder_one,
text_encoder_two,
optimizer,
train_dataloader,
lr_scheduler,
) = accelerator.prepare(
transformer, text_encoder_one, text_encoder_two, optimizer, train_dataloader, lr_scheduler
)
else:
transformer, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
transformer, optimizer, train_dataloader, lr_scheduler
)
# We need to recalculate our total training steps as the size of the training dataloader may have changed.
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
@@ -1470,6 +1572,13 @@ def main(args):
for epoch in range(first_epoch, args.num_train_epochs):
transformer.train()
if args.train_text_encoder:
text_encoder_one.train()
text_encoder_two.train()
# set top parameter requires_grad = True for gradient checkpointing works
accelerator.unwrap_model(text_encoder_one).text_model.embeddings.requires_grad_(True)
accelerator.unwrap_model(text_encoder_two).text_model.embeddings.requires_grad_(True)
for step, batch in enumerate(train_dataloader):
models_to_accumulate = [transformer]
@@ -1479,7 +1588,30 @@ def main(args):
# encode batch prompts when custom prompts are provided for each image -
if train_dataset.custom_instance_prompts:
prompt_embeds, pooled_prompt_embeds = compute_text_embeddings(prompts, text_encoders, tokenizers)
if not args.train_text_encoder:
prompt_embeds, pooled_prompt_embeds = compute_text_embeddings(
prompts, text_encoders, tokenizers
)
else:
tokens_one = tokenize_prompt(tokenizer_one, prompts)
tokens_two = tokenize_prompt(tokenizer_two, prompts)
tokens_three = tokenize_prompt(tokenizer_three, prompts)
prompt_embeds, pooled_prompt_embeds = encode_prompt(
text_encoders=[text_encoder_one, text_encoder_two, text_encoder_three],
tokenizers=[None, None, tokenizer_three],
prompt=prompts,
max_sequence_length=args.max_sequence_length,
text_input_ids_list=[tokens_one, tokens_two, tokens_three],
)
else:
if args.train_text_encoder:
prompt_embeds, pooled_prompt_embeds = encode_prompt(
text_encoders=[text_encoder_one, text_encoder_two, text_encoder_three],
tokenizers=[None, None, tokenizer_three],
prompt=prompts,
max_sequence_length=args.max_sequence_length,
text_input_ids_list=[tokens_one, tokens_two, tokens_three],
)
# Convert images to latent space
model_input = vae.encode(pixel_values).latent_dist.sample()
@@ -1553,7 +1685,11 @@ def main(args):
accelerator.backward(loss)
if accelerator.sync_gradients:
params_to_clip = transformer_lora_parameters
params_to_clip = itertools.chain(
transformer_lora_parameters,
text_lora_parameters_one,
text_lora_parameters_two if args.train_text_encoder else transformer_lora_parameters,
)
accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)
optimizer.step()
@@ -1600,10 +1736,18 @@ def main(args):
if accelerator.is_main_process:
if args.validation_prompt is not None and epoch % args.validation_epochs == 0:
# create pipeline
text_encoder_one, text_encoder_two, text_encoder_three = load_text_encoders(
text_encoder_cls_one, text_encoder_cls_two, text_encoder_cls_three
)
if not args.train_text_encoder:
# create pipeline
text_encoder_one, text_encoder_two, text_encoder_three = load_text_encoders(
text_encoder_cls_one, text_encoder_cls_two, text_encoder_cls_three
)
else:
text_encoder_three = text_encoder_cls_three.from_pretrained(
args.pretrained_model_name_or_path,
subfolder="text_encoder_3",
revision=args.revision,
variant=args.variant,
)
pipeline = StableDiffusion3Pipeline.from_pretrained(
args.pretrained_model_name_or_path,
vae=vae,
@@ -1634,15 +1778,20 @@ def main(args):
transformer = transformer.to(torch.float32)
transformer_lora_layers = get_peft_model_state_dict(transformer)
StableDiffusion3Pipeline.save_lora_weights(
save_directory=args.output_dir, transformer_lora_layers=transformer_lora_layers
)
if args.train_text_encoder:
text_encoder_one = unwrap_model(text_encoder_one)
text_encoder_lora_layers = get_peft_model_state_dict(text_encoder_one.to(torch.float32))
text_encoder_two = unwrap_model(text_encoder_two)
text_encoder_2_lora_layers = get_peft_model_state_dict(text_encoder_two.to(torch.float32))
else:
text_encoder_lora_layers = None
text_encoder_2_lora_layers = None
pipeline = StableDiffusion3Pipeline.from_pretrained(
args.pretrained_model_name_or_path,
revision=args.revision,
variant=args.variant,
torch_dtype=weight_dtype,
StableDiffusion3Pipeline.save_lora_weights(
save_directory=args.output_dir,
transformer_lora_layers=transformer_lora_layers,
text_encoder_lora_layers=text_encoder_lora_layers,
text_encoder_2_lora_layers=text_encoder_2_lora_layers,
)
# Final inference
@@ -1676,6 +1825,7 @@ def main(args):
base_model=args.pretrained_model_name_or_path,
instance_prompt=args.instance_prompt,
validation_prompt=args.validation_prompt,
train_text_encoder=args.train_text_encoder,
repo_folder=args.output_dir,
)
upload_folder(
@@ -30,7 +30,7 @@ accelerate launch finetune_instruct_pix2pix.py \
## Inference
After training the model and the lora weight of the model is stored in the ```$OUTPUT_DIR```.
```bash
```py
# load the base model pipeline
pipe_lora = StableDiffusionInstructPix2PixPipeline.from_pretrained("timbrooks/instruct-pix2pix")
@@ -6,7 +6,7 @@ This aims to provide diffusers examples with Intel optimizations such as Bfloat1
## Accelerating the fine-tuning for textual inversion
We accelereate the fine-tuning for textual inversion with Intel Extension for PyTorch. The [examples](textual_inversion) enable both single node and multi-node distributed training with Bfloat16 support on Intel Xeon Scalable Processor.
We accelerate the fine-tuning for textual inversion with Intel Extension for PyTorch. The [examples](textual_inversion) enable both single node and multi-node distributed training with Bfloat16 support on Intel Xeon Scalable Processor.
## Accelerating the inference for Stable Diffusion using Bfloat16
@@ -323,7 +323,7 @@ accelerate launch train_dreambooth.py \
### Using DreamBooth for other pipelines than Stable Diffusion
Altdiffusion also support dreambooth now, the runing comman is basically the same as above, all you need to do is replace the `MODEL_NAME` like this:
Altdiffusion also supports dreambooth now, the running command is basically the same as above, all you need to do is replace the `MODEL_NAME` like this:
One can now simply change the `pretrained_model_name_or_path` to another architecture such as [`AltDiffusion`](https://huggingface.co/docs/diffusers/api/pipelines/alt_diffusion).
```
+3 -3
View File
@@ -45,7 +45,7 @@ accelerate launch train_vqgan.py \
```
An example training run is [here](https://wandb.ai/sayakpaul/vqgan-training/runs/0m5kzdfp) by @sayakpaul and a lower scale one [here](https://wandb.ai/dsbuddy27/vqgan-training/runs/eqd6xi4n?nw=nwuserisamu). The validation images can be obtained from [here](https://huggingface.co/datasets/diffusers/docs-images/tree/main/vqgan_validation_images).
The simplest way to improve the quality of a VQGAN model is to maximize the amount of information present in the bottleneck. The easiest way to do this is increasing the image resolution. However, other ways include, but not limited to, lowering compression by downsampling fewer times or increasing the vocaburary size which at most can be around 16384. How to do this is shown below.
The simplest way to improve the quality of a VQGAN model is to maximize the amount of information present in the bottleneck. The easiest way to do this is increasing the image resolution. However, other ways include, but not limited to, lowering compression by downsampling fewer times or increasing the vocabulary size which at most can be around 16384. How to do this is shown below.
# Modifying the architecture
@@ -118,10 +118,10 @@ To lower the amount of layers in a VQGan, you can remove layers by modifying the
"vq_embed_dim": 4
}
```
For increasing the size of the vocaburaries you can increase num_vq_embeddings. However, [some research](https://magvit.cs.cmu.edu/v2/) shows that the representation of VQGANs start degrading after 2^14~16384 vq embeddings so it's not recommended to go past that.
For increasing the size of the vocabularies you can increase num_vq_embeddings. However, [some research](https://magvit.cs.cmu.edu/v2/) shows that the representation of VQGANs start degrading after 2^14~16384 vq embeddings so it's not recommended to go past that.
## Extra training tips/ideas
During logging take care to make sure data_time is low. data_time is the amount spent loading the data and where the GPU is not active. So essentially, it's the time wasted. The easiest way to lower data time is to increase the --dataloader_num_workers to a higher number like 4. Due to a bug in Pytorch, this only works on linux based systems. For more details check [here](https://github.com/huggingface/diffusers/issues/7646)
Secondly, training should seem to be done when both the discriminator and the generator loss converges.
Thirdly, another low hanging fruit is just using ema using the --use_ema parameter. This tends to make the output images smoother. This has a con where you have to lower your batch size by 1 but it may be worth it.
Another more experimental low hanging fruit is changing from the vgg19 to different models for the lpips loss using the --timm_model_backend. If you do this, I recommend also changing the timm_model_layers parameter to the layer in your model which you think is best for representation. However, becareful with the feature map norms since this can easily overdominate the loss.
Another more experimental low hanging fruit is changing from the vgg19 to different models for the lpips loss using the --timm_model_backend. If you do this, I recommend also changing the timm_model_layers parameter to the layer in your model which you think is best for representation. However, be careful with the feature map norms since this can easily overdominate the loss.
+248
View File
@@ -0,0 +1,248 @@
import argparse
from contextlib import nullcontext
import safetensors.torch
import torch
from accelerate import init_empty_weights
from diffusers import AutoencoderKL, SD3Transformer2DModel
from diffusers.loaders.single_file_utils import convert_ldm_vae_checkpoint
from diffusers.models.modeling_utils import load_model_dict_into_meta
from diffusers.utils.import_utils import is_accelerate_available
CTX = init_empty_weights if is_accelerate_available else nullcontext
parser = argparse.ArgumentParser()
parser.add_argument("--checkpoint_path", type=str)
parser.add_argument("--output_path", type=str)
parser.add_argument("--dtype", type=str, default="fp16")
args = parser.parse_args()
dtype = torch.float16 if args.dtype == "fp16" else torch.float32
def load_original_checkpoint(ckpt_path):
original_state_dict = safetensors.torch.load_file(ckpt_path)
keys = list(original_state_dict.keys())
for k in keys:
if "model.diffusion_model." in k:
original_state_dict[k.replace("model.diffusion_model.", "")] = original_state_dict.pop(k)
return original_state_dict
# in SD3 original implementation of AdaLayerNormContinuous, it split linear projection output into shift, scale;
# while in diffusers it split into scale, shift. Here we swap the linear projection weights in order to be able to use diffusers implementation
def swap_scale_shift(weight, dim):
shift, scale = weight.chunk(2, dim=0)
new_weight = torch.cat([scale, shift], dim=0)
return new_weight
def convert_sd3_transformer_checkpoint_to_diffusers(original_state_dict, num_layers, caption_projection_dim):
converted_state_dict = {}
# Positional and patch embeddings.
converted_state_dict["pos_embed.pos_embed"] = original_state_dict.pop("pos_embed")
converted_state_dict["pos_embed.proj.weight"] = original_state_dict.pop("x_embedder.proj.weight")
converted_state_dict["pos_embed.proj.bias"] = original_state_dict.pop("x_embedder.proj.bias")
# Timestep embeddings.
converted_state_dict["time_text_embed.timestep_embedder.linear_1.weight"] = original_state_dict.pop(
"t_embedder.mlp.0.weight"
)
converted_state_dict["time_text_embed.timestep_embedder.linear_1.bias"] = original_state_dict.pop(
"t_embedder.mlp.0.bias"
)
converted_state_dict["time_text_embed.timestep_embedder.linear_2.weight"] = original_state_dict.pop(
"t_embedder.mlp.2.weight"
)
converted_state_dict["time_text_embed.timestep_embedder.linear_2.bias"] = original_state_dict.pop(
"t_embedder.mlp.2.bias"
)
# Context projections.
converted_state_dict["context_embedder.weight"] = original_state_dict.pop("context_embedder.weight")
converted_state_dict["context_embedder.bias"] = original_state_dict.pop("context_embedder.bias")
# Pooled context projection.
converted_state_dict["time_text_embed.text_embedder.linear_1.weight"] = original_state_dict.pop(
"y_embedder.mlp.0.weight"
)
converted_state_dict["time_text_embed.text_embedder.linear_1.bias"] = original_state_dict.pop(
"y_embedder.mlp.0.bias"
)
converted_state_dict["time_text_embed.text_embedder.linear_2.weight"] = original_state_dict.pop(
"y_embedder.mlp.2.weight"
)
converted_state_dict["time_text_embed.text_embedder.linear_2.bias"] = original_state_dict.pop(
"y_embedder.mlp.2.bias"
)
# Transformer blocks 🎸.
for i in range(num_layers):
# Q, K, V
sample_q, sample_k, sample_v = torch.chunk(
original_state_dict.pop(f"joint_blocks.{i}.x_block.attn.qkv.weight"), 3, dim=0
)
context_q, context_k, context_v = torch.chunk(
original_state_dict.pop(f"joint_blocks.{i}.context_block.attn.qkv.weight"), 3, dim=0
)
sample_q_bias, sample_k_bias, sample_v_bias = torch.chunk(
original_state_dict.pop(f"joint_blocks.{i}.x_block.attn.qkv.bias"), 3, dim=0
)
context_q_bias, context_k_bias, context_v_bias = torch.chunk(
original_state_dict.pop(f"joint_blocks.{i}.context_block.attn.qkv.bias"), 3, dim=0
)
converted_state_dict[f"transformer_blocks.{i}.attn.to_q.weight"] = torch.cat([sample_q])
converted_state_dict[f"transformer_blocks.{i}.attn.to_q.bias"] = torch.cat([sample_q_bias])
converted_state_dict[f"transformer_blocks.{i}.attn.to_k.weight"] = torch.cat([sample_k])
converted_state_dict[f"transformer_blocks.{i}.attn.to_k.bias"] = torch.cat([sample_k_bias])
converted_state_dict[f"transformer_blocks.{i}.attn.to_v.weight"] = torch.cat([sample_v])
converted_state_dict[f"transformer_blocks.{i}.attn.to_v.bias"] = torch.cat([sample_v_bias])
converted_state_dict[f"transformer_blocks.{i}.attn.add_q_proj.weight"] = torch.cat([context_q])
converted_state_dict[f"transformer_blocks.{i}.attn.add_q_proj.bias"] = torch.cat([context_q_bias])
converted_state_dict[f"transformer_blocks.{i}.attn.add_k_proj.weight"] = torch.cat([context_k])
converted_state_dict[f"transformer_blocks.{i}.attn.add_k_proj.bias"] = torch.cat([context_k_bias])
converted_state_dict[f"transformer_blocks.{i}.attn.add_v_proj.weight"] = torch.cat([context_v])
converted_state_dict[f"transformer_blocks.{i}.attn.add_v_proj.bias"] = torch.cat([context_v_bias])
# output projections.
converted_state_dict[f"transformer_blocks.{i}.attn.to_out.0.weight"] = original_state_dict.pop(
f"joint_blocks.{i}.x_block.attn.proj.weight"
)
converted_state_dict[f"transformer_blocks.{i}.attn.to_out.0.bias"] = original_state_dict.pop(
f"joint_blocks.{i}.x_block.attn.proj.bias"
)
if not (i == num_layers - 1):
converted_state_dict[f"transformer_blocks.{i}.attn.to_add_out.weight"] = original_state_dict.pop(
f"joint_blocks.{i}.context_block.attn.proj.weight"
)
converted_state_dict[f"transformer_blocks.{i}.attn.to_add_out.bias"] = original_state_dict.pop(
f"joint_blocks.{i}.context_block.attn.proj.bias"
)
# norms.
converted_state_dict[f"transformer_blocks.{i}.norm1.linear.weight"] = original_state_dict.pop(
f"joint_blocks.{i}.x_block.adaLN_modulation.1.weight"
)
converted_state_dict[f"transformer_blocks.{i}.norm1.linear.bias"] = original_state_dict.pop(
f"joint_blocks.{i}.x_block.adaLN_modulation.1.bias"
)
if not (i == num_layers - 1):
converted_state_dict[f"transformer_blocks.{i}.norm1_context.linear.weight"] = original_state_dict.pop(
f"joint_blocks.{i}.context_block.adaLN_modulation.1.weight"
)
converted_state_dict[f"transformer_blocks.{i}.norm1_context.linear.bias"] = original_state_dict.pop(
f"joint_blocks.{i}.context_block.adaLN_modulation.1.bias"
)
else:
converted_state_dict[f"transformer_blocks.{i}.norm1_context.linear.weight"] = swap_scale_shift(
original_state_dict.pop(f"joint_blocks.{i}.context_block.adaLN_modulation.1.weight"),
dim=caption_projection_dim,
)
converted_state_dict[f"transformer_blocks.{i}.norm1_context.linear.bias"] = swap_scale_shift(
original_state_dict.pop(f"joint_blocks.{i}.context_block.adaLN_modulation.1.bias"),
dim=caption_projection_dim,
)
# ffs.
converted_state_dict[f"transformer_blocks.{i}.ff.net.0.proj.weight"] = original_state_dict.pop(
f"joint_blocks.{i}.x_block.mlp.fc1.weight"
)
converted_state_dict[f"transformer_blocks.{i}.ff.net.0.proj.bias"] = original_state_dict.pop(
f"joint_blocks.{i}.x_block.mlp.fc1.bias"
)
converted_state_dict[f"transformer_blocks.{i}.ff.net.2.weight"] = original_state_dict.pop(
f"joint_blocks.{i}.x_block.mlp.fc2.weight"
)
converted_state_dict[f"transformer_blocks.{i}.ff.net.2.bias"] = original_state_dict.pop(
f"joint_blocks.{i}.x_block.mlp.fc2.bias"
)
if not (i == num_layers - 1):
converted_state_dict[f"transformer_blocks.{i}.ff_context.net.0.proj.weight"] = original_state_dict.pop(
f"joint_blocks.{i}.context_block.mlp.fc1.weight"
)
converted_state_dict[f"transformer_blocks.{i}.ff_context.net.0.proj.bias"] = original_state_dict.pop(
f"joint_blocks.{i}.context_block.mlp.fc1.bias"
)
converted_state_dict[f"transformer_blocks.{i}.ff_context.net.2.weight"] = original_state_dict.pop(
f"joint_blocks.{i}.context_block.mlp.fc2.weight"
)
converted_state_dict[f"transformer_blocks.{i}.ff_context.net.2.bias"] = original_state_dict.pop(
f"joint_blocks.{i}.context_block.mlp.fc2.bias"
)
# Final blocks.
converted_state_dict["proj_out.weight"] = original_state_dict.pop("final_layer.linear.weight")
converted_state_dict["proj_out.bias"] = original_state_dict.pop("final_layer.linear.bias")
converted_state_dict["norm_out.linear.weight"] = swap_scale_shift(
original_state_dict.pop("final_layer.adaLN_modulation.1.weight"), dim=caption_projection_dim
)
converted_state_dict["norm_out.linear.bias"] = swap_scale_shift(
original_state_dict.pop("final_layer.adaLN_modulation.1.bias"), dim=caption_projection_dim
)
return converted_state_dict
def is_vae_in_checkpoint(original_state_dict):
return ("first_stage_model.decoder.conv_in.weight" in original_state_dict) and (
"first_stage_model.encoder.conv_in.weight" in original_state_dict
)
def main(args):
original_ckpt = load_original_checkpoint(args.checkpoint_path)
num_layers = list(set(int(k.split(".", 2)[1]) for k in original_ckpt if "joint_blocks" in k))[-1] + 1 # noqa: C401
caption_projection_dim = 1536
converted_transformer_state_dict = convert_sd3_transformer_checkpoint_to_diffusers(
original_ckpt, num_layers, caption_projection_dim
)
with CTX():
transformer = SD3Transformer2DModel(
sample_size=64,
patch_size=2,
in_channels=16,
joint_attention_dim=4096,
num_layers=num_layers,
caption_projection_dim=caption_projection_dim,
num_attention_heads=24,
pos_embed_max_size=192,
)
if is_accelerate_available():
load_model_dict_into_meta(transformer, converted_transformer_state_dict)
else:
transformer.load_state_dict(converted_transformer_state_dict, strict=True)
print("Saving SD3 Transformer in Diffusers format.")
transformer.to(dtype).save_pretrained(f"{args.output_path}/transformer")
if is_vae_in_checkpoint(original_ckpt):
with CTX():
vae = AutoencoderKL.from_config(
"stabilityai/stable-diffusion-xl-base-1.0",
subfolder="vae",
latent_channels=16,
use_post_quant_conv=False,
use_quant_conv=False,
scaling_factor=1.5305,
shift_factor=0.0609,
)
converted_vae_state_dict = convert_ldm_vae_checkpoint(original_ckpt, vae.config)
if is_accelerate_available():
load_model_dict_into_meta(vae, converted_vae_state_dict)
else:
vae.load_state_dict(converted_vae_state_dict, strict=True)
print("Saving SD3 Autoencoder in Diffusers format.")
vae.to(dtype).save_pretrained(f"{args.output_path}/vae")
if __name__ == "__main__":
main(args)
+16
View File
@@ -83,7 +83,9 @@ else:
"ControlNetModel",
"ControlNetXSAdapter",
"DiTTransformer2DModel",
"HunyuanDiT2DControlNetModel",
"HunyuanDiT2DModel",
"HunyuanDiT2DMultiControlNetModel",
"I2VGenXLUNet",
"Kandinsky3UNet",
"ModelMixin",
@@ -234,6 +236,7 @@ else:
"BlipDiffusionPipeline",
"CLIPImageProjection",
"CycleDiffusionPipeline",
"HunyuanDiTControlNetPipeline",
"HunyuanDiTPipeline",
"I2VGenXLPipeline",
"IFImg2ImgPipeline",
@@ -301,6 +304,7 @@ else:
"StableDiffusionLatentUpscalePipeline",
"StableDiffusionLDM3DPipeline",
"StableDiffusionModelEditingPipeline",
"StableDiffusionPAGPipeline",
"StableDiffusionPanoramaPipeline",
"StableDiffusionParadigmsPipeline",
"StableDiffusionPipeline",
@@ -311,11 +315,15 @@ else:
"StableDiffusionXLAdapterPipeline",
"StableDiffusionXLControlNetImg2ImgPipeline",
"StableDiffusionXLControlNetInpaintPipeline",
"StableDiffusionXLControlNetPAGPipeline",
"StableDiffusionXLControlNetPipeline",
"StableDiffusionXLControlNetXSPipeline",
"StableDiffusionXLImg2ImgPipeline",
"StableDiffusionXLInpaintPipeline",
"StableDiffusionXLInstructPix2PixPipeline",
"StableDiffusionXLPAGImg2ImgPipeline",
"StableDiffusionXLPAGInpaintPipeline",
"StableDiffusionXLPAGPipeline",
"StableDiffusionXLPipeline",
"StableUnCLIPImg2ImgPipeline",
"StableUnCLIPPipeline",
@@ -496,7 +504,9 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
ControlNetModel,
ControlNetXSAdapter,
DiTTransformer2DModel,
HunyuanDiT2DControlNetModel,
HunyuanDiT2DModel,
HunyuanDiT2DMultiControlNetModel,
I2VGenXLUNet,
Kandinsky3UNet,
ModelMixin,
@@ -625,6 +635,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
AudioLDMPipeline,
CLIPImageProjection,
CycleDiffusionPipeline,
HunyuanDiTControlNetPipeline,
HunyuanDiTPipeline,
I2VGenXLPipeline,
IFImg2ImgPipeline,
@@ -692,6 +703,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
StableDiffusionLatentUpscalePipeline,
StableDiffusionLDM3DPipeline,
StableDiffusionModelEditingPipeline,
StableDiffusionPAGPipeline,
StableDiffusionPanoramaPipeline,
StableDiffusionParadigmsPipeline,
StableDiffusionPipeline,
@@ -702,11 +714,15 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
StableDiffusionXLAdapterPipeline,
StableDiffusionXLControlNetImg2ImgPipeline,
StableDiffusionXLControlNetInpaintPipeline,
StableDiffusionXLControlNetPAGPipeline,
StableDiffusionXLControlNetPipeline,
StableDiffusionXLControlNetXSPipeline,
StableDiffusionXLImg2ImgPipeline,
StableDiffusionXLInpaintPipeline,
StableDiffusionXLInstructPix2PixPipeline,
StableDiffusionXLPAGImg2ImgPipeline,
StableDiffusionXLPAGInpaintPipeline,
StableDiffusionXLPAGPipeline,
StableDiffusionXLPipeline,
StableUnCLIPImg2ImgPipeline,
StableUnCLIPPipeline,
+3 -3
View File
@@ -23,7 +23,7 @@ import json
import os
import re
from collections import OrderedDict
from pathlib import PosixPath
from pathlib import Path
from typing import Any, Dict, Tuple, Union
import numpy as np
@@ -587,8 +587,8 @@ class ConfigMixin:
def to_json_saveable(value):
if isinstance(value, np.ndarray):
value = value.tolist()
elif isinstance(value, PosixPath):
value = str(value)
elif isinstance(value, Path):
value = value.as_posix()
return value
config_dict = {k: to_json_saveable(v) for k, v in config_dict.items()}
-146
View File
@@ -1,146 +0,0 @@
# Copyright 2024 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from huggingface_hub.utils import validate_hf_hub_args
from .single_file_utils import (
create_diffusers_vae_model_from_ldm,
fetch_ldm_config_and_checkpoint,
)
class FromOriginalVAEMixin:
"""
Load pretrained AutoencoderKL weights saved in the `.ckpt` or `.safetensors` format into a [`AutoencoderKL`].
"""
@classmethod
@validate_hf_hub_args
def from_single_file(cls, pretrained_model_link_or_path, **kwargs):
r"""
Instantiate a [`AutoencoderKL`] from pretrained ControlNet weights saved in the original `.ckpt` or
`.safetensors` format. The pipeline is set in evaluation mode (`model.eval()`) by default.
Parameters:
pretrained_model_link_or_path (`str` or `os.PathLike`, *optional*):
Can be either:
- A link to the `.ckpt` file (for example
`"https://huggingface.co/<repo_id>/blob/main/<path_to_file>.ckpt"`) on the Hub.
- A path to a *file* containing all pipeline weights.
config_file (`str`, *optional*):
Filepath to the configuration YAML file associated with the model. If not provided it will default to:
https://raw.githubusercontent.com/CompVis/stable-diffusion/main/configs/stable-diffusion/v1-inference.yaml
torch_dtype (`str` or `torch.dtype`, *optional*):
Override the default `torch.dtype` and load the model with another dtype. If `"auto"` is passed, the
dtype is automatically derived from the model's weights.
force_download (`bool`, *optional*, defaults to `False`):
Whether or not to force the (re-)download of the model weights and configuration files, overriding the
cached versions if they exist.
cache_dir (`Union[str, os.PathLike]`, *optional*):
Path to a directory where a downloaded pretrained model configuration is cached if the standard cache
is not used.
resume_download:
Deprecated and ignored. All downloads are now resumed by default when possible. Will be removed in v1
of Diffusers.
proxies (`Dict[str, str]`, *optional*):
A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128',
'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
local_files_only (`bool`, *optional*, defaults to `False`):
Whether to only load local model weights and configuration files or not. If set to True, the model
won't be downloaded from the Hub.
token (`str` or *bool*, *optional*):
The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from
`diffusers-cli login` (stored in `~/.huggingface`) is used.
revision (`str`, *optional*, defaults to `"main"`):
The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier
allowed by Git.
image_size (`int`, *optional*, defaults to 512):
The image size the model was trained on. Use 512 for all Stable Diffusion v1 models and the Stable
Diffusion v2 base model. Use 768 for Stable Diffusion v2.
scaling_factor (`float`, *optional*, defaults to 0.18215):
The component-wise standard deviation of the trained latent space computed using the first batch of the
training set. This is used to scale the latent space to have unit variance when training the diffusion
model. The latents are scaled with the formula `z = z * scaling_factor` before being passed to the
diffusion model. When decoding, the latents are scaled back to the original scale with the formula: `z
= 1 / scaling_factor * z`. For more details, refer to sections 4.3.2 and D.1 of the [High-Resolution
Image Synthesis with Latent Diffusion Models](https://arxiv.org/abs/2112.10752) paper.
kwargs (remaining dictionary of keyword arguments, *optional*):
Can be used to overwrite load and saveable variables (for example the pipeline components of the
specific pipeline class). The overwritten components are directly passed to the pipelines `__init__`
method. See example below for more information.
<Tip warning={true}>
Make sure to pass both `image_size` and `scaling_factor` to `from_single_file()` if you're loading
a VAE from SDXL or a Stable Diffusion v2 model or higher.
</Tip>
Examples:
```py
from diffusers import AutoencoderKL
url = "https://huggingface.co/stabilityai/sd-vae-ft-mse-original/blob/main/vae-ft-mse-840000-ema-pruned.safetensors" # can also be local file
model = AutoencoderKL.from_single_file(url)
```
"""
original_config_file = kwargs.pop("original_config_file", None)
config_file = kwargs.pop("config_file", None)
resume_download = kwargs.pop("resume_download", None)
force_download = kwargs.pop("force_download", False)
proxies = kwargs.pop("proxies", None)
token = kwargs.pop("token", None)
cache_dir = kwargs.pop("cache_dir", None)
local_files_only = kwargs.pop("local_files_only", None)
revision = kwargs.pop("revision", None)
torch_dtype = kwargs.pop("torch_dtype", None)
class_name = cls.__name__
if (config_file is not None) and (original_config_file is not None):
raise ValueError(
"You cannot pass both `config_file` and `original_config_file` to `from_single_file`. Please use only one of these arguments."
)
original_config_file = original_config_file or config_file
original_config, checkpoint = fetch_ldm_config_and_checkpoint(
pretrained_model_link_or_path=pretrained_model_link_or_path,
class_name=class_name,
original_config_file=original_config_file,
resume_download=resume_download,
force_download=force_download,
proxies=proxies,
token=token,
revision=revision,
local_files_only=local_files_only,
cache_dir=cache_dir,
)
image_size = kwargs.pop("image_size", None)
scaling_factor = kwargs.pop("scaling_factor", None)
component = create_diffusers_vae_model_from_ldm(
class_name,
original_config,
checkpoint,
image_size=image_size,
scaling_factor=scaling_factor,
torch_dtype=torch_dtype,
)
vae = component["vae"]
if torch_dtype is not None:
vae = vae.to(torch_dtype)
return vae
-136
View File
@@ -1,136 +0,0 @@
# Copyright 2024 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from huggingface_hub.utils import validate_hf_hub_args
from .single_file_utils import (
create_diffusers_controlnet_model_from_ldm,
fetch_ldm_config_and_checkpoint,
)
class FromOriginalControlNetMixin:
"""
Load pretrained ControlNet weights saved in the `.ckpt` or `.safetensors` format into a [`ControlNetModel`].
"""
@classmethod
@validate_hf_hub_args
def from_single_file(cls, pretrained_model_link_or_path, **kwargs):
r"""
Instantiate a [`ControlNetModel`] from pretrained ControlNet weights saved in the original `.ckpt` or
`.safetensors` format. The pipeline is set in evaluation mode (`model.eval()`) by default.
Parameters:
pretrained_model_link_or_path (`str` or `os.PathLike`, *optional*):
Can be either:
- A link to the `.ckpt` file (for example
`"https://huggingface.co/<repo_id>/blob/main/<path_to_file>.ckpt"`) on the Hub.
- A path to a *file* containing all pipeline weights.
config_file (`str`, *optional*):
Filepath to the configuration YAML file associated with the model. If not provided it will default to:
https://raw.githubusercontent.com/lllyasviel/ControlNet/main/models/cldm_v15.yaml
torch_dtype (`str` or `torch.dtype`, *optional*):
Override the default `torch.dtype` and load the model with another dtype. If `"auto"` is passed, the
dtype is automatically derived from the model's weights.
force_download (`bool`, *optional*, defaults to `False`):
Whether or not to force the (re-)download of the model weights and configuration files, overriding the
cached versions if they exist.
cache_dir (`Union[str, os.PathLike]`, *optional*):
Path to a directory where a downloaded pretrained model configuration is cached if the standard cache
is not used.
resume_download:
Deprecated and ignored. All downloads are now resumed by default when possible. Will be removed in v1
of Diffusers.
proxies (`Dict[str, str]`, *optional*):
A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128',
'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
local_files_only (`bool`, *optional*, defaults to `False`):
Whether to only load local model weights and configuration files or not. If set to True, the model
won't be downloaded from the Hub.
token (`str` or *bool*, *optional*):
The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from
`diffusers-cli login` (stored in `~/.huggingface`) is used.
revision (`str`, *optional*, defaults to `"main"`):
The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier
allowed by Git.
image_size (`int`, *optional*, defaults to 512):
The image size the model was trained on. Use 512 for all Stable Diffusion v1 models and the Stable
Diffusion v2 base model. Use 768 for Stable Diffusion v2.
upcast_attention (`bool`, *optional*, defaults to `None`):
Whether the attention computation should always be upcasted.
kwargs (remaining dictionary of keyword arguments, *optional*):
Can be used to overwrite load and saveable variables (for example the pipeline components of the
specific pipeline class). The overwritten components are directly passed to the pipelines `__init__`
method. See example below for more information.
Examples:
```py
from diffusers import StableDiffusionControlNetPipeline, ControlNetModel
url = "https://huggingface.co/lllyasviel/ControlNet-v1-1/blob/main/control_v11p_sd15_canny.pth" # can also be a local path
model = ControlNetModel.from_single_file(url)
url = "https://huggingface.co/runwayml/stable-diffusion-v1-5/blob/main/v1-5-pruned.safetensors" # can also be a local path
pipe = StableDiffusionControlNetPipeline.from_single_file(url, controlnet=controlnet)
```
"""
original_config_file = kwargs.pop("original_config_file", None)
config_file = kwargs.pop("config_file", None)
resume_download = kwargs.pop("resume_download", None)
force_download = kwargs.pop("force_download", False)
proxies = kwargs.pop("proxies", None)
token = kwargs.pop("token", None)
cache_dir = kwargs.pop("cache_dir", None)
local_files_only = kwargs.pop("local_files_only", None)
revision = kwargs.pop("revision", None)
torch_dtype = kwargs.pop("torch_dtype", None)
class_name = cls.__name__
if (config_file is not None) and (original_config_file is not None):
raise ValueError(
"You cannot pass both `config_file` and `original_config_file` to `from_single_file`. Please use only one of these arguments."
)
original_config_file = config_file or original_config_file
original_config, checkpoint = fetch_ldm_config_and_checkpoint(
pretrained_model_link_or_path=pretrained_model_link_or_path,
class_name=class_name,
original_config_file=original_config_file,
resume_download=resume_download,
force_download=force_download,
proxies=proxies,
token=token,
revision=revision,
local_files_only=local_files_only,
cache_dir=cache_dir,
)
upcast_attention = kwargs.pop("upcast_attention", False)
image_size = kwargs.pop("image_size", None)
component = create_diffusers_controlnet_model_from_ldm(
class_name,
original_config,
checkpoint,
upcast_attention=upcast_attention,
image_size=image_size,
torch_dtype=torch_dtype,
)
controlnet = component["controlnet"]
if torch_dtype is not None:
controlnet = controlnet.to(torch_dtype)
return controlnet
+13 -4
View File
@@ -396,8 +396,7 @@ class LoraLoaderMixin:
# their prefixes.
keys = list(state_dict.keys())
only_text_encoder = all(key.startswith(cls.text_encoder_name) for key in keys)
if any(key.startswith(cls.unet_name) for key in keys) and not only_text_encoder:
if not only_text_encoder:
# Load the layers corresponding to UNet.
logger.info(f"Loading {cls.unet_name}.")
unet.load_attn_procs(
@@ -1601,6 +1600,8 @@ class SD3LoraLoaderMixin:
cls,
save_directory: Union[str, os.PathLike],
transformer_lora_layers: Dict[str, torch.nn.Module] = None,
text_encoder_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None,
text_encoder_2_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None,
is_main_process: bool = True,
weight_name: str = None,
save_function: Callable = None,
@@ -1632,12 +1633,20 @@ class SD3LoraLoaderMixin:
layers_state_dict = {f"{prefix}.{module_name}": param for module_name, param in layers_weights.items()}
return layers_state_dict
if not transformer_lora_layers:
raise ValueError("You must pass `transformer_lora_layers`.")
if not (transformer_lora_layers or text_encoder_lora_layers or text_encoder_2_lora_layers):
raise ValueError(
"You must pass at least one of `transformer_lora_layers`, `text_encoder_lora_layers`, `text_encoder_2_lora_layers`."
)
if transformer_lora_layers:
state_dict.update(pack_weights(transformer_lora_layers, cls.transformer_name))
if text_encoder_lora_layers:
state_dict.update(pack_weights(text_encoder_lora_layers, "text_encoder"))
if text_encoder_2_lora_layers:
state_dict.update(pack_weights(text_encoder_2_lora_layers, "text_encoder_2"))
# Save the model
cls.write_lora_layers(
state_dict=state_dict,
@@ -142,10 +142,10 @@ def _convert_non_diffusers_lora_to_diffusers(state_dict, unet_name="unet", text_
network_alphas = {}
# Check for DoRA-enabled LoRAs.
if any(
"dora_scale" in k and ("lora_unet_" in k or "lora_te_" in k or "lora_te1_" in k or "lora_te2_" in k)
for k in state_dict
):
dora_present_in_unet = any("dora_scale" in k and "lora_unet_" in k for k in state_dict)
dora_present_in_te = any("dora_scale" in k and ("lora_te_" in k or "lora_te1_" in k) for k in state_dict)
dora_present_in_te2 = any("dora_scale" in k and "lora_te2_" in k for k in state_dict)
if dora_present_in_unet or dora_present_in_te or dora_present_in_te2:
if is_peft_version("<", "0.9.0"):
raise ValueError(
"You need `peft` 0.9.0 at least to use DoRA-enabled LoRAs. Please upgrade your installation of `peft`."
@@ -173,7 +173,7 @@ def _convert_non_diffusers_lora_to_diffusers(state_dict, unet_name="unet", text_
unet_state_dict[diffusers_name.replace(".down.", ".up.")] = state_dict.pop(lora_name_up)
# Store DoRA scale if present.
if "dora_scale" in state_dict:
if dora_present_in_unet:
dora_scale_key_to_replace = "_lora.down." if "_lora.down." in diffusers_name else ".lora.down."
unet_state_dict[
diffusers_name.replace(dora_scale_key_to_replace, ".lora_magnitude_vector.")
@@ -192,7 +192,7 @@ def _convert_non_diffusers_lora_to_diffusers(state_dict, unet_name="unet", text_
te2_state_dict[diffusers_name.replace(".down.", ".up.")] = state_dict.pop(lora_name_up)
# Store DoRA scale if present.
if "dora_scale" in state_dict:
if dora_present_in_te or dora_present_in_te2:
dora_scale_key_to_replace_te = (
"_lora.down." if "_lora.down." in diffusers_name else ".lora_linear_layer."
)
@@ -214,7 +214,7 @@ def _convert_non_diffusers_lora_to_diffusers(state_dict, unet_name="unet", text_
if len(state_dict) > 0:
raise ValueError(f"The following keys have not been correctly renamed: \n\n {', '.join(state_dict.keys())}")
logger.info("Kohya-style checkpoint detected.")
logger.info("Non-diffusers checkpoint detected.")
# Construct final state dict.
unet_state_dict = {f"{unet_name}.{module_name}": params for module_name, params in unet_state_dict.items()}
+18 -5
View File
@@ -11,6 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import importlib
import inspect
import re
from contextlib import nullcontext
@@ -72,6 +73,17 @@ SINGLE_FILE_LOADABLE_CLASSES = {
}
def _get_single_file_loadable_mapping_class(cls):
diffusers_module = importlib.import_module(__name__.split(".")[0])
for loadable_class_str in SINGLE_FILE_LOADABLE_CLASSES:
loadable_class = getattr(diffusers_module, loadable_class_str)
if issubclass(cls, loadable_class):
return loadable_class_str
return None
def _get_mapping_function_kwargs(mapping_fn, **kwargs):
parameters = inspect.signature(mapping_fn).parameters
@@ -149,8 +161,9 @@ class FromOriginalModelMixin:
```
"""
class_name = cls.__name__
if class_name not in SINGLE_FILE_LOADABLE_CLASSES:
mapping_class_name = _get_single_file_loadable_mapping_class(cls)
# if class_name not in SINGLE_FILE_LOADABLE_CLASSES:
if mapping_class_name is None:
raise ValueError(
f"FromOriginalModelMixin is currently only compatible with {', '.join(SINGLE_FILE_LOADABLE_CLASSES.keys())}"
)
@@ -195,7 +208,7 @@ class FromOriginalModelMixin:
revision=revision,
)
mapping_functions = SINGLE_FILE_LOADABLE_CLASSES[class_name]
mapping_functions = SINGLE_FILE_LOADABLE_CLASSES[mapping_class_name]
checkpoint_mapping_fn = mapping_functions["checkpoint_mapping_fn"]
if original_config:
@@ -207,7 +220,7 @@ class FromOriginalModelMixin:
if config_mapping_fn is None:
raise ValueError(
(
f"`original_config` has been provided for {class_name} but no mapping function"
f"`original_config` has been provided for {mapping_class_name} but no mapping function"
"was found to convert the original config to a Diffusers config in"
"`diffusers.loaders.single_file_utils`"
)
@@ -267,7 +280,7 @@ class FromOriginalModelMixin:
)
if not diffusers_format_checkpoint:
raise SingleFileComponentError(
f"Failed to load {class_name}. Weights for this component appear to be missing in the checkpoint."
f"Failed to load {mapping_class_name}. Weights for this component appear to be missing in the checkpoint."
)
ctx = init_empty_weights if is_accelerate_available() else nullcontext
+10 -5
View File
@@ -457,6 +457,15 @@ class UNet2DConditionLoadersMixin:
)
if is_custom_diffusion:
state_dict = self._get_custom_diffusion_state_dict()
if save_function is None and safe_serialization:
# safetensors does not support saving dicts with non-tensor values
empty_state_dict = {k: v for k, v in state_dict.items() if not isinstance(v, torch.Tensor)}
if len(empty_state_dict) > 0:
logger.warning(
f"Safetensors does not support saving dicts with non-tensor values. "
f"The following keys will be ignored: {empty_state_dict.keys()}"
)
state_dict = {k: v for k, v in state_dict.items() if isinstance(v, torch.Tensor)}
else:
if not USE_PEFT_BACKEND:
raise ValueError("PEFT backend is required for saving LoRAs using the `save_attn_procs()` method.")
@@ -922,8 +931,6 @@ class UNet2DConditionLoadersMixin:
def _convert_ip_adapter_attn_to_diffusers(self, state_dicts, low_cpu_mem_usage=False):
from ..models.attention_processor import (
AttnProcessor,
AttnProcessor2_0,
IPAdapterAttnProcessor,
IPAdapterAttnProcessor2_0,
)
@@ -963,9 +970,7 @@ class UNet2DConditionLoadersMixin:
hidden_size = self.config.block_out_channels[block_id]
if cross_attention_dim is None or "motion_modules" in name:
attn_processor_class = (
AttnProcessor2_0 if hasattr(F, "scaled_dot_product_attention") else AttnProcessor
)
attn_processor_class = self.attn_processors[name].__class__
attn_procs[name] = attn_processor_class()
else:
+2
View File
@@ -33,6 +33,7 @@ if is_torch_available():
_import_structure["autoencoders.consistency_decoder_vae"] = ["ConsistencyDecoderVAE"]
_import_structure["autoencoders.vq_model"] = ["VQModel"]
_import_structure["controlnet"] = ["ControlNetModel"]
_import_structure["controlnet_hunyuan"] = ["HunyuanDiT2DControlNetModel", "HunyuanDiT2DMultiControlNetModel"]
_import_structure["controlnet_sd3"] = ["SD3ControlNetModel", "SD3MultiControlNetModel"]
_import_structure["controlnet_xs"] = ["ControlNetXSAdapter", "UNetControlNetXSModel"]
_import_structure["embeddings"] = ["ImageProjection"]
@@ -75,6 +76,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
VQModel,
)
from .controlnet import ControlNetModel
from .controlnet_hunyuan import HunyuanDiT2DControlNetModel, HunyuanDiT2DMultiControlNetModel
from .controlnet_sd3 import SD3ControlNetModel, SD3MultiControlNetModel
from .controlnet_xs import ControlNetXSAdapter, UNetControlNetXSModel
from .embeddings import ImageProjection
+216
View File
@@ -2561,6 +2561,220 @@ class IPAdapterAttnProcessor2_0(torch.nn.Module):
return hidden_states
class PAGIdentitySelfAttnProcessor2_0:
r"""
Processor for implementing PAG using scaled dot-product attention (enabled by default if you're using PyTorch 2.0).
PAG reference: https://arxiv.org/abs/2403.17377
"""
def __init__(self):
if not hasattr(F, "scaled_dot_product_attention"):
raise ImportError(
"PAGIdentitySelfAttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0."
)
def __call__(
self,
attn: Attention,
hidden_states: torch.FloatTensor,
encoder_hidden_states: Optional[torch.FloatTensor] = None,
attention_mask: Optional[torch.FloatTensor] = None,
temb: Optional[torch.FloatTensor] = None,
) -> torch.Tensor:
residual = hidden_states
if attn.spatial_norm is not None:
hidden_states = attn.spatial_norm(hidden_states, temb)
input_ndim = hidden_states.ndim
if input_ndim == 4:
batch_size, channel, height, width = hidden_states.shape
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
# chunk
hidden_states_org, hidden_states_ptb = hidden_states.chunk(2)
# original path
batch_size, sequence_length, _ = hidden_states_org.shape
if attention_mask is not None:
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
# scaled_dot_product_attention expects attention_mask shape to be
# (batch, heads, source_length, target_length)
attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
if attn.group_norm is not None:
hidden_states_org = attn.group_norm(hidden_states_org.transpose(1, 2)).transpose(1, 2)
query = attn.to_q(hidden_states_org)
key = attn.to_k(hidden_states_org)
value = attn.to_v(hidden_states_org)
inner_dim = key.shape[-1]
head_dim = inner_dim // attn.heads
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
# the output of sdp = (batch, num_heads, seq_len, head_dim)
# TODO: add support for attn.scale when we move to Torch 2.1
hidden_states_org = F.scaled_dot_product_attention(
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
)
hidden_states_org = hidden_states_org.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
hidden_states_org = hidden_states_org.to(query.dtype)
# linear proj
hidden_states_org = attn.to_out[0](hidden_states_org)
# dropout
hidden_states_org = attn.to_out[1](hidden_states_org)
if input_ndim == 4:
hidden_states_org = hidden_states_org.transpose(-1, -2).reshape(batch_size, channel, height, width)
# perturbed path (identity attention)
batch_size, sequence_length, _ = hidden_states_ptb.shape
if attention_mask is not None:
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
# scaled_dot_product_attention expects attention_mask shape to be
# (batch, heads, source_length, target_length)
attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
if attn.group_norm is not None:
hidden_states_ptb = attn.group_norm(hidden_states_ptb.transpose(1, 2)).transpose(1, 2)
hidden_states_ptb = attn.to_v(hidden_states_ptb)
hidden_states_ptb = hidden_states_ptb.to(query.dtype)
# linear proj
hidden_states_ptb = attn.to_out[0](hidden_states_ptb)
# dropout
hidden_states_ptb = attn.to_out[1](hidden_states_ptb)
if input_ndim == 4:
hidden_states_ptb = hidden_states_ptb.transpose(-1, -2).reshape(batch_size, channel, height, width)
# cat
hidden_states = torch.cat([hidden_states_org, hidden_states_ptb])
if attn.residual_connection:
hidden_states = hidden_states + residual
hidden_states = hidden_states / attn.rescale_output_factor
return hidden_states
class PAGCFGIdentitySelfAttnProcessor2_0:
r"""
Processor for implementing PAG using scaled dot-product attention (enabled by default if you're using PyTorch 2.0).
PAG reference: https://arxiv.org/abs/2403.17377
"""
def __init__(self):
if not hasattr(F, "scaled_dot_product_attention"):
raise ImportError(
"PAGCFGIdentitySelfAttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0."
)
def __call__(
self,
attn: Attention,
hidden_states: torch.FloatTensor,
encoder_hidden_states: Optional[torch.FloatTensor] = None,
attention_mask: Optional[torch.FloatTensor] = None,
temb: Optional[torch.FloatTensor] = None,
) -> torch.Tensor:
residual = hidden_states
if attn.spatial_norm is not None:
hidden_states = attn.spatial_norm(hidden_states, temb)
input_ndim = hidden_states.ndim
if input_ndim == 4:
batch_size, channel, height, width = hidden_states.shape
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
# chunk
hidden_states_uncond, hidden_states_org, hidden_states_ptb = hidden_states.chunk(3)
hidden_states_org = torch.cat([hidden_states_uncond, hidden_states_org])
# original path
batch_size, sequence_length, _ = hidden_states_org.shape
if attention_mask is not None:
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
# scaled_dot_product_attention expects attention_mask shape to be
# (batch, heads, source_length, target_length)
attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
if attn.group_norm is not None:
hidden_states_org = attn.group_norm(hidden_states_org.transpose(1, 2)).transpose(1, 2)
query = attn.to_q(hidden_states_org)
key = attn.to_k(hidden_states_org)
value = attn.to_v(hidden_states_org)
inner_dim = key.shape[-1]
head_dim = inner_dim // attn.heads
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
# the output of sdp = (batch, num_heads, seq_len, head_dim)
# TODO: add support for attn.scale when we move to Torch 2.1
hidden_states_org = F.scaled_dot_product_attention(
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
)
hidden_states_org = hidden_states_org.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
hidden_states_org = hidden_states_org.to(query.dtype)
# linear proj
hidden_states_org = attn.to_out[0](hidden_states_org)
# dropout
hidden_states_org = attn.to_out[1](hidden_states_org)
if input_ndim == 4:
hidden_states_org = hidden_states_org.transpose(-1, -2).reshape(batch_size, channel, height, width)
# perturbed path (identity attention)
batch_size, sequence_length, _ = hidden_states_ptb.shape
if attention_mask is not None:
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
# scaled_dot_product_attention expects attention_mask shape to be
# (batch, heads, source_length, target_length)
attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
if attn.group_norm is not None:
hidden_states_ptb = attn.group_norm(hidden_states_ptb.transpose(1, 2)).transpose(1, 2)
value = attn.to_v(hidden_states_ptb)
hidden_states_ptb = value
hidden_states_ptb = hidden_states_ptb.to(query.dtype)
# linear proj
hidden_states_ptb = attn.to_out[0](hidden_states_ptb)
# dropout
hidden_states_ptb = attn.to_out[1](hidden_states_ptb)
if input_ndim == 4:
hidden_states_ptb = hidden_states_ptb.transpose(-1, -2).reshape(batch_size, channel, height, width)
# cat
hidden_states = torch.cat([hidden_states_org, hidden_states_ptb])
if attn.residual_connection:
hidden_states = hidden_states + residual
hidden_states = hidden_states / attn.rescale_output_factor
return hidden_states
ADDED_KV_ATTENTION_PROCESSORS = (
AttnAddedKVProcessor,
SlicedAttnAddedKVProcessor,
@@ -2590,4 +2804,6 @@ AttentionProcessor = Union[
CustomDiffusionAttnProcessor,
CustomDiffusionXFormersAttnProcessor,
CustomDiffusionAttnProcessor2_0,
PAGCFGIdentitySelfAttnProcessor2_0,
PAGIdentitySelfAttnProcessor2_0,
]
@@ -166,12 +166,12 @@ class VQModel(ModelMixin, ConfigMixin):
Args:
sample (`torch.Tensor`): Input sample.
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`models.vq_model.VQEncoderOutput`] instead of a plain tuple.
Whether or not to return a [`models.autoencoders.vq_model.VQEncoderOutput`] instead of a plain tuple.
Returns:
[`~models.vq_model.VQEncoderOutput`] or `tuple`:
If return_dict is True, a [`~models.vq_model.VQEncoderOutput`] is returned, otherwise a plain `tuple`
is returned.
[`~models.autoencoders.vq_model.VQEncoderOutput`] or `tuple`:
If return_dict is True, a [`~models.autoencoders.vq_model.VQEncoderOutput`] is returned, otherwise a
plain `tuple` is returned.
"""
h = self.encode(sample).latents
+399
View File
@@ -0,0 +1,399 @@
# Copyright 2024 HunyuanDiT Authors, Qixun Wang and The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from dataclasses import dataclass
from typing import Dict, Optional, Union
import torch
from torch import nn
from ..configuration_utils import ConfigMixin, register_to_config
from ..utils import logging
from .attention_processor import AttentionProcessor
from .controlnet import BaseOutput, Tuple, zero_module
from .embeddings import (
HunyuanCombinedTimestepTextSizeStyleEmbedding,
PatchEmbed,
PixArtAlphaTextProjection,
)
from .modeling_utils import ModelMixin
from .transformers.hunyuan_transformer_2d import HunyuanDiTBlock
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
@dataclass
class HunyuanControlNetOutput(BaseOutput):
controlnet_block_samples: Tuple[torch.Tensor]
class HunyuanDiT2DControlNetModel(ModelMixin, ConfigMixin):
@register_to_config
def __init__(
self,
conditioning_channels: int = 3,
num_attention_heads: int = 16,
attention_head_dim: int = 88,
in_channels: Optional[int] = None,
patch_size: Optional[int] = None,
activation_fn: str = "gelu-approximate",
sample_size=32,
hidden_size=1152,
transformer_num_layers: int = 40,
mlp_ratio: float = 4.0,
cross_attention_dim: int = 1024,
cross_attention_dim_t5: int = 2048,
pooled_projection_dim: int = 1024,
text_len: int = 77,
text_len_t5: int = 256,
):
super().__init__()
self.num_heads = num_attention_heads
self.inner_dim = num_attention_heads * attention_head_dim
self.text_embedder = PixArtAlphaTextProjection(
in_features=cross_attention_dim_t5,
hidden_size=cross_attention_dim_t5 * 4,
out_features=cross_attention_dim,
act_fn="silu_fp32",
)
self.text_embedding_padding = nn.Parameter(
torch.randn(text_len + text_len_t5, cross_attention_dim, dtype=torch.float32)
)
self.pos_embed = PatchEmbed(
height=sample_size,
width=sample_size,
in_channels=in_channels,
embed_dim=hidden_size,
patch_size=patch_size,
pos_embed_type=None,
)
self.time_extra_emb = HunyuanCombinedTimestepTextSizeStyleEmbedding(
hidden_size,
pooled_projection_dim=pooled_projection_dim,
seq_len=text_len_t5,
cross_attention_dim=cross_attention_dim_t5,
)
# controlnet_blocks
self.controlnet_blocks = nn.ModuleList([])
# HunyuanDiT Blocks
self.blocks = nn.ModuleList(
[
HunyuanDiTBlock(
dim=self.inner_dim,
num_attention_heads=self.config.num_attention_heads,
activation_fn=activation_fn,
ff_inner_dim=int(self.inner_dim * mlp_ratio),
cross_attention_dim=cross_attention_dim,
qk_norm=True, # See http://arxiv.org/abs/2302.05442 for details.
skip=False, # always False as it is the first half of the model
)
for layer in range(transformer_num_layers // 2 - 1)
]
)
self.input_block = zero_module(nn.Linear(hidden_size, hidden_size))
for _ in range(len(self.blocks)):
controlnet_block = nn.Linear(hidden_size, hidden_size)
controlnet_block = zero_module(controlnet_block)
self.controlnet_blocks.append(controlnet_block)
@property
def attn_processors(self) -> Dict[str, AttentionProcessor]:
r"""
Returns:
`dict` of attention processors: A dictionary containing all attention processors used in the model with
indexed by its weight name.
"""
# set recursively
processors = {}
def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
if hasattr(module, "get_processor"):
processors[f"{name}.processor"] = module.get_processor(return_deprecated_lora=True)
for sub_name, child in module.named_children():
fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
return processors
for name, module in self.named_children():
fn_recursive_add_processors(name, module, processors)
return processors
def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
r"""
Sets the attention processor to use to compute attention.
Parameters:
processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
The instantiated processor class or a dictionary of processor classes that will be set as the processor
for **all** `Attention` layers. If `processor` is a dict, the key needs to define the path to the
corresponding cross attention processor. This is strongly recommended when setting trainable attention
processors.
"""
count = len(self.attn_processors.keys())
if isinstance(processor, dict) and len(processor) != count:
raise ValueError(
f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
)
def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
if hasattr(module, "set_processor"):
if not isinstance(processor, dict):
module.set_processor(processor)
else:
module.set_processor(processor.pop(f"{name}.processor"))
for sub_name, child in module.named_children():
fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
for name, module in self.named_children():
fn_recursive_attn_processor(name, module, processor)
@classmethod
def from_transformer(
cls, transformer, conditioning_channels=3, transformer_num_layers=None, load_weights_from_transformer=True
):
config = transformer.config
activation_fn = config.activation_fn
attention_head_dim = config.attention_head_dim
cross_attention_dim = config.cross_attention_dim
cross_attention_dim_t5 = config.cross_attention_dim_t5
hidden_size = config.hidden_size
in_channels = config.in_channels
mlp_ratio = config.mlp_ratio
num_attention_heads = config.num_attention_heads
patch_size = config.patch_size
sample_size = config.sample_size
text_len = config.text_len
text_len_t5 = config.text_len_t5
conditioning_channels = conditioning_channels
transformer_num_layers = transformer_num_layers or config.transformer_num_layers
controlnet = cls(
conditioning_channels=conditioning_channels,
transformer_num_layers=transformer_num_layers,
activation_fn=activation_fn,
attention_head_dim=attention_head_dim,
cross_attention_dim=cross_attention_dim,
cross_attention_dim_t5=cross_attention_dim_t5,
hidden_size=hidden_size,
in_channels=in_channels,
mlp_ratio=mlp_ratio,
num_attention_heads=num_attention_heads,
patch_size=patch_size,
sample_size=sample_size,
text_len=text_len,
text_len_t5=text_len_t5,
)
if load_weights_from_transformer:
key = controlnet.load_state_dict(transformer.state_dict(), strict=False)
logger.warning(f"controlnet load from Hunyuan-DiT. missing_keys: {key[0]}")
return controlnet
def forward(
self,
hidden_states,
timestep,
controlnet_cond: torch.Tensor,
conditioning_scale: float = 1.0,
encoder_hidden_states=None,
text_embedding_mask=None,
encoder_hidden_states_t5=None,
text_embedding_mask_t5=None,
image_meta_size=None,
style=None,
image_rotary_emb=None,
return_dict=True,
):
"""
The [`HunyuanDiT2DControlNetModel`] forward method.
Args:
hidden_states (`torch.Tensor` of shape `(batch size, dim, height, width)`):
The input tensor.
timestep ( `torch.LongTensor`, *optional*):
Used to indicate denoising step.
controlnet_cond ( `torch.Tensor` ):
The conditioning input to ControlNet.
conditioning_scale ( `float` ):
Indicate the conditioning scale.
encoder_hidden_states ( `torch.Tensor` of shape `(batch size, sequence len, embed dims)`, *optional*):
Conditional embeddings for cross attention layer. This is the output of `BertModel`.
text_embedding_mask: torch.Tensor
An attention mask of shape `(batch, key_tokens)` is applied to `encoder_hidden_states`. This is the output
of `BertModel`.
encoder_hidden_states_t5 ( `torch.Tensor` of shape `(batch size, sequence len, embed dims)`, *optional*):
Conditional embeddings for cross attention layer. This is the output of T5 Text Encoder.
text_embedding_mask_t5: torch.Tensor
An attention mask of shape `(batch, key_tokens)` is applied to `encoder_hidden_states`. This is the output
of T5 Text Encoder.
image_meta_size (torch.Tensor):
Conditional embedding indicate the image sizes
style: torch.Tensor:
Conditional embedding indicate the style
image_rotary_emb (`torch.Tensor`):
The image rotary embeddings to apply on query and key tensors during attention calculation.
return_dict: bool
Whether to return a dictionary.
"""
height, width = hidden_states.shape[-2:]
hidden_states = self.pos_embed(hidden_states) # b,c,H,W -> b, N, C
# 2. pre-process
hidden_states = hidden_states + self.input_block(self.pos_embed(controlnet_cond))
temb = self.time_extra_emb(
timestep, encoder_hidden_states_t5, image_meta_size, style, hidden_dtype=timestep.dtype
) # [B, D]
# text projection
batch_size, sequence_length, _ = encoder_hidden_states_t5.shape
encoder_hidden_states_t5 = self.text_embedder(
encoder_hidden_states_t5.view(-1, encoder_hidden_states_t5.shape[-1])
)
encoder_hidden_states_t5 = encoder_hidden_states_t5.view(batch_size, sequence_length, -1)
encoder_hidden_states = torch.cat([encoder_hidden_states, encoder_hidden_states_t5], dim=1)
text_embedding_mask = torch.cat([text_embedding_mask, text_embedding_mask_t5], dim=-1)
text_embedding_mask = text_embedding_mask.unsqueeze(2).bool()
encoder_hidden_states = torch.where(text_embedding_mask, encoder_hidden_states, self.text_embedding_padding)
block_res_samples = ()
for layer, block in enumerate(self.blocks):
hidden_states = block(
hidden_states,
temb=temb,
encoder_hidden_states=encoder_hidden_states,
image_rotary_emb=image_rotary_emb,
) # (N, L, D)
block_res_samples = block_res_samples + (hidden_states,)
controlnet_block_res_samples = ()
for block_res_sample, controlnet_block in zip(block_res_samples, self.controlnet_blocks):
block_res_sample = controlnet_block(block_res_sample)
controlnet_block_res_samples = controlnet_block_res_samples + (block_res_sample,)
# 6. scaling
controlnet_block_res_samples = [sample * conditioning_scale for sample in controlnet_block_res_samples]
if not return_dict:
return (controlnet_block_res_samples,)
return HunyuanControlNetOutput(controlnet_block_samples=controlnet_block_res_samples)
class HunyuanDiT2DMultiControlNetModel(ModelMixin):
r"""
`HunyuanDiT2DMultiControlNetModel` wrapper class for Multi-HunyuanDiT2DControlNetModel
This module is a wrapper for multiple instances of the `HunyuanDiT2DControlNetModel`. The `forward()` API is
designed to be compatible with `HunyuanDiT2DControlNetModel`.
Args:
controlnets (`List[HunyuanDiT2DControlNetModel]`):
Provides additional conditioning to the unet during the denoising process. You must set multiple
`HunyuanDiT2DControlNetModel` as a list.
"""
def __init__(self, controlnets):
super().__init__()
self.nets = nn.ModuleList(controlnets)
def forward(
self,
hidden_states,
timestep,
controlnet_cond: torch.Tensor,
conditioning_scale: float = 1.0,
encoder_hidden_states=None,
text_embedding_mask=None,
encoder_hidden_states_t5=None,
text_embedding_mask_t5=None,
image_meta_size=None,
style=None,
image_rotary_emb=None,
return_dict=True,
):
"""
The [`HunyuanDiT2DControlNetModel`] forward method.
Args:
hidden_states (`torch.Tensor` of shape `(batch size, dim, height, width)`):
The input tensor.
timestep ( `torch.LongTensor`, *optional*):
Used to indicate denoising step.
controlnet_cond ( `torch.Tensor` ):
The conditioning input to ControlNet.
conditioning_scale ( `float` ):
Indicate the conditioning scale.
encoder_hidden_states ( `torch.Tensor` of shape `(batch size, sequence len, embed dims)`, *optional*):
Conditional embeddings for cross attention layer. This is the output of `BertModel`.
text_embedding_mask: torch.Tensor
An attention mask of shape `(batch, key_tokens)` is applied to `encoder_hidden_states`. This is the output
of `BertModel`.
encoder_hidden_states_t5 ( `torch.Tensor` of shape `(batch size, sequence len, embed dims)`, *optional*):
Conditional embeddings for cross attention layer. This is the output of T5 Text Encoder.
text_embedding_mask_t5: torch.Tensor
An attention mask of shape `(batch, key_tokens)` is applied to `encoder_hidden_states`. This is the output
of T5 Text Encoder.
image_meta_size (torch.Tensor):
Conditional embedding indicate the image sizes
style: torch.Tensor:
Conditional embedding indicate the style
image_rotary_emb (`torch.Tensor`):
The image rotary embeddings to apply on query and key tensors during attention calculation.
return_dict: bool
Whether to return a dictionary.
"""
for i, (image, scale, controlnet) in enumerate(zip(controlnet_cond, conditioning_scale, self.nets)):
block_samples = controlnet(
hidden_states=hidden_states,
timestep=timestep,
controlnet_cond=image,
conditioning_scale=scale,
encoder_hidden_states=encoder_hidden_states,
text_embedding_mask=text_embedding_mask,
encoder_hidden_states_t5=encoder_hidden_states_t5,
text_embedding_mask_t5=text_embedding_mask_t5,
image_meta_size=image_meta_size,
style=style,
image_rotary_emb=image_rotary_emb,
return_dict=return_dict,
)
# merge samples
if i == 0:
control_block_samples = block_samples
else:
control_block_samples = [
control_block_sample + block_sample
for control_block_sample, block_sample in zip(control_block_samples[0], block_samples[0])
]
control_block_samples = (control_block_samples,)
return control_block_samples
+1 -1
View File
@@ -23,11 +23,11 @@ from ..configuration_utils import ConfigMixin, register_to_config
from ..loaders import FromOriginalModelMixin, PeftAdapterMixin
from ..models.attention import JointTransformerBlock
from ..models.attention_processor import Attention, AttentionProcessor
from ..models.modeling_outputs import Transformer2DModelOutput
from ..models.modeling_utils import ModelMixin
from ..utils import USE_PEFT_BACKEND, is_torch_version, logging, scale_lora_layers, unscale_lora_layers
from .controlnet import BaseOutput, zero_module
from .embeddings import CombinedTimestepTextProjEmbeddings, PatchEmbed
from .transformers.transformer_2d import Transformer2DModelOutput
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+28 -11
View File
@@ -717,7 +717,14 @@ class HunyuanDiTAttentionPool(nn.Module):
class HunyuanCombinedTimestepTextSizeStyleEmbedding(nn.Module):
def __init__(self, embedding_dim, pooled_projection_dim=1024, seq_len=256, cross_attention_dim=2048):
def __init__(
self,
embedding_dim,
pooled_projection_dim=1024,
seq_len=256,
cross_attention_dim=2048,
use_style_cond_and_image_meta_size=True,
):
super().__init__()
self.time_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0)
@@ -726,9 +733,15 @@ class HunyuanCombinedTimestepTextSizeStyleEmbedding(nn.Module):
self.pooler = HunyuanDiTAttentionPool(
seq_len, cross_attention_dim, num_heads=8, output_dim=pooled_projection_dim
)
# Here we use a default learned embedder layer for future extension.
self.style_embedder = nn.Embedding(1, embedding_dim)
extra_in_dim = 256 * 6 + embedding_dim + pooled_projection_dim
self.use_style_cond_and_image_meta_size = use_style_cond_and_image_meta_size
if use_style_cond_and_image_meta_size:
self.style_embedder = nn.Embedding(1, embedding_dim)
extra_in_dim = 256 * 6 + embedding_dim + pooled_projection_dim
else:
extra_in_dim = pooled_projection_dim
self.extra_embedder = PixArtAlphaTextProjection(
in_features=extra_in_dim,
hidden_size=embedding_dim * 4,
@@ -743,16 +756,20 @@ class HunyuanCombinedTimestepTextSizeStyleEmbedding(nn.Module):
# extra condition1: text
pooled_projections = self.pooler(encoder_hidden_states) # (N, 1024)
# extra condition2: image meta size embdding
image_meta_size = get_timestep_embedding(image_meta_size.view(-1), 256, True, 0)
image_meta_size = image_meta_size.to(dtype=hidden_dtype)
image_meta_size = image_meta_size.view(-1, 6 * 256) # (N, 1536)
if self.use_style_cond_and_image_meta_size:
# extra condition2: image meta size embdding
image_meta_size = get_timestep_embedding(image_meta_size.view(-1), 256, True, 0)
image_meta_size = image_meta_size.to(dtype=hidden_dtype)
image_meta_size = image_meta_size.view(-1, 6 * 256) # (N, 1536)
# extra condition3: style embedding
style_embedding = self.style_embedder(style) # (N, embedding_dim)
# extra condition3: style embedding
style_embedding = self.style_embedder(style) # (N, embedding_dim)
# Concatenate all extra vectors
extra_cond = torch.cat([pooled_projections, image_meta_size, style_embedding], dim=1)
else:
extra_cond = torch.cat([pooled_projections], dim=1)
# Concatenate all extra vectors
extra_cond = torch.cat([pooled_projections, image_meta_size, style_embedding], dim=1)
conditioning = timesteps_emb + self.extra_embedder(extra_cond) # [B, D]
return conditioning
@@ -1,4 +1,4 @@
# Copyright 2024 HunyuanDiT Authors and The HuggingFace Team. All rights reserved.
# Copyright 2024 HunyuanDiT Authors, Qixun Wang and The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -249,6 +249,8 @@ class HunyuanDiT2DModel(ModelMixin, ConfigMixin):
The length of the clip text embedding.
text_len_t5 (`int`, *optional*):
The length of the T5 text embedding.
use_style_cond_and_image_meta_size (`bool`, *optional*):
Whether or not to use style condition and image meta size. True for version <=1.1, False for version >= 1.2
"""
@register_to_config
@@ -270,6 +272,7 @@ class HunyuanDiT2DModel(ModelMixin, ConfigMixin):
pooled_projection_dim: int = 1024,
text_len: int = 77,
text_len_t5: int = 256,
use_style_cond_and_image_meta_size: bool = True,
):
super().__init__()
self.out_channels = in_channels * 2 if learn_sigma else in_channels
@@ -301,6 +304,7 @@ class HunyuanDiT2DModel(ModelMixin, ConfigMixin):
pooled_projection_dim=pooled_projection_dim,
seq_len=text_len_t5,
cross_attention_dim=cross_attention_dim_t5,
use_style_cond_and_image_meta_size=use_style_cond_and_image_meta_size,
)
# HunyuanDiT Blocks
@@ -437,6 +441,7 @@ class HunyuanDiT2DModel(ModelMixin, ConfigMixin):
image_meta_size=None,
style=None,
image_rotary_emb=None,
controlnet_block_samples=None,
return_dict=True,
):
"""
@@ -491,7 +496,10 @@ class HunyuanDiT2DModel(ModelMixin, ConfigMixin):
skips = []
for layer, block in enumerate(self.blocks):
if layer > self.config.num_layers // 2:
skip = skips.pop()
if controlnet_block_samples is not None:
skip = skips.pop() + controlnet_block_samples.pop()
else:
skip = skips.pop()
hidden_states = block(
hidden_states,
temb=temb,
@@ -510,6 +518,9 @@ class HunyuanDiT2DModel(ModelMixin, ConfigMixin):
if layer < (self.config.num_layers // 2 - 1):
skips.append(hidden_states)
if controlnet_block_samples is not None and len(controlnet_block_samples) != 0:
raise ValueError("The number of controls is not equal to the number of skip connections.")
# final layer
hidden_states = self.norm_out(hidden_states, temb.to(torch.float32))
hidden_states = self.proj_out(hidden_states)
@@ -30,8 +30,10 @@ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
class Transformer2DModelOutput(Transformer2DModelOutput):
deprecation_message = "Importing `Transformer2DModelOutput` from `diffusers.models.transformer_2d` is deprecated and this will be removed in a future version. Please use `from diffusers.models.modeling_outputs import Transformer2DModelOutput`, instead."
deprecate("Transformer2DModelOutput", "1.0.0", deprecation_message)
def __init__(self, *args, **kwargs):
deprecation_message = "Importing `Transformer2DModelOutput` from `diffusers.models.transformer_2d` is deprecated and this will be removed in a future version. Please use `from diffusers.models.modeling_outputs import Transformer2DModelOutput`, instead."
deprecate("Transformer2DModelOutput", "1.0.0", deprecation_message)
super().__init__(*args, **kwargs)
class Transformer2DModel(LegacyModelMixin, LegacyConfigMixin):
@@ -28,7 +28,7 @@ from ...models.modeling_utils import ModelMixin
from ...models.normalization import AdaLayerNormContinuous
from ...utils import USE_PEFT_BACKEND, is_torch_version, logging, scale_lora_layers, unscale_lora_layers
from ..embeddings import CombinedTimestepTextProjEmbeddings, PatchEmbed
from .transformer_2d import Transformer2DModelOutput
from ..modeling_outputs import Transformer2DModelOutput
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+112 -15
View File
@@ -58,7 +58,9 @@ def get_down_block(
resnet_time_scale_shift: str = "default",
temporal_num_attention_heads: int = 8,
temporal_max_seq_length: int = 32,
transformer_layers_per_block: int = 1,
transformer_layers_per_block: Union[int, Tuple[int]] = 1,
temporal_transformer_layers_per_block: Union[int, Tuple[int]] = 1,
dropout: float = 0.0,
) -> Union[
"DownBlock3D",
"CrossAttnDownBlock3D",
@@ -79,6 +81,7 @@ def get_down_block(
resnet_groups=resnet_groups,
downsample_padding=downsample_padding,
resnet_time_scale_shift=resnet_time_scale_shift,
dropout=dropout,
)
elif down_block_type == "CrossAttnDownBlock3D":
if cross_attention_dim is None:
@@ -100,6 +103,7 @@ def get_down_block(
only_cross_attention=only_cross_attention,
upcast_attention=upcast_attention,
resnet_time_scale_shift=resnet_time_scale_shift,
dropout=dropout,
)
if down_block_type == "DownBlockMotion":
return DownBlockMotion(
@@ -115,6 +119,8 @@ def get_down_block(
resnet_time_scale_shift=resnet_time_scale_shift,
temporal_num_attention_heads=temporal_num_attention_heads,
temporal_max_seq_length=temporal_max_seq_length,
temporal_transformer_layers_per_block=temporal_transformer_layers_per_block,
dropout=dropout,
)
elif down_block_type == "CrossAttnDownBlockMotion":
if cross_attention_dim is None:
@@ -139,6 +145,8 @@ def get_down_block(
resnet_time_scale_shift=resnet_time_scale_shift,
temporal_num_attention_heads=temporal_num_attention_heads,
temporal_max_seq_length=temporal_max_seq_length,
temporal_transformer_layers_per_block=temporal_transformer_layers_per_block,
dropout=dropout,
)
elif down_block_type == "DownBlockSpatioTemporal":
# added for SDV
@@ -189,7 +197,8 @@ def get_up_block(
temporal_num_attention_heads: int = 8,
temporal_cross_attention_dim: Optional[int] = None,
temporal_max_seq_length: int = 32,
transformer_layers_per_block: int = 1,
transformer_layers_per_block: Union[int, Tuple[int]] = 1,
temporal_transformer_layers_per_block: Union[int, Tuple[int]] = 1,
dropout: float = 0.0,
) -> Union[
"UpBlock3D",
@@ -212,6 +221,7 @@ def get_up_block(
resnet_groups=resnet_groups,
resnet_time_scale_shift=resnet_time_scale_shift,
resolution_idx=resolution_idx,
dropout=dropout,
)
elif up_block_type == "CrossAttnUpBlock3D":
if cross_attention_dim is None:
@@ -234,6 +244,7 @@ def get_up_block(
upcast_attention=upcast_attention,
resnet_time_scale_shift=resnet_time_scale_shift,
resolution_idx=resolution_idx,
dropout=dropout,
)
if up_block_type == "UpBlockMotion":
return UpBlockMotion(
@@ -250,6 +261,8 @@ def get_up_block(
resolution_idx=resolution_idx,
temporal_num_attention_heads=temporal_num_attention_heads,
temporal_max_seq_length=temporal_max_seq_length,
temporal_transformer_layers_per_block=temporal_transformer_layers_per_block,
dropout=dropout,
)
elif up_block_type == "CrossAttnUpBlockMotion":
if cross_attention_dim is None:
@@ -275,6 +288,8 @@ def get_up_block(
resolution_idx=resolution_idx,
temporal_num_attention_heads=temporal_num_attention_heads,
temporal_max_seq_length=temporal_max_seq_length,
temporal_transformer_layers_per_block=temporal_transformer_layers_per_block,
dropout=dropout,
)
elif up_block_type == "UpBlockSpatioTemporal":
# added for SDV
@@ -948,14 +963,31 @@ class DownBlockMotion(nn.Module):
output_scale_factor: float = 1.0,
add_downsample: bool = True,
downsample_padding: int = 1,
temporal_num_attention_heads: int = 1,
temporal_num_attention_heads: Union[int, Tuple[int]] = 1,
temporal_cross_attention_dim: Optional[int] = None,
temporal_max_seq_length: int = 32,
temporal_transformer_layers_per_block: Union[int, Tuple[int]] = 1,
):
super().__init__()
resnets = []
motion_modules = []
# support for variable transformer layers per temporal block
if isinstance(temporal_transformer_layers_per_block, int):
temporal_transformer_layers_per_block = (temporal_transformer_layers_per_block,) * num_layers
elif len(temporal_transformer_layers_per_block) != num_layers:
raise ValueError(
f"`temporal_transformer_layers_per_block` must be an integer or a tuple of integers of length {num_layers}"
)
# support for variable number of attention head per temporal layers
if isinstance(temporal_num_attention_heads, int):
temporal_num_attention_heads = (temporal_num_attention_heads,) * num_layers
elif len(temporal_num_attention_heads) != num_layers:
raise ValueError(
f"`temporal_num_attention_heads` must be an integer or a tuple of integers of length {num_layers}"
)
for i in range(num_layers):
in_channels = in_channels if i == 0 else out_channels
resnets.append(
@@ -974,15 +1006,16 @@ class DownBlockMotion(nn.Module):
)
motion_modules.append(
TransformerTemporalModel(
num_attention_heads=temporal_num_attention_heads,
num_attention_heads=temporal_num_attention_heads[i],
in_channels=out_channels,
num_layers=temporal_transformer_layers_per_block[i],
norm_num_groups=resnet_groups,
cross_attention_dim=temporal_cross_attention_dim,
attention_bias=False,
activation_fn="geglu",
positional_embeddings="sinusoidal",
num_positional_embeddings=temporal_max_seq_length,
attention_head_dim=out_channels // temporal_num_attention_heads,
attention_head_dim=out_channels // temporal_num_attention_heads[i],
)
)
@@ -1065,7 +1098,7 @@ class CrossAttnDownBlockMotion(nn.Module):
temb_channels: int,
dropout: float = 0.0,
num_layers: int = 1,
transformer_layers_per_block: int = 1,
transformer_layers_per_block: Union[int, Tuple[int]] = 1,
resnet_eps: float = 1e-6,
resnet_time_scale_shift: str = "default",
resnet_act_fn: str = "swish",
@@ -1084,6 +1117,7 @@ class CrossAttnDownBlockMotion(nn.Module):
temporal_cross_attention_dim: Optional[int] = None,
temporal_num_attention_heads: int = 8,
temporal_max_seq_length: int = 32,
temporal_transformer_layers_per_block: Union[int, Tuple[int]] = 1,
):
super().__init__()
resnets = []
@@ -1093,6 +1127,22 @@ class CrossAttnDownBlockMotion(nn.Module):
self.has_cross_attention = True
self.num_attention_heads = num_attention_heads
# support for variable transformer layers per block
if isinstance(transformer_layers_per_block, int):
transformer_layers_per_block = (transformer_layers_per_block,) * num_layers
elif len(transformer_layers_per_block) != num_layers:
raise ValueError(
f"transformer_layers_per_block must be an integer or a list of integers of length {num_layers}"
)
# support for variable transformer layers per temporal block
if isinstance(temporal_transformer_layers_per_block, int):
temporal_transformer_layers_per_block = (temporal_transformer_layers_per_block,) * num_layers
elif len(temporal_transformer_layers_per_block) != num_layers:
raise ValueError(
f"temporal_transformer_layers_per_block must be an integer or a list of integers of length {num_layers}"
)
for i in range(num_layers):
in_channels = in_channels if i == 0 else out_channels
resnets.append(
@@ -1116,7 +1166,7 @@ class CrossAttnDownBlockMotion(nn.Module):
num_attention_heads,
out_channels // num_attention_heads,
in_channels=out_channels,
num_layers=transformer_layers_per_block,
num_layers=transformer_layers_per_block[i],
cross_attention_dim=cross_attention_dim,
norm_num_groups=resnet_groups,
use_linear_projection=use_linear_projection,
@@ -1141,6 +1191,7 @@ class CrossAttnDownBlockMotion(nn.Module):
TransformerTemporalModel(
num_attention_heads=temporal_num_attention_heads,
in_channels=out_channels,
num_layers=temporal_transformer_layers_per_block[i],
norm_num_groups=resnet_groups,
cross_attention_dim=temporal_cross_attention_dim,
attention_bias=False,
@@ -1257,7 +1308,7 @@ class CrossAttnUpBlockMotion(nn.Module):
resolution_idx: Optional[int] = None,
dropout: float = 0.0,
num_layers: int = 1,
transformer_layers_per_block: int = 1,
transformer_layers_per_block: Union[int, Tuple[int]] = 1,
resnet_eps: float = 1e-6,
resnet_time_scale_shift: str = "default",
resnet_act_fn: str = "swish",
@@ -1275,6 +1326,7 @@ class CrossAttnUpBlockMotion(nn.Module):
temporal_cross_attention_dim: Optional[int] = None,
temporal_num_attention_heads: int = 8,
temporal_max_seq_length: int = 32,
temporal_transformer_layers_per_block: Union[int, Tuple[int]] = 1,
):
super().__init__()
resnets = []
@@ -1284,6 +1336,22 @@ class CrossAttnUpBlockMotion(nn.Module):
self.has_cross_attention = True
self.num_attention_heads = num_attention_heads
# support for variable transformer layers per block
if isinstance(transformer_layers_per_block, int):
transformer_layers_per_block = (transformer_layers_per_block,) * num_layers
elif len(transformer_layers_per_block) != num_layers:
raise ValueError(
f"transformer_layers_per_block must be an integer or a list of integers of length {num_layers}, got {len(transformer_layers_per_block)}"
)
# support for variable transformer layers per temporal block
if isinstance(temporal_transformer_layers_per_block, int):
temporal_transformer_layers_per_block = (temporal_transformer_layers_per_block,) * num_layers
elif len(temporal_transformer_layers_per_block) != num_layers:
raise ValueError(
f"temporal_transformer_layers_per_block must be an integer or a list of integers of length {num_layers}, got {len(temporal_transformer_layers_per_block)}"
)
for i in range(num_layers):
res_skip_channels = in_channels if (i == num_layers - 1) else out_channels
resnet_in_channels = prev_output_channel if i == 0 else out_channels
@@ -1309,7 +1377,7 @@ class CrossAttnUpBlockMotion(nn.Module):
num_attention_heads,
out_channels // num_attention_heads,
in_channels=out_channels,
num_layers=transformer_layers_per_block,
num_layers=transformer_layers_per_block[i],
cross_attention_dim=cross_attention_dim,
norm_num_groups=resnet_groups,
use_linear_projection=use_linear_projection,
@@ -1333,6 +1401,7 @@ class CrossAttnUpBlockMotion(nn.Module):
TransformerTemporalModel(
num_attention_heads=temporal_num_attention_heads,
in_channels=out_channels,
num_layers=temporal_transformer_layers_per_block[i],
norm_num_groups=resnet_groups,
cross_attention_dim=temporal_cross_attention_dim,
attention_bias=False,
@@ -1467,11 +1536,20 @@ class UpBlockMotion(nn.Module):
temporal_cross_attention_dim: Optional[int] = None,
temporal_num_attention_heads: int = 8,
temporal_max_seq_length: int = 32,
temporal_transformer_layers_per_block: Union[int, Tuple[int]] = 1,
):
super().__init__()
resnets = []
motion_modules = []
# support for variable transformer layers per temporal block
if isinstance(temporal_transformer_layers_per_block, int):
temporal_transformer_layers_per_block = (temporal_transformer_layers_per_block,) * num_layers
elif len(temporal_transformer_layers_per_block) != num_layers:
raise ValueError(
f"temporal_transformer_layers_per_block must be an integer or a list of integers of length {num_layers}"
)
for i in range(num_layers):
res_skip_channels = in_channels if (i == num_layers - 1) else out_channels
resnet_in_channels = prev_output_channel if i == 0 else out_channels
@@ -1495,6 +1573,7 @@ class UpBlockMotion(nn.Module):
TransformerTemporalModel(
num_attention_heads=temporal_num_attention_heads,
in_channels=out_channels,
num_layers=temporal_transformer_layers_per_block[i],
norm_num_groups=temporal_norm_num_groups,
cross_attention_dim=temporal_cross_attention_dim,
attention_bias=False,
@@ -1596,7 +1675,7 @@ class UNetMidBlockCrossAttnMotion(nn.Module):
temb_channels: int,
dropout: float = 0.0,
num_layers: int = 1,
transformer_layers_per_block: int = 1,
transformer_layers_per_block: Union[int, Tuple[int]] = 1,
resnet_eps: float = 1e-6,
resnet_time_scale_shift: str = "default",
resnet_act_fn: str = "swish",
@@ -1605,13 +1684,14 @@ class UNetMidBlockCrossAttnMotion(nn.Module):
num_attention_heads: int = 1,
output_scale_factor: float = 1.0,
cross_attention_dim: int = 1280,
dual_cross_attention: float = False,
use_linear_projection: float = False,
upcast_attention: float = False,
dual_cross_attention: bool = False,
use_linear_projection: bool = False,
upcast_attention: bool = False,
attention_type: str = "default",
temporal_num_attention_heads: int = 1,
temporal_cross_attention_dim: Optional[int] = None,
temporal_max_seq_length: int = 32,
temporal_transformer_layers_per_block: Union[int, Tuple[int]] = 1,
):
super().__init__()
@@ -1619,6 +1699,22 @@ class UNetMidBlockCrossAttnMotion(nn.Module):
self.num_attention_heads = num_attention_heads
resnet_groups = resnet_groups if resnet_groups is not None else min(in_channels // 4, 32)
# support for variable transformer layers per block
if isinstance(transformer_layers_per_block, int):
transformer_layers_per_block = (transformer_layers_per_block,) * num_layers
elif len(transformer_layers_per_block) != num_layers:
raise ValueError(
f"`transformer_layers_per_block` should be an integer or a list of integers of length {num_layers}."
)
# support for variable transformer layers per temporal block
if isinstance(temporal_transformer_layers_per_block, int):
temporal_transformer_layers_per_block = (temporal_transformer_layers_per_block,) * num_layers
elif len(temporal_transformer_layers_per_block) != num_layers:
raise ValueError(
f"`temporal_transformer_layers_per_block` should be an integer or a list of integers of length {num_layers}."
)
# there is always at least one resnet
resnets = [
ResnetBlock2D(
@@ -1637,14 +1733,14 @@ class UNetMidBlockCrossAttnMotion(nn.Module):
attentions = []
motion_modules = []
for _ in range(num_layers):
for i in range(num_layers):
if not dual_cross_attention:
attentions.append(
Transformer2DModel(
num_attention_heads,
in_channels // num_attention_heads,
in_channels=in_channels,
num_layers=transformer_layers_per_block,
num_layers=transformer_layers_per_block[i],
cross_attention_dim=cross_attention_dim,
norm_num_groups=resnet_groups,
use_linear_projection=use_linear_projection,
@@ -1682,6 +1778,7 @@ class UNetMidBlockCrossAttnMotion(nn.Module):
num_attention_heads=temporal_num_attention_heads,
attention_head_dim=in_channels // temporal_num_attention_heads,
in_channels=in_channels,
num_layers=temporal_transformer_layers_per_block[i],
norm_num_groups=resnet_groups,
cross_attention_dim=temporal_cross_attention_dim,
attention_bias=False,
+149 -24
View File
@@ -57,7 +57,8 @@ class MotionModules(nn.Module):
self,
in_channels: int,
layers_per_block: int = 2,
num_attention_heads: int = 8,
transformer_layers_per_block: Union[int, Tuple[int]] = 8,
num_attention_heads: Union[int, Tuple[int]] = 8,
attention_bias: bool = False,
cross_attention_dim: Optional[int] = None,
activation_fn: str = "geglu",
@@ -67,10 +68,19 @@ class MotionModules(nn.Module):
super().__init__()
self.motion_modules = nn.ModuleList([])
if isinstance(transformer_layers_per_block, int):
transformer_layers_per_block = (transformer_layers_per_block,) * layers_per_block
elif len(transformer_layers_per_block) != layers_per_block:
raise ValueError(
f"The number of transformer layers per block must match the number of layers per block, "
f"got {layers_per_block} and {len(transformer_layers_per_block)}"
)
for i in range(layers_per_block):
self.motion_modules.append(
TransformerTemporalModel(
in_channels=in_channels,
num_layers=transformer_layers_per_block[i],
norm_num_groups=norm_num_groups,
cross_attention_dim=cross_attention_dim,
activation_fn=activation_fn,
@@ -88,9 +98,11 @@ class MotionAdapter(ModelMixin, ConfigMixin):
def __init__(
self,
block_out_channels: Tuple[int, ...] = (320, 640, 1280, 1280),
motion_layers_per_block: int = 2,
motion_layers_per_block: Union[int, Tuple[int]] = 2,
motion_transformer_layers_per_block: Union[int, Tuple[int], Tuple[Tuple[int]]] = 1,
motion_mid_block_layers_per_block: int = 1,
motion_num_attention_heads: int = 8,
motion_transformer_layers_per_mid_block: Union[int, Tuple[int]] = 1,
motion_num_attention_heads: Union[int, Tuple[int]] = 8,
motion_norm_num_groups: int = 32,
motion_max_seq_length: int = 32,
use_motion_mid_block: bool = True,
@@ -101,11 +113,15 @@ class MotionAdapter(ModelMixin, ConfigMixin):
Args:
block_out_channels (`Tuple[int]`, *optional*, defaults to `(320, 640, 1280, 1280)`):
The tuple of output channels for each UNet block.
motion_layers_per_block (`int`, *optional*, defaults to 2):
motion_layers_per_block (`int` or `Tuple[int]`, *optional*, defaults to 2):
The number of motion layers per UNet block.
motion_transformer_layers_per_block (`int`, `Tuple[int]`, or `Tuple[Tuple[int]]`, *optional*, defaults to 1):
The number of transformer layers to use in each motion layer in each block.
motion_mid_block_layers_per_block (`int`, *optional*, defaults to 1):
The number of motion layers in the middle UNet block.
motion_num_attention_heads (`int`, *optional*, defaults to 8):
motion_transformer_layers_per_mid_block (`int` or `Tuple[int]`, *optional*, defaults to 1):
The number of transformer layers to use in each motion layer in the middle block.
motion_num_attention_heads (`int` or `Tuple[int]`, *optional*, defaults to 8):
The number of heads to use in each attention layer of the motion module.
motion_norm_num_groups (`int`, *optional*, defaults to 32):
The number of groups to use in each group normalization layer of the motion module.
@@ -119,6 +135,35 @@ class MotionAdapter(ModelMixin, ConfigMixin):
down_blocks = []
up_blocks = []
if isinstance(motion_layers_per_block, int):
motion_layers_per_block = (motion_layers_per_block,) * len(block_out_channels)
elif len(motion_layers_per_block) != len(block_out_channels):
raise ValueError(
f"The number of motion layers per block must match the number of blocks, "
f"got {len(block_out_channels)} and {len(motion_layers_per_block)}"
)
if isinstance(motion_transformer_layers_per_block, int):
motion_transformer_layers_per_block = (motion_transformer_layers_per_block,) * len(block_out_channels)
if isinstance(motion_transformer_layers_per_mid_block, int):
motion_transformer_layers_per_mid_block = (
motion_transformer_layers_per_mid_block,
) * motion_mid_block_layers_per_block
elif len(motion_transformer_layers_per_mid_block) != motion_mid_block_layers_per_block:
raise ValueError(
f"The number of layers per mid block ({motion_mid_block_layers_per_block}) "
f"must match the length of motion_transformer_layers_per_mid_block ({len(motion_transformer_layers_per_mid_block)})"
)
if isinstance(motion_num_attention_heads, int):
motion_num_attention_heads = (motion_num_attention_heads,) * len(block_out_channels)
elif len(motion_num_attention_heads) != len(block_out_channels):
raise ValueError(
f"The length of the attention head number tuple in the motion module must match the "
f"number of block, got {len(motion_num_attention_heads)} and {len(block_out_channels)}"
)
if conv_in_channels:
# input
self.conv_in = nn.Conv2d(conv_in_channels, block_out_channels[0], kernel_size=3, padding=1)
@@ -134,9 +179,10 @@ class MotionAdapter(ModelMixin, ConfigMixin):
cross_attention_dim=None,
activation_fn="geglu",
attention_bias=False,
num_attention_heads=motion_num_attention_heads,
num_attention_heads=motion_num_attention_heads[i],
max_seq_length=motion_max_seq_length,
layers_per_block=motion_layers_per_block,
layers_per_block=motion_layers_per_block[i],
transformer_layers_per_block=motion_transformer_layers_per_block[i],
)
)
@@ -147,15 +193,20 @@ class MotionAdapter(ModelMixin, ConfigMixin):
cross_attention_dim=None,
activation_fn="geglu",
attention_bias=False,
num_attention_heads=motion_num_attention_heads,
layers_per_block=motion_mid_block_layers_per_block,
num_attention_heads=motion_num_attention_heads[-1],
max_seq_length=motion_max_seq_length,
layers_per_block=motion_mid_block_layers_per_block,
transformer_layers_per_block=motion_transformer_layers_per_mid_block,
)
else:
self.mid_block = None
reversed_block_out_channels = list(reversed(block_out_channels))
output_channel = reversed_block_out_channels[0]
reversed_motion_layers_per_block = list(reversed(motion_layers_per_block))
reversed_motion_transformer_layers_per_block = list(reversed(motion_transformer_layers_per_block))
reversed_motion_num_attention_heads = list(reversed(motion_num_attention_heads))
for i, channel in enumerate(reversed_block_out_channels):
output_channel = reversed_block_out_channels[i]
up_blocks.append(
@@ -165,9 +216,10 @@ class MotionAdapter(ModelMixin, ConfigMixin):
cross_attention_dim=None,
activation_fn="geglu",
attention_bias=False,
num_attention_heads=motion_num_attention_heads,
num_attention_heads=reversed_motion_num_attention_heads[i],
max_seq_length=motion_max_seq_length,
layers_per_block=motion_layers_per_block + 1,
layers_per_block=reversed_motion_layers_per_block[i] + 1,
transformer_layers_per_block=reversed_motion_transformer_layers_per_block[i],
)
)
@@ -208,7 +260,7 @@ class UNetMotionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin):
"CrossAttnUpBlockMotion",
),
block_out_channels: Tuple[int, ...] = (320, 640, 1280, 1280),
layers_per_block: int = 2,
layers_per_block: Union[int, Tuple[int]] = 2,
downsample_padding: int = 1,
mid_block_scale_factor: float = 1,
act_fn: str = "silu",
@@ -216,12 +268,18 @@ class UNetMotionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin):
norm_eps: float = 1e-5,
cross_attention_dim: int = 1280,
transformer_layers_per_block: Union[int, Tuple[int], Tuple[Tuple]] = 1,
reverse_transformer_layers_per_block: Optional[Tuple[Tuple[int]]] = None,
reverse_transformer_layers_per_block: Optional[Union[int, Tuple[int], Tuple[Tuple]]] = None,
temporal_transformer_layers_per_block: Union[int, Tuple[int], Tuple[Tuple]] = 1,
reverse_temporal_transformer_layers_per_block: Optional[Union[int, Tuple[int], Tuple[Tuple]]] = None,
transformer_layers_per_mid_block: Optional[Union[int, Tuple[int]]] = None,
temporal_transformer_layers_per_mid_block: Optional[Union[int, Tuple[int]]] = 1,
use_linear_projection: bool = False,
num_attention_heads: Union[int, Tuple[int, ...]] = 8,
motion_max_seq_length: int = 32,
motion_num_attention_heads: int = 8,
use_motion_mid_block: int = True,
motion_num_attention_heads: Union[int, Tuple[int, ...]] = 8,
reverse_motion_num_attention_heads: Optional[Union[int, Tuple[int, ...], Tuple[Tuple[int, ...], ...]]] = None,
use_motion_mid_block: bool = True,
mid_block_layers: int = 1,
encoder_hid_dim: Optional[int] = None,
encoder_hid_dim_type: Optional[str] = None,
addition_embed_type: Optional[str] = None,
@@ -264,6 +322,16 @@ class UNetMotionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin):
if isinstance(layer_number_per_block, list):
raise ValueError("Must provide 'reverse_transformer_layers_per_block` if using asymmetrical UNet.")
if (
isinstance(temporal_transformer_layers_per_block, list)
and reverse_temporal_transformer_layers_per_block is None
):
for layer_number_per_block in temporal_transformer_layers_per_block:
if isinstance(layer_number_per_block, list):
raise ValueError(
"Must provide 'reverse_temporal_transformer_layers_per_block` if using asymmetrical motion module in UNet."
)
# input
conv_in_kernel = 3
conv_out_kernel = 3
@@ -304,6 +372,20 @@ class UNetMotionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin):
if isinstance(transformer_layers_per_block, int):
transformer_layers_per_block = [transformer_layers_per_block] * len(down_block_types)
if isinstance(reverse_transformer_layers_per_block, int):
reverse_transformer_layers_per_block = [reverse_transformer_layers_per_block] * len(down_block_types)
if isinstance(temporal_transformer_layers_per_block, int):
temporal_transformer_layers_per_block = [temporal_transformer_layers_per_block] * len(down_block_types)
if isinstance(reverse_temporal_transformer_layers_per_block, int):
reverse_temporal_transformer_layers_per_block = [reverse_temporal_transformer_layers_per_block] * len(
down_block_types
)
if isinstance(motion_num_attention_heads, int):
motion_num_attention_heads = (motion_num_attention_heads,) * len(down_block_types)
# down
output_channel = block_out_channels[0]
for i, down_block_type in enumerate(down_block_types):
@@ -326,13 +408,19 @@ class UNetMotionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin):
downsample_padding=downsample_padding,
use_linear_projection=use_linear_projection,
dual_cross_attention=False,
temporal_num_attention_heads=motion_num_attention_heads,
temporal_num_attention_heads=motion_num_attention_heads[i],
temporal_max_seq_length=motion_max_seq_length,
transformer_layers_per_block=transformer_layers_per_block[i],
temporal_transformer_layers_per_block=temporal_transformer_layers_per_block[i],
)
self.down_blocks.append(down_block)
# mid
if transformer_layers_per_mid_block is None:
transformer_layers_per_mid_block = (
transformer_layers_per_block[-1] if isinstance(transformer_layers_per_block[-1], int) else 1
)
if use_motion_mid_block:
self.mid_block = UNetMidBlockCrossAttnMotion(
in_channels=block_out_channels[-1],
@@ -345,9 +433,11 @@ class UNetMotionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin):
resnet_groups=norm_num_groups,
dual_cross_attention=False,
use_linear_projection=use_linear_projection,
temporal_num_attention_heads=motion_num_attention_heads,
num_layers=mid_block_layers,
temporal_num_attention_heads=motion_num_attention_heads[-1],
temporal_max_seq_length=motion_max_seq_length,
transformer_layers_per_block=transformer_layers_per_block[-1],
transformer_layers_per_block=transformer_layers_per_mid_block,
temporal_transformer_layers_per_block=temporal_transformer_layers_per_mid_block,
)
else:
@@ -362,7 +452,8 @@ class UNetMotionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin):
resnet_groups=norm_num_groups,
dual_cross_attention=False,
use_linear_projection=use_linear_projection,
transformer_layers_per_block=transformer_layers_per_block[-1],
num_layers=mid_block_layers,
transformer_layers_per_block=transformer_layers_per_mid_block,
)
# count how many layers upsample the images
@@ -373,7 +464,13 @@ class UNetMotionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin):
reversed_num_attention_heads = list(reversed(num_attention_heads))
reversed_layers_per_block = list(reversed(layers_per_block))
reversed_cross_attention_dim = list(reversed(cross_attention_dim))
reversed_transformer_layers_per_block = list(reversed(transformer_layers_per_block))
reversed_motion_num_attention_heads = list(reversed(motion_num_attention_heads))
if reverse_transformer_layers_per_block is None:
reverse_transformer_layers_per_block = list(reversed(transformer_layers_per_block))
if reverse_temporal_transformer_layers_per_block is None:
reverse_temporal_transformer_layers_per_block = list(reversed(temporal_transformer_layers_per_block))
output_channel = reversed_block_out_channels[0]
for i, up_block_type in enumerate(up_block_types):
@@ -406,9 +503,10 @@ class UNetMotionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin):
dual_cross_attention=False,
resolution_idx=i,
use_linear_projection=use_linear_projection,
temporal_num_attention_heads=motion_num_attention_heads,
temporal_num_attention_heads=reversed_motion_num_attention_heads[i],
temporal_max_seq_length=motion_max_seq_length,
transformer_layers_per_block=reversed_transformer_layers_per_block[i],
transformer_layers_per_block=reverse_transformer_layers_per_block[i],
temporal_transformer_layers_per_block=reverse_temporal_transformer_layers_per_block[i],
)
self.up_blocks.append(up_block)
prev_output_channel = output_channel
@@ -440,6 +538,24 @@ class UNetMotionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin):
if has_motion_adapter:
motion_adapter.to(device=unet.device)
# check compatibility of number of blocks
if len(unet.config["down_block_types"]) != len(motion_adapter.config["block_out_channels"]):
raise ValueError("Incompatible Motion Adapter, got different number of blocks")
# check layers compatibility for each block
if isinstance(unet.config["layers_per_block"], int):
expanded_layers_per_block = [unet.config["layers_per_block"]] * len(unet.config["down_block_types"])
else:
expanded_layers_per_block = list(unet.config["layers_per_block"])
if isinstance(motion_adapter.config["motion_layers_per_block"], int):
expanded_adapter_layers_per_block = [motion_adapter.config["motion_layers_per_block"]] * len(
motion_adapter.config["block_out_channels"]
)
else:
expanded_adapter_layers_per_block = list(motion_adapter.config["motion_layers_per_block"])
if expanded_layers_per_block != expanded_adapter_layers_per_block:
raise ValueError("Incompatible Motion Adapter, got different number of layers per block")
# based on https://github.com/guoyww/AnimateDiff/blob/895f3220c06318ea0760131ec70408b466c49333/animatediff/models/unet.py#L459
config = dict(unet.config)
config["_class_name"] = cls.__name__
@@ -458,13 +574,20 @@ class UNetMotionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin):
up_blocks.append("CrossAttnUpBlockMotion")
else:
up_blocks.append("UpBlockMotion")
config["up_block_types"] = up_blocks
if has_motion_adapter:
config["motion_num_attention_heads"] = motion_adapter.config["motion_num_attention_heads"]
config["motion_max_seq_length"] = motion_adapter.config["motion_max_seq_length"]
config["use_motion_mid_block"] = motion_adapter.config["use_motion_mid_block"]
config["layers_per_block"] = motion_adapter.config["motion_layers_per_block"]
config["temporal_transformer_layers_per_mid_block"] = motion_adapter.config[
"motion_transformer_layers_per_mid_block"
]
config["temporal_transformer_layers_per_block"] = motion_adapter.config[
"motion_transformer_layers_per_block"
]
config["motion_num_attention_heads"] = motion_adapter.config["motion_num_attention_heads"]
# For PIA UNets we need to set the number input channels to 9
if motion_adapter.config["conv_in_channels"]:
@@ -474,7 +597,9 @@ class UNetMotionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin):
if not config.get("num_attention_heads"):
config["num_attention_heads"] = config["attention_head_dim"]
config = FrozenDict(config)
expected_kwargs, optional_kwargs = cls._get_signature_keys(cls)
config = FrozenDict({k: config.get(k) for k in config if k in expected_kwargs or k in optional_kwargs})
config["_class_name"] = cls.__name__
model = cls.from_config(config)
if not load_weights:
+8 -4
View File
@@ -16,10 +16,14 @@ from .autoencoders.vq_model import VQEncoderOutput, VQModel
class VQEncoderOutput(VQEncoderOutput):
deprecation_message = "Importing `VQEncoderOutput` from `diffusers.models.vq_model` is deprecated and this will be removed in a future version. Please use `from diffusers.models.autoencoders.vq_model import VQEncoderOutput`, instead."
deprecate("VQEncoderOutput", "0.31", deprecation_message)
def __init__(self, *args, **kwargs):
deprecation_message = "Importing `VQEncoderOutput` from `diffusers.models.vq_model` is deprecated and this will be removed in a future version. Please use `from diffusers.models.autoencoders.vq_model import VQEncoderOutput`, instead."
deprecate("VQEncoderOutput", "0.31", deprecation_message)
super().__init__(*args, **kwargs)
class VQModel(VQModel):
deprecation_message = "Importing `VQModel` from `diffusers.models.vq_model` is deprecated and this will be removed in a future version. Please use `from diffusers.models.autoencoders.vq_model import VQModel`, instead."
deprecate("VQModel", "0.31", deprecation_message)
def __init__(self, *args, **kwargs):
deprecation_message = "Importing `VQModel` from `diffusers.models.vq_model` is deprecated and this will be removed in a future version. Please use `from diffusers.models.autoencoders.vq_model import VQModel`, instead."
deprecate("VQModel", "0.31", deprecation_message)
super().__init__(*args, **kwargs)
+26
View File
@@ -20,12 +20,14 @@ from ..utils import (
_dummy_objects = {}
_import_structure = {
"controlnet": [],
"controlnet_hunyuandit": [],
"controlnet_sd3": [],
"controlnet_xs": [],
"deprecated": [],
"latent_diffusion": [],
"ledits_pp": [],
"marigold": [],
"pag": [],
"stable_diffusion": [],
"stable_diffusion_xl": [],
}
@@ -137,12 +139,26 @@ else:
"StableDiffusionXLControlNetPipeline",
]
)
_import_structure["pag"].extend(
[
"StableDiffusionPAGPipeline",
"StableDiffusionXLPAGPipeline",
"StableDiffusionXLPAGInpaintPipeline",
"StableDiffusionXLControlNetPAGPipeline",
"StableDiffusionXLPAGImg2ImgPipeline",
]
)
_import_structure["controlnet_xs"].extend(
[
"StableDiffusionControlNetXSPipeline",
"StableDiffusionXLControlNetXSPipeline",
]
)
_import_structure["controlnet_hunyuandit"].extend(
[
"HunyuanDiTControlNetPipeline",
]
)
_import_structure["controlnet_sd3"].extend(
[
"StableDiffusion3ControlNetPipeline",
@@ -400,6 +416,9 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
StableDiffusionXLControlNetInpaintPipeline,
StableDiffusionXLControlNetPipeline,
)
from .controlnet_hunyuandit import (
HunyuanDiTControlNetPipeline,
)
from .controlnet_sd3 import (
StableDiffusion3ControlNetPipeline,
)
@@ -472,6 +491,13 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
MarigoldNormalsPipeline,
)
from .musicldm import MusicLDMPipeline
from .pag import (
StableDiffusionPAGPipeline,
StableDiffusionXLControlNetPAGPipeline,
StableDiffusionXLPAGImg2ImgPipeline,
StableDiffusionXLPAGInpaintPipeline,
StableDiffusionXLPAGPipeline,
)
from .paint_by_example import PaintByExamplePipeline
from .pia import PIAPipeline
from .pixart_alpha import PixArtAlphaPipeline, PixArtSigmaPipeline
@@ -352,6 +352,9 @@ class AnimateDiffPipeline(
def prepare_ip_adapter_image_embeds(
self, ip_adapter_image, ip_adapter_image_embeds, device, num_images_per_prompt, do_classifier_free_guidance
):
image_embeds = []
if do_classifier_free_guidance:
negative_image_embeds = []
if ip_adapter_image_embeds is None:
if not isinstance(ip_adapter_image, list):
ip_adapter_image = [ip_adapter_image]
@@ -361,7 +364,6 @@ class AnimateDiffPipeline(
f"`ip_adapter_image` must have same length as the number of IP Adapters. Got {len(ip_adapter_image)} images and {len(self.unet.encoder_hid_proj.image_projection_layers)} IP Adapters."
)
image_embeds = []
for single_ip_adapter_image, image_proj_layer in zip(
ip_adapter_image, self.unet.encoder_hid_proj.image_projection_layers
):
@@ -369,36 +371,28 @@ class AnimateDiffPipeline(
single_image_embeds, single_negative_image_embeds = self.encode_image(
single_ip_adapter_image, device, 1, output_hidden_state
)
single_image_embeds = torch.stack([single_image_embeds] * num_images_per_prompt, dim=0)
single_negative_image_embeds = torch.stack(
[single_negative_image_embeds] * num_images_per_prompt, dim=0
)
image_embeds.append(single_image_embeds[None, :])
if do_classifier_free_guidance:
single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds])
single_image_embeds = single_image_embeds.to(device)
image_embeds.append(single_image_embeds)
negative_image_embeds.append(single_negative_image_embeds[None, :])
else:
repeat_dims = [1]
image_embeds = []
for single_image_embeds in ip_adapter_image_embeds:
if do_classifier_free_guidance:
single_negative_image_embeds, single_image_embeds = single_image_embeds.chunk(2)
single_image_embeds = single_image_embeds.repeat(
num_images_per_prompt, *(repeat_dims * len(single_image_embeds.shape[1:]))
)
single_negative_image_embeds = single_negative_image_embeds.repeat(
num_images_per_prompt, *(repeat_dims * len(single_negative_image_embeds.shape[1:]))
)
single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds])
else:
single_image_embeds = single_image_embeds.repeat(
num_images_per_prompt, *(repeat_dims * len(single_image_embeds.shape[1:]))
)
negative_image_embeds.append(single_negative_image_embeds)
image_embeds.append(single_image_embeds)
return image_embeds
ip_adapter_image_embeds = []
for i, single_image_embeds in enumerate(image_embeds):
single_image_embeds = torch.cat([single_image_embeds] * num_images_per_prompt, dim=0)
if do_classifier_free_guidance:
single_negative_image_embeds = torch.cat([negative_image_embeds[i]] * num_images_per_prompt, dim=0)
single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds], dim=0)
single_image_embeds = single_image_embeds.to(device=device)
ip_adapter_image_embeds.append(single_image_embeds)
return ip_adapter_image_embeds
# Copied from diffusers.pipelines.text_to_video_synthesis/pipeline_text_to_video_synth.TextToVideoSDPipeline.decode_latents
def decode_latents(self, latents):
@@ -564,6 +564,9 @@ class AnimateDiffSDXLPipeline(
def prepare_ip_adapter_image_embeds(
self, ip_adapter_image, ip_adapter_image_embeds, device, num_images_per_prompt, do_classifier_free_guidance
):
image_embeds = []
if do_classifier_free_guidance:
negative_image_embeds = []
if ip_adapter_image_embeds is None:
if not isinstance(ip_adapter_image, list):
ip_adapter_image = [ip_adapter_image]
@@ -573,7 +576,6 @@ class AnimateDiffSDXLPipeline(
f"`ip_adapter_image` must have same length as the number of IP Adapters. Got {len(ip_adapter_image)} images and {len(self.unet.encoder_hid_proj.image_projection_layers)} IP Adapters."
)
image_embeds = []
for single_ip_adapter_image, image_proj_layer in zip(
ip_adapter_image, self.unet.encoder_hid_proj.image_projection_layers
):
@@ -581,36 +583,28 @@ class AnimateDiffSDXLPipeline(
single_image_embeds, single_negative_image_embeds = self.encode_image(
single_ip_adapter_image, device, 1, output_hidden_state
)
single_image_embeds = torch.stack([single_image_embeds] * num_images_per_prompt, dim=0)
single_negative_image_embeds = torch.stack(
[single_negative_image_embeds] * num_images_per_prompt, dim=0
)
image_embeds.append(single_image_embeds[None, :])
if do_classifier_free_guidance:
single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds])
single_image_embeds = single_image_embeds.to(device)
image_embeds.append(single_image_embeds)
negative_image_embeds.append(single_negative_image_embeds[None, :])
else:
repeat_dims = [1]
image_embeds = []
for single_image_embeds in ip_adapter_image_embeds:
if do_classifier_free_guidance:
single_negative_image_embeds, single_image_embeds = single_image_embeds.chunk(2)
single_image_embeds = single_image_embeds.repeat(
num_images_per_prompt, *(repeat_dims * len(single_image_embeds.shape[1:]))
)
single_negative_image_embeds = single_negative_image_embeds.repeat(
num_images_per_prompt, *(repeat_dims * len(single_negative_image_embeds.shape[1:]))
)
single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds])
else:
single_image_embeds = single_image_embeds.repeat(
num_images_per_prompt, *(repeat_dims * len(single_image_embeds.shape[1:]))
)
negative_image_embeds.append(single_negative_image_embeds)
image_embeds.append(single_image_embeds)
return image_embeds
ip_adapter_image_embeds = []
for i, single_image_embeds in enumerate(image_embeds):
single_image_embeds = torch.cat([single_image_embeds] * num_images_per_prompt, dim=0)
if do_classifier_free_guidance:
single_negative_image_embeds = torch.cat([negative_image_embeds[i]] * num_images_per_prompt, dim=0)
single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds], dim=0)
single_image_embeds = single_image_embeds.to(device=device)
ip_adapter_image_embeds.append(single_image_embeds)
return ip_adapter_image_embeds
# Copied from diffusers.pipelines.text_to_video_synthesis/pipeline_text_to_video_synth.TextToVideoSDPipeline.decode_latents
def decode_latents(self, latents):
@@ -456,6 +456,9 @@ class AnimateDiffVideoToVideoPipeline(
def prepare_ip_adapter_image_embeds(
self, ip_adapter_image, ip_adapter_image_embeds, device, num_images_per_prompt, do_classifier_free_guidance
):
image_embeds = []
if do_classifier_free_guidance:
negative_image_embeds = []
if ip_adapter_image_embeds is None:
if not isinstance(ip_adapter_image, list):
ip_adapter_image = [ip_adapter_image]
@@ -465,7 +468,6 @@ class AnimateDiffVideoToVideoPipeline(
f"`ip_adapter_image` must have same length as the number of IP Adapters. Got {len(ip_adapter_image)} images and {len(self.unet.encoder_hid_proj.image_projection_layers)} IP Adapters."
)
image_embeds = []
for single_ip_adapter_image, image_proj_layer in zip(
ip_adapter_image, self.unet.encoder_hid_proj.image_projection_layers
):
@@ -473,36 +475,28 @@ class AnimateDiffVideoToVideoPipeline(
single_image_embeds, single_negative_image_embeds = self.encode_image(
single_ip_adapter_image, device, 1, output_hidden_state
)
single_image_embeds = torch.stack([single_image_embeds] * num_images_per_prompt, dim=0)
single_negative_image_embeds = torch.stack(
[single_negative_image_embeds] * num_images_per_prompt, dim=0
)
image_embeds.append(single_image_embeds[None, :])
if do_classifier_free_guidance:
single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds])
single_image_embeds = single_image_embeds.to(device)
image_embeds.append(single_image_embeds)
negative_image_embeds.append(single_negative_image_embeds[None, :])
else:
repeat_dims = [1]
image_embeds = []
for single_image_embeds in ip_adapter_image_embeds:
if do_classifier_free_guidance:
single_negative_image_embeds, single_image_embeds = single_image_embeds.chunk(2)
single_image_embeds = single_image_embeds.repeat(
num_images_per_prompt, *(repeat_dims * len(single_image_embeds.shape[1:]))
)
single_negative_image_embeds = single_negative_image_embeds.repeat(
num_images_per_prompt, *(repeat_dims * len(single_negative_image_embeds.shape[1:]))
)
single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds])
else:
single_image_embeds = single_image_embeds.repeat(
num_images_per_prompt, *(repeat_dims * len(single_image_embeds.shape[1:]))
)
negative_image_embeds.append(single_negative_image_embeds)
image_embeds.append(single_image_embeds)
return image_embeds
ip_adapter_image_embeds = []
for i, single_image_embeds in enumerate(image_embeds):
single_image_embeds = torch.cat([single_image_embeds] * num_images_per_prompt, dim=0)
if do_classifier_free_guidance:
single_negative_image_embeds = torch.cat([negative_image_embeds[i]] * num_images_per_prompt, dim=0)
single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds], dim=0)
single_image_embeds = single_image_embeds.to(device=device)
ip_adapter_image_embeds.append(single_image_embeds)
return ip_adapter_image_embeds
# Copied from diffusers.pipelines.text_to_video_synthesis/pipeline_text_to_video_synth.TextToVideoSDPipeline.decode_latents
def decode_latents(self, latents):
+71 -4
View File
@@ -46,6 +46,13 @@ from .kandinsky2_2 import (
)
from .kandinsky3 import Kandinsky3Img2ImgPipeline, Kandinsky3Pipeline
from .latent_consistency_models import LatentConsistencyModelImg2ImgPipeline, LatentConsistencyModelPipeline
from .pag import (
StableDiffusionPAGPipeline,
StableDiffusionXLControlNetPAGPipeline,
StableDiffusionXLPAGImg2ImgPipeline,
StableDiffusionXLPAGInpaintPipeline,
StableDiffusionXLPAGPipeline,
)
from .pixart_alpha import PixArtAlphaPipeline, PixArtSigmaPipeline
from .stable_cascade import StableCascadeCombinedPipeline, StableCascadeDecoderPipeline
from .stable_diffusion import (
@@ -82,6 +89,9 @@ AUTO_TEXT2IMAGE_PIPELINES_MAPPING = OrderedDict(
("lcm", LatentConsistencyModelPipeline),
("pixart-alpha", PixArtAlphaPipeline),
("pixart-sigma", PixArtSigmaPipeline),
("stable-diffusion-pag", StableDiffusionPAGPipeline),
("stable-diffusion-xl-pag", StableDiffusionXLPAGPipeline),
("stable-diffusion-xl-controlnet-pag", StableDiffusionXLControlNetPAGPipeline),
]
)
@@ -96,6 +106,7 @@ AUTO_IMAGE2IMAGE_PIPELINES_MAPPING = OrderedDict(
("kandinsky3", Kandinsky3Img2ImgPipeline),
("stable-diffusion-controlnet", StableDiffusionControlNetImg2ImgPipeline),
("stable-diffusion-xl-controlnet", StableDiffusionXLControlNetImg2ImgPipeline),
("stable-diffusion-xl-pag", StableDiffusionXLPAGImg2ImgPipeline),
("lcm", LatentConsistencyModelImg2ImgPipeline),
]
)
@@ -109,6 +120,7 @@ AUTO_INPAINT_PIPELINES_MAPPING = OrderedDict(
("kandinsky22", KandinskyV22InpaintCombinedPipeline),
("stable-diffusion-controlnet", StableDiffusionControlNetInpaintPipeline),
("stable-diffusion-xl-controlnet", StableDiffusionXLControlNetInpaintPipeline),
("stable-diffusion-xl-pag", StableDiffusionXLPAGInpaintPipeline),
]
)
@@ -340,6 +352,10 @@ class AutoPipelineForText2Image(ConfigMixin):
if "controlnet" in kwargs:
orig_class_name = config["_class_name"].replace("Pipeline", "ControlNetPipeline")
if "enable_pag" in kwargs:
enable_pag = kwargs.pop("enable_pag")
if enable_pag:
orig_class_name = orig_class_name.replace("Pipeline", "PAGPipeline")
text_2_image_cls = _get_task_class(AUTO_TEXT2IMAGE_PIPELINES_MAPPING, orig_class_name)
@@ -383,14 +399,28 @@ class AutoPipelineForText2Image(ConfigMixin):
if "controlnet" in kwargs:
if kwargs["controlnet"] is not None:
to_replace = "PAGPipeline" if "PAG" in text_2_image_cls.__name__ else "Pipeline"
text_2_image_cls = _get_task_class(
AUTO_TEXT2IMAGE_PIPELINES_MAPPING,
text_2_image_cls.__name__.replace("ControlNet", "").replace("Pipeline", "ControlNetPipeline"),
text_2_image_cls.__name__.replace("ControlNet", "").replace(to_replace, "ControlNet" + to_replace),
)
else:
text_2_image_cls = _get_task_class(
AUTO_TEXT2IMAGE_PIPELINES_MAPPING,
text_2_image_cls.__name__.replace("ControlNetPipeline", "Pipeline"),
text_2_image_cls.__name__.replace("ControlNet", ""),
)
if "enable_pag" in kwargs:
enable_pag = kwargs.pop("enable_pag")
if enable_pag:
text_2_image_cls = _get_task_class(
AUTO_TEXT2IMAGE_PIPELINES_MAPPING,
text_2_image_cls.__name__.replace("PAG", "").replace("Pipeline", "PAGPipeline"),
)
else:
text_2_image_cls = _get_task_class(
AUTO_TEXT2IMAGE_PIPELINES_MAPPING,
text_2_image_cls.__name__.replace("PAG", ""),
)
# define expected module and optional kwargs given the pipeline signature
@@ -613,6 +643,10 @@ class AutoPipelineForImage2Image(ConfigMixin):
if "controlnet" in kwargs:
orig_class_name = config["_class_name"].replace("Pipeline", "ControlNetPipeline")
if "enable_pag" in kwargs:
enable_pag = kwargs.pop("enable_pag")
if enable_pag:
orig_class_name = orig_class_name.replace("Pipeline", "PAGPipeline")
image_2_image_cls = _get_task_class(AUTO_IMAGE2IMAGE_PIPELINES_MAPPING, orig_class_name)
@@ -658,16 +692,32 @@ class AutoPipelineForImage2Image(ConfigMixin):
if "controlnet" in kwargs:
if kwargs["controlnet"] is not None:
to_replace = "Img2ImgPipeline"
if "PAG" in image_2_image_cls.__name__:
to_replace = "PAG" + to_replace
image_2_image_cls = _get_task_class(
AUTO_IMAGE2IMAGE_PIPELINES_MAPPING,
image_2_image_cls.__name__.replace("ControlNet", "").replace(
"Img2ImgPipeline", "ControlNetImg2ImgPipeline"
to_replace, "ControlNet" + to_replace
),
)
else:
image_2_image_cls = _get_task_class(
AUTO_IMAGE2IMAGE_PIPELINES_MAPPING,
image_2_image_cls.__name__.replace("ControlNetImg2ImgPipeline", "Img2ImgPipeline"),
image_2_image_cls.__name__.replace("ControlNet", ""),
)
if "enable_pag" in kwargs:
enable_pag = kwargs.pop("enable_pag")
if enable_pag:
image_2_image_cls = _get_task_class(
AUTO_IMAGE2IMAGE_PIPELINES_MAPPING,
image_2_image_cls.__name__.replace("PAG", "").replace("Img2ImgPipeline", "PAGImg2ImgPipeline"),
)
else:
image_2_image_cls = _get_task_class(
AUTO_IMAGE2IMAGE_PIPELINES_MAPPING,
image_2_image_cls.__name__.replace("PAG", ""),
)
# define expected module and optional kwargs given the pipeline signature
@@ -889,6 +939,10 @@ class AutoPipelineForInpainting(ConfigMixin):
if "controlnet" in kwargs:
orig_class_name = config["_class_name"].replace("Pipeline", "ControlNetPipeline")
if "enable_pag" in kwargs:
enable_pag = kwargs.pop("enable_pag")
if enable_pag:
orig_class_name = config["_class_name"].replace("Pipeline", "PAGPipeline")
inpainting_cls = _get_task_class(AUTO_INPAINT_PIPELINES_MAPPING, orig_class_name)
@@ -945,6 +999,19 @@ class AutoPipelineForInpainting(ConfigMixin):
inpainting_cls.__name__.replace("ControlNetInpaintPipeline", "InpaintPipeline"),
)
if "enable_pag" in kwargs:
enable_pag = kwargs.pop("enable_pag")
if enable_pag:
inpainting_cls = _get_task_class(
AUTO_INPAINT_PIPELINES_MAPPING,
inpainting_cls.__name__.replace("PAG", "").replace("InpaintPipeline", "PAGInpaintPipeline"),
)
else:
inpainting_cls = _get_task_class(
AUTO_INPAINT_PIPELINES_MAPPING,
inpainting_cls.__name__.replace("PAGInpaintPipeline", "InpaintPipeline"),
)
# define expected module and optional kwargs given the pipeline signature
expected_modules, optional_kwargs = inpainting_cls._get_signature_keys(inpainting_cls)
@@ -52,7 +52,7 @@ EXAMPLE_DOC_STRING = """
>>> image.save("cd_imagenet64_l2_onestep_sample_penguin.png")
>>> # Multistep sampling, class-conditional image generation
>>> # Timesteps can be explicitly specified; the particular timesteps below are from the original Github repo:
>>> # Timesteps can be explicitly specified; the particular timesteps below are from the original GitHub repo:
>>> # https://github.com/openai/consistency_models/blob/main/scripts/launch.sh#L77
>>> image = pipe(num_inference_steps=None, timesteps=[22, 0], class_labels=145).images[0]
>>> image.save("cd_imagenet64_l2_multistep_sample_penguin.png")
@@ -499,6 +499,9 @@ class StableDiffusionControlNetPipeline(
def prepare_ip_adapter_image_embeds(
self, ip_adapter_image, ip_adapter_image_embeds, device, num_images_per_prompt, do_classifier_free_guidance
):
image_embeds = []
if do_classifier_free_guidance:
negative_image_embeds = []
if ip_adapter_image_embeds is None:
if not isinstance(ip_adapter_image, list):
ip_adapter_image = [ip_adapter_image]
@@ -508,7 +511,6 @@ class StableDiffusionControlNetPipeline(
f"`ip_adapter_image` must have same length as the number of IP Adapters. Got {len(ip_adapter_image)} images and {len(self.unet.encoder_hid_proj.image_projection_layers)} IP Adapters."
)
image_embeds = []
for single_ip_adapter_image, image_proj_layer in zip(
ip_adapter_image, self.unet.encoder_hid_proj.image_projection_layers
):
@@ -516,36 +518,28 @@ class StableDiffusionControlNetPipeline(
single_image_embeds, single_negative_image_embeds = self.encode_image(
single_ip_adapter_image, device, 1, output_hidden_state
)
single_image_embeds = torch.stack([single_image_embeds] * num_images_per_prompt, dim=0)
single_negative_image_embeds = torch.stack(
[single_negative_image_embeds] * num_images_per_prompt, dim=0
)
image_embeds.append(single_image_embeds[None, :])
if do_classifier_free_guidance:
single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds])
single_image_embeds = single_image_embeds.to(device)
image_embeds.append(single_image_embeds)
negative_image_embeds.append(single_negative_image_embeds[None, :])
else:
repeat_dims = [1]
image_embeds = []
for single_image_embeds in ip_adapter_image_embeds:
if do_classifier_free_guidance:
single_negative_image_embeds, single_image_embeds = single_image_embeds.chunk(2)
single_image_embeds = single_image_embeds.repeat(
num_images_per_prompt, *(repeat_dims * len(single_image_embeds.shape[1:]))
)
single_negative_image_embeds = single_negative_image_embeds.repeat(
num_images_per_prompt, *(repeat_dims * len(single_negative_image_embeds.shape[1:]))
)
single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds])
else:
single_image_embeds = single_image_embeds.repeat(
num_images_per_prompt, *(repeat_dims * len(single_image_embeds.shape[1:]))
)
negative_image_embeds.append(single_negative_image_embeds)
image_embeds.append(single_image_embeds)
return image_embeds
ip_adapter_image_embeds = []
for i, single_image_embeds in enumerate(image_embeds):
single_image_embeds = torch.cat([single_image_embeds] * num_images_per_prompt, dim=0)
if do_classifier_free_guidance:
single_negative_image_embeds = torch.cat([negative_image_embeds[i]] * num_images_per_prompt, dim=0)
single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds], dim=0)
single_image_embeds = single_image_embeds.to(device=device)
ip_adapter_image_embeds.append(single_image_embeds)
return ip_adapter_image_embeds
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker
def run_safety_checker(self, image, device, dtype):
@@ -477,6 +477,9 @@ class StableDiffusionControlNetImg2ImgPipeline(
def prepare_ip_adapter_image_embeds(
self, ip_adapter_image, ip_adapter_image_embeds, device, num_images_per_prompt, do_classifier_free_guidance
):
image_embeds = []
if do_classifier_free_guidance:
negative_image_embeds = []
if ip_adapter_image_embeds is None:
if not isinstance(ip_adapter_image, list):
ip_adapter_image = [ip_adapter_image]
@@ -486,7 +489,6 @@ class StableDiffusionControlNetImg2ImgPipeline(
f"`ip_adapter_image` must have same length as the number of IP Adapters. Got {len(ip_adapter_image)} images and {len(self.unet.encoder_hid_proj.image_projection_layers)} IP Adapters."
)
image_embeds = []
for single_ip_adapter_image, image_proj_layer in zip(
ip_adapter_image, self.unet.encoder_hid_proj.image_projection_layers
):
@@ -494,36 +496,28 @@ class StableDiffusionControlNetImg2ImgPipeline(
single_image_embeds, single_negative_image_embeds = self.encode_image(
single_ip_adapter_image, device, 1, output_hidden_state
)
single_image_embeds = torch.stack([single_image_embeds] * num_images_per_prompt, dim=0)
single_negative_image_embeds = torch.stack(
[single_negative_image_embeds] * num_images_per_prompt, dim=0
)
image_embeds.append(single_image_embeds[None, :])
if do_classifier_free_guidance:
single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds])
single_image_embeds = single_image_embeds.to(device)
image_embeds.append(single_image_embeds)
negative_image_embeds.append(single_negative_image_embeds[None, :])
else:
repeat_dims = [1]
image_embeds = []
for single_image_embeds in ip_adapter_image_embeds:
if do_classifier_free_guidance:
single_negative_image_embeds, single_image_embeds = single_image_embeds.chunk(2)
single_image_embeds = single_image_embeds.repeat(
num_images_per_prompt, *(repeat_dims * len(single_image_embeds.shape[1:]))
)
single_negative_image_embeds = single_negative_image_embeds.repeat(
num_images_per_prompt, *(repeat_dims * len(single_negative_image_embeds.shape[1:]))
)
single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds])
else:
single_image_embeds = single_image_embeds.repeat(
num_images_per_prompt, *(repeat_dims * len(single_image_embeds.shape[1:]))
)
negative_image_embeds.append(single_negative_image_embeds)
image_embeds.append(single_image_embeds)
return image_embeds
ip_adapter_image_embeds = []
for i, single_image_embeds in enumerate(image_embeds):
single_image_embeds = torch.cat([single_image_embeds] * num_images_per_prompt, dim=0)
if do_classifier_free_guidance:
single_negative_image_embeds = torch.cat([negative_image_embeds[i]] * num_images_per_prompt, dim=0)
single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds], dim=0)
single_image_embeds = single_image_embeds.to(device=device)
ip_adapter_image_embeds.append(single_image_embeds)
return ip_adapter_image_embeds
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker
def run_safety_checker(self, image, device, dtype):
@@ -479,6 +479,9 @@ class StableDiffusionControlNetInpaintPipeline(
def prepare_ip_adapter_image_embeds(
self, ip_adapter_image, ip_adapter_image_embeds, device, num_images_per_prompt, do_classifier_free_guidance
):
image_embeds = []
if do_classifier_free_guidance:
negative_image_embeds = []
if ip_adapter_image_embeds is None:
if not isinstance(ip_adapter_image, list):
ip_adapter_image = [ip_adapter_image]
@@ -488,7 +491,6 @@ class StableDiffusionControlNetInpaintPipeline(
f"`ip_adapter_image` must have same length as the number of IP Adapters. Got {len(ip_adapter_image)} images and {len(self.unet.encoder_hid_proj.image_projection_layers)} IP Adapters."
)
image_embeds = []
for single_ip_adapter_image, image_proj_layer in zip(
ip_adapter_image, self.unet.encoder_hid_proj.image_projection_layers
):
@@ -496,36 +498,28 @@ class StableDiffusionControlNetInpaintPipeline(
single_image_embeds, single_negative_image_embeds = self.encode_image(
single_ip_adapter_image, device, 1, output_hidden_state
)
single_image_embeds = torch.stack([single_image_embeds] * num_images_per_prompt, dim=0)
single_negative_image_embeds = torch.stack(
[single_negative_image_embeds] * num_images_per_prompt, dim=0
)
image_embeds.append(single_image_embeds[None, :])
if do_classifier_free_guidance:
single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds])
single_image_embeds = single_image_embeds.to(device)
image_embeds.append(single_image_embeds)
negative_image_embeds.append(single_negative_image_embeds[None, :])
else:
repeat_dims = [1]
image_embeds = []
for single_image_embeds in ip_adapter_image_embeds:
if do_classifier_free_guidance:
single_negative_image_embeds, single_image_embeds = single_image_embeds.chunk(2)
single_image_embeds = single_image_embeds.repeat(
num_images_per_prompt, *(repeat_dims * len(single_image_embeds.shape[1:]))
)
single_negative_image_embeds = single_negative_image_embeds.repeat(
num_images_per_prompt, *(repeat_dims * len(single_negative_image_embeds.shape[1:]))
)
single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds])
else:
single_image_embeds = single_image_embeds.repeat(
num_images_per_prompt, *(repeat_dims * len(single_image_embeds.shape[1:]))
)
negative_image_embeds.append(single_negative_image_embeds)
image_embeds.append(single_image_embeds)
return image_embeds
ip_adapter_image_embeds = []
for i, single_image_embeds in enumerate(image_embeds):
single_image_embeds = torch.cat([single_image_embeds] * num_images_per_prompt, dim=0)
if do_classifier_free_guidance:
single_negative_image_embeds = torch.cat([negative_image_embeds[i]] * num_images_per_prompt, dim=0)
single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds], dim=0)
single_image_embeds = single_image_embeds.to(device=device)
ip_adapter_image_embeds.append(single_image_embeds)
return ip_adapter_image_embeds
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker
def run_safety_checker(self, image, device, dtype):
@@ -84,6 +84,7 @@ EXAMPLE_DOC_STRING = """
>>> # !pip install transformers accelerate
>>> from diffusers import StableDiffusionXLControlNetInpaintPipeline, ControlNetModel, DDIMScheduler
>>> from diffusers.utils import load_image
>>> from PIL import Image
>>> import numpy as np
>>> import torch
@@ -532,6 +533,9 @@ class StableDiffusionXLControlNetInpaintPipeline(
def prepare_ip_adapter_image_embeds(
self, ip_adapter_image, ip_adapter_image_embeds, device, num_images_per_prompt, do_classifier_free_guidance
):
image_embeds = []
if do_classifier_free_guidance:
negative_image_embeds = []
if ip_adapter_image_embeds is None:
if not isinstance(ip_adapter_image, list):
ip_adapter_image = [ip_adapter_image]
@@ -541,7 +545,6 @@ class StableDiffusionXLControlNetInpaintPipeline(
f"`ip_adapter_image` must have same length as the number of IP Adapters. Got {len(ip_adapter_image)} images and {len(self.unet.encoder_hid_proj.image_projection_layers)} IP Adapters."
)
image_embeds = []
for single_ip_adapter_image, image_proj_layer in zip(
ip_adapter_image, self.unet.encoder_hid_proj.image_projection_layers
):
@@ -549,36 +552,28 @@ class StableDiffusionXLControlNetInpaintPipeline(
single_image_embeds, single_negative_image_embeds = self.encode_image(
single_ip_adapter_image, device, 1, output_hidden_state
)
single_image_embeds = torch.stack([single_image_embeds] * num_images_per_prompt, dim=0)
single_negative_image_embeds = torch.stack(
[single_negative_image_embeds] * num_images_per_prompt, dim=0
)
image_embeds.append(single_image_embeds[None, :])
if do_classifier_free_guidance:
single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds])
single_image_embeds = single_image_embeds.to(device)
image_embeds.append(single_image_embeds)
negative_image_embeds.append(single_negative_image_embeds[None, :])
else:
repeat_dims = [1]
image_embeds = []
for single_image_embeds in ip_adapter_image_embeds:
if do_classifier_free_guidance:
single_negative_image_embeds, single_image_embeds = single_image_embeds.chunk(2)
single_image_embeds = single_image_embeds.repeat(
num_images_per_prompt, *(repeat_dims * len(single_image_embeds.shape[1:]))
)
single_negative_image_embeds = single_negative_image_embeds.repeat(
num_images_per_prompt, *(repeat_dims * len(single_negative_image_embeds.shape[1:]))
)
single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds])
else:
single_image_embeds = single_image_embeds.repeat(
num_images_per_prompt, *(repeat_dims * len(single_image_embeds.shape[1:]))
)
negative_image_embeds.append(single_negative_image_embeds)
image_embeds.append(single_image_embeds)
return image_embeds
ip_adapter_image_embeds = []
for i, single_image_embeds in enumerate(image_embeds):
single_image_embeds = torch.cat([single_image_embeds] * num_images_per_prompt, dim=0)
if do_classifier_free_guidance:
single_negative_image_embeds = torch.cat([negative_image_embeds[i]] * num_images_per_prompt, dim=0)
single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds], dim=0)
single_image_embeds = single_image_embeds.to(device=device)
ip_adapter_image_embeds.append(single_image_embeds)
return ip_adapter_image_embeds
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
def prepare_extra_step_kwargs(self, generator, eta):
@@ -554,6 +554,9 @@ class StableDiffusionXLControlNetPipeline(
def prepare_ip_adapter_image_embeds(
self, ip_adapter_image, ip_adapter_image_embeds, device, num_images_per_prompt, do_classifier_free_guidance
):
image_embeds = []
if do_classifier_free_guidance:
negative_image_embeds = []
if ip_adapter_image_embeds is None:
if not isinstance(ip_adapter_image, list):
ip_adapter_image = [ip_adapter_image]
@@ -563,7 +566,6 @@ class StableDiffusionXLControlNetPipeline(
f"`ip_adapter_image` must have same length as the number of IP Adapters. Got {len(ip_adapter_image)} images and {len(self.unet.encoder_hid_proj.image_projection_layers)} IP Adapters."
)
image_embeds = []
for single_ip_adapter_image, image_proj_layer in zip(
ip_adapter_image, self.unet.encoder_hid_proj.image_projection_layers
):
@@ -571,36 +573,28 @@ class StableDiffusionXLControlNetPipeline(
single_image_embeds, single_negative_image_embeds = self.encode_image(
single_ip_adapter_image, device, 1, output_hidden_state
)
single_image_embeds = torch.stack([single_image_embeds] * num_images_per_prompt, dim=0)
single_negative_image_embeds = torch.stack(
[single_negative_image_embeds] * num_images_per_prompt, dim=0
)
image_embeds.append(single_image_embeds[None, :])
if do_classifier_free_guidance:
single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds])
single_image_embeds = single_image_embeds.to(device)
image_embeds.append(single_image_embeds)
negative_image_embeds.append(single_negative_image_embeds[None, :])
else:
repeat_dims = [1]
image_embeds = []
for single_image_embeds in ip_adapter_image_embeds:
if do_classifier_free_guidance:
single_negative_image_embeds, single_image_embeds = single_image_embeds.chunk(2)
single_image_embeds = single_image_embeds.repeat(
num_images_per_prompt, *(repeat_dims * len(single_image_embeds.shape[1:]))
)
single_negative_image_embeds = single_negative_image_embeds.repeat(
num_images_per_prompt, *(repeat_dims * len(single_negative_image_embeds.shape[1:]))
)
single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds])
else:
single_image_embeds = single_image_embeds.repeat(
num_images_per_prompt, *(repeat_dims * len(single_image_embeds.shape[1:]))
)
negative_image_embeds.append(single_negative_image_embeds)
image_embeds.append(single_image_embeds)
return image_embeds
ip_adapter_image_embeds = []
for i, single_image_embeds in enumerate(image_embeds):
single_image_embeds = torch.cat([single_image_embeds] * num_images_per_prompt, dim=0)
if do_classifier_free_guidance:
single_negative_image_embeds = torch.cat([negative_image_embeds[i]] * num_images_per_prompt, dim=0)
single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds], dim=0)
single_image_embeds = single_image_embeds.to(device=device)
ip_adapter_image_embeds.append(single_image_embeds)
return ip_adapter_image_embeds
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
def prepare_extra_step_kwargs(self, generator, eta):
@@ -548,6 +548,9 @@ class StableDiffusionXLControlNetImg2ImgPipeline(
def prepare_ip_adapter_image_embeds(
self, ip_adapter_image, ip_adapter_image_embeds, device, num_images_per_prompt, do_classifier_free_guidance
):
image_embeds = []
if do_classifier_free_guidance:
negative_image_embeds = []
if ip_adapter_image_embeds is None:
if not isinstance(ip_adapter_image, list):
ip_adapter_image = [ip_adapter_image]
@@ -557,7 +560,6 @@ class StableDiffusionXLControlNetImg2ImgPipeline(
f"`ip_adapter_image` must have same length as the number of IP Adapters. Got {len(ip_adapter_image)} images and {len(self.unet.encoder_hid_proj.image_projection_layers)} IP Adapters."
)
image_embeds = []
for single_ip_adapter_image, image_proj_layer in zip(
ip_adapter_image, self.unet.encoder_hid_proj.image_projection_layers
):
@@ -565,36 +567,28 @@ class StableDiffusionXLControlNetImg2ImgPipeline(
single_image_embeds, single_negative_image_embeds = self.encode_image(
single_ip_adapter_image, device, 1, output_hidden_state
)
single_image_embeds = torch.stack([single_image_embeds] * num_images_per_prompt, dim=0)
single_negative_image_embeds = torch.stack(
[single_negative_image_embeds] * num_images_per_prompt, dim=0
)
image_embeds.append(single_image_embeds[None, :])
if do_classifier_free_guidance:
single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds])
single_image_embeds = single_image_embeds.to(device)
image_embeds.append(single_image_embeds)
negative_image_embeds.append(single_negative_image_embeds[None, :])
else:
repeat_dims = [1]
image_embeds = []
for single_image_embeds in ip_adapter_image_embeds:
if do_classifier_free_guidance:
single_negative_image_embeds, single_image_embeds = single_image_embeds.chunk(2)
single_image_embeds = single_image_embeds.repeat(
num_images_per_prompt, *(repeat_dims * len(single_image_embeds.shape[1:]))
)
single_negative_image_embeds = single_negative_image_embeds.repeat(
num_images_per_prompt, *(repeat_dims * len(single_negative_image_embeds.shape[1:]))
)
single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds])
else:
single_image_embeds = single_image_embeds.repeat(
num_images_per_prompt, *(repeat_dims * len(single_image_embeds.shape[1:]))
)
negative_image_embeds.append(single_negative_image_embeds)
image_embeds.append(single_image_embeds)
return image_embeds
ip_adapter_image_embeds = []
for i, single_image_embeds in enumerate(image_embeds):
single_image_embeds = torch.cat([single_image_embeds] * num_images_per_prompt, dim=0)
if do_classifier_free_guidance:
single_negative_image_embeds = torch.cat([negative_image_embeds[i]] * num_images_per_prompt, dim=0)
single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds], dim=0)
single_image_embeds = single_image_embeds.to(device=device)
ip_adapter_image_embeds.append(single_image_embeds)
return ip_adapter_image_embeds
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
def prepare_extra_step_kwargs(self, generator, eta):
@@ -0,0 +1,48 @@
from typing import TYPE_CHECKING
from ...utils import (
DIFFUSERS_SLOW_IMPORT,
OptionalDependencyNotAvailable,
_LazyModule,
get_objects_from_module,
is_torch_available,
is_transformers_available,
)
_dummy_objects = {}
_import_structure = {}
try:
if not (is_transformers_available() and is_torch_available()):
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
from ...utils import dummy_torch_and_transformers_objects # noqa F403
_dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects))
else:
_import_structure["pipeline_hunyuandit_controlnet"] = ["HunyuanDiTControlNetPipeline"]
if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
try:
if not (is_transformers_available() and is_torch_available()):
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
from ...utils.dummy_torch_and_transformers_objects import *
else:
from .pipeline_hunyuandit_controlnet import HunyuanDiTControlNetPipeline
else:
import sys
sys.modules[__name__] = _LazyModule(
__name__,
globals()["__file__"],
_import_structure,
module_spec=__spec__,
)
for name, value in _dummy_objects.items():
setattr(sys.modules[__name__], name, value)
File diff suppressed because it is too large Load Diff
@@ -441,6 +441,9 @@ class LatentConsistencyModelImg2ImgPipeline(
def prepare_ip_adapter_image_embeds(
self, ip_adapter_image, ip_adapter_image_embeds, device, num_images_per_prompt, do_classifier_free_guidance
):
image_embeds = []
if do_classifier_free_guidance:
negative_image_embeds = []
if ip_adapter_image_embeds is None:
if not isinstance(ip_adapter_image, list):
ip_adapter_image = [ip_adapter_image]
@@ -450,7 +453,6 @@ class LatentConsistencyModelImg2ImgPipeline(
f"`ip_adapter_image` must have same length as the number of IP Adapters. Got {len(ip_adapter_image)} images and {len(self.unet.encoder_hid_proj.image_projection_layers)} IP Adapters."
)
image_embeds = []
for single_ip_adapter_image, image_proj_layer in zip(
ip_adapter_image, self.unet.encoder_hid_proj.image_projection_layers
):
@@ -458,36 +460,28 @@ class LatentConsistencyModelImg2ImgPipeline(
single_image_embeds, single_negative_image_embeds = self.encode_image(
single_ip_adapter_image, device, 1, output_hidden_state
)
single_image_embeds = torch.stack([single_image_embeds] * num_images_per_prompt, dim=0)
single_negative_image_embeds = torch.stack(
[single_negative_image_embeds] * num_images_per_prompt, dim=0
)
image_embeds.append(single_image_embeds[None, :])
if do_classifier_free_guidance:
single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds])
single_image_embeds = single_image_embeds.to(device)
image_embeds.append(single_image_embeds)
negative_image_embeds.append(single_negative_image_embeds[None, :])
else:
repeat_dims = [1]
image_embeds = []
for single_image_embeds in ip_adapter_image_embeds:
if do_classifier_free_guidance:
single_negative_image_embeds, single_image_embeds = single_image_embeds.chunk(2)
single_image_embeds = single_image_embeds.repeat(
num_images_per_prompt, *(repeat_dims * len(single_image_embeds.shape[1:]))
)
single_negative_image_embeds = single_negative_image_embeds.repeat(
num_images_per_prompt, *(repeat_dims * len(single_negative_image_embeds.shape[1:]))
)
single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds])
else:
single_image_embeds = single_image_embeds.repeat(
num_images_per_prompt, *(repeat_dims * len(single_image_embeds.shape[1:]))
)
negative_image_embeds.append(single_negative_image_embeds)
image_embeds.append(single_image_embeds)
return image_embeds
ip_adapter_image_embeds = []
for i, single_image_embeds in enumerate(image_embeds):
single_image_embeds = torch.cat([single_image_embeds] * num_images_per_prompt, dim=0)
if do_classifier_free_guidance:
single_negative_image_embeds = torch.cat([negative_image_embeds[i]] * num_images_per_prompt, dim=0)
single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds], dim=0)
single_image_embeds = single_image_embeds.to(device=device)
ip_adapter_image_embeds.append(single_image_embeds)
return ip_adapter_image_embeds
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker
def run_safety_checker(self, image, device, dtype):
@@ -425,6 +425,9 @@ class LatentConsistencyModelPipeline(
def prepare_ip_adapter_image_embeds(
self, ip_adapter_image, ip_adapter_image_embeds, device, num_images_per_prompt, do_classifier_free_guidance
):
image_embeds = []
if do_classifier_free_guidance:
negative_image_embeds = []
if ip_adapter_image_embeds is None:
if not isinstance(ip_adapter_image, list):
ip_adapter_image = [ip_adapter_image]
@@ -434,7 +437,6 @@ class LatentConsistencyModelPipeline(
f"`ip_adapter_image` must have same length as the number of IP Adapters. Got {len(ip_adapter_image)} images and {len(self.unet.encoder_hid_proj.image_projection_layers)} IP Adapters."
)
image_embeds = []
for single_ip_adapter_image, image_proj_layer in zip(
ip_adapter_image, self.unet.encoder_hid_proj.image_projection_layers
):
@@ -442,36 +444,28 @@ class LatentConsistencyModelPipeline(
single_image_embeds, single_negative_image_embeds = self.encode_image(
single_ip_adapter_image, device, 1, output_hidden_state
)
single_image_embeds = torch.stack([single_image_embeds] * num_images_per_prompt, dim=0)
single_negative_image_embeds = torch.stack(
[single_negative_image_embeds] * num_images_per_prompt, dim=0
)
image_embeds.append(single_image_embeds[None, :])
if do_classifier_free_guidance:
single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds])
single_image_embeds = single_image_embeds.to(device)
image_embeds.append(single_image_embeds)
negative_image_embeds.append(single_negative_image_embeds[None, :])
else:
repeat_dims = [1]
image_embeds = []
for single_image_embeds in ip_adapter_image_embeds:
if do_classifier_free_guidance:
single_negative_image_embeds, single_image_embeds = single_image_embeds.chunk(2)
single_image_embeds = single_image_embeds.repeat(
num_images_per_prompt, *(repeat_dims * len(single_image_embeds.shape[1:]))
)
single_negative_image_embeds = single_negative_image_embeds.repeat(
num_images_per_prompt, *(repeat_dims * len(single_negative_image_embeds.shape[1:]))
)
single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds])
else:
single_image_embeds = single_image_embeds.repeat(
num_images_per_prompt, *(repeat_dims * len(single_image_embeds.shape[1:]))
)
negative_image_embeds.append(single_negative_image_embeds)
image_embeds.append(single_image_embeds)
return image_embeds
ip_adapter_image_embeds = []
for i, single_image_embeds in enumerate(image_embeds):
single_image_embeds = torch.cat([single_image_embeds] * num_images_per_prompt, dim=0)
if do_classifier_free_guidance:
single_negative_image_embeds = torch.cat([negative_image_embeds[i]] * num_images_per_prompt, dim=0)
single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds], dim=0)
single_image_embeds = single_image_embeds.to(device=device)
ip_adapter_image_embeds.append(single_image_embeds)
return ip_adapter_image_embeds
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker
def run_safety_checker(self, image, device, dtype):
+55
View File
@@ -0,0 +1,55 @@
from typing import TYPE_CHECKING
from ...utils import (
DIFFUSERS_SLOW_IMPORT,
OptionalDependencyNotAvailable,
_LazyModule,
get_objects_from_module,
is_flax_available,
is_torch_available,
is_transformers_available,
)
_dummy_objects = {}
_import_structure = {}
try:
if not (is_transformers_available() and is_torch_available()):
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
from ...utils import dummy_torch_and_transformers_objects # noqa F403
_dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects))
else:
_import_structure["pipeline_pag_controlnet_sd_xl"] = ["StableDiffusionXLControlNetPAGPipeline"]
_import_structure["pipeline_pag_sd"] = ["StableDiffusionPAGPipeline"]
_import_structure["pipeline_pag_sd_xl"] = ["StableDiffusionXLPAGPipeline"]
_import_structure["pipeline_pag_sd_xl_img2img"] = ["StableDiffusionXLPAGImg2ImgPipeline"]
_import_structure["pipeline_pag_sd_xl_inpaint"] = ["StableDiffusionXLPAGInpaintPipeline"]
if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
try:
if not (is_transformers_available() and is_torch_available()):
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
from ...utils.dummy_torch_and_transformers_objects import *
else:
from .pipeline_pag_controlnet_sd_xl import StableDiffusionXLControlNetPAGPipeline
from .pipeline_pag_sd import StableDiffusionPAGPipeline
from .pipeline_pag_sd_xl import StableDiffusionXLPAGPipeline
from .pipeline_pag_sd_xl_img2img import StableDiffusionXLPAGImg2ImgPipeline
from .pipeline_pag_sd_xl_inpaint import StableDiffusionXLPAGInpaintPipeline
else:
import sys
sys.modules[__name__] = _LazyModule(
__name__,
globals()["__file__"],
_import_structure,
module_spec=__spec__,
)
for name, value in _dummy_objects.items():
setattr(sys.modules[__name__], name, value)
+260
View File
@@ -0,0 +1,260 @@
# Copyright 2024 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
from ...models.attention_processor import (
PAGCFGIdentitySelfAttnProcessor2_0,
PAGIdentitySelfAttnProcessor2_0,
)
from ...utils import logging
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
class PAGMixin:
r"""Mixin class for PAG."""
@staticmethod
def _check_input_pag_applied_layer(layer):
r"""
Check if each layer input in `applied_pag_layers` is valid. It should be either one of these 3 formats:
"{block_type}", "{block_type}.{block_index}", or "{block_type}.{block_index}.{attention_index}". `block_type`
can be "down", "mid", "up". `block_index` should be in the format of "block_{i}". `attention_index` should be
in the format of "attentions_{j}".
"""
layer_splits = layer.split(".")
if len(layer_splits) > 3:
raise ValueError(f"pag layer should only contains block_type, block_index and attention_index{layer}.")
if len(layer_splits) >= 1:
if layer_splits[0] not in ["down", "mid", "up"]:
raise ValueError(
f"Invalid block_type in pag layer {layer}. Accept 'down', 'mid', 'up', got {layer_splits[0]}"
)
if len(layer_splits) >= 2:
if not layer_splits[1].startswith("block_"):
raise ValueError(f"Invalid block_index in pag layer: {layer}. Should start with 'block_'")
if len(layer_splits) == 3:
if not layer_splits[2].startswith("attentions_"):
raise ValueError(f"Invalid attention_index in pag layer: {layer}. Should start with 'attentions_'")
def _set_pag_attn_processor(self, pag_applied_layers, do_classifier_free_guidance):
r"""
Set the attention processor for the PAG layers.
"""
if do_classifier_free_guidance:
pag_attn_proc = PAGCFGIdentitySelfAttnProcessor2_0()
else:
pag_attn_proc = PAGIdentitySelfAttnProcessor2_0()
def is_self_attn(module_name):
r"""
Check if the module is self-attention module based on its name.
"""
return "attn1" in module_name and "to" not in name
def get_block_type(module_name):
r"""
Get the block type from the module name. can be "down", "mid", "up".
"""
# down_blocks.1.attentions.0.transformer_blocks.0.attn1 -> "down"
return module_name.split(".")[0].split("_")[0]
def get_block_index(module_name):
r"""
Get the block index from the module name. can be "block_0", "block_1", ... If there is only one block (e.g.
mid_block) and index is ommited from the name, it will be "block_0".
"""
# down_blocks.1.attentions.0.transformer_blocks.0.attn1 -> "block_1"
# mid_block.attentions.0.transformer_blocks.0.attn1 -> "block_0"
if "attentions" in module_name.split(".")[1]:
return "block_0"
else:
return f"block_{module_name.split('.')[1]}"
def get_attn_index(module_name):
r"""
Get the attention index from the module name. can be "attentions_0", "attentions_1", ...
"""
# down_blocks.1.attentions.0.transformer_blocks.0.attn1 -> "attentions_0"
# mid_block.attentions.0.transformer_blocks.0.attn1 -> "attentions_0"
if "attentions" in module_name.split(".")[2]:
return f"attentions_{module_name.split('.')[3]}"
elif "attentions" in module_name.split(".")[1]:
return f"attentions_{module_name.split('.')[2]}"
for pag_layer_input in pag_applied_layers:
# for each PAG layer input, we find corresponding self-attention layers in the unet model
target_modules = []
pag_layer_input_splits = pag_layer_input.split(".")
if len(pag_layer_input_splits) == 1:
# when the layer input only contains block_type. e.g. "mid", "down", "up"
block_type = pag_layer_input_splits[0]
for name, module in self.unet.named_modules():
if is_self_attn(name) and get_block_type(name) == block_type:
target_modules.append(module)
elif len(pag_layer_input_splits) == 2:
# when the layer inpput contains both block_type and block_index. e.g. "down.block_1", "mid.block_0"
block_type = pag_layer_input_splits[0]
block_index = pag_layer_input_splits[1]
for name, module in self.unet.named_modules():
if (
is_self_attn(name)
and get_block_type(name) == block_type
and get_block_index(name) == block_index
):
target_modules.append(module)
elif len(pag_layer_input_splits) == 3:
# when the layer input contains block_type, block_index and attention_index. e.g. "down.blocks_1.attentions_1"
block_type = pag_layer_input_splits[0]
block_index = pag_layer_input_splits[1]
attn_index = pag_layer_input_splits[2]
for name, module in self.unet.named_modules():
if (
is_self_attn(name)
and get_block_type(name) == block_type
and get_block_index(name) == block_index
and get_attn_index(name) == attn_index
):
target_modules.append(module)
if len(target_modules) == 0:
raise ValueError(f"Cannot find pag layer to set attention processor for: {pag_layer_input}")
for module in target_modules:
module.processor = pag_attn_proc
def _get_pag_scale(self, t):
r"""
Get the scale factor for the perturbed attention guidance at timestep `t`.
"""
if self.do_pag_adaptive_scaling:
signal_scale = self.pag_scale - self.pag_adaptive_scale * (1000 - t)
if signal_scale < 0:
signal_scale = 0
return signal_scale
else:
return self.pag_scale
def _apply_perturbed_attention_guidance(self, noise_pred, do_classifier_free_guidance, guidance_scale, t):
r"""
Apply perturbed attention guidance to the noise prediction.
Args:
noise_pred (torch.Tensor): The noise prediction tensor.
do_classifier_free_guidance (bool): Whether to apply classifier-free guidance.
guidance_scale (float): The scale factor for the guidance term.
t (int): The current time step.
Returns:
torch.Tensor: The updated noise prediction tensor after applying perturbed attention guidance.
"""
pag_scale = self._get_pag_scale(t)
if do_classifier_free_guidance:
noise_pred_uncond, noise_pred_text, noise_pred_perturb = noise_pred.chunk(3)
noise_pred = (
noise_pred_uncond
+ guidance_scale * (noise_pred_text - noise_pred_uncond)
+ pag_scale * (noise_pred_text - noise_pred_perturb)
)
else:
noise_pred_text, noise_pred_perturb = noise_pred.chunk(2)
noise_pred = noise_pred_text + pag_scale * (noise_pred_text - noise_pred_perturb)
return noise_pred
def _prepare_perturbed_attention_guidance(self, cond, uncond, do_classifier_free_guidance):
"""
Prepares the perturbed attention guidance for the PAG model.
Args:
cond (torch.Tensor): The conditional input tensor.
uncond (torch.Tensor): The unconditional input tensor.
do_classifier_free_guidance (bool): Flag indicating whether to perform classifier-free guidance.
Returns:
torch.Tensor: The prepared perturbed attention guidance tensor.
"""
cond = torch.cat([cond] * 2, dim=0)
if do_classifier_free_guidance:
cond = torch.cat([uncond, cond], dim=0)
return cond
def set_pag_applied_layers(self, pag_applied_layers):
r"""
set the the self-attention layers to apply PAG. Raise ValueError if the input is invalid.
"""
if not isinstance(pag_applied_layers, list):
pag_applied_layers = [pag_applied_layers]
for pag_layer in pag_applied_layers:
self._check_input_pag_applied_layer(pag_layer)
self.pag_applied_layers = pag_applied_layers
@property
def pag_scale(self):
"""
Get the scale factor for the perturbed attention guidance.
"""
return self._pag_scale
@property
def pag_adaptive_scale(self):
"""
Get the adaptive scale factor for the perturbed attention guidance.
"""
return self._pag_adaptive_scale
@property
def do_pag_adaptive_scaling(self):
"""
Check if the adaptive scaling is enabled for the perturbed attention guidance.
"""
return self._pag_adaptive_scale > 0 and self._pag_scale > 0 and len(self.pag_applied_layers) > 0
@property
def do_perturbed_attention_guidance(self):
"""
Check if the perturbed attention guidance is enabled.
"""
return self._pag_scale > 0 and len(self.pag_applied_layers) > 0
@property
def pag_attn_processors(self):
r"""
Returns:
`dict` of PAG attention processors: A dictionary contains all PAG attention processors used in the model
with the key as the name of the layer.
"""
processors = {}
for name, proc in self.unet.attn_processors.items():
if proc.__class__ in (PAGCFGIdentitySelfAttnProcessor2_0, PAGIdentitySelfAttnProcessor2_0):
processors[name] = proc
return processors
File diff suppressed because it is too large Load Diff
File diff suppressed because it is too large Load Diff
File diff suppressed because it is too large Load Diff
File diff suppressed because it is too large Load Diff
File diff suppressed because it is too large Load Diff
+17 -23
View File
@@ -505,6 +505,9 @@ class PIAPipeline(
def prepare_ip_adapter_image_embeds(
self, ip_adapter_image, ip_adapter_image_embeds, device, num_images_per_prompt, do_classifier_free_guidance
):
image_embeds = []
if do_classifier_free_guidance:
negative_image_embeds = []
if ip_adapter_image_embeds is None:
if not isinstance(ip_adapter_image, list):
ip_adapter_image = [ip_adapter_image]
@@ -514,7 +517,6 @@ class PIAPipeline(
f"`ip_adapter_image` must have same length as the number of IP Adapters. Got {len(ip_adapter_image)} images and {len(self.unet.encoder_hid_proj.image_projection_layers)} IP Adapters."
)
image_embeds = []
for single_ip_adapter_image, image_proj_layer in zip(
ip_adapter_image, self.unet.encoder_hid_proj.image_projection_layers
):
@@ -522,36 +524,28 @@ class PIAPipeline(
single_image_embeds, single_negative_image_embeds = self.encode_image(
single_ip_adapter_image, device, 1, output_hidden_state
)
single_image_embeds = torch.stack([single_image_embeds] * num_images_per_prompt, dim=0)
single_negative_image_embeds = torch.stack(
[single_negative_image_embeds] * num_images_per_prompt, dim=0
)
image_embeds.append(single_image_embeds[None, :])
if do_classifier_free_guidance:
single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds])
single_image_embeds = single_image_embeds.to(device)
image_embeds.append(single_image_embeds)
negative_image_embeds.append(single_negative_image_embeds[None, :])
else:
repeat_dims = [1]
image_embeds = []
for single_image_embeds in ip_adapter_image_embeds:
if do_classifier_free_guidance:
single_negative_image_embeds, single_image_embeds = single_image_embeds.chunk(2)
single_image_embeds = single_image_embeds.repeat(
num_images_per_prompt, *(repeat_dims * len(single_image_embeds.shape[1:]))
)
single_negative_image_embeds = single_negative_image_embeds.repeat(
num_images_per_prompt, *(repeat_dims * len(single_negative_image_embeds.shape[1:]))
)
single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds])
else:
single_image_embeds = single_image_embeds.repeat(
num_images_per_prompt, *(repeat_dims * len(single_image_embeds.shape[1:]))
)
negative_image_embeds.append(single_negative_image_embeds)
image_embeds.append(single_image_embeds)
return image_embeds
ip_adapter_image_embeds = []
for i, single_image_embeds in enumerate(image_embeds):
single_image_embeds = torch.cat([single_image_embeds] * num_images_per_prompt, dim=0)
if do_classifier_free_guidance:
single_negative_image_embeds = torch.cat([negative_image_embeds[i]] * num_images_per_prompt, dim=0)
single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds], dim=0)
single_image_embeds = single_image_embeds.to(device=device)
ip_adapter_image_embeds.append(single_image_embeds)
return ip_adapter_image_embeds
# Copied from diffusers.pipelines.text_to_video_synthesis.pipeline_text_to_video_synth.TextToVideoSDPipeline.prepare_latents
def prepare_latents(
@@ -296,7 +296,7 @@ class FlaxDiffusionPipeline(ConfigMixin, PushToHubMixin):
>>> # see more in [the documentation](https://huggingface.co/docs/hub/security-tokens)
>>> pipeline, params = FlaxDiffusionPipeline.from_pretrained(
... "runwayml/stable-diffusion-v1-5",
... revision="bf16",
... variant="bf16",
... dtype=jnp.bfloat16,
... )
@@ -310,7 +310,7 @@ class FlaxDiffusionPipeline(ConfigMixin, PushToHubMixin):
... )
>>> dpm_pipe, dpm_params = FlaxStableDiffusionPipeline.from_pretrained(
... model_id, revision="bf16", dtype=jnp.bfloat16, scheduler=dpmpp
... model_id, variant="bf16", dtype=jnp.bfloat16, scheduler=dpmpp
... )
>>> dpm_params["scheduler"] = dpmpp_state
```
@@ -564,7 +564,7 @@ class FlaxDiffusionPipeline(ConfigMixin, PushToHubMixin):
... )
>>> text2img = FlaxStableDiffusionPipeline.from_pretrained(
... "runwayml/stable-diffusion-v1-5", revision="bf16", dtype=jnp.bfloat16
... "runwayml/stable-diffusion-v1-5", variant="bf16", dtype=jnp.bfloat16
... )
>>> img2img = FlaxStableDiffusionImg2ImgPipeline(**text2img.components)
```
@@ -55,7 +55,7 @@ EXAMPLE_DOC_STRING = """
>>> from diffusers import FlaxStableDiffusionPipeline
>>> pipeline, params = FlaxStableDiffusionPipeline.from_pretrained(
... "runwayml/stable-diffusion-v1-5", revision="bf16", dtype=jax.numpy.bfloat16
... "runwayml/stable-diffusion-v1-5", variant="bf16", dtype=jax.numpy.bfloat16
... )
>>> prompt = "a photo of an astronaut riding a horse on mars"
@@ -508,6 +508,9 @@ class StableDiffusionPipeline(
def prepare_ip_adapter_image_embeds(
self, ip_adapter_image, ip_adapter_image_embeds, device, num_images_per_prompt, do_classifier_free_guidance
):
image_embeds = []
if do_classifier_free_guidance:
negative_image_embeds = []
if ip_adapter_image_embeds is None:
if not isinstance(ip_adapter_image, list):
ip_adapter_image = [ip_adapter_image]
@@ -517,7 +520,6 @@ class StableDiffusionPipeline(
f"`ip_adapter_image` must have same length as the number of IP Adapters. Got {len(ip_adapter_image)} images and {len(self.unet.encoder_hid_proj.image_projection_layers)} IP Adapters."
)
image_embeds = []
for single_ip_adapter_image, image_proj_layer in zip(
ip_adapter_image, self.unet.encoder_hid_proj.image_projection_layers
):
@@ -525,36 +527,28 @@ class StableDiffusionPipeline(
single_image_embeds, single_negative_image_embeds = self.encode_image(
single_ip_adapter_image, device, 1, output_hidden_state
)
single_image_embeds = torch.stack([single_image_embeds] * num_images_per_prompt, dim=0)
single_negative_image_embeds = torch.stack(
[single_negative_image_embeds] * num_images_per_prompt, dim=0
)
image_embeds.append(single_image_embeds[None, :])
if do_classifier_free_guidance:
single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds])
single_image_embeds = single_image_embeds.to(device)
image_embeds.append(single_image_embeds)
negative_image_embeds.append(single_negative_image_embeds[None, :])
else:
repeat_dims = [1]
image_embeds = []
for single_image_embeds in ip_adapter_image_embeds:
if do_classifier_free_guidance:
single_negative_image_embeds, single_image_embeds = single_image_embeds.chunk(2)
single_image_embeds = single_image_embeds.repeat(
num_images_per_prompt, *(repeat_dims * len(single_image_embeds.shape[1:]))
)
single_negative_image_embeds = single_negative_image_embeds.repeat(
num_images_per_prompt, *(repeat_dims * len(single_negative_image_embeds.shape[1:]))
)
single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds])
else:
single_image_embeds = single_image_embeds.repeat(
num_images_per_prompt, *(repeat_dims * len(single_image_embeds.shape[1:]))
)
negative_image_embeds.append(single_negative_image_embeds)
image_embeds.append(single_image_embeds)
return image_embeds
ip_adapter_image_embeds = []
for i, single_image_embeds in enumerate(image_embeds):
single_image_embeds = torch.cat([single_image_embeds] * num_images_per_prompt, dim=0)
if do_classifier_free_guidance:
single_negative_image_embeds = torch.cat([negative_image_embeds[i]] * num_images_per_prompt, dim=0)
single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds], dim=0)
single_image_embeds = single_image_embeds.to(device=device)
ip_adapter_image_embeds.append(single_image_embeds)
return ip_adapter_image_embeds
def run_safety_checker(self, image, device, dtype):
if self.safety_checker is None:
@@ -553,6 +553,9 @@ class StableDiffusionImg2ImgPipeline(
def prepare_ip_adapter_image_embeds(
self, ip_adapter_image, ip_adapter_image_embeds, device, num_images_per_prompt, do_classifier_free_guidance
):
image_embeds = []
if do_classifier_free_guidance:
negative_image_embeds = []
if ip_adapter_image_embeds is None:
if not isinstance(ip_adapter_image, list):
ip_adapter_image = [ip_adapter_image]
@@ -562,7 +565,6 @@ class StableDiffusionImg2ImgPipeline(
f"`ip_adapter_image` must have same length as the number of IP Adapters. Got {len(ip_adapter_image)} images and {len(self.unet.encoder_hid_proj.image_projection_layers)} IP Adapters."
)
image_embeds = []
for single_ip_adapter_image, image_proj_layer in zip(
ip_adapter_image, self.unet.encoder_hid_proj.image_projection_layers
):
@@ -570,36 +572,28 @@ class StableDiffusionImg2ImgPipeline(
single_image_embeds, single_negative_image_embeds = self.encode_image(
single_ip_adapter_image, device, 1, output_hidden_state
)
single_image_embeds = torch.stack([single_image_embeds] * num_images_per_prompt, dim=0)
single_negative_image_embeds = torch.stack(
[single_negative_image_embeds] * num_images_per_prompt, dim=0
)
image_embeds.append(single_image_embeds[None, :])
if do_classifier_free_guidance:
single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds])
single_image_embeds = single_image_embeds.to(device)
image_embeds.append(single_image_embeds)
negative_image_embeds.append(single_negative_image_embeds[None, :])
else:
repeat_dims = [1]
image_embeds = []
for single_image_embeds in ip_adapter_image_embeds:
if do_classifier_free_guidance:
single_negative_image_embeds, single_image_embeds = single_image_embeds.chunk(2)
single_image_embeds = single_image_embeds.repeat(
num_images_per_prompt, *(repeat_dims * len(single_image_embeds.shape[1:]))
)
single_negative_image_embeds = single_negative_image_embeds.repeat(
num_images_per_prompt, *(repeat_dims * len(single_negative_image_embeds.shape[1:]))
)
single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds])
else:
single_image_embeds = single_image_embeds.repeat(
num_images_per_prompt, *(repeat_dims * len(single_image_embeds.shape[1:]))
)
negative_image_embeds.append(single_negative_image_embeds)
image_embeds.append(single_image_embeds)
return image_embeds
ip_adapter_image_embeds = []
for i, single_image_embeds in enumerate(image_embeds):
single_image_embeds = torch.cat([single_image_embeds] * num_images_per_prompt, dim=0)
if do_classifier_free_guidance:
single_negative_image_embeds = torch.cat([negative_image_embeds[i]] * num_images_per_prompt, dim=0)
single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds], dim=0)
single_image_embeds = single_image_embeds.to(device=device)
ip_adapter_image_embeds.append(single_image_embeds)
return ip_adapter_image_embeds
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker
def run_safety_checker(self, image, device, dtype):
@@ -502,6 +502,9 @@ class StableDiffusionInpaintPipeline(
def prepare_ip_adapter_image_embeds(
self, ip_adapter_image, ip_adapter_image_embeds, device, num_images_per_prompt, do_classifier_free_guidance
):
image_embeds = []
if do_classifier_free_guidance:
negative_image_embeds = []
if ip_adapter_image_embeds is None:
if not isinstance(ip_adapter_image, list):
ip_adapter_image = [ip_adapter_image]
@@ -511,7 +514,6 @@ class StableDiffusionInpaintPipeline(
f"`ip_adapter_image` must have same length as the number of IP Adapters. Got {len(ip_adapter_image)} images and {len(self.unet.encoder_hid_proj.image_projection_layers)} IP Adapters."
)
image_embeds = []
for single_ip_adapter_image, image_proj_layer in zip(
ip_adapter_image, self.unet.encoder_hid_proj.image_projection_layers
):
@@ -519,36 +521,28 @@ class StableDiffusionInpaintPipeline(
single_image_embeds, single_negative_image_embeds = self.encode_image(
single_ip_adapter_image, device, 1, output_hidden_state
)
single_image_embeds = torch.stack([single_image_embeds] * num_images_per_prompt, dim=0)
single_negative_image_embeds = torch.stack(
[single_negative_image_embeds] * num_images_per_prompt, dim=0
)
image_embeds.append(single_image_embeds[None, :])
if do_classifier_free_guidance:
single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds])
single_image_embeds = single_image_embeds.to(device)
image_embeds.append(single_image_embeds)
negative_image_embeds.append(single_negative_image_embeds[None, :])
else:
repeat_dims = [1]
image_embeds = []
for single_image_embeds in ip_adapter_image_embeds:
if do_classifier_free_guidance:
single_negative_image_embeds, single_image_embeds = single_image_embeds.chunk(2)
single_image_embeds = single_image_embeds.repeat(
num_images_per_prompt, *(repeat_dims * len(single_image_embeds.shape[1:]))
)
single_negative_image_embeds = single_negative_image_embeds.repeat(
num_images_per_prompt, *(repeat_dims * len(single_negative_image_embeds.shape[1:]))
)
single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds])
else:
single_image_embeds = single_image_embeds.repeat(
num_images_per_prompt, *(repeat_dims * len(single_image_embeds.shape[1:]))
)
negative_image_embeds.append(single_negative_image_embeds)
image_embeds.append(single_image_embeds)
return image_embeds
ip_adapter_image_embeds = []
for i, single_image_embeds in enumerate(image_embeds):
single_image_embeds = torch.cat([single_image_embeds] * num_images_per_prompt, dim=0)
if do_classifier_free_guidance:
single_negative_image_embeds = torch.cat([negative_image_embeds[i]] * num_images_per_prompt, dim=0)
single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds], dim=0)
single_image_embeds = single_image_embeds.to(device=device)
ip_adapter_image_embeds.append(single_image_embeds)
return ip_adapter_image_embeds
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker
def run_safety_checker(self, image, device, dtype):
@@ -612,7 +612,7 @@ class StableDiffusionUpscalePipeline(
>>> # load model and scheduler
>>> model_id = "stabilityai/stable-diffusion-x4-upscaler"
>>> pipeline = StableDiffusionUpscalePipeline.from_pretrained(
... model_id, revision="fp16", torch_dtype=torch.float16
... model_id, variant="fp16", torch_dtype=torch.float16
... )
>>> pipeline = pipeline.to("cuda")
@@ -62,7 +62,7 @@ EXAMPLE_DOC_STRING = """
>>> pipe = pipe.to(device)
>>> url = "https://raw.githubusercontent.com/CompVis/stable-diffusion/main/assets/stable-samples/img2img/sketch-mountains-input.jpg"
>>> init_image = load_image(url).resize((512, 512))
>>> init_image = load_image(url).resize((1024, 1024))
>>> prompt = "cat wizard, gandalf, lord of the rings, detailed, fantasy, cute, adorable, Pixar, Disney, 8k"
@@ -852,7 +852,7 @@ class StableDiffusion3Img2ImgPipeline(DiffusionPipeline):
# 4. Prepare timesteps
timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps)
timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, strength, device)
latent_timestep = timesteps[:1].repeat(batch_size * num_inference_steps)
latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt)
# 5. Prepare latent variables
if latents is None:

Some files were not shown because too many files have changed in this diff Show More