Compare commits

...

31 Commits

Author SHA1 Message Date
patil-suraj ea238e821b up 2024-03-18 11:47:47 +01:00
patil-suraj b6d1d670fc up 2024-03-18 11:34:17 +01:00
Dhruv Nair 4330a747d4 [Tests] Fix ControlNet Single File tests (#7315)
* update

* update

---------

Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
2024-03-18 11:28:59 +05:30
Sayak Paul 76de6a09fb post-release v0.27.0 (#7329)
* post-release

* quality
2024-03-18 10:52:20 +05:30
Sayak Paul 25caf24ef9 Fix release workflow deps (#7339)
* pop scale from the top-level unet instead of getting it.

* improve readability.

* fix: pypi workflow deps

* revert
2024-03-16 07:18:11 +05:30
Abubakar Abid 8db3c9bc9f Adds docs for gradio.Interface.from_pipeline() (#7346)
* gradio docs

* Update docs/source/en/api/pipelines/stable_diffusion/overview.md

Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com>

* changes

* changes

* changes

* Update docs/source/en/api/pipelines/stable_diffusion/overview.md

---------

Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com>
Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
2024-03-16 07:11:28 +05:30
Sayak Paul e0e9f81971 add: torch to the pypi step. (#7328) 2024-03-15 12:28:12 +05:30
M. Tolga Cangöz 5d848ec07c [Tests] Update a deprecated parameter in test files and fix several typos (#7277)
* Add properties and `IPAdapterTesterMixin` tests for `StableDiffusionPanoramaPipeline`

* Fix variable name typo and update comments

* Update deprecated `output_type="numpy"` to "np" in test files

* Discard changes to src/diffusers/pipelines/stable_diffusion_panorama/pipeline_stable_diffusion_panorama.py

* Update test_stable_diffusion_panorama.py

* Update numbers in README.md

* Update get_guidance_scale_embedding method to use timesteps instead of w

* Update number of checkpoints in README.md

* Add type hints and fix var name

* Fix PyTorch's convention for inplace functions

* Fix a typo

* Revert "Fix PyTorch's convention for inplace functions"

This reverts commit 74350cf65b.

* Fix typos

* Indent

* Refactor get_guidance_scale_embedding method in LEditsPPPipelineStableDiffusionXL class
2024-03-14 12:17:35 -07:00
Dhruv Nair 4974b84564 Update Cascade Tests (#7324)
* update

* update

* update
2024-03-14 20:51:22 +05:30
Linoy Tsaban 83062fb872 [Advanced DreamBooth LoRA SDXL] Support EDM-style training (follow up of #7126) (#7182)
* add edm style training

* style

* finish adding edm training feature

* import fix

* fix latents mean

* minor adjustments

* add edm to readme

* style

* fix autocast and scheduler config issues when using edm

* style

---------

Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
2024-03-14 18:40:14 +05:30
Suraj Patil b6d7e31d10 add edm schedulers in doc (#7319)
* add edm schedulers in doc

* add in toctree

* address reviewe comments
2024-03-14 11:52:25 +01:00
Anatoly Belikov 53e9aacc10 log loss per image (#7278)
* log loss per image

* add commandline param for per image loss logging

* style

* debug-loss -> debug_loss

---------

Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
2024-03-14 11:41:43 +05:30
Dhruv Nair 41424466e3 [Tests] Fix incorrect constant in VAE scaling test. (#7301)
update
2024-03-14 10:24:01 +05:30
Sayak Paul 95de1981c9 add: pytest log installation (#7313) 2024-03-14 10:01:16 +05:30
Kenneth Gerald Hamilton 0b45b58867 update get_order_list if statement (#7309)
* update get_order_list if statement

* revery
2024-03-13 18:29:42 -10:00
Beinsezii d3986f18be Change step_offset scheduler docstrings (#7128)
* Change step_offset scheduler docstrings

* Mention it may be needed by some models

* More docstrings

These ones failed literal S&R because I performed it case-sensitive
which is fun.

---------

Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
2024-03-13 15:12:00 -10:00
Alexander Bonnet ee6a3a993d Fix typos in UNet2DConditionModel documentation (#7291)
* fix typo in UNet2DConditionModel documentation

* Fix indentation that may fix doc rendering

* Fix squished doc lines
2024-03-13 09:31:29 -07:00
Michael b300517305 Add Intro page of TCD (#7259)
* add tcd intro

* resolve repos

* Apply suggestions from code review

Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com>

* revise NFEs related

* change inpainting location

---------

Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com>
2024-03-13 09:21:51 -07:00
jnhuang ac07b6dc6a Fix Wrong Text-encoder Grad Setting in Custom_Diffusion Training (#7302)
fix index in set textencoder grad

Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
2024-03-13 20:22:44 +05:30
Sayak Paul 46ab56a468 add: support for notifying maintainers about the nightly test status (#7117)
* add: support for notifying maintainers about the nightly test status

* add: a tempoerary workflow for validation.

* cancel in progress.

* runs-on

* clean up

* add: peft dep

* change device.

* multiple edits.

* remove temp workflow.
2024-03-13 16:48:11 +05:30
Sayak Paul 038ff70023 [PyPI publishing] feat: automate the process of pypi publication to some extent. (#7270)
* feat: automate the process of pypi publication to some extent.

* utility to fetch the latest release branch

* correct package name.
2024-03-13 16:27:59 +05:30
Manuel Brack 00eca4b887 [Pipeline] Add LEDITS++ pipelines (#6074)
* Setup LEdits++ file structure

* Fix import

* LEditsPP Stable Diffusion pipeline

* Include variable image aspect ratios

* Implement LEDITS++ for SDXL

* clean up LEditsPPPipelineStableDiffusion

* Adjust inversion output

* Added docu, more cleanup for LEditsPPPipelineStableDiffusion

* clean up LEditsPPPipelineStableDiffusionXL

* Update documentation

* Fix documentation import

* Add skeleton IF implementation

* Fix documentation typo

* Add LEDTIS docu to toctree

* Add missing title

* Finalize SD documentation

* Finalize SD-XL documentation

* Fix code style and quality

* Fix typo

* Fix return types

* added LEditsPPPipelineIF; minor changes for LEditsPPPipelineStableDiffusion and LEditsPPPipelineStableDiffusionXL

* Fix copy reference

* add documentation for IF

* Add first tests

* Fix batching for SD-XL

* Fix text encoding and perfect reconstruction for SD-XL

* Add tests for SD-XL, minor changes

* move user_mask to correct device, use cross_attention_kwargs also for inversion

* Example docstring

* Fix attention resolution for non-square images

* Refactoring for PR review

* Safely remove ledits_utils.py

* Style fixes

* Replace assertions with ValueError

* Remove LEditsPPPipelineIF

* Remove unecessary input checks

* Refactoring of CrossAttnProcessor

* Revert unecessary changes to scheduler

* Remove first progress-bar in inversion

* Refactor scheduler usage and reset

* Use imageprocessor instead of custom logic

* Fix scheduler init warning

* Fix error when running the pipeline in fp16

* Update documentation wrt perfect inversion

* Update tests

* Fix code quality and copy consistency

* Update LEditsPP import

* Remove enable/disable methods that are now in StableDiffusionMixin

* Change import in docs

* Revert import structure change

* Fix ledits imports

---------

Co-authored-by: Katharina Kornmeier <katharina.kornmeier@stud.tu-darmstadt.de>
2024-03-13 12:43:47 +02:00
Dhruv Nair 30132aba30 Update Stable Cascade Conversion Scripts (#7271)
* update

* update

* update

* update

* update

* update

* update

* update

* update

---------

Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
2024-03-13 12:35:44 +05:30
Dhruv Nair a17d6d6858 Update Cascade documentation (#7257)
* updates

* update

* update

* Update docs/source/en/api/pipelines/stable_cascade.md

Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>

* update

* update

* update

* update

---------

Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
Co-authored-by: Kashif Rasul <kashif.rasul@gmail.com>
2024-03-13 11:29:59 +05:30
Sayak Paul 8efd9ce787 [Chore] clean residue from copy-pasting in the UNet single file loader (#7295)
clean residue from copy-pasting
2024-03-13 11:20:13 +05:30
Dhruv Nair 299c16d0f5 Fix loading Img2Img refiner components in from_single_file (#7282)
* update

* update

* update

* update
2024-03-13 09:25:53 +05:30
Dhruv Nair 69f49195ac Fix passing pooled prompt embeds to Cascade Decoder and Combined Pipeline (#7287)
* update

* update

* update

* update
2024-03-13 09:21:41 +05:30
Dhruv Nair ed224f94ba Add single file support for Stable Cascade (#7274)
* update

* update

* update

* update

* update

* update

---------

Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
2024-03-13 08:37:31 +05:30
Sayak Paul 531e719163 [LoRA] use the PyTorch classes wherever needed and start depcrecation cycles (#7204)
* fix PyTorch classes and start deprecsation cycles.

* remove args crafting for accommodating scale.

* remove scale check in feedforward.

* assert against nn.Linear and not CompatibleLinear.

* remove conv_cls and lineaR_cls.

* remove scale

* 👋 scale.

* fix: unet2dcondition

* fix attention.py

* fix: attention.py again

* fix: unet_2d_blocks.

* fix-copies.

* more fixes.

* fix: resnet.py

* more fixes

* fix i2vgenxl unet.

* depcrecate scale gently.

* fix-copies

* Apply suggestions from code review

Co-authored-by: YiYi Xu <yixu310@gmail.com>

* quality

* throw warning when scale is passed to the the BasicTransformerBlock class.

* remove scale from signature.

* cross_attention_kwargs, very nice catch by Yiyi

* fix: logger.warn

* make deprecation message clearer.

* address final comments.

* maintain same depcrecation message and also add it to activations.

* address yiyi

* fix copies

* Apply suggestions from code review

Co-authored-by: YiYi Xu <yixu310@gmail.com>

* more depcrecation

* fix-copies

---------

Co-authored-by: YiYi Xu <yixu310@gmail.com>
2024-03-13 07:56:19 +05:30
Sayak Paul 4fbd310fd2 [Chore] switch to logger.warning (#7289)
switch to logger.warning
2024-03-13 06:56:43 +05:30
Dhruv Nair 2ea28d69dc Change export_to_video default (#6990)
update

Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
2024-03-12 17:13:12 +05:30
210 changed files with 7281 additions and 1095 deletions
+22 -2
View File
@@ -12,6 +12,7 @@ env:
PYTEST_TIMEOUT: 600
RUN_SLOW: yes
RUN_NIGHTLY: yes
SLACK_API_TOKEN: ${{ secrets.SLACK_CIFEEDBACK_BOT_TOKEN }}
jobs:
run_nightly_tests:
@@ -64,6 +65,7 @@ jobs:
python -m uv pip install -e [quality,test]
python -m uv pip install -U transformers@git+https://github.com/huggingface/transformers
python -m uv pip install accelerate@git+https://github.com/huggingface/accelerate
python -m uv pip install pytest-reportlog
- name: Environment
run: |
@@ -78,7 +80,8 @@ jobs:
python -m pytest -n 1 --max-worker-restart=0 --dist=loadfile \
-s -v -k "not Flax and not Onnx" \
--make-reports=tests_${{ matrix.config.report }} \
tests/
--report-log=${{ matrix.config.report }}.log \
tests/
- name: Run nightly Flax TPU tests
if: ${{ matrix.config.framework == 'flax' }}
@@ -89,6 +92,7 @@ jobs:
python -m pytest -n 0 \
-s -v -k "Flax" \
--make-reports=tests_${{ matrix.config.report }} \
--report-log=${{ matrix.config.report }}.log \
tests/
- name: Run nightly ONNXRuntime CUDA tests
@@ -100,6 +104,7 @@ jobs:
python -m pytest -n 1 --max-worker-restart=0 --dist=loadfile \
-s -v -k "Onnx" \
--make-reports=tests_${{ matrix.config.report }} \
--report-log=${{ matrix.config.report }}.log \
tests/
- name: Failure short reports
@@ -112,6 +117,12 @@ jobs:
with:
name: ${{ matrix.config.report }}_test_reports
path: reports
- name: Generate Report and Notify Channel
if: always()
run: |
pip install slack_sdk tabulate
python scripts/log_reports.py >> $GITHUB_STEP_SUMMARY
run_nightly_tests_apple_m1:
name: Nightly PyTorch MPS tests on MacOS
@@ -140,6 +151,7 @@ jobs:
${CONDA_RUN} python -m uv pip install -e [quality,test]
${CONDA_RUN} python -m uv pip install torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/cpu
${CONDA_RUN} python -m uv pip install accelerate@git+https://github.com/huggingface/accelerate
${CONDA_RUN} python -m uv pip install pytest-reportlog
- name: Environment
shell: arch -arch arm64 bash {0}
@@ -152,7 +164,9 @@ jobs:
HF_HOME: /System/Volumes/Data/mnt/cache
HUGGING_FACE_HUB_TOKEN: ${{ secrets.HUGGING_FACE_HUB_TOKEN }}
run: |
${CONDA_RUN} python -m pytest -n 1 -s -v --make-reports=tests_torch_mps tests/
${CONDA_RUN} python -m pytest -n 1 -s -v --make-reports=tests_torch_mps \
--report-log=tests_torch_mps.log \
tests/
- name: Failure short reports
if: ${{ failure() }}
@@ -164,3 +178,9 @@ jobs:
with:
name: torch_mps_test_reports
path: reports
- name: Generate Report and Notify Channel
if: always()
run: |
pip install slack_sdk tabulate
python scripts/log_reports.py >> $GITHUB_STEP_SUMMARY
@@ -0,0 +1,23 @@
name: Notify Slack about a release
on:
workflow_dispatch:
release:
types: [published]
jobs:
build:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v3
- name: Setup Python
uses: actions/setup-python@v4
with:
python-version: '3.8'
- name: Notify Slack about the release
env:
SLACK_WEBHOOK_URL: ${{ secrets.SLACK_WEBHOOK_URL }}
run: pip install requests && python utils/notify_slack_about_release.py
+81
View File
@@ -0,0 +1,81 @@
# Adapted from https://blog.deepjyoti30.dev/pypi-release-github-action
name: PyPI release
on:
workflow_dispatch:
push:
tags:
- "*"
jobs:
find-and-checkout-latest-branch:
runs-on: ubuntu-latest
outputs:
latest_branch: ${{ steps.set_latest_branch.outputs.latest_branch }}
steps:
- name: Checkout Repo
uses: actions/checkout@v3
- name: Set up Python
uses: actions/setup-python@v4
with:
python-version: '3.8'
- name: Fetch latest branch
id: fetch_latest_branch
run: |
pip install -U requests packaging
LATEST_BRANCH=$(python utils/fetch_latest_release_branch.py)
echo "Latest branch: $LATEST_BRANCH"
echo "latest_branch=$LATEST_BRANCH" >> $GITHUB_ENV
- name: Set latest branch output
id: set_latest_branch
run: echo "::set-output name=latest_branch::${{ env.latest_branch }}"
release:
needs: find-and-checkout-latest-branch
runs-on: ubuntu-latest
steps:
- name: Checkout Repo
uses: actions/checkout@v3
with:
ref: ${{ needs.find-and-checkout-latest-branch.outputs.latest_branch }}
- name: Setup Python
uses: actions/setup-python@v4
with:
python-version: "3.8"
- name: Install dependencies
run: |
python -m pip install --upgrade pip
pip install -U setuptools wheel twine
pip install -U torch --index-url https://download.pytorch.org/whl/cpu
pip install -U transformers
- name: Build the dist files
run: python setup.py bdist_wheel && python setup.py sdist
- name: Publish to the test PyPI
env:
TWINE_USERNAME: ${{ secrets.TEST_PYPI_USERNAME }}
TWINE_PASSWORD: ${{ secrets.TEST_PYPI_PASSWORD }}
run: twine upload dist/* -r pypitest --repository-url=https://test.pypi.org/legacy/
- name: Test installing diffusers and importing
run: |
pip install diffusers && pip uninstall diffusers -y
pip install -i https://testpypi.python.org/pypi diffusers
python -c "from diffusers import __version__; print(__version__)"
python -c "from diffusers import DiffusionPipeline; pipe = DiffusionPipeline.from_pretrained('fusing/unet-ldm-dummy-update'); pipe()"
python -c "from diffusers import DiffusionPipeline; pipe = DiffusionPipeline.from_pretrained('hf-internal-testing/tiny-stable-diffusion-pipe', safety_checker=None); pipe('ah suh du')"
python -c "from diffusers import *"
- name: Publish to PyPI
env:
TWINE_USERNAME: ${{ secrets.PYPI_USERNAME }}
TWINE_PASSWORD: ${{ secrets.PYPI_PASSWORD }}
run: twine upload dist/* -r pypi
+2 -2
View File
@@ -77,7 +77,7 @@ Please refer to the [How to use Stable Diffusion in Apple Silicon](https://huggi
## Quickstart
Generating outputs is super easy with 🤗 Diffusers. To generate an image from text, use the `from_pretrained` method to load any pretrained diffusion model (browse the [Hub](https://huggingface.co/models?library=diffusers&sort=downloads) for 19000+ checkpoints):
Generating outputs is super easy with 🤗 Diffusers. To generate an image from text, use the `from_pretrained` method to load any pretrained diffusion model (browse the [Hub](https://huggingface.co/models?library=diffusers&sort=downloads) for 22000+ checkpoints):
```python
from diffusers import DiffusionPipeline
@@ -219,7 +219,7 @@ Also, say 👋 in our public Discord channel <a href="https://discord.gg/G7tWnz9
- https://github.com/deep-floyd/IF
- https://github.com/bentoml/BentoML
- https://github.com/bmaltais/kohya_ss
- +8000 other amazing GitHub repositories 💪
- +9000 other amazing GitHub repositories 💪
Thank you for using us ❤️.
+8
View File
@@ -104,6 +104,8 @@
title: Latent Consistency Model-LoRA
- local: using-diffusers/inference_with_lcm
title: Latent Consistency Model
- local: using-diffusers/inference_with_tcd_lora
title: Trajectory Consistency Distillation-LoRA
- local: using-diffusers/svd
title: Stable Video Diffusion
title: Specific pipeline examples
@@ -304,6 +306,8 @@
title: Latent Consistency Models
- local: api/pipelines/latent_diffusion
title: Latent Diffusion
- local: api/pipelines/ledits_pp
title: LEDITS++
- local: api/pipelines/panorama
title: MultiDiffusion
- local: api/pipelines/musicldm
@@ -396,6 +400,10 @@
title: DPMSolverSDEScheduler
- local: api/schedulers/singlestep_dpm_solver
title: DPMSolverSinglestepScheduler
- local: api/schedulers/edm_multistep_dpm_solver
title: EDMDPMSolverMultistepScheduler
- local: api/schedulers/edm_euler
title: EDMEulerScheduler
- local: api/schedulers/euler_ancestral
title: EulerAncestralDiscreteScheduler
- local: api/schedulers/euler
+54
View File
@@ -0,0 +1,54 @@
<!--Copyright 2023 The HuggingFace Team. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
specific language governing permissions and limitations under the License.
-->
# LEDITS++
LEDITS++ was proposed in [LEDITS++: Limitless Image Editing using Text-to-Image Models](https://huggingface.co/papers/2311.16711) by Manuel Brack, Felix Friedrich, Katharina Kornmeier, Linoy Tsaban, Patrick Schramowski, Kristian Kersting, Apolinário Passos.
The abstract from the paper is:
*Text-to-image diffusion models have recently received increasing interest for their astonishing ability to produce high-fidelity images from solely text inputs. Subsequent research efforts aim to exploit and apply their capabilities to real image editing. However, existing image-to-image methods are often inefficient, imprecise, and of limited versatility. They either require time-consuming fine-tuning, deviate unnecessarily strongly from the input image, and/or lack support for multiple, simultaneous edits. To address these issues, we introduce LEDITS++, an efficient yet versatile and precise textual image manipulation technique. LEDITS++'s novel inversion approach requires no tuning nor optimization and produces high-fidelity results with a few diffusion steps. Second, our methodology supports multiple simultaneous edits and is architecture-agnostic. Third, we use a novel implicit masking technique that limits changes to relevant image regions. We propose the novel TEdBench++ benchmark as part of our exhaustive evaluation. Our results demonstrate the capabilities of LEDITS++ and its improvements over previous methods. The project page is available at https://leditsplusplus-project.static.hf.space .*
<Tip>
You can find additional information about LEDITS++ on the [project page](https://leditsplusplus-project.static.hf.space/index.html) and try it out in a [demo](https://huggingface.co/spaces/editing-images/leditsplusplus).
</Tip>
<Tip warning={true}>
Due to some backward compatability issues with the current diffusers implementation of [`~schedulers.DPMSolverMultistepScheduler`] this implementation of LEdits++ can no longer guarantee perfect inversion.
This issue is unlikely to have any noticeable effects on applied use-cases. However, we provide an alternative implementation that guarantees perfect inversion in a dedicated [GitHub repo](https://github.com/ml-research/ledits_pp).
</Tip>
We provide two distinct pipelines based on different pre-trained models.
## LEditsPPPipelineStableDiffusion
[[autodoc]] pipelines.ledits_pp.LEditsPPPipelineStableDiffusion
- all
- __call__
- invert
## LEditsPPPipelineStableDiffusionXL
[[autodoc]] pipelines.ledits_pp.LEditsPPPipelineStableDiffusionXL
- all
- __call__
- invert
## LEditsPPDiffusionPipelineOutput
[[autodoc]] pipelines.ledits_pp.pipeline_output.LEditsPPDiffusionPipelineOutput
- all
## LEditsPPInversionPipelineOutput
[[autodoc]] pipelines.ledits_pp.pipeline_output.LEditsPPInversionPipelineOutput
- all
+1
View File
@@ -57,6 +57,7 @@ The table below lists all the pipelines currently available in 🤗 Diffusers an
| [Latent Consistency Models](latent_consistency_models) | text2image |
| [Latent Diffusion](latent_diffusion) | text2image, super-resolution |
| [LDM3D](stable_diffusion/ldm3d_diffusion) | text2image, text-to-3D, text-to-pano, upscaling |
| [LEDITS++](ledits_pp) | image editing |
| [MultiDiffusion](panorama) | text2image |
| [MusicLDM](musicldm) | text2audio |
| [Paint by Example](paint_by_example) | inpainting |
@@ -30,6 +30,6 @@ Make sure to check out the Schedulers [guide](../../using-diffusers/schedulers)
- all
- __call__
## StableDiffusionSafePipelineOutput
## SemanticStableDiffusionPipelineOutput
[[autodoc]] pipelines.semantic_stable_diffusion.pipeline_output.SemanticStableDiffusionPipelineOutput
- all
+154 -13
View File
@@ -12,13 +12,13 @@ specific language governing permissions and limitations under the License.
# Stable Cascade
This model is built upon the [Würstchen](https://openreview.net/forum?id=gU58d5QeGv) architecture and its main
difference to other models like Stable Diffusion is that it is working at a much smaller latent space. Why is this
important? The smaller the latent space, the **faster** you can run inference and the **cheaper** the training becomes.
How small is the latent space? Stable Diffusion uses a compression factor of 8, resulting in a 1024x1024 image being
encoded to 128x128. Stable Cascade achieves a compression factor of 42, meaning that it is possible to encode a
1024x1024 image to 24x24, while maintaining crisp reconstructions. The text-conditional model is then trained in the
highly compressed latent space. Previous versions of this architecture, achieved a 16x cost reduction over Stable
This model is built upon the [Würstchen](https://openreview.net/forum?id=gU58d5QeGv) architecture and its main
difference to other models like Stable Diffusion is that it is working at a much smaller latent space. Why is this
important? The smaller the latent space, the **faster** you can run inference and the **cheaper** the training becomes.
How small is the latent space? Stable Diffusion uses a compression factor of 8, resulting in a 1024x1024 image being
encoded to 128x128. Stable Cascade achieves a compression factor of 42, meaning that it is possible to encode a
1024x1024 image to 24x24, while maintaining crisp reconstructions. The text-conditional model is then trained in the
highly compressed latent space. Previous versions of this architecture, achieved a 16x cost reduction over Stable
Diffusion 1.5.
Therefore, this kind of model is well suited for usages where efficiency is important. Furthermore, all known extensions
@@ -30,13 +30,154 @@ The original codebase can be found at [Stability-AI/StableCascade](https://githu
Stable Cascade consists of three models: Stage A, Stage B and Stage C, representing a cascade to generate images,
hence the name "Stable Cascade".
Stage A & B are used to compress images, similar to what the job of the VAE is in Stable Diffusion.
However, with this setup, a much higher compression of images can be achieved. While the Stable Diffusion models use a
spatial compression factor of 8, encoding an image with resolution of 1024 x 1024 to 128 x 128, Stable Cascade achieves
a compression factor of 42. This encodes a 1024 x 1024 image to 24 x 24, while being able to accurately decode the
image. This comes with the great benefit of cheaper training and inference. Furthermore, Stage C is responsible
Stage A & B are used to compress images, similar to what the job of the VAE is in Stable Diffusion.
However, with this setup, a much higher compression of images can be achieved. While the Stable Diffusion models use a
spatial compression factor of 8, encoding an image with resolution of 1024 x 1024 to 128 x 128, Stable Cascade achieves
a compression factor of 42. This encodes a 1024 x 1024 image to 24 x 24, while being able to accurately decode the
image. This comes with the great benefit of cheaper training and inference. Furthermore, Stage C is responsible
for generating the small 24 x 24 latents given a text prompt.
The Stage C model operates on the small 24 x 24 latents and denoises the latents conditioned on text prompts. The model is also the largest component in the Cascade pipeline and is meant to be used with the `StableCascadePriorPipeline`
The Stage B and Stage A models are used with the `StableCascadeDecoderPipeline` and are responsible for generating the final image given the small 24 x 24 latents.
<Tip warning={true}>
There are some restrictions on data types that can be used with the Stable Cascade models. The official checkpoints for the `StableCascadePriorPipeline` do not support the `torch.float16` data type. Please use `torch.bfloat16` instead.
In order to use the `torch.bfloat16` data type with the `StableCascadeDecoderPipeline` you need to have PyTorch 2.2.0 or higher installed. This also means that using the `StableCascadeCombinedPipeline` with `torch.bfloat16` requires PyTorch 2.2.0 or higher, since it calls the `StableCascadeDecoderPipeline` internally.
If it is not possible to install PyTorch 2.2.0 or higher in your environment, the `StableCascadeDecoderPipeline` can be used on its own with the `torch.float16` data type. You can download the full precision or `bf16` variant weights for the pipeline and cast the weights to `torch.float16`.
</Tip>
## Usage example
```python
import torch
from diffusers import StableCascadeDecoderPipeline, StableCascadePriorPipeline
prompt = "an image of a shiba inu, donning a spacesuit and helmet"
negative_prompt = ""
prior = StableCascadePriorPipeline.from_pretrained("stabilityai/stable-cascade-prior", variant="bf16", torch_dtype=torch.bfloat16)
decoder = StableCascadeDecoderPipeline.from_pretrained("stabilityai/stable-cascade", variant="bf16", torch_dtype=torch.float16)
prior.enable_model_cpu_offload()
prior_output = prior(
prompt=prompt,
height=1024,
width=1024,
negative_prompt=negative_prompt,
guidance_scale=4.0,
num_images_per_prompt=1,
num_inference_steps=20
)
decoder.enable_model_cpu_offload()
decoder_output = decoder(
image_embeddings=prior_output.image_embeddings.to(torch.float16),
prompt=prompt,
negative_prompt=negative_prompt,
guidance_scale=0.0,
output_type="pil",
num_inference_steps=10
).images[0]
decoder_output.save("cascade.png")
```
## Using the Lite Versions of the Stage B and Stage C models
```python
import torch
from diffusers import (
StableCascadeDecoderPipeline,
StableCascadePriorPipeline,
StableCascadeUNet,
)
prompt = "an image of a shiba inu, donning a spacesuit and helmet"
negative_prompt = ""
prior_unet = StableCascadeUNet.from_pretrained("stabilityai/stable-cascade-prior", subfolder="prior_lite")
decoder_unet = StableCascadeUNet.from_pretrained("stabilityai/stable-cascade", subfolder="decoder_lite")
prior = StableCascadePriorPipeline.from_pretrained("stabilityai/stable-cascade-prior", prior=prior_unet)
decoder = StableCascadeDecoderPipeline.from_pretrained("stabilityai/stable-cascade", decoder=decoder_unet)
prior.enable_model_cpu_offload()
prior_output = prior(
prompt=prompt,
height=1024,
width=1024,
negative_prompt=negative_prompt,
guidance_scale=4.0,
num_images_per_prompt=1,
num_inference_steps=20
)
decoder.enable_model_cpu_offload()
decoder_output = decoder(
image_embeddings=prior_output.image_embeddings,
prompt=prompt,
negative_prompt=negative_prompt,
guidance_scale=0.0,
output_type="pil",
num_inference_steps=10
).images[0]
decoder_output.save("cascade.png")
```
## Loading original checkpoints with `from_single_file`
Loading the original format checkpoints is supported via `from_single_file` method in the StableCascadeUNet.
```python
import torch
from diffusers import (
StableCascadeDecoderPipeline,
StableCascadePriorPipeline,
StableCascadeUNet,
)
prompt = "an image of a shiba inu, donning a spacesuit and helmet"
negative_prompt = ""
prior_unet = StableCascadeUNet.from_single_file(
"https://huggingface.co/stabilityai/stable-cascade/resolve/main/stage_c_bf16.safetensors",
torch_dtype=torch.bfloat16
)
decoder_unet = StableCascadeUNet.from_single_file(
"https://huggingface.co/stabilityai/stable-cascade/blob/main/stage_b_bf16.safetensors",
torch_dtype=torch.bfloat16
)
prior = StableCascadePriorPipeline.from_pretrained("stabilityai/stable-cascade-prior", prior=prior_unet, torch_dtype=torch.bfloat16)
decoder = StableCascadeDecoderPipeline.from_pretrained("stabilityai/stable-cascade", decoder=decoder_unet, torch_dtype=torch.bfloat16)
prior.enable_model_cpu_offload()
prior_output = prior(
prompt=prompt,
height=1024,
width=1024,
negative_prompt=negative_prompt,
guidance_scale=4.0,
num_images_per_prompt=1,
num_inference_steps=20
)
decoder.enable_model_cpu_offload()
decoder_output = decoder(
image_embeddings=prior_output.image_embeddings,
prompt=prompt,
negative_prompt=negative_prompt,
guidance_scale=0.0,
output_type="pil",
num_inference_steps=10
).images[0]
decoder_output.save("cascade-single-file.png")
```
## Uses
### Direct Use
@@ -53,7 +194,7 @@ Excluded uses are described below.
### Out-of-Scope Use
The model was not trained to be factual or true representations of people or events,
The model was not trained to be factual or true representations of people or events,
and therefore using the model to generate such content is out-of-scope for the abilities of this model.
The model should not be used in any way that violates Stability AI's [Acceptable Use Policy](https://stability.ai/use-policy).
@@ -172,3 +172,41 @@ inpaint = StableDiffusionInpaintPipeline(**text2img.components)
# now you can use text2img(...), img2img(...), inpaint(...) just like the call methods of each respective pipeline
```
### Create web demos using `gradio`
The Stable Diffusion pipelines are automatically supported in [Gradio](https://github.com/gradio-app/gradio/), a library that makes creating beautiful and user-friendly machine learning apps on the web a breeze. First, make sure you have Gradio installed:
```
pip install -U gradio
```
Then, create a web demo around any Stable Diffusion-based pipeline. For example, you can create an image generation pipeline in a single line of code with Gradio's [`Interface.from_pipeline`](https://www.gradio.app/docs/interface#interface-from-pipeline) function:
```py
from diffusers import StableDiffusionPipeline
import gradio as gr
pipe = StableDiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-4")
gr.Interface.from_pipeline(pipe).launch()
```
which opens an intuitive drag-and-drop interface in your browser:
![](https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/gradio-panda.png)
Similarly, you could create a demo for an image-to-image pipeline with:
```py
from diffusers import StableDiffusionImg2ImgPipeline
import gradio as gr
pipe = StableDiffusionImg2ImgPipeline.from_pretrained("runwayml/stable-diffusion-v1-5")
gr.Interface.from_pipeline(pipe).launch()
```
By default, the web demo runs on a local server. If you'd like to share it with others, you can generate a temporary public
link by setting `share=True` in `launch()`. Or, you can host your demo on [Hugging Face Spaces](https://huggingface.co/spaces)https://huggingface.co/spaces for a permanent link.
@@ -0,0 +1,22 @@
<!--Copyright 2024 The HuggingFace Team. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
specific language governing permissions and limitations under the License.
-->
# EDMEulerScheduler
The Karras formulation of the Euler scheduler (Algorithm 2) from the [Elucidating the Design Space of Diffusion-Based Generative Models](https://huggingface.co/papers/2206.00364) paper by Karras et al. This is a fast scheduler which can often generate good outputs in 20-30 steps. The scheduler is based on the original [k-diffusion](https://github.com/crowsonkb/k-diffusion/blob/481677d114f6ea445aa009cf5bd7a9cdee909e47/k_diffusion/sampling.py#L51) implementation by [Katherine Crowson](https://github.com/crowsonkb/).
## EDMEulerScheduler
[[autodoc]] EDMEulerScheduler
## EDMEulerSchedulerOutput
[[autodoc]] schedulers.scheduling_edm_euler.EDMEulerSchedulerOutput
@@ -0,0 +1,24 @@
<!--Copyright 2024 The HuggingFace Team. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
specific language governing permissions and limitations under the License.
-->
# EDMDPMSolverMultistepScheduler
`EDMDPMSolverMultistepScheduler` is a [Karras formulation](https://huggingface.co/papers/2206.00364) of `DPMSolverMultistep`, a multistep scheduler from [DPM-Solver: A Fast ODE Solver for Diffusion Probabilistic Model Sampling in Around 10 Steps](https://huggingface.co/papers/2206.00927) and [DPM-Solver++: Fast Solver for Guided Sampling of Diffusion Probabilistic Models](https://huggingface.co/papers/2211.01095) by Cheng Lu, Yuhao Zhou, Fan Bao, Jianfei Chen, Chongxuan Li, and Jun Zhu.
DPMSolver (and the improved version DPMSolver++) is a fast dedicated high-order solver for diffusion ODEs with convergence order guarantee. Empirically, DPMSolver sampling with only 20 steps can generate high-quality
samples, and it can generate quite good samples even in 10 steps.
## EDMDPMSolverMultistepScheduler
[[autodoc]] EDMDPMSolverMultistepScheduler
## SchedulerOutput
[[autodoc]] schedulers.scheduling_utils.SchedulerOutput
@@ -0,0 +1,438 @@
<!--Copyright 2024 The HuggingFace Team. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
specific language governing permissions and limitations under the License.
-->
[[open-in-colab]]
# Trajectory Consistency Distillation-LoRA
Trajectory Consistency Distillation (TCD) enables a model to generate higher quality and more detailed images with fewer steps. Moreover, owing to the effective error mitigation during the distillation process, TCD demonstrates superior performance even under conditions of large inference steps.
The major advantages of TCD are:
- Better than Teacher: TCD demonstrates superior generative quality at both small and large inference steps and exceeds the performance of [DPM-Solver++(2S)](../../api/schedulers/multistep_dpm_solver) with Stable Diffusion XL (SDXL). There is no additional discriminator or LPIPS supervision included during TCD training.
- Flexible Inference Steps: The inference steps for TCD sampling can be freely adjusted without adversely affecting the image quality.
- Freely change detail level: During inference, the level of detail in the image can be adjusted with a single hyperparameter, *gamma*.
> [!TIP]
> For more technical details of TCD, please refer to the [paper](https://arxiv.org/abs/2402.19159) or official [project page](https://mhh0318.github.io/tcd/)).
For large models like SDXL, TCD is trained with [LoRA](https://huggingface.co/docs/peft/conceptual_guides/adapter#low-rank-adaptation-lora) to reduce memory usage. This is also useful because you can reuse LoRAs between different finetuned models, as long as they share the same base model, without further training.
This guide will show you how to perform inference with TCD-LoRAs for a variety of tasks like text-to-image and inpainting, as well as how you can easily combine TCD-LoRAs with other adapters. Choose one of the supported base model and it's corresponding TCD-LoRA checkpoint from the table below to get started.
| Base model | TCD-LoRA checkpoint |
|-------------------------------------------------------------------------------------------------|----------------------------------------------------------------|
| [stable-diffusion-v1-5](https://huggingface.co/runwayml/stable-diffusion-v1-5) | [TCD-SD15](https://huggingface.co/h1t/TCD-SD15-LoRA) |
| [stable-diffusion-2-1-base](https://huggingface.co/stabilityai/stable-diffusion-2-1-base) | [TCD-SD21-base](https://huggingface.co/h1t/TCD-SD21-base-LoRA) |
| [stable-diffusion-xl-base-1.0](https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0) | [TCD-SDXL](https://huggingface.co/h1t/TCD-SDXL-LoRA) |
Make sure you have [PEFT](https://github.com/huggingface/peft) installed for better LoRA support.
```bash
pip install -U peft
```
## General tasks
In this guide, let's use the [`StableDiffusionXLPipeline`] and the [`TCDScheduler`]. Use the [`~StableDiffusionPipeline.load_lora_weights`] method to load the SDXL-compatible TCD-LoRA weights.
A few tips to keep in mind for TCD-LoRA inference are to:
- Keep the `num_inference_steps` between 4 and 50
- Set `eta` (used to control stochasticity at each step) between 0 and 1. You should use a higher `eta` when increasing the number of inference steps, but the downside is that a larger `eta` in [`TCDScheduler`] leads to blurrier images. A value of 0.3 is recommended to produce good results.
<hfoptions id="tasks">
<hfoption id="text-to-image">
```python
import torch
from diffusers import StableDiffusionXLPipeline, TCDScheduler
device = "cuda"
base_model_id = "stabilityai/stable-diffusion-xl-base-1.0"
tcd_lora_id = "h1t/TCD-SDXL-LoRA"
pipe = StableDiffusionXLPipeline.from_pretrained(base_model_id, torch_dtype=torch.float16, variant="fp16").to(device)
pipe.scheduler = TCDScheduler.from_config(pipe.scheduler.config)
pipe.load_lora_weights(tcd_lora_id)
pipe.fuse_lora()
prompt = "Painting of the orange cat Otto von Garfield, Count of Bismarck-Schönhausen, Duke of Lauenburg, Minister-President of Prussia. Depicted wearing a Prussian Pickelhaube and eating his favorite meal - lasagna."
image = pipe(
prompt=prompt,
num_inference_steps=4,
guidance_scale=0,
eta=0.3,
generator=torch.Generator(device=device).manual_seed(0),
).images[0]
```
![](https://github.com/jabir-zheng/TCD/raw/main/assets/demo_image.png)
</hfoption>
<hfoption id="inpainting">
```python
import torch
from diffusers import AutoPipelineForInpainting, TCDScheduler
from diffusers.utils import load_image, make_image_grid
device = "cuda"
base_model_id = "diffusers/stable-diffusion-xl-1.0-inpainting-0.1"
tcd_lora_id = "h1t/TCD-SDXL-LoRA"
pipe = AutoPipelineForInpainting.from_pretrained(base_model_id, torch_dtype=torch.float16, variant="fp16").to(device)
pipe.scheduler = TCDScheduler.from_config(pipe.scheduler.config)
pipe.load_lora_weights(tcd_lora_id)
pipe.fuse_lora()
img_url = "https://raw.githubusercontent.com/CompVis/latent-diffusion/main/data/inpainting_examples/overture-creations-5sI6fQgYIuo.png"
mask_url = "https://raw.githubusercontent.com/CompVis/latent-diffusion/main/data/inpainting_examples/overture-creations-5sI6fQgYIuo_mask.png"
init_image = load_image(img_url).resize((1024, 1024))
mask_image = load_image(mask_url).resize((1024, 1024))
prompt = "a tiger sitting on a park bench"
image = pipe(
prompt=prompt,
image=init_image,
mask_image=mask_image,
num_inference_steps=8,
guidance_scale=0,
eta=0.3,
strength=0.99, # make sure to use `strength` below 1.0
generator=torch.Generator(device=device).manual_seed(0),
).images[0]
grid_image = make_image_grid([init_image, mask_image, image], rows=1, cols=3)
```
![](https://github.com/jabir-zheng/TCD/raw/main/assets/inpainting_tcd.png)
</hfoption>
</hfoptions>
## Community models
TCD-LoRA also works with many community finetuned models and plugins. For example, load the [animagine-xl-3.0](https://huggingface.co/cagliostrolab/animagine-xl-3.0) checkpoint which is a community finetuned version of SDXL for generating anime images.
```python
import torch
from diffusers import StableDiffusionXLPipeline, TCDScheduler
device = "cuda"
base_model_id = "cagliostrolab/animagine-xl-3.0"
tcd_lora_id = "h1t/TCD-SDXL-LoRA"
pipe = StableDiffusionXLPipeline.from_pretrained(base_model_id, torch_dtype=torch.float16, variant="fp16").to(device)
pipe.scheduler = TCDScheduler.from_config(pipe.scheduler.config)
pipe.load_lora_weights(tcd_lora_id)
pipe.fuse_lora()
prompt = "A man, clad in a meticulously tailored military uniform, stands with unwavering resolve. The uniform boasts intricate details, and his eyes gleam with determination. Strands of vibrant, windswept hair peek out from beneath the brim of his cap."
image = pipe(
prompt=prompt,
num_inference_steps=8,
guidance_scale=0,
eta=0.3,
generator=torch.Generator(device=device).manual_seed(0),
).images[0]
```
![](https://github.com/jabir-zheng/TCD/raw/main/assets/animagine_xl.png)
TCD-LoRA also supports other LoRAs trained on different styles. For example, let's load the [TheLastBen/Papercut_SDXL](https://huggingface.co/TheLastBen/Papercut_SDXL) LoRA and fuse it with the TCD-LoRA with the [`~loaders.UNet2DConditionLoadersMixin.set_adapters`] method.
> [!TIP]
> Check out the [Merge LoRAs](merge_loras) guide to learn more about efficient merging methods.
```python
import torch
from diffusers import StableDiffusionXLPipeline
from scheduling_tcd import TCDScheduler
device = "cuda"
base_model_id = "stabilityai/stable-diffusion-xl-base-1.0"
tcd_lora_id = "h1t/TCD-SDXL-LoRA"
styled_lora_id = "TheLastBen/Papercut_SDXL"
pipe = StableDiffusionXLPipeline.from_pretrained(base_model_id, torch_dtype=torch.float16, variant="fp16").to(device)
pipe.scheduler = TCDScheduler.from_config(pipe.scheduler.config)
pipe.load_lora_weights(tcd_lora_id, adapter_name="tcd")
pipe.load_lora_weights(styled_lora_id, adapter_name="style")
pipe.set_adapters(["tcd", "style"], adapter_weights=[1.0, 1.0])
prompt = "papercut of a winter mountain, snow"
image = pipe(
prompt=prompt,
num_inference_steps=4,
guidance_scale=0,
eta=0.3,
generator=torch.Generator(device=device).manual_seed(0),
).images[0]
```
![](https://github.com/jabir-zheng/TCD/raw/main/assets/styled_lora.png)
## Adapters
TCD-LoRA is very versatile, and it can be combined with other adapter types like ControlNets, IP-Adapter, and AnimateDiff.
<hfoptions id="adapters">
<hfoption id="ControlNet">
### Depth ControlNet
```python
import torch
import numpy as np
from PIL import Image
from transformers import DPTFeatureExtractor, DPTForDepthEstimation
from diffusers import ControlNetModel, StableDiffusionXLControlNetPipeline
from diffusers.utils import load_image, make_image_grid
from scheduling_tcd import TCDScheduler
device = "cuda"
depth_estimator = DPTForDepthEstimation.from_pretrained("Intel/dpt-hybrid-midas").to(device)
feature_extractor = DPTFeatureExtractor.from_pretrained("Intel/dpt-hybrid-midas")
def get_depth_map(image):
image = feature_extractor(images=image, return_tensors="pt").pixel_values.to(device)
with torch.no_grad(), torch.autocast(device):
depth_map = depth_estimator(image).predicted_depth
depth_map = torch.nn.functional.interpolate(
depth_map.unsqueeze(1),
size=(1024, 1024),
mode="bicubic",
align_corners=False,
)
depth_min = torch.amin(depth_map, dim=[1, 2, 3], keepdim=True)
depth_max = torch.amax(depth_map, dim=[1, 2, 3], keepdim=True)
depth_map = (depth_map - depth_min) / (depth_max - depth_min)
image = torch.cat([depth_map] * 3, dim=1)
image = image.permute(0, 2, 3, 1).cpu().numpy()[0]
image = Image.fromarray((image * 255.0).clip(0, 255).astype(np.uint8))
return image
base_model_id = "stabilityai/stable-diffusion-xl-base-1.0"
controlnet_id = "diffusers/controlnet-depth-sdxl-1.0"
tcd_lora_id = "h1t/TCD-SDXL-LoRA"
controlnet = ControlNetModel.from_pretrained(
controlnet_id,
torch_dtype=torch.float16,
variant="fp16",
).to(device)
pipe = StableDiffusionXLControlNetPipeline.from_pretrained(
base_model_id,
controlnet=controlnet,
torch_dtype=torch.float16,
variant="fp16",
).to(device)
pipe.enable_model_cpu_offload()
pipe.scheduler = TCDScheduler.from_config(pipe.scheduler.config)
pipe.load_lora_weights(tcd_lora_id)
pipe.fuse_lora()
prompt = "stormtrooper lecture, photorealistic"
image = load_image("https://huggingface.co/lllyasviel/sd-controlnet-depth/resolve/main/images/stormtrooper.png")
depth_image = get_depth_map(image)
controlnet_conditioning_scale = 0.5 # recommended for good generalization
image = pipe(
prompt,
image=depth_image,
num_inference_steps=4,
guidance_scale=0,
eta=0.3,
controlnet_conditioning_scale=controlnet_conditioning_scale,
generator=torch.Generator(device=device).manual_seed(0),
).images[0]
grid_image = make_image_grid([depth_image, image], rows=1, cols=2)
```
![](https://github.com/jabir-zheng/TCD/raw/main/assets/controlnet_depth_tcd.png)
### Canny ControlNet
```python
import torch
from diffusers import ControlNetModel, StableDiffusionXLControlNetPipeline
from diffusers.utils import load_image, make_image_grid
from scheduling_tcd import TCDScheduler
device = "cuda"
base_model_id = "stabilityai/stable-diffusion-xl-base-1.0"
controlnet_id = "diffusers/controlnet-canny-sdxl-1.0"
tcd_lora_id = "h1t/TCD-SDXL-LoRA"
controlnet = ControlNetModel.from_pretrained(
controlnet_id,
torch_dtype=torch.float16,
variant="fp16",
).to(device)
pipe = StableDiffusionXLControlNetPipeline.from_pretrained(
base_model_id,
controlnet=controlnet,
torch_dtype=torch.float16,
variant="fp16",
).to(device)
pipe.enable_model_cpu_offload()
pipe.scheduler = TCDScheduler.from_config(pipe.scheduler.config)
pipe.load_lora_weights(tcd_lora_id)
pipe.fuse_lora()
prompt = "ultrarealistic shot of a furry blue bird"
canny_image = load_image("https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/sd_controlnet/bird_canny.png")
controlnet_conditioning_scale = 0.5 # recommended for good generalization
image = pipe(
prompt,
image=canny_image,
num_inference_steps=4,
guidance_scale=0,
eta=0.3,
controlnet_conditioning_scale=controlnet_conditioning_scale,
generator=torch.Generator(device=device).manual_seed(0),
).images[0]
grid_image = make_image_grid([canny_image, image], rows=1, cols=2)
```
![](https://github.com/jabir-zheng/TCD/raw/main/assets/controlnet_canny_tcd.png)
<Tip>
The inference parameters in this example might not work for all examples, so we recommend you to try different values for `num_inference_steps`, `guidance_scale`, `controlnet_conditioning_scale` and `cross_attention_kwargs` parameters and choose the best one.
</Tip>
</hfoption>
<hfoption id="IP-Adapter">
This example shows how to use the TCD-LoRA with the [IP-Adapter](https://github.com/tencent-ailab/IP-Adapter/tree/main) and SDXL.
```python
import torch
from diffusers import StableDiffusionXLPipeline
from diffusers.utils import load_image, make_image_grid
from ip_adapter import IPAdapterXL
from scheduling_tcd import TCDScheduler
device = "cuda"
base_model_path = "stabilityai/stable-diffusion-xl-base-1.0"
image_encoder_path = "sdxl_models/image_encoder"
ip_ckpt = "sdxl_models/ip-adapter_sdxl.bin"
tcd_lora_id = "h1t/TCD-SDXL-LoRA"
pipe = StableDiffusionXLPipeline.from_pretrained(
base_model_path,
torch_dtype=torch.float16,
variant="fp16"
)
pipe.scheduler = TCDScheduler.from_config(pipe.scheduler.config)
pipe.load_lora_weights(tcd_lora_id)
pipe.fuse_lora()
ip_model = IPAdapterXL(pipe, image_encoder_path, ip_ckpt, device)
ref_image = load_image("https://raw.githubusercontent.com/tencent-ailab/IP-Adapter/main/assets/images/woman.png").resize((512, 512))
prompt = "best quality, high quality, wearing sunglasses"
image = ip_model.generate(
pil_image=ref_image,
prompt=prompt,
scale=0.5,
num_samples=1,
num_inference_steps=4,
guidance_scale=0,
eta=0.3,
seed=0,
)[0]
grid_image = make_image_grid([ref_image, image], rows=1, cols=2)
```
![](https://github.com/jabir-zheng/TCD/raw/main/assets/ip_adapter.png)
</hfoption>
<hfoption id="AnimateDiff">
[`AnimateDiff`] allows animating images using Stable Diffusion models. TCD-LoRA can substantially accelerate the process without degrading image quality. The quality of animation with TCD-LoRA and AnimateDiff has a more lucid outcome.
```python
import torch
from diffusers import MotionAdapter, AnimateDiffPipeline, DDIMScheduler
from scheduling_tcd import TCDScheduler
from diffusers.utils import export_to_gif
adapter = MotionAdapter.from_pretrained("guoyww/animatediff-motion-adapter-v1-5")
pipe = AnimateDiffPipeline.from_pretrained(
"frankjoshua/toonyou_beta6",
motion_adapter=adapter,
).to("cuda")
# set TCDScheduler
pipe.scheduler = TCDScheduler.from_config(pipe.scheduler.config)
# load TCD LoRA
pipe.load_lora_weights("h1t/TCD-SD15-LoRA", adapter_name="tcd")
pipe.load_lora_weights("guoyww/animatediff-motion-lora-zoom-in", weight_name="diffusion_pytorch_model.safetensors", adapter_name="motion-lora")
pipe.set_adapters(["tcd", "motion-lora"], adapter_weights=[1.0, 1.2])
prompt = "best quality, masterpiece, 1girl, looking at viewer, blurry background, upper body, contemporary, dress"
generator = torch.manual_seed(0)
frames = pipe(
prompt=prompt,
num_inference_steps=5,
guidance_scale=0,
cross_attention_kwargs={"scale": 1},
num_frames=24,
eta=0.3,
generator=generator
).frames[0]
export_to_gif(frames, "animation.gif")
```
![](https://github.com/jabir-zheng/TCD/raw/main/assets/animation_example.gif)
</hfoption>
</hfoptions>
@@ -259,6 +259,50 @@ pip install git+https://github.com/huggingface/peft.git
**Inference**
The inference is the same as if you train a regular LoRA 🤗
## Conducting EDM-style training
It's now possible to perform EDM-style training as proposed in [Elucidating the Design Space of Diffusion-Based Generative Models](https://arxiv.org/abs/2206.00364).
simply set:
```diff
+ --do_edm_style_training \
```
Other SDXL-like models that use the EDM formulation, such as [playgroundai/playground-v2.5-1024px-aesthetic](https://huggingface.co/playgroundai/playground-v2.5-1024px-aesthetic), can also be DreamBooth'd with the script. Below is an example command:
```bash
accelerate launch train_dreambooth_lora_sdxl_advanced.py \
--pretrained_model_name_or_path="playgroundai/playground-v2.5-1024px-aesthetic" \
--dataset_name="linoyts/3d_icon" \
--instance_prompt="3d icon in the style of TOK" \
--validation_prompt="a TOK icon of an astronaut riding a horse, in the style of TOK" \
--output_dir="3d-icon-SDXL-LoRA" \
--do_edm_style_training \
--caption_column="prompt" \
--mixed_precision="bf16" \
--resolution=1024 \
--train_batch_size=3 \
--repeats=1 \
--report_to="wandb"\
--gradient_accumulation_steps=1 \
--gradient_checkpointing \
--learning_rate=1.0 \
--text_encoder_lr=1.0 \
--optimizer="prodigy"\
--train_text_encoder_ti\
--train_text_encoder_ti_frac=0.5\
--lr_scheduler="constant" \
--lr_warmup_steps=0 \
--rank=8 \
--max_train_steps=1000 \
--checkpointing_steps=2000 \
--seed="0" \
--push_to_hub
```
> [!CAUTION]
> Min-SNR gamma is not supported with the EDM-style training yet. When training with the PlaygroundAI model, it's recommended to not pass any "variant".
### Tips and Tricks
Check out [these recommended practices](https://huggingface.co/blog/sdxl_lora_advanced_script#additional-good-practices)
@@ -70,7 +70,7 @@ from diffusers.utils.import_utils import is_xformers_available
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.27.0.dev0")
check_min_version("0.28.0.dev0")
logger = get_logger(__name__)
@@ -1215,7 +1215,7 @@ def main(args):
xformers_version = version.parse(xformers.__version__)
if xformers_version == version.parse("0.0.16"):
logger.warn(
logger.warning(
"xFormers 0.0.16 cannot be used for training in some GPUs. If you observe problems during training, "
"please update xFormers to at least 0.0.17. See https://huggingface.co/docs/diffusers/main/en/optimization/xformers for more details."
)
@@ -1366,14 +1366,14 @@ def main(args):
# Optimizer creation
if not (args.optimizer.lower() == "prodigy" or args.optimizer.lower() == "adamw"):
logger.warn(
logger.warning(
f"Unsupported choice of optimizer: {args.optimizer}.Supported optimizers include [adamW, prodigy]."
"Defaulting to adamW"
)
args.optimizer = "adamw"
if args.use_8bit_adam and not args.optimizer.lower() == "adamw":
logger.warn(
logger.warning(
f"use_8bit_adam is ignored when optimizer is not set to 'AdamW'. Optimizer was "
f"set to {args.optimizer.lower()}"
)
@@ -1407,11 +1407,11 @@ def main(args):
optimizer_class = prodigyopt.Prodigy
if args.learning_rate <= 0.1:
logger.warn(
logger.warning(
"Learning rate is too low. When using prodigy, it's generally better to set learning rate around 1.0"
)
if args.train_text_encoder and args.text_encoder_lr:
logger.warn(
logger.warning(
f"Learning rates were provided both for the unet and the text encoder- e.g. text_encoder_lr:"
f" {args.text_encoder_lr} and learning_rate: {args.learning_rate}. "
f"When using prodigy only learning_rate is used as the initial learning rate."
@@ -14,9 +14,11 @@
# See the License for the specific language governing permissions and
import argparse
import contextlib
import gc
import hashlib
import itertools
import json
import logging
import math
import os
@@ -37,7 +39,7 @@ import transformers
from accelerate import Accelerator
from accelerate.logging import get_logger
from accelerate.utils import DistributedDataParallelKwargs, ProjectConfiguration, set_seed
from huggingface_hub import create_repo, upload_folder
from huggingface_hub import create_repo, hf_hub_download, upload_folder
from packaging import version
from peft import LoraConfig, set_peft_model_state_dict
from peft.utils import get_peft_model_state_dict
@@ -55,6 +57,8 @@ from diffusers import (
AutoencoderKL,
DDPMScheduler,
DPMSolverMultistepScheduler,
EDMEulerScheduler,
EulerDiscreteScheduler,
StableDiffusionXLPipeline,
UNet2DConditionModel,
)
@@ -74,11 +78,25 @@ from diffusers.utils.torch_utils import is_compiled_module
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.27.0.dev0")
check_min_version("0.28.0.dev0")
logger = get_logger(__name__)
def determine_scheduler_type(pretrained_model_name_or_path, revision):
model_index_filename = "model_index.json"
if os.path.isdir(pretrained_model_name_or_path):
model_index = os.path.join(pretrained_model_name_or_path, model_index_filename)
else:
model_index = hf_hub_download(
repo_id=pretrained_model_name_or_path, filename=model_index_filename, revision=revision
)
with open(model_index, "r") as f:
scheduler_type = json.load(f)["scheduler"][1]
return scheduler_type
def save_model_card(
repo_id: str,
use_dora: bool,
@@ -370,6 +388,11 @@ def parse_args(input_args=None):
" `args.validation_prompt` multiple times: `args.num_validation_images`."
),
)
parser.add_argument(
"--do_edm_style_training",
action="store_true",
help="Flag to conduct training using the EDM formulation as introduced in https://arxiv.org/abs/2206.00364.",
)
parser.add_argument(
"--with_prior_preservation",
default=False,
@@ -1117,6 +1140,8 @@ def main(args):
"You cannot use both --report_to=wandb and --hub_token due to a security risk of exposing your token."
" Please use `huggingface-cli login` to authenticate with the Hub."
)
if args.do_edm_style_training and args.snr_gamma is not None:
raise ValueError("Min-SNR formulation is not supported when conducting EDM-style training.")
logging_dir = Path(args.output_dir, args.logging_dir)
@@ -1234,7 +1259,19 @@ def main(args):
)
# Load scheduler and models
noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler")
scheduler_type = determine_scheduler_type(args.pretrained_model_name_or_path, args.revision)
if "EDM" in scheduler_type:
args.do_edm_style_training = True
noise_scheduler = EDMEulerScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler")
logger.info("Performing EDM-style training!")
elif args.do_edm_style_training:
noise_scheduler = EulerDiscreteScheduler.from_pretrained(
args.pretrained_model_name_or_path, subfolder="scheduler"
)
logger.info("Performing EDM-style training!")
else:
noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler")
text_encoder_one = text_encoder_cls_one.from_pretrained(
args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision, variant=args.variant
)
@@ -1252,7 +1289,12 @@ def main(args):
revision=args.revision,
variant=args.variant,
)
vae_scaling_factor = vae.config.scaling_factor
latents_mean = latents_std = None
if hasattr(vae.config, "latents_mean") and vae.config.latents_mean is not None:
latents_mean = torch.tensor(vae.config.latents_mean).view(1, 4, 1, 1)
if hasattr(vae.config, "latents_std") and vae.config.latents_std is not None:
latents_std = torch.tensor(vae.config.latents_std).view(1, 4, 1, 1)
unet = UNet2DConditionModel.from_pretrained(
args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision, variant=args.variant
)
@@ -1317,7 +1359,7 @@ def main(args):
xformers_version = version.parse(xformers.__version__)
if xformers_version == version.parse("0.0.16"):
logger.warn(
logger.warning(
"xFormers 0.0.16 cannot be used for training in some GPUs. If you observe problems during training, "
"please update xFormers to at least 0.0.17. See https://huggingface.co/docs/diffusers/main/en/optimization/xformers for more details."
)
@@ -1522,14 +1564,14 @@ def main(args):
# Optimizer creation
if not (args.optimizer.lower() == "prodigy" or args.optimizer.lower() == "adamw"):
logger.warn(
logger.warning(
f"Unsupported choice of optimizer: {args.optimizer}.Supported optimizers include [adamW, prodigy]."
"Defaulting to adamW"
)
args.optimizer = "adamw"
if args.use_8bit_adam and not args.optimizer.lower() == "adamw":
logger.warn(
logger.warning(
f"use_8bit_adam is ignored when optimizer is not set to 'AdamW'. Optimizer was "
f"set to {args.optimizer.lower()}"
)
@@ -1563,11 +1605,11 @@ def main(args):
optimizer_class = prodigyopt.Prodigy
if args.learning_rate <= 0.1:
logger.warn(
logger.warning(
"Learning rate is too low. When using prodigy, it's generally better to set learning rate around 1.0"
)
if args.train_text_encoder and args.text_encoder_lr:
logger.warn(
logger.warning(
f"Learning rates were provided both for the unet and the text encoder- e.g. text_encoder_lr:"
f" {args.text_encoder_lr} and learning_rate: {args.learning_rate}. "
f"When using prodigy only learning_rate is used as the initial learning rate."
@@ -1790,6 +1832,19 @@ def main(args):
disable=not accelerator.is_local_main_process,
)
def get_sigmas(timesteps, n_dim=4, dtype=torch.float32):
# TODO: revisit other sampling algorithms
sigmas = noise_scheduler.sigmas.to(device=accelerator.device, dtype=dtype)
schedule_timesteps = noise_scheduler.timesteps.to(accelerator.device)
timesteps = timesteps.to(accelerator.device)
step_indices = [(schedule_timesteps == t).nonzero().item() for t in timesteps]
sigma = sigmas[step_indices].flatten()
while len(sigma.shape) < n_dim:
sigma = sigma.unsqueeze(-1)
return sigma
if args.train_text_encoder:
num_train_epochs_text_encoder = int(args.train_text_encoder_frac * args.num_train_epochs)
elif args.train_text_encoder_ti: # args.train_text_encoder_ti
@@ -1841,9 +1896,15 @@ def main(args):
pixel_values = batch["pixel_values"].to(dtype=vae.dtype)
model_input = vae.encode(pixel_values).latent_dist.sample()
model_input = model_input * vae_scaling_factor
if args.pretrained_vae_model_name_or_path is None:
model_input = model_input.to(weight_dtype)
if latents_mean is None and latents_std is None:
model_input = model_input * vae.config.scaling_factor
if args.pretrained_vae_model_name_or_path is None:
model_input = model_input.to(weight_dtype)
else:
latents_mean = latents_mean.to(device=model_input.device, dtype=model_input.dtype)
latents_std = latents_std.to(device=model_input.device, dtype=model_input.dtype)
model_input = (model_input - latents_mean) * vae.config.scaling_factor / latents_std
model_input = model_input.to(dtype=weight_dtype)
# Sample noise that we'll add to the latents
noise = torch.randn_like(model_input)
@@ -1854,15 +1915,32 @@ def main(args):
)
bsz = model_input.shape[0]
# Sample a random timestep for each image
timesteps = torch.randint(
0, noise_scheduler.config.num_train_timesteps, (bsz,), device=model_input.device
)
timesteps = timesteps.long()
if not args.do_edm_style_training:
timesteps = torch.randint(
0, noise_scheduler.config.num_train_timesteps, (bsz,), device=model_input.device
)
timesteps = timesteps.long()
else:
# in EDM formulation, the model is conditioned on the pre-conditioned noise levels
# instead of discrete timesteps, so here we sample indices to get the noise levels
# from `scheduler.timesteps`
indices = torch.randint(0, noise_scheduler.config.num_train_timesteps, (bsz,))
timesteps = noise_scheduler.timesteps[indices].to(device=model_input.device)
# Add noise to the model input according to the noise magnitude at each timestep
# (this is the forward diffusion process)
noisy_model_input = noise_scheduler.add_noise(model_input, noise, timesteps)
# For EDM-style training, we first obtain the sigmas based on the continuous timesteps.
# We then precondition the final model inputs based on these sigmas instead of the timesteps.
# Follow: Section 5 of https://arxiv.org/abs/2206.00364.
if args.do_edm_style_training:
sigmas = get_sigmas(timesteps, len(noisy_model_input.shape), noisy_model_input.dtype)
if "EDM" in scheduler_type:
inp_noisy_latents = noise_scheduler.precondition_inputs(noisy_model_input, sigmas)
else:
inp_noisy_latents = noisy_model_input / ((sigmas**2 + 1) ** 0.5)
# time ids
add_time_ids = torch.cat(
@@ -1888,7 +1966,7 @@ def main(args):
}
prompt_embeds_input = prompt_embeds.repeat(elems_to_repeat_text_embeds, 1, 1)
model_pred = unet(
noisy_model_input,
inp_noisy_latents if args.do_edm_style_training else noisy_model_input,
timesteps,
prompt_embeds_input,
added_cond_kwargs=unet_added_conditions,
@@ -1906,14 +1984,42 @@ def main(args):
)
prompt_embeds_input = prompt_embeds.repeat(elems_to_repeat_text_embeds, 1, 1)
model_pred = unet(
noisy_model_input, timesteps, prompt_embeds_input, added_cond_kwargs=unet_added_conditions
inp_noisy_latents if args.do_edm_style_training else noisy_model_input,
timesteps,
prompt_embeds_input,
added_cond_kwargs=unet_added_conditions,
).sample
weighting = None
if args.do_edm_style_training:
# Similar to the input preconditioning, the model predictions are also preconditioned
# on noised model inputs (before preconditioning) and the sigmas.
# Follow: Section 5 of https://arxiv.org/abs/2206.00364.
if "EDM" in scheduler_type:
model_pred = noise_scheduler.precondition_outputs(noisy_model_input, model_pred, sigmas)
else:
if noise_scheduler.config.prediction_type == "epsilon":
model_pred = model_pred * (-sigmas) + noisy_model_input
elif noise_scheduler.config.prediction_type == "v_prediction":
model_pred = model_pred * (-sigmas / (sigmas**2 + 1) ** 0.5) + (
noisy_model_input / (sigmas**2 + 1)
)
# We are not doing weighting here because it tends result in numerical problems.
# See: https://github.com/huggingface/diffusers/pull/7126#issuecomment-1968523051
# There might be other alternatives for weighting as well:
# https://github.com/huggingface/diffusers/pull/7126#discussion_r1505404686
if "EDM" not in scheduler_type:
weighting = (sigmas**-2.0).float()
# Get the target for loss depending on the prediction type
if noise_scheduler.config.prediction_type == "epsilon":
target = noise
target = model_input if args.do_edm_style_training else noise
elif noise_scheduler.config.prediction_type == "v_prediction":
target = noise_scheduler.get_velocity(model_input, noise, timesteps)
target = (
model_input
if args.do_edm_style_training
else noise_scheduler.get_velocity(model_input, noise, timesteps)
)
else:
raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}")
@@ -1923,10 +2029,28 @@ def main(args):
target, target_prior = torch.chunk(target, 2, dim=0)
# Compute prior loss
prior_loss = F.mse_loss(model_pred_prior.float(), target_prior.float(), reduction="mean")
if weighting is not None:
prior_loss = torch.mean(
(weighting.float() * (model_pred_prior.float() - target_prior.float()) ** 2).reshape(
target_prior.shape[0], -1
),
1,
)
prior_loss = prior_loss.mean()
else:
prior_loss = F.mse_loss(model_pred_prior.float(), target_prior.float(), reduction="mean")
if args.snr_gamma is None:
loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean")
if weighting is not None:
loss = torch.mean(
(weighting.float() * (model_pred.float() - target.float()) ** 2).reshape(
target.shape[0], -1
),
1,
)
loss = loss.mean()
else:
loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean")
else:
# Compute loss-weights as per Section 3.4 of https://arxiv.org/abs/2303.09556.
# Since we predict the noise instead of x_0, the original formulation is slightly changed.
@@ -2049,17 +2173,18 @@ def main(args):
# We train on the simplified learning objective. If we were previously predicting a variance, we need the scheduler to ignore it
scheduler_args = {}
if "variance_type" in pipeline.scheduler.config:
variance_type = pipeline.scheduler.config.variance_type
if not args.do_edm_style_training:
if "variance_type" in pipeline.scheduler.config:
variance_type = pipeline.scheduler.config.variance_type
if variance_type in ["learned", "learned_range"]:
variance_type = "fixed_small"
if variance_type in ["learned", "learned_range"]:
variance_type = "fixed_small"
scheduler_args["variance_type"] = variance_type
scheduler_args["variance_type"] = variance_type
pipeline.scheduler = DPMSolverMultistepScheduler.from_config(
pipeline.scheduler.config, **scheduler_args
)
pipeline.scheduler = DPMSolverMultistepScheduler.from_config(
pipeline.scheduler.config, **scheduler_args
)
pipeline = pipeline.to(accelerator.device)
pipeline.set_progress_bar_config(disable=True)
@@ -2067,8 +2192,13 @@ def main(args):
# run inference
generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed else None
pipeline_args = {"prompt": args.validation_prompt}
inference_ctx = (
contextlib.nullcontext()
if "playground" in args.pretrained_model_name_or_path
else torch.cuda.amp.autocast()
)
with torch.cuda.amp.autocast():
with inference_ctx:
images = [
pipeline(**pipeline_args, generator=generator).images[0]
for _ in range(args.num_validation_images)
@@ -2144,15 +2274,18 @@ def main(args):
# We train on the simplified learning objective. If we were previously predicting a variance, we need the scheduler to ignore it
scheduler_args = {}
if "variance_type" in pipeline.scheduler.config:
variance_type = pipeline.scheduler.config.variance_type
if not args.do_edm_style_training:
if "variance_type" in pipeline.scheduler.config:
variance_type = pipeline.scheduler.config.variance_type
if variance_type in ["learned", "learned_range"]:
variance_type = "fixed_small"
if variance_type in ["learned", "learned_range"]:
variance_type = "fixed_small"
scheduler_args["variance_type"] = variance_type
scheduler_args["variance_type"] = variance_type
pipeline.scheduler = DPMSolverMultistepScheduler.from_config(pipeline.scheduler.config, **scheduler_args)
pipeline.scheduler = DPMSolverMultistepScheduler.from_config(
pipeline.scheduler.config, **scheduler_args
)
# load attention processors
pipeline.load_lora_weights(args.output_dir)
@@ -513,9 +513,7 @@ class LCMSchedulerWithTimestamp(SchedulerMixin, ConfigMixin):
there is no previous alpha. When this option is `True` the previous alpha product is fixed to `1`,
otherwise it uses the alpha value at step 0.
steps_offset (`int`, defaults to 0):
An offset added to the inference steps. You can use a combination of `offset=1` and
`set_alpha_to_one=False` to make the last step use step 0 for the previous alpha product like in Stable
Diffusion.
An offset added to the inference steps, as required by some model families.
prediction_type (`str`, defaults to `epsilon`, *optional*):
Prediction type of the scheduler function; can be `epsilon` (predicts the noise of the diffusion process),
`sample` (directly predicts the noisy sample`) or `v_prediction` (see section 2.4 of [Imagen
@@ -418,9 +418,7 @@ class LCMScheduler(SchedulerMixin, ConfigMixin):
there is no previous alpha. When this option is `True` the previous alpha product is fixed to `1`,
otherwise it uses the alpha value at step 0.
steps_offset (`int`, defaults to 0):
An offset added to the inference steps. You can use a combination of `offset=1` and
`set_alpha_to_one=False` to make the last step use step 0 for the previous alpha product like in Stable
Diffusion.
An offset added to the inference steps, as required by some model families.
prediction_type (`str`, defaults to `epsilon`, *optional*):
Prediction type of the scheduler function; can be `epsilon` (predicts the noise of the diffusion process),
`sample` (directly predicts the noisy sample`) or `v_prediction` (see section 2.4 of [Imagen
@@ -40,7 +40,7 @@ from diffusers.utils import BaseOutput, check_min_version
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.27.0.dev0")
check_min_version("0.28.0.dev0")
class MarigoldDepthOutput(BaseOutput):
@@ -452,7 +452,7 @@ class StableDiffusionXLInstantIDPipeline(StableDiffusionXLControlNetPipeline):
xformers_version = version.parse(xformers.__version__)
if xformers_version == version.parse("0.0.16"):
logger.warn(
logger.warning(
"xFormers 0.0.16 cannot be used for training in some GPUs. If you observe problems during training, please update xFormers to at least 0.0.17. See https://huggingface.co/docs/diffusers/main/en/optimization/xformers for more details."
)
self.enable_xformers_memory_efficient_attention()
+1 -3
View File
@@ -171,9 +171,7 @@ class UFOGenScheduler(SchedulerMixin, ConfigMixin):
The way the timesteps should be scaled. Refer to Table 2 of the [Common Diffusion Noise Schedules and
Sample Steps are Flawed](https://huggingface.co/papers/2305.08891) for more information.
steps_offset (`int`, defaults to 0):
An offset added to the inference steps. You can use a combination of `offset=1` and
`set_alpha_to_one=False` to make the last step use step 0 for the previous alpha product like in Stable
Diffusion.
An offset added to the inference steps, as required by some model families.
rescale_betas_zero_snr (`bool`, defaults to `False`):
Whether to rescale the betas to have zero terminal SNR. This enables the model to generate very bright and
dark samples instead of limiting it to samples with medium brightness. Loosely related to
@@ -72,7 +72,7 @@ if is_wandb_available():
import wandb
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.27.0.dev0")
check_min_version("0.28.0.dev0")
logger = get_logger(__name__)
@@ -308,7 +308,7 @@ def log_validation(vae, unet, args, accelerator, weight_dtype, step):
tracker.log({"validation": formatted_images})
else:
logger.warn(f"image logging not implemented for {tracker.name}")
logger.warning(f"image logging not implemented for {tracker.name}")
del pipeline
gc.collect()
@@ -1068,7 +1068,7 @@ def main(args):
xformers_version = version.parse(xformers.__version__)
if xformers_version == version.parse("0.0.16"):
logger.warn(
logger.warning(
"xFormers 0.0.16 cannot be used for training in some GPUs. If you observe problems during training, please update xFormers to at least 0.0.17. See https://huggingface.co/docs/diffusers/main/en/optimization/xformers for more details."
)
unet.enable_xformers_memory_efficient_attention()
@@ -65,7 +65,7 @@ if is_wandb_available():
import wandb
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.27.0.dev0")
check_min_version("0.28.0.dev0")
logger = get_logger(__name__)
@@ -180,7 +180,7 @@ def log_validation(vae, args, accelerator, weight_dtype, step, unet=None, is_fin
logger_name = "test" if is_final_validation else "validation"
tracker.log({logger_name: formatted_images})
else:
logger.warn(f"image logging not implemented for {tracker.name}")
logger.warning(f"image logging not implemented for {tracker.name}")
del pipeline
gc.collect()
@@ -928,7 +928,7 @@ def main(args):
xformers_version = version.parse(xformers.__version__)
if xformers_version == version.parse("0.0.16"):
logger.warn(
logger.warning(
"xFormers 0.0.16 cannot be used for training in some GPUs. If you observe problems during training, please update xFormers to at least 0.0.17. See https://huggingface.co/docs/diffusers/main/en/optimization/xformers for more details."
)
unet.enable_xformers_memory_efficient_attention()
@@ -78,7 +78,7 @@ if is_wandb_available():
import wandb
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.27.0.dev0")
check_min_version("0.28.0.dev0")
logger = get_logger(__name__)
@@ -325,7 +325,7 @@ def log_validation(vae, unet, args, accelerator, weight_dtype, step):
tracker.log({"validation": formatted_images})
else:
logger.warn(f"image logging not implemented for {tracker.name}")
logger.warning(f"image logging not implemented for {tracker.name}")
del pipeline
gc.collect()
@@ -1083,7 +1083,7 @@ def main(args):
xformers_version = version.parse(xformers.__version__)
if xformers_version == version.parse("0.0.16"):
logger.warn(
logger.warning(
"xFormers 0.0.16 cannot be used for training in some GPUs. If you observe problems during training, please update xFormers to at least 0.0.17. See https://huggingface.co/docs/diffusers/main/en/optimization/xformers for more details."
)
unet.enable_xformers_memory_efficient_attention()
@@ -71,7 +71,7 @@ if is_wandb_available():
import wandb
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.27.0.dev0")
check_min_version("0.28.0.dev0")
logger = get_logger(__name__)
@@ -285,7 +285,7 @@ def log_validation(vae, unet, args, accelerator, weight_dtype, step, name="targe
tracker.log({f"validation/{name}": formatted_images})
else:
logger.warn(f"image logging not implemented for {tracker.name}")
logger.warning(f"image logging not implemented for {tracker.name}")
del pipeline
gc.collect()
@@ -1023,7 +1023,7 @@ def main(args):
xformers_version = version.parse(xformers.__version__)
if xformers_version == version.parse("0.0.16"):
logger.warn(
logger.warning(
"xFormers 0.0.16 cannot be used for training in some GPUs. If you observe problems during training, please update xFormers to at least 0.0.17. See https://huggingface.co/docs/diffusers/main/en/optimization/xformers for more details."
)
unet.enable_xformers_memory_efficient_attention()
@@ -77,7 +77,7 @@ if is_wandb_available():
import wandb
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.27.0.dev0")
check_min_version("0.28.0.dev0")
logger = get_logger(__name__)
@@ -303,7 +303,7 @@ def log_validation(vae, unet, args, accelerator, weight_dtype, step, name="targe
tracker.log({f"validation/{name}": formatted_images})
else:
logger.warn(f"image logging not implemented for {tracker.name}")
logger.warning(f"image logging not implemented for {tracker.name}")
del pipeline
gc.collect()
@@ -1083,7 +1083,7 @@ def main(args):
xformers_version = version.parse(xformers.__version__)
if xformers_version == version.parse("0.0.16"):
logger.warn(
logger.warning(
"xFormers 0.0.16 cannot be used for training in some GPUs. If you observe problems during training, please update xFormers to at least 0.0.17. See https://huggingface.co/docs/diffusers/main/en/optimization/xformers for more details."
)
unet.enable_xformers_memory_efficient_attention()
+3 -3
View File
@@ -60,7 +60,7 @@ if is_wandb_available():
import wandb
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.27.0.dev0")
check_min_version("0.28.0.dev0")
logger = get_logger(__name__)
@@ -178,7 +178,7 @@ def log_validation(
tracker.log({tracker_key: formatted_images})
else:
logger.warn(f"image logging not implemented for {tracker.name}")
logger.warning(f"image logging not implemented for {tracker.name}")
del pipeline
gc.collect()
@@ -861,7 +861,7 @@ def main(args):
xformers_version = version.parse(xformers.__version__)
if xformers_version == version.parse("0.0.16"):
logger.warn(
logger.warning(
"xFormers 0.0.16 cannot be used for training in some GPUs. If you observe problems during training, please update xFormers to at least 0.0.17. See https://huggingface.co/docs/diffusers/main/en/optimization/xformers for more details."
)
unet.enable_xformers_memory_efficient_attention()
+2 -2
View File
@@ -60,7 +60,7 @@ if is_wandb_available():
import wandb
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.27.0.dev0")
check_min_version("0.28.0.dev0")
logger = logging.getLogger(__name__)
@@ -128,7 +128,7 @@ def log_validation(pipeline, pipeline_params, controlnet_params, tokenizer, args
wandb.log({"validation": formatted_images})
else:
logger.warn(f"image logging not implemented for {args.report_to}")
logger.warning(f"image logging not implemented for {args.report_to}")
return image_logs
+3 -3
View File
@@ -61,7 +61,7 @@ if is_wandb_available():
import wandb
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.27.0.dev0")
check_min_version("0.28.0.dev0")
logger = get_logger(__name__)
@@ -178,7 +178,7 @@ def log_validation(vae, unet, controlnet, args, accelerator, weight_dtype, step,
tracker.log({tracker_key: formatted_images})
else:
logger.warn(f"image logging not implemented for {tracker.name}")
logger.warning(f"image logging not implemented for {tracker.name}")
del pipeline
gc.collect()
@@ -929,7 +929,7 @@ def main(args):
xformers_version = version.parse(xformers.__version__)
if xformers_version == version.parse("0.0.16"):
logger.warn(
logger.warning(
"xFormers 0.0.16 cannot be used for training in some GPUs. If you observe problems during training, please update xFormers to at least 0.0.17. See https://huggingface.co/docs/diffusers/main/en/optimization/xformers for more details."
)
unet.enable_xformers_memory_efficient_attention()
@@ -63,7 +63,7 @@ from diffusers.utils.import_utils import is_xformers_available
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.27.0.dev0")
check_min_version("0.28.0.dev0")
logger = get_logger(__name__)
@@ -904,7 +904,7 @@ def main(args):
xformers_version = version.parse(xformers.__version__)
if xformers_version == version.parse("0.0.16"):
logger.warn(
logger.warning(
"xFormers 0.0.16 cannot be used for training in some GPUs. If you observe problems during training, please update xFormers to at least 0.0.17. See https://huggingface.co/docs/diffusers/main/en/optimization/xformers for more details."
)
attention_class = CustomDiffusionXFormersAttnProcessor
@@ -1178,7 +1178,7 @@ def main(args):
grads_text_encoder = text_encoder.get_input_embeddings().weight.grad
# Get the index for tokens that we want to zero the grads for
index_grads_to_zero = torch.arange(len(tokenizer)) != modifier_token_id[0]
for i in range(len(modifier_token_id[1:])):
for i in range(1, len(modifier_token_id)):
index_grads_to_zero = index_grads_to_zero & (
torch.arange(len(tokenizer)) != modifier_token_id[i]
)
+2 -2
View File
@@ -63,7 +63,7 @@ if is_wandb_available():
import wandb
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.27.0.dev0")
check_min_version("0.28.0.dev0")
logger = get_logger(__name__)
@@ -987,7 +987,7 @@ def main(args):
xformers_version = version.parse(xformers.__version__)
if xformers_version == version.parse("0.0.16"):
logger.warn(
logger.warning(
"xFormers 0.0.16 cannot be used for training in some GPUs. If you observe problems during training, please update xFormers to at least 0.0.17. See https://huggingface.co/docs/diffusers/main/en/optimization/xformers for more details."
)
unet.enable_xformers_memory_efficient_attention()
+1 -1
View File
@@ -35,7 +35,7 @@ from diffusers.utils import check_min_version
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.27.0.dev0")
check_min_version("0.28.0.dev0")
# Cache compiled models across invocations of this script.
cc.initialize_cache(os.path.expanduser("~/.cache/jax/compilation_cache"))
+2 -2
View File
@@ -70,7 +70,7 @@ if is_wandb_available():
import wandb
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.27.0.dev0")
check_min_version("0.28.0.dev0")
logger = get_logger(__name__)
@@ -895,7 +895,7 @@ def main(args):
xformers_version = version.parse(xformers.__version__)
if xformers_version == version.parse("0.0.16"):
logger.warn(
logger.warning(
"xFormers 0.0.16 cannot be used for training in some GPUs. If you observe problems during training, please update xFormers to at least 0.0.17. See https://huggingface.co/docs/diffusers/main/en/optimization/xformers for more details."
)
unet.enable_xformers_memory_efficient_attention()
@@ -75,7 +75,7 @@ if is_wandb_available():
import wandb
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.27.0.dev0")
check_min_version("0.28.0.dev0")
logger = get_logger(__name__)
@@ -1141,7 +1141,7 @@ def main(args):
xformers_version = version.parse(xformers.__version__)
if xformers_version == version.parse("0.0.16"):
logger.warn(
logger.warning(
"xFormers 0.0.16 cannot be used for training in some GPUs. If you observe problems during training, "
"please update xFormers to at least 0.0.17. See https://huggingface.co/docs/diffusers/main/en/optimization/xformers for more details."
)
@@ -1317,14 +1317,14 @@ def main(args):
# Optimizer creation
if not (args.optimizer.lower() == "prodigy" or args.optimizer.lower() == "adamw"):
logger.warn(
logger.warning(
f"Unsupported choice of optimizer: {args.optimizer}.Supported optimizers include [adamW, prodigy]."
"Defaulting to adamW"
)
args.optimizer = "adamw"
if args.use_8bit_adam and not args.optimizer.lower() == "adamw":
logger.warn(
logger.warning(
f"use_8bit_adam is ignored when optimizer is not set to 'AdamW'. Optimizer was "
f"set to {args.optimizer.lower()}"
)
@@ -1358,11 +1358,11 @@ def main(args):
optimizer_class = prodigyopt.Prodigy
if args.learning_rate <= 0.1:
logger.warn(
logger.warning(
"Learning rate is too low. When using prodigy, it's generally better to set learning rate around 1.0"
)
if args.train_text_encoder and args.text_encoder_lr:
logger.warn(
logger.warning(
f"Learning rates were provided both for the unet and the text encoder- e.g. text_encoder_lr:"
f" {args.text_encoder_lr} and learning_rate: {args.learning_rate}. "
f"When using prodigy only learning_rate is used as the initial learning rate."
@@ -53,7 +53,7 @@ from diffusers.utils.torch_utils import is_compiled_module
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.27.0.dev0")
check_min_version("0.28.0.dev0")
logger = get_logger(__name__, log_level="INFO")
@@ -488,7 +488,7 @@ def main():
xformers_version = version.parse(xformers.__version__)
if xformers_version == version.parse("0.0.16"):
logger.warn(
logger.warning(
"xFormers 0.0.16 cannot be used for training in some GPUs. If you observe problems during training, please update xFormers to at least 0.0.17. See https://huggingface.co/docs/diffusers/main/en/optimization/xformers for more details."
)
unet.enable_xformers_memory_efficient_attention()
@@ -59,7 +59,7 @@ if is_wandb_available():
import wandb
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.27.0.dev0")
check_min_version("0.28.0.dev0")
logger = get_logger(__name__, log_level="INFO")
@@ -580,7 +580,7 @@ def main():
xformers_version = version.parse(xformers.__version__)
if xformers_version == version.parse("0.0.16"):
logger.warn(
logger.warning(
"xFormers 0.0.16 cannot be used for training in some GPUs. If you observe problems during training, please update xFormers to at least 0.0.17. See https://huggingface.co/docs/diffusers/main/en/optimization/xformers for more details."
)
unet.enable_xformers_memory_efficient_attention()
@@ -52,7 +52,7 @@ if is_wandb_available():
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.27.0.dev0")
check_min_version("0.28.0.dev0")
logger = get_logger(__name__, log_level="INFO")
@@ -177,7 +177,7 @@ def log_validation(vae, image_encoder, image_processor, unet, args, accelerator,
}
)
else:
logger.warn(f"image logging not implemented for {tracker.name}")
logger.warning(f"image logging not implemented for {tracker.name}")
del pipeline
torch.cuda.empty_cache()
@@ -534,7 +534,7 @@ def main():
xformers_version = version.parse(xformers.__version__)
if xformers_version == version.parse("0.0.16"):
logger.warn(
logger.warning(
"xFormers 0.0.16 cannot be used for training in some GPUs. If you observe problems during training, please update xFormers to at least 0.0.17. See https://huggingface.co/docs/diffusers/main/en/optimization/xformers for more details."
)
unet.enable_xformers_memory_efficient_attention()
@@ -46,7 +46,7 @@ from diffusers.utils import check_min_version, is_wandb_available
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.27.0.dev0")
check_min_version("0.28.0.dev0")
logger = get_logger(__name__, log_level="INFO")
@@ -46,7 +46,7 @@ from diffusers.utils import check_min_version, is_wandb_available
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.27.0.dev0")
check_min_version("0.28.0.dev0")
logger = get_logger(__name__, log_level="INFO")
@@ -51,7 +51,7 @@ if is_wandb_available():
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.27.0.dev0")
check_min_version("0.28.0.dev0")
logger = get_logger(__name__, log_level="INFO")
@@ -180,7 +180,7 @@ def log_validation(
}
)
else:
logger.warn(f"image logging not implemented for {tracker.name}")
logger.warning(f"image logging not implemented for {tracker.name}")
del pipeline
torch.cuda.empty_cache()
@@ -219,7 +219,7 @@ def log_validation(unet, scheduler, args, accelerator, weight_dtype, step, name=
if args.num_classes is not None:
class_labels = list(range(args.num_classes))
else:
logger.warn(
logger.warning(
"The model is class-conditional but the number of classes is not set. The generated images will be"
" unconditional rather than class-conditional."
)
@@ -266,7 +266,7 @@ def log_validation(unet, scheduler, args, accelerator, weight_dtype, step, name=
tracker.log({f"validation/{name}": formatted_images})
else:
logger.warn(f"image logging not implemented for {tracker.name}")
logger.warning(f"image logging not implemented for {tracker.name}")
del pipeline
gc.collect()
@@ -863,14 +863,14 @@ def main(args):
elif args.model_config_name_or_path is None:
# TODO: use default architectures from iCT paper
if not args.class_conditional and (args.num_classes is not None or args.class_embed_type is not None):
logger.warn(
logger.warning(
f"`--class_conditional` is set to `False` but `--num_classes` is set to {args.num_classes} and"
f" `--class_embed_type` is set to {args.class_embed_type}. These values will be overridden to `None`."
)
args.num_classes = None
args.class_embed_type = None
elif args.class_conditional and args.num_classes is None and args.class_embed_type is None:
logger.warn(
logger.warning(
"`--class_conditional` is set to `True` but neither `--num_classes` nor `--class_embed_type` is set."
"`class_conditional` will be overridden to `False`."
)
@@ -996,7 +996,7 @@ def main(args):
xformers_version = version.parse(xformers.__version__)
if xformers_version == version.parse("0.0.16"):
logger.warn(
logger.warning(
"xFormers 0.0.16 cannot be used for training in some GPUs. If you observe problems during training, please update xFormers to at least 0.0.17. See https://huggingface.co/docs/diffusers/main/en/optimization/xformers for more details."
)
unet.enable_xformers_memory_efficient_attention()
@@ -407,7 +407,7 @@ def log_validation(vae, unet, controlnet, args, accelerator, weight_dtype, step)
tracker.log({"validation": formatted_images})
else:
logger.warn(f"image logging not implemented for {tracker.name}")
logger.warning(f"image logging not implemented for {tracker.name}")
del pipeline
gc.collect()
@@ -1057,7 +1057,7 @@ def main(args):
xformers_version = version.parse(xformers.__version__)
if xformers_version == version.parse("0.0.16"):
logger.warn(
logger.warning(
"xFormers 0.0.16 cannot be used for training in some GPUs. If you observe problems during training, please update xFormers to at least 0.0.17. See https://huggingface.co/docs/diffusers/main/en/optimization/xformers for more details."
)
unet.enable_xformers_memory_efficient_attention()
@@ -574,7 +574,7 @@ def main(args):
xformers_version = version.parse(xformers.__version__)
if xformers_version == version.parse("0.0.16"):
logger.warn(
logger.warning(
"xFormers 0.0.16 cannot be used for training in some GPUs. If you observe problems during training, please update xFormers to at least 0.0.17. See https://huggingface.co/docs/diffusers/main/en/optimization/xformers for more details."
)
unet.enable_xformers_memory_efficient_attention()
@@ -672,7 +672,7 @@ def main(args):
xformers_version = version.parse(xformers.__version__)
if xformers_version == version.parse("0.0.16"):
logger.warn(
logger.warning(
"xFormers 0.0.16 cannot be used for training in some GPUs. If you observe problems during training, please update xFormers to at least 0.0.17. See https://huggingface.co/docs/diffusers/main/en/optimization/xformers for more details."
)
unet.enable_xformers_memory_efficient_attention()
@@ -516,7 +516,7 @@ def main():
xformers_version = version.parse(xformers.__version__)
if xformers_version == version.parse("0.0.16"):
logger.warn(
logger.warning(
"xFormers 0.0.16 cannot be used for training in some GPUs. If you observe problems during training, please update xFormers to at least 0.0.17. See https://huggingface.co/docs/diffusers/main/en/optimization/xformers for more details."
)
unet.enable_xformers_memory_efficient_attention()
@@ -608,7 +608,7 @@ def main():
# Create the pipeline using using the trained modules and save it.
if accelerator.is_main_process:
if args.push_to_hub and args.only_save_embeds:
logger.warn("Enabling full model saving because --push_to_hub=True was specified.")
logger.warning("Enabling full model saving because --push_to_hub=True was specified.")
save_full_model = True
else:
save_full_model = not args.only_save_embeds
@@ -541,7 +541,7 @@ def main():
xformers_version = version.parse(xformers.__version__)
if xformers_version == version.parse("0.0.16"):
logger.warn(
logger.warning(
"xFormers 0.0.16 cannot be used for training in some GPUs. If you observe problems during training, please update xFormers to at least 0.0.17. See https://huggingface.co/docs/diffusers/main/en/optimization/xformers for more details."
)
unet.enable_xformers_memory_efficient_attention()
@@ -645,7 +645,7 @@ def main():
xformers_version = version.parse(xformers.__version__)
if xformers_version == version.parse("0.0.16"):
logger.warn(
logger.warning(
"xFormers 0.0.16 cannot be used for training in some GPUs. If you observe problems during training, please update xFormers to at least 0.0.17. See https://huggingface.co/docs/diffusers/main/en/optimization/xformers for more details."
)
unet.enable_xformers_memory_efficient_attention()
@@ -901,7 +901,7 @@ def main():
accelerator.wait_for_everyone()
if accelerator.is_main_process:
if args.push_to_hub and args.only_save_embeds:
logger.warn("Enabling full model saving because --push_to_hub=True was specified.")
logger.warning("Enabling full model saving because --push_to_hub=True was specified.")
save_full_model = True
else:
save_full_model = not args.only_save_embeds
@@ -108,7 +108,7 @@ def log_validation(vae, text_encoder, tokenizer, unet, args, accelerator, weight
}
)
else:
logger.warn(f"image logging not implemented for {tracker.name}")
logger.warning(f"image logging not implemented for {tracker.name}")
del pipeline
torch.cuda.empty_cache()
@@ -523,7 +523,7 @@ def main():
xformers_version = version.parse(xformers.__version__)
if xformers_version == version.parse("0.0.16"):
logger.warn(
logger.warning(
"xFormers 0.0.16 cannot be used for training in some GPUs. If you observe problems during training, please update xFormers to at least 0.0.17. See https://huggingface.co/docs/diffusers/main/en/optimization/xformers for more details."
)
unet.enable_xformers_memory_efficient_attention()
@@ -687,7 +687,7 @@ def main():
xformers_version = version.parse(xformers.__version__)
if xformers_version == version.parse("0.0.16"):
logger.warn(
logger.warning(
"xFormers 0.0.16 cannot be used for training in some GPUs. If you observe problems during training, please update xFormers to at least 0.0.17. See https://huggingface.co/docs/diffusers/main/en/optimization/xformers for more details."
)
unet.enable_xformers_memory_efficient_attention()
@@ -916,7 +916,7 @@ def main():
accelerator.wait_for_everyone()
if accelerator.is_main_process:
if args.push_to_hub and not args.save_as_full_pipeline:
logger.warn("Enabling full model saving because --push_to_hub=True was specified.")
logger.warning("Enabling full model saving because --push_to_hub=True was specified.")
save_full_model = True
else:
save_full_model = args.save_as_full_pipeline
@@ -410,7 +410,7 @@ def main(args):
xformers_version = version.parse(xformers.__version__)
if xformers_version == version.parse("0.0.16"):
logger.warn(
logger.warning(
"xFormers 0.0.16 cannot be used for training in some GPUs. If you observe problems during training, please update xFormers to at least 0.0.17. See https://huggingface.co/docs/diffusers/main/en/optimization/xformers for more details."
)
model.enable_xformers_memory_efficient_attention()
@@ -637,7 +637,7 @@ def main(args):
generator=generator,
batch_size=args.eval_batch_size,
num_inference_steps=args.ddpm_num_inference_steps,
output_type="numpy",
output_type="np",
).images
if args.use_ema:
@@ -629,7 +629,7 @@ def main(args):
xformers_version = version.parse(xformers.__version__)
if xformers_version == version.parse("0.0.16"):
logger.warn(
logger.warning(
"xFormers 0.0.16 cannot be used for training in some GPUs. If you observe problems during training, please update xFormers to at least 0.0.17. See https://huggingface.co/docs/diffusers/main/en/optimization/xformers for more details."
)
unet.enable_xformers_memory_efficient_attention()
@@ -60,7 +60,7 @@ if is_wandb_available():
import wandb
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.27.0.dev0")
check_min_version("0.28.0.dev0")
logger = get_logger(__name__)
@@ -167,7 +167,7 @@ def log_validation(vae, unet, adapter, args, accelerator, weight_dtype, step):
tracker.log({"validation": formatted_images})
else:
logger.warn(f"image logging not implemented for {tracker.name}")
logger.warning(f"image logging not implemented for {tracker.name}")
del pipeline
gc.collect()
@@ -932,7 +932,7 @@ def main(args):
xformers_version = version.parse(xformers.__version__)
if xformers_version == version.parse("0.0.16"):
logger.warn(
logger.warning(
"xFormers 0.0.16 cannot be used for training in some GPUs. If you observe problems during training, please update xFormers to at least 0.0.17. See https://huggingface.co/docs/diffusers/main/en/optimization/xformers for more details."
)
unet.enable_xformers_memory_efficient_attention()
@@ -56,7 +56,7 @@ if is_wandb_available():
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.27.0.dev0")
check_min_version("0.28.0.dev0")
logger = get_logger(__name__, log_level="INFO")
@@ -183,7 +183,7 @@ def log_validation(vae, text_encoder, tokenizer, unet, args, accelerator, weight
}
)
else:
logger.warn(f"image logging not implemented for {tracker.name}")
logger.warning(f"image logging not implemented for {tracker.name}")
del pipeline
torch.cuda.empty_cache()
@@ -608,7 +608,7 @@ def main():
xformers_version = version.parse(xformers.__version__)
if xformers_version == version.parse("0.0.16"):
logger.warn(
logger.warning(
"xFormers 0.0.16 cannot be used for training in some GPUs. If you observe problems during training, please update xFormers to at least 0.0.17. See https://huggingface.co/docs/diffusers/main/en/optimization/xformers for more details."
)
unet.enable_xformers_memory_efficient_attention()
@@ -49,7 +49,7 @@ from diffusers.utils import check_min_version
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.27.0.dev0")
check_min_version("0.28.0.dev0")
logger = logging.getLogger(__name__)
@@ -52,7 +52,7 @@ from diffusers.utils.torch_utils import is_compiled_module
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.27.0.dev0")
check_min_version("0.28.0.dev0")
logger = get_logger(__name__, log_level="INFO")
@@ -497,7 +497,7 @@ def main():
xformers_version = version.parse(xformers.__version__)
if xformers_version == version.parse("0.0.16"):
logger.warn(
logger.warning(
"xFormers 0.0.16 cannot be used for training in some GPUs. If you observe problems during training, please update xFormers to at least 0.0.17. See https://huggingface.co/docs/diffusers/main/en/optimization/xformers for more details."
)
unet.enable_xformers_memory_efficient_attention()
@@ -64,7 +64,7 @@ from diffusers.utils.torch_utils import is_compiled_module
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.27.0.dev0")
check_min_version("0.28.0.dev0")
logger = get_logger(__name__)
@@ -425,6 +425,11 @@ def parse_args(input_args=None):
default=4,
help=("The dimension of the LoRA update matrices."),
)
parser.add_argument(
"--debug_loss",
action="store_true",
help="debug loss for each image, if filenames are awailable in the dataset",
)
if input_args is not None:
args = parser.parse_args(input_args)
@@ -603,6 +608,7 @@ def main(args):
# Move unet, vae and text_encoder to device and cast to weight_dtype
# The VAE is in float32 to avoid NaN losses.
unet.to(accelerator.device, dtype=weight_dtype)
if args.pretrained_vae_model_name_or_path is None:
vae.to(accelerator.device, dtype=torch.float32)
else:
@@ -616,7 +622,7 @@ def main(args):
xformers_version = version.parse(xformers.__version__)
if xformers_version == version.parse("0.0.16"):
logger.warn(
logger.warning(
"xFormers 0.0.16 cannot be used for training in some GPUs. If you observe problems during training, please update xFormers to at least 0.0.17. See https://huggingface.co/docs/diffusers/main/en/optimization/xformers for more details."
)
unet.enable_xformers_memory_efficient_attention()
@@ -890,13 +896,17 @@ def main(args):
tokens_one, tokens_two = tokenize_captions(examples)
examples["input_ids_one"] = tokens_one
examples["input_ids_two"] = tokens_two
if args.debug_loss:
fnames = [os.path.basename(image.filename) for image in examples[image_column] if image.filename]
if fnames:
examples["filenames"] = fnames
return examples
with accelerator.main_process_first():
if args.max_train_samples is not None:
dataset["train"] = dataset["train"].shuffle(seed=args.seed).select(range(args.max_train_samples))
# Set the training transforms
train_dataset = dataset["train"].with_transform(preprocess_train)
train_dataset = dataset["train"].with_transform(preprocess_train, output_all_columns=True)
def collate_fn(examples):
pixel_values = torch.stack([example["pixel_values"] for example in examples])
@@ -905,7 +915,7 @@ def main(args):
crop_top_lefts = [example["crop_top_lefts"] for example in examples]
input_ids_one = torch.stack([example["input_ids_one"] for example in examples])
input_ids_two = torch.stack([example["input_ids_two"] for example in examples])
return {
result = {
"pixel_values": pixel_values,
"input_ids_one": input_ids_one,
"input_ids_two": input_ids_two,
@@ -913,6 +923,11 @@ def main(args):
"crop_top_lefts": crop_top_lefts,
}
filenames = [example["filenames"] for example in examples if "filenames" in example]
if filenames:
result["filenames"] = filenames
return result
# DataLoaders creation:
train_dataloader = torch.utils.data.DataLoader(
train_dataset,
@@ -1105,7 +1120,9 @@ def main(args):
loss = F.mse_loss(model_pred.float(), target.float(), reduction="none")
loss = loss.mean(dim=list(range(1, len(loss.shape)))) * mse_loss_weights
loss = loss.mean()
if args.debug_loss and "filenames" in batch:
for fname in batch["filenames"]:
accelerator.log({"loss_for_" + fname: loss}, step=global_step)
# Gather the losses across all processes for logging (if we use distributed training).
avg_loss = accelerator.gather(loss.repeat(args.train_batch_size)).mean()
train_loss += avg_loss.item() / args.gradient_accumulation_steps
@@ -54,7 +54,7 @@ from diffusers.utils.torch_utils import is_compiled_module
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.27.0.dev0")
check_min_version("0.28.0.dev0")
logger = get_logger(__name__)
@@ -712,7 +712,7 @@ def main(args):
xformers_version = version.parse(xformers.__version__)
if xformers_version == version.parse("0.0.16"):
logger.warn(
logger.warning(
"xFormers 0.0.16 cannot be used for training in some GPUs. If you observe problems during training, please update xFormers to at least 0.0.17. See https://huggingface.co/docs/diffusers/main/en/optimization/xformers for more details."
)
unet.enable_xformers_memory_efficient_attention()
@@ -80,7 +80,7 @@ else:
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.27.0.dev0")
check_min_version("0.28.0.dev0")
logger = get_logger(__name__)
@@ -708,7 +708,7 @@ def main():
xformers_version = version.parse(xformers.__version__)
if xformers_version == version.parse("0.0.16"):
logger.warn(
logger.warning(
"xFormers 0.0.16 cannot be used for training in some GPUs. If you observe problems during training, please update xFormers to at least 0.0.17. See https://huggingface.co/docs/diffusers/main/en/optimization/xformers for more details."
)
unet.enable_xformers_memory_efficient_attention()
@@ -966,7 +966,7 @@ def main():
accelerator.wait_for_everyone()
if accelerator.is_main_process:
if args.push_to_hub and not args.save_as_full_pipeline:
logger.warn("Enabling full model saving because --push_to_hub=True was specified.")
logger.warning("Enabling full model saving because --push_to_hub=True was specified.")
save_full_model = True
else:
save_full_model = args.save_as_full_pipeline
@@ -56,7 +56,7 @@ else:
# ------------------------------------------------------------------------------
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.27.0.dev0")
check_min_version("0.28.0.dev0")
logger = logging.getLogger(__name__)
@@ -76,7 +76,7 @@ else:
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.27.0.dev0")
check_min_version("0.28.0.dev0")
logger = get_logger(__name__)
@@ -711,7 +711,7 @@ def main():
xformers_version = version.parse(xformers.__version__)
if xformers_version == version.parse("0.0.16"):
logger.warn(
logger.warning(
"xFormers 0.0.16 cannot be used for training in some GPUs. If you observe problems during training, please update xFormers to at least 0.0.17. See https://huggingface.co/docs/diffusers/main/en/optimization/xformers for more details."
)
unet.enable_xformers_memory_efficient_attention()
@@ -1022,7 +1022,7 @@ def main():
)
if args.push_to_hub and not args.save_as_full_pipeline:
logger.warn("Enabling full model saving because --push_to_hub=True was specified.")
logger.warning("Enabling full model saving because --push_to_hub=True was specified.")
save_full_model = True
else:
save_full_model = args.save_as_full_pipeline
@@ -29,7 +29,7 @@ from diffusers.utils.import_utils import is_xformers_available
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.27.0.dev0")
check_min_version("0.28.0.dev0")
logger = get_logger(__name__, log_level="INFO")
@@ -408,7 +408,7 @@ def main(args):
xformers_version = version.parse(xformers.__version__)
if xformers_version == version.parse("0.0.16"):
logger.warn(
logger.warning(
"xFormers 0.0.16 cannot be used for training in some GPUs. If you observe problems during training, please update xFormers to at least 0.0.17. See https://huggingface.co/docs/diffusers/main/en/optimization/xformers for more details."
)
model.enable_xformers_memory_efficient_attention()
@@ -648,7 +648,7 @@ def main(args):
generator=generator,
batch_size=args.eval_batch_size,
num_inference_steps=args.ddpm_num_inference_steps,
output_type="numpy",
output_type="np",
).images
if args.use_ema:
@@ -50,7 +50,7 @@ if is_wandb_available():
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.27.0.dev0")
check_min_version("0.28.0.dev0")
logger = get_logger(__name__, log_level="INFO")
@@ -184,7 +184,7 @@ def log_validation(text_encoder, tokenizer, prior, args, accelerator, weight_dty
}
)
else:
logger.warn(f"image logging not implemented for {tracker.name}")
logger.warning(f"image logging not implemented for {tracker.name}")
del pipeline
torch.cuda.empty_cache()
@@ -51,7 +51,7 @@ if is_wandb_available():
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.27.0.dev0")
check_min_version("0.28.0.dev0")
logger = get_logger(__name__, log_level="INFO")
@@ -182,7 +182,7 @@ def log_validation(text_encoder, tokenizer, prior, args, accelerator, weight_dty
}
)
else:
logger.warn(f"image logging not implemented for {tracker.name}")
logger.warning(f"image logging not implemented for {tracker.name}")
del pipeline
torch.cuda.empty_cache()
+163 -160
View File
@@ -1,7 +1,7 @@
# Run this script to convert the Stable Cascade model weights to a diffusers pipeline.
import argparse
from contextlib import nullcontext
import accelerate
import torch
from safetensors.torch import load_file
from transformers import (
@@ -18,23 +18,56 @@ from diffusers import (
StableCascadeDecoderPipeline,
StableCascadePriorPipeline,
)
from diffusers.loaders.single_file_utils import convert_stable_cascade_unet_single_file_to_diffusers
from diffusers.models import StableCascadeUNet
from diffusers.models.modeling_utils import load_model_dict_into_meta
from diffusers.pipelines.wuerstchen import PaellaVQModel
from diffusers.utils import is_accelerate_available
if is_accelerate_available():
from accelerate import init_empty_weights
parser = argparse.ArgumentParser(description="Convert Stable Cascade model weights to a diffusers pipeline")
parser.add_argument("--model_path", type=str, default="../StableCascade", help="Location of Stable Cascade weights")
parser.add_argument("--model_path", type=str, help="Location of Stable Cascade weights")
parser.add_argument("--stage_c_name", type=str, default="stage_c.safetensors", help="Name of stage c checkpoint file")
parser.add_argument("--stage_b_name", type=str, default="stage_b.safetensors", help="Name of stage b checkpoint file")
parser.add_argument("--skip_stage_c", action="store_true", help="Skip converting stage c")
parser.add_argument("--skip_stage_b", action="store_true", help="Skip converting stage b")
parser.add_argument("--use_safetensors", action="store_true", help="Use SafeTensors for conversion")
parser.add_argument("--save_org", type=str, default="diffusers", help="Hub organization to save the pipelines to")
parser.add_argument(
"--prior_output_path", default="stable-cascade-prior", type=str, help="Hub organization to save the pipelines to"
)
parser.add_argument(
"--decoder_output_path",
type=str,
default="stable-cascade-decoder",
help="Hub organization to save the pipelines to",
)
parser.add_argument(
"--combined_output_path",
type=str,
default="stable-cascade-combined",
help="Hub organization to save the pipelines to",
)
parser.add_argument("--save_combined", action="store_true")
parser.add_argument("--push_to_hub", action="store_true", help="Push to hub")
parser.add_argument("--variant", type=str, help="Set to bf16 to save bfloat16 weights")
args = parser.parse_args()
if args.skip_stage_b and args.skip_stage_c:
raise ValueError("At least one stage should be converted")
if (args.skip_stage_b or args.skip_stage_c) and args.save_combined:
raise ValueError("Cannot skip stages when creating a combined pipeline")
model_path = args.model_path
device = "cpu"
if args.variant == "bf16":
dtype = torch.bfloat16
else:
dtype = torch.float32
# set paths to model weights
prior_checkpoint_path = f"{model_path}/{args.stage_c_name}"
@@ -52,164 +85,134 @@ tokenizer = AutoTokenizer.from_pretrained("laion/CLIP-ViT-bigG-14-laion2B-39B-b1
feature_extractor = CLIPImageProcessor()
image_encoder = CLIPVisionModelWithProjection.from_pretrained("openai/clip-vit-large-patch14")
# Prior
if args.use_safetensors:
orig_state_dict = load_file(prior_checkpoint_path, device=device)
else:
orig_state_dict = torch.load(prior_checkpoint_path, map_location=device)
state_dict = {}
for key in orig_state_dict.keys():
if key.endswith("in_proj_weight"):
weights = orig_state_dict[key].chunk(3, 0)
state_dict[key.replace("attn.in_proj_weight", "to_q.weight")] = weights[0]
state_dict[key.replace("attn.in_proj_weight", "to_k.weight")] = weights[1]
state_dict[key.replace("attn.in_proj_weight", "to_v.weight")] = weights[2]
elif key.endswith("in_proj_bias"):
weights = orig_state_dict[key].chunk(3, 0)
state_dict[key.replace("attn.in_proj_bias", "to_q.bias")] = weights[0]
state_dict[key.replace("attn.in_proj_bias", "to_k.bias")] = weights[1]
state_dict[key.replace("attn.in_proj_bias", "to_v.bias")] = weights[2]
elif key.endswith("out_proj.weight"):
weights = orig_state_dict[key]
state_dict[key.replace("attn.out_proj.weight", "to_out.0.weight")] = weights
elif key.endswith("out_proj.bias"):
weights = orig_state_dict[key]
state_dict[key.replace("attn.out_proj.bias", "to_out.0.bias")] = weights
else:
state_dict[key] = orig_state_dict[key]
with accelerate.init_empty_weights():
prior_model = StableCascadeUNet(
in_channels=16,
out_channels=16,
timestep_ratio_embedding_dim=64,
patch_size=1,
conditioning_dim=2048,
block_out_channels=[2048, 2048],
num_attention_heads=[32, 32],
down_num_layers_per_block=[8, 24],
up_num_layers_per_block=[24, 8],
down_blocks_repeat_mappers=[1, 1],
up_blocks_repeat_mappers=[1, 1],
block_types_per_layer=[
["SDCascadeResBlock", "SDCascadeTimestepBlock", "SDCascadeAttnBlock"],
["SDCascadeResBlock", "SDCascadeTimestepBlock", "SDCascadeAttnBlock"],
],
clip_text_in_channels=1280,
clip_text_pooled_in_channels=1280,
clip_image_in_channels=768,
clip_seq=4,
kernel_size=3,
dropout=[0.1, 0.1],
self_attn=True,
timestep_conditioning_type=["sca", "crp"],
switch_level=[False],
)
load_model_dict_into_meta(prior_model, state_dict)
# scheduler for prior and decoder
scheduler = DDPMWuerstchenScheduler()
ctx = init_empty_weights if is_accelerate_available() else nullcontext
# Prior pipeline
prior_pipeline = StableCascadePriorPipeline(
prior=prior_model,
tokenizer=tokenizer,
text_encoder=text_encoder,
image_encoder=image_encoder,
scheduler=scheduler,
feature_extractor=feature_extractor,
)
prior_pipeline.save_pretrained(f"{args.save_org}/StableCascade-prior", push_to_hub=args.push_to_hub)
# Decoder
if args.use_safetensors:
orig_state_dict = load_file(decoder_checkpoint_path, device=device)
else:
orig_state_dict = torch.load(decoder_checkpoint_path, map_location=device)
state_dict = {}
for key in orig_state_dict.keys():
if key.endswith("in_proj_weight"):
weights = orig_state_dict[key].chunk(3, 0)
state_dict[key.replace("attn.in_proj_weight", "to_q.weight")] = weights[0]
state_dict[key.replace("attn.in_proj_weight", "to_k.weight")] = weights[1]
state_dict[key.replace("attn.in_proj_weight", "to_v.weight")] = weights[2]
elif key.endswith("in_proj_bias"):
weights = orig_state_dict[key].chunk(3, 0)
state_dict[key.replace("attn.in_proj_bias", "to_q.bias")] = weights[0]
state_dict[key.replace("attn.in_proj_bias", "to_k.bias")] = weights[1]
state_dict[key.replace("attn.in_proj_bias", "to_v.bias")] = weights[2]
elif key.endswith("out_proj.weight"):
weights = orig_state_dict[key]
state_dict[key.replace("attn.out_proj.weight", "to_out.0.weight")] = weights
elif key.endswith("out_proj.bias"):
weights = orig_state_dict[key]
state_dict[key.replace("attn.out_proj.bias", "to_out.0.bias")] = weights
# rename clip_mapper to clip_txt_pooled_mapper
elif key.endswith("clip_mapper.weight"):
weights = orig_state_dict[key]
state_dict[key.replace("clip_mapper.weight", "clip_txt_pooled_mapper.weight")] = weights
elif key.endswith("clip_mapper.bias"):
weights = orig_state_dict[key]
state_dict[key.replace("clip_mapper.bias", "clip_txt_pooled_mapper.bias")] = weights
else:
state_dict[key] = orig_state_dict[key]
with accelerate.init_empty_weights():
decoder = StableCascadeUNet(
in_channels=4,
out_channels=4,
timestep_ratio_embedding_dim=64,
patch_size=2,
conditioning_dim=1280,
block_out_channels=[320, 640, 1280, 1280],
down_num_layers_per_block=[2, 6, 28, 6],
up_num_layers_per_block=[6, 28, 6, 2],
down_blocks_repeat_mappers=[1, 1, 1, 1],
up_blocks_repeat_mappers=[3, 3, 2, 2],
num_attention_heads=[0, 0, 20, 20],
block_types_per_layer=[
["SDCascadeResBlock", "SDCascadeTimestepBlock"],
["SDCascadeResBlock", "SDCascadeTimestepBlock"],
["SDCascadeResBlock", "SDCascadeTimestepBlock", "SDCascadeAttnBlock"],
["SDCascadeResBlock", "SDCascadeTimestepBlock", "SDCascadeAttnBlock"],
],
clip_text_pooled_in_channels=1280,
clip_seq=4,
effnet_in_channels=16,
pixel_mapper_in_channels=3,
kernel_size=3,
dropout=[0, 0, 0.1, 0.1],
self_attn=True,
timestep_conditioning_type=["sca"],
)
load_model_dict_into_meta(decoder, state_dict)
# VQGAN from Wuerstchen-V2
vqmodel = PaellaVQModel.from_pretrained("warp-ai/wuerstchen", subfolder="vqgan")
# Decoder pipeline
decoder_pipeline = StableCascadeDecoderPipeline(
decoder=decoder, text_encoder=text_encoder, tokenizer=tokenizer, vqgan=vqmodel, scheduler=scheduler
)
decoder_pipeline.save_pretrained(f"{args.save_org}/StableCascade-decoder", push_to_hub=args.push_to_hub)
# Stable Cascade combined pipeline
stable_cascade_pipeline = StableCascadeCombinedPipeline(
# Decoder
text_encoder=text_encoder,
tokenizer=tokenizer,
decoder=decoder,
scheduler=scheduler,
vqgan=vqmodel,
if not args.skip_stage_c:
# Prior
prior_text_encoder=text_encoder,
prior_tokenizer=tokenizer,
prior_prior=prior_model,
prior_scheduler=scheduler,
prior_image_encoder=image_encoder,
prior_feature_extractor=feature_extractor,
)
stable_cascade_pipeline.save_pretrained(f"{args.save_org}/StableCascade", push_to_hub=args.push_to_hub)
if args.use_safetensors:
prior_orig_state_dict = load_file(prior_checkpoint_path, device=device)
else:
prior_orig_state_dict = torch.load(prior_checkpoint_path, map_location=device)
prior_state_dict = convert_stable_cascade_unet_single_file_to_diffusers(prior_orig_state_dict)
with ctx():
prior_model = StableCascadeUNet(
in_channels=16,
out_channels=16,
timestep_ratio_embedding_dim=64,
patch_size=1,
conditioning_dim=2048,
block_out_channels=[2048, 2048],
num_attention_heads=[32, 32],
down_num_layers_per_block=[8, 24],
up_num_layers_per_block=[24, 8],
down_blocks_repeat_mappers=[1, 1],
up_blocks_repeat_mappers=[1, 1],
block_types_per_layer=[
["SDCascadeResBlock", "SDCascadeTimestepBlock", "SDCascadeAttnBlock"],
["SDCascadeResBlock", "SDCascadeTimestepBlock", "SDCascadeAttnBlock"],
],
clip_text_in_channels=1280,
clip_text_pooled_in_channels=1280,
clip_image_in_channels=768,
clip_seq=4,
kernel_size=3,
dropout=[0.1, 0.1],
self_attn=True,
timestep_conditioning_type=["sca", "crp"],
switch_level=[False],
)
if is_accelerate_available():
load_model_dict_into_meta(prior_model, prior_state_dict)
else:
prior_model.load_state_dict(prior_state_dict)
# Prior pipeline
prior_pipeline = StableCascadePriorPipeline(
prior=prior_model,
tokenizer=tokenizer,
text_encoder=text_encoder,
image_encoder=image_encoder,
scheduler=scheduler,
feature_extractor=feature_extractor,
)
prior_pipeline.to(dtype).save_pretrained(
args.prior_output_path, push_to_hub=args.push_to_hub, variant=args.variant
)
if not args.skip_stage_b:
# Decoder
if args.use_safetensors:
decoder_orig_state_dict = load_file(decoder_checkpoint_path, device=device)
else:
decoder_orig_state_dict = torch.load(decoder_checkpoint_path, map_location=device)
decoder_state_dict = convert_stable_cascade_unet_single_file_to_diffusers(decoder_orig_state_dict)
with ctx():
decoder = StableCascadeUNet(
in_channels=4,
out_channels=4,
timestep_ratio_embedding_dim=64,
patch_size=2,
conditioning_dim=1280,
block_out_channels=[320, 640, 1280, 1280],
down_num_layers_per_block=[2, 6, 28, 6],
up_num_layers_per_block=[6, 28, 6, 2],
down_blocks_repeat_mappers=[1, 1, 1, 1],
up_blocks_repeat_mappers=[3, 3, 2, 2],
num_attention_heads=[0, 0, 20, 20],
block_types_per_layer=[
["SDCascadeResBlock", "SDCascadeTimestepBlock"],
["SDCascadeResBlock", "SDCascadeTimestepBlock"],
["SDCascadeResBlock", "SDCascadeTimestepBlock", "SDCascadeAttnBlock"],
["SDCascadeResBlock", "SDCascadeTimestepBlock", "SDCascadeAttnBlock"],
],
clip_text_pooled_in_channels=1280,
clip_seq=4,
effnet_in_channels=16,
pixel_mapper_in_channels=3,
kernel_size=3,
dropout=[0, 0, 0.1, 0.1],
self_attn=True,
timestep_conditioning_type=["sca"],
)
if is_accelerate_available():
load_model_dict_into_meta(decoder, decoder_state_dict)
else:
decoder.load_state_dict(decoder_state_dict)
# VQGAN from Wuerstchen-V2
vqmodel = PaellaVQModel.from_pretrained("warp-ai/wuerstchen", subfolder="vqgan")
# Decoder pipeline
decoder_pipeline = StableCascadeDecoderPipeline(
decoder=decoder, text_encoder=text_encoder, tokenizer=tokenizer, vqgan=vqmodel, scheduler=scheduler
)
decoder_pipeline.to(dtype).save_pretrained(
args.decoder_output_path, push_to_hub=args.push_to_hub, variant=args.variant
)
if args.save_combined:
# Stable Cascade combined pipeline
stable_cascade_pipeline = StableCascadeCombinedPipeline(
# Decoder
text_encoder=text_encoder,
tokenizer=tokenizer,
decoder=decoder,
scheduler=scheduler,
vqgan=vqmodel,
# Prior
prior_text_encoder=text_encoder,
prior_tokenizer=tokenizer,
prior_prior=prior_model,
prior_scheduler=scheduler,
prior_image_encoder=image_encoder,
prior_feature_extractor=feature_extractor,
)
stable_cascade_pipeline.to(dtype).save_pretrained(
args.combined_output_path, push_to_hub=args.push_to_hub, variant=args.variant
)
+226
View File
@@ -0,0 +1,226 @@
# Run this script to convert the Stable Cascade model weights to a diffusers pipeline.
import argparse
from contextlib import nullcontext
import torch
from safetensors.torch import load_file
from transformers import (
AutoTokenizer,
CLIPConfig,
CLIPImageProcessor,
CLIPTextModelWithProjection,
CLIPVisionModelWithProjection,
)
from diffusers import (
DDPMWuerstchenScheduler,
StableCascadeCombinedPipeline,
StableCascadeDecoderPipeline,
StableCascadePriorPipeline,
)
from diffusers.loaders.single_file_utils import convert_stable_cascade_unet_single_file_to_diffusers
from diffusers.models import StableCascadeUNet
from diffusers.models.modeling_utils import load_model_dict_into_meta
from diffusers.pipelines.wuerstchen import PaellaVQModel
from diffusers.utils import is_accelerate_available
if is_accelerate_available():
from accelerate import init_empty_weights
parser = argparse.ArgumentParser(description="Convert Stable Cascade model weights to a diffusers pipeline")
parser.add_argument("--model_path", type=str, help="Location of Stable Cascade weights")
parser.add_argument(
"--stage_c_name", type=str, default="stage_c_lite.safetensors", help="Name of stage c checkpoint file"
)
parser.add_argument(
"--stage_b_name", type=str, default="stage_b_lite.safetensors", help="Name of stage b checkpoint file"
)
parser.add_argument("--skip_stage_c", action="store_true", help="Skip converting stage c")
parser.add_argument("--skip_stage_b", action="store_true", help="Skip converting stage b")
parser.add_argument("--use_safetensors", action="store_true", help="Use SafeTensors for conversion")
parser.add_argument(
"--prior_output_path",
default="stable-cascade-prior-lite",
type=str,
help="Hub organization to save the pipelines to",
)
parser.add_argument(
"--decoder_output_path",
type=str,
default="stable-cascade-decoder-lite",
help="Hub organization to save the pipelines to",
)
parser.add_argument(
"--combined_output_path",
type=str,
default="stable-cascade-combined-lite",
help="Hub organization to save the pipelines to",
)
parser.add_argument("--save_combined", action="store_true")
parser.add_argument("--push_to_hub", action="store_true", help="Push to hub")
parser.add_argument("--variant", type=str, help="Set to bf16 to save bfloat16 weights")
args = parser.parse_args()
if args.skip_stage_b and args.skip_stage_c:
raise ValueError("At least one stage should be converted")
if (args.skip_stage_b or args.skip_stage_c) and args.save_combined:
raise ValueError("Cannot skip stages when creating a combined pipeline")
model_path = args.model_path
device = "cpu"
if args.variant == "bf16":
dtype = torch.bfloat16
else:
dtype = torch.float32
# set paths to model weights
prior_checkpoint_path = f"{model_path}/{args.stage_c_name}"
decoder_checkpoint_path = f"{model_path}/{args.stage_b_name}"
# Clip Text encoder and tokenizer
config = CLIPConfig.from_pretrained("laion/CLIP-ViT-bigG-14-laion2B-39B-b160k")
config.text_config.projection_dim = config.projection_dim
text_encoder = CLIPTextModelWithProjection.from_pretrained(
"laion/CLIP-ViT-bigG-14-laion2B-39B-b160k", config=config.text_config
)
tokenizer = AutoTokenizer.from_pretrained("laion/CLIP-ViT-bigG-14-laion2B-39B-b160k")
# image processor
feature_extractor = CLIPImageProcessor()
image_encoder = CLIPVisionModelWithProjection.from_pretrained("openai/clip-vit-large-patch14")
# scheduler for prior and decoder
scheduler = DDPMWuerstchenScheduler()
ctx = init_empty_weights if is_accelerate_available() else nullcontext
if not args.skip_stage_c:
# Prior
if args.use_safetensors:
prior_orig_state_dict = load_file(prior_checkpoint_path, device=device)
else:
prior_orig_state_dict = torch.load(prior_checkpoint_path, map_location=device)
prior_state_dict = convert_stable_cascade_unet_single_file_to_diffusers(prior_orig_state_dict)
with ctx():
prior_model = StableCascadeUNet(
in_channels=16,
out_channels=16,
timestep_ratio_embedding_dim=64,
patch_size=1,
conditioning_dim=1536,
block_out_channels=[1536, 1536],
num_attention_heads=[24, 24],
down_num_layers_per_block=[4, 12],
up_num_layers_per_block=[12, 4],
down_blocks_repeat_mappers=[1, 1],
up_blocks_repeat_mappers=[1, 1],
block_types_per_layer=[
["SDCascadeResBlock", "SDCascadeTimestepBlock", "SDCascadeAttnBlock"],
["SDCascadeResBlock", "SDCascadeTimestepBlock", "SDCascadeAttnBlock"],
],
clip_text_in_channels=1280,
clip_text_pooled_in_channels=1280,
clip_image_in_channels=768,
clip_seq=4,
kernel_size=3,
dropout=[0.1, 0.1],
self_attn=True,
timestep_conditioning_type=["sca", "crp"],
switch_level=[False],
)
if is_accelerate_available():
load_model_dict_into_meta(prior_model, prior_state_dict)
else:
prior_model.load_state_dict(prior_state_dict)
# Prior pipeline
prior_pipeline = StableCascadePriorPipeline(
prior=prior_model,
tokenizer=tokenizer,
text_encoder=text_encoder,
image_encoder=image_encoder,
scheduler=scheduler,
feature_extractor=feature_extractor,
)
prior_pipeline.to(dtype).save_pretrained(
args.prior_output_path, push_to_hub=args.push_to_hub, variant=args.variant
)
if not args.skip_stage_b:
# Decoder
if args.use_safetensors:
decoder_orig_state_dict = load_file(decoder_checkpoint_path, device=device)
else:
decoder_orig_state_dict = torch.load(decoder_checkpoint_path, map_location=device)
decoder_state_dict = convert_stable_cascade_unet_single_file_to_diffusers(decoder_orig_state_dict)
with ctx():
decoder = StableCascadeUNet(
in_channels=4,
out_channels=4,
timestep_ratio_embedding_dim=64,
patch_size=2,
conditioning_dim=1280,
block_out_channels=[320, 576, 1152, 1152],
down_num_layers_per_block=[2, 4, 14, 4],
up_num_layers_per_block=[4, 14, 4, 2],
down_blocks_repeat_mappers=[1, 1, 1, 1],
up_blocks_repeat_mappers=[2, 2, 2, 2],
num_attention_heads=[0, 9, 18, 18],
block_types_per_layer=[
["SDCascadeResBlock", "SDCascadeTimestepBlock"],
["SDCascadeResBlock", "SDCascadeTimestepBlock"],
["SDCascadeResBlock", "SDCascadeTimestepBlock", "SDCascadeAttnBlock"],
["SDCascadeResBlock", "SDCascadeTimestepBlock", "SDCascadeAttnBlock"],
],
clip_text_pooled_in_channels=1280,
clip_seq=4,
effnet_in_channels=16,
pixel_mapper_in_channels=3,
kernel_size=3,
dropout=[0, 0, 0.1, 0.1],
self_attn=True,
timestep_conditioning_type=["sca"],
)
if is_accelerate_available():
load_model_dict_into_meta(decoder, decoder_state_dict)
else:
decoder.load_state_dict(decoder_state_dict)
# VQGAN from Wuerstchen-V2
vqmodel = PaellaVQModel.from_pretrained("warp-ai/wuerstchen", subfolder="vqgan")
# Decoder pipeline
decoder_pipeline = StableCascadeDecoderPipeline(
decoder=decoder, text_encoder=text_encoder, tokenizer=tokenizer, vqgan=vqmodel, scheduler=scheduler
)
decoder_pipeline.to(dtype).save_pretrained(
args.decoder_output_path, push_to_hub=args.push_to_hub, variant=args.variant
)
if args.save_combined:
# Stable Cascade combined pipeline
stable_cascade_pipeline = StableCascadeCombinedPipeline(
# Decoder
text_encoder=text_encoder,
tokenizer=tokenizer,
decoder=decoder,
scheduler=scheduler,
vqgan=vqmodel,
# Prior
prior_text_encoder=text_encoder,
prior_tokenizer=tokenizer,
prior_prior=prior_model,
prior_scheduler=scheduler,
prior_image_encoder=image_encoder,
prior_feature_extractor=feature_extractor,
)
stable_cascade_pipeline.to(dtype).save_pretrained(
args.combined_output_path, push_to_hub=args.push_to_hub, variant=args.variant
)
+139
View File
@@ -0,0 +1,139 @@
import argparse
import json
import os
from datetime import date
from pathlib import Path
from slack_sdk import WebClient
from tabulate import tabulate
MAX_LEN_MESSAGE = 2900 # slack endpoint has a limit of 3001 characters
parser = argparse.ArgumentParser()
parser.add_argument("--slack_channel_name", default="diffusers-ci-nightly")
def main(slack_channel_name=None):
failed = []
passed = []
group_info = []
total_num_failed = 0
empty_file = False or len(list(Path().glob("*.log"))) == 0
total_empty_files = []
for log in Path().glob("*.log"):
section_num_failed = 0
i = 0
with open(log) as f:
for line in f:
line = json.loads(line)
i += 1
if line.get("nodeid", "") != "":
test = line["nodeid"]
if line.get("duration", None) is not None:
duration = f'{line["duration"]:.4f}'
if line.get("outcome", "") == "failed":
section_num_failed += 1
failed.append([test, duration, log.name.split("_")[0]])
total_num_failed += 1
else:
passed.append([test, duration, log.name.split("_")[0]])
empty_file = i == 0
group_info.append([str(log), section_num_failed, failed])
total_empty_files.append(empty_file)
os.remove(log)
failed = []
text = (
"🌞 There were no failures!"
if not any(total_empty_files)
else "Something went wrong there is at least one empty file - please check GH action results."
)
no_error_payload = {
"type": "section",
"text": {
"type": "plain_text",
"text": text,
"emoji": True,
},
}
message = ""
payload = [
{
"type": "header",
"text": {
"type": "plain_text",
"text": "🤗 Results of the Diffusers scheduled nightly tests.",
},
},
]
if total_num_failed > 0:
for i, (name, num_failed, failed_tests) in enumerate(group_info):
if num_failed > 0:
if num_failed == 1:
message += f"*{name}: {num_failed} failed test*\n"
else:
message += f"*{name}: {num_failed} failed tests*\n"
failed_table = []
for test in failed_tests:
failed_table.append(test[0].split("::"))
failed_table = tabulate(
failed_table,
headers=["Test Location", "Test Case", "Test Name"],
showindex="always",
tablefmt="grid",
maxcolwidths=[12, 12, 12],
)
message += "\n```\n" + failed_table + "\n```"
if total_empty_files[i]:
message += f"\n*{name}: Warning! Empty file - please check the GitHub action job *\n"
print(f"### {message}")
else:
payload.append(no_error_payload)
if len(message) > MAX_LEN_MESSAGE:
print(f"Truncating long message from {len(message)} to {MAX_LEN_MESSAGE}")
message = message[:MAX_LEN_MESSAGE] + "..."
if len(message) != 0:
md_report = {
"type": "section",
"text": {"type": "mrkdwn", "text": message},
}
payload.append(md_report)
action_button = {
"type": "section",
"text": {"type": "mrkdwn", "text": "*For more details:*"},
"accessory": {
"type": "button",
"text": {"type": "plain_text", "text": "Check Action results", "emoji": True},
"url": f"https://github.com/huggingface/diffusers/actions/runs/{os.environ['GITHUB_RUN_ID']}",
},
}
payload.append(action_button)
date_report = {
"type": "context",
"elements": [
{
"type": "plain_text",
"text": f"Nightly test results for {date.today()}",
},
],
}
payload.append(date_report)
print(payload)
client = WebClient(token=os.environ.get("SLACK_API_TOKEN"))
client.chat_postMessage(channel=f"#{slack_channel_name}", text=message, blocks=payload)
if __name__ == "__main__":
args = parser.parse_args()
main(args.slack_channel_name)
+1 -1
View File
@@ -249,7 +249,7 @@ version_range_max = max(sys.version_info[1], 10) + 1
setup(
name="diffusers",
version="0.27.0.dev0", # expected format is one of x.y.z.dev0, or x.y.z.rc1 or x.y.z (no to dashes, yes to dots)
version="0.28.0.dev0", # expected format is one of x.y.z.dev0, or x.y.z.rc1 or x.y.z (no to dashes, yes to dots)
description="State-of-the-art diffusion in PyTorch and JAX.",
long_description=open("README.md", "r", encoding="utf-8").read(),
long_description_content_type="text/markdown",
+5 -1
View File
@@ -1,4 +1,4 @@
__version__ = "0.27.0.dev0"
__version__ = "0.28.0.dev0"
from typing import TYPE_CHECKING
@@ -253,6 +253,8 @@ else:
"LatentConsistencyModelImg2ImgPipeline",
"LatentConsistencyModelPipeline",
"LDMTextToImagePipeline",
"LEditsPPPipelineStableDiffusion",
"LEditsPPPipelineStableDiffusionXL",
"MusicLDMPipeline",
"PaintByExamplePipeline",
"PIAPipeline",
@@ -623,6 +625,8 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
LatentConsistencyModelImg2ImgPipeline,
LatentConsistencyModelPipeline,
LDMTextToImagePipeline,
LEditsPPPipelineStableDiffusion,
LEditsPPPipelineStableDiffusionXL,
MusicLDMPipeline,
PaintByExamplePipeline,
PIAPipeline,
+2 -2
View File
@@ -430,7 +430,7 @@ class LoraLoaderMixin:
# contain the module names of the `unet` as its keys WITHOUT any prefix.
if not USE_PEFT_BACKEND:
warn_message = "You have saved the LoRA weights using the old format. To convert the old LoRA weights to the new format, you can first load them in a dictionary and then create a new dictionary like the following: `new_state_dict = {f'unet.{module_name}': params for module_name, params in old_state_dict.items()}`."
logger.warn(warn_message)
logger.warning(warn_message)
if len(state_dict.keys()) > 0:
if adapter_name in getattr(unet, "peft_config", {}):
@@ -882,7 +882,7 @@ class LoraLoaderMixin:
if fuse_unet or fuse_text_encoder:
self.num_fused_loras += 1
if self.num_fused_loras > 1:
logger.warn(
logger.warning(
"The current API is supported for operating with a single LoRA file. You are trying to load and fuse more than one LoRA which is not well-supported.",
)
+6 -1
View File
@@ -56,6 +56,8 @@ def build_sub_model_components(
if component_name == "unet":
num_in_channels = kwargs.pop("num_in_channels", None)
upcast_attention = kwargs.pop("upcast_attention", None)
unet_components = create_diffusers_unet_model_from_ldm(
pipeline_class_name,
original_config,
@@ -64,6 +66,7 @@ def build_sub_model_components(
image_size=image_size,
torch_dtype=torch_dtype,
model_type=model_type,
upcast_attention=upcast_attention,
)
return unet_components
@@ -300,7 +303,9 @@ class FromSingleFileMixin:
continue
init_kwargs.update(components)
additional_components = set_additional_components(class_name, original_config, model_type=model_type)
additional_components = set_additional_components(
class_name, original_config, checkpoint=checkpoint, model_type=model_type
)
if additional_components:
init_kwargs.update(additional_components)
+116 -12
View File
@@ -81,6 +81,87 @@ SCHEDULER_DEFAULT_CONFIG = {
"timestep_spacing": "leading",
}
STABLE_CASCADE_DEFAULT_CONFIGS = {
"stage_c": {"pretrained_model_name_or_path": "diffusers/stable-cascade-configs", "subfolder": "prior"},
"stage_c_lite": {"pretrained_model_name_or_path": "diffusers/stable-cascade-configs", "subfolder": "prior_lite"},
"stage_b": {"pretrained_model_name_or_path": "diffusers/stable-cascade-configs", "subfolder": "decoder"},
"stage_b_lite": {"pretrained_model_name_or_path": "diffusers/stable-cascade-configs", "subfolder": "decoder_lite"},
}
def convert_stable_cascade_unet_single_file_to_diffusers(original_state_dict):
is_stage_c = "clip_txt_mapper.weight" in original_state_dict
if is_stage_c:
state_dict = {}
for key in original_state_dict.keys():
if key.endswith("in_proj_weight"):
weights = original_state_dict[key].chunk(3, 0)
state_dict[key.replace("attn.in_proj_weight", "to_q.weight")] = weights[0]
state_dict[key.replace("attn.in_proj_weight", "to_k.weight")] = weights[1]
state_dict[key.replace("attn.in_proj_weight", "to_v.weight")] = weights[2]
elif key.endswith("in_proj_bias"):
weights = original_state_dict[key].chunk(3, 0)
state_dict[key.replace("attn.in_proj_bias", "to_q.bias")] = weights[0]
state_dict[key.replace("attn.in_proj_bias", "to_k.bias")] = weights[1]
state_dict[key.replace("attn.in_proj_bias", "to_v.bias")] = weights[2]
elif key.endswith("out_proj.weight"):
weights = original_state_dict[key]
state_dict[key.replace("attn.out_proj.weight", "to_out.0.weight")] = weights
elif key.endswith("out_proj.bias"):
weights = original_state_dict[key]
state_dict[key.replace("attn.out_proj.bias", "to_out.0.bias")] = weights
else:
state_dict[key] = original_state_dict[key]
else:
state_dict = {}
for key in original_state_dict.keys():
if key.endswith("in_proj_weight"):
weights = original_state_dict[key].chunk(3, 0)
state_dict[key.replace("attn.in_proj_weight", "to_q.weight")] = weights[0]
state_dict[key.replace("attn.in_proj_weight", "to_k.weight")] = weights[1]
state_dict[key.replace("attn.in_proj_weight", "to_v.weight")] = weights[2]
elif key.endswith("in_proj_bias"):
weights = original_state_dict[key].chunk(3, 0)
state_dict[key.replace("attn.in_proj_bias", "to_q.bias")] = weights[0]
state_dict[key.replace("attn.in_proj_bias", "to_k.bias")] = weights[1]
state_dict[key.replace("attn.in_proj_bias", "to_v.bias")] = weights[2]
elif key.endswith("out_proj.weight"):
weights = original_state_dict[key]
state_dict[key.replace("attn.out_proj.weight", "to_out.0.weight")] = weights
elif key.endswith("out_proj.bias"):
weights = original_state_dict[key]
state_dict[key.replace("attn.out_proj.bias", "to_out.0.bias")] = weights
# rename clip_mapper to clip_txt_pooled_mapper
elif key.endswith("clip_mapper.weight"):
weights = original_state_dict[key]
state_dict[key.replace("clip_mapper.weight", "clip_txt_pooled_mapper.weight")] = weights
elif key.endswith("clip_mapper.bias"):
weights = original_state_dict[key]
state_dict[key.replace("clip_mapper.bias", "clip_txt_pooled_mapper.bias")] = weights
else:
state_dict[key] = original_state_dict[key]
return state_dict
def infer_stable_cascade_single_file_config(checkpoint):
is_stage_c = "clip_txt_mapper.weight" in checkpoint
is_stage_b = "down_blocks.1.0.channelwise.0.weight" in checkpoint
if is_stage_c and (checkpoint["clip_txt_mapper.weight"].shape[0] == 1536):
config_type = "stage_c_lite"
elif is_stage_c and (checkpoint["clip_txt_mapper.weight"].shape[0] == 2048):
config_type = "stage_c"
elif is_stage_b and checkpoint["down_blocks.1.0.channelwise.0.weight"].shape[-1] == 576:
config_type = "stage_b_lite"
elif is_stage_b and checkpoint["down_blocks.1.0.channelwise.0.weight"].shape[-1] == 640:
config_type = "stage_b"
return STABLE_CASCADE_DEFAULT_CONFIGS[config_type]
DIFFUSERS_TO_LDM_MAPPING = {
"unet": {
"layers": {
@@ -229,10 +310,34 @@ def fetch_ldm_config_and_checkpoint(
cache_dir=None,
local_files_only=None,
revision=None,
):
checkpoint = load_single_file_model_checkpoint(
pretrained_model_link_or_path,
resume_download=resume_download,
force_download=force_download,
proxies=proxies,
token=token,
cache_dir=cache_dir,
local_files_only=local_files_only,
revision=revision,
)
original_config = fetch_original_config(class_name, checkpoint, original_config_file)
return original_config, checkpoint
def load_single_file_model_checkpoint(
pretrained_model_link_or_path,
resume_download=False,
force_download=False,
proxies=None,
token=None,
cache_dir=None,
local_files_only=None,
revision=None,
):
if os.path.isfile(pretrained_model_link_or_path):
checkpoint = load_state_dict(pretrained_model_link_or_path)
else:
repo_id, weights_name = _extract_repo_id_and_weights_name(pretrained_model_link_or_path)
checkpoint_path = _get_model_file(
@@ -252,9 +357,7 @@ def fetch_ldm_config_and_checkpoint(
while "state_dict" in checkpoint:
checkpoint = checkpoint["state_dict"]
original_config = fetch_original_config(class_name, checkpoint, original_config_file)
return original_config, checkpoint
return checkpoint
def infer_original_config_file(class_name, checkpoint):
@@ -307,7 +410,7 @@ def fetch_original_config(pipeline_class_name, checkpoint, original_config_file=
return original_config
def infer_model_type(original_config, checkpoint=None, model_type=None):
def infer_model_type(original_config, checkpoint, model_type=None):
if model_type is not None:
return model_type
@@ -884,7 +987,7 @@ def create_diffusers_controlnet_model_from_ldm(
unexpected_keys = [k for k in unexpected_keys if re.search(pat, k) is None]
if len(unexpected_keys) > 0:
logger.warn(
logger.warning(
f"Some weights of the model checkpoint were not used when initializing {controlnet.__name__}: \n {[', '.join(unexpected_keys)]}"
)
else:
@@ -1060,7 +1163,7 @@ def create_text_encoder_from_ldm_clip_checkpoint(config_name, checkpoint, local_
unexpected_keys = [k for k in unexpected_keys if re.search(pat, k) is None]
if len(unexpected_keys) > 0:
logger.warn(
logger.warning(
f"Some weights of the model checkpoint were not used when initializing {text_model.__class__.__name__}: \n {[', '.join(unexpected_keys)]}"
)
else:
@@ -1155,7 +1258,7 @@ def create_text_encoder_from_open_clip_checkpoint(
unexpected_keys = [k for k in unexpected_keys if re.search(pat, k) is None]
if len(unexpected_keys) > 0:
logger.warn(
logger.warning(
f"Some weights of the model checkpoint were not used when initializing {text_model.__class__.__name__}: \n {[', '.join(unexpected_keys)]}"
)
@@ -1176,7 +1279,7 @@ def create_diffusers_unet_model_from_ldm(
original_config,
checkpoint,
num_in_channels=None,
upcast_attention=False,
upcast_attention=None,
extract_ema=False,
image_size=None,
torch_dtype=None,
@@ -1204,7 +1307,8 @@ def create_diffusers_unet_model_from_ldm(
)
unet_config = create_unet_diffusers_config(original_config, image_size=image_size)
unet_config["in_channels"] = num_in_channels
unet_config["upcast_attention"] = upcast_attention
if upcast_attention is not None:
unet_config["upcast_attention"] = upcast_attention
diffusers_format_unet_checkpoint = convert_ldm_unet_checkpoint(checkpoint, unet_config, extract_ema=extract_ema)
ctx = init_empty_weights if is_accelerate_available() else nullcontext
@@ -1221,7 +1325,7 @@ def create_diffusers_unet_model_from_ldm(
unexpected_keys = [k for k in unexpected_keys if re.search(pat, k) is None]
if len(unexpected_keys) > 0:
logger.warn(
logger.warning(
f"Some weights of the model checkpoint were not used when initializing {unet.__name__}: \n {[', '.join(unexpected_keys)]}"
)
else:
@@ -1283,7 +1387,7 @@ def create_diffusers_vae_model_from_ldm(
unexpected_keys = [k for k in unexpected_keys if re.search(pat, k) is None]
if len(unexpected_keys) > 0:
logger.warn(
logger.warning(
f"Some weights of the model checkpoint were not used when initializing {vae.__name__}: \n {[', '.join(unexpected_keys)]}"
)
else:
+107 -2
View File
@@ -42,6 +42,11 @@ from ..utils import (
set_adapter_layers,
set_weights_and_activate_adapters,
)
from .single_file_utils import (
convert_stable_cascade_unet_single_file_to_diffusers,
infer_stable_cascade_single_file_config,
load_single_file_model_checkpoint,
)
from .utils import AttnProcsLayers
@@ -345,7 +350,7 @@ class UNet2DConditionLoadersMixin:
is_model_cpu_offload = False
is_sequential_cpu_offload = False
# For PEFT backend the Unet is already offloaded at this stage as it is handled inside `lora_lora_weights_into_unet`
# For PEFT backend the Unet is already offloaded at this stage as it is handled inside `load_lora_weights_into_unet`
if not USE_PEFT_BACKEND:
if _pipeline is not None:
for _, component in _pipeline.components.items():
@@ -384,7 +389,7 @@ class UNet2DConditionLoadersMixin:
is_text_encoder_present = any(key.startswith(self.text_encoder_name) for key in state_dict.keys())
if is_text_encoder_present:
warn_message = "The state_dict contains LoRA params corresponding to the text encoder which are not being used here. To use both UNet and text encoder related LoRA params, use [`pipe.load_lora_weights()`](https://huggingface.co/docs/diffusers/main/en/api/loaders#diffusers.loaders.LoraLoaderMixin.load_lora_weights)."
logger.warn(warn_message)
logger.warning(warn_message)
unet_keys = [k for k in state_dict.keys() if k.startswith(self.unet_name)]
state_dict = {k.replace(f"{self.unet_name}.", ""): v for k, v in state_dict.items() if k in unet_keys}
@@ -896,3 +901,103 @@ class UNet2DConditionLoadersMixin:
self.config.encoder_hid_dim_type = "ip_image_proj"
self.to(dtype=self.dtype, device=self.device)
class FromOriginalUNetMixin:
"""
Load pretrained UNet model weights saved in the `.ckpt` or `.safetensors` format into a [`StableCascadeUNet`].
"""
@classmethod
@validate_hf_hub_args
def from_single_file(cls, pretrained_model_link_or_path, **kwargs):
r"""
Instantiate a [`StableCascadeUNet`] from pretrained StableCascadeUNet weights saved in the original `.ckpt` or
`.safetensors` format. The pipeline is set in evaluation mode (`model.eval()`) by default.
Parameters:
pretrained_model_link_or_path (`str` or `os.PathLike`, *optional*):
Can be either:
- A link to the `.ckpt` file (for example
`"https://huggingface.co/<repo_id>/blob/main/<path_to_file>.ckpt"`) on the Hub.
- A path to a *file* containing all pipeline weights.
config: (`dict`, *optional*):
Dictionary containing the configuration of the model:
torch_dtype (`str` or `torch.dtype`, *optional*):
Override the default `torch.dtype` and load the model with another dtype. If `"auto"` is passed, the
dtype is automatically derived from the model's weights.
force_download (`bool`, *optional*, defaults to `False`):
Whether or not to force the (re-)download of the model weights and configuration files, overriding the
cached versions if they exist.
cache_dir (`Union[str, os.PathLike]`, *optional*):
Path to a directory where a downloaded pretrained model configuration is cached if the standard cache
is not used.
resume_download (`bool`, *optional*, defaults to `False`):
Whether or not to resume downloading the model weights and configuration files. If set to `False`, any
incompletely downloaded files are deleted.
proxies (`Dict[str, str]`, *optional*):
A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128',
'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
local_files_only (`bool`, *optional*, defaults to `False`):
Whether to only load local model weights and configuration files or not. If set to True, the model
won't be downloaded from the Hub.
token (`str` or *bool*, *optional*):
The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from
`diffusers-cli login` (stored in `~/.huggingface`) is used.
revision (`str`, *optional*, defaults to `"main"`):
The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier
allowed by Git.
kwargs (remaining dictionary of keyword arguments, *optional*):
Can be used to overwrite load and saveable variables of the model.
"""
class_name = cls.__name__
if class_name != "StableCascadeUNet":
raise ValueError("FromOriginalUNetMixin is currently only compatible with StableCascadeUNet")
config = kwargs.pop("config", None)
resume_download = kwargs.pop("resume_download", False)
force_download = kwargs.pop("force_download", False)
proxies = kwargs.pop("proxies", None)
token = kwargs.pop("token", None)
cache_dir = kwargs.pop("cache_dir", None)
local_files_only = kwargs.pop("local_files_only", None)
revision = kwargs.pop("revision", None)
torch_dtype = kwargs.pop("torch_dtype", None)
checkpoint = load_single_file_model_checkpoint(
pretrained_model_link_or_path,
resume_download=resume_download,
force_download=force_download,
proxies=proxies,
token=token,
cache_dir=cache_dir,
local_files_only=local_files_only,
revision=revision,
)
if config is None:
config = infer_stable_cascade_single_file_config(checkpoint)
model_config = cls.load_config(**config, **kwargs)
else:
model_config = config
ctx = init_empty_weights if is_accelerate_available() else nullcontext
with ctx():
model = cls.from_config(model_config, **kwargs)
diffusers_format_checkpoint = convert_stable_cascade_unet_single_file_to_diffusers(checkpoint)
if is_accelerate_available():
unexpected_keys = load_model_dict_into_meta(model, diffusers_format_checkpoint, dtype=torch_dtype)
if len(unexpected_keys) > 0:
logger.warn(
f"Some weights of the model checkpoint were not used when initializing {cls.__name__}: \n {[', '.join(unexpected_keys)]}"
)
else:
model.load_state_dict(diffusers_format_checkpoint)
if torch_dtype is not None:
model.to(torch_dtype)
return model
+8 -8
View File
@@ -17,8 +17,7 @@ import torch
import torch.nn.functional as F
from torch import nn
from ..utils import USE_PEFT_BACKEND
from .lora import LoRACompatibleLinear
from ..utils import deprecate
ACTIVATION_FUNCTIONS = {
@@ -87,9 +86,7 @@ class GEGLU(nn.Module):
def __init__(self, dim_in: int, dim_out: int, bias: bool = True):
super().__init__()
linear_cls = LoRACompatibleLinear if not USE_PEFT_BACKEND else nn.Linear
self.proj = linear_cls(dim_in, dim_out * 2, bias=bias)
self.proj = nn.Linear(dim_in, dim_out * 2, bias=bias)
def gelu(self, gate: torch.Tensor) -> torch.Tensor:
if gate.device.type != "mps":
@@ -97,9 +94,12 @@ class GEGLU(nn.Module):
# mps: gelu is not implemented for float16
return F.gelu(gate.to(dtype=torch.float32)).to(dtype=gate.dtype)
def forward(self, hidden_states, scale: float = 1.0):
args = () if USE_PEFT_BACKEND else (scale,)
hidden_states, gate = self.proj(hidden_states, *args).chunk(2, dim=-1)
def forward(self, hidden_states, *args, **kwargs):
if len(args) > 0 or kwargs.get("scale", None) is not None:
deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`."
deprecate("scale", "1.0.0", deprecation_message)
hidden_states, gate = self.proj(hidden_states).chunk(2, dim=-1)
return hidden_states * self.gelu(gate)
+23 -33
View File
@@ -17,18 +17,18 @@ import torch
import torch.nn.functional as F
from torch import nn
from ..utils import USE_PEFT_BACKEND
from ..utils import deprecate, logging
from ..utils.torch_utils import maybe_allow_in_graph
from .activations import GEGLU, GELU, ApproximateGELU
from .attention_processor import Attention
from .embeddings import SinusoidalPositionalEmbedding
from .lora import LoRACompatibleLinear
from .normalization import AdaLayerNorm, AdaLayerNormContinuous, AdaLayerNormZero, RMSNorm
def _chunked_feed_forward(
ff: nn.Module, hidden_states: torch.Tensor, chunk_dim: int, chunk_size: int, lora_scale: Optional[float] = None
):
logger = logging.get_logger(__name__)
def _chunked_feed_forward(ff: nn.Module, hidden_states: torch.Tensor, chunk_dim: int, chunk_size: int):
# "feed_forward_chunk_size" can be used to save memory
if hidden_states.shape[chunk_dim] % chunk_size != 0:
raise ValueError(
@@ -36,18 +36,10 @@ def _chunked_feed_forward(
)
num_chunks = hidden_states.shape[chunk_dim] // chunk_size
if lora_scale is None:
ff_output = torch.cat(
[ff(hid_slice) for hid_slice in hidden_states.chunk(num_chunks, dim=chunk_dim)],
dim=chunk_dim,
)
else:
# TOOD(Patrick): LoRA scale can be removed once PEFT refactor is complete
ff_output = torch.cat(
[ff(hid_slice, scale=lora_scale) for hid_slice in hidden_states.chunk(num_chunks, dim=chunk_dim)],
dim=chunk_dim,
)
ff_output = torch.cat(
[ff(hid_slice) for hid_slice in hidden_states.chunk(num_chunks, dim=chunk_dim)],
dim=chunk_dim,
)
return ff_output
@@ -299,6 +291,10 @@ class BasicTransformerBlock(nn.Module):
class_labels: Optional[torch.LongTensor] = None,
added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None,
) -> torch.FloatTensor:
if cross_attention_kwargs is not None:
if cross_attention_kwargs.get("scale", None) is not None:
logger.warning("Passing `scale` to `cross_attention_kwargs` is deprecated. `scale` will be ignored.")
# Notice that normalization is always applied before the real computation in the following blocks.
# 0. Self-Attention
batch_size = hidden_states.shape[0]
@@ -326,10 +322,7 @@ class BasicTransformerBlock(nn.Module):
if self.pos_embed is not None:
norm_hidden_states = self.pos_embed(norm_hidden_states)
# 1. Retrieve lora scale.
lora_scale = cross_attention_kwargs.get("scale", 1.0) if cross_attention_kwargs is not None else 1.0
# 2. Prepare GLIGEN inputs
# 1. Prepare GLIGEN inputs
cross_attention_kwargs = cross_attention_kwargs.copy() if cross_attention_kwargs is not None else {}
gligen_kwargs = cross_attention_kwargs.pop("gligen", None)
@@ -348,7 +341,7 @@ class BasicTransformerBlock(nn.Module):
if hidden_states.ndim == 4:
hidden_states = hidden_states.squeeze(1)
# 2.5 GLIGEN Control
# 1.2 GLIGEN Control
if gligen_kwargs is not None:
hidden_states = self.fuser(hidden_states, gligen_kwargs["objs"])
@@ -394,11 +387,9 @@ class BasicTransformerBlock(nn.Module):
if self._chunk_size is not None:
# "feed_forward_chunk_size" can be used to save memory
ff_output = _chunked_feed_forward(
self.ff, norm_hidden_states, self._chunk_dim, self._chunk_size, lora_scale=lora_scale
)
ff_output = _chunked_feed_forward(self.ff, norm_hidden_states, self._chunk_dim, self._chunk_size)
else:
ff_output = self.ff(norm_hidden_states, scale=lora_scale)
ff_output = self.ff(norm_hidden_states)
if self.norm_type == "ada_norm_zero":
ff_output = gate_mlp.unsqueeze(1) * ff_output
@@ -643,7 +634,7 @@ class FeedForward(nn.Module):
if inner_dim is None:
inner_dim = int(dim * mult)
dim_out = dim_out if dim_out is not None else dim
linear_cls = LoRACompatibleLinear if not USE_PEFT_BACKEND else nn.Linear
linear_cls = nn.Linear
if activation_fn == "gelu":
act_fn = GELU(dim, inner_dim, bias=bias)
@@ -665,11 +656,10 @@ class FeedForward(nn.Module):
if final_dropout:
self.net.append(nn.Dropout(dropout))
def forward(self, hidden_states: torch.Tensor, scale: float = 1.0) -> torch.Tensor:
compatible_cls = (GEGLU,) if USE_PEFT_BACKEND else (GEGLU, LoRACompatibleLinear)
def forward(self, hidden_states: torch.Tensor, *args, **kwargs) -> torch.Tensor:
if len(args) > 0 or kwargs.get("scale", None) is not None:
deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`."
deprecate("scale", "1.0.0", deprecation_message)
for module in self.net:
if isinstance(module, compatible_cls):
hidden_states = module(hidden_states, scale)
else:
hidden_states = module(hidden_states)
hidden_states = module(hidden_states)
return hidden_states
+97 -58
View File
@@ -20,10 +20,10 @@ import torch.nn.functional as F
from torch import nn
from ..image_processor import IPAdapterMaskProcessor
from ..utils import USE_PEFT_BACKEND, deprecate, logging
from ..utils import deprecate, logging
from ..utils.import_utils import is_xformers_available
from ..utils.torch_utils import maybe_allow_in_graph
from .lora import LoRACompatibleLinear, LoRALinearLayer
from .lora import LoRALinearLayer
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
@@ -181,10 +181,7 @@ class Attention(nn.Module):
f"unknown cross_attention_norm: {cross_attention_norm}. Should be None, 'layer_norm' or 'group_norm'"
)
if USE_PEFT_BACKEND:
linear_cls = nn.Linear
else:
linear_cls = LoRACompatibleLinear
linear_cls = nn.Linear
self.linear_cls = linear_cls
self.to_q = linear_cls(query_dim, self.inner_dim, bias=bias)
@@ -741,11 +738,14 @@ class AttnProcessor:
encoder_hidden_states: Optional[torch.FloatTensor] = None,
attention_mask: Optional[torch.FloatTensor] = None,
temb: Optional[torch.FloatTensor] = None,
scale: float = 1.0,
*args,
**kwargs,
) -> torch.Tensor:
residual = hidden_states
if len(args) > 0 or kwargs.get("scale", None) is not None:
deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`."
deprecate("scale", "1.0.0", deprecation_message)
args = () if USE_PEFT_BACKEND else (scale,)
residual = hidden_states
if attn.spatial_norm is not None:
hidden_states = attn.spatial_norm(hidden_states, temb)
@@ -764,15 +764,26 @@ class AttnProcessor:
if attn.group_norm is not None:
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
query = attn.to_q(hidden_states, *args)
query = attn.to_q(hidden_states)
if encoder_hidden_states is None:
encoder_hidden_states = hidden_states
# encoder_hidden_states = hidden_states
batch, seq, dim = hidden_states.shape
height = width = seq**0.5
# reshape to (batch, height, width, dim)
encoder_hidden_states = hidden_states.view(batch, height, width, dim)
# reshape to (batch, dim, height, width)
encoder_hidden_states = encoder_hidden_states.permute(0, 3, 1, 2)
encoder_hidden_states = torch.nn.functional.avg_pool2d(hidden_states, kernel_size=4)
# reshape to (batch, dim, seq)
encoder_hidden_states = encoder_hidden_states.view(batch, dim, -1)
# reshape to (batch, seq, dim)
encoder_hidden_states = encoder_hidden_states.permute(0, 2, 1)
elif attn.norm_cross:
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
key = attn.to_k(encoder_hidden_states, *args)
value = attn.to_v(encoder_hidden_states, *args)
key = attn.to_k(encoder_hidden_states)
value = attn.to_v(encoder_hidden_states)
query = attn.head_to_batch_dim(query)
key = attn.head_to_batch_dim(key)
@@ -783,7 +794,7 @@ class AttnProcessor:
hidden_states = attn.batch_to_head_dim(hidden_states)
# linear proj
hidden_states = attn.to_out[0](hidden_states, *args)
hidden_states = attn.to_out[0](hidden_states)
# dropout
hidden_states = attn.to_out[1](hidden_states)
@@ -914,11 +925,14 @@ class AttnAddedKVProcessor:
hidden_states: torch.FloatTensor,
encoder_hidden_states: Optional[torch.FloatTensor] = None,
attention_mask: Optional[torch.FloatTensor] = None,
scale: float = 1.0,
*args,
**kwargs,
) -> torch.Tensor:
residual = hidden_states
if len(args) > 0 or kwargs.get("scale", None) is not None:
deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`."
deprecate("scale", "1.0.0", deprecation_message)
args = () if USE_PEFT_BACKEND else (scale,)
residual = hidden_states
hidden_states = hidden_states.view(hidden_states.shape[0], hidden_states.shape[1], -1).transpose(1, 2)
batch_size, sequence_length, _ = hidden_states.shape
@@ -932,17 +946,17 @@ class AttnAddedKVProcessor:
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
query = attn.to_q(hidden_states, *args)
query = attn.to_q(hidden_states)
query = attn.head_to_batch_dim(query)
encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states, *args)
encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states, *args)
encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states)
encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states)
encoder_hidden_states_key_proj = attn.head_to_batch_dim(encoder_hidden_states_key_proj)
encoder_hidden_states_value_proj = attn.head_to_batch_dim(encoder_hidden_states_value_proj)
if not attn.only_cross_attention:
key = attn.to_k(hidden_states, *args)
value = attn.to_v(hidden_states, *args)
key = attn.to_k(hidden_states)
value = attn.to_v(hidden_states)
key = attn.head_to_batch_dim(key)
value = attn.head_to_batch_dim(value)
key = torch.cat([encoder_hidden_states_key_proj, key], dim=1)
@@ -956,7 +970,7 @@ class AttnAddedKVProcessor:
hidden_states = attn.batch_to_head_dim(hidden_states)
# linear proj
hidden_states = attn.to_out[0](hidden_states, *args)
hidden_states = attn.to_out[0](hidden_states)
# dropout
hidden_states = attn.to_out[1](hidden_states)
@@ -984,11 +998,14 @@ class AttnAddedKVProcessor2_0:
hidden_states: torch.FloatTensor,
encoder_hidden_states: Optional[torch.FloatTensor] = None,
attention_mask: Optional[torch.FloatTensor] = None,
scale: float = 1.0,
*args,
**kwargs,
) -> torch.Tensor:
residual = hidden_states
if len(args) > 0 or kwargs.get("scale", None) is not None:
deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`."
deprecate("scale", "1.0.0", deprecation_message)
args = () if USE_PEFT_BACKEND else (scale,)
residual = hidden_states
hidden_states = hidden_states.view(hidden_states.shape[0], hidden_states.shape[1], -1).transpose(1, 2)
batch_size, sequence_length, _ = hidden_states.shape
@@ -1002,7 +1019,7 @@ class AttnAddedKVProcessor2_0:
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
query = attn.to_q(hidden_states, *args)
query = attn.to_q(hidden_states)
query = attn.head_to_batch_dim(query, out_dim=4)
encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states)
@@ -1011,8 +1028,8 @@ class AttnAddedKVProcessor2_0:
encoder_hidden_states_value_proj = attn.head_to_batch_dim(encoder_hidden_states_value_proj, out_dim=4)
if not attn.only_cross_attention:
key = attn.to_k(hidden_states, *args)
value = attn.to_v(hidden_states, *args)
key = attn.to_k(hidden_states)
value = attn.to_v(hidden_states)
key = attn.head_to_batch_dim(key, out_dim=4)
value = attn.head_to_batch_dim(value, out_dim=4)
key = torch.cat([encoder_hidden_states_key_proj, key], dim=2)
@@ -1029,7 +1046,7 @@ class AttnAddedKVProcessor2_0:
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, residual.shape[1])
# linear proj
hidden_states = attn.to_out[0](hidden_states, *args)
hidden_states = attn.to_out[0](hidden_states)
# dropout
hidden_states = attn.to_out[1](hidden_states)
@@ -1132,11 +1149,14 @@ class XFormersAttnProcessor:
encoder_hidden_states: Optional[torch.FloatTensor] = None,
attention_mask: Optional[torch.FloatTensor] = None,
temb: Optional[torch.FloatTensor] = None,
scale: float = 1.0,
*args,
**kwargs,
) -> torch.FloatTensor:
residual = hidden_states
if len(args) > 0 or kwargs.get("scale", None) is not None:
deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`."
deprecate("scale", "1.0.0", deprecation_message)
args = () if USE_PEFT_BACKEND else (scale,)
residual = hidden_states
if attn.spatial_norm is not None:
hidden_states = attn.spatial_norm(hidden_states, temb)
@@ -1165,15 +1185,15 @@ class XFormersAttnProcessor:
if attn.group_norm is not None:
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
query = attn.to_q(hidden_states, *args)
query = attn.to_q(hidden_states)
if encoder_hidden_states is None:
encoder_hidden_states = hidden_states
elif attn.norm_cross:
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
key = attn.to_k(encoder_hidden_states, *args)
value = attn.to_v(encoder_hidden_states, *args)
key = attn.to_k(encoder_hidden_states)
value = attn.to_v(encoder_hidden_states)
query = attn.head_to_batch_dim(query).contiguous()
key = attn.head_to_batch_dim(key).contiguous()
@@ -1186,7 +1206,7 @@ class XFormersAttnProcessor:
hidden_states = attn.batch_to_head_dim(hidden_states)
# linear proj
hidden_states = attn.to_out[0](hidden_states, *args)
hidden_states = attn.to_out[0](hidden_states)
# dropout
hidden_states = attn.to_out[1](hidden_states)
@@ -1217,8 +1237,13 @@ class AttnProcessor2_0:
encoder_hidden_states: Optional[torch.FloatTensor] = None,
attention_mask: Optional[torch.FloatTensor] = None,
temb: Optional[torch.FloatTensor] = None,
scale: float = 1.0,
*args,
**kwargs,
) -> torch.FloatTensor:
if len(args) > 0 or kwargs.get("scale", None) is not None:
deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`."
deprecate("scale", "1.0.0", deprecation_message)
residual = hidden_states
if attn.spatial_norm is not None:
hidden_states = attn.spatial_norm(hidden_states, temb)
@@ -1242,16 +1267,26 @@ class AttnProcessor2_0:
if attn.group_norm is not None:
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
args = () if USE_PEFT_BACKEND else (scale,)
query = attn.to_q(hidden_states, *args)
query = attn.to_q(hidden_states)
if encoder_hidden_states is None:
encoder_hidden_states = hidden_states
# encoder_hidden_states = hidden_states
batch, seq, dim = hidden_states.shape
height = width = seq**0.5
# reshape to (batch, height, width, dim)
encoder_hidden_states = hidden_states.view(batch, height, width, dim)
# reshape to (batch, dim, height, width)
encoder_hidden_states = encoder_hidden_states.permute(0, 3, 1, 2)
encoder_hidden_states = torch.nn.functional.avg_pool2d(hidden_states, kernel_size=4)
# reshape to (batch, dim, seq)
encoder_hidden_states = encoder_hidden_states.view(batch, dim, -1)
# reshape to (batch, seq, dim)
encoder_hidden_states = encoder_hidden_states.permute(0, 2, 1)
elif attn.norm_cross:
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
key = attn.to_k(encoder_hidden_states, *args)
value = attn.to_v(encoder_hidden_states, *args)
key = attn.to_k(encoder_hidden_states)
value = attn.to_v(encoder_hidden_states)
inner_dim = key.shape[-1]
head_dim = inner_dim // attn.heads
@@ -1271,7 +1306,7 @@ class AttnProcessor2_0:
hidden_states = hidden_states.to(query.dtype)
# linear proj
hidden_states = attn.to_out[0](hidden_states, *args)
hidden_states = attn.to_out[0](hidden_states)
# dropout
hidden_states = attn.to_out[1](hidden_states)
@@ -1312,8 +1347,13 @@ class FusedAttnProcessor2_0:
encoder_hidden_states: Optional[torch.FloatTensor] = None,
attention_mask: Optional[torch.FloatTensor] = None,
temb: Optional[torch.FloatTensor] = None,
scale: float = 1.0,
*args,
**kwargs,
) -> torch.FloatTensor:
if len(args) > 0 or kwargs.get("scale", None) is not None:
deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`."
deprecate("scale", "1.0.0", deprecation_message)
residual = hidden_states
if attn.spatial_norm is not None:
hidden_states = attn.spatial_norm(hidden_states, temb)
@@ -1337,17 +1377,16 @@ class FusedAttnProcessor2_0:
if attn.group_norm is not None:
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
args = () if USE_PEFT_BACKEND else (scale,)
if encoder_hidden_states is None:
qkv = attn.to_qkv(hidden_states, *args)
qkv = attn.to_qkv(hidden_states)
split_size = qkv.shape[-1] // 3
query, key, value = torch.split(qkv, split_size, dim=-1)
else:
if attn.norm_cross:
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
query = attn.to_q(hidden_states, *args)
query = attn.to_q(hidden_states)
kv = attn.to_kv(encoder_hidden_states, *args)
kv = attn.to_kv(encoder_hidden_states)
split_size = kv.shape[-1] // 2
key, value = torch.split(kv, split_size, dim=-1)
@@ -1368,7 +1407,7 @@ class FusedAttnProcessor2_0:
hidden_states = hidden_states.to(query.dtype)
# linear proj
hidden_states = attn.to_out[0](hidden_states, *args)
hidden_states = attn.to_out[0](hidden_states)
# dropout
hidden_states = attn.to_out[1](hidden_states)
@@ -1859,7 +1898,7 @@ class LoRAAttnProcessor(nn.Module):
self.to_v_lora = LoRALinearLayer(cross_attention_dim or v_hidden_size, v_hidden_size, v_rank, network_alpha)
self.to_out_lora = LoRALinearLayer(out_hidden_size, out_hidden_size, out_rank, network_alpha)
def __call__(self, attn: Attention, hidden_states: torch.FloatTensor, *args, **kwargs) -> torch.FloatTensor:
def __call__(self, attn: Attention, hidden_states: torch.FloatTensor, **kwargs) -> torch.FloatTensor:
self_cls_name = self.__class__.__name__
deprecate(
self_cls_name,
@@ -1877,7 +1916,7 @@ class LoRAAttnProcessor(nn.Module):
attn._modules.pop("processor")
attn.processor = AttnProcessor()
return attn.processor(attn, hidden_states, *args, **kwargs)
return attn.processor(attn, hidden_states, **kwargs)
class LoRAAttnProcessor2_0(nn.Module):
@@ -1920,7 +1959,7 @@ class LoRAAttnProcessor2_0(nn.Module):
self.to_v_lora = LoRALinearLayer(cross_attention_dim or v_hidden_size, v_hidden_size, v_rank, network_alpha)
self.to_out_lora = LoRALinearLayer(out_hidden_size, out_hidden_size, out_rank, network_alpha)
def __call__(self, attn: Attention, hidden_states: torch.FloatTensor, *args, **kwargs) -> torch.FloatTensor:
def __call__(self, attn: Attention, hidden_states: torch.FloatTensor, **kwargs) -> torch.FloatTensor:
self_cls_name = self.__class__.__name__
deprecate(
self_cls_name,
@@ -1938,7 +1977,7 @@ class LoRAAttnProcessor2_0(nn.Module):
attn._modules.pop("processor")
attn.processor = AttnProcessor2_0()
return attn.processor(attn, hidden_states, *args, **kwargs)
return attn.processor(attn, hidden_states, **kwargs)
class LoRAXFormersAttnProcessor(nn.Module):
@@ -1999,7 +2038,7 @@ class LoRAXFormersAttnProcessor(nn.Module):
self.to_v_lora = LoRALinearLayer(cross_attention_dim or v_hidden_size, v_hidden_size, v_rank, network_alpha)
self.to_out_lora = LoRALinearLayer(out_hidden_size, out_hidden_size, out_rank, network_alpha)
def __call__(self, attn: Attention, hidden_states: torch.FloatTensor, *args, **kwargs) -> torch.FloatTensor:
def __call__(self, attn: Attention, hidden_states: torch.FloatTensor, **kwargs) -> torch.FloatTensor:
self_cls_name = self.__class__.__name__
deprecate(
self_cls_name,
@@ -2017,7 +2056,7 @@ class LoRAXFormersAttnProcessor(nn.Module):
attn._modules.pop("processor")
attn.processor = XFormersAttnProcessor()
return attn.processor(attn, hidden_states, *args, **kwargs)
return attn.processor(attn, hidden_states, **kwargs)
class LoRAAttnAddedKVProcessor(nn.Module):
@@ -2058,7 +2097,7 @@ class LoRAAttnAddedKVProcessor(nn.Module):
self.to_v_lora = LoRALinearLayer(hidden_size, hidden_size, rank, network_alpha)
self.to_out_lora = LoRALinearLayer(hidden_size, hidden_size, rank, network_alpha)
def __call__(self, attn: Attention, hidden_states: torch.FloatTensor, *args, **kwargs) -> torch.FloatTensor:
def __call__(self, attn: Attention, hidden_states: torch.FloatTensor, **kwargs) -> torch.FloatTensor:
self_cls_name = self.__class__.__name__
deprecate(
self_cls_name,
@@ -2076,7 +2115,7 @@ class LoRAAttnAddedKVProcessor(nn.Module):
attn._modules.pop("processor")
attn.processor = AttnAddedKVProcessor()
return attn.processor(attn, hidden_states, *args, **kwargs)
return attn.processor(attn, hidden_states, **kwargs)
class IPAdapterAttnProcessor(nn.Module):
+7 -11
View File
@@ -18,8 +18,7 @@ import torch
import torch.nn as nn
import torch.nn.functional as F
from ..utils import USE_PEFT_BACKEND
from .lora import LoRACompatibleConv
from ..utils import deprecate
from .normalization import RMSNorm
from .upsampling import upfirdn2d_native
@@ -103,7 +102,7 @@ class Downsample2D(nn.Module):
self.padding = padding
stride = 2
self.name = name
conv_cls = nn.Conv2d if USE_PEFT_BACKEND else LoRACompatibleConv
conv_cls = nn.Conv2d
if norm_type == "ln_norm":
self.norm = nn.LayerNorm(channels, eps, elementwise_affine)
@@ -131,7 +130,10 @@ class Downsample2D(nn.Module):
else:
self.conv = conv
def forward(self, hidden_states: torch.FloatTensor, scale: float = 1.0) -> torch.FloatTensor:
def forward(self, hidden_states: torch.FloatTensor, *args, **kwargs) -> torch.FloatTensor:
if len(args) > 0 or kwargs.get("scale", None) is not None:
deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`."
deprecate("scale", "1.0.0", deprecation_message)
assert hidden_states.shape[1] == self.channels
if self.norm is not None:
@@ -143,13 +145,7 @@ class Downsample2D(nn.Module):
assert hidden_states.shape[1] == self.channels
if not USE_PEFT_BACKEND:
if isinstance(self.conv, LoRACompatibleConv):
hidden_states = self.conv(hidden_states, scale)
else:
hidden_states = self.conv(hidden_states)
else:
hidden_states = self.conv(hidden_states)
hidden_states = self.conv(hidden_states)
return hidden_states
+2 -3
View File
@@ -18,10 +18,9 @@ import numpy as np
import torch
from torch import nn
from ..utils import USE_PEFT_BACKEND, deprecate
from ..utils import deprecate
from .activations import get_activation
from .attention_processor import Attention
from .lora import LoRACompatibleLinear
def get_timestep_embedding(
@@ -200,7 +199,7 @@ class TimestepEmbedding(nn.Module):
sample_proj_bias=True,
):
super().__init__()
linear_cls = nn.Linear if USE_PEFT_BACKEND else LoRACompatibleLinear
linear_cls = nn.Linear
self.linear_1 = linear_cls(in_channels, time_embed_dim, sample_proj_bias)
+6
View File
@@ -204,6 +204,9 @@ class LoRALinearLayer(nn.Module):
):
super().__init__()
deprecation_message = "Use of `LoRALinearLayer` is deprecated. Please switch to PEFT backend by installing PEFT: `pip install peft`."
deprecate("LoRALinearLayer", "1.0.0", deprecation_message)
self.down = nn.Linear(in_features, rank, bias=False, device=device, dtype=dtype)
self.up = nn.Linear(rank, out_features, bias=False, device=device, dtype=dtype)
# This value has the same meaning as the `--network_alpha` option in the kohya-ss trainer script.
@@ -264,6 +267,9 @@ class LoRAConv2dLayer(nn.Module):
):
super().__init__()
deprecation_message = "Use of `LoRAConv2dLayer` is deprecated. Please switch to PEFT backend by installing PEFT: `pip install peft`."
deprecate("LoRAConv2dLayer", "1.0.0", deprecation_message)
self.down = nn.Conv2d(in_features, rank, kernel_size=kernel_size, stride=stride, padding=padding, bias=False)
# according to the official kohya_ss trainer kernel_size are always fixed for the up layer
# # see: https://github.com/bmaltais/kohya_ss/blob/2accb1305979ba62f5077a23aabac23b4c37e935/networks/lora_diffusers.py#L129
+2 -2
View File
@@ -677,7 +677,7 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
unexpected_keys = [k for k in unexpected_keys if re.search(pat, k) is None]
if len(unexpected_keys) > 0:
logger.warn(
logger.warning(
f"Some weights of the model checkpoint were not used when initializing {cls.__name__}: \n {[', '.join(unexpected_keys)]}"
)
@@ -705,7 +705,7 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
# the weights so we don't have to do this again.
if "'Attention' object has no attribute" in str(e):
logger.warn(
logger.warning(
f"Taking `{str(e)}` while using `accelerate.load_checkpoint_and_dispatch` to mean {pretrained_model_name_or_path}"
" was saved with deprecated attention block weight names. We will load it with the deprecated attention block"
" names and convert them on the fly to the new attention block format. Please re-save the model after this conversion,"
+29 -56
View File
@@ -20,7 +20,7 @@ import torch
import torch.nn as nn
import torch.nn.functional as F
from ..utils import USE_PEFT_BACKEND
from ..utils import deprecate
from .activations import get_activation
from .attention_processor import SpatialNorm
from .downsampling import ( # noqa
@@ -30,7 +30,6 @@ from .downsampling import ( # noqa
KDownsample2D,
downsample_2d,
)
from .lora import LoRACompatibleConv, LoRACompatibleLinear
from .normalization import AdaGroupNorm
from .upsampling import ( # noqa
FirUpsample2D,
@@ -102,7 +101,7 @@ class ResnetBlockCondNorm2D(nn.Module):
self.output_scale_factor = output_scale_factor
self.time_embedding_norm = time_embedding_norm
conv_cls = nn.Conv2d if USE_PEFT_BACKEND else LoRACompatibleConv
conv_cls = nn.Conv2d
if groups_out is None:
groups_out = groups
@@ -149,12 +148,11 @@ class ResnetBlockCondNorm2D(nn.Module):
bias=conv_shortcut_bias,
)
def forward(
self,
input_tensor: torch.FloatTensor,
temb: torch.FloatTensor,
scale: float = 1.0,
) -> torch.FloatTensor:
def forward(self, input_tensor: torch.FloatTensor, temb: torch.FloatTensor, *args, **kwargs) -> torch.FloatTensor:
if len(args) > 0 or kwargs.get("scale", None) is not None:
deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`."
deprecate("scale", "1.0.0", deprecation_message)
hidden_states = input_tensor
hidden_states = self.norm1(hidden_states, temb)
@@ -166,26 +164,24 @@ class ResnetBlockCondNorm2D(nn.Module):
if hidden_states.shape[0] >= 64:
input_tensor = input_tensor.contiguous()
hidden_states = hidden_states.contiguous()
input_tensor = self.upsample(input_tensor, scale=scale)
hidden_states = self.upsample(hidden_states, scale=scale)
input_tensor = self.upsample(input_tensor)
hidden_states = self.upsample(hidden_states)
elif self.downsample is not None:
input_tensor = self.downsample(input_tensor, scale=scale)
hidden_states = self.downsample(hidden_states, scale=scale)
input_tensor = self.downsample(input_tensor)
hidden_states = self.downsample(hidden_states)
hidden_states = self.conv1(hidden_states, scale) if not USE_PEFT_BACKEND else self.conv1(hidden_states)
hidden_states = self.conv1(hidden_states)
hidden_states = self.norm2(hidden_states, temb)
hidden_states = self.nonlinearity(hidden_states)
hidden_states = self.dropout(hidden_states)
hidden_states = self.conv2(hidden_states, scale) if not USE_PEFT_BACKEND else self.conv2(hidden_states)
hidden_states = self.conv2(hidden_states)
if self.conv_shortcut is not None:
input_tensor = (
self.conv_shortcut(input_tensor, scale) if not USE_PEFT_BACKEND else self.conv_shortcut(input_tensor)
)
input_tensor = self.conv_shortcut(input_tensor)
output_tensor = (input_tensor + hidden_states) / self.output_scale_factor
@@ -267,8 +263,8 @@ class ResnetBlock2D(nn.Module):
self.time_embedding_norm = time_embedding_norm
self.skip_time_act = skip_time_act
linear_cls = nn.Linear if USE_PEFT_BACKEND else LoRACompatibleLinear
conv_cls = nn.Conv2d if USE_PEFT_BACKEND else LoRACompatibleConv
linear_cls = nn.Linear
conv_cls = nn.Conv2d
if groups_out is None:
groups_out = groups
@@ -326,12 +322,11 @@ class ResnetBlock2D(nn.Module):
bias=conv_shortcut_bias,
)
def forward(
self,
input_tensor: torch.FloatTensor,
temb: torch.FloatTensor,
scale: float = 1.0,
) -> torch.FloatTensor:
def forward(self, input_tensor: torch.FloatTensor, temb: torch.FloatTensor, *args, **kwargs) -> torch.FloatTensor:
if len(args) > 0 or kwargs.get("scale", None) is not None:
deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`."
deprecate("scale", "1.0.0", deprecation_message)
hidden_states = input_tensor
hidden_states = self.norm1(hidden_states)
@@ -342,38 +337,18 @@ class ResnetBlock2D(nn.Module):
if hidden_states.shape[0] >= 64:
input_tensor = input_tensor.contiguous()
hidden_states = hidden_states.contiguous()
input_tensor = (
self.upsample(input_tensor, scale=scale)
if isinstance(self.upsample, Upsample2D)
else self.upsample(input_tensor)
)
hidden_states = (
self.upsample(hidden_states, scale=scale)
if isinstance(self.upsample, Upsample2D)
else self.upsample(hidden_states)
)
input_tensor = self.upsample(input_tensor)
hidden_states = self.upsample(hidden_states)
elif self.downsample is not None:
input_tensor = (
self.downsample(input_tensor, scale=scale)
if isinstance(self.downsample, Downsample2D)
else self.downsample(input_tensor)
)
hidden_states = (
self.downsample(hidden_states, scale=scale)
if isinstance(self.downsample, Downsample2D)
else self.downsample(hidden_states)
)
input_tensor = self.downsample(input_tensor)
hidden_states = self.downsample(hidden_states)
hidden_states = self.conv1(hidden_states, scale) if not USE_PEFT_BACKEND else self.conv1(hidden_states)
hidden_states = self.conv1(hidden_states)
if self.time_emb_proj is not None:
if not self.skip_time_act:
temb = self.nonlinearity(temb)
temb = (
self.time_emb_proj(temb, scale)[:, :, None, None]
if not USE_PEFT_BACKEND
else self.time_emb_proj(temb)[:, :, None, None]
)
temb = self.time_emb_proj(temb)[:, :, None, None]
if self.time_embedding_norm == "default":
if temb is not None:
@@ -393,12 +368,10 @@ class ResnetBlock2D(nn.Module):
hidden_states = self.nonlinearity(hidden_states)
hidden_states = self.dropout(hidden_states)
hidden_states = self.conv2(hidden_states, scale) if not USE_PEFT_BACKEND else self.conv2(hidden_states)
hidden_states = self.conv2(hidden_states)
if self.conv_shortcut is not None:
input_tensor = (
self.conv_shortcut(input_tensor, scale) if not USE_PEFT_BACKEND else self.conv_shortcut(input_tensor)
)
input_tensor = self.conv_shortcut(input_tensor)
output_tensor = (input_tensor + hidden_states) / self.output_scale_factor
@@ -19,14 +19,16 @@ import torch.nn.functional as F
from torch import nn
from ...configuration_utils import ConfigMixin, register_to_config
from ...utils import USE_PEFT_BACKEND, BaseOutput, deprecate, is_torch_version
from ...utils import BaseOutput, deprecate, is_torch_version, logging
from ..attention import BasicTransformerBlock
from ..embeddings import ImagePositionalEmbeddings, PatchEmbed, PixArtAlphaTextProjection
from ..lora import LoRACompatibleConv, LoRACompatibleLinear
from ..modeling_utils import ModelMixin
from ..normalization import AdaLayerNormSingle
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
@dataclass
class Transformer2DModelOutput(BaseOutput):
"""
@@ -115,8 +117,8 @@ class Transformer2DModel(ModelMixin, ConfigMixin):
self.attention_head_dim = attention_head_dim
inner_dim = num_attention_heads * attention_head_dim
conv_cls = nn.Conv2d if USE_PEFT_BACKEND else LoRACompatibleConv
linear_cls = nn.Linear if USE_PEFT_BACKEND else LoRACompatibleLinear
conv_cls = nn.Conv2d
linear_cls = nn.Linear
# 1. Transformer2DModel can process both standard continuous images of shape `(batch_size, num_channels, width, height)` as well as quantized image embeddings of shape `(batch_size, num_image_vectors)`
# Define whether input is continuous or discrete depending on configuration
@@ -304,6 +306,9 @@ class Transformer2DModel(ModelMixin, ConfigMixin):
If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a
`tuple` where the first element is the sample tensor.
"""
if cross_attention_kwargs is not None:
if cross_attention_kwargs.get("scale", None) is not None:
logger.warning("Passing `scale` to `cross_attention_kwargs` is deprecated. `scale` will be ignored.")
# ensure attention_mask is a bias, and give it a singleton query_tokens dimension.
# we may have done this conversion already, e.g. if we came here via UNet2DConditionModel#forward.
# we can tell by counting dims; if ndim == 2: it's a mask rather than a bias.
@@ -327,9 +332,6 @@ class Transformer2DModel(ModelMixin, ConfigMixin):
encoder_attention_mask = (1 - encoder_attention_mask.to(hidden_states.dtype)) * -10000.0
encoder_attention_mask = encoder_attention_mask.unsqueeze(1)
# Retrieve lora scale.
lora_scale = cross_attention_kwargs.get("scale", 1.0) if cross_attention_kwargs is not None else 1.0
# 1. Input
if self.is_input_continuous:
batch, _, height, width = hidden_states.shape
@@ -337,21 +339,13 @@ class Transformer2DModel(ModelMixin, ConfigMixin):
hidden_states = self.norm(hidden_states)
if not self.use_linear_projection:
hidden_states = (
self.proj_in(hidden_states, scale=lora_scale)
if not USE_PEFT_BACKEND
else self.proj_in(hidden_states)
)
hidden_states = self.proj_in(hidden_states)
inner_dim = hidden_states.shape[1]
hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * width, inner_dim)
else:
inner_dim = hidden_states.shape[1]
hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * width, inner_dim)
hidden_states = (
self.proj_in(hidden_states, scale=lora_scale)
if not USE_PEFT_BACKEND
else self.proj_in(hidden_states)
)
hidden_states = self.proj_in(hidden_states)
elif self.is_input_vectorized:
hidden_states = self.latent_image_embedding(hidden_states)
@@ -414,17 +408,9 @@ class Transformer2DModel(ModelMixin, ConfigMixin):
if self.is_input_continuous:
if not self.use_linear_projection:
hidden_states = hidden_states.reshape(batch, height, width, inner_dim).permute(0, 3, 1, 2).contiguous()
hidden_states = (
self.proj_out(hidden_states, scale=lora_scale)
if not USE_PEFT_BACKEND
else self.proj_out(hidden_states)
)
hidden_states = self.proj_out(hidden_states)
else:
hidden_states = (
self.proj_out(hidden_states, scale=lora_scale)
if not USE_PEFT_BACKEND
else self.proj_out(hidden_states)
)
hidden_states = self.proj_out(hidden_states)
hidden_states = hidden_states.reshape(batch, height, width, inner_dim).permute(0, 3, 1, 2).contiguous()
output = hidden_states + residual
+162 -97
View File
@@ -18,7 +18,7 @@ import torch
import torch.nn.functional as F
from torch import nn
from ...utils import is_torch_version, logging
from ...utils import deprecate, is_torch_version, logging
from ...utils.torch_utils import apply_freeu
from ..activations import get_activation
from ..attention_processor import Attention, AttnAddedKVProcessor, AttnAddedKVProcessor2_0
@@ -69,7 +69,7 @@ def get_down_block(
):
# If attn head dim is not defined, we default it to the number of heads
if attention_head_dim is None:
logger.warn(
logger.warning(
f"It is recommended to provide `attention_head_dim` when calling `get_down_block`. Defaulting `attention_head_dim` to {num_attention_heads}."
)
attention_head_dim = num_attention_heads
@@ -354,7 +354,7 @@ def get_up_block(
) -> nn.Module:
# If attn head dim is not defined, we default it to the number of heads
if attention_head_dim is None:
logger.warn(
logger.warning(
f"It is recommended to provide `attention_head_dim` when calling `get_up_block`. Defaulting `attention_head_dim` to {num_attention_heads}."
)
attention_head_dim = num_attention_heads
@@ -673,7 +673,7 @@ class UNetMidBlock2D(nn.Module):
attentions = []
if attention_head_dim is None:
logger.warn(
logger.warning(
f"It is not recommend to pass `attention_head_dim=None`. Defaulting `attention_head_dim` to `in_channels`: {in_channels}."
)
attention_head_dim = in_channels
@@ -844,8 +844,11 @@ class UNetMidBlock2DCrossAttn(nn.Module):
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
encoder_attention_mask: Optional[torch.FloatTensor] = None,
) -> torch.FloatTensor:
lora_scale = cross_attention_kwargs.get("scale", 1.0) if cross_attention_kwargs is not None else 1.0
hidden_states = self.resnets[0](hidden_states, temb, scale=lora_scale)
if cross_attention_kwargs is not None:
if cross_attention_kwargs.get("scale", None) is not None:
logger.warning("Passing `scale` to `cross_attention_kwargs` is deprecated. `scale` will be ignored.")
hidden_states = self.resnets[0](hidden_states, temb)
for attn, resnet in zip(self.attentions, self.resnets[1:]):
if self.training and self.gradient_checkpointing:
@@ -882,7 +885,7 @@ class UNetMidBlock2DCrossAttn(nn.Module):
encoder_attention_mask=encoder_attention_mask,
return_dict=False,
)[0]
hidden_states = resnet(hidden_states, temb, scale=lora_scale)
hidden_states = resnet(hidden_states, temb)
return hidden_states
@@ -982,7 +985,8 @@ class UNetMidBlock2DSimpleCrossAttn(nn.Module):
encoder_attention_mask: Optional[torch.FloatTensor] = None,
) -> torch.FloatTensor:
cross_attention_kwargs = cross_attention_kwargs if cross_attention_kwargs is not None else {}
lora_scale = cross_attention_kwargs.get("scale", 1.0)
if cross_attention_kwargs.get("scale", None) is not None:
logger.warning("Passing `scale` to `cross_attention_kwargs` is deprecated. `scale` will be ignored.")
if attention_mask is None:
# if encoder_hidden_states is defined: we are doing cross-attn, so we should use cross-attn mask.
@@ -995,7 +999,7 @@ class UNetMidBlock2DSimpleCrossAttn(nn.Module):
# mask = attention_mask if encoder_hidden_states is None else encoder_attention_mask
mask = attention_mask
hidden_states = self.resnets[0](hidden_states, temb, scale=lora_scale)
hidden_states = self.resnets[0](hidden_states, temb)
for attn, resnet in zip(self.attentions, self.resnets[1:]):
# attn
hidden_states = attn(
@@ -1006,7 +1010,7 @@ class UNetMidBlock2DSimpleCrossAttn(nn.Module):
)
# resnet
hidden_states = resnet(hidden_states, temb, scale=lora_scale)
hidden_states = resnet(hidden_states, temb)
return hidden_states
@@ -1035,7 +1039,7 @@ class AttnDownBlock2D(nn.Module):
self.downsample_type = downsample_type
if attention_head_dim is None:
logger.warn(
logger.warning(
f"It is not recommend to pass `attention_head_dim=None`. Defaulting `attention_head_dim` to `in_channels`: {out_channels}."
)
attention_head_dim = out_channels
@@ -1111,23 +1115,22 @@ class AttnDownBlock2D(nn.Module):
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
) -> Tuple[torch.FloatTensor, Tuple[torch.FloatTensor, ...]]:
cross_attention_kwargs = cross_attention_kwargs if cross_attention_kwargs is not None else {}
lora_scale = cross_attention_kwargs.get("scale", 1.0)
if cross_attention_kwargs.get("scale", None) is not None:
logger.warning("Passing `scale` to `cross_attention_kwargs` is deprecated. `scale` will be ignored.")
output_states = ()
for resnet, attn in zip(self.resnets, self.attentions):
cross_attention_kwargs.update({"scale": lora_scale})
hidden_states = resnet(hidden_states, temb, scale=lora_scale)
hidden_states = resnet(hidden_states, temb)
hidden_states = attn(hidden_states, **cross_attention_kwargs)
output_states = output_states + (hidden_states,)
if self.downsamplers is not None:
for downsampler in self.downsamplers:
if self.downsample_type == "resnet":
hidden_states = downsampler(hidden_states, temb=temb, scale=lora_scale)
hidden_states = downsampler(hidden_states, temb=temb)
else:
hidden_states = downsampler(hidden_states, scale=lora_scale)
hidden_states = downsampler(hidden_states)
output_states += (hidden_states,)
@@ -1236,9 +1239,11 @@ class CrossAttnDownBlock2D(nn.Module):
encoder_attention_mask: Optional[torch.FloatTensor] = None,
additional_residuals: Optional[torch.FloatTensor] = None,
) -> Tuple[torch.FloatTensor, Tuple[torch.FloatTensor, ...]]:
output_states = ()
if cross_attention_kwargs is not None:
if cross_attention_kwargs.get("scale", None) is not None:
logger.warning("Passing `scale` to `cross_attention_kwargs` is deprecated. `scale` will be ignored.")
lora_scale = cross_attention_kwargs.get("scale", 1.0) if cross_attention_kwargs is not None else 1.0
output_states = ()
blocks = list(zip(self.resnets, self.attentions))
@@ -1270,7 +1275,7 @@ class CrossAttnDownBlock2D(nn.Module):
return_dict=False,
)[0]
else:
hidden_states = resnet(hidden_states, temb, scale=lora_scale)
hidden_states = resnet(hidden_states, temb)
hidden_states = attn(
hidden_states,
encoder_hidden_states=encoder_hidden_states,
@@ -1288,7 +1293,7 @@ class CrossAttnDownBlock2D(nn.Module):
if self.downsamplers is not None:
for downsampler in self.downsamplers:
hidden_states = downsampler(hidden_states, scale=lora_scale)
hidden_states = downsampler(hidden_states)
output_states = output_states + (hidden_states,)
@@ -1348,8 +1353,12 @@ class DownBlock2D(nn.Module):
self.gradient_checkpointing = False
def forward(
self, hidden_states: torch.FloatTensor, temb: Optional[torch.FloatTensor] = None, scale: float = 1.0
self, hidden_states: torch.FloatTensor, temb: Optional[torch.FloatTensor] = None, *args, **kwargs
) -> Tuple[torch.FloatTensor, Tuple[torch.FloatTensor, ...]]:
if len(args) > 0 or kwargs.get("scale", None) is not None:
deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`."
deprecate("scale", "1.0.0", deprecation_message)
output_states = ()
for resnet in self.resnets:
@@ -1370,13 +1379,13 @@ class DownBlock2D(nn.Module):
create_custom_forward(resnet), hidden_states, temb
)
else:
hidden_states = resnet(hidden_states, temb, scale=scale)
hidden_states = resnet(hidden_states, temb)
output_states = output_states + (hidden_states,)
if self.downsamplers is not None:
for downsampler in self.downsamplers:
hidden_states = downsampler(hidden_states, scale=scale)
hidden_states = downsampler(hidden_states)
output_states = output_states + (hidden_states,)
@@ -1447,13 +1456,17 @@ class DownEncoderBlock2D(nn.Module):
else:
self.downsamplers = None
def forward(self, hidden_states: torch.FloatTensor, scale: float = 1.0) -> torch.FloatTensor:
def forward(self, hidden_states: torch.FloatTensor, *args, **kwargs) -> torch.FloatTensor:
if len(args) > 0 or kwargs.get("scale", None) is not None:
deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`."
deprecate("scale", "1.0.0", deprecation_message)
for resnet in self.resnets:
hidden_states = resnet(hidden_states, temb=None, scale=scale)
hidden_states = resnet(hidden_states, temb=None)
if self.downsamplers is not None:
for downsampler in self.downsamplers:
hidden_states = downsampler(hidden_states, scale)
hidden_states = downsampler(hidden_states)
return hidden_states
@@ -1480,7 +1493,7 @@ class AttnDownEncoderBlock2D(nn.Module):
attentions = []
if attention_head_dim is None:
logger.warn(
logger.warning(
f"It is not recommend to pass `attention_head_dim=None`. Defaulting `attention_head_dim` to `in_channels`: {out_channels}."
)
attention_head_dim = out_channels
@@ -1545,15 +1558,18 @@ class AttnDownEncoderBlock2D(nn.Module):
else:
self.downsamplers = None
def forward(self, hidden_states: torch.FloatTensor, scale: float = 1.0) -> torch.FloatTensor:
def forward(self, hidden_states: torch.FloatTensor, *args, **kwargs) -> torch.FloatTensor:
if len(args) > 0 or kwargs.get("scale", None) is not None:
deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`."
deprecate("scale", "1.0.0", deprecation_message)
for resnet, attn in zip(self.resnets, self.attentions):
hidden_states = resnet(hidden_states, temb=None, scale=scale)
cross_attention_kwargs = {"scale": scale}
hidden_states = attn(hidden_states, **cross_attention_kwargs)
hidden_states = resnet(hidden_states, temb=None)
hidden_states = attn(hidden_states)
if self.downsamplers is not None:
for downsampler in self.downsamplers:
hidden_states = downsampler(hidden_states, scale)
hidden_states = downsampler(hidden_states)
return hidden_states
@@ -1579,7 +1595,7 @@ class AttnSkipDownBlock2D(nn.Module):
self.resnets = nn.ModuleList([])
if attention_head_dim is None:
logger.warn(
logger.warning(
f"It is not recommend to pass `attention_head_dim=None`. Defaulting `attention_head_dim` to `in_channels`: {out_channels}."
)
attention_head_dim = out_channels
@@ -1644,18 +1660,22 @@ class AttnSkipDownBlock2D(nn.Module):
hidden_states: torch.FloatTensor,
temb: Optional[torch.FloatTensor] = None,
skip_sample: Optional[torch.FloatTensor] = None,
scale: float = 1.0,
*args,
**kwargs,
) -> Tuple[torch.FloatTensor, Tuple[torch.FloatTensor, ...], torch.FloatTensor]:
if len(args) > 0 or kwargs.get("scale", None) is not None:
deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`."
deprecate("scale", "1.0.0", deprecation_message)
output_states = ()
for resnet, attn in zip(self.resnets, self.attentions):
hidden_states = resnet(hidden_states, temb, scale=scale)
cross_attention_kwargs = {"scale": scale}
hidden_states = attn(hidden_states, **cross_attention_kwargs)
hidden_states = resnet(hidden_states, temb)
hidden_states = attn(hidden_states)
output_states += (hidden_states,)
if self.downsamplers is not None:
hidden_states = self.resnet_down(hidden_states, temb, scale=scale)
hidden_states = self.resnet_down(hidden_states, temb)
for downsampler in self.downsamplers:
skip_sample = downsampler(skip_sample)
@@ -1731,16 +1751,21 @@ class SkipDownBlock2D(nn.Module):
hidden_states: torch.FloatTensor,
temb: Optional[torch.FloatTensor] = None,
skip_sample: Optional[torch.FloatTensor] = None,
scale: float = 1.0,
*args,
**kwargs,
) -> Tuple[torch.FloatTensor, Tuple[torch.FloatTensor, ...], torch.FloatTensor]:
if len(args) > 0 or kwargs.get("scale", None) is not None:
deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`."
deprecate("scale", "1.0.0", deprecation_message)
output_states = ()
for resnet in self.resnets:
hidden_states = resnet(hidden_states, temb, scale)
hidden_states = resnet(hidden_states, temb)
output_states += (hidden_states,)
if self.downsamplers is not None:
hidden_states = self.resnet_down(hidden_states, temb, scale)
hidden_states = self.resnet_down(hidden_states, temb)
for downsampler in self.downsamplers:
skip_sample = downsampler(skip_sample)
@@ -1816,8 +1841,12 @@ class ResnetDownsampleBlock2D(nn.Module):
self.gradient_checkpointing = False
def forward(
self, hidden_states: torch.FloatTensor, temb: Optional[torch.FloatTensor] = None, scale: float = 1.0
self, hidden_states: torch.FloatTensor, temb: Optional[torch.FloatTensor] = None, *args, **kwargs
) -> Tuple[torch.FloatTensor, Tuple[torch.FloatTensor, ...]]:
if len(args) > 0 or kwargs.get("scale", None) is not None:
deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`."
deprecate("scale", "1.0.0", deprecation_message)
output_states = ()
for resnet in self.resnets:
@@ -1838,13 +1867,13 @@ class ResnetDownsampleBlock2D(nn.Module):
create_custom_forward(resnet), hidden_states, temb
)
else:
hidden_states = resnet(hidden_states, temb, scale)
hidden_states = resnet(hidden_states, temb)
output_states = output_states + (hidden_states,)
if self.downsamplers is not None:
for downsampler in self.downsamplers:
hidden_states = downsampler(hidden_states, temb, scale)
hidden_states = downsampler(hidden_states, temb)
output_states = output_states + (hidden_states,)
@@ -1955,10 +1984,11 @@ class SimpleCrossAttnDownBlock2D(nn.Module):
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
encoder_attention_mask: Optional[torch.FloatTensor] = None,
) -> Tuple[torch.FloatTensor, Tuple[torch.FloatTensor, ...]]:
output_states = ()
cross_attention_kwargs = cross_attention_kwargs if cross_attention_kwargs is not None else {}
if cross_attention_kwargs.get("scale", None) is not None:
logger.warning("Passing `scale` to `cross_attention_kwargs` is deprecated. `scale` will be ignored.")
lora_scale = cross_attention_kwargs.get("scale", 1.0)
output_states = ()
if attention_mask is None:
# if encoder_hidden_states is defined: we are doing cross-attn, so we should use cross-attn mask.
@@ -1991,7 +2021,7 @@ class SimpleCrossAttnDownBlock2D(nn.Module):
**cross_attention_kwargs,
)
else:
hidden_states = resnet(hidden_states, temb, scale=lora_scale)
hidden_states = resnet(hidden_states, temb)
hidden_states = attn(
hidden_states,
@@ -2004,7 +2034,7 @@ class SimpleCrossAttnDownBlock2D(nn.Module):
if self.downsamplers is not None:
for downsampler in self.downsamplers:
hidden_states = downsampler(hidden_states, temb, scale=lora_scale)
hidden_states = downsampler(hidden_states, temb)
output_states = output_states + (hidden_states,)
@@ -2058,8 +2088,12 @@ class KDownBlock2D(nn.Module):
self.gradient_checkpointing = False
def forward(
self, hidden_states: torch.FloatTensor, temb: Optional[torch.FloatTensor] = None, scale: float = 1.0
self, hidden_states: torch.FloatTensor, temb: Optional[torch.FloatTensor] = None, *args, **kwargs
) -> Tuple[torch.FloatTensor, Tuple[torch.FloatTensor, ...]]:
if len(args) > 0 or kwargs.get("scale", None) is not None:
deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`."
deprecate("scale", "1.0.0", deprecation_message)
output_states = ()
for resnet in self.resnets:
@@ -2080,7 +2114,7 @@ class KDownBlock2D(nn.Module):
create_custom_forward(resnet), hidden_states, temb
)
else:
hidden_states = resnet(hidden_states, temb, scale)
hidden_states = resnet(hidden_states, temb)
output_states += (hidden_states,)
@@ -2165,8 +2199,11 @@ class KCrossAttnDownBlock2D(nn.Module):
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
encoder_attention_mask: Optional[torch.FloatTensor] = None,
) -> Tuple[torch.FloatTensor, Tuple[torch.FloatTensor, ...]]:
cross_attention_kwargs = cross_attention_kwargs if cross_attention_kwargs is not None else {}
if cross_attention_kwargs.get("scale", None) is not None:
logger.warning("Passing `scale` to `cross_attention_kwargs` is deprecated. `scale` will be ignored.")
output_states = ()
lora_scale = cross_attention_kwargs.get("scale", 1.0) if cross_attention_kwargs is not None else 1.0
for resnet, attn in zip(self.resnets, self.attentions):
if self.training and self.gradient_checkpointing:
@@ -2196,7 +2233,7 @@ class KCrossAttnDownBlock2D(nn.Module):
encoder_attention_mask=encoder_attention_mask,
)
else:
hidden_states = resnet(hidden_states, temb, scale=lora_scale)
hidden_states = resnet(hidden_states, temb)
hidden_states = attn(
hidden_states,
encoder_hidden_states=encoder_hidden_states,
@@ -2244,7 +2281,7 @@ class AttnUpBlock2D(nn.Module):
self.upsample_type = upsample_type
if attention_head_dim is None:
logger.warn(
logger.warning(
f"It is not recommend to pass `attention_head_dim=None`. Defaulting `attention_head_dim` to `in_channels`: {out_channels}."
)
attention_head_dim = out_channels
@@ -2316,24 +2353,28 @@ class AttnUpBlock2D(nn.Module):
res_hidden_states_tuple: Tuple[torch.FloatTensor, ...],
temb: Optional[torch.FloatTensor] = None,
upsample_size: Optional[int] = None,
scale: float = 1.0,
*args,
**kwargs,
) -> torch.FloatTensor:
if len(args) > 0 or kwargs.get("scale", None) is not None:
deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`."
deprecate("scale", "1.0.0", deprecation_message)
for resnet, attn in zip(self.resnets, self.attentions):
# pop res hidden states
res_hidden_states = res_hidden_states_tuple[-1]
res_hidden_states_tuple = res_hidden_states_tuple[:-1]
hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
hidden_states = resnet(hidden_states, temb, scale=scale)
cross_attention_kwargs = {"scale": scale}
hidden_states = attn(hidden_states, **cross_attention_kwargs)
hidden_states = resnet(hidden_states, temb)
hidden_states = attn(hidden_states)
if self.upsamplers is not None:
for upsampler in self.upsamplers:
if self.upsample_type == "resnet":
hidden_states = upsampler(hidden_states, temb=temb, scale=scale)
hidden_states = upsampler(hidden_states, temb=temb)
else:
hidden_states = upsampler(hidden_states, scale=scale)
hidden_states = upsampler(hidden_states)
return hidden_states
@@ -2440,7 +2481,10 @@ class CrossAttnUpBlock2D(nn.Module):
attention_mask: Optional[torch.FloatTensor] = None,
encoder_attention_mask: Optional[torch.FloatTensor] = None,
) -> torch.FloatTensor:
lora_scale = cross_attention_kwargs.get("scale", 1.0) if cross_attention_kwargs is not None else 1.0
if cross_attention_kwargs is not None:
if cross_attention_kwargs.get("scale", None) is not None:
logger.warning("Passing `scale` to `cross_attention_kwargs` is deprecated. `scale` will be ignored.")
is_freeu_enabled = (
getattr(self, "s1", None)
and getattr(self, "s2", None)
@@ -2494,7 +2538,7 @@ class CrossAttnUpBlock2D(nn.Module):
return_dict=False,
)[0]
else:
hidden_states = resnet(hidden_states, temb, scale=lora_scale)
hidden_states = resnet(hidden_states, temb)
hidden_states = attn(
hidden_states,
encoder_hidden_states=encoder_hidden_states,
@@ -2506,7 +2550,7 @@ class CrossAttnUpBlock2D(nn.Module):
if self.upsamplers is not None:
for upsampler in self.upsamplers:
hidden_states = upsampler(hidden_states, upsample_size, scale=lora_scale)
hidden_states = upsampler(hidden_states, upsample_size)
return hidden_states
@@ -2567,8 +2611,13 @@ class UpBlock2D(nn.Module):
res_hidden_states_tuple: Tuple[torch.FloatTensor, ...],
temb: Optional[torch.FloatTensor] = None,
upsample_size: Optional[int] = None,
scale: float = 1.0,
*args,
**kwargs,
) -> torch.FloatTensor:
if len(args) > 0 or kwargs.get("scale", None) is not None:
deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`."
deprecate("scale", "1.0.0", deprecation_message)
is_freeu_enabled = (
getattr(self, "s1", None)
and getattr(self, "s2", None)
@@ -2612,11 +2661,11 @@ class UpBlock2D(nn.Module):
create_custom_forward(resnet), hidden_states, temb
)
else:
hidden_states = resnet(hidden_states, temb, scale=scale)
hidden_states = resnet(hidden_states, temb)
if self.upsamplers is not None:
for upsampler in self.upsamplers:
hidden_states = upsampler(hidden_states, upsample_size, scale=scale)
hidden_states = upsampler(hidden_states, upsample_size)
return hidden_states
@@ -2683,11 +2732,9 @@ class UpDecoderBlock2D(nn.Module):
self.resolution_idx = resolution_idx
def forward(
self, hidden_states: torch.FloatTensor, temb: Optional[torch.FloatTensor] = None, scale: float = 1.0
) -> torch.FloatTensor:
def forward(self, hidden_states: torch.FloatTensor, temb: Optional[torch.FloatTensor] = None) -> torch.FloatTensor:
for resnet in self.resnets:
hidden_states = resnet(hidden_states, temb=temb, scale=scale)
hidden_states = resnet(hidden_states, temb=temb)
if self.upsamplers is not None:
for upsampler in self.upsamplers:
@@ -2719,7 +2766,7 @@ class AttnUpDecoderBlock2D(nn.Module):
attentions = []
if attention_head_dim is None:
logger.warn(
logger.warning(
f"It is not recommend to pass `attention_head_dim=None`. Defaulting `attention_head_dim` to `out_channels`: {out_channels}."
)
attention_head_dim = out_channels
@@ -2783,17 +2830,14 @@ class AttnUpDecoderBlock2D(nn.Module):
self.resolution_idx = resolution_idx
def forward(
self, hidden_states: torch.FloatTensor, temb: Optional[torch.FloatTensor] = None, scale: float = 1.0
) -> torch.FloatTensor:
def forward(self, hidden_states: torch.FloatTensor, temb: Optional[torch.FloatTensor] = None) -> torch.FloatTensor:
for resnet, attn in zip(self.resnets, self.attentions):
hidden_states = resnet(hidden_states, temb=temb, scale=scale)
cross_attention_kwargs = {"scale": scale}
hidden_states = attn(hidden_states, temb=temb, **cross_attention_kwargs)
hidden_states = resnet(hidden_states, temb=temb)
hidden_states = attn(hidden_states, temb=temb)
if self.upsamplers is not None:
for upsampler in self.upsamplers:
hidden_states = upsampler(hidden_states, scale=scale)
hidden_states = upsampler(hidden_states)
return hidden_states
@@ -2841,7 +2885,7 @@ class AttnSkipUpBlock2D(nn.Module):
)
if attention_head_dim is None:
logger.warn(
logger.warning(
f"It is not recommend to pass `attention_head_dim=None`. Defaulting `attention_head_dim` to `out_channels`: {out_channels}."
)
attention_head_dim = out_channels
@@ -2898,18 +2942,22 @@ class AttnSkipUpBlock2D(nn.Module):
res_hidden_states_tuple: Tuple[torch.FloatTensor, ...],
temb: Optional[torch.FloatTensor] = None,
skip_sample=None,
scale: float = 1.0,
*args,
**kwargs,
) -> Tuple[torch.FloatTensor, torch.FloatTensor]:
if len(args) > 0 or kwargs.get("scale", None) is not None:
deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`."
deprecate("scale", "1.0.0", deprecation_message)
for resnet in self.resnets:
# pop res hidden states
res_hidden_states = res_hidden_states_tuple[-1]
res_hidden_states_tuple = res_hidden_states_tuple[:-1]
hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
hidden_states = resnet(hidden_states, temb, scale=scale)
hidden_states = resnet(hidden_states, temb)
cross_attention_kwargs = {"scale": scale}
hidden_states = self.attentions[0](hidden_states, **cross_attention_kwargs)
hidden_states = self.attentions[0](hidden_states)
if skip_sample is not None:
skip_sample = self.upsampler(skip_sample)
@@ -2923,7 +2971,7 @@ class AttnSkipUpBlock2D(nn.Module):
skip_sample = skip_sample + skip_sample_states
hidden_states = self.resnet_up(hidden_states, temb, scale=scale)
hidden_states = self.resnet_up(hidden_states, temb)
return hidden_states, skip_sample
@@ -3006,15 +3054,20 @@ class SkipUpBlock2D(nn.Module):
res_hidden_states_tuple: Tuple[torch.FloatTensor, ...],
temb: Optional[torch.FloatTensor] = None,
skip_sample=None,
scale: float = 1.0,
*args,
**kwargs,
) -> Tuple[torch.FloatTensor, torch.FloatTensor]:
if len(args) > 0 or kwargs.get("scale", None) is not None:
deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`."
deprecate("scale", "1.0.0", deprecation_message)
for resnet in self.resnets:
# pop res hidden states
res_hidden_states = res_hidden_states_tuple[-1]
res_hidden_states_tuple = res_hidden_states_tuple[:-1]
hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
hidden_states = resnet(hidden_states, temb, scale=scale)
hidden_states = resnet(hidden_states, temb)
if skip_sample is not None:
skip_sample = self.upsampler(skip_sample)
@@ -3028,7 +3081,7 @@ class SkipUpBlock2D(nn.Module):
skip_sample = skip_sample + skip_sample_states
hidden_states = self.resnet_up(hidden_states, temb, scale=scale)
hidden_states = self.resnet_up(hidden_states, temb)
return hidden_states, skip_sample
@@ -3108,8 +3161,13 @@ class ResnetUpsampleBlock2D(nn.Module):
res_hidden_states_tuple: Tuple[torch.FloatTensor, ...],
temb: Optional[torch.FloatTensor] = None,
upsample_size: Optional[int] = None,
scale: float = 1.0,
*args,
**kwargs,
) -> torch.FloatTensor:
if len(args) > 0 or kwargs.get("scale", None) is not None:
deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`."
deprecate("scale", "1.0.0", deprecation_message)
for resnet in self.resnets:
# pop res hidden states
res_hidden_states = res_hidden_states_tuple[-1]
@@ -3133,11 +3191,11 @@ class ResnetUpsampleBlock2D(nn.Module):
create_custom_forward(resnet), hidden_states, temb
)
else:
hidden_states = resnet(hidden_states, temb, scale=scale)
hidden_states = resnet(hidden_states, temb)
if self.upsamplers is not None:
for upsampler in self.upsamplers:
hidden_states = upsampler(hidden_states, temb, scale=scale)
hidden_states = upsampler(hidden_states, temb)
return hidden_states
@@ -3253,8 +3311,9 @@ class SimpleCrossAttnUpBlock2D(nn.Module):
encoder_attention_mask: Optional[torch.FloatTensor] = None,
) -> torch.FloatTensor:
cross_attention_kwargs = cross_attention_kwargs if cross_attention_kwargs is not None else {}
if cross_attention_kwargs.get("scale", None) is not None:
logger.warning("Passing `scale` to `cross_attention_kwargs` is deprecated. `scale` will be ignored.")
lora_scale = cross_attention_kwargs.get("scale", 1.0)
if attention_mask is None:
# if encoder_hidden_states is defined: we are doing cross-attn, so we should use cross-attn mask.
mask = None if encoder_hidden_states is None else encoder_attention_mask
@@ -3292,7 +3351,7 @@ class SimpleCrossAttnUpBlock2D(nn.Module):
**cross_attention_kwargs,
)
else:
hidden_states = resnet(hidden_states, temb, scale=lora_scale)
hidden_states = resnet(hidden_states, temb)
hidden_states = attn(
hidden_states,
@@ -3303,7 +3362,7 @@ class SimpleCrossAttnUpBlock2D(nn.Module):
if self.upsamplers is not None:
for upsampler in self.upsamplers:
hidden_states = upsampler(hidden_states, temb, scale=lora_scale)
hidden_states = upsampler(hidden_states, temb)
return hidden_states
@@ -3364,8 +3423,13 @@ class KUpBlock2D(nn.Module):
res_hidden_states_tuple: Tuple[torch.FloatTensor, ...],
temb: Optional[torch.FloatTensor] = None,
upsample_size: Optional[int] = None,
scale: float = 1.0,
*args,
**kwargs,
) -> torch.FloatTensor:
if len(args) > 0 or kwargs.get("scale", None) is not None:
deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`."
deprecate("scale", "1.0.0", deprecation_message)
res_hidden_states_tuple = res_hidden_states_tuple[-1]
if res_hidden_states_tuple is not None:
hidden_states = torch.cat([hidden_states, res_hidden_states_tuple], dim=1)
@@ -3388,7 +3452,7 @@ class KUpBlock2D(nn.Module):
create_custom_forward(resnet), hidden_states, temb
)
else:
hidden_states = resnet(hidden_states, temb, scale=scale)
hidden_states = resnet(hidden_states, temb)
if self.upsamplers is not None:
for upsampler in self.upsamplers:
@@ -3498,7 +3562,6 @@ class KCrossAttnUpBlock2D(nn.Module):
if res_hidden_states_tuple is not None:
hidden_states = torch.cat([hidden_states, res_hidden_states_tuple], dim=1)
lora_scale = cross_attention_kwargs.get("scale", 1.0) if cross_attention_kwargs is not None else 1.0
for resnet, attn in zip(self.resnets, self.attentions):
if self.training and self.gradient_checkpointing:
@@ -3527,7 +3590,7 @@ class KCrossAttnUpBlock2D(nn.Module):
encoder_attention_mask=encoder_attention_mask,
)
else:
hidden_states = resnet(hidden_states, temb, scale=lora_scale)
hidden_states = resnet(hidden_states, temb)
hidden_states = attn(
hidden_states,
encoder_hidden_states=encoder_hidden_states,
@@ -3630,6 +3693,8 @@ class KAttentionBlock(nn.Module):
encoder_attention_mask: Optional[torch.FloatTensor] = None,
) -> torch.FloatTensor:
cross_attention_kwargs = cross_attention_kwargs if cross_attention_kwargs is not None else {}
if cross_attention_kwargs.get("scale", None) is not None:
logger.warning("Passing `scale` to `cross_attention_kwargs` is deprecated. `scale` will be ignored.")
# 1. Self-Attention
if self.add_self_attention:
@@ -80,7 +80,7 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin,
in_channels (`int`, *optional*, defaults to 4): Number of channels in the input sample.
out_channels (`int`, *optional*, defaults to 4): Number of channels in the output.
center_input_sample (`bool`, *optional*, defaults to `False`): Whether to center the input sample.
flip_sin_to_cos (`bool`, *optional*, defaults to `False`):
flip_sin_to_cos (`bool`, *optional*, defaults to `True`):
Whether to flip the sin to cos in the time embedding.
freq_shift (`int`, *optional*, defaults to 0): The frequency shift to apply to the time embedding.
down_block_types (`Tuple[str]`, *optional*, defaults to `("CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "DownBlock2D")`):
@@ -109,7 +109,7 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin,
The number of transformer blocks of type [`~models.attention.BasicTransformerBlock`]. Only relevant for
[`~models.unet_2d_blocks.CrossAttnDownBlock2D`], [`~models.unet_2d_blocks.CrossAttnUpBlock2D`],
[`~models.unet_2d_blocks.UNetMidBlock2DCrossAttn`].
reverse_transformer_layers_per_block : (`Tuple[Tuple]`, *optional*, defaults to None):
reverse_transformer_layers_per_block : (`Tuple[Tuple]`, *optional*, defaults to None):
The number of transformer blocks of type [`~models.attention.BasicTransformerBlock`], in the upsampling
blocks of the U-Net. Only relevant if `transformer_layers_per_block` is of type `Tuple[Tuple]` and for
[`~models.unet_2d_blocks.CrossAttnDownBlock2D`], [`~models.unet_2d_blocks.CrossAttnUpBlock2D`],
@@ -147,9 +147,9 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin,
The second activation function to use in timestep embedding. Choose from `silu`, `mish` and `gelu`.
time_cond_proj_dim (`int`, *optional*, defaults to `None`):
The dimension of `cond_proj` layer in the timestep embedding.
conv_in_kernel (`int`, *optional*, default to `3`): The kernel size of `conv_in` layer. conv_out_kernel (`int`,
*optional*, default to `3`): The kernel size of `conv_out` layer. projection_class_embeddings_input_dim (`int`,
*optional*): The dimension of the `class_labels` input when
conv_in_kernel (`int`, *optional*, default to `3`): The kernel size of `conv_in` layer.
conv_out_kernel (`int`, *optional*, default to `3`): The kernel size of `conv_out` layer.
projection_class_embeddings_input_dim (`int`, *optional*): The dimension of the `class_labels` input when
`class_embed_type="projection"`. Required when `class_embed_type="projection"`.
class_embeddings_concat (`bool`, *optional*, defaults to `False`): Whether to concatenate the time
embeddings with the class embeddings.
@@ -1226,7 +1226,7 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin,
**additional_residuals,
)
else:
sample, res_samples = downsample_block(hidden_states=sample, temb=emb, scale=lora_scale)
sample, res_samples = downsample_block(hidden_states=sample, temb=emb)
if is_adapter and len(down_intrablock_additional_residuals) > 0:
sample += down_intrablock_additional_residuals.pop(0)
@@ -1297,7 +1297,6 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin,
temb=emb,
res_hidden_states_tuple=res_samples,
upsample_size=upsample_size,
scale=lora_scale,
)
# 6. post-process
+39 -18
View File
@@ -17,7 +17,7 @@ from typing import Any, Dict, Optional, Tuple, Union
import torch
from torch import nn
from ...utils import is_torch_version
from ...utils import deprecate, is_torch_version, logging
from ...utils.torch_utils import apply_freeu
from ..attention import Attention
from ..resnet import (
@@ -35,6 +35,9 @@ from ..transformers.transformer_temporal import (
)
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
def get_down_block(
down_block_type: str,
num_layers: int,
@@ -1005,9 +1008,14 @@ class DownBlockMotion(nn.Module):
self,
hidden_states: torch.FloatTensor,
temb: Optional[torch.FloatTensor] = None,
scale: float = 1.0,
num_frames: int = 1,
*args,
**kwargs,
) -> Union[torch.FloatTensor, Tuple[torch.FloatTensor, ...]]:
if len(args) > 0 or kwargs.get("scale", None) is not None:
deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`."
deprecate("scale", "1.0.0", deprecation_message)
output_states = ()
blocks = zip(self.resnets, self.motion_modules)
@@ -1029,18 +1037,18 @@ class DownBlockMotion(nn.Module):
)
else:
hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(resnet), hidden_states, temb, scale
create_custom_forward(resnet), hidden_states, temb
)
else:
hidden_states = resnet(hidden_states, temb, scale=scale)
hidden_states = resnet(hidden_states, temb)
hidden_states = motion_module(hidden_states, num_frames=num_frames)[0]
output_states = output_states + (hidden_states,)
if self.downsamplers is not None:
for downsampler in self.downsamplers:
hidden_states = downsampler(hidden_states, scale=scale)
hidden_states = downsampler(hidden_states)
output_states = output_states + (hidden_states,)
@@ -1173,9 +1181,11 @@ class CrossAttnDownBlockMotion(nn.Module):
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
additional_residuals: Optional[torch.FloatTensor] = None,
):
output_states = ()
if cross_attention_kwargs is not None:
if cross_attention_kwargs.get("scale", None) is not None:
logger.warning("Passing `scale` to `cross_attention_kwargs` is deprecated. `scale` will be ignored.")
lora_scale = cross_attention_kwargs.get("scale", 1.0) if cross_attention_kwargs is not None else 1.0
output_states = ()
blocks = list(zip(self.resnets, self.attentions, self.motion_modules))
for i, (resnet, attn, motion_module) in enumerate(blocks):
@@ -1206,7 +1216,7 @@ class CrossAttnDownBlockMotion(nn.Module):
return_dict=False,
)[0]
else:
hidden_states = resnet(hidden_states, temb, scale=lora_scale)
hidden_states = resnet(hidden_states, temb)
hidden_states = attn(
hidden_states,
encoder_hidden_states=encoder_hidden_states,
@@ -1228,7 +1238,7 @@ class CrossAttnDownBlockMotion(nn.Module):
if self.downsamplers is not None:
for downsampler in self.downsamplers:
hidden_states = downsampler(hidden_states, scale=lora_scale)
hidden_states = downsampler(hidden_states)
output_states = output_states + (hidden_states,)
@@ -1355,7 +1365,10 @@ class CrossAttnUpBlockMotion(nn.Module):
encoder_attention_mask: Optional[torch.FloatTensor] = None,
num_frames: int = 1,
) -> torch.FloatTensor:
lora_scale = cross_attention_kwargs.get("scale", 1.0) if cross_attention_kwargs is not None else 1.0
if cross_attention_kwargs is not None:
if cross_attention_kwargs.get("scale", None) is not None:
logger.warning("Passing `scale` to `cross_attention_kwargs` is deprecated. `scale` will be ignored.")
is_freeu_enabled = (
getattr(self, "s1", None)
and getattr(self, "s2", None)
@@ -1410,7 +1423,7 @@ class CrossAttnUpBlockMotion(nn.Module):
return_dict=False,
)[0]
else:
hidden_states = resnet(hidden_states, temb, scale=lora_scale)
hidden_states = resnet(hidden_states, temb)
hidden_states = attn(
hidden_states,
encoder_hidden_states=encoder_hidden_states,
@@ -1426,7 +1439,7 @@ class CrossAttnUpBlockMotion(nn.Module):
if self.upsamplers is not None:
for upsampler in self.upsamplers:
hidden_states = upsampler(hidden_states, upsample_size, scale=lora_scale)
hidden_states = upsampler(hidden_states, upsample_size)
return hidden_states
@@ -1507,9 +1520,14 @@ class UpBlockMotion(nn.Module):
res_hidden_states_tuple: Tuple[torch.FloatTensor, ...],
temb: Optional[torch.FloatTensor] = None,
upsample_size=None,
scale: float = 1.0,
num_frames: int = 1,
*args,
**kwargs,
) -> torch.FloatTensor:
if len(args) > 0 or kwargs.get("scale", None) is not None:
deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`."
deprecate("scale", "1.0.0", deprecation_message)
is_freeu_enabled = (
getattr(self, "s1", None)
and getattr(self, "s2", None)
@@ -1559,12 +1577,12 @@ class UpBlockMotion(nn.Module):
)
else:
hidden_states = resnet(hidden_states, temb, scale=scale)
hidden_states = resnet(hidden_states, temb)
hidden_states = motion_module(hidden_states, num_frames=num_frames)[0]
if self.upsamplers is not None:
for upsampler in self.upsamplers:
hidden_states = upsampler(hidden_states, upsample_size, scale=scale)
hidden_states = upsampler(hidden_states, upsample_size)
return hidden_states
@@ -1687,8 +1705,11 @@ class UNetMidBlockCrossAttnMotion(nn.Module):
encoder_attention_mask: Optional[torch.FloatTensor] = None,
num_frames: int = 1,
) -> torch.FloatTensor:
lora_scale = cross_attention_kwargs.get("scale", 1.0) if cross_attention_kwargs is not None else 1.0
hidden_states = self.resnets[0](hidden_states, temb, scale=lora_scale)
if cross_attention_kwargs is not None:
if cross_attention_kwargs.get("scale", None) is not None:
logger.warning("Passing `scale` to `cross_attention_kwargs` is deprecated. `scale` will be ignored.")
hidden_states = self.resnets[0](hidden_states, temb)
blocks = zip(self.attentions, self.resnets[1:], self.motion_modules)
for attn, resnet, motion_module in blocks:
@@ -1737,7 +1758,7 @@ class UNetMidBlockCrossAttnMotion(nn.Module):
hidden_states,
num_frames=num_frames,
)[0]
hidden_states = resnet(hidden_states, temb, scale=lora_scale)
hidden_states = resnet(hidden_states, temb)
return hidden_states
+1 -1
View File
@@ -89,7 +89,7 @@ class I2VGenXLTransformerTemporalEncoder(nn.Module):
if hidden_states.ndim == 4:
hidden_states = hidden_states.squeeze(1)
ff_output = self.ff(hidden_states, scale=1.0)
ff_output = self.ff(hidden_states)
hidden_states = ff_output + hidden_states
if hidden_states.ndim == 4:
hidden_states = hidden_states.squeeze(1)
@@ -21,6 +21,7 @@ import torch
import torch.nn as nn
from ...configuration_utils import ConfigMixin, register_to_config
from ...loaders.unet import FromOriginalUNetMixin
from ...utils import BaseOutput
from ..attention_processor import Attention
from ..modeling_utils import ModelMixin
@@ -134,7 +135,7 @@ class StableCascadeUNetOutput(BaseOutput):
sample: torch.FloatTensor = None
class StableCascadeUNet(ModelMixin, ConfigMixin):
class StableCascadeUNet(ModelMixin, ConfigMixin, FromOriginalUNetMixin):
_supports_gradient_checkpointing = True
@register_to_config
+9 -15
View File
@@ -18,8 +18,7 @@ import torch
import torch.nn as nn
import torch.nn.functional as F
from ..utils import USE_PEFT_BACKEND
from .lora import LoRACompatibleConv
from ..utils import deprecate
from .normalization import RMSNorm
@@ -111,7 +110,7 @@ class Upsample2D(nn.Module):
self.use_conv_transpose = use_conv_transpose
self.name = name
self.interpolate = interpolate
conv_cls = nn.Conv2d if USE_PEFT_BACKEND else LoRACompatibleConv
conv_cls = nn.Conv2d
if norm_type == "ln_norm":
self.norm = nn.LayerNorm(channels, eps, elementwise_affine)
@@ -141,11 +140,12 @@ class Upsample2D(nn.Module):
self.Conv2d_0 = conv
def forward(
self,
hidden_states: torch.FloatTensor,
output_size: Optional[int] = None,
scale: float = 1.0,
self, hidden_states: torch.FloatTensor, output_size: Optional[int] = None, *args, **kwargs
) -> torch.FloatTensor:
if len(args) > 0 or kwargs.get("scale", None) is not None:
deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`."
deprecate("scale", "1.0.0", deprecation_message)
assert hidden_states.shape[1] == self.channels
if self.norm is not None:
@@ -180,15 +180,9 @@ class Upsample2D(nn.Module):
# TODO(Suraj, Patrick) - clean up after weight dicts are correctly renamed
if self.use_conv:
if self.name == "conv":
if isinstance(self.conv, LoRACompatibleConv) and not USE_PEFT_BACKEND:
hidden_states = self.conv(hidden_states, scale)
else:
hidden_states = self.conv(hidden_states)
hidden_states = self.conv(hidden_states)
else:
if isinstance(self.Conv2d_0, LoRACompatibleConv) and not USE_PEFT_BACKEND:
hidden_states = self.Conv2d_0(hidden_states, scale)
else:
hidden_states = self.Conv2d_0(hidden_states)
hidden_states = self.Conv2d_0(hidden_states)
return hidden_states
+13
View File
@@ -23,6 +23,7 @@ _import_structure = {
"controlnet_xs": [],
"deprecated": [],
"latent_diffusion": [],
"ledits_pp": [],
"stable_diffusion": [],
"stable_diffusion_xl": [],
}
@@ -171,6 +172,12 @@ else:
"LatentConsistencyModelPipeline",
]
_import_structure["latent_diffusion"].extend(["LDMTextToImagePipeline"])
_import_structure["ledits_pp"].extend(
[
"LEditsPPPipelineStableDiffusion",
"LEditsPPPipelineStableDiffusionXL",
]
)
_import_structure["musicldm"] = ["MusicLDMPipeline"]
_import_structure["paint_by_example"] = ["PaintByExamplePipeline"]
_import_structure["pia"] = ["PIAPipeline"]
@@ -424,6 +431,12 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
LatentConsistencyModelPipeline,
)
from .latent_diffusion import LDMTextToImagePipeline
from .ledits_pp import (
LEditsPPDiffusionPipelineOutput,
LEditsPPInversionPipelineOutput,
LEditsPPPipelineStableDiffusion,
LEditsPPPipelineStableDiffusionXL,
)
from .musicldm import MusicLDMPipeline
from .paint_by_example import PaintByExamplePipeline
from .pia import PIAPipeline
@@ -127,7 +127,7 @@ class AmusedImg2ImgPipeline(DiffusionPipeline):
on the amount of noise initially added. When `strength` is 1, added noise is maximum and the denoising
process runs for the full number of iterations specified in `num_inference_steps`. A value of 1
essentially ignores `image`.
num_inference_steps (`int`, *optional*, defaults to 16):
num_inference_steps (`int`, *optional*, defaults to 12):
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
expense of slower inference.
guidance_scale (`float`, *optional*, defaults to 10.0):
@@ -191,7 +191,7 @@ class AmusedImg2ImgPipeline(DiffusionPipeline):
negative_prompt_embeds is None and negative_encoder_hidden_states is not None
):
raise ValueError(
"pass either both `negatve_prompt_embeds` and `negative_encoder_hidden_states` or neither"
"pass either both `negative_prompt_embeds` and `negative_encoder_hidden_states` or neither"
)
if (prompt is None and prompt_embeds is None) or (prompt is not None and prompt_embeds is not None):
@@ -824,20 +824,22 @@ class StableDiffusionControlNetPipeline(
return latents
# Copied from diffusers.pipelines.latent_consistency_models.pipeline_latent_consistency_text2img.LatentConsistencyModelPipeline.get_guidance_scale_embedding
def get_guidance_scale_embedding(self, w, embedding_dim=512, dtype=torch.float32):
def get_guidance_scale_embedding(
self, w: torch.Tensor, embedding_dim: int = 512, dtype: torch.dtype = torch.float32
) -> torch.FloatTensor:
"""
See https://github.com/google-research/vdm/blob/dc27b98a554f65cdc654b800da5aa1846545d41b/model_vdm.py#L298
Args:
timesteps (`torch.Tensor`):
generate embedding vectors at these timesteps
w (`torch.Tensor`):
Generate embedding vectors with a specified guidance scale to subsequently enrich timestep embeddings.
embedding_dim (`int`, *optional*, defaults to 512):
dimension of the embeddings to generate
dtype:
data type of the generated embeddings
Dimension of the embeddings to generate.
dtype (`torch.dtype`, *optional*, defaults to `torch.float32`):
Data type of the generated embeddings.
Returns:
`torch.FloatTensor`: Embedding vectors with shape `(len(timesteps), embedding_dim)`
`torch.FloatTensor`: Embedding vectors with shape `(len(w), embedding_dim)`.
"""
assert len(w.shape) == 1
w = w * 1000.0
@@ -869,20 +869,22 @@ class StableDiffusionXLControlNetPipeline(
self.vae.decoder.mid_block.to(dtype)
# Copied from diffusers.pipelines.latent_consistency_models.pipeline_latent_consistency_text2img.LatentConsistencyModelPipeline.get_guidance_scale_embedding
def get_guidance_scale_embedding(self, w, embedding_dim=512, dtype=torch.float32):
def get_guidance_scale_embedding(
self, w: torch.Tensor, embedding_dim: int = 512, dtype: torch.dtype = torch.float32
) -> torch.FloatTensor:
"""
See https://github.com/google-research/vdm/blob/dc27b98a554f65cdc654b800da5aa1846545d41b/model_vdm.py#L298
Args:
timesteps (`torch.Tensor`):
generate embedding vectors at these timesteps
w (`torch.Tensor`):
Generate embedding vectors with a specified guidance scale to subsequently enrich timestep embeddings.
embedding_dim (`int`, *optional*, defaults to 512):
dimension of the embeddings to generate
dtype:
data type of the generated embeddings
Dimension of the embeddings to generate.
dtype (`torch.dtype`, *optional*, defaults to `torch.float32`):
Data type of the generated embeddings.
Returns:
`torch.FloatTensor`: Embedding vectors with shape `(len(timesteps), embedding_dim)`
`torch.FloatTensor`: Embedding vectors with shape `(len(w), embedding_dim)`.
"""
assert len(w.shape) == 1
w = w * 1000.0
@@ -156,7 +156,7 @@ class FlaxStableDiffusionControlNetPipeline(FlaxDiffusionPipeline):
self.dtype = dtype
if safety_checker is None:
logger.warn(
logger.warning(
f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure"
" that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered"
" results in services or applications open to the public. Both the diffusers team and Hugging Face"
@@ -416,13 +416,13 @@ class IFPipeline(DiffusionPipeline, LoraLoaderMixin):
def _text_preprocessing(self, text, clean_caption=False):
if clean_caption and not is_bs4_available():
logger.warn(BACKENDS_MAPPING["bs4"][-1].format("Setting `clean_caption=True`"))
logger.warn("Setting `clean_caption` to False...")
logger.warning(BACKENDS_MAPPING["bs4"][-1].format("Setting `clean_caption=True`"))
logger.warning("Setting `clean_caption` to False...")
clean_caption = False
if clean_caption and not is_ftfy_available():
logger.warn(BACKENDS_MAPPING["ftfy"][-1].format("Setting `clean_caption=True`"))
logger.warn("Setting `clean_caption` to False...")
logger.warning(BACKENDS_MAPPING["ftfy"][-1].format("Setting `clean_caption=True`"))
logger.warning("Setting `clean_caption` to False...")
clean_caption = False
if not isinstance(text, (tuple, list)):
@@ -460,13 +460,13 @@ class IFImg2ImgPipeline(DiffusionPipeline, LoraLoaderMixin):
# Copied from diffusers.pipelines.deepfloyd_if.pipeline_if.IFPipeline._text_preprocessing
def _text_preprocessing(self, text, clean_caption=False):
if clean_caption and not is_bs4_available():
logger.warn(BACKENDS_MAPPING["bs4"][-1].format("Setting `clean_caption=True`"))
logger.warn("Setting `clean_caption` to False...")
logger.warning(BACKENDS_MAPPING["bs4"][-1].format("Setting `clean_caption=True`"))
logger.warning("Setting `clean_caption` to False...")
clean_caption = False
if clean_caption and not is_ftfy_available():
logger.warn(BACKENDS_MAPPING["ftfy"][-1].format("Setting `clean_caption=True`"))
logger.warn("Setting `clean_caption` to False...")
logger.warning(BACKENDS_MAPPING["ftfy"][-1].format("Setting `clean_caption=True`"))
logger.warning("Setting `clean_caption` to False...")
clean_caption = False
if not isinstance(text, (tuple, list)):
@@ -175,7 +175,7 @@ class IFImg2ImgSuperResolutionPipeline(DiffusionPipeline, LoraLoaderMixin):
)
if unet.config.in_channels != 6:
logger.warn(
logger.warning(
"It seems like you have loaded a checkpoint that shall not be used for super resolution from {unet.config._name_or_path} as it accepts {unet.config.in_channels} input channels instead of 6. Please make sure to pass a super resolution checkpoint as the `'unet'`: IFSuperResolutionPipeline.from_pretrained(unet=super_resolution_unet, ...)`."
)
@@ -209,13 +209,13 @@ class IFImg2ImgSuperResolutionPipeline(DiffusionPipeline, LoraLoaderMixin):
# Copied from diffusers.pipelines.deepfloyd_if.pipeline_if.IFPipeline._text_preprocessing
def _text_preprocessing(self, text, clean_caption=False):
if clean_caption and not is_bs4_available():
logger.warn(BACKENDS_MAPPING["bs4"][-1].format("Setting `clean_caption=True`"))
logger.warn("Setting `clean_caption` to False...")
logger.warning(BACKENDS_MAPPING["bs4"][-1].format("Setting `clean_caption=True`"))
logger.warning("Setting `clean_caption` to False...")
clean_caption = False
if clean_caption and not is_ftfy_available():
logger.warn(BACKENDS_MAPPING["ftfy"][-1].format("Setting `clean_caption=True`"))
logger.warn("Setting `clean_caption` to False...")
logger.warning(BACKENDS_MAPPING["ftfy"][-1].format("Setting `clean_caption=True`"))
logger.warning("Setting `clean_caption` to False...")
clean_caption = False
if not isinstance(text, (tuple, list)):
@@ -500,13 +500,13 @@ class IFInpaintingPipeline(DiffusionPipeline, LoraLoaderMixin):
# Copied from diffusers.pipelines.deepfloyd_if.pipeline_if.IFPipeline._text_preprocessing
def _text_preprocessing(self, text, clean_caption=False):
if clean_caption and not is_bs4_available():
logger.warn(BACKENDS_MAPPING["bs4"][-1].format("Setting `clean_caption=True`"))
logger.warn("Setting `clean_caption` to False...")
logger.warning(BACKENDS_MAPPING["bs4"][-1].format("Setting `clean_caption=True`"))
logger.warning("Setting `clean_caption` to False...")
clean_caption = False
if clean_caption and not is_ftfy_available():
logger.warn(BACKENDS_MAPPING["ftfy"][-1].format("Setting `clean_caption=True`"))
logger.warn("Setting `clean_caption` to False...")
logger.warning(BACKENDS_MAPPING["ftfy"][-1].format("Setting `clean_caption=True`"))
logger.warning("Setting `clean_caption` to False...")
clean_caption = False
if not isinstance(text, (tuple, list)):
@@ -177,7 +177,7 @@ class IFInpaintingSuperResolutionPipeline(DiffusionPipeline, LoraLoaderMixin):
)
if unet.config.in_channels != 6:
logger.warn(
logger.warning(
"It seems like you have loaded a checkpoint that shall not be used for super resolution from {unet.config._name_or_path} as it accepts {unet.config.in_channels} input channels instead of 6. Please make sure to pass a super resolution checkpoint as the `'unet'`: IFSuperResolutionPipeline.from_pretrained(unet=super_resolution_unet, ...)`."
)
@@ -211,13 +211,13 @@ class IFInpaintingSuperResolutionPipeline(DiffusionPipeline, LoraLoaderMixin):
# Copied from diffusers.pipelines.deepfloyd_if.pipeline_if.IFPipeline._text_preprocessing
def _text_preprocessing(self, text, clean_caption=False):
if clean_caption and not is_bs4_available():
logger.warn(BACKENDS_MAPPING["bs4"][-1].format("Setting `clean_caption=True`"))
logger.warn("Setting `clean_caption` to False...")
logger.warning(BACKENDS_MAPPING["bs4"][-1].format("Setting `clean_caption=True`"))
logger.warning("Setting `clean_caption` to False...")
clean_caption = False
if clean_caption and not is_ftfy_available():
logger.warn(BACKENDS_MAPPING["ftfy"][-1].format("Setting `clean_caption=True`"))
logger.warn("Setting `clean_caption` to False...")
logger.warning(BACKENDS_MAPPING["ftfy"][-1].format("Setting `clean_caption=True`"))
logger.warning("Setting `clean_caption` to False...")
clean_caption = False
if not isinstance(text, (tuple, list)):
@@ -133,7 +133,7 @@ class IFSuperResolutionPipeline(DiffusionPipeline, LoraLoaderMixin):
)
if unet.config.in_channels != 6:
logger.warn(
logger.warning(
"It seems like you have loaded a checkpoint that shall not be used for super resolution from {unet.config._name_or_path} as it accepts {unet.config.in_channels} input channels instead of 6. Please make sure to pass a super resolution checkpoint as the `'unet'`: IFSuperResolutionPipeline.from_pretrained(unet=super_resolution_unet, ...)`."
)
@@ -167,13 +167,13 @@ class IFSuperResolutionPipeline(DiffusionPipeline, LoraLoaderMixin):
# Copied from diffusers.pipelines.deepfloyd_if.pipeline_if.IFPipeline._text_preprocessing
def _text_preprocessing(self, text, clean_caption=False):
if clean_caption and not is_bs4_available():
logger.warn(BACKENDS_MAPPING["bs4"][-1].format("Setting `clean_caption=True`"))
logger.warn("Setting `clean_caption` to False...")
logger.warning(BACKENDS_MAPPING["bs4"][-1].format("Setting `clean_caption=True`"))
logger.warning("Setting `clean_caption` to False...")
clean_caption = False
if clean_caption and not is_ftfy_available():
logger.warn(BACKENDS_MAPPING["ftfy"][-1].format("Setting `clean_caption=True`"))
logger.warn("Setting `clean_caption` to False...")
logger.warning(BACKENDS_MAPPING["ftfy"][-1].format("Setting `clean_caption=True`"))
logger.warning("Setting `clean_caption` to False...")
clean_caption = False
if not isinstance(text, (tuple, list)):
@@ -133,7 +133,7 @@ class SpectrogramDiffusionPipeline(DiffusionPipeline):
generator: Optional[torch.Generator] = None,
num_inference_steps: int = 100,
return_dict: bool = True,
output_type: str = "numpy",
output_type: str = "np",
callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
callback_steps: int = 1,
) -> Union[AudioPipelineOutput, Tuple]:
@@ -157,7 +157,7 @@ class SpectrogramDiffusionPipeline(DiffusionPipeline):
expense of slower inference.
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`~pipelines.AudioPipelineOutput`] instead of a plain tuple.
output_type (`str`, *optional*, defaults to `"numpy"`):
output_type (`str`, *optional*, defaults to `"np"`):
The output format of the generated audio.
callback (`Callable`, *optional*):
A function that calls every `callback_steps` steps during inference. The function is called with the
@@ -249,16 +249,16 @@ class SpectrogramDiffusionPipeline(DiffusionPipeline):
logger.info("Generated segment", i)
if output_type == "numpy" and not is_onnx_available():
if output_type == "np" and not is_onnx_available():
raise ValueError(
"Cannot return output in 'np' format if ONNX is not available. Make sure to have ONNX installed or set 'output_type' to 'mel'."
)
elif output_type == "numpy" and self.melgan is None:
elif output_type == "np" and self.melgan is None:
raise ValueError(
"Cannot return output in 'np' format if melgan component is not defined. Make sure to define `self.melgan` or set 'output_type' to 'mel'."
)
if output_type == "numpy":
if output_type == "np":
output = self.melgan(input_features=full_pred_mel.astype(np.float32))
else:
output = full_pred_mel

Some files were not shown because too many files have changed in this diff Show More