Compare commits
56 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| ea238e821b | |||
| b6d1d670fc | |||
| 4330a747d4 | |||
| 76de6a09fb | |||
| 25caf24ef9 | |||
| 8db3c9bc9f | |||
| e0e9f81971 | |||
| 5d848ec07c | |||
| 4974b84564 | |||
| 83062fb872 | |||
| b6d7e31d10 | |||
| 53e9aacc10 | |||
| 41424466e3 | |||
| 95de1981c9 | |||
| 0b45b58867 | |||
| d3986f18be | |||
| ee6a3a993d | |||
| b300517305 | |||
| ac07b6dc6a | |||
| 46ab56a468 | |||
| 038ff70023 | |||
| 00eca4b887 | |||
| 30132aba30 | |||
| a17d6d6858 | |||
| 8efd9ce787 | |||
| 299c16d0f5 | |||
| 69f49195ac | |||
| ed224f94ba | |||
| 531e719163 | |||
| 4fbd310fd2 | |||
| 2ea28d69dc | |||
| a1cb106459 | |||
| 5dd8e04d4b | |||
| 165af7edd3 | |||
| 6c5f0de713 | |||
| e64fdcf2ce | |||
| ec64f371b1 | |||
| cd6e1f1171 | |||
| 6f2b310a17 | |||
| e3cd6cae50 | |||
| e5ee05da76 | |||
| e6ff752840 | |||
| 3f9c746fb2 | |||
| 1f22c98820 | |||
| b4226bd6a7 | |||
| 46fac824be | |||
| b33b64f595 | |||
| 9d9744075e | |||
| d9a3b69806 | |||
| f7e5954d5e | |||
| 8e19c073e5 | |||
| f6df16cbb8 | |||
| b24f78349c | |||
| 3ce905c9d0 | |||
| f539497ab4 | |||
| 39dfb7abbd |
@@ -1,22 +1,58 @@
|
||||
name: Build Docker images (nightly)
|
||||
name: Test, build, and push Docker images
|
||||
|
||||
on:
|
||||
pull_request: # During PRs, we just check if the changes Dockerfiles can be successfully built
|
||||
branches:
|
||||
- main
|
||||
paths:
|
||||
- "docker/**"
|
||||
workflow_dispatch:
|
||||
schedule:
|
||||
- cron: "0 0 * * *" # every day at midnight
|
||||
|
||||
concurrency:
|
||||
group: docker-image-builds
|
||||
cancel-in-progress: false
|
||||
group: ${{ github.workflow }}-${{ github.head_ref || github.run_id }}
|
||||
cancel-in-progress: true
|
||||
|
||||
env:
|
||||
REGISTRY: diffusers
|
||||
CI_SLACK_CHANNEL: ${{ secrets.CI_DOCKER_CHANNEL }}
|
||||
|
||||
jobs:
|
||||
build-docker-images:
|
||||
test-build-docker-images:
|
||||
runs-on: ubuntu-latest
|
||||
if: github.event_name == 'pull_request'
|
||||
steps:
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@v1
|
||||
|
||||
- name: Check out code
|
||||
uses: actions/checkout@v3
|
||||
|
||||
- name: Find Changed Dockerfiles
|
||||
id: file_changes
|
||||
uses: jitterbit/get-changed-files@v1
|
||||
with:
|
||||
format: 'space-delimited'
|
||||
token: ${{ secrets.GITHUB_TOKEN }}
|
||||
|
||||
- name: Build Changed Docker Images
|
||||
run: |
|
||||
CHANGED_FILES="${{ steps.file_changes.outputs.all }}"
|
||||
for FILE in $CHANGED_FILES; do
|
||||
if [[ "$FILE" == docker/*Dockerfile ]]; then
|
||||
DOCKER_PATH="${FILE%/Dockerfile}"
|
||||
DOCKER_TAG=$(basename "$DOCKER_PATH")
|
||||
echo "Building Docker image for $DOCKER_TAG"
|
||||
docker build -t "$DOCKER_TAG" "$DOCKER_PATH"
|
||||
fi
|
||||
done
|
||||
if: steps.file_changes.outputs.all != ''
|
||||
|
||||
build-and-push-docker-images:
|
||||
runs-on: ubuntu-latest
|
||||
if: github.event_name != 'pull_request'
|
||||
|
||||
permissions:
|
||||
contents: read
|
||||
packages: write
|
||||
|
||||
@@ -12,6 +12,7 @@ env:
|
||||
PYTEST_TIMEOUT: 600
|
||||
RUN_SLOW: yes
|
||||
RUN_NIGHTLY: yes
|
||||
SLACK_API_TOKEN: ${{ secrets.SLACK_CIFEEDBACK_BOT_TOKEN }}
|
||||
|
||||
jobs:
|
||||
run_nightly_tests:
|
||||
@@ -64,6 +65,7 @@ jobs:
|
||||
python -m uv pip install -e [quality,test]
|
||||
python -m uv pip install -U transformers@git+https://github.com/huggingface/transformers
|
||||
python -m uv pip install accelerate@git+https://github.com/huggingface/accelerate
|
||||
python -m uv pip install pytest-reportlog
|
||||
|
||||
- name: Environment
|
||||
run: |
|
||||
@@ -78,7 +80,8 @@ jobs:
|
||||
python -m pytest -n 1 --max-worker-restart=0 --dist=loadfile \
|
||||
-s -v -k "not Flax and not Onnx" \
|
||||
--make-reports=tests_${{ matrix.config.report }} \
|
||||
tests/
|
||||
--report-log=${{ matrix.config.report }}.log \
|
||||
tests/
|
||||
|
||||
- name: Run nightly Flax TPU tests
|
||||
if: ${{ matrix.config.framework == 'flax' }}
|
||||
@@ -89,6 +92,7 @@ jobs:
|
||||
python -m pytest -n 0 \
|
||||
-s -v -k "Flax" \
|
||||
--make-reports=tests_${{ matrix.config.report }} \
|
||||
--report-log=${{ matrix.config.report }}.log \
|
||||
tests/
|
||||
|
||||
- name: Run nightly ONNXRuntime CUDA tests
|
||||
@@ -100,6 +104,7 @@ jobs:
|
||||
python -m pytest -n 1 --max-worker-restart=0 --dist=loadfile \
|
||||
-s -v -k "Onnx" \
|
||||
--make-reports=tests_${{ matrix.config.report }} \
|
||||
--report-log=${{ matrix.config.report }}.log \
|
||||
tests/
|
||||
|
||||
- name: Failure short reports
|
||||
@@ -112,6 +117,12 @@ jobs:
|
||||
with:
|
||||
name: ${{ matrix.config.report }}_test_reports
|
||||
path: reports
|
||||
|
||||
- name: Generate Report and Notify Channel
|
||||
if: always()
|
||||
run: |
|
||||
pip install slack_sdk tabulate
|
||||
python scripts/log_reports.py >> $GITHUB_STEP_SUMMARY
|
||||
|
||||
run_nightly_tests_apple_m1:
|
||||
name: Nightly PyTorch MPS tests on MacOS
|
||||
@@ -140,6 +151,7 @@ jobs:
|
||||
${CONDA_RUN} python -m uv pip install -e [quality,test]
|
||||
${CONDA_RUN} python -m uv pip install torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/cpu
|
||||
${CONDA_RUN} python -m uv pip install accelerate@git+https://github.com/huggingface/accelerate
|
||||
${CONDA_RUN} python -m uv pip install pytest-reportlog
|
||||
|
||||
- name: Environment
|
||||
shell: arch -arch arm64 bash {0}
|
||||
@@ -152,7 +164,9 @@ jobs:
|
||||
HF_HOME: /System/Volumes/Data/mnt/cache
|
||||
HUGGING_FACE_HUB_TOKEN: ${{ secrets.HUGGING_FACE_HUB_TOKEN }}
|
||||
run: |
|
||||
${CONDA_RUN} python -m pytest -n 1 -s -v --make-reports=tests_torch_mps tests/
|
||||
${CONDA_RUN} python -m pytest -n 1 -s -v --make-reports=tests_torch_mps \
|
||||
--report-log=tests_torch_mps.log \
|
||||
tests/
|
||||
|
||||
- name: Failure short reports
|
||||
if: ${{ failure() }}
|
||||
@@ -164,3 +178,9 @@ jobs:
|
||||
with:
|
||||
name: torch_mps_test_reports
|
||||
path: reports
|
||||
|
||||
- name: Generate Report and Notify Channel
|
||||
if: always()
|
||||
run: |
|
||||
pip install slack_sdk tabulate
|
||||
python scripts/log_reports.py >> $GITHUB_STEP_SUMMARY
|
||||
|
||||
@@ -0,0 +1,23 @@
|
||||
name: Notify Slack about a release
|
||||
|
||||
on:
|
||||
workflow_dispatch:
|
||||
release:
|
||||
types: [published]
|
||||
|
||||
jobs:
|
||||
build:
|
||||
runs-on: ubuntu-latest
|
||||
|
||||
steps:
|
||||
- uses: actions/checkout@v3
|
||||
|
||||
- name: Setup Python
|
||||
uses: actions/setup-python@v4
|
||||
with:
|
||||
python-version: '3.8'
|
||||
|
||||
- name: Notify Slack about the release
|
||||
env:
|
||||
SLACK_WEBHOOK_URL: ${{ secrets.SLACK_WEBHOOK_URL }}
|
||||
run: pip install requests && python utils/notify_slack_about_release.py
|
||||
@@ -0,0 +1,81 @@
|
||||
# Adapted from https://blog.deepjyoti30.dev/pypi-release-github-action
|
||||
|
||||
name: PyPI release
|
||||
|
||||
on:
|
||||
workflow_dispatch:
|
||||
push:
|
||||
tags:
|
||||
- "*"
|
||||
|
||||
jobs:
|
||||
find-and-checkout-latest-branch:
|
||||
runs-on: ubuntu-latest
|
||||
outputs:
|
||||
latest_branch: ${{ steps.set_latest_branch.outputs.latest_branch }}
|
||||
steps:
|
||||
- name: Checkout Repo
|
||||
uses: actions/checkout@v3
|
||||
|
||||
- name: Set up Python
|
||||
uses: actions/setup-python@v4
|
||||
with:
|
||||
python-version: '3.8'
|
||||
|
||||
- name: Fetch latest branch
|
||||
id: fetch_latest_branch
|
||||
run: |
|
||||
pip install -U requests packaging
|
||||
LATEST_BRANCH=$(python utils/fetch_latest_release_branch.py)
|
||||
echo "Latest branch: $LATEST_BRANCH"
|
||||
echo "latest_branch=$LATEST_BRANCH" >> $GITHUB_ENV
|
||||
|
||||
- name: Set latest branch output
|
||||
id: set_latest_branch
|
||||
run: echo "::set-output name=latest_branch::${{ env.latest_branch }}"
|
||||
|
||||
release:
|
||||
needs: find-and-checkout-latest-branch
|
||||
runs-on: ubuntu-latest
|
||||
|
||||
steps:
|
||||
- name: Checkout Repo
|
||||
uses: actions/checkout@v3
|
||||
with:
|
||||
ref: ${{ needs.find-and-checkout-latest-branch.outputs.latest_branch }}
|
||||
|
||||
- name: Setup Python
|
||||
uses: actions/setup-python@v4
|
||||
with:
|
||||
python-version: "3.8"
|
||||
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
python -m pip install --upgrade pip
|
||||
pip install -U setuptools wheel twine
|
||||
pip install -U torch --index-url https://download.pytorch.org/whl/cpu
|
||||
pip install -U transformers
|
||||
|
||||
- name: Build the dist files
|
||||
run: python setup.py bdist_wheel && python setup.py sdist
|
||||
|
||||
- name: Publish to the test PyPI
|
||||
env:
|
||||
TWINE_USERNAME: ${{ secrets.TEST_PYPI_USERNAME }}
|
||||
TWINE_PASSWORD: ${{ secrets.TEST_PYPI_PASSWORD }}
|
||||
run: twine upload dist/* -r pypitest --repository-url=https://test.pypi.org/legacy/
|
||||
|
||||
- name: Test installing diffusers and importing
|
||||
run: |
|
||||
pip install diffusers && pip uninstall diffusers -y
|
||||
pip install -i https://testpypi.python.org/pypi diffusers
|
||||
python -c "from diffusers import __version__; print(__version__)"
|
||||
python -c "from diffusers import DiffusionPipeline; pipe = DiffusionPipeline.from_pretrained('fusing/unet-ldm-dummy-update'); pipe()"
|
||||
python -c "from diffusers import DiffusionPipeline; pipe = DiffusionPipeline.from_pretrained('hf-internal-testing/tiny-stable-diffusion-pipe', safety_checker=None); pipe('ah suh du')"
|
||||
python -c "from diffusers import *"
|
||||
|
||||
- name: Publish to PyPI
|
||||
env:
|
||||
TWINE_USERNAME: ${{ secrets.PYPI_USERNAME }}
|
||||
TWINE_PASSWORD: ${{ secrets.PYPI_PASSWORD }}
|
||||
run: twine upload dist/* -r pypi
|
||||
@@ -77,7 +77,7 @@ Please refer to the [How to use Stable Diffusion in Apple Silicon](https://huggi
|
||||
|
||||
## Quickstart
|
||||
|
||||
Generating outputs is super easy with 🤗 Diffusers. To generate an image from text, use the `from_pretrained` method to load any pretrained diffusion model (browse the [Hub](https://huggingface.co/models?library=diffusers&sort=downloads) for 19000+ checkpoints):
|
||||
Generating outputs is super easy with 🤗 Diffusers. To generate an image from text, use the `from_pretrained` method to load any pretrained diffusion model (browse the [Hub](https://huggingface.co/models?library=diffusers&sort=downloads) for 22000+ checkpoints):
|
||||
|
||||
```python
|
||||
from diffusers import DiffusionPipeline
|
||||
@@ -219,7 +219,7 @@ Also, say 👋 in our public Discord channel <a href="https://discord.gg/G7tWnz9
|
||||
- https://github.com/deep-floyd/IF
|
||||
- https://github.com/bentoml/BentoML
|
||||
- https://github.com/bmaltais/kohya_ss
|
||||
- +8000 other amazing GitHub repositories 💪
|
||||
- +9000 other amazing GitHub repositories 💪
|
||||
|
||||
Thank you for using us ❤️.
|
||||
|
||||
|
||||
@@ -40,6 +40,6 @@ RUN python3 -m pip install --no-cache-dir --upgrade pip uv==0.1.11 && \
|
||||
numpy \
|
||||
scipy \
|
||||
tensorboard \
|
||||
transformers
|
||||
transformers matplotlib
|
||||
|
||||
CMD ["/bin/bash"]
|
||||
|
||||
@@ -18,7 +18,7 @@
|
||||
- local: tutorials/basic_training
|
||||
title: Train a diffusion model
|
||||
- local: tutorials/using_peft_for_inference
|
||||
title: Inference with PEFT
|
||||
title: Load LoRAs for inference
|
||||
- local: tutorials/fast_diffusion
|
||||
title: Accelerate inference of text-to-image diffusion models
|
||||
title: Tutorials
|
||||
@@ -62,6 +62,8 @@
|
||||
title: Textual inversion
|
||||
- local: using-diffusers/ip_adapter
|
||||
title: IP-Adapter
|
||||
- local: using-diffusers/merge_loras
|
||||
title: Merge LoRAs
|
||||
- local: training/distributed_inference
|
||||
title: Distributed inference with multiple GPUs
|
||||
- local: using-diffusers/reusing_seeds
|
||||
@@ -102,6 +104,8 @@
|
||||
title: Latent Consistency Model-LoRA
|
||||
- local: using-diffusers/inference_with_lcm
|
||||
title: Latent Consistency Model
|
||||
- local: using-diffusers/inference_with_tcd_lora
|
||||
title: Trajectory Consistency Distillation-LoRA
|
||||
- local: using-diffusers/svd
|
||||
title: Stable Video Diffusion
|
||||
title: Specific pipeline examples
|
||||
@@ -302,6 +306,8 @@
|
||||
title: Latent Consistency Models
|
||||
- local: api/pipelines/latent_diffusion
|
||||
title: Latent Diffusion
|
||||
- local: api/pipelines/ledits_pp
|
||||
title: LEDITS++
|
||||
- local: api/pipelines/panorama
|
||||
title: MultiDiffusion
|
||||
- local: api/pipelines/musicldm
|
||||
@@ -394,6 +400,10 @@
|
||||
title: DPMSolverSDEScheduler
|
||||
- local: api/schedulers/singlestep_dpm_solver
|
||||
title: DPMSolverSinglestepScheduler
|
||||
- local: api/schedulers/edm_multistep_dpm_solver
|
||||
title: EDMDPMSolverMultistepScheduler
|
||||
- local: api/schedulers/edm_euler
|
||||
title: EDMEulerScheduler
|
||||
- local: api/schedulers/euler_ancestral
|
||||
title: EulerAncestralDiscreteScheduler
|
||||
- local: api/schedulers/euler
|
||||
|
||||
@@ -23,3 +23,7 @@ Learn how to load an IP-Adapter checkpoint and image in the IP-Adapter [loading]
|
||||
## IPAdapterMixin
|
||||
|
||||
[[autodoc]] loaders.ip_adapter.IPAdapterMixin
|
||||
|
||||
## IPAdapterMaskProcessor
|
||||
|
||||
[[autodoc]] image_processor.IPAdapterMaskProcessor
|
||||
@@ -0,0 +1,54 @@
|
||||
<!--Copyright 2023 The HuggingFace Team. All rights reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
|
||||
the License. You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
|
||||
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
|
||||
specific language governing permissions and limitations under the License.
|
||||
-->
|
||||
|
||||
# LEDITS++
|
||||
|
||||
LEDITS++ was proposed in [LEDITS++: Limitless Image Editing using Text-to-Image Models](https://huggingface.co/papers/2311.16711) by Manuel Brack, Felix Friedrich, Katharina Kornmeier, Linoy Tsaban, Patrick Schramowski, Kristian Kersting, Apolinário Passos.
|
||||
|
||||
The abstract from the paper is:
|
||||
|
||||
*Text-to-image diffusion models have recently received increasing interest for their astonishing ability to produce high-fidelity images from solely text inputs. Subsequent research efforts aim to exploit and apply their capabilities to real image editing. However, existing image-to-image methods are often inefficient, imprecise, and of limited versatility. They either require time-consuming fine-tuning, deviate unnecessarily strongly from the input image, and/or lack support for multiple, simultaneous edits. To address these issues, we introduce LEDITS++, an efficient yet versatile and precise textual image manipulation technique. LEDITS++'s novel inversion approach requires no tuning nor optimization and produces high-fidelity results with a few diffusion steps. Second, our methodology supports multiple simultaneous edits and is architecture-agnostic. Third, we use a novel implicit masking technique that limits changes to relevant image regions. We propose the novel TEdBench++ benchmark as part of our exhaustive evaluation. Our results demonstrate the capabilities of LEDITS++ and its improvements over previous methods. The project page is available at https://leditsplusplus-project.static.hf.space .*
|
||||
|
||||
<Tip>
|
||||
|
||||
You can find additional information about LEDITS++ on the [project page](https://leditsplusplus-project.static.hf.space/index.html) and try it out in a [demo](https://huggingface.co/spaces/editing-images/leditsplusplus).
|
||||
|
||||
</Tip>
|
||||
|
||||
<Tip warning={true}>
|
||||
Due to some backward compatability issues with the current diffusers implementation of [`~schedulers.DPMSolverMultistepScheduler`] this implementation of LEdits++ can no longer guarantee perfect inversion.
|
||||
This issue is unlikely to have any noticeable effects on applied use-cases. However, we provide an alternative implementation that guarantees perfect inversion in a dedicated [GitHub repo](https://github.com/ml-research/ledits_pp).
|
||||
</Tip>
|
||||
|
||||
We provide two distinct pipelines based on different pre-trained models.
|
||||
|
||||
## LEditsPPPipelineStableDiffusion
|
||||
[[autodoc]] pipelines.ledits_pp.LEditsPPPipelineStableDiffusion
|
||||
- all
|
||||
- __call__
|
||||
- invert
|
||||
|
||||
## LEditsPPPipelineStableDiffusionXL
|
||||
[[autodoc]] pipelines.ledits_pp.LEditsPPPipelineStableDiffusionXL
|
||||
- all
|
||||
- __call__
|
||||
- invert
|
||||
|
||||
|
||||
|
||||
## LEditsPPDiffusionPipelineOutput
|
||||
[[autodoc]] pipelines.ledits_pp.pipeline_output.LEditsPPDiffusionPipelineOutput
|
||||
- all
|
||||
|
||||
## LEditsPPInversionPipelineOutput
|
||||
[[autodoc]] pipelines.ledits_pp.pipeline_output.LEditsPPInversionPipelineOutput
|
||||
- all
|
||||
@@ -57,6 +57,7 @@ The table below lists all the pipelines currently available in 🤗 Diffusers an
|
||||
| [Latent Consistency Models](latent_consistency_models) | text2image |
|
||||
| [Latent Diffusion](latent_diffusion) | text2image, super-resolution |
|
||||
| [LDM3D](stable_diffusion/ldm3d_diffusion) | text2image, text-to-3D, text-to-pano, upscaling |
|
||||
| [LEDITS++](ledits_pp) | image editing |
|
||||
| [MultiDiffusion](panorama) | text2image |
|
||||
| [MusicLDM](musicldm) | text2audio |
|
||||
| [Paint by Example](paint_by_example) | inpainting |
|
||||
|
||||
@@ -30,6 +30,6 @@ Make sure to check out the Schedulers [guide](../../using-diffusers/schedulers)
|
||||
- all
|
||||
- __call__
|
||||
|
||||
## StableDiffusionSafePipelineOutput
|
||||
## SemanticStableDiffusionPipelineOutput
|
||||
[[autodoc]] pipelines.semantic_stable_diffusion.pipeline_output.SemanticStableDiffusionPipelineOutput
|
||||
- all
|
||||
|
||||
@@ -12,13 +12,13 @@ specific language governing permissions and limitations under the License.
|
||||
|
||||
# Stable Cascade
|
||||
|
||||
This model is built upon the [Würstchen](https://openreview.net/forum?id=gU58d5QeGv) architecture and its main
|
||||
difference to other models like Stable Diffusion is that it is working at a much smaller latent space. Why is this
|
||||
important? The smaller the latent space, the **faster** you can run inference and the **cheaper** the training becomes.
|
||||
How small is the latent space? Stable Diffusion uses a compression factor of 8, resulting in a 1024x1024 image being
|
||||
encoded to 128x128. Stable Cascade achieves a compression factor of 42, meaning that it is possible to encode a
|
||||
1024x1024 image to 24x24, while maintaining crisp reconstructions. The text-conditional model is then trained in the
|
||||
highly compressed latent space. Previous versions of this architecture, achieved a 16x cost reduction over Stable
|
||||
This model is built upon the [Würstchen](https://openreview.net/forum?id=gU58d5QeGv) architecture and its main
|
||||
difference to other models like Stable Diffusion is that it is working at a much smaller latent space. Why is this
|
||||
important? The smaller the latent space, the **faster** you can run inference and the **cheaper** the training becomes.
|
||||
How small is the latent space? Stable Diffusion uses a compression factor of 8, resulting in a 1024x1024 image being
|
||||
encoded to 128x128. Stable Cascade achieves a compression factor of 42, meaning that it is possible to encode a
|
||||
1024x1024 image to 24x24, while maintaining crisp reconstructions. The text-conditional model is then trained in the
|
||||
highly compressed latent space. Previous versions of this architecture, achieved a 16x cost reduction over Stable
|
||||
Diffusion 1.5.
|
||||
|
||||
Therefore, this kind of model is well suited for usages where efficiency is important. Furthermore, all known extensions
|
||||
@@ -30,13 +30,154 @@ The original codebase can be found at [Stability-AI/StableCascade](https://githu
|
||||
Stable Cascade consists of three models: Stage A, Stage B and Stage C, representing a cascade to generate images,
|
||||
hence the name "Stable Cascade".
|
||||
|
||||
Stage A & B are used to compress images, similar to what the job of the VAE is in Stable Diffusion.
|
||||
However, with this setup, a much higher compression of images can be achieved. While the Stable Diffusion models use a
|
||||
spatial compression factor of 8, encoding an image with resolution of 1024 x 1024 to 128 x 128, Stable Cascade achieves
|
||||
a compression factor of 42. This encodes a 1024 x 1024 image to 24 x 24, while being able to accurately decode the
|
||||
image. This comes with the great benefit of cheaper training and inference. Furthermore, Stage C is responsible
|
||||
Stage A & B are used to compress images, similar to what the job of the VAE is in Stable Diffusion.
|
||||
However, with this setup, a much higher compression of images can be achieved. While the Stable Diffusion models use a
|
||||
spatial compression factor of 8, encoding an image with resolution of 1024 x 1024 to 128 x 128, Stable Cascade achieves
|
||||
a compression factor of 42. This encodes a 1024 x 1024 image to 24 x 24, while being able to accurately decode the
|
||||
image. This comes with the great benefit of cheaper training and inference. Furthermore, Stage C is responsible
|
||||
for generating the small 24 x 24 latents given a text prompt.
|
||||
|
||||
The Stage C model operates on the small 24 x 24 latents and denoises the latents conditioned on text prompts. The model is also the largest component in the Cascade pipeline and is meant to be used with the `StableCascadePriorPipeline`
|
||||
|
||||
The Stage B and Stage A models are used with the `StableCascadeDecoderPipeline` and are responsible for generating the final image given the small 24 x 24 latents.
|
||||
|
||||
<Tip warning={true}>
|
||||
|
||||
There are some restrictions on data types that can be used with the Stable Cascade models. The official checkpoints for the `StableCascadePriorPipeline` do not support the `torch.float16` data type. Please use `torch.bfloat16` instead.
|
||||
|
||||
In order to use the `torch.bfloat16` data type with the `StableCascadeDecoderPipeline` you need to have PyTorch 2.2.0 or higher installed. This also means that using the `StableCascadeCombinedPipeline` with `torch.bfloat16` requires PyTorch 2.2.0 or higher, since it calls the `StableCascadeDecoderPipeline` internally.
|
||||
|
||||
If it is not possible to install PyTorch 2.2.0 or higher in your environment, the `StableCascadeDecoderPipeline` can be used on its own with the `torch.float16` data type. You can download the full precision or `bf16` variant weights for the pipeline and cast the weights to `torch.float16`.
|
||||
|
||||
</Tip>
|
||||
|
||||
## Usage example
|
||||
|
||||
```python
|
||||
import torch
|
||||
from diffusers import StableCascadeDecoderPipeline, StableCascadePriorPipeline
|
||||
|
||||
prompt = "an image of a shiba inu, donning a spacesuit and helmet"
|
||||
negative_prompt = ""
|
||||
|
||||
prior = StableCascadePriorPipeline.from_pretrained("stabilityai/stable-cascade-prior", variant="bf16", torch_dtype=torch.bfloat16)
|
||||
decoder = StableCascadeDecoderPipeline.from_pretrained("stabilityai/stable-cascade", variant="bf16", torch_dtype=torch.float16)
|
||||
|
||||
prior.enable_model_cpu_offload()
|
||||
prior_output = prior(
|
||||
prompt=prompt,
|
||||
height=1024,
|
||||
width=1024,
|
||||
negative_prompt=negative_prompt,
|
||||
guidance_scale=4.0,
|
||||
num_images_per_prompt=1,
|
||||
num_inference_steps=20
|
||||
)
|
||||
|
||||
decoder.enable_model_cpu_offload()
|
||||
decoder_output = decoder(
|
||||
image_embeddings=prior_output.image_embeddings.to(torch.float16),
|
||||
prompt=prompt,
|
||||
negative_prompt=negative_prompt,
|
||||
guidance_scale=0.0,
|
||||
output_type="pil",
|
||||
num_inference_steps=10
|
||||
).images[0]
|
||||
decoder_output.save("cascade.png")
|
||||
```
|
||||
|
||||
## Using the Lite Versions of the Stage B and Stage C models
|
||||
|
||||
```python
|
||||
import torch
|
||||
from diffusers import (
|
||||
StableCascadeDecoderPipeline,
|
||||
StableCascadePriorPipeline,
|
||||
StableCascadeUNet,
|
||||
)
|
||||
|
||||
prompt = "an image of a shiba inu, donning a spacesuit and helmet"
|
||||
negative_prompt = ""
|
||||
|
||||
prior_unet = StableCascadeUNet.from_pretrained("stabilityai/stable-cascade-prior", subfolder="prior_lite")
|
||||
decoder_unet = StableCascadeUNet.from_pretrained("stabilityai/stable-cascade", subfolder="decoder_lite")
|
||||
|
||||
prior = StableCascadePriorPipeline.from_pretrained("stabilityai/stable-cascade-prior", prior=prior_unet)
|
||||
decoder = StableCascadeDecoderPipeline.from_pretrained("stabilityai/stable-cascade", decoder=decoder_unet)
|
||||
|
||||
prior.enable_model_cpu_offload()
|
||||
prior_output = prior(
|
||||
prompt=prompt,
|
||||
height=1024,
|
||||
width=1024,
|
||||
negative_prompt=negative_prompt,
|
||||
guidance_scale=4.0,
|
||||
num_images_per_prompt=1,
|
||||
num_inference_steps=20
|
||||
)
|
||||
|
||||
decoder.enable_model_cpu_offload()
|
||||
decoder_output = decoder(
|
||||
image_embeddings=prior_output.image_embeddings,
|
||||
prompt=prompt,
|
||||
negative_prompt=negative_prompt,
|
||||
guidance_scale=0.0,
|
||||
output_type="pil",
|
||||
num_inference_steps=10
|
||||
).images[0]
|
||||
decoder_output.save("cascade.png")
|
||||
```
|
||||
|
||||
## Loading original checkpoints with `from_single_file`
|
||||
|
||||
Loading the original format checkpoints is supported via `from_single_file` method in the StableCascadeUNet.
|
||||
|
||||
```python
|
||||
import torch
|
||||
from diffusers import (
|
||||
StableCascadeDecoderPipeline,
|
||||
StableCascadePriorPipeline,
|
||||
StableCascadeUNet,
|
||||
)
|
||||
|
||||
prompt = "an image of a shiba inu, donning a spacesuit and helmet"
|
||||
negative_prompt = ""
|
||||
|
||||
prior_unet = StableCascadeUNet.from_single_file(
|
||||
"https://huggingface.co/stabilityai/stable-cascade/resolve/main/stage_c_bf16.safetensors",
|
||||
torch_dtype=torch.bfloat16
|
||||
)
|
||||
decoder_unet = StableCascadeUNet.from_single_file(
|
||||
"https://huggingface.co/stabilityai/stable-cascade/blob/main/stage_b_bf16.safetensors",
|
||||
torch_dtype=torch.bfloat16
|
||||
)
|
||||
|
||||
prior = StableCascadePriorPipeline.from_pretrained("stabilityai/stable-cascade-prior", prior=prior_unet, torch_dtype=torch.bfloat16)
|
||||
decoder = StableCascadeDecoderPipeline.from_pretrained("stabilityai/stable-cascade", decoder=decoder_unet, torch_dtype=torch.bfloat16)
|
||||
|
||||
prior.enable_model_cpu_offload()
|
||||
prior_output = prior(
|
||||
prompt=prompt,
|
||||
height=1024,
|
||||
width=1024,
|
||||
negative_prompt=negative_prompt,
|
||||
guidance_scale=4.0,
|
||||
num_images_per_prompt=1,
|
||||
num_inference_steps=20
|
||||
)
|
||||
|
||||
decoder.enable_model_cpu_offload()
|
||||
decoder_output = decoder(
|
||||
image_embeddings=prior_output.image_embeddings,
|
||||
prompt=prompt,
|
||||
negative_prompt=negative_prompt,
|
||||
guidance_scale=0.0,
|
||||
output_type="pil",
|
||||
num_inference_steps=10
|
||||
).images[0]
|
||||
decoder_output.save("cascade-single-file.png")
|
||||
```
|
||||
|
||||
## Uses
|
||||
|
||||
### Direct Use
|
||||
@@ -53,7 +194,7 @@ Excluded uses are described below.
|
||||
|
||||
### Out-of-Scope Use
|
||||
|
||||
The model was not trained to be factual or true representations of people or events,
|
||||
The model was not trained to be factual or true representations of people or events,
|
||||
and therefore using the model to generate such content is out-of-scope for the abilities of this model.
|
||||
The model should not be used in any way that violates Stability AI's [Acceptable Use Policy](https://stability.ai/use-policy).
|
||||
|
||||
|
||||
@@ -172,3 +172,41 @@ inpaint = StableDiffusionInpaintPipeline(**text2img.components)
|
||||
|
||||
# now you can use text2img(...), img2img(...), inpaint(...) just like the call methods of each respective pipeline
|
||||
```
|
||||
|
||||
### Create web demos using `gradio`
|
||||
|
||||
The Stable Diffusion pipelines are automatically supported in [Gradio](https://github.com/gradio-app/gradio/), a library that makes creating beautiful and user-friendly machine learning apps on the web a breeze. First, make sure you have Gradio installed:
|
||||
|
||||
```
|
||||
pip install -U gradio
|
||||
```
|
||||
|
||||
Then, create a web demo around any Stable Diffusion-based pipeline. For example, you can create an image generation pipeline in a single line of code with Gradio's [`Interface.from_pipeline`](https://www.gradio.app/docs/interface#interface-from-pipeline) function:
|
||||
|
||||
```py
|
||||
from diffusers import StableDiffusionPipeline
|
||||
import gradio as gr
|
||||
|
||||
pipe = StableDiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-4")
|
||||
|
||||
gr.Interface.from_pipeline(pipe).launch()
|
||||
```
|
||||
|
||||
which opens an intuitive drag-and-drop interface in your browser:
|
||||
|
||||

|
||||
|
||||
Similarly, you could create a demo for an image-to-image pipeline with:
|
||||
|
||||
```py
|
||||
from diffusers import StableDiffusionImg2ImgPipeline
|
||||
import gradio as gr
|
||||
|
||||
|
||||
pipe = StableDiffusionImg2ImgPipeline.from_pretrained("runwayml/stable-diffusion-v1-5")
|
||||
|
||||
gr.Interface.from_pipeline(pipe).launch()
|
||||
```
|
||||
|
||||
By default, the web demo runs on a local server. If you'd like to share it with others, you can generate a temporary public
|
||||
link by setting `share=True` in `launch()`. Or, you can host your demo on [Hugging Face Spaces](https://huggingface.co/spaces)https://huggingface.co/spaces for a permanent link.
|
||||
@@ -0,0 +1,22 @@
|
||||
<!--Copyright 2024 The HuggingFace Team. All rights reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
|
||||
the License. You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
|
||||
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
|
||||
specific language governing permissions and limitations under the License.
|
||||
-->
|
||||
|
||||
# EDMEulerScheduler
|
||||
|
||||
The Karras formulation of the Euler scheduler (Algorithm 2) from the [Elucidating the Design Space of Diffusion-Based Generative Models](https://huggingface.co/papers/2206.00364) paper by Karras et al. This is a fast scheduler which can often generate good outputs in 20-30 steps. The scheduler is based on the original [k-diffusion](https://github.com/crowsonkb/k-diffusion/blob/481677d114f6ea445aa009cf5bd7a9cdee909e47/k_diffusion/sampling.py#L51) implementation by [Katherine Crowson](https://github.com/crowsonkb/).
|
||||
|
||||
|
||||
## EDMEulerScheduler
|
||||
[[autodoc]] EDMEulerScheduler
|
||||
|
||||
## EDMEulerSchedulerOutput
|
||||
[[autodoc]] schedulers.scheduling_edm_euler.EDMEulerSchedulerOutput
|
||||
@@ -0,0 +1,24 @@
|
||||
<!--Copyright 2024 The HuggingFace Team. All rights reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
|
||||
the License. You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
|
||||
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
|
||||
specific language governing permissions and limitations under the License.
|
||||
-->
|
||||
|
||||
# EDMDPMSolverMultistepScheduler
|
||||
|
||||
`EDMDPMSolverMultistepScheduler` is a [Karras formulation](https://huggingface.co/papers/2206.00364) of `DPMSolverMultistep`, a multistep scheduler from [DPM-Solver: A Fast ODE Solver for Diffusion Probabilistic Model Sampling in Around 10 Steps](https://huggingface.co/papers/2206.00927) and [DPM-Solver++: Fast Solver for Guided Sampling of Diffusion Probabilistic Models](https://huggingface.co/papers/2211.01095) by Cheng Lu, Yuhao Zhou, Fan Bao, Jianfei Chen, Chongxuan Li, and Jun Zhu.
|
||||
|
||||
DPMSolver (and the improved version DPMSolver++) is a fast dedicated high-order solver for diffusion ODEs with convergence order guarantee. Empirically, DPMSolver sampling with only 20 steps can generate high-quality
|
||||
samples, and it can generate quite good samples even in 10 steps.
|
||||
|
||||
## EDMDPMSolverMultistepScheduler
|
||||
[[autodoc]] EDMDPMSolverMultistepScheduler
|
||||
|
||||
## SchedulerOutput
|
||||
[[autodoc]] schedulers.scheduling_utils.SchedulerOutput
|
||||
@@ -14,19 +14,17 @@ specific language governing permissions and limitations under the License.
|
||||
|
||||
# Load LoRAs for inference
|
||||
|
||||
There are many adapters (with LoRAs being the most common type) trained in different styles to achieve different effects. You can even combine multiple adapters to create new and unique images. With the 🤗 [PEFT](https://huggingface.co/docs/peft/index) integration in 🤗 Diffusers, it is really easy to load and manage adapters for inference. In this guide, you'll learn how to use different adapters with [Stable Diffusion XL (SDXL)](../api/pipelines/stable_diffusion/stable_diffusion_xl) for inference.
|
||||
There are many adapter types (with [LoRAs](https://huggingface.co/docs/peft/conceptual_guides/adapter#low-rank-adaptation-lora) being the most popular) trained in different styles to achieve different effects. You can even combine multiple adapters to create new and unique images.
|
||||
|
||||
Throughout this guide, you'll use LoRA as the main adapter technique, so we'll use the terms LoRA and adapter interchangeably. You should have some familiarity with LoRA, and if you don't, we welcome you to check out the [LoRA guide](https://huggingface.co/docs/peft/conceptual_guides/lora).
|
||||
In this tutorial, you'll learn how to easily load and manage adapters for inference with the 🤗 [PEFT](https://huggingface.co/docs/peft/index) integration in 🤗 Diffusers. You'll use LoRA as the main adapter technique, so you'll see the terms LoRA and adapter used interchangeably.
|
||||
|
||||
Let's first install all the required libraries.
|
||||
|
||||
```bash
|
||||
!pip install -q transformers accelerate
|
||||
!pip install peft
|
||||
!pip install diffusers
|
||||
!pip install -q transformers accelerate peft diffusers
|
||||
```
|
||||
|
||||
Now, let's load a pipeline with a SDXL checkpoint:
|
||||
Now, load a pipeline with a [Stable Diffusion XL (SDXL)](../api/pipelines/stable_diffusion/stable_diffusion_xl) checkpoint:
|
||||
|
||||
```python
|
||||
from diffusers import DiffusionPipeline
|
||||
@@ -36,16 +34,13 @@ pipe_id = "stabilityai/stable-diffusion-xl-base-1.0"
|
||||
pipe = DiffusionPipeline.from_pretrained(pipe_id, torch_dtype=torch.float16).to("cuda")
|
||||
```
|
||||
|
||||
|
||||
Next, load a LoRA checkpoint with the [`~diffusers.loaders.StableDiffusionXLLoraLoaderMixin.load_lora_weights`] method.
|
||||
|
||||
With the 🤗 PEFT integration, you can assign a specific `adapter_name` to the checkpoint, which let's you easily switch between different LoRA checkpoints. Let's call this adapter `"toy"`.
|
||||
Next, load a [CiroN2022/toy-face](https://huggingface.co/CiroN2022/toy-face) adapter with the [`~diffusers.loaders.StableDiffusionXLLoraLoaderMixin.load_lora_weights`] method. With the 🤗 PEFT integration, you can assign a specific `adapter_name` to the checkpoint, which let's you easily switch between different LoRA checkpoints. Let's call this adapter `"toy"`.
|
||||
|
||||
```python
|
||||
pipe.load_lora_weights("CiroN2022/toy-face", weight_name="toy_face_sdxl.safetensors", adapter_name="toy")
|
||||
```
|
||||
|
||||
And then perform inference:
|
||||
Make sure to include the token `toy_face` in the prompt and then you can perform inference:
|
||||
|
||||
```python
|
||||
prompt = "toy_face of a hacker with a hoodie"
|
||||
@@ -59,17 +54,16 @@ image
|
||||
|
||||

|
||||
|
||||
With the `adapter_name` parameter, it is really easy to use another adapter for inference! Load the [nerijs/pixel-art-xl](https://huggingface.co/nerijs/pixel-art-xl) adapter that has been fine-tuned to generate pixel art images and call it `"pixel"`.
|
||||
|
||||
With the `adapter_name` parameter, it is really easy to use another adapter for inference! Load the [nerijs/pixel-art-xl](https://huggingface.co/nerijs/pixel-art-xl) adapter that has been fine-tuned to generate pixel art images, and let's call it `"pixel"`.
|
||||
|
||||
The pipeline automatically sets the first loaded adapter (`"toy"`) as the active adapter. But you can activate the `"pixel"` adapter with the [`~diffusers.loaders.UNet2DConditionLoadersMixin.set_adapters`] method as shown below:
|
||||
The pipeline automatically sets the first loaded adapter (`"toy"`) as the active adapter, but you can activate the `"pixel"` adapter with the [`~diffusers.loaders.UNet2DConditionLoadersMixin.set_adapters`] method:
|
||||
|
||||
```python
|
||||
pipe.load_lora_weights("nerijs/pixel-art-xl", weight_name="pixel-art-xl.safetensors", adapter_name="pixel")
|
||||
pipe.set_adapters("pixel")
|
||||
```
|
||||
|
||||
Let's now generate an image with the second adapter and check the result:
|
||||
Make sure you include the token `pixel art` in your prompt to generate a pixel art image:
|
||||
|
||||
```python
|
||||
prompt = "a hacker with a hoodie, pixel art"
|
||||
@@ -81,29 +75,25 @@ image
|
||||
|
||||

|
||||
|
||||
## Combine multiple adapters
|
||||
## Merge adapters
|
||||
|
||||
You can also perform multi-adapter inference where you combine different adapter checkpoints for inference.
|
||||
You can also merge different adapter checkpoints for inference to blend their styles together.
|
||||
|
||||
Once again, use the [`~diffusers.loaders.UNet2DConditionLoadersMixin.set_adapters`] method to activate two LoRA checkpoints and specify the weight for how the checkpoints should be combined.
|
||||
Once again, use the [`~diffusers.loaders.UNet2DConditionLoadersMixin.set_adapters`] method to activate the `pixel` and `toy` adapters and specify the weights for how they should be merged.
|
||||
|
||||
```python
|
||||
pipe.set_adapters(["pixel", "toy"], adapter_weights=[0.5, 1.0])
|
||||
```
|
||||
|
||||
Now that we have set these two adapters, let's generate an image from the combined adapters!
|
||||
|
||||
<Tip>
|
||||
|
||||
LoRA checkpoints in the diffusion community are almost always obtained with [DreamBooth](https://huggingface.co/docs/diffusers/main/en/training/dreambooth). DreamBooth training often relies on "trigger" words in the input text prompts in order for the generation results to look as expected. When you combine multiple LoRA checkpoints, it's important to ensure the trigger words for the corresponding LoRA checkpoints are present in the input text prompts.
|
||||
|
||||
</Tip>
|
||||
|
||||
The trigger words for [CiroN2022/toy-face](https://hf.co/CiroN2022/toy-face) and [nerijs/pixel-art-xl](https://hf.co/nerijs/pixel-art-xl) are found in their repositories.
|
||||
|
||||
Remember to use the trigger words for [CiroN2022/toy-face](https://hf.co/CiroN2022/toy-face) and [nerijs/pixel-art-xl](https://hf.co/nerijs/pixel-art-xl) (these are found in their repositories) in the prompt to generate an image.
|
||||
|
||||
```python
|
||||
# Notice how the prompt is constructed.
|
||||
prompt = "toy_face of a hacker with a hoodie, pixel art"
|
||||
image = pipe(
|
||||
prompt, num_inference_steps=30, cross_attention_kwargs={"scale": 1.0}, generator=torch.manual_seed(0)
|
||||
@@ -113,15 +103,16 @@ image
|
||||
|
||||

|
||||
|
||||
Impressive! As you can see, the model was able to generate an image that mixes the characteristics of both adapters.
|
||||
Impressive! As you can see, the model generated an image that mixed the characteristics of both adapters.
|
||||
|
||||
If you want to go back to using only one adapter, use the [`~diffusers.loaders.UNet2DConditionLoadersMixin.set_adapters`] method to activate the `"toy"` adapter:
|
||||
> [!TIP]
|
||||
> Through its PEFT integration, Diffusers also offers more efficient merging methods which you can learn about in the [Merge LoRAs](../using-diffusers/merge_loras) guide!
|
||||
|
||||
To return to only using one adapter, use the [`~diffusers.loaders.UNet2DConditionLoadersMixin.set_adapters`] method to activate the `"toy"` adapter:
|
||||
|
||||
```python
|
||||
# First, set the adapter.
|
||||
pipe.set_adapters("toy")
|
||||
|
||||
# Then, run inference.
|
||||
prompt = "toy_face of a hacker with a hoodie"
|
||||
lora_scale= 0.9
|
||||
image = pipe(
|
||||
@@ -130,11 +121,7 @@ image = pipe(
|
||||
image
|
||||
```
|
||||
|
||||

|
||||
|
||||
|
||||
If you want to switch to only the base model, disable all LoRAs with the [`~diffusers.loaders.UNet2DConditionLoadersMixin.disable_lora`] method.
|
||||
|
||||
Or to disable all adapters entirely, use the [`~diffusers.loaders.UNet2DConditionLoadersMixin.disable_lora`] method to return the base model.
|
||||
|
||||
```python
|
||||
pipe.disable_lora()
|
||||
@@ -145,11 +132,9 @@ image = pipe(prompt, num_inference_steps=30, generator=torch.manual_seed(0)).ima
|
||||
image
|
||||
```
|
||||
|
||||

|
||||
## Manage active adapters
|
||||
|
||||
## Monitoring active adapters
|
||||
|
||||
You have attached multiple adapters in this tutorial, and if you're feeling a bit lost on what adapters have been attached to the pipeline's components, you can easily check the list of active adapters using the [`~diffusers.loaders.LoraLoaderMixin.get_active_adapters`] method:
|
||||
You have attached multiple adapters in this tutorial, and if you're feeling a bit lost on what adapters have been attached to the pipeline's components, use the [`~diffusers.loaders.LoraLoaderMixin.get_active_adapters`] method to check the list of active adapters:
|
||||
|
||||
```py
|
||||
active_adapters = pipe.get_active_adapters()
|
||||
@@ -164,78 +149,3 @@ list_adapters_component_wise = pipe.get_list_adapters()
|
||||
list_adapters_component_wise
|
||||
{"text_encoder": ["toy", "pixel"], "unet": ["toy", "pixel"], "text_encoder_2": ["toy", "pixel"]}
|
||||
```
|
||||
|
||||
## Compatibility with `torch.compile`
|
||||
|
||||
If you want to compile your model with `torch.compile` make sure to first fuse the LoRA weights into the base model and unload them.
|
||||
|
||||
```diff
|
||||
pipe.load_lora_weights("nerijs/pixel-art-xl", weight_name="pixel-art-xl.safetensors", adapter_name="pixel")
|
||||
pipe.load_lora_weights("CiroN2022/toy-face", weight_name="toy_face_sdxl.safetensors", adapter_name="toy")
|
||||
|
||||
pipe.set_adapters(["pixel", "toy"], adapter_weights=[0.5, 1.0])
|
||||
# Fuses the LoRAs into the Unet
|
||||
pipe.fuse_lora()
|
||||
pipe.unload_lora_weights()
|
||||
|
||||
+ pipe.unet.to(memory_format=torch.channels_last)
|
||||
+ pipe.unet = torch.compile(pipe.unet, mode="reduce-overhead", fullgraph=True)
|
||||
|
||||
prompt = "toy_face of a hacker with a hoodie, pixel art"
|
||||
image = pipe(prompt, num_inference_steps=30, generator=torch.manual_seed(0)).images[0]
|
||||
```
|
||||
|
||||
> [!TIP]
|
||||
> You can refer to the `torch.compile()` section [here](https://huggingface.co/docs/diffusers/main/en/optimization/torch2.0#torchcompile) and [here](https://huggingface.co/docs/diffusers/main/en/tutorials/fast_diffusion#torchcompile) for more elaborate examples.
|
||||
|
||||
## Fusing adapters into the model
|
||||
|
||||
You can use PEFT to easily fuse/unfuse multiple adapters directly into the model weights (both UNet and text encoder) using the [`~diffusers.loaders.LoraLoaderMixin.fuse_lora`] method, which can lead to a speed-up in inference and lower VRAM usage.
|
||||
|
||||
```py
|
||||
pipe.load_lora_weights("nerijs/pixel-art-xl", weight_name="pixel-art-xl.safetensors", adapter_name="pixel")
|
||||
pipe.load_lora_weights("CiroN2022/toy-face", weight_name="toy_face_sdxl.safetensors", adapter_name="toy")
|
||||
|
||||
pipe.set_adapters(["pixel", "toy"], adapter_weights=[0.5, 1.0])
|
||||
# Fuses the LoRAs into the Unet
|
||||
pipe.fuse_lora()
|
||||
|
||||
prompt = "toy_face of a hacker with a hoodie, pixel art"
|
||||
image = pipe(prompt, num_inference_steps=30, generator=torch.manual_seed(0)).images[0]
|
||||
|
||||
# Gets the Unet back to the original state
|
||||
pipe.unfuse_lora()
|
||||
```
|
||||
|
||||
You can also fuse some adapters using `adapter_names` for faster generation:
|
||||
|
||||
```py
|
||||
pipe.load_lora_weights("nerijs/pixel-art-xl", weight_name="pixel-art-xl.safetensors", adapter_name="pixel")
|
||||
pipe.load_lora_weights("CiroN2022/toy-face", weight_name="toy_face_sdxl.safetensors", adapter_name="toy")
|
||||
|
||||
pipe.set_adapters(["pixel"], adapter_weights=[0.5, 1.0])
|
||||
# Fuses the LoRAs into the Unet
|
||||
pipe.fuse_lora(adapter_names=["pixel"])
|
||||
|
||||
prompt = "a hacker with a hoodie, pixel art"
|
||||
image = pipe(prompt, num_inference_steps=30, generator=torch.manual_seed(0)).images[0]
|
||||
|
||||
# Gets the Unet back to the original state
|
||||
pipe.unfuse_lora()
|
||||
|
||||
# Fuse all adapters
|
||||
pipe.fuse_lora(adapter_names=["pixel", "toy"])
|
||||
|
||||
prompt = "toy_face of a hacker with a hoodie, pixel art"
|
||||
image = pipe(prompt, num_inference_steps=30, generator=torch.manual_seed(0)).images[0]
|
||||
```
|
||||
|
||||
## Saving a pipeline after fusing the adapters
|
||||
|
||||
To properly save a pipeline after it's been loaded with the adapters, it should be serialized like so:
|
||||
|
||||
```python
|
||||
pipe.fuse_lora(lora_scale=1.0)
|
||||
pipe.unload_lora_weights()
|
||||
pipe.save_pretrained("path-to-pipeline")
|
||||
```
|
||||
|
||||
@@ -12,13 +12,18 @@ specific language governing permissions and limitations under the License.
|
||||
|
||||
# Pipeline callbacks
|
||||
|
||||
The denoising loop of a pipeline can be modified with custom defined functions using the `callback_on_step_end` parameter. This can be really useful for *dynamically* adjusting certain pipeline attributes, or modifying tensor variables. The flexibility of callbacks opens up some interesting use-cases such as changing the prompt embeddings at each timestep, assigning different weights to the prompt embeddings, and editing the guidance scale.
|
||||
The denoising loop of a pipeline can be modified with custom defined functions using the `callback_on_step_end` parameter. The callback function is executed at the end of each step, and modifies the pipeline attributes and variables for the next step. This is really useful for *dynamically* adjusting certain pipeline attributes or modifying tensor variables. This versatility allows for interesting use-cases such as changing the prompt embeddings at each timestep, assigning different weights to the prompt embeddings, and editing the guidance scale. With callbacks, you can implement new features without modifying the underlying code!
|
||||
|
||||
This guide will show you how to use the `callback_on_step_end` parameter to disable classifier-free guidance (CFG) after 40% of the inference steps to save compute with minimal cost to performance.
|
||||
> [!TIP]
|
||||
> 🤗 Diffusers currently only supports `callback_on_step_end`, but feel free to open a [feature request](https://github.com/huggingface/diffusers/issues/new/choose) if you have a cool use-case and require a callback function with a different execution point!
|
||||
|
||||
The callback function should have the following arguments:
|
||||
This guide will demonstrate how callbacks work by a few features you can implement with them.
|
||||
|
||||
* `pipe` (or the pipeline instance) provides access to useful properties such as `num_timesteps` and `guidance_scale`. You can modify these properties by updating the underlying attributes. For this example, you'll disable CFG by setting `pipe._guidance_scale=0.0`.
|
||||
## Dynamic classifier-free guidance
|
||||
|
||||
Dynamic classifier-free guidance (CFG) is a feature that allows you to disable CFG after a certain number of inference steps which can help you save compute with minimal cost to performance. The callback function for this should have the following arguments:
|
||||
|
||||
* `pipeline` (or the pipeline instance) provides access to important properties such as `num_timesteps` and `guidance_scale`. You can modify these properties by updating the underlying attributes. For this example, you'll disable CFG by setting `pipeline._guidance_scale=0.0`.
|
||||
* `step_index` and `timestep` tell you where you are in the denoising loop. Use `step_index` to turn off CFG after reaching 40% of `num_timesteps`.
|
||||
* `callback_kwargs` is a dict that contains tensor variables you can modify during the denoising loop. It only includes variables specified in the `callback_on_step_end_tensor_inputs` argument, which is passed to the pipeline's `__call__` method. Different pipelines may use different sets of variables, so please check a pipeline's `_callback_tensor_inputs` attribute for the list of variables you can modify. Some common variables include `latents` and `prompt_embeds`. For this function, change the batch size of `prompt_embeds` after setting `guidance_scale=0.0` in order for it to work properly.
|
||||
|
||||
@@ -27,12 +32,12 @@ Your callback function should look something like this:
|
||||
```python
|
||||
def callback_dynamic_cfg(pipe, step_index, timestep, callback_kwargs):
|
||||
# adjust the batch_size of prompt_embeds according to guidance_scale
|
||||
if step_index == int(pipe.num_timesteps * 0.4):
|
||||
if step_index == int(pipeline.num_timesteps * 0.4):
|
||||
prompt_embeds = callback_kwargs["prompt_embeds"]
|
||||
prompt_embeds = prompt_embeds.chunk(2)[-1]
|
||||
|
||||
# update guidance_scale and prompt_embeds
|
||||
pipe._guidance_scale = 0.0
|
||||
pipeline._guidance_scale = 0.0
|
||||
callback_kwargs["prompt_embeds"] = prompt_embeds
|
||||
return callback_kwargs
|
||||
```
|
||||
@@ -43,58 +48,134 @@ Now, you can pass the callback function to the `callback_on_step_end` parameter
|
||||
import torch
|
||||
from diffusers import StableDiffusionPipeline
|
||||
|
||||
pipe = StableDiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5", torch_dtype=torch.float16)
|
||||
pipe = pipe.to("cuda")
|
||||
pipeline = StableDiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5", torch_dtype=torch.float16)
|
||||
pipeline = pipeline.to("cuda")
|
||||
|
||||
prompt = "a photo of an astronaut riding a horse on mars"
|
||||
|
||||
generator = torch.Generator(device="cuda").manual_seed(1)
|
||||
out = pipe(prompt, generator=generator, callback_on_step_end=callback_dynamic_cfg, callback_on_step_end_tensor_inputs=['prompt_embeds'])
|
||||
out = pipeline(
|
||||
prompt,
|
||||
generator=generator,
|
||||
callback_on_step_end=callback_dynamic_cfg,
|
||||
callback_on_step_end_tensor_inputs=['prompt_embeds']
|
||||
)
|
||||
|
||||
out.images[0].save("out_custom_cfg.png")
|
||||
```
|
||||
|
||||
The callback function is executed at the end of each denoising step, and modifies the pipeline attributes and tensor variables for the next denoising step.
|
||||
|
||||
With callbacks, you can implement features such as dynamic CFG without having to modify the underlying code at all!
|
||||
|
||||
<Tip>
|
||||
|
||||
🤗 Diffusers currently only supports `callback_on_step_end`, but feel free to open a [feature request](https://github.com/huggingface/diffusers/issues/new/choose) if you have a cool use-case and require a callback function with a different execution point!
|
||||
|
||||
</Tip>
|
||||
|
||||
## Interrupt the diffusion process
|
||||
|
||||
Interrupting the diffusion process is particularly useful when building UIs that work with Diffusers because it allows users to stop the generation process if they're unhappy with the intermediate results. You can incorporate this into your pipeline with a callback.
|
||||
> [!TIP]
|
||||
> The interruption callback is supported for text-to-image, image-to-image, and inpainting for the [StableDiffusionPipeline](../api/pipelines/stable_diffusion/overview) and [StableDiffusionXLPipeline](../api/pipelines/stable_diffusion/stable_diffusion_xl).
|
||||
|
||||
<Tip>
|
||||
Stopping the diffusion process early is useful when building UIs that work with Diffusers because it allows users to stop the generation process if they're unhappy with the intermediate results. You can incorporate this into your pipeline with a callback.
|
||||
|
||||
The interruption callback is supported for text-to-image, image-to-image, and inpainting for the [StableDiffusionPipeline](../api/pipelines/stable_diffusion/overview) and [StableDiffusionXLPipeline](../api/pipelines/stable_diffusion/stable_diffusion_xl).
|
||||
|
||||
</Tip>
|
||||
|
||||
This callback function should take the following arguments: `pipe`, `i`, `t`, and `callback_kwargs` (this must be returned). Set the pipeline's `_interrupt` attribute to `True` to stop the diffusion process after a certain number of steps. You are also free to implement your own custom stopping logic inside the callback.
|
||||
This callback function should take the following arguments: `pipeline`, `i`, `t`, and `callback_kwargs` (this must be returned). Set the pipeline's `_interrupt` attribute to `True` to stop the diffusion process after a certain number of steps. You are also free to implement your own custom stopping logic inside the callback.
|
||||
|
||||
In this example, the diffusion process is stopped after 10 steps even though `num_inference_steps` is set to 50.
|
||||
|
||||
```python
|
||||
from diffusers import StableDiffusionPipeline
|
||||
|
||||
pipe = StableDiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5")
|
||||
pipe.enable_model_cpu_offload()
|
||||
pipeline = StableDiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5")
|
||||
pipeline.enable_model_cpu_offload()
|
||||
num_inference_steps = 50
|
||||
|
||||
def interrupt_callback(pipe, i, t, callback_kwargs):
|
||||
def interrupt_callback(pipeline, i, t, callback_kwargs):
|
||||
stop_idx = 10
|
||||
if i == stop_idx:
|
||||
pipe._interrupt = True
|
||||
pipeline._interrupt = True
|
||||
|
||||
return callback_kwargs
|
||||
|
||||
pipe(
|
||||
pipeline(
|
||||
"A photo of a cat",
|
||||
num_inference_steps=num_inference_steps,
|
||||
callback_on_step_end=interrupt_callback,
|
||||
)
|
||||
```
|
||||
|
||||
## Display image after each generation step
|
||||
|
||||
> [!TIP]
|
||||
> This tip was contributed by [asomoza](https://github.com/asomoza).
|
||||
|
||||
Display an image after each generation step by accessing and converting the latents after each step into an image. The latent space is compressed to 128x128, so the images are also 128x128 which is useful for a quick preview.
|
||||
|
||||
1. Use the function below to convert the SDXL latents (4 channels) to RGB tensors (3 channels) as explained in the [Explaining the SDXL latent space](https://huggingface.co/blog/TimothyAlexisVass/explaining-the-sdxl-latent-space) blog post.
|
||||
|
||||
```py
|
||||
def latents_to_rgb(latents):
|
||||
weights = (
|
||||
(60, -60, 25, -70),
|
||||
(60, -5, 15, -50),
|
||||
(60, 10, -5, -35)
|
||||
)
|
||||
|
||||
weights_tensor = torch.t(torch.tensor(weights, dtype=latents.dtype).to(latents.device))
|
||||
biases_tensor = torch.tensor((150, 140, 130), dtype=latents.dtype).to(latents.device)
|
||||
rgb_tensor = torch.einsum("...lxy,lr -> ...rxy", latents, weights_tensor) + biases_tensor.unsqueeze(-1).unsqueeze(-1)
|
||||
image_array = rgb_tensor.clamp(0, 255)[0].byte().cpu().numpy()
|
||||
image_array = image_array.transpose(1, 2, 0)
|
||||
|
||||
return Image.fromarray(image_array)
|
||||
```
|
||||
|
||||
2. Create a function to decode and save the latents into an image.
|
||||
|
||||
```py
|
||||
def decode_tensors(pipe, step, timestep, callback_kwargs):
|
||||
latents = callback_kwargs["latents"]
|
||||
|
||||
image = latents_to_rgb(latents)
|
||||
image.save(f"{step}.png")
|
||||
|
||||
return callback_kwargs
|
||||
```
|
||||
|
||||
3. Pass the `decode_tensors` function to the `callback_on_step_end` parameter to decode the tensors after each step. You also need to specify what you want to modify in the `callback_on_step_end_tensor_inputs` parameter, which in this case are the latents.
|
||||
|
||||
```py
|
||||
from diffusers import AutoPipelineForText2Image
|
||||
import torch
|
||||
from PIL import Image
|
||||
|
||||
pipeline = AutoPipelineForText2Image.from_pretrained(
|
||||
"stabilityai/stable-diffusion-xl-base-1.0",
|
||||
torch_dtype=torch.float16,
|
||||
variant="fp16",
|
||||
use_safetensors=True
|
||||
).to("cuda")
|
||||
|
||||
image = pipe(
|
||||
prompt = "A croissant shaped like a cute bear."
|
||||
negative_prompt = "Deformed, ugly, bad anatomy"
|
||||
callback_on_step_end=decode_tensors,
|
||||
callback_on_step_end_tensor_inputs=["latents"],
|
||||
).images[0]
|
||||
```
|
||||
|
||||
<div class="flex gap-4 justify-center">
|
||||
<div>
|
||||
<img class="rounded-xl" src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/tips_step_0.png"/>
|
||||
<figcaption class="mt-2 text-center text-sm text-gray-500">step 0</figcaption>
|
||||
</div>
|
||||
<div>
|
||||
<img class="rounded-xl" src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/tips_step_19.png"/>
|
||||
<figcaption class="mt-2 text-center text-sm text-gray-500">step 19
|
||||
</figcaption>
|
||||
</div>
|
||||
<div>
|
||||
<img class="rounded-xl" src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/tips_step_29.png"/>
|
||||
<figcaption class="mt-2 text-center text-sm text-gray-500">step 29</figcaption>
|
||||
</div>
|
||||
<div>
|
||||
<img class="rounded-xl" src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/tips_step_39.png"/>
|
||||
<figcaption class="mt-2 text-center text-sm text-gray-500">step 39</figcaption>
|
||||
</div>
|
||||
<div>
|
||||
<img class="rounded-xl" src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/tips_step_49.png"/>
|
||||
<figcaption class="mt-2 text-center text-sm text-gray-500">step 49</figcaption>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
@@ -429,6 +429,27 @@ image = pipe(
|
||||
make_image_grid([original_image, canny_image, image], rows=1, cols=3)
|
||||
```
|
||||
|
||||
<Tip>
|
||||
|
||||
You can use a refiner model with `StableDiffusionXLControlNetPipeline` to improve image quality, just like you can with a regular `StableDiffusionXLPipeline`.
|
||||
See the [Refine image quality](./sdxl#refine-image-quality) section to learn how to use the refiner model.
|
||||
Make sure to use `StableDiffusionXLControlNetPipeline` and pass `image` and `controlnet_conditioning_scale`.
|
||||
|
||||
```py
|
||||
base = StableDiffusionXLControlNetPipeline(...)
|
||||
image = base(
|
||||
prompt=prompt,
|
||||
controlnet_conditioning_scale=0.5,
|
||||
image=canny_image,
|
||||
num_inference_steps=40,
|
||||
denoising_end=0.8,
|
||||
output_type="latent",
|
||||
).images
|
||||
# rest exactly as with StableDiffusionXLPipeline
|
||||
```
|
||||
|
||||
</Tip>
|
||||
|
||||
## MultiControlNet
|
||||
|
||||
<Tip>
|
||||
|
||||
@@ -0,0 +1,438 @@
|
||||
<!--Copyright 2024 The HuggingFace Team. All rights reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
|
||||
the License. You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
|
||||
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
|
||||
specific language governing permissions and limitations under the License.
|
||||
-->
|
||||
|
||||
[[open-in-colab]]
|
||||
|
||||
# Trajectory Consistency Distillation-LoRA
|
||||
|
||||
Trajectory Consistency Distillation (TCD) enables a model to generate higher quality and more detailed images with fewer steps. Moreover, owing to the effective error mitigation during the distillation process, TCD demonstrates superior performance even under conditions of large inference steps.
|
||||
|
||||
The major advantages of TCD are:
|
||||
|
||||
- Better than Teacher: TCD demonstrates superior generative quality at both small and large inference steps and exceeds the performance of [DPM-Solver++(2S)](../../api/schedulers/multistep_dpm_solver) with Stable Diffusion XL (SDXL). There is no additional discriminator or LPIPS supervision included during TCD training.
|
||||
|
||||
- Flexible Inference Steps: The inference steps for TCD sampling can be freely adjusted without adversely affecting the image quality.
|
||||
|
||||
- Freely change detail level: During inference, the level of detail in the image can be adjusted with a single hyperparameter, *gamma*.
|
||||
|
||||
> [!TIP]
|
||||
> For more technical details of TCD, please refer to the [paper](https://arxiv.org/abs/2402.19159) or official [project page](https://mhh0318.github.io/tcd/)).
|
||||
|
||||
For large models like SDXL, TCD is trained with [LoRA](https://huggingface.co/docs/peft/conceptual_guides/adapter#low-rank-adaptation-lora) to reduce memory usage. This is also useful because you can reuse LoRAs between different finetuned models, as long as they share the same base model, without further training.
|
||||
|
||||
|
||||
|
||||
This guide will show you how to perform inference with TCD-LoRAs for a variety of tasks like text-to-image and inpainting, as well as how you can easily combine TCD-LoRAs with other adapters. Choose one of the supported base model and it's corresponding TCD-LoRA checkpoint from the table below to get started.
|
||||
|
||||
| Base model | TCD-LoRA checkpoint |
|
||||
|-------------------------------------------------------------------------------------------------|----------------------------------------------------------------|
|
||||
| [stable-diffusion-v1-5](https://huggingface.co/runwayml/stable-diffusion-v1-5) | [TCD-SD15](https://huggingface.co/h1t/TCD-SD15-LoRA) |
|
||||
| [stable-diffusion-2-1-base](https://huggingface.co/stabilityai/stable-diffusion-2-1-base) | [TCD-SD21-base](https://huggingface.co/h1t/TCD-SD21-base-LoRA) |
|
||||
| [stable-diffusion-xl-base-1.0](https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0) | [TCD-SDXL](https://huggingface.co/h1t/TCD-SDXL-LoRA) |
|
||||
|
||||
|
||||
Make sure you have [PEFT](https://github.com/huggingface/peft) installed for better LoRA support.
|
||||
|
||||
```bash
|
||||
pip install -U peft
|
||||
```
|
||||
|
||||
## General tasks
|
||||
|
||||
In this guide, let's use the [`StableDiffusionXLPipeline`] and the [`TCDScheduler`]. Use the [`~StableDiffusionPipeline.load_lora_weights`] method to load the SDXL-compatible TCD-LoRA weights.
|
||||
|
||||
A few tips to keep in mind for TCD-LoRA inference are to:
|
||||
|
||||
- Keep the `num_inference_steps` between 4 and 50
|
||||
- Set `eta` (used to control stochasticity at each step) between 0 and 1. You should use a higher `eta` when increasing the number of inference steps, but the downside is that a larger `eta` in [`TCDScheduler`] leads to blurrier images. A value of 0.3 is recommended to produce good results.
|
||||
|
||||
<hfoptions id="tasks">
|
||||
<hfoption id="text-to-image">
|
||||
|
||||
```python
|
||||
import torch
|
||||
from diffusers import StableDiffusionXLPipeline, TCDScheduler
|
||||
|
||||
device = "cuda"
|
||||
base_model_id = "stabilityai/stable-diffusion-xl-base-1.0"
|
||||
tcd_lora_id = "h1t/TCD-SDXL-LoRA"
|
||||
|
||||
pipe = StableDiffusionXLPipeline.from_pretrained(base_model_id, torch_dtype=torch.float16, variant="fp16").to(device)
|
||||
pipe.scheduler = TCDScheduler.from_config(pipe.scheduler.config)
|
||||
|
||||
pipe.load_lora_weights(tcd_lora_id)
|
||||
pipe.fuse_lora()
|
||||
|
||||
prompt = "Painting of the orange cat Otto von Garfield, Count of Bismarck-Schönhausen, Duke of Lauenburg, Minister-President of Prussia. Depicted wearing a Prussian Pickelhaube and eating his favorite meal - lasagna."
|
||||
|
||||
image = pipe(
|
||||
prompt=prompt,
|
||||
num_inference_steps=4,
|
||||
guidance_scale=0,
|
||||
eta=0.3,
|
||||
generator=torch.Generator(device=device).manual_seed(0),
|
||||
).images[0]
|
||||
```
|
||||
|
||||

|
||||
|
||||
</hfoption>
|
||||
|
||||
<hfoption id="inpainting">
|
||||
|
||||
```python
|
||||
import torch
|
||||
from diffusers import AutoPipelineForInpainting, TCDScheduler
|
||||
from diffusers.utils import load_image, make_image_grid
|
||||
|
||||
device = "cuda"
|
||||
base_model_id = "diffusers/stable-diffusion-xl-1.0-inpainting-0.1"
|
||||
tcd_lora_id = "h1t/TCD-SDXL-LoRA"
|
||||
|
||||
pipe = AutoPipelineForInpainting.from_pretrained(base_model_id, torch_dtype=torch.float16, variant="fp16").to(device)
|
||||
pipe.scheduler = TCDScheduler.from_config(pipe.scheduler.config)
|
||||
|
||||
pipe.load_lora_weights(tcd_lora_id)
|
||||
pipe.fuse_lora()
|
||||
|
||||
img_url = "https://raw.githubusercontent.com/CompVis/latent-diffusion/main/data/inpainting_examples/overture-creations-5sI6fQgYIuo.png"
|
||||
mask_url = "https://raw.githubusercontent.com/CompVis/latent-diffusion/main/data/inpainting_examples/overture-creations-5sI6fQgYIuo_mask.png"
|
||||
|
||||
init_image = load_image(img_url).resize((1024, 1024))
|
||||
mask_image = load_image(mask_url).resize((1024, 1024))
|
||||
|
||||
prompt = "a tiger sitting on a park bench"
|
||||
|
||||
image = pipe(
|
||||
prompt=prompt,
|
||||
image=init_image,
|
||||
mask_image=mask_image,
|
||||
num_inference_steps=8,
|
||||
guidance_scale=0,
|
||||
eta=0.3,
|
||||
strength=0.99, # make sure to use `strength` below 1.0
|
||||
generator=torch.Generator(device=device).manual_seed(0),
|
||||
).images[0]
|
||||
|
||||
grid_image = make_image_grid([init_image, mask_image, image], rows=1, cols=3)
|
||||
```
|
||||
|
||||

|
||||
|
||||
|
||||
</hfoption>
|
||||
</hfoptions>
|
||||
|
||||
## Community models
|
||||
|
||||
TCD-LoRA also works with many community finetuned models and plugins. For example, load the [animagine-xl-3.0](https://huggingface.co/cagliostrolab/animagine-xl-3.0) checkpoint which is a community finetuned version of SDXL for generating anime images.
|
||||
|
||||
```python
|
||||
import torch
|
||||
from diffusers import StableDiffusionXLPipeline, TCDScheduler
|
||||
|
||||
device = "cuda"
|
||||
base_model_id = "cagliostrolab/animagine-xl-3.0"
|
||||
tcd_lora_id = "h1t/TCD-SDXL-LoRA"
|
||||
|
||||
pipe = StableDiffusionXLPipeline.from_pretrained(base_model_id, torch_dtype=torch.float16, variant="fp16").to(device)
|
||||
pipe.scheduler = TCDScheduler.from_config(pipe.scheduler.config)
|
||||
|
||||
pipe.load_lora_weights(tcd_lora_id)
|
||||
pipe.fuse_lora()
|
||||
|
||||
prompt = "A man, clad in a meticulously tailored military uniform, stands with unwavering resolve. The uniform boasts intricate details, and his eyes gleam with determination. Strands of vibrant, windswept hair peek out from beneath the brim of his cap."
|
||||
|
||||
image = pipe(
|
||||
prompt=prompt,
|
||||
num_inference_steps=8,
|
||||
guidance_scale=0,
|
||||
eta=0.3,
|
||||
generator=torch.Generator(device=device).manual_seed(0),
|
||||
).images[0]
|
||||
```
|
||||
|
||||

|
||||
|
||||
TCD-LoRA also supports other LoRAs trained on different styles. For example, let's load the [TheLastBen/Papercut_SDXL](https://huggingface.co/TheLastBen/Papercut_SDXL) LoRA and fuse it with the TCD-LoRA with the [`~loaders.UNet2DConditionLoadersMixin.set_adapters`] method.
|
||||
|
||||
> [!TIP]
|
||||
> Check out the [Merge LoRAs](merge_loras) guide to learn more about efficient merging methods.
|
||||
|
||||
```python
|
||||
import torch
|
||||
from diffusers import StableDiffusionXLPipeline
|
||||
from scheduling_tcd import TCDScheduler
|
||||
|
||||
device = "cuda"
|
||||
base_model_id = "stabilityai/stable-diffusion-xl-base-1.0"
|
||||
tcd_lora_id = "h1t/TCD-SDXL-LoRA"
|
||||
styled_lora_id = "TheLastBen/Papercut_SDXL"
|
||||
|
||||
pipe = StableDiffusionXLPipeline.from_pretrained(base_model_id, torch_dtype=torch.float16, variant="fp16").to(device)
|
||||
pipe.scheduler = TCDScheduler.from_config(pipe.scheduler.config)
|
||||
|
||||
pipe.load_lora_weights(tcd_lora_id, adapter_name="tcd")
|
||||
pipe.load_lora_weights(styled_lora_id, adapter_name="style")
|
||||
pipe.set_adapters(["tcd", "style"], adapter_weights=[1.0, 1.0])
|
||||
|
||||
prompt = "papercut of a winter mountain, snow"
|
||||
|
||||
image = pipe(
|
||||
prompt=prompt,
|
||||
num_inference_steps=4,
|
||||
guidance_scale=0,
|
||||
eta=0.3,
|
||||
generator=torch.Generator(device=device).manual_seed(0),
|
||||
).images[0]
|
||||
```
|
||||
|
||||

|
||||
|
||||
|
||||
## Adapters
|
||||
|
||||
TCD-LoRA is very versatile, and it can be combined with other adapter types like ControlNets, IP-Adapter, and AnimateDiff.
|
||||
|
||||
<hfoptions id="adapters">
|
||||
<hfoption id="ControlNet">
|
||||
|
||||
### Depth ControlNet
|
||||
|
||||
```python
|
||||
import torch
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
from transformers import DPTFeatureExtractor, DPTForDepthEstimation
|
||||
from diffusers import ControlNetModel, StableDiffusionXLControlNetPipeline
|
||||
from diffusers.utils import load_image, make_image_grid
|
||||
from scheduling_tcd import TCDScheduler
|
||||
|
||||
device = "cuda"
|
||||
depth_estimator = DPTForDepthEstimation.from_pretrained("Intel/dpt-hybrid-midas").to(device)
|
||||
feature_extractor = DPTFeatureExtractor.from_pretrained("Intel/dpt-hybrid-midas")
|
||||
|
||||
def get_depth_map(image):
|
||||
image = feature_extractor(images=image, return_tensors="pt").pixel_values.to(device)
|
||||
with torch.no_grad(), torch.autocast(device):
|
||||
depth_map = depth_estimator(image).predicted_depth
|
||||
|
||||
depth_map = torch.nn.functional.interpolate(
|
||||
depth_map.unsqueeze(1),
|
||||
size=(1024, 1024),
|
||||
mode="bicubic",
|
||||
align_corners=False,
|
||||
)
|
||||
depth_min = torch.amin(depth_map, dim=[1, 2, 3], keepdim=True)
|
||||
depth_max = torch.amax(depth_map, dim=[1, 2, 3], keepdim=True)
|
||||
depth_map = (depth_map - depth_min) / (depth_max - depth_min)
|
||||
image = torch.cat([depth_map] * 3, dim=1)
|
||||
|
||||
image = image.permute(0, 2, 3, 1).cpu().numpy()[0]
|
||||
image = Image.fromarray((image * 255.0).clip(0, 255).astype(np.uint8))
|
||||
return image
|
||||
|
||||
base_model_id = "stabilityai/stable-diffusion-xl-base-1.0"
|
||||
controlnet_id = "diffusers/controlnet-depth-sdxl-1.0"
|
||||
tcd_lora_id = "h1t/TCD-SDXL-LoRA"
|
||||
|
||||
controlnet = ControlNetModel.from_pretrained(
|
||||
controlnet_id,
|
||||
torch_dtype=torch.float16,
|
||||
variant="fp16",
|
||||
).to(device)
|
||||
pipe = StableDiffusionXLControlNetPipeline.from_pretrained(
|
||||
base_model_id,
|
||||
controlnet=controlnet,
|
||||
torch_dtype=torch.float16,
|
||||
variant="fp16",
|
||||
).to(device)
|
||||
pipe.enable_model_cpu_offload()
|
||||
|
||||
pipe.scheduler = TCDScheduler.from_config(pipe.scheduler.config)
|
||||
|
||||
pipe.load_lora_weights(tcd_lora_id)
|
||||
pipe.fuse_lora()
|
||||
|
||||
prompt = "stormtrooper lecture, photorealistic"
|
||||
|
||||
image = load_image("https://huggingface.co/lllyasviel/sd-controlnet-depth/resolve/main/images/stormtrooper.png")
|
||||
depth_image = get_depth_map(image)
|
||||
|
||||
controlnet_conditioning_scale = 0.5 # recommended for good generalization
|
||||
|
||||
image = pipe(
|
||||
prompt,
|
||||
image=depth_image,
|
||||
num_inference_steps=4,
|
||||
guidance_scale=0,
|
||||
eta=0.3,
|
||||
controlnet_conditioning_scale=controlnet_conditioning_scale,
|
||||
generator=torch.Generator(device=device).manual_seed(0),
|
||||
).images[0]
|
||||
|
||||
grid_image = make_image_grid([depth_image, image], rows=1, cols=2)
|
||||
```
|
||||
|
||||

|
||||
|
||||
### Canny ControlNet
|
||||
```python
|
||||
import torch
|
||||
from diffusers import ControlNetModel, StableDiffusionXLControlNetPipeline
|
||||
from diffusers.utils import load_image, make_image_grid
|
||||
from scheduling_tcd import TCDScheduler
|
||||
|
||||
device = "cuda"
|
||||
base_model_id = "stabilityai/stable-diffusion-xl-base-1.0"
|
||||
controlnet_id = "diffusers/controlnet-canny-sdxl-1.0"
|
||||
tcd_lora_id = "h1t/TCD-SDXL-LoRA"
|
||||
|
||||
controlnet = ControlNetModel.from_pretrained(
|
||||
controlnet_id,
|
||||
torch_dtype=torch.float16,
|
||||
variant="fp16",
|
||||
).to(device)
|
||||
pipe = StableDiffusionXLControlNetPipeline.from_pretrained(
|
||||
base_model_id,
|
||||
controlnet=controlnet,
|
||||
torch_dtype=torch.float16,
|
||||
variant="fp16",
|
||||
).to(device)
|
||||
pipe.enable_model_cpu_offload()
|
||||
|
||||
pipe.scheduler = TCDScheduler.from_config(pipe.scheduler.config)
|
||||
|
||||
pipe.load_lora_weights(tcd_lora_id)
|
||||
pipe.fuse_lora()
|
||||
|
||||
prompt = "ultrarealistic shot of a furry blue bird"
|
||||
|
||||
canny_image = load_image("https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/sd_controlnet/bird_canny.png")
|
||||
|
||||
controlnet_conditioning_scale = 0.5 # recommended for good generalization
|
||||
|
||||
image = pipe(
|
||||
prompt,
|
||||
image=canny_image,
|
||||
num_inference_steps=4,
|
||||
guidance_scale=0,
|
||||
eta=0.3,
|
||||
controlnet_conditioning_scale=controlnet_conditioning_scale,
|
||||
generator=torch.Generator(device=device).manual_seed(0),
|
||||
).images[0]
|
||||
|
||||
grid_image = make_image_grid([canny_image, image], rows=1, cols=2)
|
||||
```
|
||||

|
||||
|
||||
<Tip>
|
||||
The inference parameters in this example might not work for all examples, so we recommend you to try different values for `num_inference_steps`, `guidance_scale`, `controlnet_conditioning_scale` and `cross_attention_kwargs` parameters and choose the best one.
|
||||
</Tip>
|
||||
|
||||
</hfoption>
|
||||
<hfoption id="IP-Adapter">
|
||||
|
||||
This example shows how to use the TCD-LoRA with the [IP-Adapter](https://github.com/tencent-ailab/IP-Adapter/tree/main) and SDXL.
|
||||
|
||||
```python
|
||||
import torch
|
||||
from diffusers import StableDiffusionXLPipeline
|
||||
from diffusers.utils import load_image, make_image_grid
|
||||
|
||||
from ip_adapter import IPAdapterXL
|
||||
from scheduling_tcd import TCDScheduler
|
||||
|
||||
device = "cuda"
|
||||
base_model_path = "stabilityai/stable-diffusion-xl-base-1.0"
|
||||
image_encoder_path = "sdxl_models/image_encoder"
|
||||
ip_ckpt = "sdxl_models/ip-adapter_sdxl.bin"
|
||||
tcd_lora_id = "h1t/TCD-SDXL-LoRA"
|
||||
|
||||
pipe = StableDiffusionXLPipeline.from_pretrained(
|
||||
base_model_path,
|
||||
torch_dtype=torch.float16,
|
||||
variant="fp16"
|
||||
)
|
||||
pipe.scheduler = TCDScheduler.from_config(pipe.scheduler.config)
|
||||
|
||||
pipe.load_lora_weights(tcd_lora_id)
|
||||
pipe.fuse_lora()
|
||||
|
||||
ip_model = IPAdapterXL(pipe, image_encoder_path, ip_ckpt, device)
|
||||
|
||||
ref_image = load_image("https://raw.githubusercontent.com/tencent-ailab/IP-Adapter/main/assets/images/woman.png").resize((512, 512))
|
||||
|
||||
prompt = "best quality, high quality, wearing sunglasses"
|
||||
|
||||
image = ip_model.generate(
|
||||
pil_image=ref_image,
|
||||
prompt=prompt,
|
||||
scale=0.5,
|
||||
num_samples=1,
|
||||
num_inference_steps=4,
|
||||
guidance_scale=0,
|
||||
eta=0.3,
|
||||
seed=0,
|
||||
)[0]
|
||||
|
||||
grid_image = make_image_grid([ref_image, image], rows=1, cols=2)
|
||||
```
|
||||
|
||||

|
||||
|
||||
|
||||
|
||||
</hfoption>
|
||||
<hfoption id="AnimateDiff">
|
||||
|
||||
[`AnimateDiff`] allows animating images using Stable Diffusion models. TCD-LoRA can substantially accelerate the process without degrading image quality. The quality of animation with TCD-LoRA and AnimateDiff has a more lucid outcome.
|
||||
|
||||
```python
|
||||
import torch
|
||||
from diffusers import MotionAdapter, AnimateDiffPipeline, DDIMScheduler
|
||||
from scheduling_tcd import TCDScheduler
|
||||
from diffusers.utils import export_to_gif
|
||||
|
||||
adapter = MotionAdapter.from_pretrained("guoyww/animatediff-motion-adapter-v1-5")
|
||||
pipe = AnimateDiffPipeline.from_pretrained(
|
||||
"frankjoshua/toonyou_beta6",
|
||||
motion_adapter=adapter,
|
||||
).to("cuda")
|
||||
|
||||
# set TCDScheduler
|
||||
pipe.scheduler = TCDScheduler.from_config(pipe.scheduler.config)
|
||||
|
||||
# load TCD LoRA
|
||||
pipe.load_lora_weights("h1t/TCD-SD15-LoRA", adapter_name="tcd")
|
||||
pipe.load_lora_weights("guoyww/animatediff-motion-lora-zoom-in", weight_name="diffusion_pytorch_model.safetensors", adapter_name="motion-lora")
|
||||
|
||||
pipe.set_adapters(["tcd", "motion-lora"], adapter_weights=[1.0, 1.2])
|
||||
|
||||
prompt = "best quality, masterpiece, 1girl, looking at viewer, blurry background, upper body, contemporary, dress"
|
||||
generator = torch.manual_seed(0)
|
||||
frames = pipe(
|
||||
prompt=prompt,
|
||||
num_inference_steps=5,
|
||||
guidance_scale=0,
|
||||
cross_attention_kwargs={"scale": 1},
|
||||
num_frames=24,
|
||||
eta=0.3,
|
||||
generator=generator
|
||||
).frames[0]
|
||||
export_to_gif(frames, "animation.gif")
|
||||
```
|
||||
|
||||

|
||||
|
||||
</hfoption>
|
||||
</hfoptions>
|
||||
@@ -25,6 +25,9 @@ Let's take a look at how to use IP-Adapter's image prompting capabilities with t
|
||||
|
||||
In all the following examples, you'll see the [`~loaders.IPAdapterMixin.set_ip_adapter_scale`] method. This method controls the amount of text or image conditioning to apply to the model. A value of `1.0` means the model is only conditioned on the image prompt. Lowering this value encourages the model to produce more diverse images, but they may not be as aligned with the image prompt. Typically, a value of `0.5` achieves a good balance between the two prompt types and produces good results.
|
||||
|
||||
> [!TIP]
|
||||
> In the examples below, try adding `low_cpu_mem_usage=True` to the [`~loaders.IPAdapterMixin.load_ip_adapter`] method to speed up the loading time.
|
||||
|
||||
<hfoptions id="tasks">
|
||||
<hfoption id="Text-to-image">
|
||||
|
||||
@@ -231,10 +234,21 @@ export_to_gif(frames, "gummy_bear.gif")
|
||||
</hfoption>
|
||||
</hfoptions>
|
||||
|
||||
> [!TIP]
|
||||
> While calling `load_ip_adapter()`, pass `low_cpu_mem_usage=True` to speed up the loading time.
|
||||
## Configure parameters
|
||||
|
||||
All the pipelines supporting IP-Adapter accept a `ip_adapter_image_embeds` argument. If you need to run the IP-Adapter multiple times with the same image, you can encode the image once and save the embedding to the disk.
|
||||
There are a couple of IP-Adapter parameters that are useful to know about and can help you with your image generation tasks. These parameters can make your workflow more efficient or give you more control over image generation.
|
||||
|
||||
### Image embeddings
|
||||
|
||||
IP-Adapter enabled pipelines provide the `ip_adapter_image_embeds` parameter to accept precomputed image embeddings. This is particularly useful in scenarios where you need to run the IP-Adapter pipeline multiple times because you have more than one image. For example, [multi IP-Adapter](#multi-ip-adapter) is a specific use case where you provide multiple styling images to generate a specific image in a specific style. Loading and encoding multiple images each time you use the pipeline would be inefficient. Instead, you can precompute and save the image embeddings to disk (which can save a lot of space if you're using high-quality images) and load them when you need them.
|
||||
|
||||
> [!TIP]
|
||||
> This parameter also gives you the flexibility to load embeddings from other sources. For example, ComfyUI image embeddings for IP-Adapters are compatible with Diffusers and should work ouf-of-the-box!
|
||||
|
||||
Call the [`~StableDiffusionPipeline.prepare_ip_adapter_image_embeds`] method to encode and generate the image embeddings. Then you can save them to disk with `torch.save`.
|
||||
|
||||
> [!TIP]
|
||||
> If you're using IP-Adapter with `ip_adapter_image_embedding` instead of `ip_adapter_image`', you can set `load_ip_adapter(image_encoder_folder=None,...)` because you don't need to load an encoder to generate the image embeddings.
|
||||
|
||||
```py
|
||||
image_embeds = pipeline.prepare_ip_adapter_image_embeds(
|
||||
@@ -248,10 +262,7 @@ image_embeds = pipeline.prepare_ip_adapter_image_embeds(
|
||||
torch.save(image_embeds, "image_embeds.ipadpt")
|
||||
```
|
||||
|
||||
Load the image embedding and pass it to the pipeline as `ip_adapter_image_embeds`
|
||||
|
||||
> [!TIP]
|
||||
> ComfyUI image embeddings for IP-Adapters are fully compatible in Diffusers and should work out-of-box.
|
||||
Now load the image embeddings by passing them to the `ip_adapter_image_embeds` parameter.
|
||||
|
||||
```py
|
||||
image_embeds = torch.load("image_embeds.ipadpt")
|
||||
@@ -264,8 +275,86 @@ images = pipeline(
|
||||
).images
|
||||
```
|
||||
|
||||
> [!TIP]
|
||||
> If you use IP-Adapter with `ip_adapter_image_embedding` instead of `ip_adapter_image`, you can choose not to load an image encoder by passing `image_encoder_folder=None` to `load_ip_adapter()`.
|
||||
### IP-Adapter masking
|
||||
|
||||
Binary masks specify which portion of the output image should be assigned to an IP-Adapter. This is useful for composing more than one IP-Adapter image. For each input IP-Adapter image, you must provide a binary mask an an IP-Adapter.
|
||||
|
||||
To start, preprocess the input IP-Adapter images with the [`~image_processor.IPAdapterMaskProcessor.preprocess()`] to generate their masks. For optimal results, provide the output height and width to [`~image_processor.IPAdapterMaskProcessor.preprocess()`]. This ensures masks with different aspect ratios are appropriately stretched. If the input masks already match the aspect ratio of the generated image, you don't have to set the `height` and `width`.
|
||||
|
||||
```py
|
||||
from diffusers.image_processor import IPAdapterMaskProcessor
|
||||
|
||||
mask1 = load_image("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/ip_mask_mask1.png")
|
||||
mask2 = load_image("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/ip_mask_mask2.png")
|
||||
|
||||
output_height = 1024
|
||||
output_width = 1024
|
||||
|
||||
processor = IPAdapterMaskProcessor()
|
||||
masks = processor.preprocess([mask1, mask2], height=output_height, width=output_width)
|
||||
```
|
||||
|
||||
<div class="flex flex-row gap-4">
|
||||
<div class="flex-1">
|
||||
<img class="rounded-xl" src="https://huggingface.co/datasets/YiYiXu/testing-images/resolve/main/ip_mask_mask1.png"/>
|
||||
<figcaption class="mt-2 text-center text-sm text-gray-500">mask one</figcaption>
|
||||
</div>
|
||||
<div class="flex-1">
|
||||
<img class="rounded-xl" src="https://huggingface.co/datasets/YiYiXu/testing-images/resolve/main/ip_mask_mask2.png"/>
|
||||
<figcaption class="mt-2 text-center text-sm text-gray-500">mask two</figcaption>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
When there is more than one input IP-Adapter image, load them as a list to ensure each image is assigned to a different IP-Adapter. Each of the input IP-Adapter images here correspond to the masks generated above.
|
||||
|
||||
```py
|
||||
face_image1 = load_image("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/ip_mask_girl1.png")
|
||||
face_image2 = load_image("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/ip_mask_girl2.png")
|
||||
|
||||
ip_images = [[face_image1], [face_image2]]
|
||||
```
|
||||
|
||||
<div class="flex flex-row gap-4">
|
||||
<div class="flex-1">
|
||||
<img class="rounded-xl" src="https://huggingface.co/datasets/YiYiXu/testing-images/resolve/main/ip_mask_girl1.png"/>
|
||||
<figcaption class="mt-2 text-center text-sm text-gray-500">IP-Adapter image one</figcaption>
|
||||
</div>
|
||||
<div class="flex-1">
|
||||
<img class="rounded-xl" src="https://huggingface.co/datasets/YiYiXu/testing-images/resolve/main/ip_mask_girl2.png"/>
|
||||
<figcaption class="mt-2 text-center text-sm text-gray-500">IP-Adapter image two</figcaption>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
Now pass the preprocessed masks to `cross_attention_kwargs` in the pipeline call.
|
||||
|
||||
```py
|
||||
pipeline.load_ip_adapter("h94/IP-Adapter", subfolder="sdxl_models", weight_name=["ip-adapter-plus-face_sdxl_vit-h.safetensors"] * 2)
|
||||
pipeline.set_ip_adapter_scale([0.7] * 2)
|
||||
generator = torch.Generator(device="cpu").manual_seed(0)
|
||||
num_images = 1
|
||||
|
||||
image = pipeline(
|
||||
prompt="2 girls",
|
||||
ip_adapter_image=ip_images,
|
||||
negative_prompt="monochrome, lowres, bad anatomy, worst quality, low quality",
|
||||
num_inference_steps=20,
|
||||
num_images_per_prompt=num_images,
|
||||
generator=generator,
|
||||
cross_attention_kwargs={"ip_adapter_masks": masks}
|
||||
).images[0]
|
||||
image
|
||||
```
|
||||
|
||||
<div class="flex flex-row gap-4">
|
||||
<div class="flex-1">
|
||||
<img class="rounded-xl" src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/ip_adapter_attention_mask_result_seed_0.png"/>
|
||||
<figcaption class="mt-2 text-center text-sm text-gray-500">IP-Adapter masking applied</figcaption>
|
||||
</div>
|
||||
<div class="flex-1">
|
||||
<img class="rounded-xl" src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/ip_adapter_no_attention_mask_result_seed_0.png"/>
|
||||
<figcaption class="mt-2 text-center text-sm text-gray-500">no IP-Adapter masking applied</figcaption>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
## Specific use cases
|
||||
|
||||
@@ -279,6 +368,7 @@ Generating accurate faces is challenging because they are complex and nuanced. D
|
||||
* [ip-adapter-plus-face_sd15.safetensors](https://huggingface.co/h94/IP-Adapter/blob/main/models/ip-adapter-plus-face_sd15.safetensors) uses patch embeddings and is conditioned with images of cropped faces
|
||||
|
||||
> [!TIP]
|
||||
>
|
||||
> [IP-Adapter-FaceID](https://huggingface.co/h94/IP-Adapter-FaceID) is a face-specific IP-Adapter trained with face ID embeddings instead of CLIP image embeddings, allowing you to generate more consistent faces in different contexts and styles. Try out this popular [community pipeline](https://github.com/huggingface/diffusers/tree/main/examples/community#ip-adapter-face-id) and see how it compares to the other face IP-Adapters.
|
||||
|
||||
For face models, use the [h94/IP-Adapter](https://huggingface.co/h94/IP-Adapter) checkpoint. It is also recommended to use [`DDIMScheduler`] or [`EulerDiscreteScheduler`] for face models.
|
||||
@@ -502,82 +592,3 @@ image
|
||||
<div class="flex justify-center">
|
||||
<img src="https://huggingface.co/datasets/YiYiXu/testing-images/resolve/main/ipa-controlnet-out.png" />
|
||||
</div>
|
||||
|
||||
### IP-Adapter masking
|
||||
|
||||
Binary masks can be used to specify which portion of the output image should be assigned to an IP-Adapter.
|
||||
For each input IP-Adapter image, a binary mask and an IP-Adapter must be provided.
|
||||
|
||||
Before passing the masks to the pipeline, it's essential to preprocess them using [`IPAdapterMaskProcessor.preprocess()`].
|
||||
|
||||
> [!TIP]
|
||||
> For optimal results, provide the output height and width to [`IPAdapterMaskProcessor.preprocess()`]. This ensures that masks with differing aspect ratios are appropriately stretched. If the input masks already match the aspect ratio of the generated image, specifying height and width can be omitted.
|
||||
|
||||
Here an example with two masks:
|
||||
|
||||
```py
|
||||
from diffusers.image_processor import IPAdapterMaskProcessor
|
||||
|
||||
mask1 = load_image("https://huggingface.co/datasets/YiYiXu/testing-images/resolve/main/ip_mask_mask1.png")
|
||||
mask2 = load_image("https://huggingface.co/datasets/YiYiXu/testing-images/resolve/main/ip_mask_mask2.png")
|
||||
|
||||
output_height = 1024
|
||||
output_width = 1024
|
||||
|
||||
processor = IPAdapterMaskProcessor()
|
||||
masks = processor.preprocess([mask1, mask2], height=output_height, width=output_width)
|
||||
```
|
||||
|
||||
<div class="flex flex-row gap-4">
|
||||
<div class="flex-1">
|
||||
<img class="rounded-xl" src="https://huggingface.co/datasets/YiYiXu/testing-images/resolve/main/ip_mask_mask1.png"/>
|
||||
<figcaption class="mt-2 text-center text-sm text-gray-500">mask one</figcaption>
|
||||
</div>
|
||||
<div class="flex-1">
|
||||
<img class="rounded-xl" src="https://huggingface.co/datasets/YiYiXu/testing-images/resolve/main/ip_mask_mask2.png"/>
|
||||
<figcaption class="mt-2 text-center text-sm text-gray-500">mask two</figcaption>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
If you have more than one IP-Adapter image, load them into a list, ensuring each image is assigned to a different IP-Adapter.
|
||||
|
||||
```py
|
||||
face_image1 = load_image("https://huggingface.co/datasets/YiYiXu/testing-images/resolve/main/ip_mask_girl1.png")
|
||||
face_image2 = load_image("https://huggingface.co/datasets/YiYiXu/testing-images/resolve/main/ip_mask_girl2.png")
|
||||
|
||||
ip_images = [[face_image1], [face_image2]]
|
||||
```
|
||||
|
||||
<div class="flex flex-row gap-4">
|
||||
<div class="flex-1">
|
||||
<img class="rounded-xl" src="https://huggingface.co/datasets/YiYiXu/testing-images/resolve/main/ip_mask_girl1.png"/>
|
||||
<figcaption class="mt-2 text-center text-sm text-gray-500">ip adapter image one</figcaption>
|
||||
</div>
|
||||
<div class="flex-1">
|
||||
<img class="rounded-xl" src="https://huggingface.co/datasets/YiYiXu/testing-images/resolve/main/ip_mask_girl2.png"/>
|
||||
<figcaption class="mt-2 text-center text-sm text-gray-500">ip adapter image two</figcaption>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
Pass preprocessed masks to the pipeline using `cross_attention_kwargs` as shown below:
|
||||
|
||||
```py
|
||||
pipeline.load_ip_adapter("h94/IP-Adapter", subfolder="sdxl_models", weight_name=["ip-adapter-plus-face_sdxl_vit-h.safetensors"] * 2)
|
||||
pipeline.set_ip_adapter_scale([0.7] * 2)
|
||||
generator = torch.Generator(device="cpu").manual_seed(0)
|
||||
num_images = 1
|
||||
|
||||
image = pipeline(
|
||||
prompt="2 girls",
|
||||
ip_adapter_image=ip_images,
|
||||
negative_prompt="monochrome, lowres, bad anatomy, worst quality, low quality",
|
||||
num_inference_steps=20, num_images_per_prompt=num_images,
|
||||
generator=generator, cross_attention_kwargs={"ip_adapter_masks": masks}
|
||||
).images[0]
|
||||
image
|
||||
```
|
||||
|
||||
<div class="flex justify-center">
|
||||
<img src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/ip_adapter_attention_mask_result_seed_0.png" />
|
||||
<figcaption class="mt-2 text-center text-sm text-gray-500">output image</figcaption>
|
||||
</div>
|
||||
|
||||
@@ -103,7 +103,7 @@ image
|
||||
|
||||
<Tip>
|
||||
|
||||
LoRA is a very general training technique that can be used with other training methods. For example, it is common to train a model with DreamBooth and LoRA.
|
||||
LoRA is a very general training technique that can be used with other training methods. For example, it is common to train a model with DreamBooth and LoRA. It is also increasingly common to load and merge multiple LoRAs to create new and unique images. You can learn more about it in the in-depth [Merge LoRAs](merge_loras) guide since merging is outside the scope of this loading guide.
|
||||
|
||||
</Tip>
|
||||
|
||||
@@ -165,101 +165,14 @@ To unload the LoRA weights, use the [`~loaders.LoraLoaderMixin.unload_lora_weigh
|
||||
pipeline.unload_lora_weights()
|
||||
```
|
||||
|
||||
### Load multiple LoRAs
|
||||
|
||||
It can be fun to use multiple LoRAs together to create something entirely new and unique. The [`~loaders.LoraLoaderMixin.fuse_lora`] method allows you to fuse the LoRA weights with the original weights of the underlying model.
|
||||
|
||||
<Tip>
|
||||
|
||||
Fusing the weights can lead to a speedup in inference latency because you don't need to separately load the base model and LoRA! You can save your fused pipeline with [`~DiffusionPipeline.save_pretrained`] to avoid loading and fusing the weights every time you want to use the model.
|
||||
|
||||
</Tip>
|
||||
|
||||
Load an initial model:
|
||||
|
||||
```py
|
||||
from diffusers import StableDiffusionXLPipeline, AutoencoderKL
|
||||
import torch
|
||||
|
||||
vae = AutoencoderKL.from_pretrained("madebyollin/sdxl-vae-fp16-fix", torch_dtype=torch.float16)
|
||||
pipeline = StableDiffusionXLPipeline.from_pretrained(
|
||||
"stabilityai/stable-diffusion-xl-base-1.0",
|
||||
vae=vae,
|
||||
torch_dtype=torch.float16,
|
||||
).to("cuda")
|
||||
```
|
||||
|
||||
Next, load the LoRA checkpoint and fuse it with the original weights. The `lora_scale` parameter controls how much to scale the output by with the LoRA weights. It is important to make the `lora_scale` adjustments in the [`~loaders.LoraLoaderMixin.fuse_lora`] method because it won't work if you try to pass `scale` to the `cross_attention_kwargs` in the pipeline.
|
||||
|
||||
If you need to reset the original model weights for any reason (use a different `lora_scale`), you should use the [`~loaders.LoraLoaderMixin.unfuse_lora`] method.
|
||||
|
||||
```py
|
||||
pipeline.load_lora_weights("ostris/ikea-instructions-lora-sdxl")
|
||||
pipeline.fuse_lora(lora_scale=0.7)
|
||||
|
||||
# to unfuse the LoRA weights
|
||||
pipeline.unfuse_lora()
|
||||
```
|
||||
|
||||
Then fuse this pipeline with the next set of LoRA weights:
|
||||
|
||||
```py
|
||||
pipeline.load_lora_weights("ostris/super-cereal-sdxl-lora")
|
||||
pipeline.fuse_lora(lora_scale=0.7)
|
||||
```
|
||||
|
||||
<Tip warning={true}>
|
||||
|
||||
You can't unfuse multiple LoRA checkpoints, so if you need to reset the model to its original weights, you'll need to reload it.
|
||||
|
||||
</Tip>
|
||||
|
||||
Now you can generate an image that uses the weights from both LoRAs:
|
||||
|
||||
```py
|
||||
prompt = "A cute brown bear eating a slice of pizza, stunning color scheme, masterpiece, illustration"
|
||||
image = pipeline(prompt).images[0]
|
||||
image
|
||||
```
|
||||
|
||||
### 🤗 PEFT
|
||||
|
||||
<Tip>
|
||||
|
||||
Read the [Inference with 🤗 PEFT](../tutorials/using_peft_for_inference) tutorial to learn more about its integration with 🤗 Diffusers and how you can easily work with and juggle multiple adapters. You'll need to install 🤗 Diffusers and PEFT from source to run the example in this section.
|
||||
|
||||
</Tip>
|
||||
|
||||
Another way you can load and use multiple LoRAs is to specify the `adapter_name` parameter in [`~loaders.LoraLoaderMixin.load_lora_weights`]. This method takes advantage of the 🤗 PEFT integration. For example, load and name both LoRA weights:
|
||||
|
||||
```py
|
||||
from diffusers import DiffusionPipeline
|
||||
import torch
|
||||
|
||||
pipeline = DiffusionPipeline.from_pretrained("stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16).to("cuda")
|
||||
pipeline.load_lora_weights("ostris/ikea-instructions-lora-sdxl", weight_name="ikea_instructions_xl_v1_5.safetensors", adapter_name="ikea")
|
||||
pipeline.load_lora_weights("ostris/super-cereal-sdxl-lora", weight_name="cereal_box_sdxl_v1.safetensors", adapter_name="cereal")
|
||||
```
|
||||
|
||||
Now use the [`~loaders.UNet2DConditionLoadersMixin.set_adapters`] to activate both LoRAs, and you can configure how much weight each LoRA should have on the output:
|
||||
|
||||
```py
|
||||
pipeline.set_adapters(["ikea", "cereal"], adapter_weights=[0.7, 0.5])
|
||||
```
|
||||
|
||||
Then, generate an image:
|
||||
|
||||
```py
|
||||
prompt = "A cute brown bear eating a slice of pizza, stunning color scheme, masterpiece, illustration"
|
||||
image = pipeline(prompt, num_inference_steps=30, cross_attention_kwargs={"scale": 1.0}).images[0]
|
||||
image
|
||||
```
|
||||
|
||||
### Kohya and TheLastBen
|
||||
|
||||
Other popular LoRA trainers from the community include those by [Kohya](https://github.com/kohya-ss/sd-scripts/) and [TheLastBen](https://github.com/TheLastBen/fast-stable-diffusion). These trainers create different LoRA checkpoints than those trained by 🤗 Diffusers, but they can still be loaded in the same way.
|
||||
|
||||
Let's download the [Blueprintify SD XL 1.0](https://civitai.com/models/150986/blueprintify-sd-xl-10) checkpoint from [Civitai](https://civitai.com/):
|
||||
<hfoptions id="other-trainers">
|
||||
<hfoption id="Kohya">
|
||||
|
||||
To load a Kohya LoRA, let's download the [Blueprintify SD XL 1.0](https://civitai.com/models/150986/blueprintify-sd-xl-10) checkpoint from [Civitai](https://civitai.com/) as an example:
|
||||
|
||||
```sh
|
||||
!wget https://civitai.com/api/download/models/168776 -O blueprintify-sd-xl-10.safetensors
|
||||
@@ -293,6 +206,9 @@ Some limitations of using Kohya LoRAs with 🤗 Diffusers include:
|
||||
|
||||
</Tip>
|
||||
|
||||
</hfoption>
|
||||
<hfoption id="TheLastBen">
|
||||
|
||||
Loading a checkpoint from TheLastBen is very similar. For example, to load the [TheLastBen/William_Eggleston_Style_SDXL](https://huggingface.co/TheLastBen/William_Eggleston_Style_SDXL) checkpoint:
|
||||
|
||||
```py
|
||||
@@ -308,6 +224,9 @@ image = pipeline(prompt=prompt).images[0]
|
||||
image
|
||||
```
|
||||
|
||||
</hfoption>
|
||||
</hfoptions>
|
||||
|
||||
## IP-Adapter
|
||||
|
||||
[IP-Adapter](https://ip-adapter.github.io/) is a lightweight adapter that enables image prompting for any diffusion model. This adapter works by decoupling the cross-attention layers of the image and text features. All the other model components are frozen and only the embedded image features in the UNet are trained. As a result, IP-Adapter files are typically only ~100MBs.
|
||||
|
||||
@@ -0,0 +1,266 @@
|
||||
<!--Copyright 2024 The HuggingFace Team. All rights reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
|
||||
the License. You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
|
||||
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
|
||||
specific language governing permissions and limitations under the License.
|
||||
-->
|
||||
|
||||
# Merge LoRAs
|
||||
|
||||
It can be fun and creative to use multiple [LoRAs]((https://huggingface.co/docs/peft/conceptual_guides/adapter#low-rank-adaptation-lora)) together to generate something entirely new and unique. This works by merging multiple LoRA weights together to produce images that are a blend of different styles. Diffusers provides a few methods to merge LoRAs depending on *how* you want to merge their weights, which can affect image quality.
|
||||
|
||||
This guide will show you how to merge LoRAs using the [`~loaders.UNet2DConditionLoadersMixin.set_adapters`] and [`~peft.LoraModel.add_weighted_adapter`] methods. To improve inference speed and reduce memory-usage of merged LoRAs, you'll also see how to use the [`~loaders.LoraLoaderMixin.fuse_lora`] method to fuse the LoRA weights with the original weights of the underlying model.
|
||||
|
||||
For this guide, load a Stable Diffusion XL (SDXL) checkpoint and the [KappaNeuro/studio-ghibli-style]() and [Norod78/sdxl-chalkboarddrawing-lora]() LoRAs with the [`~loaders.LoraLoaderMixin.load_lora_weights`] method. You'll need to assign each LoRA an `adapter_name` to combine them later.
|
||||
|
||||
```py
|
||||
from diffusers import DiffusionPipeline
|
||||
import torch
|
||||
|
||||
pipeline = DiffusionPipeline.from_pretrained("stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16).to("cuda")
|
||||
pipeline.load_lora_weights("ostris/ikea-instructions-lora-sdxl", weight_name="ikea_instructions_xl_v1_5.safetensors", adapter_name="ikea")
|
||||
pipeline.load_lora_weights("lordjia/by-feng-zikai", weight_name="fengzikai_v1.0_XL.safetensors", adapter_name="feng")
|
||||
```
|
||||
|
||||
## set_adapters
|
||||
|
||||
The [`~loaders.UNet2DConditionLoadersMixin.set_adapters`] method merges LoRA adapters by concatenating their weighted matrices. Use the adapter name to specify which LoRAs to merge, and the `adapter_weights` parameter to control the scaling for each LoRA. For example, if `adapter_weights=[0.5, 0.5]`, then the merged LoRA output is an average of both LoRAs. Try adjusting the adapter weights to see how it affects the generated image!
|
||||
|
||||
```py
|
||||
pipeline.set_adapters(["ikea", "feng"], adapter_weights=[0.7, 0.8])
|
||||
|
||||
generator = torch.manual_seed(0)
|
||||
prompt = "A bowl of ramen shaped like a cute kawaii bear, by Feng Zikai"
|
||||
image = pipeline(prompt, generator=generator, cross_attention_kwargs={"scale": 1.0}).images[0]
|
||||
image
|
||||
```
|
||||
|
||||
<div class="flex justify-center">
|
||||
<img src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/lora_merge_set_adapters.png"/>
|
||||
</div>
|
||||
|
||||
## add_weighted_adapter
|
||||
|
||||
> [!WARNING]
|
||||
> This is an experimental method that adds PEFTs [`~peft.LoraModel.add_weighted_adapter`] method to Diffusers to enable more efficient merging methods. Check out this [issue](https://github.com/huggingface/diffusers/issues/6892) if you're interested in learning more about the motivation and design behind this integration.
|
||||
|
||||
The [`~peft.LoraModel.add_weighted_adapter`] method provides access to more efficient merging method such as [TIES and DARE](https://huggingface.co/docs/peft/developer_guides/model_merging). To use these merging methods, make sure you have the latest stable version of Diffusers and PEFT installed.
|
||||
|
||||
```bash
|
||||
pip install -U diffusers peft
|
||||
```
|
||||
|
||||
There are three steps to merge LoRAs with the [`~peft.LoraModel.add_weighted_adapter`] method:
|
||||
|
||||
1. Create a [`~peft.PeftModel`] from the underlying model and LoRA checkpoint.
|
||||
2. Load a base UNet model and the LoRA adapters.
|
||||
3. Merge the adapters using the [`~peft.LoraModel.add_weighted_adapter`] method and the merging method of your choice.
|
||||
|
||||
Let's dive deeper into what these steps entail.
|
||||
|
||||
1. Load a UNet that corresponds to the UNet in the LoRA checkpoint. In this case, both LoRAs use the SDXL UNet as their base model.
|
||||
|
||||
```python
|
||||
from diffusers import UNet2DConditionModel
|
||||
import torch
|
||||
|
||||
unet = UNet2DConditionModel.from_pretrained(
|
||||
"stabilityai/stable-diffusion-xl-base-1.0",
|
||||
torch_dtype=torch.float16,
|
||||
use_safetensors=True,
|
||||
variant="fp16",
|
||||
subfolder="unet",
|
||||
).to("cuda")
|
||||
```
|
||||
|
||||
Load the SDXL pipeline and the LoRA checkpoints, starting with the [ostris/ikea-instructions-lora-sdxl](https://huggingface.co/ostris/ikea-instructions-lora-sdxl) LoRA.
|
||||
|
||||
```python
|
||||
from diffusers import DiffusionPipeline
|
||||
|
||||
pipeline = DiffusionPipeline.from_pretrained(
|
||||
"stabilityai/stable-diffusion-xl-base-1.0",
|
||||
variant="fp16",
|
||||
torch_dtype=torch.float16,
|
||||
unet=unet
|
||||
).to("cuda")
|
||||
pipeline.load_lora_weights("ostris/ikea-instructions-lora-sdxl", weight_name="ikea_instructions_xl_v1_5.safetensors", adapter_name="ikea")
|
||||
```
|
||||
|
||||
Now you'll create a [`~peft.PeftModel`] from the loaded LoRA checkpoint by combining the SDXL UNet and the LoRA UNet from the pipeline.
|
||||
|
||||
```python
|
||||
from peft import get_peft_model, LoraConfig
|
||||
import copy
|
||||
|
||||
sdxl_unet = copy.deepcopy(unet)
|
||||
ikea_peft_model = get_peft_model(
|
||||
sdxl_unet,
|
||||
pipeline.unet.peft_config["ikea"],
|
||||
adapter_name="ikea"
|
||||
)
|
||||
|
||||
original_state_dict = {f"base_model.model.{k}": v for k, v in pipeline.unet.state_dict().items()}
|
||||
ikea_peft_model.load_state_dict(original_state_dict, strict=True)
|
||||
```
|
||||
|
||||
> [!TIP]
|
||||
> You can optionally push the ikea_peft_model to the Hub by calling `ikea_peft_model.push_to_hub("ikea_peft_model", token=TOKEN)`.
|
||||
|
||||
Repeat this process to create a [`~peft.PeftModel`] from the [lordjia/by-feng-zikai](https://huggingface.co/lordjia/by-feng-zikai) LoRA.
|
||||
|
||||
```python
|
||||
pipeline.delete_adapters("ikea")
|
||||
sdxl_unet.delete_adapters("ikea")
|
||||
|
||||
pipeline.load_lora_weights("lordjia/by-feng-zikai", weight_name="fengzikai_v1.0_XL.safetensors", adapter_name="feng")
|
||||
pipeline.set_adapters(adapter_names="feng")
|
||||
|
||||
feng_peft_model = get_peft_model(
|
||||
sdxl_unet,
|
||||
pipeline.unet.peft_config["feng"],
|
||||
adapter_name="feng"
|
||||
)
|
||||
|
||||
original_state_dict = {f"base_model.model.{k}": v for k, v in pipe.unet.state_dict().items()}
|
||||
feng_peft_model.load_state_dict(original_state_dict, strict=True)
|
||||
```
|
||||
|
||||
2. Load a base UNet model and then load the adapters onto it.
|
||||
|
||||
```python
|
||||
from peft import PeftModel
|
||||
|
||||
base_unet = UNet2DConditionModel.from_pretrained(
|
||||
"stabilityai/stable-diffusion-xl-base-1.0",
|
||||
torch_dtype=torch.float16,
|
||||
use_safetensors=True,
|
||||
variant="fp16",
|
||||
subfolder="unet",
|
||||
).to("cuda")
|
||||
|
||||
model = PeftModel.from_pretrained(base_unet, "stevhliu/ikea_peft_model", use_safetensors=True, subfolder="ikea", adapter_name="ikea")
|
||||
model.load_adapter("stevhliu/feng_peft_model", use_safetensors=True, subfolder="feng", adapter_name="feng")
|
||||
```
|
||||
|
||||
3. Merge the adapters using the [`~peft.LoraModel.add_weighted_adapter`] method and the merging method of your choice (learn more about other merging methods in this [blog post](https://huggingface.co/blog/peft_merging)). For this example, let's use the `"dare_linear"` method to merge the LoRAs.
|
||||
|
||||
> [!WARNING]
|
||||
> Keep in mind the LoRAs need to have the same rank to be merged!
|
||||
|
||||
```python
|
||||
model.add_weighted_adapter(
|
||||
adapters=["ikea", "feng"],
|
||||
weights=[1.0, 1.0],
|
||||
combination_type="dare_linear",
|
||||
adapter_name="ikea-feng"
|
||||
)
|
||||
model.set_adapters("ikea-feng")
|
||||
```
|
||||
|
||||
Now you can generate an image with the merged LoRA.
|
||||
|
||||
```python
|
||||
model = model.to(dtype=torch.float16, device="cuda")
|
||||
|
||||
pipeline = DiffusionPipeline.from_pretrained(
|
||||
"stabilityai/stable-diffusion-xl-base-1.0", unet=model, variant="fp16", torch_dtype=torch.float16,
|
||||
).to("cuda")
|
||||
|
||||
image = pipeline("A bowl of ramen shaped like a cute kawaii bear, by Feng Zikai", generator=torch.manual_seed(0)).images[0]
|
||||
image
|
||||
```
|
||||
|
||||
<div class="flex justify-center">
|
||||
<img src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/ikea-feng-dare-linear.png"/>
|
||||
</div>
|
||||
|
||||
## fuse_lora
|
||||
|
||||
Both the [`~loaders.UNet2DConditionLoadersMixin.set_adapters`] and [`~peft.LoraModel.add_weighted_adapter`] methods require loading the base model and the LoRA adapters separately which incurs some overhead. The [`~loaders.LoraLoaderMixin.fuse_lora`] method allows you to fuse the LoRA weights directly with the original weights of the underlying model. This way, you're only loading the model once which can increase inference and lower memory-usage.
|
||||
|
||||
You can use PEFT to easily fuse/unfuse multiple adapters directly into the model weights (both UNet and text encoder) using the [`~loaders.LoraLoaderMixin.fuse_lora`] method, which can lead to a speed-up in inference and lower VRAM usage.
|
||||
|
||||
For example, if you have a base model and adapters loaded and set as active with the following adapter weights:
|
||||
|
||||
```py
|
||||
from diffusers import DiffusionPipeline
|
||||
import torch
|
||||
|
||||
pipeline = DiffusionPipeline.from_pretrained("stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16).to("cuda")
|
||||
pipeline.load_lora_weights("ostris/ikea-instructions-lora-sdxl", weight_name="ikea_instructions_xl_v1_5.safetensors", adapter_name="ikea")
|
||||
pipeline.load_lora_weights("lordjia/by-feng-zikai", weight_name="fengzikai_v1.0_XL.safetensors", adapter_name="feng")
|
||||
|
||||
pipeline.set_adapters(["ikea", "feng"], adapter_weights=[0.7, 0.8])
|
||||
```
|
||||
|
||||
Fuse these LoRAs into the UNet with the [`~loaders.LoraLoaderMixin.fuse_lora`] method. The `lora_scale` parameter controls how much to scale the output by with the LoRA weights. It is important to make the `lora_scale` adjustments in the [`~loaders.LoraLoaderMixin.fuse_lora`] method because it won’t work if you try to pass `scale` to the `cross_attention_kwargs` in the pipeline.
|
||||
|
||||
```py
|
||||
pipeline.fuse_lora(adapter_names=["ikea", "feng"], lora_scale=1.0)
|
||||
```
|
||||
|
||||
Then you should use [`~loaders.LoraLoaderMixin.unload_lora_weights`] to unload the LoRA weights since they've already been fused with the underlying base model. Finally, call [`~DiffusionPipeline.save_pretrained`] to save the fused pipeline locally or you could call [`~DiffusionPipeline.push_to_hub`] to push the fused pipeline to the Hub.
|
||||
|
||||
```py
|
||||
pipeline.unload_lora_weights()
|
||||
# save locally
|
||||
pipeline.save_pretrained("path/to/fused-pipeline")
|
||||
# save to the Hub
|
||||
pipeline.push_to_hub("fused-ikea-feng")
|
||||
```
|
||||
|
||||
Now you can quickly load the fused pipeline and use it for inference without needing to separately load the LoRA adapters.
|
||||
|
||||
```py
|
||||
pipeline = DiffusionPipeline.from_pretrained(
|
||||
"username/fused-ikea-feng", torch_dtype=torch.float16,
|
||||
).to("cuda")
|
||||
|
||||
image = pipeline("A bowl of ramen shaped like a cute kawaii bear, by Feng Zikai", generator=torch.manual_seed(0)).images[0]
|
||||
image
|
||||
```
|
||||
|
||||
You can call [`~loaders.LoraLoaderMixin.unfuse_lora`] to restore the original model's weights (for example, if you want to use a different `lora_scale` value). However, this only works if you've only fused one LoRA adapter to the original model. If you've fused multiple LoRAs, you'll need to reload the model.
|
||||
|
||||
```py
|
||||
pipeline.unfuse_lora()
|
||||
```
|
||||
|
||||
### torch.compile
|
||||
|
||||
[torch.compile](../optimization/torch2.0#torchcompile) can speed up your pipeline even more, but the LoRA weights must be fused first and then unloaded. Typically, the UNet is compiled because it is such a computationally intensive component of the pipeline.
|
||||
|
||||
```py
|
||||
from diffusers import DiffusionPipeline
|
||||
import torch
|
||||
|
||||
# load base model and LoRAs
|
||||
pipeline = DiffusionPipeline.from_pretrained("stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16).to("cuda")
|
||||
pipeline.load_lora_weights("ostris/ikea-instructions-lora-sdxl", weight_name="ikea_instructions_xl_v1_5.safetensors", adapter_name="ikea")
|
||||
pipeline.load_lora_weights("lordjia/by-feng-zikai", weight_name="fengzikai_v1.0_XL.safetensors", adapter_name="feng")
|
||||
|
||||
# activate both LoRAs and set adapter weights
|
||||
pipeline.set_adapters(["ikea", "feng"], adapter_weights=[0.7, 0.8])
|
||||
|
||||
# fuse LoRAs and unload weights
|
||||
pipeline.fuse_lora(adapter_names=["ikea", "feng"], lora_scale=1.0)
|
||||
pipeline.unload_lora_weights()
|
||||
|
||||
# torch.compile
|
||||
pipeline.unet.to(memory_format=torch.channels_last)
|
||||
pipeline.unet = torch.compile(pipeline.unet, mode="reduce-overhead", fullgraph=True)
|
||||
|
||||
image = pipeline("A bowl of ramen shaped like a cute kawaii bear, by Feng Zikai", generator=torch.manual_seed(0)).images[0]
|
||||
```
|
||||
|
||||
Learn more about torch.compile in the [Accelerate inference of text-to-image diffusion models](../tutorials/fast_diffusion#torchcompile) guide.
|
||||
|
||||
## Next steps
|
||||
|
||||
For more conceptual details about how each merging method works, take a look at the [🤗 PEFT welcomes new merging methods](https://huggingface.co/blog/peft_merging#concatenation-cat) blog post!
|
||||
@@ -273,7 +273,6 @@ Lastly, convert the image to a `PIL.Image` to see your generated image!
|
||||
```py
|
||||
>>> image = (image / 2 + 0.5).clamp(0, 1).squeeze()
|
||||
>>> image = (image.permute(1, 2, 0) * 255).to(torch.uint8).cpu().numpy()
|
||||
>>> image = (image * 255).round().astype("uint8")
|
||||
>>> image = Image.fromarray(image)
|
||||
>>> image
|
||||
```
|
||||
|
||||
@@ -259,6 +259,50 @@ pip install git+https://github.com/huggingface/peft.git
|
||||
**Inference**
|
||||
The inference is the same as if you train a regular LoRA 🤗
|
||||
|
||||
## Conducting EDM-style training
|
||||
|
||||
It's now possible to perform EDM-style training as proposed in [Elucidating the Design Space of Diffusion-Based Generative Models](https://arxiv.org/abs/2206.00364).
|
||||
|
||||
simply set:
|
||||
|
||||
```diff
|
||||
+ --do_edm_style_training \
|
||||
```
|
||||
|
||||
Other SDXL-like models that use the EDM formulation, such as [playgroundai/playground-v2.5-1024px-aesthetic](https://huggingface.co/playgroundai/playground-v2.5-1024px-aesthetic), can also be DreamBooth'd with the script. Below is an example command:
|
||||
|
||||
```bash
|
||||
accelerate launch train_dreambooth_lora_sdxl_advanced.py \
|
||||
--pretrained_model_name_or_path="playgroundai/playground-v2.5-1024px-aesthetic" \
|
||||
--dataset_name="linoyts/3d_icon" \
|
||||
--instance_prompt="3d icon in the style of TOK" \
|
||||
--validation_prompt="a TOK icon of an astronaut riding a horse, in the style of TOK" \
|
||||
--output_dir="3d-icon-SDXL-LoRA" \
|
||||
--do_edm_style_training \
|
||||
--caption_column="prompt" \
|
||||
--mixed_precision="bf16" \
|
||||
--resolution=1024 \
|
||||
--train_batch_size=3 \
|
||||
--repeats=1 \
|
||||
--report_to="wandb"\
|
||||
--gradient_accumulation_steps=1 \
|
||||
--gradient_checkpointing \
|
||||
--learning_rate=1.0 \
|
||||
--text_encoder_lr=1.0 \
|
||||
--optimizer="prodigy"\
|
||||
--train_text_encoder_ti\
|
||||
--train_text_encoder_ti_frac=0.5\
|
||||
--lr_scheduler="constant" \
|
||||
--lr_warmup_steps=0 \
|
||||
--rank=8 \
|
||||
--max_train_steps=1000 \
|
||||
--checkpointing_steps=2000 \
|
||||
--seed="0" \
|
||||
--push_to_hub
|
||||
```
|
||||
|
||||
> [!CAUTION]
|
||||
> Min-SNR gamma is not supported with the EDM-style training yet. When training with the PlaygroundAI model, it's recommended to not pass any "variant".
|
||||
|
||||
### Tips and Tricks
|
||||
Check out [these recommended practices](https://huggingface.co/blog/sdxl_lora_advanced_script#additional-good-practices)
|
||||
|
||||
@@ -70,7 +70,7 @@ from diffusers.utils.import_utils import is_xformers_available
|
||||
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.27.0.dev0")
|
||||
check_min_version("0.28.0.dev0")
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
@@ -1215,7 +1215,7 @@ def main(args):
|
||||
|
||||
xformers_version = version.parse(xformers.__version__)
|
||||
if xformers_version == version.parse("0.0.16"):
|
||||
logger.warn(
|
||||
logger.warning(
|
||||
"xFormers 0.0.16 cannot be used for training in some GPUs. If you observe problems during training, "
|
||||
"please update xFormers to at least 0.0.17. See https://huggingface.co/docs/diffusers/main/en/optimization/xformers for more details."
|
||||
)
|
||||
@@ -1366,14 +1366,14 @@ def main(args):
|
||||
|
||||
# Optimizer creation
|
||||
if not (args.optimizer.lower() == "prodigy" or args.optimizer.lower() == "adamw"):
|
||||
logger.warn(
|
||||
logger.warning(
|
||||
f"Unsupported choice of optimizer: {args.optimizer}.Supported optimizers include [adamW, prodigy]."
|
||||
"Defaulting to adamW"
|
||||
)
|
||||
args.optimizer = "adamw"
|
||||
|
||||
if args.use_8bit_adam and not args.optimizer.lower() == "adamw":
|
||||
logger.warn(
|
||||
logger.warning(
|
||||
f"use_8bit_adam is ignored when optimizer is not set to 'AdamW'. Optimizer was "
|
||||
f"set to {args.optimizer.lower()}"
|
||||
)
|
||||
@@ -1407,11 +1407,11 @@ def main(args):
|
||||
optimizer_class = prodigyopt.Prodigy
|
||||
|
||||
if args.learning_rate <= 0.1:
|
||||
logger.warn(
|
||||
logger.warning(
|
||||
"Learning rate is too low. When using prodigy, it's generally better to set learning rate around 1.0"
|
||||
)
|
||||
if args.train_text_encoder and args.text_encoder_lr:
|
||||
logger.warn(
|
||||
logger.warning(
|
||||
f"Learning rates were provided both for the unet and the text encoder- e.g. text_encoder_lr:"
|
||||
f" {args.text_encoder_lr} and learning_rate: {args.learning_rate}. "
|
||||
f"When using prodigy only learning_rate is used as the initial learning rate."
|
||||
|
||||
@@ -14,9 +14,11 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
|
||||
import argparse
|
||||
import contextlib
|
||||
import gc
|
||||
import hashlib
|
||||
import itertools
|
||||
import json
|
||||
import logging
|
||||
import math
|
||||
import os
|
||||
@@ -37,7 +39,7 @@ import transformers
|
||||
from accelerate import Accelerator
|
||||
from accelerate.logging import get_logger
|
||||
from accelerate.utils import DistributedDataParallelKwargs, ProjectConfiguration, set_seed
|
||||
from huggingface_hub import create_repo, upload_folder
|
||||
from huggingface_hub import create_repo, hf_hub_download, upload_folder
|
||||
from packaging import version
|
||||
from peft import LoraConfig, set_peft_model_state_dict
|
||||
from peft.utils import get_peft_model_state_dict
|
||||
@@ -55,6 +57,8 @@ from diffusers import (
|
||||
AutoencoderKL,
|
||||
DDPMScheduler,
|
||||
DPMSolverMultistepScheduler,
|
||||
EDMEulerScheduler,
|
||||
EulerDiscreteScheduler,
|
||||
StableDiffusionXLPipeline,
|
||||
UNet2DConditionModel,
|
||||
)
|
||||
@@ -74,11 +78,25 @@ from diffusers.utils.torch_utils import is_compiled_module
|
||||
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.27.0.dev0")
|
||||
check_min_version("0.28.0.dev0")
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
def determine_scheduler_type(pretrained_model_name_or_path, revision):
|
||||
model_index_filename = "model_index.json"
|
||||
if os.path.isdir(pretrained_model_name_or_path):
|
||||
model_index = os.path.join(pretrained_model_name_or_path, model_index_filename)
|
||||
else:
|
||||
model_index = hf_hub_download(
|
||||
repo_id=pretrained_model_name_or_path, filename=model_index_filename, revision=revision
|
||||
)
|
||||
|
||||
with open(model_index, "r") as f:
|
||||
scheduler_type = json.load(f)["scheduler"][1]
|
||||
return scheduler_type
|
||||
|
||||
|
||||
def save_model_card(
|
||||
repo_id: str,
|
||||
use_dora: bool,
|
||||
@@ -370,6 +388,11 @@ def parse_args(input_args=None):
|
||||
" `args.validation_prompt` multiple times: `args.num_validation_images`."
|
||||
),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--do_edm_style_training",
|
||||
action="store_true",
|
||||
help="Flag to conduct training using the EDM formulation as introduced in https://arxiv.org/abs/2206.00364.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--with_prior_preservation",
|
||||
default=False,
|
||||
@@ -1117,6 +1140,8 @@ def main(args):
|
||||
"You cannot use both --report_to=wandb and --hub_token due to a security risk of exposing your token."
|
||||
" Please use `huggingface-cli login` to authenticate with the Hub."
|
||||
)
|
||||
if args.do_edm_style_training and args.snr_gamma is not None:
|
||||
raise ValueError("Min-SNR formulation is not supported when conducting EDM-style training.")
|
||||
|
||||
logging_dir = Path(args.output_dir, args.logging_dir)
|
||||
|
||||
@@ -1234,7 +1259,19 @@ def main(args):
|
||||
)
|
||||
|
||||
# Load scheduler and models
|
||||
noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler")
|
||||
scheduler_type = determine_scheduler_type(args.pretrained_model_name_or_path, args.revision)
|
||||
if "EDM" in scheduler_type:
|
||||
args.do_edm_style_training = True
|
||||
noise_scheduler = EDMEulerScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler")
|
||||
logger.info("Performing EDM-style training!")
|
||||
elif args.do_edm_style_training:
|
||||
noise_scheduler = EulerDiscreteScheduler.from_pretrained(
|
||||
args.pretrained_model_name_or_path, subfolder="scheduler"
|
||||
)
|
||||
logger.info("Performing EDM-style training!")
|
||||
else:
|
||||
noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler")
|
||||
|
||||
text_encoder_one = text_encoder_cls_one.from_pretrained(
|
||||
args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision, variant=args.variant
|
||||
)
|
||||
@@ -1252,7 +1289,12 @@ def main(args):
|
||||
revision=args.revision,
|
||||
variant=args.variant,
|
||||
)
|
||||
vae_scaling_factor = vae.config.scaling_factor
|
||||
latents_mean = latents_std = None
|
||||
if hasattr(vae.config, "latents_mean") and vae.config.latents_mean is not None:
|
||||
latents_mean = torch.tensor(vae.config.latents_mean).view(1, 4, 1, 1)
|
||||
if hasattr(vae.config, "latents_std") and vae.config.latents_std is not None:
|
||||
latents_std = torch.tensor(vae.config.latents_std).view(1, 4, 1, 1)
|
||||
|
||||
unet = UNet2DConditionModel.from_pretrained(
|
||||
args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision, variant=args.variant
|
||||
)
|
||||
@@ -1317,7 +1359,7 @@ def main(args):
|
||||
|
||||
xformers_version = version.parse(xformers.__version__)
|
||||
if xformers_version == version.parse("0.0.16"):
|
||||
logger.warn(
|
||||
logger.warning(
|
||||
"xFormers 0.0.16 cannot be used for training in some GPUs. If you observe problems during training, "
|
||||
"please update xFormers to at least 0.0.17. See https://huggingface.co/docs/diffusers/main/en/optimization/xformers for more details."
|
||||
)
|
||||
@@ -1522,14 +1564,14 @@ def main(args):
|
||||
|
||||
# Optimizer creation
|
||||
if not (args.optimizer.lower() == "prodigy" or args.optimizer.lower() == "adamw"):
|
||||
logger.warn(
|
||||
logger.warning(
|
||||
f"Unsupported choice of optimizer: {args.optimizer}.Supported optimizers include [adamW, prodigy]."
|
||||
"Defaulting to adamW"
|
||||
)
|
||||
args.optimizer = "adamw"
|
||||
|
||||
if args.use_8bit_adam and not args.optimizer.lower() == "adamw":
|
||||
logger.warn(
|
||||
logger.warning(
|
||||
f"use_8bit_adam is ignored when optimizer is not set to 'AdamW'. Optimizer was "
|
||||
f"set to {args.optimizer.lower()}"
|
||||
)
|
||||
@@ -1563,11 +1605,11 @@ def main(args):
|
||||
optimizer_class = prodigyopt.Prodigy
|
||||
|
||||
if args.learning_rate <= 0.1:
|
||||
logger.warn(
|
||||
logger.warning(
|
||||
"Learning rate is too low. When using prodigy, it's generally better to set learning rate around 1.0"
|
||||
)
|
||||
if args.train_text_encoder and args.text_encoder_lr:
|
||||
logger.warn(
|
||||
logger.warning(
|
||||
f"Learning rates were provided both for the unet and the text encoder- e.g. text_encoder_lr:"
|
||||
f" {args.text_encoder_lr} and learning_rate: {args.learning_rate}. "
|
||||
f"When using prodigy only learning_rate is used as the initial learning rate."
|
||||
@@ -1790,6 +1832,19 @@ def main(args):
|
||||
disable=not accelerator.is_local_main_process,
|
||||
)
|
||||
|
||||
def get_sigmas(timesteps, n_dim=4, dtype=torch.float32):
|
||||
# TODO: revisit other sampling algorithms
|
||||
sigmas = noise_scheduler.sigmas.to(device=accelerator.device, dtype=dtype)
|
||||
schedule_timesteps = noise_scheduler.timesteps.to(accelerator.device)
|
||||
timesteps = timesteps.to(accelerator.device)
|
||||
|
||||
step_indices = [(schedule_timesteps == t).nonzero().item() for t in timesteps]
|
||||
|
||||
sigma = sigmas[step_indices].flatten()
|
||||
while len(sigma.shape) < n_dim:
|
||||
sigma = sigma.unsqueeze(-1)
|
||||
return sigma
|
||||
|
||||
if args.train_text_encoder:
|
||||
num_train_epochs_text_encoder = int(args.train_text_encoder_frac * args.num_train_epochs)
|
||||
elif args.train_text_encoder_ti: # args.train_text_encoder_ti
|
||||
@@ -1841,9 +1896,15 @@ def main(args):
|
||||
pixel_values = batch["pixel_values"].to(dtype=vae.dtype)
|
||||
model_input = vae.encode(pixel_values).latent_dist.sample()
|
||||
|
||||
model_input = model_input * vae_scaling_factor
|
||||
if args.pretrained_vae_model_name_or_path is None:
|
||||
model_input = model_input.to(weight_dtype)
|
||||
if latents_mean is None and latents_std is None:
|
||||
model_input = model_input * vae.config.scaling_factor
|
||||
if args.pretrained_vae_model_name_or_path is None:
|
||||
model_input = model_input.to(weight_dtype)
|
||||
else:
|
||||
latents_mean = latents_mean.to(device=model_input.device, dtype=model_input.dtype)
|
||||
latents_std = latents_std.to(device=model_input.device, dtype=model_input.dtype)
|
||||
model_input = (model_input - latents_mean) * vae.config.scaling_factor / latents_std
|
||||
model_input = model_input.to(dtype=weight_dtype)
|
||||
|
||||
# Sample noise that we'll add to the latents
|
||||
noise = torch.randn_like(model_input)
|
||||
@@ -1854,15 +1915,32 @@ def main(args):
|
||||
)
|
||||
|
||||
bsz = model_input.shape[0]
|
||||
|
||||
# Sample a random timestep for each image
|
||||
timesteps = torch.randint(
|
||||
0, noise_scheduler.config.num_train_timesteps, (bsz,), device=model_input.device
|
||||
)
|
||||
timesteps = timesteps.long()
|
||||
if not args.do_edm_style_training:
|
||||
timesteps = torch.randint(
|
||||
0, noise_scheduler.config.num_train_timesteps, (bsz,), device=model_input.device
|
||||
)
|
||||
timesteps = timesteps.long()
|
||||
else:
|
||||
# in EDM formulation, the model is conditioned on the pre-conditioned noise levels
|
||||
# instead of discrete timesteps, so here we sample indices to get the noise levels
|
||||
# from `scheduler.timesteps`
|
||||
indices = torch.randint(0, noise_scheduler.config.num_train_timesteps, (bsz,))
|
||||
timesteps = noise_scheduler.timesteps[indices].to(device=model_input.device)
|
||||
|
||||
# Add noise to the model input according to the noise magnitude at each timestep
|
||||
# (this is the forward diffusion process)
|
||||
noisy_model_input = noise_scheduler.add_noise(model_input, noise, timesteps)
|
||||
# For EDM-style training, we first obtain the sigmas based on the continuous timesteps.
|
||||
# We then precondition the final model inputs based on these sigmas instead of the timesteps.
|
||||
# Follow: Section 5 of https://arxiv.org/abs/2206.00364.
|
||||
if args.do_edm_style_training:
|
||||
sigmas = get_sigmas(timesteps, len(noisy_model_input.shape), noisy_model_input.dtype)
|
||||
if "EDM" in scheduler_type:
|
||||
inp_noisy_latents = noise_scheduler.precondition_inputs(noisy_model_input, sigmas)
|
||||
else:
|
||||
inp_noisy_latents = noisy_model_input / ((sigmas**2 + 1) ** 0.5)
|
||||
|
||||
# time ids
|
||||
add_time_ids = torch.cat(
|
||||
@@ -1888,7 +1966,7 @@ def main(args):
|
||||
}
|
||||
prompt_embeds_input = prompt_embeds.repeat(elems_to_repeat_text_embeds, 1, 1)
|
||||
model_pred = unet(
|
||||
noisy_model_input,
|
||||
inp_noisy_latents if args.do_edm_style_training else noisy_model_input,
|
||||
timesteps,
|
||||
prompt_embeds_input,
|
||||
added_cond_kwargs=unet_added_conditions,
|
||||
@@ -1906,14 +1984,42 @@ def main(args):
|
||||
)
|
||||
prompt_embeds_input = prompt_embeds.repeat(elems_to_repeat_text_embeds, 1, 1)
|
||||
model_pred = unet(
|
||||
noisy_model_input, timesteps, prompt_embeds_input, added_cond_kwargs=unet_added_conditions
|
||||
inp_noisy_latents if args.do_edm_style_training else noisy_model_input,
|
||||
timesteps,
|
||||
prompt_embeds_input,
|
||||
added_cond_kwargs=unet_added_conditions,
|
||||
).sample
|
||||
|
||||
weighting = None
|
||||
if args.do_edm_style_training:
|
||||
# Similar to the input preconditioning, the model predictions are also preconditioned
|
||||
# on noised model inputs (before preconditioning) and the sigmas.
|
||||
# Follow: Section 5 of https://arxiv.org/abs/2206.00364.
|
||||
if "EDM" in scheduler_type:
|
||||
model_pred = noise_scheduler.precondition_outputs(noisy_model_input, model_pred, sigmas)
|
||||
else:
|
||||
if noise_scheduler.config.prediction_type == "epsilon":
|
||||
model_pred = model_pred * (-sigmas) + noisy_model_input
|
||||
elif noise_scheduler.config.prediction_type == "v_prediction":
|
||||
model_pred = model_pred * (-sigmas / (sigmas**2 + 1) ** 0.5) + (
|
||||
noisy_model_input / (sigmas**2 + 1)
|
||||
)
|
||||
# We are not doing weighting here because it tends result in numerical problems.
|
||||
# See: https://github.com/huggingface/diffusers/pull/7126#issuecomment-1968523051
|
||||
# There might be other alternatives for weighting as well:
|
||||
# https://github.com/huggingface/diffusers/pull/7126#discussion_r1505404686
|
||||
if "EDM" not in scheduler_type:
|
||||
weighting = (sigmas**-2.0).float()
|
||||
|
||||
# Get the target for loss depending on the prediction type
|
||||
if noise_scheduler.config.prediction_type == "epsilon":
|
||||
target = noise
|
||||
target = model_input if args.do_edm_style_training else noise
|
||||
elif noise_scheduler.config.prediction_type == "v_prediction":
|
||||
target = noise_scheduler.get_velocity(model_input, noise, timesteps)
|
||||
target = (
|
||||
model_input
|
||||
if args.do_edm_style_training
|
||||
else noise_scheduler.get_velocity(model_input, noise, timesteps)
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}")
|
||||
|
||||
@@ -1923,10 +2029,28 @@ def main(args):
|
||||
target, target_prior = torch.chunk(target, 2, dim=0)
|
||||
|
||||
# Compute prior loss
|
||||
prior_loss = F.mse_loss(model_pred_prior.float(), target_prior.float(), reduction="mean")
|
||||
if weighting is not None:
|
||||
prior_loss = torch.mean(
|
||||
(weighting.float() * (model_pred_prior.float() - target_prior.float()) ** 2).reshape(
|
||||
target_prior.shape[0], -1
|
||||
),
|
||||
1,
|
||||
)
|
||||
prior_loss = prior_loss.mean()
|
||||
else:
|
||||
prior_loss = F.mse_loss(model_pred_prior.float(), target_prior.float(), reduction="mean")
|
||||
|
||||
if args.snr_gamma is None:
|
||||
loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean")
|
||||
if weighting is not None:
|
||||
loss = torch.mean(
|
||||
(weighting.float() * (model_pred.float() - target.float()) ** 2).reshape(
|
||||
target.shape[0], -1
|
||||
),
|
||||
1,
|
||||
)
|
||||
loss = loss.mean()
|
||||
else:
|
||||
loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean")
|
||||
else:
|
||||
# Compute loss-weights as per Section 3.4 of https://arxiv.org/abs/2303.09556.
|
||||
# Since we predict the noise instead of x_0, the original formulation is slightly changed.
|
||||
@@ -2049,17 +2173,18 @@ def main(args):
|
||||
# We train on the simplified learning objective. If we were previously predicting a variance, we need the scheduler to ignore it
|
||||
scheduler_args = {}
|
||||
|
||||
if "variance_type" in pipeline.scheduler.config:
|
||||
variance_type = pipeline.scheduler.config.variance_type
|
||||
if not args.do_edm_style_training:
|
||||
if "variance_type" in pipeline.scheduler.config:
|
||||
variance_type = pipeline.scheduler.config.variance_type
|
||||
|
||||
if variance_type in ["learned", "learned_range"]:
|
||||
variance_type = "fixed_small"
|
||||
if variance_type in ["learned", "learned_range"]:
|
||||
variance_type = "fixed_small"
|
||||
|
||||
scheduler_args["variance_type"] = variance_type
|
||||
scheduler_args["variance_type"] = variance_type
|
||||
|
||||
pipeline.scheduler = DPMSolverMultistepScheduler.from_config(
|
||||
pipeline.scheduler.config, **scheduler_args
|
||||
)
|
||||
pipeline.scheduler = DPMSolverMultistepScheduler.from_config(
|
||||
pipeline.scheduler.config, **scheduler_args
|
||||
)
|
||||
|
||||
pipeline = pipeline.to(accelerator.device)
|
||||
pipeline.set_progress_bar_config(disable=True)
|
||||
@@ -2067,8 +2192,13 @@ def main(args):
|
||||
# run inference
|
||||
generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed else None
|
||||
pipeline_args = {"prompt": args.validation_prompt}
|
||||
inference_ctx = (
|
||||
contextlib.nullcontext()
|
||||
if "playground" in args.pretrained_model_name_or_path
|
||||
else torch.cuda.amp.autocast()
|
||||
)
|
||||
|
||||
with torch.cuda.amp.autocast():
|
||||
with inference_ctx:
|
||||
images = [
|
||||
pipeline(**pipeline_args, generator=generator).images[0]
|
||||
for _ in range(args.num_validation_images)
|
||||
@@ -2144,15 +2274,18 @@ def main(args):
|
||||
# We train on the simplified learning objective. If we were previously predicting a variance, we need the scheduler to ignore it
|
||||
scheduler_args = {}
|
||||
|
||||
if "variance_type" in pipeline.scheduler.config:
|
||||
variance_type = pipeline.scheduler.config.variance_type
|
||||
if not args.do_edm_style_training:
|
||||
if "variance_type" in pipeline.scheduler.config:
|
||||
variance_type = pipeline.scheduler.config.variance_type
|
||||
|
||||
if variance_type in ["learned", "learned_range"]:
|
||||
variance_type = "fixed_small"
|
||||
if variance_type in ["learned", "learned_range"]:
|
||||
variance_type = "fixed_small"
|
||||
|
||||
scheduler_args["variance_type"] = variance_type
|
||||
scheduler_args["variance_type"] = variance_type
|
||||
|
||||
pipeline.scheduler = DPMSolverMultistepScheduler.from_config(pipeline.scheduler.config, **scheduler_args)
|
||||
pipeline.scheduler = DPMSolverMultistepScheduler.from_config(
|
||||
pipeline.scheduler.config, **scheduler_args
|
||||
)
|
||||
|
||||
# load attention processors
|
||||
pipeline.load_lora_weights(args.output_dir)
|
||||
|
||||
@@ -105,7 +105,7 @@ pipeline_output = pipe(
|
||||
# processing_res=768, # (optional) Maximum resolution of processing. If set to 0: will not resize at all. Defaults to 768.
|
||||
# match_input_res=True, # (optional) Resize depth prediction to match input resolution.
|
||||
# batch_size=0, # (optional) Inference batch size, no bigger than `num_ensemble`. If set to 0, the script will automatically decide the proper batch size. Defaults to 0.
|
||||
# color_map="Spectral", # (optional) Colormap used to colorize the depth map. Defaults to "Spectral".
|
||||
# color_map="Spectral", # (optional) Colormap used to colorize the depth map. Defaults to "Spectral". Set to `None` to skip colormap generation.
|
||||
# show_progress_bar=True, # (optional) If true, will show progress bars of the inference progress.
|
||||
)
|
||||
|
||||
@@ -3414,15 +3414,13 @@ pipeline(prompt, uncond, inverted_latent, guidance_scale=7.5, num_inference_step
|
||||
|
||||
### Rerender A Video
|
||||
|
||||
This is the Diffusers implementation of zero-shot video-to-video translation pipeline [Rerender A Video](https://github.com/williamyang1991/Rerender_A_Video) (without Ebsynth postprocessing). To run the code, please install gmflow. Then modify the path in `examples/community/rerender_a_video.py`:
|
||||
This is the Diffusers implementation of zero-shot video-to-video translation pipeline [Rerender A Video](https://github.com/williamyang1991/Rerender_A_Video) (without Ebsynth postprocessing). To run the code, please install gmflow. Then modify the path in `gmflow_dir`. After that, you can run the pipeline with:
|
||||
|
||||
```py
|
||||
import sys
|
||||
gmflow_dir = "/path/to/gmflow"
|
||||
```
|
||||
sys.path.insert(0, gmflow_dir)
|
||||
|
||||
After that, you can run the pipeline with:
|
||||
|
||||
```py
|
||||
from diffusers import ControlNetModel, AutoencoderKL, DDIMScheduler
|
||||
from diffusers.utils import export_to_video
|
||||
import numpy as np
|
||||
|
||||
@@ -513,9 +513,7 @@ class LCMSchedulerWithTimestamp(SchedulerMixin, ConfigMixin):
|
||||
there is no previous alpha. When this option is `True` the previous alpha product is fixed to `1`,
|
||||
otherwise it uses the alpha value at step 0.
|
||||
steps_offset (`int`, defaults to 0):
|
||||
An offset added to the inference steps. You can use a combination of `offset=1` and
|
||||
`set_alpha_to_one=False` to make the last step use step 0 for the previous alpha product like in Stable
|
||||
Diffusion.
|
||||
An offset added to the inference steps, as required by some model families.
|
||||
prediction_type (`str`, defaults to `epsilon`, *optional*):
|
||||
Prediction type of the scheduler function; can be `epsilon` (predicts the noise of the diffusion process),
|
||||
`sample` (directly predicts the noisy sample`) or `v_prediction` (see section 2.4 of [Imagen
|
||||
|
||||
@@ -418,9 +418,7 @@ class LCMScheduler(SchedulerMixin, ConfigMixin):
|
||||
there is no previous alpha. When this option is `True` the previous alpha product is fixed to `1`,
|
||||
otherwise it uses the alpha value at step 0.
|
||||
steps_offset (`int`, defaults to 0):
|
||||
An offset added to the inference steps. You can use a combination of `offset=1` and
|
||||
`set_alpha_to_one=False` to make the last step use step 0 for the previous alpha product like in Stable
|
||||
Diffusion.
|
||||
An offset added to the inference steps, as required by some model families.
|
||||
prediction_type (`str`, defaults to `epsilon`, *optional*):
|
||||
Prediction type of the scheduler function; can be `epsilon` (predicts the noise of the diffusion process),
|
||||
`sample` (directly predicts the noisy sample`) or `v_prediction` (see section 2.4 of [Imagen
|
||||
|
||||
@@ -40,7 +40,7 @@ from diffusers.utils import BaseOutput, check_min_version
|
||||
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.27.0.dev0")
|
||||
check_min_version("0.28.0.dev0")
|
||||
|
||||
|
||||
class MarigoldDepthOutput(BaseOutput):
|
||||
@@ -50,14 +50,14 @@ class MarigoldDepthOutput(BaseOutput):
|
||||
Args:
|
||||
depth_np (`np.ndarray`):
|
||||
Predicted depth map, with depth values in the range of [0, 1].
|
||||
depth_colored (`PIL.Image.Image`):
|
||||
depth_colored (`None` or `PIL.Image.Image`):
|
||||
Colorized depth map, with the shape of [3, H, W] and values in [0, 1].
|
||||
uncertainty (`None` or `np.ndarray`):
|
||||
Uncalibrated uncertainty(MAD, median absolute deviation) coming from ensembling.
|
||||
"""
|
||||
|
||||
depth_np: np.ndarray
|
||||
depth_colored: Image.Image
|
||||
depth_colored: Union[None, Image.Image]
|
||||
uncertainty: Union[None, np.ndarray]
|
||||
|
||||
|
||||
@@ -139,14 +139,15 @@ class MarigoldPipeline(DiffusionPipeline):
|
||||
If set to 0, the script will automatically decide the proper batch size.
|
||||
show_progress_bar (`bool`, *optional*, defaults to `True`):
|
||||
Display a progress bar of diffusion denoising.
|
||||
color_map (`str`, *optional*, defaults to `"Spectral"`):
|
||||
color_map (`str`, *optional*, defaults to `"Spectral"`, pass `None` to skip colorized depth map generation):
|
||||
Colormap used to colorize the depth map.
|
||||
ensemble_kwargs (`dict`, *optional*, defaults to `None`):
|
||||
Arguments for detailed ensembling settings.
|
||||
Returns:
|
||||
`MarigoldDepthOutput`: Output class for Marigold monocular depth prediction pipeline, including:
|
||||
- **depth_np** (`np.ndarray`) Predicted depth map, with depth values in the range of [0, 1]
|
||||
- **depth_colored** (`PIL.Image.Image`) Colorized depth map, with the shape of [3, H, W] and values in [0, 1]
|
||||
- **depth_colored** (`None` or `PIL.Image.Image`) Colorized depth map, with the shape of [3, H, W] and
|
||||
values in [0, 1]. None if `color_map` is `None`
|
||||
- **uncertainty** (`None` or `np.ndarray`) Uncalibrated uncertainty(MAD, median absolute deviation)
|
||||
coming from ensembling. None if `ensemble_size = 1`
|
||||
"""
|
||||
@@ -233,12 +234,15 @@ class MarigoldPipeline(DiffusionPipeline):
|
||||
depth_pred = depth_pred.clip(0, 1)
|
||||
|
||||
# Colorize
|
||||
depth_colored = self.colorize_depth_maps(
|
||||
depth_pred, 0, 1, cmap=color_map
|
||||
).squeeze() # [3, H, W], value in (0, 1)
|
||||
depth_colored = (depth_colored * 255).astype(np.uint8)
|
||||
depth_colored_hwc = self.chw2hwc(depth_colored)
|
||||
depth_colored_img = Image.fromarray(depth_colored_hwc)
|
||||
if color_map is not None:
|
||||
depth_colored = self.colorize_depth_maps(
|
||||
depth_pred, 0, 1, cmap=color_map
|
||||
).squeeze() # [3, H, W], value in (0, 1)
|
||||
depth_colored = (depth_colored * 255).astype(np.uint8)
|
||||
depth_colored_hwc = self.chw2hwc(depth_colored)
|
||||
depth_colored_img = Image.fromarray(depth_colored_hwc)
|
||||
else:
|
||||
depth_colored_img = None
|
||||
return MarigoldDepthOutput(
|
||||
depth_np=depth_pred,
|
||||
depth_colored=depth_colored_img,
|
||||
|
||||
@@ -13,7 +13,6 @@
|
||||
# limitations under the License.
|
||||
|
||||
import inspect
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
@@ -27,6 +26,7 @@ from diffusers.loaders import IPAdapterMixin, LoraLoaderMixin, TextualInversionL
|
||||
from diffusers.models import AutoencoderKL, ControlNetModel, ImageProjection, UNet2DConditionModel, UNetMotionModel
|
||||
from diffusers.models.lora import adjust_lora_scale_text_encoder
|
||||
from diffusers.models.unets.unet_motion_model import MotionAdapter
|
||||
from diffusers.pipelines.animatediff.pipeline_output import AnimateDiffPipelineOutput
|
||||
from diffusers.pipelines.controlnet.multicontrolnet import MultiControlNetModel
|
||||
from diffusers.pipelines.pipeline_utils import DiffusionPipeline, StableDiffusionMixin
|
||||
from diffusers.schedulers import (
|
||||
@@ -37,7 +37,7 @@ from diffusers.schedulers import (
|
||||
LMSDiscreteScheduler,
|
||||
PNDMScheduler,
|
||||
)
|
||||
from diffusers.utils import USE_PEFT_BACKEND, BaseOutput, deprecate, logging, scale_lora_layers, unscale_lora_layers
|
||||
from diffusers.utils import USE_PEFT_BACKEND, deprecate, logging, scale_lora_layers, unscale_lora_layers
|
||||
from diffusers.utils.torch_utils import is_compiled_module, randn_tensor
|
||||
|
||||
|
||||
@@ -91,10 +91,8 @@ EXAMPLE_DOC_STRING = """
|
||||
"""
|
||||
|
||||
|
||||
# Copied from diffusers.pipelines.animatediff.pipeline_animatediff.tensor2vid
|
||||
def tensor2vid(video: torch.Tensor, processor, output_type="np"):
|
||||
# Based on:
|
||||
# https://github.com/modelscope/modelscope/blob/1509fdb973e5871f37148a4b5e5964cafd43e64d/modelscope/pipelines/multi_modal/text_to_video_synthesis_pipeline.py#L78
|
||||
|
||||
batch_size, channels, num_frames, height, width = video.shape
|
||||
outputs = []
|
||||
for batch_idx in range(batch_size):
|
||||
@@ -103,14 +101,18 @@ def tensor2vid(video: torch.Tensor, processor, output_type="np"):
|
||||
|
||||
outputs.append(batch_output)
|
||||
|
||||
if output_type == "np":
|
||||
outputs = np.stack(outputs)
|
||||
|
||||
elif output_type == "pt":
|
||||
outputs = torch.stack(outputs)
|
||||
|
||||
elif not output_type == "pil":
|
||||
raise ValueError(f"{output_type} does not exist. Please choose one of ['np', 'pt', 'pil']")
|
||||
|
||||
return outputs
|
||||
|
||||
|
||||
@dataclass
|
||||
class AnimateDiffControlNetPipelineOutput(BaseOutput):
|
||||
frames: Union[torch.Tensor, np.ndarray]
|
||||
|
||||
|
||||
class AnimateDiffControlNetPipeline(
|
||||
DiffusionPipeline, StableDiffusionMixin, TextualInversionLoaderMixin, IPAdapterMixin, LoraLoaderMixin
|
||||
):
|
||||
@@ -843,8 +845,8 @@ class AnimateDiffControlNetPipeline(
|
||||
Examples:
|
||||
|
||||
Returns:
|
||||
[`~pipelines.text_to_video_synthesis.TextToVideoSDPipelineOutput`] or `tuple`:
|
||||
If `return_dict` is `True`, [`~pipelines.text_to_video_synthesis.TextToVideoSDPipelineOutput`] is
|
||||
[`~pipelines.animatediff.pipeline_output.AnimateDiffPipelineOutput`] or `tuple`:
|
||||
If `return_dict` is `True`, [`~pipelines.animatediff.pipeline_output.AnimateDiffPipelineOutput`] is
|
||||
returned, otherwise a `tuple` is returned where the first element is a list with the generated frames.
|
||||
"""
|
||||
|
||||
@@ -1020,7 +1022,7 @@ class AnimateDiffControlNetPipeline(
|
||||
]
|
||||
controlnet_keep.append(keeps[0] if isinstance(controlnet, ControlNetModel) else keeps)
|
||||
|
||||
# Denoising loop
|
||||
# 8. Denoising loop
|
||||
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
|
||||
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
||||
for i, t in enumerate(timesteps):
|
||||
@@ -1096,21 +1098,17 @@ class AnimateDiffControlNetPipeline(
|
||||
if callback is not None and i % callback_steps == 0:
|
||||
callback(i, t, latents)
|
||||
|
||||
# 9. Post processing
|
||||
if output_type == "latent":
|
||||
return AnimateDiffControlNetPipelineOutput(frames=latents)
|
||||
|
||||
# Post-processing
|
||||
video_tensor = self.decode_latents(latents)
|
||||
|
||||
if output_type == "pt":
|
||||
video = video_tensor
|
||||
video = latents
|
||||
else:
|
||||
video_tensor = self.decode_latents(latents)
|
||||
video = tensor2vid(video_tensor, self.image_processor, output_type=output_type)
|
||||
|
||||
# Offload all models
|
||||
# 10. Offload all models
|
||||
self.maybe_free_model_hooks()
|
||||
|
||||
if not return_dict:
|
||||
return (video,)
|
||||
|
||||
return AnimateDiffControlNetPipelineOutput(frames=video)
|
||||
return AnimateDiffPipelineOutput(frames=video)
|
||||
|
||||
@@ -158,10 +158,8 @@ def slerp(
|
||||
return v2
|
||||
|
||||
|
||||
# Copied from diffusers.pipelines.animatediff.pipeline_animatediff.tensor2vid
|
||||
def tensor2vid(video: torch.Tensor, processor, output_type="np"):
|
||||
# Based on:
|
||||
# https://github.com/modelscope/modelscope/blob/1509fdb973e5871f37148a4b5e5964cafd43e64d/modelscope/pipelines/multi_modal/text_to_video_synthesis_pipeline.py#L78
|
||||
|
||||
batch_size, channels, num_frames, height, width = video.shape
|
||||
outputs = []
|
||||
for batch_idx in range(batch_size):
|
||||
@@ -170,6 +168,15 @@ def tensor2vid(video: torch.Tensor, processor, output_type="np"):
|
||||
|
||||
outputs.append(batch_output)
|
||||
|
||||
if output_type == "np":
|
||||
outputs = np.stack(outputs)
|
||||
|
||||
elif output_type == "pt":
|
||||
outputs = torch.stack(outputs)
|
||||
|
||||
elif not output_type == "pil":
|
||||
raise ValueError(f"{output_type} does not exist. Please choose one of ['np', 'pt', 'pil']")
|
||||
|
||||
return outputs
|
||||
|
||||
|
||||
@@ -826,8 +833,8 @@ class AnimateDiffImgToVideoPipeline(
|
||||
Examples:
|
||||
|
||||
Returns:
|
||||
[`AnimateDiffPipelineOutput`] or `tuple`:
|
||||
If `return_dict` is `True`, [`AnimateDiffPipelineOutput`] is
|
||||
[`~pipelines.animatediff.pipeline_output.AnimateDiffPipelineOutput`] or `tuple`:
|
||||
If `return_dict` is `True`, [`~pipelines.animatediff.pipeline_output.AnimateDiffPipelineOutput`] is
|
||||
returned, otherwise a `tuple` is returned where the first element is a list with the generated frames.
|
||||
"""
|
||||
# 0. Default height and width to unet
|
||||
@@ -958,11 +965,10 @@ class AnimateDiffImgToVideoPipeline(
|
||||
return AnimateDiffPipelineOutput(frames=latents)
|
||||
|
||||
# 10. Post-processing
|
||||
video_tensor = self.decode_latents(latents)
|
||||
|
||||
if output_type == "pt":
|
||||
video = video_tensor
|
||||
if output_type == "latent":
|
||||
video = latents
|
||||
else:
|
||||
video_tensor = self.decode_latents(latents)
|
||||
video = tensor2vid(video_tensor, self.image_processor, output_type=output_type)
|
||||
|
||||
# 11. Offload all models
|
||||
|
||||
@@ -452,7 +452,7 @@ class StableDiffusionXLInstantIDPipeline(StableDiffusionXLControlNetPipeline):
|
||||
|
||||
xformers_version = version.parse(xformers.__version__)
|
||||
if xformers_version == version.parse("0.0.16"):
|
||||
logger.warn(
|
||||
logger.warning(
|
||||
"xFormers 0.0.16 cannot be used for training in some GPUs. If you observe problems during training, please update xFormers to at least 0.0.17. See https://huggingface.co/docs/diffusers/main/en/optimization/xformers for more details."
|
||||
)
|
||||
self.enable_xformers_memory_efficient_attention()
|
||||
|
||||
@@ -12,7 +12,6 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import sys
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
|
||||
|
||||
@@ -21,6 +20,7 @@ import PIL.Image
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
import torchvision.transforms as T
|
||||
from gmflow.gmflow import GMFlow
|
||||
from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer
|
||||
|
||||
from diffusers.image_processor import VaeImageProcessor
|
||||
@@ -34,13 +34,6 @@ from diffusers.utils import BaseOutput, deprecate, logging
|
||||
from diffusers.utils.torch_utils import is_compiled_module, randn_tensor
|
||||
|
||||
|
||||
gmflow_dir = "/path/to/gmflow"
|
||||
sys.path.insert(0, gmflow_dir)
|
||||
from gmflow.gmflow import GMFlow # noqa: E402
|
||||
|
||||
from utils.utils import InputPadder # noqa: E402
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
|
||||
@@ -119,11 +112,11 @@ def forward_backward_consistency_check(fwd_flow, bwd_flow, alpha=0.01, beta=0.5)
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def get_warped_and_mask(flow_model, image1, image2, image3=None, pixel_consistency=False):
|
||||
def get_warped_and_mask(flow_model, image1, image2, image3=None, pixel_consistency=False, device=None):
|
||||
if image3 is None:
|
||||
image3 = image1
|
||||
padder = InputPadder(image1.shape, padding_factor=8)
|
||||
image1, image2 = padder.pad(image1[None].cuda(), image2[None].cuda())
|
||||
image1, image2 = padder.pad(image1[None].to(device), image2[None].to(device))
|
||||
results_dict = flow_model(
|
||||
image1, image2, attn_splits_list=[2], corr_radius_list=[-1], prop_radius_list=[-1], pred_bidir_flow=True
|
||||
)
|
||||
@@ -307,6 +300,7 @@ class RerenderAVideoPipeline(StableDiffusionControlNetImg2ImgPipeline):
|
||||
feature_extractor: CLIPImageProcessor,
|
||||
image_encoder=None,
|
||||
requires_safety_checker: bool = True,
|
||||
device=None,
|
||||
):
|
||||
super().__init__(
|
||||
vae,
|
||||
@@ -320,6 +314,7 @@ class RerenderAVideoPipeline(StableDiffusionControlNetImg2ImgPipeline):
|
||||
image_encoder,
|
||||
requires_safety_checker,
|
||||
)
|
||||
self.to(device)
|
||||
|
||||
if safety_checker is None and requires_safety_checker:
|
||||
logger.warning(
|
||||
@@ -374,7 +369,7 @@ class RerenderAVideoPipeline(StableDiffusionControlNetImg2ImgPipeline):
|
||||
attention_type="swin",
|
||||
ffn_dim_expansion=4,
|
||||
num_transformer_layers=6,
|
||||
).to("cuda")
|
||||
).to(self.device)
|
||||
|
||||
checkpoint = torch.utils.model_zoo.load_url(
|
||||
"https://huggingface.co/Anonymous-sub/Rerender/resolve/main/models/gmflow_sintel-0c07dcb3.pth",
|
||||
@@ -928,13 +923,13 @@ class RerenderAVideoPipeline(StableDiffusionControlNetImg2ImgPipeline):
|
||||
prev_image = self.image_processor.preprocess(prev_image).to(dtype=torch.float32)
|
||||
|
||||
warped_0, bwd_occ_0, bwd_flow_0 = get_warped_and_mask(
|
||||
self.flow_model, first_image, image[0], first_result, False
|
||||
self.flow_model, first_image, image[0], first_result, False, self.device
|
||||
)
|
||||
blend_mask_0 = blur(F.max_pool2d(bwd_occ_0, kernel_size=9, stride=1, padding=4))
|
||||
blend_mask_0 = torch.clamp(blend_mask_0 + bwd_occ_0, 0, 1)
|
||||
|
||||
warped_pre, bwd_occ_pre, bwd_flow_pre = get_warped_and_mask(
|
||||
self.flow_model, prev_image[0], image[0], prev_result, False
|
||||
self.flow_model, prev_image[0], image[0], prev_result, False, self.device
|
||||
)
|
||||
blend_mask_pre = blur(F.max_pool2d(bwd_occ_pre, kernel_size=9, stride=1, padding=4))
|
||||
blend_mask_pre = torch.clamp(blend_mask_pre + bwd_occ_pre, 0, 1)
|
||||
@@ -1176,3 +1171,24 @@ class RerenderAVideoPipeline(StableDiffusionControlNetImg2ImgPipeline):
|
||||
return output_frames
|
||||
|
||||
return TextToVideoSDPipelineOutput(frames=output_frames)
|
||||
|
||||
|
||||
class InputPadder:
|
||||
"""Pads images such that dimensions are divisible by 8"""
|
||||
|
||||
def __init__(self, dims, mode="sintel", padding_factor=8):
|
||||
self.ht, self.wd = dims[-2:]
|
||||
pad_ht = (((self.ht // padding_factor) + 1) * padding_factor - self.ht) % padding_factor
|
||||
pad_wd = (((self.wd // padding_factor) + 1) * padding_factor - self.wd) % padding_factor
|
||||
if mode == "sintel":
|
||||
self._pad = [pad_wd // 2, pad_wd - pad_wd // 2, pad_ht // 2, pad_ht - pad_ht // 2]
|
||||
else:
|
||||
self._pad = [pad_wd // 2, pad_wd - pad_wd // 2, 0, pad_ht]
|
||||
|
||||
def pad(self, *inputs):
|
||||
return [F.pad(x, self._pad, mode="replicate") for x in inputs]
|
||||
|
||||
def unpad(self, x):
|
||||
ht, wd = x.shape[-2:]
|
||||
c = [self._pad[2], ht - self._pad[3], self._pad[0], wd - self._pad[1]]
|
||||
return x[..., c[0] : c[1], c[2] : c[3]]
|
||||
|
||||
@@ -171,9 +171,7 @@ class UFOGenScheduler(SchedulerMixin, ConfigMixin):
|
||||
The way the timesteps should be scaled. Refer to Table 2 of the [Common Diffusion Noise Schedules and
|
||||
Sample Steps are Flawed](https://huggingface.co/papers/2305.08891) for more information.
|
||||
steps_offset (`int`, defaults to 0):
|
||||
An offset added to the inference steps. You can use a combination of `offset=1` and
|
||||
`set_alpha_to_one=False` to make the last step use step 0 for the previous alpha product like in Stable
|
||||
Diffusion.
|
||||
An offset added to the inference steps, as required by some model families.
|
||||
rescale_betas_zero_snr (`bool`, defaults to `False`):
|
||||
Whether to rescale the betas to have zero terminal SNR. This enables the model to generate very bright and
|
||||
dark samples instead of limiting it to samples with medium brightness. Loosely related to
|
||||
|
||||
@@ -72,7 +72,7 @@ if is_wandb_available():
|
||||
import wandb
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.27.0.dev0")
|
||||
check_min_version("0.28.0.dev0")
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
@@ -308,7 +308,7 @@ def log_validation(vae, unet, args, accelerator, weight_dtype, step):
|
||||
|
||||
tracker.log({"validation": formatted_images})
|
||||
else:
|
||||
logger.warn(f"image logging not implemented for {tracker.name}")
|
||||
logger.warning(f"image logging not implemented for {tracker.name}")
|
||||
|
||||
del pipeline
|
||||
gc.collect()
|
||||
@@ -1068,7 +1068,7 @@ def main(args):
|
||||
|
||||
xformers_version = version.parse(xformers.__version__)
|
||||
if xformers_version == version.parse("0.0.16"):
|
||||
logger.warn(
|
||||
logger.warning(
|
||||
"xFormers 0.0.16 cannot be used for training in some GPUs. If you observe problems during training, please update xFormers to at least 0.0.17. See https://huggingface.co/docs/diffusers/main/en/optimization/xformers for more details."
|
||||
)
|
||||
unet.enable_xformers_memory_efficient_attention()
|
||||
|
||||
@@ -65,7 +65,7 @@ if is_wandb_available():
|
||||
import wandb
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.27.0.dev0")
|
||||
check_min_version("0.28.0.dev0")
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
@@ -180,7 +180,7 @@ def log_validation(vae, args, accelerator, weight_dtype, step, unet=None, is_fin
|
||||
logger_name = "test" if is_final_validation else "validation"
|
||||
tracker.log({logger_name: formatted_images})
|
||||
else:
|
||||
logger.warn(f"image logging not implemented for {tracker.name}")
|
||||
logger.warning(f"image logging not implemented for {tracker.name}")
|
||||
|
||||
del pipeline
|
||||
gc.collect()
|
||||
@@ -928,7 +928,7 @@ def main(args):
|
||||
|
||||
xformers_version = version.parse(xformers.__version__)
|
||||
if xformers_version == version.parse("0.0.16"):
|
||||
logger.warn(
|
||||
logger.warning(
|
||||
"xFormers 0.0.16 cannot be used for training in some GPUs. If you observe problems during training, please update xFormers to at least 0.0.17. See https://huggingface.co/docs/diffusers/main/en/optimization/xformers for more details."
|
||||
)
|
||||
unet.enable_xformers_memory_efficient_attention()
|
||||
|
||||
@@ -78,7 +78,7 @@ if is_wandb_available():
|
||||
import wandb
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.27.0.dev0")
|
||||
check_min_version("0.28.0.dev0")
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
@@ -325,7 +325,7 @@ def log_validation(vae, unet, args, accelerator, weight_dtype, step):
|
||||
|
||||
tracker.log({"validation": formatted_images})
|
||||
else:
|
||||
logger.warn(f"image logging not implemented for {tracker.name}")
|
||||
logger.warning(f"image logging not implemented for {tracker.name}")
|
||||
|
||||
del pipeline
|
||||
gc.collect()
|
||||
@@ -1083,7 +1083,7 @@ def main(args):
|
||||
|
||||
xformers_version = version.parse(xformers.__version__)
|
||||
if xformers_version == version.parse("0.0.16"):
|
||||
logger.warn(
|
||||
logger.warning(
|
||||
"xFormers 0.0.16 cannot be used for training in some GPUs. If you observe problems during training, please update xFormers to at least 0.0.17. See https://huggingface.co/docs/diffusers/main/en/optimization/xformers for more details."
|
||||
)
|
||||
unet.enable_xformers_memory_efficient_attention()
|
||||
|
||||
@@ -71,7 +71,7 @@ if is_wandb_available():
|
||||
import wandb
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.27.0.dev0")
|
||||
check_min_version("0.28.0.dev0")
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
@@ -285,7 +285,7 @@ def log_validation(vae, unet, args, accelerator, weight_dtype, step, name="targe
|
||||
|
||||
tracker.log({f"validation/{name}": formatted_images})
|
||||
else:
|
||||
logger.warn(f"image logging not implemented for {tracker.name}")
|
||||
logger.warning(f"image logging not implemented for {tracker.name}")
|
||||
|
||||
del pipeline
|
||||
gc.collect()
|
||||
@@ -1023,7 +1023,7 @@ def main(args):
|
||||
|
||||
xformers_version = version.parse(xformers.__version__)
|
||||
if xformers_version == version.parse("0.0.16"):
|
||||
logger.warn(
|
||||
logger.warning(
|
||||
"xFormers 0.0.16 cannot be used for training in some GPUs. If you observe problems during training, please update xFormers to at least 0.0.17. See https://huggingface.co/docs/diffusers/main/en/optimization/xformers for more details."
|
||||
)
|
||||
unet.enable_xformers_memory_efficient_attention()
|
||||
|
||||
@@ -77,7 +77,7 @@ if is_wandb_available():
|
||||
import wandb
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.27.0.dev0")
|
||||
check_min_version("0.28.0.dev0")
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
@@ -303,7 +303,7 @@ def log_validation(vae, unet, args, accelerator, weight_dtype, step, name="targe
|
||||
|
||||
tracker.log({f"validation/{name}": formatted_images})
|
||||
else:
|
||||
logger.warn(f"image logging not implemented for {tracker.name}")
|
||||
logger.warning(f"image logging not implemented for {tracker.name}")
|
||||
|
||||
del pipeline
|
||||
gc.collect()
|
||||
@@ -1083,7 +1083,7 @@ def main(args):
|
||||
|
||||
xformers_version = version.parse(xformers.__version__)
|
||||
if xformers_version == version.parse("0.0.16"):
|
||||
logger.warn(
|
||||
logger.warning(
|
||||
"xFormers 0.0.16 cannot be used for training in some GPUs. If you observe problems during training, please update xFormers to at least 0.0.17. See https://huggingface.co/docs/diffusers/main/en/optimization/xformers for more details."
|
||||
)
|
||||
unet.enable_xformers_memory_efficient_attention()
|
||||
|
||||
@@ -60,7 +60,7 @@ if is_wandb_available():
|
||||
import wandb
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.27.0.dev0")
|
||||
check_min_version("0.28.0.dev0")
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
@@ -178,7 +178,7 @@ def log_validation(
|
||||
|
||||
tracker.log({tracker_key: formatted_images})
|
||||
else:
|
||||
logger.warn(f"image logging not implemented for {tracker.name}")
|
||||
logger.warning(f"image logging not implemented for {tracker.name}")
|
||||
|
||||
del pipeline
|
||||
gc.collect()
|
||||
@@ -861,7 +861,7 @@ def main(args):
|
||||
|
||||
xformers_version = version.parse(xformers.__version__)
|
||||
if xformers_version == version.parse("0.0.16"):
|
||||
logger.warn(
|
||||
logger.warning(
|
||||
"xFormers 0.0.16 cannot be used for training in some GPUs. If you observe problems during training, please update xFormers to at least 0.0.17. See https://huggingface.co/docs/diffusers/main/en/optimization/xformers for more details."
|
||||
)
|
||||
unet.enable_xformers_memory_efficient_attention()
|
||||
|
||||
@@ -60,7 +60,7 @@ if is_wandb_available():
|
||||
import wandb
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.27.0.dev0")
|
||||
check_min_version("0.28.0.dev0")
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -128,7 +128,7 @@ def log_validation(pipeline, pipeline_params, controlnet_params, tokenizer, args
|
||||
|
||||
wandb.log({"validation": formatted_images})
|
||||
else:
|
||||
logger.warn(f"image logging not implemented for {args.report_to}")
|
||||
logger.warning(f"image logging not implemented for {args.report_to}")
|
||||
|
||||
return image_logs
|
||||
|
||||
|
||||
@@ -61,7 +61,7 @@ if is_wandb_available():
|
||||
import wandb
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.27.0.dev0")
|
||||
check_min_version("0.28.0.dev0")
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
@@ -178,7 +178,7 @@ def log_validation(vae, unet, controlnet, args, accelerator, weight_dtype, step,
|
||||
|
||||
tracker.log({tracker_key: formatted_images})
|
||||
else:
|
||||
logger.warn(f"image logging not implemented for {tracker.name}")
|
||||
logger.warning(f"image logging not implemented for {tracker.name}")
|
||||
|
||||
del pipeline
|
||||
gc.collect()
|
||||
@@ -929,7 +929,7 @@ def main(args):
|
||||
|
||||
xformers_version = version.parse(xformers.__version__)
|
||||
if xformers_version == version.parse("0.0.16"):
|
||||
logger.warn(
|
||||
logger.warning(
|
||||
"xFormers 0.0.16 cannot be used for training in some GPUs. If you observe problems during training, please update xFormers to at least 0.0.17. See https://huggingface.co/docs/diffusers/main/en/optimization/xformers for more details."
|
||||
)
|
||||
unet.enable_xformers_memory_efficient_attention()
|
||||
|
||||
@@ -63,7 +63,7 @@ from diffusers.utils.import_utils import is_xformers_available
|
||||
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.27.0.dev0")
|
||||
check_min_version("0.28.0.dev0")
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
@@ -904,7 +904,7 @@ def main(args):
|
||||
|
||||
xformers_version = version.parse(xformers.__version__)
|
||||
if xformers_version == version.parse("0.0.16"):
|
||||
logger.warn(
|
||||
logger.warning(
|
||||
"xFormers 0.0.16 cannot be used for training in some GPUs. If you observe problems during training, please update xFormers to at least 0.0.17. See https://huggingface.co/docs/diffusers/main/en/optimization/xformers for more details."
|
||||
)
|
||||
attention_class = CustomDiffusionXFormersAttnProcessor
|
||||
@@ -1178,7 +1178,7 @@ def main(args):
|
||||
grads_text_encoder = text_encoder.get_input_embeddings().weight.grad
|
||||
# Get the index for tokens that we want to zero the grads for
|
||||
index_grads_to_zero = torch.arange(len(tokenizer)) != modifier_token_id[0]
|
||||
for i in range(len(modifier_token_id[1:])):
|
||||
for i in range(1, len(modifier_token_id)):
|
||||
index_grads_to_zero = index_grads_to_zero & (
|
||||
torch.arange(len(tokenizer)) != modifier_token_id[i]
|
||||
)
|
||||
|
||||
@@ -63,7 +63,7 @@ if is_wandb_available():
|
||||
import wandb
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.27.0.dev0")
|
||||
check_min_version("0.28.0.dev0")
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
@@ -987,7 +987,7 @@ def main(args):
|
||||
|
||||
xformers_version = version.parse(xformers.__version__)
|
||||
if xformers_version == version.parse("0.0.16"):
|
||||
logger.warn(
|
||||
logger.warning(
|
||||
"xFormers 0.0.16 cannot be used for training in some GPUs. If you observe problems during training, please update xFormers to at least 0.0.17. See https://huggingface.co/docs/diffusers/main/en/optimization/xformers for more details."
|
||||
)
|
||||
unet.enable_xformers_memory_efficient_attention()
|
||||
|
||||
@@ -35,7 +35,7 @@ from diffusers.utils import check_min_version
|
||||
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.27.0.dev0")
|
||||
check_min_version("0.28.0.dev0")
|
||||
|
||||
# Cache compiled models across invocations of this script.
|
||||
cc.initialize_cache(os.path.expanduser("~/.cache/jax/compilation_cache"))
|
||||
|
||||
@@ -70,7 +70,7 @@ if is_wandb_available():
|
||||
import wandb
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.27.0.dev0")
|
||||
check_min_version("0.28.0.dev0")
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
@@ -895,7 +895,7 @@ def main(args):
|
||||
|
||||
xformers_version = version.parse(xformers.__version__)
|
||||
if xformers_version == version.parse("0.0.16"):
|
||||
logger.warn(
|
||||
logger.warning(
|
||||
"xFormers 0.0.16 cannot be used for training in some GPUs. If you observe problems during training, please update xFormers to at least 0.0.17. See https://huggingface.co/docs/diffusers/main/en/optimization/xformers for more details."
|
||||
)
|
||||
unet.enable_xformers_memory_efficient_attention()
|
||||
|
||||
@@ -75,7 +75,7 @@ if is_wandb_available():
|
||||
import wandb
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.27.0.dev0")
|
||||
check_min_version("0.28.0.dev0")
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
@@ -114,7 +114,7 @@ def save_model_card(
|
||||
)
|
||||
|
||||
model_description = f"""
|
||||
# {'SDXL' if 'playgroundai' not in base_model else 'Playground'} LoRA DreamBooth - {repo_id}
|
||||
# {'SDXL' if 'playground' not in base_model else 'Playground'} LoRA DreamBooth - {repo_id}
|
||||
|
||||
<Gallery />
|
||||
|
||||
@@ -139,7 +139,7 @@ Weights for this model are available in Safetensors format.
|
||||
[Download]({repo_id}/tree/main) them in the Files & versions tab.
|
||||
|
||||
"""
|
||||
if "playgroundai" in args.pretrained_model_name_or_path:
|
||||
if "playground" in base_model:
|
||||
model_description += """\n
|
||||
## License
|
||||
|
||||
@@ -148,7 +148,7 @@ Please adhere to the licensing terms as described [here](https://huggingface.co/
|
||||
model_card = load_or_create_model_card(
|
||||
repo_id_or_path=repo_id,
|
||||
from_training=True,
|
||||
license="openrail++" if "playgroundai" not in base_model else "playground-v2dot5-community",
|
||||
license="openrail++" if "playground" not in base_model else "playground-v2dot5-community",
|
||||
base_model=base_model,
|
||||
prompt=instance_prompt,
|
||||
model_description=model_description,
|
||||
@@ -162,7 +162,7 @@ Please adhere to the licensing terms as described [here](https://huggingface.co/
|
||||
"lora" if not use_dora else "dora",
|
||||
"template:sd-lora",
|
||||
]
|
||||
if "playgroundai" in base_model:
|
||||
if "playground" in base_model:
|
||||
tags.extend(["playground", "playground-diffusers"])
|
||||
else:
|
||||
tags.extend(["stable-diffusion-xl", "stable-diffusion-xl-diffusers"])
|
||||
@@ -206,7 +206,7 @@ def log_validation(
|
||||
# Currently the context determination is a bit hand-wavy. We can improve it in the future if there's a better
|
||||
# way to condition it. Reference: https://github.com/huggingface/diffusers/pull/7126#issuecomment-1968523051
|
||||
inference_ctx = (
|
||||
contextlib.nullcontext() if "playgroundai" in args.pretrained_model_name_or_path else torch.cuda.amp.autocast()
|
||||
contextlib.nullcontext() if "playground" in args.pretrained_model_name_or_path else torch.cuda.amp.autocast()
|
||||
)
|
||||
|
||||
with inference_ctx:
|
||||
@@ -1141,7 +1141,7 @@ def main(args):
|
||||
|
||||
xformers_version = version.parse(xformers.__version__)
|
||||
if xformers_version == version.parse("0.0.16"):
|
||||
logger.warn(
|
||||
logger.warning(
|
||||
"xFormers 0.0.16 cannot be used for training in some GPUs. If you observe problems during training, "
|
||||
"please update xFormers to at least 0.0.17. See https://huggingface.co/docs/diffusers/main/en/optimization/xformers for more details."
|
||||
)
|
||||
@@ -1317,14 +1317,14 @@ def main(args):
|
||||
|
||||
# Optimizer creation
|
||||
if not (args.optimizer.lower() == "prodigy" or args.optimizer.lower() == "adamw"):
|
||||
logger.warn(
|
||||
logger.warning(
|
||||
f"Unsupported choice of optimizer: {args.optimizer}.Supported optimizers include [adamW, prodigy]."
|
||||
"Defaulting to adamW"
|
||||
)
|
||||
args.optimizer = "adamw"
|
||||
|
||||
if args.use_8bit_adam and not args.optimizer.lower() == "adamw":
|
||||
logger.warn(
|
||||
logger.warning(
|
||||
f"use_8bit_adam is ignored when optimizer is not set to 'AdamW'. Optimizer was "
|
||||
f"set to {args.optimizer.lower()}"
|
||||
)
|
||||
@@ -1358,11 +1358,11 @@ def main(args):
|
||||
optimizer_class = prodigyopt.Prodigy
|
||||
|
||||
if args.learning_rate <= 0.1:
|
||||
logger.warn(
|
||||
logger.warning(
|
||||
"Learning rate is too low. When using prodigy, it's generally better to set learning rate around 1.0"
|
||||
)
|
||||
if args.train_text_encoder and args.text_encoder_lr:
|
||||
logger.warn(
|
||||
logger.warning(
|
||||
f"Learning rates were provided both for the unet and the text encoder- e.g. text_encoder_lr:"
|
||||
f" {args.text_encoder_lr} and learning_rate: {args.learning_rate}. "
|
||||
f"When using prodigy only learning_rate is used as the initial learning rate."
|
||||
@@ -1509,7 +1509,7 @@ def main(args):
|
||||
if accelerator.is_main_process:
|
||||
tracker_name = (
|
||||
"dreambooth-lora-sd-xl"
|
||||
if "playgroundai" not in args.pretrained_model_name_or_path
|
||||
if "playground" not in args.pretrained_model_name_or_path
|
||||
else "dreambooth-lora-playground"
|
||||
)
|
||||
accelerator.init_trackers(tracker_name, config=vars(args))
|
||||
|
||||
@@ -53,7 +53,7 @@ from diffusers.utils.torch_utils import is_compiled_module
|
||||
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.27.0.dev0")
|
||||
check_min_version("0.28.0.dev0")
|
||||
|
||||
logger = get_logger(__name__, log_level="INFO")
|
||||
|
||||
@@ -488,7 +488,7 @@ def main():
|
||||
|
||||
xformers_version = version.parse(xformers.__version__)
|
||||
if xformers_version == version.parse("0.0.16"):
|
||||
logger.warn(
|
||||
logger.warning(
|
||||
"xFormers 0.0.16 cannot be used for training in some GPUs. If you observe problems during training, please update xFormers to at least 0.0.17. See https://huggingface.co/docs/diffusers/main/en/optimization/xformers for more details."
|
||||
)
|
||||
unet.enable_xformers_memory_efficient_attention()
|
||||
|
||||
@@ -59,7 +59,7 @@ if is_wandb_available():
|
||||
import wandb
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.27.0.dev0")
|
||||
check_min_version("0.28.0.dev0")
|
||||
|
||||
logger = get_logger(__name__, log_level="INFO")
|
||||
|
||||
@@ -580,7 +580,7 @@ def main():
|
||||
|
||||
xformers_version = version.parse(xformers.__version__)
|
||||
if xformers_version == version.parse("0.0.16"):
|
||||
logger.warn(
|
||||
logger.warning(
|
||||
"xFormers 0.0.16 cannot be used for training in some GPUs. If you observe problems during training, please update xFormers to at least 0.0.17. See https://huggingface.co/docs/diffusers/main/en/optimization/xformers for more details."
|
||||
)
|
||||
unet.enable_xformers_memory_efficient_attention()
|
||||
|
||||
@@ -52,7 +52,7 @@ if is_wandb_available():
|
||||
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.27.0.dev0")
|
||||
check_min_version("0.28.0.dev0")
|
||||
|
||||
logger = get_logger(__name__, log_level="INFO")
|
||||
|
||||
@@ -177,7 +177,7 @@ def log_validation(vae, image_encoder, image_processor, unet, args, accelerator,
|
||||
}
|
||||
)
|
||||
else:
|
||||
logger.warn(f"image logging not implemented for {tracker.name}")
|
||||
logger.warning(f"image logging not implemented for {tracker.name}")
|
||||
|
||||
del pipeline
|
||||
torch.cuda.empty_cache()
|
||||
@@ -534,7 +534,7 @@ def main():
|
||||
|
||||
xformers_version = version.parse(xformers.__version__)
|
||||
if xformers_version == version.parse("0.0.16"):
|
||||
logger.warn(
|
||||
logger.warning(
|
||||
"xFormers 0.0.16 cannot be used for training in some GPUs. If you observe problems during training, please update xFormers to at least 0.0.17. See https://huggingface.co/docs/diffusers/main/en/optimization/xformers for more details."
|
||||
)
|
||||
unet.enable_xformers_memory_efficient_attention()
|
||||
|
||||
@@ -46,7 +46,7 @@ from diffusers.utils import check_min_version, is_wandb_available
|
||||
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.27.0.dev0")
|
||||
check_min_version("0.28.0.dev0")
|
||||
|
||||
logger = get_logger(__name__, log_level="INFO")
|
||||
|
||||
|
||||
@@ -46,7 +46,7 @@ from diffusers.utils import check_min_version, is_wandb_available
|
||||
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.27.0.dev0")
|
||||
check_min_version("0.28.0.dev0")
|
||||
|
||||
logger = get_logger(__name__, log_level="INFO")
|
||||
|
||||
|
||||
@@ -51,7 +51,7 @@ if is_wandb_available():
|
||||
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.27.0.dev0")
|
||||
check_min_version("0.28.0.dev0")
|
||||
|
||||
logger = get_logger(__name__, log_level="INFO")
|
||||
|
||||
@@ -180,7 +180,7 @@ def log_validation(
|
||||
}
|
||||
)
|
||||
else:
|
||||
logger.warn(f"image logging not implemented for {tracker.name}")
|
||||
logger.warning(f"image logging not implemented for {tracker.name}")
|
||||
|
||||
del pipeline
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
@@ -219,7 +219,7 @@ def log_validation(unet, scheduler, args, accelerator, weight_dtype, step, name=
|
||||
if args.num_classes is not None:
|
||||
class_labels = list(range(args.num_classes))
|
||||
else:
|
||||
logger.warn(
|
||||
logger.warning(
|
||||
"The model is class-conditional but the number of classes is not set. The generated images will be"
|
||||
" unconditional rather than class-conditional."
|
||||
)
|
||||
@@ -266,7 +266,7 @@ def log_validation(unet, scheduler, args, accelerator, weight_dtype, step, name=
|
||||
|
||||
tracker.log({f"validation/{name}": formatted_images})
|
||||
else:
|
||||
logger.warn(f"image logging not implemented for {tracker.name}")
|
||||
logger.warning(f"image logging not implemented for {tracker.name}")
|
||||
|
||||
del pipeline
|
||||
gc.collect()
|
||||
@@ -863,14 +863,14 @@ def main(args):
|
||||
elif args.model_config_name_or_path is None:
|
||||
# TODO: use default architectures from iCT paper
|
||||
if not args.class_conditional and (args.num_classes is not None or args.class_embed_type is not None):
|
||||
logger.warn(
|
||||
logger.warning(
|
||||
f"`--class_conditional` is set to `False` but `--num_classes` is set to {args.num_classes} and"
|
||||
f" `--class_embed_type` is set to {args.class_embed_type}. These values will be overridden to `None`."
|
||||
)
|
||||
args.num_classes = None
|
||||
args.class_embed_type = None
|
||||
elif args.class_conditional and args.num_classes is None and args.class_embed_type is None:
|
||||
logger.warn(
|
||||
logger.warning(
|
||||
"`--class_conditional` is set to `True` but neither `--num_classes` nor `--class_embed_type` is set."
|
||||
"`class_conditional` will be overridden to `False`."
|
||||
)
|
||||
@@ -996,7 +996,7 @@ def main(args):
|
||||
|
||||
xformers_version = version.parse(xformers.__version__)
|
||||
if xformers_version == version.parse("0.0.16"):
|
||||
logger.warn(
|
||||
logger.warning(
|
||||
"xFormers 0.0.16 cannot be used for training in some GPUs. If you observe problems during training, please update xFormers to at least 0.0.17. See https://huggingface.co/docs/diffusers/main/en/optimization/xformers for more details."
|
||||
)
|
||||
unet.enable_xformers_memory_efficient_attention()
|
||||
|
||||
@@ -407,7 +407,7 @@ def log_validation(vae, unet, controlnet, args, accelerator, weight_dtype, step)
|
||||
|
||||
tracker.log({"validation": formatted_images})
|
||||
else:
|
||||
logger.warn(f"image logging not implemented for {tracker.name}")
|
||||
logger.warning(f"image logging not implemented for {tracker.name}")
|
||||
|
||||
del pipeline
|
||||
gc.collect()
|
||||
@@ -1057,7 +1057,7 @@ def main(args):
|
||||
|
||||
xformers_version = version.parse(xformers.__version__)
|
||||
if xformers_version == version.parse("0.0.16"):
|
||||
logger.warn(
|
||||
logger.warning(
|
||||
"xFormers 0.0.16 cannot be used for training in some GPUs. If you observe problems during training, please update xFormers to at least 0.0.17. See https://huggingface.co/docs/diffusers/main/en/optimization/xformers for more details."
|
||||
)
|
||||
unet.enable_xformers_memory_efficient_attention()
|
||||
|
||||
@@ -574,7 +574,7 @@ def main(args):
|
||||
|
||||
xformers_version = version.parse(xformers.__version__)
|
||||
if xformers_version == version.parse("0.0.16"):
|
||||
logger.warn(
|
||||
logger.warning(
|
||||
"xFormers 0.0.16 cannot be used for training in some GPUs. If you observe problems during training, please update xFormers to at least 0.0.17. See https://huggingface.co/docs/diffusers/main/en/optimization/xformers for more details."
|
||||
)
|
||||
unet.enable_xformers_memory_efficient_attention()
|
||||
|
||||
@@ -672,7 +672,7 @@ def main(args):
|
||||
|
||||
xformers_version = version.parse(xformers.__version__)
|
||||
if xformers_version == version.parse("0.0.16"):
|
||||
logger.warn(
|
||||
logger.warning(
|
||||
"xFormers 0.0.16 cannot be used for training in some GPUs. If you observe problems during training, please update xFormers to at least 0.0.17. See https://huggingface.co/docs/diffusers/main/en/optimization/xformers for more details."
|
||||
)
|
||||
unet.enable_xformers_memory_efficient_attention()
|
||||
|
||||
@@ -516,7 +516,7 @@ def main():
|
||||
|
||||
xformers_version = version.parse(xformers.__version__)
|
||||
if xformers_version == version.parse("0.0.16"):
|
||||
logger.warn(
|
||||
logger.warning(
|
||||
"xFormers 0.0.16 cannot be used for training in some GPUs. If you observe problems during training, please update xFormers to at least 0.0.17. See https://huggingface.co/docs/diffusers/main/en/optimization/xformers for more details."
|
||||
)
|
||||
unet.enable_xformers_memory_efficient_attention()
|
||||
|
||||
@@ -608,7 +608,7 @@ def main():
|
||||
# Create the pipeline using using the trained modules and save it.
|
||||
if accelerator.is_main_process:
|
||||
if args.push_to_hub and args.only_save_embeds:
|
||||
logger.warn("Enabling full model saving because --push_to_hub=True was specified.")
|
||||
logger.warning("Enabling full model saving because --push_to_hub=True was specified.")
|
||||
save_full_model = True
|
||||
else:
|
||||
save_full_model = not args.only_save_embeds
|
||||
|
||||
@@ -541,7 +541,7 @@ def main():
|
||||
|
||||
xformers_version = version.parse(xformers.__version__)
|
||||
if xformers_version == version.parse("0.0.16"):
|
||||
logger.warn(
|
||||
logger.warning(
|
||||
"xFormers 0.0.16 cannot be used for training in some GPUs. If you observe problems during training, please update xFormers to at least 0.0.17. See https://huggingface.co/docs/diffusers/main/en/optimization/xformers for more details."
|
||||
)
|
||||
unet.enable_xformers_memory_efficient_attention()
|
||||
|
||||
@@ -645,7 +645,7 @@ def main():
|
||||
|
||||
xformers_version = version.parse(xformers.__version__)
|
||||
if xformers_version == version.parse("0.0.16"):
|
||||
logger.warn(
|
||||
logger.warning(
|
||||
"xFormers 0.0.16 cannot be used for training in some GPUs. If you observe problems during training, please update xFormers to at least 0.0.17. See https://huggingface.co/docs/diffusers/main/en/optimization/xformers for more details."
|
||||
)
|
||||
unet.enable_xformers_memory_efficient_attention()
|
||||
@@ -901,7 +901,7 @@ def main():
|
||||
accelerator.wait_for_everyone()
|
||||
if accelerator.is_main_process:
|
||||
if args.push_to_hub and args.only_save_embeds:
|
||||
logger.warn("Enabling full model saving because --push_to_hub=True was specified.")
|
||||
logger.warning("Enabling full model saving because --push_to_hub=True was specified.")
|
||||
save_full_model = True
|
||||
else:
|
||||
save_full_model = not args.only_save_embeds
|
||||
|
||||
@@ -108,7 +108,7 @@ def log_validation(vae, text_encoder, tokenizer, unet, args, accelerator, weight
|
||||
}
|
||||
)
|
||||
else:
|
||||
logger.warn(f"image logging not implemented for {tracker.name}")
|
||||
logger.warning(f"image logging not implemented for {tracker.name}")
|
||||
|
||||
del pipeline
|
||||
torch.cuda.empty_cache()
|
||||
@@ -523,7 +523,7 @@ def main():
|
||||
|
||||
xformers_version = version.parse(xformers.__version__)
|
||||
if xformers_version == version.parse("0.0.16"):
|
||||
logger.warn(
|
||||
logger.warning(
|
||||
"xFormers 0.0.16 cannot be used for training in some GPUs. If you observe problems during training, please update xFormers to at least 0.0.17. See https://huggingface.co/docs/diffusers/main/en/optimization/xformers for more details."
|
||||
)
|
||||
unet.enable_xformers_memory_efficient_attention()
|
||||
|
||||
@@ -687,7 +687,7 @@ def main():
|
||||
|
||||
xformers_version = version.parse(xformers.__version__)
|
||||
if xformers_version == version.parse("0.0.16"):
|
||||
logger.warn(
|
||||
logger.warning(
|
||||
"xFormers 0.0.16 cannot be used for training in some GPUs. If you observe problems during training, please update xFormers to at least 0.0.17. See https://huggingface.co/docs/diffusers/main/en/optimization/xformers for more details."
|
||||
)
|
||||
unet.enable_xformers_memory_efficient_attention()
|
||||
@@ -916,7 +916,7 @@ def main():
|
||||
accelerator.wait_for_everyone()
|
||||
if accelerator.is_main_process:
|
||||
if args.push_to_hub and not args.save_as_full_pipeline:
|
||||
logger.warn("Enabling full model saving because --push_to_hub=True was specified.")
|
||||
logger.warning("Enabling full model saving because --push_to_hub=True was specified.")
|
||||
save_full_model = True
|
||||
else:
|
||||
save_full_model = args.save_as_full_pipeline
|
||||
|
||||
+2
-2
@@ -410,7 +410,7 @@ def main(args):
|
||||
|
||||
xformers_version = version.parse(xformers.__version__)
|
||||
if xformers_version == version.parse("0.0.16"):
|
||||
logger.warn(
|
||||
logger.warning(
|
||||
"xFormers 0.0.16 cannot be used for training in some GPUs. If you observe problems during training, please update xFormers to at least 0.0.17. See https://huggingface.co/docs/diffusers/main/en/optimization/xformers for more details."
|
||||
)
|
||||
model.enable_xformers_memory_efficient_attention()
|
||||
@@ -637,7 +637,7 @@ def main(args):
|
||||
generator=generator,
|
||||
batch_size=args.eval_batch_size,
|
||||
num_inference_steps=args.ddpm_num_inference_steps,
|
||||
output_type="numpy",
|
||||
output_type="np",
|
||||
).images
|
||||
|
||||
if args.use_ema:
|
||||
|
||||
@@ -629,7 +629,7 @@ def main(args):
|
||||
|
||||
xformers_version = version.parse(xformers.__version__)
|
||||
if xformers_version == version.parse("0.0.16"):
|
||||
logger.warn(
|
||||
logger.warning(
|
||||
"xFormers 0.0.16 cannot be used for training in some GPUs. If you observe problems during training, please update xFormers to at least 0.0.17. See https://huggingface.co/docs/diffusers/main/en/optimization/xformers for more details."
|
||||
)
|
||||
unet.enable_xformers_memory_efficient_attention()
|
||||
|
||||
@@ -60,7 +60,7 @@ if is_wandb_available():
|
||||
import wandb
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.27.0.dev0")
|
||||
check_min_version("0.28.0.dev0")
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
@@ -167,7 +167,7 @@ def log_validation(vae, unet, adapter, args, accelerator, weight_dtype, step):
|
||||
|
||||
tracker.log({"validation": formatted_images})
|
||||
else:
|
||||
logger.warn(f"image logging not implemented for {tracker.name}")
|
||||
logger.warning(f"image logging not implemented for {tracker.name}")
|
||||
|
||||
del pipeline
|
||||
gc.collect()
|
||||
@@ -932,7 +932,7 @@ def main(args):
|
||||
|
||||
xformers_version = version.parse(xformers.__version__)
|
||||
if xformers_version == version.parse("0.0.16"):
|
||||
logger.warn(
|
||||
logger.warning(
|
||||
"xFormers 0.0.16 cannot be used for training in some GPUs. If you observe problems during training, please update xFormers to at least 0.0.17. See https://huggingface.co/docs/diffusers/main/en/optimization/xformers for more details."
|
||||
)
|
||||
unet.enable_xformers_memory_efficient_attention()
|
||||
|
||||
@@ -56,7 +56,7 @@ if is_wandb_available():
|
||||
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.27.0.dev0")
|
||||
check_min_version("0.28.0.dev0")
|
||||
|
||||
logger = get_logger(__name__, log_level="INFO")
|
||||
|
||||
@@ -183,7 +183,7 @@ def log_validation(vae, text_encoder, tokenizer, unet, args, accelerator, weight
|
||||
}
|
||||
)
|
||||
else:
|
||||
logger.warn(f"image logging not implemented for {tracker.name}")
|
||||
logger.warning(f"image logging not implemented for {tracker.name}")
|
||||
|
||||
del pipeline
|
||||
torch.cuda.empty_cache()
|
||||
@@ -608,7 +608,7 @@ def main():
|
||||
|
||||
xformers_version = version.parse(xformers.__version__)
|
||||
if xformers_version == version.parse("0.0.16"):
|
||||
logger.warn(
|
||||
logger.warning(
|
||||
"xFormers 0.0.16 cannot be used for training in some GPUs. If you observe problems during training, please update xFormers to at least 0.0.17. See https://huggingface.co/docs/diffusers/main/en/optimization/xformers for more details."
|
||||
)
|
||||
unet.enable_xformers_memory_efficient_attention()
|
||||
|
||||
@@ -49,7 +49,7 @@ from diffusers.utils import check_min_version
|
||||
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.27.0.dev0")
|
||||
check_min_version("0.28.0.dev0")
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@@ -52,7 +52,7 @@ from diffusers.utils.torch_utils import is_compiled_module
|
||||
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.27.0.dev0")
|
||||
check_min_version("0.28.0.dev0")
|
||||
|
||||
logger = get_logger(__name__, log_level="INFO")
|
||||
|
||||
@@ -497,7 +497,7 @@ def main():
|
||||
|
||||
xformers_version = version.parse(xformers.__version__)
|
||||
if xformers_version == version.parse("0.0.16"):
|
||||
logger.warn(
|
||||
logger.warning(
|
||||
"xFormers 0.0.16 cannot be used for training in some GPUs. If you observe problems during training, please update xFormers to at least 0.0.17. See https://huggingface.co/docs/diffusers/main/en/optimization/xformers for more details."
|
||||
)
|
||||
unet.enable_xformers_memory_efficient_attention()
|
||||
|
||||
@@ -64,7 +64,7 @@ from diffusers.utils.torch_utils import is_compiled_module
|
||||
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.27.0.dev0")
|
||||
check_min_version("0.28.0.dev0")
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
@@ -425,6 +425,11 @@ def parse_args(input_args=None):
|
||||
default=4,
|
||||
help=("The dimension of the LoRA update matrices."),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--debug_loss",
|
||||
action="store_true",
|
||||
help="debug loss for each image, if filenames are awailable in the dataset",
|
||||
)
|
||||
|
||||
if input_args is not None:
|
||||
args = parser.parse_args(input_args)
|
||||
@@ -603,6 +608,7 @@ def main(args):
|
||||
# Move unet, vae and text_encoder to device and cast to weight_dtype
|
||||
# The VAE is in float32 to avoid NaN losses.
|
||||
unet.to(accelerator.device, dtype=weight_dtype)
|
||||
|
||||
if args.pretrained_vae_model_name_or_path is None:
|
||||
vae.to(accelerator.device, dtype=torch.float32)
|
||||
else:
|
||||
@@ -616,7 +622,7 @@ def main(args):
|
||||
|
||||
xformers_version = version.parse(xformers.__version__)
|
||||
if xformers_version == version.parse("0.0.16"):
|
||||
logger.warn(
|
||||
logger.warning(
|
||||
"xFormers 0.0.16 cannot be used for training in some GPUs. If you observe problems during training, please update xFormers to at least 0.0.17. See https://huggingface.co/docs/diffusers/main/en/optimization/xformers for more details."
|
||||
)
|
||||
unet.enable_xformers_memory_efficient_attention()
|
||||
@@ -890,13 +896,17 @@ def main(args):
|
||||
tokens_one, tokens_two = tokenize_captions(examples)
|
||||
examples["input_ids_one"] = tokens_one
|
||||
examples["input_ids_two"] = tokens_two
|
||||
if args.debug_loss:
|
||||
fnames = [os.path.basename(image.filename) for image in examples[image_column] if image.filename]
|
||||
if fnames:
|
||||
examples["filenames"] = fnames
|
||||
return examples
|
||||
|
||||
with accelerator.main_process_first():
|
||||
if args.max_train_samples is not None:
|
||||
dataset["train"] = dataset["train"].shuffle(seed=args.seed).select(range(args.max_train_samples))
|
||||
# Set the training transforms
|
||||
train_dataset = dataset["train"].with_transform(preprocess_train)
|
||||
train_dataset = dataset["train"].with_transform(preprocess_train, output_all_columns=True)
|
||||
|
||||
def collate_fn(examples):
|
||||
pixel_values = torch.stack([example["pixel_values"] for example in examples])
|
||||
@@ -905,7 +915,7 @@ def main(args):
|
||||
crop_top_lefts = [example["crop_top_lefts"] for example in examples]
|
||||
input_ids_one = torch.stack([example["input_ids_one"] for example in examples])
|
||||
input_ids_two = torch.stack([example["input_ids_two"] for example in examples])
|
||||
return {
|
||||
result = {
|
||||
"pixel_values": pixel_values,
|
||||
"input_ids_one": input_ids_one,
|
||||
"input_ids_two": input_ids_two,
|
||||
@@ -913,6 +923,11 @@ def main(args):
|
||||
"crop_top_lefts": crop_top_lefts,
|
||||
}
|
||||
|
||||
filenames = [example["filenames"] for example in examples if "filenames" in example]
|
||||
if filenames:
|
||||
result["filenames"] = filenames
|
||||
return result
|
||||
|
||||
# DataLoaders creation:
|
||||
train_dataloader = torch.utils.data.DataLoader(
|
||||
train_dataset,
|
||||
@@ -1105,7 +1120,9 @@ def main(args):
|
||||
loss = F.mse_loss(model_pred.float(), target.float(), reduction="none")
|
||||
loss = loss.mean(dim=list(range(1, len(loss.shape)))) * mse_loss_weights
|
||||
loss = loss.mean()
|
||||
|
||||
if args.debug_loss and "filenames" in batch:
|
||||
for fname in batch["filenames"]:
|
||||
accelerator.log({"loss_for_" + fname: loss}, step=global_step)
|
||||
# Gather the losses across all processes for logging (if we use distributed training).
|
||||
avg_loss = accelerator.gather(loss.repeat(args.train_batch_size)).mean()
|
||||
train_loss += avg_loss.item() / args.gradient_accumulation_steps
|
||||
|
||||
@@ -54,7 +54,7 @@ from diffusers.utils.torch_utils import is_compiled_module
|
||||
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.27.0.dev0")
|
||||
check_min_version("0.28.0.dev0")
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
@@ -712,7 +712,7 @@ def main(args):
|
||||
|
||||
xformers_version = version.parse(xformers.__version__)
|
||||
if xformers_version == version.parse("0.0.16"):
|
||||
logger.warn(
|
||||
logger.warning(
|
||||
"xFormers 0.0.16 cannot be used for training in some GPUs. If you observe problems during training, please update xFormers to at least 0.0.17. See https://huggingface.co/docs/diffusers/main/en/optimization/xformers for more details."
|
||||
)
|
||||
unet.enable_xformers_memory_efficient_attention()
|
||||
|
||||
@@ -80,7 +80,7 @@ else:
|
||||
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.27.0.dev0")
|
||||
check_min_version("0.28.0.dev0")
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
@@ -708,7 +708,7 @@ def main():
|
||||
|
||||
xformers_version = version.parse(xformers.__version__)
|
||||
if xformers_version == version.parse("0.0.16"):
|
||||
logger.warn(
|
||||
logger.warning(
|
||||
"xFormers 0.0.16 cannot be used for training in some GPUs. If you observe problems during training, please update xFormers to at least 0.0.17. See https://huggingface.co/docs/diffusers/main/en/optimization/xformers for more details."
|
||||
)
|
||||
unet.enable_xformers_memory_efficient_attention()
|
||||
@@ -966,7 +966,7 @@ def main():
|
||||
accelerator.wait_for_everyone()
|
||||
if accelerator.is_main_process:
|
||||
if args.push_to_hub and not args.save_as_full_pipeline:
|
||||
logger.warn("Enabling full model saving because --push_to_hub=True was specified.")
|
||||
logger.warning("Enabling full model saving because --push_to_hub=True was specified.")
|
||||
save_full_model = True
|
||||
else:
|
||||
save_full_model = args.save_as_full_pipeline
|
||||
|
||||
@@ -56,7 +56,7 @@ else:
|
||||
# ------------------------------------------------------------------------------
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.27.0.dev0")
|
||||
check_min_version("0.28.0.dev0")
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@@ -76,7 +76,7 @@ else:
|
||||
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.27.0.dev0")
|
||||
check_min_version("0.28.0.dev0")
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
@@ -711,7 +711,7 @@ def main():
|
||||
|
||||
xformers_version = version.parse(xformers.__version__)
|
||||
if xformers_version == version.parse("0.0.16"):
|
||||
logger.warn(
|
||||
logger.warning(
|
||||
"xFormers 0.0.16 cannot be used for training in some GPUs. If you observe problems during training, please update xFormers to at least 0.0.17. See https://huggingface.co/docs/diffusers/main/en/optimization/xformers for more details."
|
||||
)
|
||||
unet.enable_xformers_memory_efficient_attention()
|
||||
@@ -1022,7 +1022,7 @@ def main():
|
||||
)
|
||||
|
||||
if args.push_to_hub and not args.save_as_full_pipeline:
|
||||
logger.warn("Enabling full model saving because --push_to_hub=True was specified.")
|
||||
logger.warning("Enabling full model saving because --push_to_hub=True was specified.")
|
||||
save_full_model = True
|
||||
else:
|
||||
save_full_model = args.save_as_full_pipeline
|
||||
|
||||
@@ -29,7 +29,7 @@ from diffusers.utils.import_utils import is_xformers_available
|
||||
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.27.0.dev0")
|
||||
check_min_version("0.28.0.dev0")
|
||||
|
||||
logger = get_logger(__name__, log_level="INFO")
|
||||
|
||||
@@ -408,7 +408,7 @@ def main(args):
|
||||
|
||||
xformers_version = version.parse(xformers.__version__)
|
||||
if xformers_version == version.parse("0.0.16"):
|
||||
logger.warn(
|
||||
logger.warning(
|
||||
"xFormers 0.0.16 cannot be used for training in some GPUs. If you observe problems during training, please update xFormers to at least 0.0.17. See https://huggingface.co/docs/diffusers/main/en/optimization/xformers for more details."
|
||||
)
|
||||
model.enable_xformers_memory_efficient_attention()
|
||||
@@ -648,7 +648,7 @@ def main(args):
|
||||
generator=generator,
|
||||
batch_size=args.eval_batch_size,
|
||||
num_inference_steps=args.ddpm_num_inference_steps,
|
||||
output_type="numpy",
|
||||
output_type="np",
|
||||
).images
|
||||
|
||||
if args.use_ema:
|
||||
|
||||
@@ -50,7 +50,7 @@ if is_wandb_available():
|
||||
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.27.0.dev0")
|
||||
check_min_version("0.28.0.dev0")
|
||||
|
||||
logger = get_logger(__name__, log_level="INFO")
|
||||
|
||||
@@ -184,7 +184,7 @@ def log_validation(text_encoder, tokenizer, prior, args, accelerator, weight_dty
|
||||
}
|
||||
)
|
||||
else:
|
||||
logger.warn(f"image logging not implemented for {tracker.name}")
|
||||
logger.warning(f"image logging not implemented for {tracker.name}")
|
||||
|
||||
del pipeline
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
@@ -51,7 +51,7 @@ if is_wandb_available():
|
||||
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.27.0.dev0")
|
||||
check_min_version("0.28.0.dev0")
|
||||
|
||||
logger = get_logger(__name__, log_level="INFO")
|
||||
|
||||
@@ -182,7 +182,7 @@ def log_validation(text_encoder, tokenizer, prior, args, accelerator, weight_dty
|
||||
}
|
||||
)
|
||||
else:
|
||||
logger.warn(f"image logging not implemented for {tracker.name}")
|
||||
logger.warning(f"image logging not implemented for {tracker.name}")
|
||||
|
||||
del pipeline
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
+163
-160
@@ -1,7 +1,7 @@
|
||||
# Run this script to convert the Stable Cascade model weights to a diffusers pipeline.
|
||||
import argparse
|
||||
from contextlib import nullcontext
|
||||
|
||||
import accelerate
|
||||
import torch
|
||||
from safetensors.torch import load_file
|
||||
from transformers import (
|
||||
@@ -18,23 +18,56 @@ from diffusers import (
|
||||
StableCascadeDecoderPipeline,
|
||||
StableCascadePriorPipeline,
|
||||
)
|
||||
from diffusers.loaders.single_file_utils import convert_stable_cascade_unet_single_file_to_diffusers
|
||||
from diffusers.models import StableCascadeUNet
|
||||
from diffusers.models.modeling_utils import load_model_dict_into_meta
|
||||
from diffusers.pipelines.wuerstchen import PaellaVQModel
|
||||
from diffusers.utils import is_accelerate_available
|
||||
|
||||
|
||||
if is_accelerate_available():
|
||||
from accelerate import init_empty_weights
|
||||
|
||||
parser = argparse.ArgumentParser(description="Convert Stable Cascade model weights to a diffusers pipeline")
|
||||
parser.add_argument("--model_path", type=str, default="../StableCascade", help="Location of Stable Cascade weights")
|
||||
parser.add_argument("--model_path", type=str, help="Location of Stable Cascade weights")
|
||||
parser.add_argument("--stage_c_name", type=str, default="stage_c.safetensors", help="Name of stage c checkpoint file")
|
||||
parser.add_argument("--stage_b_name", type=str, default="stage_b.safetensors", help="Name of stage b checkpoint file")
|
||||
parser.add_argument("--skip_stage_c", action="store_true", help="Skip converting stage c")
|
||||
parser.add_argument("--skip_stage_b", action="store_true", help="Skip converting stage b")
|
||||
parser.add_argument("--use_safetensors", action="store_true", help="Use SafeTensors for conversion")
|
||||
parser.add_argument("--save_org", type=str, default="diffusers", help="Hub organization to save the pipelines to")
|
||||
parser.add_argument(
|
||||
"--prior_output_path", default="stable-cascade-prior", type=str, help="Hub organization to save the pipelines to"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--decoder_output_path",
|
||||
type=str,
|
||||
default="stable-cascade-decoder",
|
||||
help="Hub organization to save the pipelines to",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--combined_output_path",
|
||||
type=str,
|
||||
default="stable-cascade-combined",
|
||||
help="Hub organization to save the pipelines to",
|
||||
)
|
||||
parser.add_argument("--save_combined", action="store_true")
|
||||
parser.add_argument("--push_to_hub", action="store_true", help="Push to hub")
|
||||
parser.add_argument("--variant", type=str, help="Set to bf16 to save bfloat16 weights")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
if args.skip_stage_b and args.skip_stage_c:
|
||||
raise ValueError("At least one stage should be converted")
|
||||
if (args.skip_stage_b or args.skip_stage_c) and args.save_combined:
|
||||
raise ValueError("Cannot skip stages when creating a combined pipeline")
|
||||
|
||||
model_path = args.model_path
|
||||
|
||||
device = "cpu"
|
||||
if args.variant == "bf16":
|
||||
dtype = torch.bfloat16
|
||||
else:
|
||||
dtype = torch.float32
|
||||
|
||||
# set paths to model weights
|
||||
prior_checkpoint_path = f"{model_path}/{args.stage_c_name}"
|
||||
@@ -52,164 +85,134 @@ tokenizer = AutoTokenizer.from_pretrained("laion/CLIP-ViT-bigG-14-laion2B-39B-b1
|
||||
feature_extractor = CLIPImageProcessor()
|
||||
image_encoder = CLIPVisionModelWithProjection.from_pretrained("openai/clip-vit-large-patch14")
|
||||
|
||||
# Prior
|
||||
if args.use_safetensors:
|
||||
orig_state_dict = load_file(prior_checkpoint_path, device=device)
|
||||
else:
|
||||
orig_state_dict = torch.load(prior_checkpoint_path, map_location=device)
|
||||
|
||||
state_dict = {}
|
||||
for key in orig_state_dict.keys():
|
||||
if key.endswith("in_proj_weight"):
|
||||
weights = orig_state_dict[key].chunk(3, 0)
|
||||
state_dict[key.replace("attn.in_proj_weight", "to_q.weight")] = weights[0]
|
||||
state_dict[key.replace("attn.in_proj_weight", "to_k.weight")] = weights[1]
|
||||
state_dict[key.replace("attn.in_proj_weight", "to_v.weight")] = weights[2]
|
||||
elif key.endswith("in_proj_bias"):
|
||||
weights = orig_state_dict[key].chunk(3, 0)
|
||||
state_dict[key.replace("attn.in_proj_bias", "to_q.bias")] = weights[0]
|
||||
state_dict[key.replace("attn.in_proj_bias", "to_k.bias")] = weights[1]
|
||||
state_dict[key.replace("attn.in_proj_bias", "to_v.bias")] = weights[2]
|
||||
elif key.endswith("out_proj.weight"):
|
||||
weights = orig_state_dict[key]
|
||||
state_dict[key.replace("attn.out_proj.weight", "to_out.0.weight")] = weights
|
||||
elif key.endswith("out_proj.bias"):
|
||||
weights = orig_state_dict[key]
|
||||
state_dict[key.replace("attn.out_proj.bias", "to_out.0.bias")] = weights
|
||||
else:
|
||||
state_dict[key] = orig_state_dict[key]
|
||||
|
||||
|
||||
with accelerate.init_empty_weights():
|
||||
prior_model = StableCascadeUNet(
|
||||
in_channels=16,
|
||||
out_channels=16,
|
||||
timestep_ratio_embedding_dim=64,
|
||||
patch_size=1,
|
||||
conditioning_dim=2048,
|
||||
block_out_channels=[2048, 2048],
|
||||
num_attention_heads=[32, 32],
|
||||
down_num_layers_per_block=[8, 24],
|
||||
up_num_layers_per_block=[24, 8],
|
||||
down_blocks_repeat_mappers=[1, 1],
|
||||
up_blocks_repeat_mappers=[1, 1],
|
||||
block_types_per_layer=[
|
||||
["SDCascadeResBlock", "SDCascadeTimestepBlock", "SDCascadeAttnBlock"],
|
||||
["SDCascadeResBlock", "SDCascadeTimestepBlock", "SDCascadeAttnBlock"],
|
||||
],
|
||||
clip_text_in_channels=1280,
|
||||
clip_text_pooled_in_channels=1280,
|
||||
clip_image_in_channels=768,
|
||||
clip_seq=4,
|
||||
kernel_size=3,
|
||||
dropout=[0.1, 0.1],
|
||||
self_attn=True,
|
||||
timestep_conditioning_type=["sca", "crp"],
|
||||
switch_level=[False],
|
||||
)
|
||||
load_model_dict_into_meta(prior_model, state_dict)
|
||||
|
||||
# scheduler for prior and decoder
|
||||
scheduler = DDPMWuerstchenScheduler()
|
||||
ctx = init_empty_weights if is_accelerate_available() else nullcontext
|
||||
|
||||
# Prior pipeline
|
||||
prior_pipeline = StableCascadePriorPipeline(
|
||||
prior=prior_model,
|
||||
tokenizer=tokenizer,
|
||||
text_encoder=text_encoder,
|
||||
image_encoder=image_encoder,
|
||||
scheduler=scheduler,
|
||||
feature_extractor=feature_extractor,
|
||||
)
|
||||
prior_pipeline.save_pretrained(f"{args.save_org}/StableCascade-prior", push_to_hub=args.push_to_hub)
|
||||
|
||||
# Decoder
|
||||
if args.use_safetensors:
|
||||
orig_state_dict = load_file(decoder_checkpoint_path, device=device)
|
||||
else:
|
||||
orig_state_dict = torch.load(decoder_checkpoint_path, map_location=device)
|
||||
|
||||
state_dict = {}
|
||||
for key in orig_state_dict.keys():
|
||||
if key.endswith("in_proj_weight"):
|
||||
weights = orig_state_dict[key].chunk(3, 0)
|
||||
state_dict[key.replace("attn.in_proj_weight", "to_q.weight")] = weights[0]
|
||||
state_dict[key.replace("attn.in_proj_weight", "to_k.weight")] = weights[1]
|
||||
state_dict[key.replace("attn.in_proj_weight", "to_v.weight")] = weights[2]
|
||||
elif key.endswith("in_proj_bias"):
|
||||
weights = orig_state_dict[key].chunk(3, 0)
|
||||
state_dict[key.replace("attn.in_proj_bias", "to_q.bias")] = weights[0]
|
||||
state_dict[key.replace("attn.in_proj_bias", "to_k.bias")] = weights[1]
|
||||
state_dict[key.replace("attn.in_proj_bias", "to_v.bias")] = weights[2]
|
||||
elif key.endswith("out_proj.weight"):
|
||||
weights = orig_state_dict[key]
|
||||
state_dict[key.replace("attn.out_proj.weight", "to_out.0.weight")] = weights
|
||||
elif key.endswith("out_proj.bias"):
|
||||
weights = orig_state_dict[key]
|
||||
state_dict[key.replace("attn.out_proj.bias", "to_out.0.bias")] = weights
|
||||
# rename clip_mapper to clip_txt_pooled_mapper
|
||||
elif key.endswith("clip_mapper.weight"):
|
||||
weights = orig_state_dict[key]
|
||||
state_dict[key.replace("clip_mapper.weight", "clip_txt_pooled_mapper.weight")] = weights
|
||||
elif key.endswith("clip_mapper.bias"):
|
||||
weights = orig_state_dict[key]
|
||||
state_dict[key.replace("clip_mapper.bias", "clip_txt_pooled_mapper.bias")] = weights
|
||||
else:
|
||||
state_dict[key] = orig_state_dict[key]
|
||||
|
||||
with accelerate.init_empty_weights():
|
||||
decoder = StableCascadeUNet(
|
||||
in_channels=4,
|
||||
out_channels=4,
|
||||
timestep_ratio_embedding_dim=64,
|
||||
patch_size=2,
|
||||
conditioning_dim=1280,
|
||||
block_out_channels=[320, 640, 1280, 1280],
|
||||
down_num_layers_per_block=[2, 6, 28, 6],
|
||||
up_num_layers_per_block=[6, 28, 6, 2],
|
||||
down_blocks_repeat_mappers=[1, 1, 1, 1],
|
||||
up_blocks_repeat_mappers=[3, 3, 2, 2],
|
||||
num_attention_heads=[0, 0, 20, 20],
|
||||
block_types_per_layer=[
|
||||
["SDCascadeResBlock", "SDCascadeTimestepBlock"],
|
||||
["SDCascadeResBlock", "SDCascadeTimestepBlock"],
|
||||
["SDCascadeResBlock", "SDCascadeTimestepBlock", "SDCascadeAttnBlock"],
|
||||
["SDCascadeResBlock", "SDCascadeTimestepBlock", "SDCascadeAttnBlock"],
|
||||
],
|
||||
clip_text_pooled_in_channels=1280,
|
||||
clip_seq=4,
|
||||
effnet_in_channels=16,
|
||||
pixel_mapper_in_channels=3,
|
||||
kernel_size=3,
|
||||
dropout=[0, 0, 0.1, 0.1],
|
||||
self_attn=True,
|
||||
timestep_conditioning_type=["sca"],
|
||||
)
|
||||
load_model_dict_into_meta(decoder, state_dict)
|
||||
|
||||
# VQGAN from Wuerstchen-V2
|
||||
vqmodel = PaellaVQModel.from_pretrained("warp-ai/wuerstchen", subfolder="vqgan")
|
||||
|
||||
# Decoder pipeline
|
||||
decoder_pipeline = StableCascadeDecoderPipeline(
|
||||
decoder=decoder, text_encoder=text_encoder, tokenizer=tokenizer, vqgan=vqmodel, scheduler=scheduler
|
||||
)
|
||||
decoder_pipeline.save_pretrained(f"{args.save_org}/StableCascade-decoder", push_to_hub=args.push_to_hub)
|
||||
|
||||
# Stable Cascade combined pipeline
|
||||
stable_cascade_pipeline = StableCascadeCombinedPipeline(
|
||||
# Decoder
|
||||
text_encoder=text_encoder,
|
||||
tokenizer=tokenizer,
|
||||
decoder=decoder,
|
||||
scheduler=scheduler,
|
||||
vqgan=vqmodel,
|
||||
if not args.skip_stage_c:
|
||||
# Prior
|
||||
prior_text_encoder=text_encoder,
|
||||
prior_tokenizer=tokenizer,
|
||||
prior_prior=prior_model,
|
||||
prior_scheduler=scheduler,
|
||||
prior_image_encoder=image_encoder,
|
||||
prior_feature_extractor=feature_extractor,
|
||||
)
|
||||
stable_cascade_pipeline.save_pretrained(f"{args.save_org}/StableCascade", push_to_hub=args.push_to_hub)
|
||||
if args.use_safetensors:
|
||||
prior_orig_state_dict = load_file(prior_checkpoint_path, device=device)
|
||||
else:
|
||||
prior_orig_state_dict = torch.load(prior_checkpoint_path, map_location=device)
|
||||
|
||||
prior_state_dict = convert_stable_cascade_unet_single_file_to_diffusers(prior_orig_state_dict)
|
||||
|
||||
with ctx():
|
||||
prior_model = StableCascadeUNet(
|
||||
in_channels=16,
|
||||
out_channels=16,
|
||||
timestep_ratio_embedding_dim=64,
|
||||
patch_size=1,
|
||||
conditioning_dim=2048,
|
||||
block_out_channels=[2048, 2048],
|
||||
num_attention_heads=[32, 32],
|
||||
down_num_layers_per_block=[8, 24],
|
||||
up_num_layers_per_block=[24, 8],
|
||||
down_blocks_repeat_mappers=[1, 1],
|
||||
up_blocks_repeat_mappers=[1, 1],
|
||||
block_types_per_layer=[
|
||||
["SDCascadeResBlock", "SDCascadeTimestepBlock", "SDCascadeAttnBlock"],
|
||||
["SDCascadeResBlock", "SDCascadeTimestepBlock", "SDCascadeAttnBlock"],
|
||||
],
|
||||
clip_text_in_channels=1280,
|
||||
clip_text_pooled_in_channels=1280,
|
||||
clip_image_in_channels=768,
|
||||
clip_seq=4,
|
||||
kernel_size=3,
|
||||
dropout=[0.1, 0.1],
|
||||
self_attn=True,
|
||||
timestep_conditioning_type=["sca", "crp"],
|
||||
switch_level=[False],
|
||||
)
|
||||
if is_accelerate_available():
|
||||
load_model_dict_into_meta(prior_model, prior_state_dict)
|
||||
else:
|
||||
prior_model.load_state_dict(prior_state_dict)
|
||||
|
||||
# Prior pipeline
|
||||
prior_pipeline = StableCascadePriorPipeline(
|
||||
prior=prior_model,
|
||||
tokenizer=tokenizer,
|
||||
text_encoder=text_encoder,
|
||||
image_encoder=image_encoder,
|
||||
scheduler=scheduler,
|
||||
feature_extractor=feature_extractor,
|
||||
)
|
||||
prior_pipeline.to(dtype).save_pretrained(
|
||||
args.prior_output_path, push_to_hub=args.push_to_hub, variant=args.variant
|
||||
)
|
||||
|
||||
if not args.skip_stage_b:
|
||||
# Decoder
|
||||
if args.use_safetensors:
|
||||
decoder_orig_state_dict = load_file(decoder_checkpoint_path, device=device)
|
||||
else:
|
||||
decoder_orig_state_dict = torch.load(decoder_checkpoint_path, map_location=device)
|
||||
|
||||
decoder_state_dict = convert_stable_cascade_unet_single_file_to_diffusers(decoder_orig_state_dict)
|
||||
with ctx():
|
||||
decoder = StableCascadeUNet(
|
||||
in_channels=4,
|
||||
out_channels=4,
|
||||
timestep_ratio_embedding_dim=64,
|
||||
patch_size=2,
|
||||
conditioning_dim=1280,
|
||||
block_out_channels=[320, 640, 1280, 1280],
|
||||
down_num_layers_per_block=[2, 6, 28, 6],
|
||||
up_num_layers_per_block=[6, 28, 6, 2],
|
||||
down_blocks_repeat_mappers=[1, 1, 1, 1],
|
||||
up_blocks_repeat_mappers=[3, 3, 2, 2],
|
||||
num_attention_heads=[0, 0, 20, 20],
|
||||
block_types_per_layer=[
|
||||
["SDCascadeResBlock", "SDCascadeTimestepBlock"],
|
||||
["SDCascadeResBlock", "SDCascadeTimestepBlock"],
|
||||
["SDCascadeResBlock", "SDCascadeTimestepBlock", "SDCascadeAttnBlock"],
|
||||
["SDCascadeResBlock", "SDCascadeTimestepBlock", "SDCascadeAttnBlock"],
|
||||
],
|
||||
clip_text_pooled_in_channels=1280,
|
||||
clip_seq=4,
|
||||
effnet_in_channels=16,
|
||||
pixel_mapper_in_channels=3,
|
||||
kernel_size=3,
|
||||
dropout=[0, 0, 0.1, 0.1],
|
||||
self_attn=True,
|
||||
timestep_conditioning_type=["sca"],
|
||||
)
|
||||
|
||||
if is_accelerate_available():
|
||||
load_model_dict_into_meta(decoder, decoder_state_dict)
|
||||
else:
|
||||
decoder.load_state_dict(decoder_state_dict)
|
||||
|
||||
# VQGAN from Wuerstchen-V2
|
||||
vqmodel = PaellaVQModel.from_pretrained("warp-ai/wuerstchen", subfolder="vqgan")
|
||||
|
||||
# Decoder pipeline
|
||||
decoder_pipeline = StableCascadeDecoderPipeline(
|
||||
decoder=decoder, text_encoder=text_encoder, tokenizer=tokenizer, vqgan=vqmodel, scheduler=scheduler
|
||||
)
|
||||
decoder_pipeline.to(dtype).save_pretrained(
|
||||
args.decoder_output_path, push_to_hub=args.push_to_hub, variant=args.variant
|
||||
)
|
||||
|
||||
if args.save_combined:
|
||||
# Stable Cascade combined pipeline
|
||||
stable_cascade_pipeline = StableCascadeCombinedPipeline(
|
||||
# Decoder
|
||||
text_encoder=text_encoder,
|
||||
tokenizer=tokenizer,
|
||||
decoder=decoder,
|
||||
scheduler=scheduler,
|
||||
vqgan=vqmodel,
|
||||
# Prior
|
||||
prior_text_encoder=text_encoder,
|
||||
prior_tokenizer=tokenizer,
|
||||
prior_prior=prior_model,
|
||||
prior_scheduler=scheduler,
|
||||
prior_image_encoder=image_encoder,
|
||||
prior_feature_extractor=feature_extractor,
|
||||
)
|
||||
stable_cascade_pipeline.to(dtype).save_pretrained(
|
||||
args.combined_output_path, push_to_hub=args.push_to_hub, variant=args.variant
|
||||
)
|
||||
|
||||
@@ -0,0 +1,226 @@
|
||||
# Run this script to convert the Stable Cascade model weights to a diffusers pipeline.
|
||||
import argparse
|
||||
from contextlib import nullcontext
|
||||
|
||||
import torch
|
||||
from safetensors.torch import load_file
|
||||
from transformers import (
|
||||
AutoTokenizer,
|
||||
CLIPConfig,
|
||||
CLIPImageProcessor,
|
||||
CLIPTextModelWithProjection,
|
||||
CLIPVisionModelWithProjection,
|
||||
)
|
||||
|
||||
from diffusers import (
|
||||
DDPMWuerstchenScheduler,
|
||||
StableCascadeCombinedPipeline,
|
||||
StableCascadeDecoderPipeline,
|
||||
StableCascadePriorPipeline,
|
||||
)
|
||||
from diffusers.loaders.single_file_utils import convert_stable_cascade_unet_single_file_to_diffusers
|
||||
from diffusers.models import StableCascadeUNet
|
||||
from diffusers.models.modeling_utils import load_model_dict_into_meta
|
||||
from diffusers.pipelines.wuerstchen import PaellaVQModel
|
||||
from diffusers.utils import is_accelerate_available
|
||||
|
||||
|
||||
if is_accelerate_available():
|
||||
from accelerate import init_empty_weights
|
||||
|
||||
parser = argparse.ArgumentParser(description="Convert Stable Cascade model weights to a diffusers pipeline")
|
||||
parser.add_argument("--model_path", type=str, help="Location of Stable Cascade weights")
|
||||
parser.add_argument(
|
||||
"--stage_c_name", type=str, default="stage_c_lite.safetensors", help="Name of stage c checkpoint file"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--stage_b_name", type=str, default="stage_b_lite.safetensors", help="Name of stage b checkpoint file"
|
||||
)
|
||||
parser.add_argument("--skip_stage_c", action="store_true", help="Skip converting stage c")
|
||||
parser.add_argument("--skip_stage_b", action="store_true", help="Skip converting stage b")
|
||||
parser.add_argument("--use_safetensors", action="store_true", help="Use SafeTensors for conversion")
|
||||
parser.add_argument(
|
||||
"--prior_output_path",
|
||||
default="stable-cascade-prior-lite",
|
||||
type=str,
|
||||
help="Hub organization to save the pipelines to",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--decoder_output_path",
|
||||
type=str,
|
||||
default="stable-cascade-decoder-lite",
|
||||
help="Hub organization to save the pipelines to",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--combined_output_path",
|
||||
type=str,
|
||||
default="stable-cascade-combined-lite",
|
||||
help="Hub organization to save the pipelines to",
|
||||
)
|
||||
parser.add_argument("--save_combined", action="store_true")
|
||||
parser.add_argument("--push_to_hub", action="store_true", help="Push to hub")
|
||||
parser.add_argument("--variant", type=str, help="Set to bf16 to save bfloat16 weights")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
if args.skip_stage_b and args.skip_stage_c:
|
||||
raise ValueError("At least one stage should be converted")
|
||||
if (args.skip_stage_b or args.skip_stage_c) and args.save_combined:
|
||||
raise ValueError("Cannot skip stages when creating a combined pipeline")
|
||||
|
||||
model_path = args.model_path
|
||||
|
||||
device = "cpu"
|
||||
if args.variant == "bf16":
|
||||
dtype = torch.bfloat16
|
||||
else:
|
||||
dtype = torch.float32
|
||||
|
||||
# set paths to model weights
|
||||
prior_checkpoint_path = f"{model_path}/{args.stage_c_name}"
|
||||
decoder_checkpoint_path = f"{model_path}/{args.stage_b_name}"
|
||||
|
||||
# Clip Text encoder and tokenizer
|
||||
config = CLIPConfig.from_pretrained("laion/CLIP-ViT-bigG-14-laion2B-39B-b160k")
|
||||
config.text_config.projection_dim = config.projection_dim
|
||||
text_encoder = CLIPTextModelWithProjection.from_pretrained(
|
||||
"laion/CLIP-ViT-bigG-14-laion2B-39B-b160k", config=config.text_config
|
||||
)
|
||||
tokenizer = AutoTokenizer.from_pretrained("laion/CLIP-ViT-bigG-14-laion2B-39B-b160k")
|
||||
|
||||
# image processor
|
||||
feature_extractor = CLIPImageProcessor()
|
||||
image_encoder = CLIPVisionModelWithProjection.from_pretrained("openai/clip-vit-large-patch14")
|
||||
# scheduler for prior and decoder
|
||||
scheduler = DDPMWuerstchenScheduler()
|
||||
|
||||
ctx = init_empty_weights if is_accelerate_available() else nullcontext
|
||||
|
||||
if not args.skip_stage_c:
|
||||
# Prior
|
||||
if args.use_safetensors:
|
||||
prior_orig_state_dict = load_file(prior_checkpoint_path, device=device)
|
||||
else:
|
||||
prior_orig_state_dict = torch.load(prior_checkpoint_path, map_location=device)
|
||||
|
||||
prior_state_dict = convert_stable_cascade_unet_single_file_to_diffusers(prior_orig_state_dict)
|
||||
with ctx():
|
||||
prior_model = StableCascadeUNet(
|
||||
in_channels=16,
|
||||
out_channels=16,
|
||||
timestep_ratio_embedding_dim=64,
|
||||
patch_size=1,
|
||||
conditioning_dim=1536,
|
||||
block_out_channels=[1536, 1536],
|
||||
num_attention_heads=[24, 24],
|
||||
down_num_layers_per_block=[4, 12],
|
||||
up_num_layers_per_block=[12, 4],
|
||||
down_blocks_repeat_mappers=[1, 1],
|
||||
up_blocks_repeat_mappers=[1, 1],
|
||||
block_types_per_layer=[
|
||||
["SDCascadeResBlock", "SDCascadeTimestepBlock", "SDCascadeAttnBlock"],
|
||||
["SDCascadeResBlock", "SDCascadeTimestepBlock", "SDCascadeAttnBlock"],
|
||||
],
|
||||
clip_text_in_channels=1280,
|
||||
clip_text_pooled_in_channels=1280,
|
||||
clip_image_in_channels=768,
|
||||
clip_seq=4,
|
||||
kernel_size=3,
|
||||
dropout=[0.1, 0.1],
|
||||
self_attn=True,
|
||||
timestep_conditioning_type=["sca", "crp"],
|
||||
switch_level=[False],
|
||||
)
|
||||
|
||||
if is_accelerate_available():
|
||||
load_model_dict_into_meta(prior_model, prior_state_dict)
|
||||
else:
|
||||
prior_model.load_state_dict(prior_state_dict)
|
||||
|
||||
# Prior pipeline
|
||||
prior_pipeline = StableCascadePriorPipeline(
|
||||
prior=prior_model,
|
||||
tokenizer=tokenizer,
|
||||
text_encoder=text_encoder,
|
||||
image_encoder=image_encoder,
|
||||
scheduler=scheduler,
|
||||
feature_extractor=feature_extractor,
|
||||
)
|
||||
prior_pipeline.to(dtype).save_pretrained(
|
||||
args.prior_output_path, push_to_hub=args.push_to_hub, variant=args.variant
|
||||
)
|
||||
|
||||
if not args.skip_stage_b:
|
||||
# Decoder
|
||||
if args.use_safetensors:
|
||||
decoder_orig_state_dict = load_file(decoder_checkpoint_path, device=device)
|
||||
else:
|
||||
decoder_orig_state_dict = torch.load(decoder_checkpoint_path, map_location=device)
|
||||
|
||||
decoder_state_dict = convert_stable_cascade_unet_single_file_to_diffusers(decoder_orig_state_dict)
|
||||
|
||||
with ctx():
|
||||
decoder = StableCascadeUNet(
|
||||
in_channels=4,
|
||||
out_channels=4,
|
||||
timestep_ratio_embedding_dim=64,
|
||||
patch_size=2,
|
||||
conditioning_dim=1280,
|
||||
block_out_channels=[320, 576, 1152, 1152],
|
||||
down_num_layers_per_block=[2, 4, 14, 4],
|
||||
up_num_layers_per_block=[4, 14, 4, 2],
|
||||
down_blocks_repeat_mappers=[1, 1, 1, 1],
|
||||
up_blocks_repeat_mappers=[2, 2, 2, 2],
|
||||
num_attention_heads=[0, 9, 18, 18],
|
||||
block_types_per_layer=[
|
||||
["SDCascadeResBlock", "SDCascadeTimestepBlock"],
|
||||
["SDCascadeResBlock", "SDCascadeTimestepBlock"],
|
||||
["SDCascadeResBlock", "SDCascadeTimestepBlock", "SDCascadeAttnBlock"],
|
||||
["SDCascadeResBlock", "SDCascadeTimestepBlock", "SDCascadeAttnBlock"],
|
||||
],
|
||||
clip_text_pooled_in_channels=1280,
|
||||
clip_seq=4,
|
||||
effnet_in_channels=16,
|
||||
pixel_mapper_in_channels=3,
|
||||
kernel_size=3,
|
||||
dropout=[0, 0, 0.1, 0.1],
|
||||
self_attn=True,
|
||||
timestep_conditioning_type=["sca"],
|
||||
)
|
||||
|
||||
if is_accelerate_available():
|
||||
load_model_dict_into_meta(decoder, decoder_state_dict)
|
||||
else:
|
||||
decoder.load_state_dict(decoder_state_dict)
|
||||
|
||||
# VQGAN from Wuerstchen-V2
|
||||
vqmodel = PaellaVQModel.from_pretrained("warp-ai/wuerstchen", subfolder="vqgan")
|
||||
|
||||
# Decoder pipeline
|
||||
decoder_pipeline = StableCascadeDecoderPipeline(
|
||||
decoder=decoder, text_encoder=text_encoder, tokenizer=tokenizer, vqgan=vqmodel, scheduler=scheduler
|
||||
)
|
||||
decoder_pipeline.to(dtype).save_pretrained(
|
||||
args.decoder_output_path, push_to_hub=args.push_to_hub, variant=args.variant
|
||||
)
|
||||
|
||||
if args.save_combined:
|
||||
# Stable Cascade combined pipeline
|
||||
stable_cascade_pipeline = StableCascadeCombinedPipeline(
|
||||
# Decoder
|
||||
text_encoder=text_encoder,
|
||||
tokenizer=tokenizer,
|
||||
decoder=decoder,
|
||||
scheduler=scheduler,
|
||||
vqgan=vqmodel,
|
||||
# Prior
|
||||
prior_text_encoder=text_encoder,
|
||||
prior_tokenizer=tokenizer,
|
||||
prior_prior=prior_model,
|
||||
prior_scheduler=scheduler,
|
||||
prior_image_encoder=image_encoder,
|
||||
prior_feature_extractor=feature_extractor,
|
||||
)
|
||||
stable_cascade_pipeline.to(dtype).save_pretrained(
|
||||
args.combined_output_path, push_to_hub=args.push_to_hub, variant=args.variant
|
||||
)
|
||||
@@ -0,0 +1,139 @@
|
||||
import argparse
|
||||
import json
|
||||
import os
|
||||
from datetime import date
|
||||
from pathlib import Path
|
||||
|
||||
from slack_sdk import WebClient
|
||||
from tabulate import tabulate
|
||||
|
||||
|
||||
MAX_LEN_MESSAGE = 2900 # slack endpoint has a limit of 3001 characters
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--slack_channel_name", default="diffusers-ci-nightly")
|
||||
|
||||
|
||||
def main(slack_channel_name=None):
|
||||
failed = []
|
||||
passed = []
|
||||
|
||||
group_info = []
|
||||
|
||||
total_num_failed = 0
|
||||
empty_file = False or len(list(Path().glob("*.log"))) == 0
|
||||
|
||||
total_empty_files = []
|
||||
|
||||
for log in Path().glob("*.log"):
|
||||
section_num_failed = 0
|
||||
i = 0
|
||||
with open(log) as f:
|
||||
for line in f:
|
||||
line = json.loads(line)
|
||||
i += 1
|
||||
if line.get("nodeid", "") != "":
|
||||
test = line["nodeid"]
|
||||
if line.get("duration", None) is not None:
|
||||
duration = f'{line["duration"]:.4f}'
|
||||
if line.get("outcome", "") == "failed":
|
||||
section_num_failed += 1
|
||||
failed.append([test, duration, log.name.split("_")[0]])
|
||||
total_num_failed += 1
|
||||
else:
|
||||
passed.append([test, duration, log.name.split("_")[0]])
|
||||
empty_file = i == 0
|
||||
group_info.append([str(log), section_num_failed, failed])
|
||||
total_empty_files.append(empty_file)
|
||||
os.remove(log)
|
||||
failed = []
|
||||
text = (
|
||||
"🌞 There were no failures!"
|
||||
if not any(total_empty_files)
|
||||
else "Something went wrong there is at least one empty file - please check GH action results."
|
||||
)
|
||||
no_error_payload = {
|
||||
"type": "section",
|
||||
"text": {
|
||||
"type": "plain_text",
|
||||
"text": text,
|
||||
"emoji": True,
|
||||
},
|
||||
}
|
||||
|
||||
message = ""
|
||||
payload = [
|
||||
{
|
||||
"type": "header",
|
||||
"text": {
|
||||
"type": "plain_text",
|
||||
"text": "🤗 Results of the Diffusers scheduled nightly tests.",
|
||||
},
|
||||
},
|
||||
]
|
||||
if total_num_failed > 0:
|
||||
for i, (name, num_failed, failed_tests) in enumerate(group_info):
|
||||
if num_failed > 0:
|
||||
if num_failed == 1:
|
||||
message += f"*{name}: {num_failed} failed test*\n"
|
||||
else:
|
||||
message += f"*{name}: {num_failed} failed tests*\n"
|
||||
failed_table = []
|
||||
for test in failed_tests:
|
||||
failed_table.append(test[0].split("::"))
|
||||
failed_table = tabulate(
|
||||
failed_table,
|
||||
headers=["Test Location", "Test Case", "Test Name"],
|
||||
showindex="always",
|
||||
tablefmt="grid",
|
||||
maxcolwidths=[12, 12, 12],
|
||||
)
|
||||
message += "\n```\n" + failed_table + "\n```"
|
||||
|
||||
if total_empty_files[i]:
|
||||
message += f"\n*{name}: Warning! Empty file - please check the GitHub action job *\n"
|
||||
print(f"### {message}")
|
||||
else:
|
||||
payload.append(no_error_payload)
|
||||
|
||||
if len(message) > MAX_LEN_MESSAGE:
|
||||
print(f"Truncating long message from {len(message)} to {MAX_LEN_MESSAGE}")
|
||||
message = message[:MAX_LEN_MESSAGE] + "..."
|
||||
|
||||
if len(message) != 0:
|
||||
md_report = {
|
||||
"type": "section",
|
||||
"text": {"type": "mrkdwn", "text": message},
|
||||
}
|
||||
payload.append(md_report)
|
||||
action_button = {
|
||||
"type": "section",
|
||||
"text": {"type": "mrkdwn", "text": "*For more details:*"},
|
||||
"accessory": {
|
||||
"type": "button",
|
||||
"text": {"type": "plain_text", "text": "Check Action results", "emoji": True},
|
||||
"url": f"https://github.com/huggingface/diffusers/actions/runs/{os.environ['GITHUB_RUN_ID']}",
|
||||
},
|
||||
}
|
||||
payload.append(action_button)
|
||||
|
||||
date_report = {
|
||||
"type": "context",
|
||||
"elements": [
|
||||
{
|
||||
"type": "plain_text",
|
||||
"text": f"Nightly test results for {date.today()}",
|
||||
},
|
||||
],
|
||||
}
|
||||
payload.append(date_report)
|
||||
|
||||
print(payload)
|
||||
|
||||
client = WebClient(token=os.environ.get("SLACK_API_TOKEN"))
|
||||
client.chat_postMessage(channel=f"#{slack_channel_name}", text=message, blocks=payload)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
args = parser.parse_args()
|
||||
main(args.slack_channel_name)
|
||||
@@ -249,7 +249,7 @@ version_range_max = max(sys.version_info[1], 10) + 1
|
||||
|
||||
setup(
|
||||
name="diffusers",
|
||||
version="0.27.0.dev0", # expected format is one of x.y.z.dev0, or x.y.z.rc1 or x.y.z (no to dashes, yes to dots)
|
||||
version="0.28.0.dev0", # expected format is one of x.y.z.dev0, or x.y.z.rc1 or x.y.z (no to dashes, yes to dots)
|
||||
description="State-of-the-art diffusion in PyTorch and JAX.",
|
||||
long_description=open("README.md", "r", encoding="utf-8").read(),
|
||||
long_description_content_type="text/markdown",
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
__version__ = "0.27.0.dev0"
|
||||
__version__ = "0.28.0.dev0"
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
@@ -253,6 +253,8 @@ else:
|
||||
"LatentConsistencyModelImg2ImgPipeline",
|
||||
"LatentConsistencyModelPipeline",
|
||||
"LDMTextToImagePipeline",
|
||||
"LEditsPPPipelineStableDiffusion",
|
||||
"LEditsPPPipelineStableDiffusionXL",
|
||||
"MusicLDMPipeline",
|
||||
"PaintByExamplePipeline",
|
||||
"PIAPipeline",
|
||||
@@ -623,6 +625,8 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
LatentConsistencyModelImg2ImgPipeline,
|
||||
LatentConsistencyModelPipeline,
|
||||
LDMTextToImagePipeline,
|
||||
LEditsPPPipelineStableDiffusion,
|
||||
LEditsPPPipelineStableDiffusionXL,
|
||||
MusicLDMPipeline,
|
||||
PaintByExamplePipeline,
|
||||
PIAPipeline,
|
||||
|
||||
@@ -215,7 +215,7 @@ class IPAdapterMixin:
|
||||
else:
|
||||
logger.warning(
|
||||
"image_encoder is not loaded since `image_encoder_folder=None` passed. You will not be able to use `ip_adapter_image` when calling the pipeline with IP-Adapter."
|
||||
"Use `ip_adapter_image_embedding` to pass pre-geneated image embedding instead."
|
||||
"Use `ip_adapter_image_embeds` to pass pre-generated image embedding instead."
|
||||
)
|
||||
|
||||
# create feature extractor if it has not been registered to the pipeline yet
|
||||
|
||||
@@ -430,7 +430,7 @@ class LoraLoaderMixin:
|
||||
# contain the module names of the `unet` as its keys WITHOUT any prefix.
|
||||
if not USE_PEFT_BACKEND:
|
||||
warn_message = "You have saved the LoRA weights using the old format. To convert the old LoRA weights to the new format, you can first load them in a dictionary and then create a new dictionary like the following: `new_state_dict = {f'unet.{module_name}': params for module_name, params in old_state_dict.items()}`."
|
||||
logger.warn(warn_message)
|
||||
logger.warning(warn_message)
|
||||
|
||||
if len(state_dict.keys()) > 0:
|
||||
if adapter_name in getattr(unet, "peft_config", {}):
|
||||
@@ -882,7 +882,7 @@ class LoraLoaderMixin:
|
||||
if fuse_unet or fuse_text_encoder:
|
||||
self.num_fused_loras += 1
|
||||
if self.num_fused_loras > 1:
|
||||
logger.warn(
|
||||
logger.warning(
|
||||
"The current API is supported for operating with a single LoRA file. You are trying to load and fuse more than one LoRA which is not well-supported.",
|
||||
)
|
||||
|
||||
|
||||
@@ -56,6 +56,8 @@ def build_sub_model_components(
|
||||
|
||||
if component_name == "unet":
|
||||
num_in_channels = kwargs.pop("num_in_channels", None)
|
||||
upcast_attention = kwargs.pop("upcast_attention", None)
|
||||
|
||||
unet_components = create_diffusers_unet_model_from_ldm(
|
||||
pipeline_class_name,
|
||||
original_config,
|
||||
@@ -64,6 +66,7 @@ def build_sub_model_components(
|
||||
image_size=image_size,
|
||||
torch_dtype=torch_dtype,
|
||||
model_type=model_type,
|
||||
upcast_attention=upcast_attention,
|
||||
)
|
||||
return unet_components
|
||||
|
||||
@@ -189,6 +192,30 @@ class FromSingleFileMixin:
|
||||
revision (`str`, *optional*, defaults to `"main"`):
|
||||
The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier
|
||||
allowed by Git.
|
||||
original_config_file (`str`, *optional*):
|
||||
The path to the original config file that was used to train the model. If not provided, the config file
|
||||
will be inferred from the checkpoint file.
|
||||
model_type (`str`, *optional*):
|
||||
The type of model to load. If not provided, the model type will be inferred from the checkpoint file.
|
||||
image_size (`int`, *optional*):
|
||||
The size of the image output. It's used to configure the `sample_size` parameter of the UNet and VAE model.
|
||||
load_safety_checker (`bool`, *optional*, defaults to `False`):
|
||||
Whether to load the safety checker model or not. By default, the safety checker is not loaded unless a `safety_checker` component is passed to the `kwargs`.
|
||||
num_in_channels (`int`, *optional*):
|
||||
Specify the number of input channels for the UNet model. Read more about how to configure UNet model with this parameter
|
||||
[here](https://huggingface.co/docs/diffusers/training/adapt_a_model#configure-unet2dconditionmodel-parameters).
|
||||
scaling_factor (`float`, *optional*):
|
||||
The scaling factor to use for the VAE model. If not provided, it is inferred from the config file first.
|
||||
If the scaling factor is not found in the config file, the default value 0.18215 is used.
|
||||
scheduler_type (`str`, *optional*):
|
||||
The type of scheduler to load. If not provided, the scheduler type will be inferred from the checkpoint file.
|
||||
prediction_type (`str`, *optional*):
|
||||
The type of prediction to load. If not provided, the prediction type will be inferred from the checkpoint file.
|
||||
kwargs (remaining dictionary of keyword arguments, *optional*):
|
||||
Can be used to overwrite load and saveable variables (the pipeline components of the specific pipeline
|
||||
class). The overwritten components are passed directly to the pipelines `__init__` method. See example
|
||||
below for more information.
|
||||
|
||||
Examples:
|
||||
|
||||
```py
|
||||
@@ -276,7 +303,9 @@ class FromSingleFileMixin:
|
||||
continue
|
||||
init_kwargs.update(components)
|
||||
|
||||
additional_components = set_additional_components(class_name, original_config, model_type=model_type)
|
||||
additional_components = set_additional_components(
|
||||
class_name, original_config, checkpoint=checkpoint, model_type=model_type
|
||||
)
|
||||
if additional_components:
|
||||
init_kwargs.update(additional_components)
|
||||
|
||||
|
||||
@@ -81,6 +81,87 @@ SCHEDULER_DEFAULT_CONFIG = {
|
||||
"timestep_spacing": "leading",
|
||||
}
|
||||
|
||||
|
||||
STABLE_CASCADE_DEFAULT_CONFIGS = {
|
||||
"stage_c": {"pretrained_model_name_or_path": "diffusers/stable-cascade-configs", "subfolder": "prior"},
|
||||
"stage_c_lite": {"pretrained_model_name_or_path": "diffusers/stable-cascade-configs", "subfolder": "prior_lite"},
|
||||
"stage_b": {"pretrained_model_name_or_path": "diffusers/stable-cascade-configs", "subfolder": "decoder"},
|
||||
"stage_b_lite": {"pretrained_model_name_or_path": "diffusers/stable-cascade-configs", "subfolder": "decoder_lite"},
|
||||
}
|
||||
|
||||
|
||||
def convert_stable_cascade_unet_single_file_to_diffusers(original_state_dict):
|
||||
is_stage_c = "clip_txt_mapper.weight" in original_state_dict
|
||||
|
||||
if is_stage_c:
|
||||
state_dict = {}
|
||||
for key in original_state_dict.keys():
|
||||
if key.endswith("in_proj_weight"):
|
||||
weights = original_state_dict[key].chunk(3, 0)
|
||||
state_dict[key.replace("attn.in_proj_weight", "to_q.weight")] = weights[0]
|
||||
state_dict[key.replace("attn.in_proj_weight", "to_k.weight")] = weights[1]
|
||||
state_dict[key.replace("attn.in_proj_weight", "to_v.weight")] = weights[2]
|
||||
elif key.endswith("in_proj_bias"):
|
||||
weights = original_state_dict[key].chunk(3, 0)
|
||||
state_dict[key.replace("attn.in_proj_bias", "to_q.bias")] = weights[0]
|
||||
state_dict[key.replace("attn.in_proj_bias", "to_k.bias")] = weights[1]
|
||||
state_dict[key.replace("attn.in_proj_bias", "to_v.bias")] = weights[2]
|
||||
elif key.endswith("out_proj.weight"):
|
||||
weights = original_state_dict[key]
|
||||
state_dict[key.replace("attn.out_proj.weight", "to_out.0.weight")] = weights
|
||||
elif key.endswith("out_proj.bias"):
|
||||
weights = original_state_dict[key]
|
||||
state_dict[key.replace("attn.out_proj.bias", "to_out.0.bias")] = weights
|
||||
else:
|
||||
state_dict[key] = original_state_dict[key]
|
||||
else:
|
||||
state_dict = {}
|
||||
for key in original_state_dict.keys():
|
||||
if key.endswith("in_proj_weight"):
|
||||
weights = original_state_dict[key].chunk(3, 0)
|
||||
state_dict[key.replace("attn.in_proj_weight", "to_q.weight")] = weights[0]
|
||||
state_dict[key.replace("attn.in_proj_weight", "to_k.weight")] = weights[1]
|
||||
state_dict[key.replace("attn.in_proj_weight", "to_v.weight")] = weights[2]
|
||||
elif key.endswith("in_proj_bias"):
|
||||
weights = original_state_dict[key].chunk(3, 0)
|
||||
state_dict[key.replace("attn.in_proj_bias", "to_q.bias")] = weights[0]
|
||||
state_dict[key.replace("attn.in_proj_bias", "to_k.bias")] = weights[1]
|
||||
state_dict[key.replace("attn.in_proj_bias", "to_v.bias")] = weights[2]
|
||||
elif key.endswith("out_proj.weight"):
|
||||
weights = original_state_dict[key]
|
||||
state_dict[key.replace("attn.out_proj.weight", "to_out.0.weight")] = weights
|
||||
elif key.endswith("out_proj.bias"):
|
||||
weights = original_state_dict[key]
|
||||
state_dict[key.replace("attn.out_proj.bias", "to_out.0.bias")] = weights
|
||||
# rename clip_mapper to clip_txt_pooled_mapper
|
||||
elif key.endswith("clip_mapper.weight"):
|
||||
weights = original_state_dict[key]
|
||||
state_dict[key.replace("clip_mapper.weight", "clip_txt_pooled_mapper.weight")] = weights
|
||||
elif key.endswith("clip_mapper.bias"):
|
||||
weights = original_state_dict[key]
|
||||
state_dict[key.replace("clip_mapper.bias", "clip_txt_pooled_mapper.bias")] = weights
|
||||
else:
|
||||
state_dict[key] = original_state_dict[key]
|
||||
|
||||
return state_dict
|
||||
|
||||
|
||||
def infer_stable_cascade_single_file_config(checkpoint):
|
||||
is_stage_c = "clip_txt_mapper.weight" in checkpoint
|
||||
is_stage_b = "down_blocks.1.0.channelwise.0.weight" in checkpoint
|
||||
|
||||
if is_stage_c and (checkpoint["clip_txt_mapper.weight"].shape[0] == 1536):
|
||||
config_type = "stage_c_lite"
|
||||
elif is_stage_c and (checkpoint["clip_txt_mapper.weight"].shape[0] == 2048):
|
||||
config_type = "stage_c"
|
||||
elif is_stage_b and checkpoint["down_blocks.1.0.channelwise.0.weight"].shape[-1] == 576:
|
||||
config_type = "stage_b_lite"
|
||||
elif is_stage_b and checkpoint["down_blocks.1.0.channelwise.0.weight"].shape[-1] == 640:
|
||||
config_type = "stage_b"
|
||||
|
||||
return STABLE_CASCADE_DEFAULT_CONFIGS[config_type]
|
||||
|
||||
|
||||
DIFFUSERS_TO_LDM_MAPPING = {
|
||||
"unet": {
|
||||
"layers": {
|
||||
@@ -229,10 +310,34 @@ def fetch_ldm_config_and_checkpoint(
|
||||
cache_dir=None,
|
||||
local_files_only=None,
|
||||
revision=None,
|
||||
):
|
||||
checkpoint = load_single_file_model_checkpoint(
|
||||
pretrained_model_link_or_path,
|
||||
resume_download=resume_download,
|
||||
force_download=force_download,
|
||||
proxies=proxies,
|
||||
token=token,
|
||||
cache_dir=cache_dir,
|
||||
local_files_only=local_files_only,
|
||||
revision=revision,
|
||||
)
|
||||
original_config = fetch_original_config(class_name, checkpoint, original_config_file)
|
||||
|
||||
return original_config, checkpoint
|
||||
|
||||
|
||||
def load_single_file_model_checkpoint(
|
||||
pretrained_model_link_or_path,
|
||||
resume_download=False,
|
||||
force_download=False,
|
||||
proxies=None,
|
||||
token=None,
|
||||
cache_dir=None,
|
||||
local_files_only=None,
|
||||
revision=None,
|
||||
):
|
||||
if os.path.isfile(pretrained_model_link_or_path):
|
||||
checkpoint = load_state_dict(pretrained_model_link_or_path)
|
||||
|
||||
else:
|
||||
repo_id, weights_name = _extract_repo_id_and_weights_name(pretrained_model_link_or_path)
|
||||
checkpoint_path = _get_model_file(
|
||||
@@ -252,9 +357,7 @@ def fetch_ldm_config_and_checkpoint(
|
||||
while "state_dict" in checkpoint:
|
||||
checkpoint = checkpoint["state_dict"]
|
||||
|
||||
original_config = fetch_original_config(class_name, checkpoint, original_config_file)
|
||||
|
||||
return original_config, checkpoint
|
||||
return checkpoint
|
||||
|
||||
|
||||
def infer_original_config_file(class_name, checkpoint):
|
||||
@@ -307,7 +410,7 @@ def fetch_original_config(pipeline_class_name, checkpoint, original_config_file=
|
||||
return original_config
|
||||
|
||||
|
||||
def infer_model_type(original_config, checkpoint=None, model_type=None):
|
||||
def infer_model_type(original_config, checkpoint, model_type=None):
|
||||
if model_type is not None:
|
||||
return model_type
|
||||
|
||||
@@ -462,8 +565,8 @@ def create_unet_diffusers_config(original_config, image_size: int):
|
||||
config = {
|
||||
"sample_size": image_size // vae_scale_factor,
|
||||
"in_channels": unet_params["in_channels"],
|
||||
"down_block_types": tuple(down_block_types),
|
||||
"block_out_channels": tuple(block_out_channels),
|
||||
"down_block_types": down_block_types,
|
||||
"block_out_channels": block_out_channels,
|
||||
"layers_per_block": unet_params["num_res_blocks"],
|
||||
"cross_attention_dim": context_dim,
|
||||
"attention_head_dim": head_dim,
|
||||
@@ -482,7 +585,7 @@ def create_unet_diffusers_config(original_config, image_size: int):
|
||||
config["num_class_embeds"] = unet_params["num_classes"]
|
||||
|
||||
config["out_channels"] = unet_params["out_channels"]
|
||||
config["up_block_types"] = tuple(up_block_types)
|
||||
config["up_block_types"] = up_block_types
|
||||
|
||||
return config
|
||||
|
||||
@@ -530,9 +633,9 @@ def create_vae_diffusers_config(original_config, image_size, scaling_factor=None
|
||||
"sample_size": image_size,
|
||||
"in_channels": vae_params["in_channels"],
|
||||
"out_channels": vae_params["out_ch"],
|
||||
"down_block_types": tuple(down_block_types),
|
||||
"up_block_types": tuple(up_block_types),
|
||||
"block_out_channels": tuple(block_out_channels),
|
||||
"down_block_types": down_block_types,
|
||||
"up_block_types": up_block_types,
|
||||
"block_out_channels": block_out_channels,
|
||||
"latent_channels": vae_params["z_channels"],
|
||||
"layers_per_block": vae_params["num_res_blocks"],
|
||||
"scaling_factor": scaling_factor,
|
||||
@@ -884,7 +987,7 @@ def create_diffusers_controlnet_model_from_ldm(
|
||||
unexpected_keys = [k for k in unexpected_keys if re.search(pat, k) is None]
|
||||
|
||||
if len(unexpected_keys) > 0:
|
||||
logger.warn(
|
||||
logger.warning(
|
||||
f"Some weights of the model checkpoint were not used when initializing {controlnet.__name__}: \n {[', '.join(unexpected_keys)]}"
|
||||
)
|
||||
else:
|
||||
@@ -1060,7 +1163,7 @@ def create_text_encoder_from_ldm_clip_checkpoint(config_name, checkpoint, local_
|
||||
unexpected_keys = [k for k in unexpected_keys if re.search(pat, k) is None]
|
||||
|
||||
if len(unexpected_keys) > 0:
|
||||
logger.warn(
|
||||
logger.warning(
|
||||
f"Some weights of the model checkpoint were not used when initializing {text_model.__class__.__name__}: \n {[', '.join(unexpected_keys)]}"
|
||||
)
|
||||
else:
|
||||
@@ -1155,7 +1258,7 @@ def create_text_encoder_from_open_clip_checkpoint(
|
||||
unexpected_keys = [k for k in unexpected_keys if re.search(pat, k) is None]
|
||||
|
||||
if len(unexpected_keys) > 0:
|
||||
logger.warn(
|
||||
logger.warning(
|
||||
f"Some weights of the model checkpoint were not used when initializing {text_model.__class__.__name__}: \n {[', '.join(unexpected_keys)]}"
|
||||
)
|
||||
|
||||
@@ -1176,7 +1279,7 @@ def create_diffusers_unet_model_from_ldm(
|
||||
original_config,
|
||||
checkpoint,
|
||||
num_in_channels=None,
|
||||
upcast_attention=False,
|
||||
upcast_attention=None,
|
||||
extract_ema=False,
|
||||
image_size=None,
|
||||
torch_dtype=None,
|
||||
@@ -1204,7 +1307,8 @@ def create_diffusers_unet_model_from_ldm(
|
||||
)
|
||||
unet_config = create_unet_diffusers_config(original_config, image_size=image_size)
|
||||
unet_config["in_channels"] = num_in_channels
|
||||
unet_config["upcast_attention"] = upcast_attention
|
||||
if upcast_attention is not None:
|
||||
unet_config["upcast_attention"] = upcast_attention
|
||||
|
||||
diffusers_format_unet_checkpoint = convert_ldm_unet_checkpoint(checkpoint, unet_config, extract_ema=extract_ema)
|
||||
ctx = init_empty_weights if is_accelerate_available() else nullcontext
|
||||
@@ -1221,7 +1325,7 @@ def create_diffusers_unet_model_from_ldm(
|
||||
unexpected_keys = [k for k in unexpected_keys if re.search(pat, k) is None]
|
||||
|
||||
if len(unexpected_keys) > 0:
|
||||
logger.warn(
|
||||
logger.warning(
|
||||
f"Some weights of the model checkpoint were not used when initializing {unet.__name__}: \n {[', '.join(unexpected_keys)]}"
|
||||
)
|
||||
else:
|
||||
@@ -1283,7 +1387,7 @@ def create_diffusers_vae_model_from_ldm(
|
||||
unexpected_keys = [k for k in unexpected_keys if re.search(pat, k) is None]
|
||||
|
||||
if len(unexpected_keys) > 0:
|
||||
logger.warn(
|
||||
logger.warning(
|
||||
f"Some weights of the model checkpoint were not used when initializing {vae.__name__}: \n {[', '.join(unexpected_keys)]}"
|
||||
)
|
||||
else:
|
||||
|
||||
@@ -42,6 +42,11 @@ from ..utils import (
|
||||
set_adapter_layers,
|
||||
set_weights_and_activate_adapters,
|
||||
)
|
||||
from .single_file_utils import (
|
||||
convert_stable_cascade_unet_single_file_to_diffusers,
|
||||
infer_stable_cascade_single_file_config,
|
||||
load_single_file_model_checkpoint,
|
||||
)
|
||||
from .utils import AttnProcsLayers
|
||||
|
||||
|
||||
@@ -345,7 +350,7 @@ class UNet2DConditionLoadersMixin:
|
||||
is_model_cpu_offload = False
|
||||
is_sequential_cpu_offload = False
|
||||
|
||||
# For PEFT backend the Unet is already offloaded at this stage as it is handled inside `lora_lora_weights_into_unet`
|
||||
# For PEFT backend the Unet is already offloaded at this stage as it is handled inside `load_lora_weights_into_unet`
|
||||
if not USE_PEFT_BACKEND:
|
||||
if _pipeline is not None:
|
||||
for _, component in _pipeline.components.items():
|
||||
@@ -384,7 +389,7 @@ class UNet2DConditionLoadersMixin:
|
||||
is_text_encoder_present = any(key.startswith(self.text_encoder_name) for key in state_dict.keys())
|
||||
if is_text_encoder_present:
|
||||
warn_message = "The state_dict contains LoRA params corresponding to the text encoder which are not being used here. To use both UNet and text encoder related LoRA params, use [`pipe.load_lora_weights()`](https://huggingface.co/docs/diffusers/main/en/api/loaders#diffusers.loaders.LoraLoaderMixin.load_lora_weights)."
|
||||
logger.warn(warn_message)
|
||||
logger.warning(warn_message)
|
||||
unet_keys = [k for k in state_dict.keys() if k.startswith(self.unet_name)]
|
||||
state_dict = {k.replace(f"{self.unet_name}.", ""): v for k, v in state_dict.items() if k in unet_keys}
|
||||
|
||||
@@ -896,3 +901,103 @@ class UNet2DConditionLoadersMixin:
|
||||
self.config.encoder_hid_dim_type = "ip_image_proj"
|
||||
|
||||
self.to(dtype=self.dtype, device=self.device)
|
||||
|
||||
|
||||
class FromOriginalUNetMixin:
|
||||
"""
|
||||
Load pretrained UNet model weights saved in the `.ckpt` or `.safetensors` format into a [`StableCascadeUNet`].
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
@validate_hf_hub_args
|
||||
def from_single_file(cls, pretrained_model_link_or_path, **kwargs):
|
||||
r"""
|
||||
Instantiate a [`StableCascadeUNet`] from pretrained StableCascadeUNet weights saved in the original `.ckpt` or
|
||||
`.safetensors` format. The pipeline is set in evaluation mode (`model.eval()`) by default.
|
||||
|
||||
Parameters:
|
||||
pretrained_model_link_or_path (`str` or `os.PathLike`, *optional*):
|
||||
Can be either:
|
||||
- A link to the `.ckpt` file (for example
|
||||
`"https://huggingface.co/<repo_id>/blob/main/<path_to_file>.ckpt"`) on the Hub.
|
||||
- A path to a *file* containing all pipeline weights.
|
||||
config: (`dict`, *optional*):
|
||||
Dictionary containing the configuration of the model:
|
||||
torch_dtype (`str` or `torch.dtype`, *optional*):
|
||||
Override the default `torch.dtype` and load the model with another dtype. If `"auto"` is passed, the
|
||||
dtype is automatically derived from the model's weights.
|
||||
force_download (`bool`, *optional*, defaults to `False`):
|
||||
Whether or not to force the (re-)download of the model weights and configuration files, overriding the
|
||||
cached versions if they exist.
|
||||
cache_dir (`Union[str, os.PathLike]`, *optional*):
|
||||
Path to a directory where a downloaded pretrained model configuration is cached if the standard cache
|
||||
is not used.
|
||||
resume_download (`bool`, *optional*, defaults to `False`):
|
||||
Whether or not to resume downloading the model weights and configuration files. If set to `False`, any
|
||||
incompletely downloaded files are deleted.
|
||||
proxies (`Dict[str, str]`, *optional*):
|
||||
A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128',
|
||||
'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
|
||||
local_files_only (`bool`, *optional*, defaults to `False`):
|
||||
Whether to only load local model weights and configuration files or not. If set to True, the model
|
||||
won't be downloaded from the Hub.
|
||||
token (`str` or *bool*, *optional*):
|
||||
The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from
|
||||
`diffusers-cli login` (stored in `~/.huggingface`) is used.
|
||||
revision (`str`, *optional*, defaults to `"main"`):
|
||||
The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier
|
||||
allowed by Git.
|
||||
kwargs (remaining dictionary of keyword arguments, *optional*):
|
||||
Can be used to overwrite load and saveable variables of the model.
|
||||
|
||||
"""
|
||||
class_name = cls.__name__
|
||||
if class_name != "StableCascadeUNet":
|
||||
raise ValueError("FromOriginalUNetMixin is currently only compatible with StableCascadeUNet")
|
||||
|
||||
config = kwargs.pop("config", None)
|
||||
resume_download = kwargs.pop("resume_download", False)
|
||||
force_download = kwargs.pop("force_download", False)
|
||||
proxies = kwargs.pop("proxies", None)
|
||||
token = kwargs.pop("token", None)
|
||||
cache_dir = kwargs.pop("cache_dir", None)
|
||||
local_files_only = kwargs.pop("local_files_only", None)
|
||||
revision = kwargs.pop("revision", None)
|
||||
torch_dtype = kwargs.pop("torch_dtype", None)
|
||||
|
||||
checkpoint = load_single_file_model_checkpoint(
|
||||
pretrained_model_link_or_path,
|
||||
resume_download=resume_download,
|
||||
force_download=force_download,
|
||||
proxies=proxies,
|
||||
token=token,
|
||||
cache_dir=cache_dir,
|
||||
local_files_only=local_files_only,
|
||||
revision=revision,
|
||||
)
|
||||
|
||||
if config is None:
|
||||
config = infer_stable_cascade_single_file_config(checkpoint)
|
||||
model_config = cls.load_config(**config, **kwargs)
|
||||
else:
|
||||
model_config = config
|
||||
|
||||
ctx = init_empty_weights if is_accelerate_available() else nullcontext
|
||||
with ctx():
|
||||
model = cls.from_config(model_config, **kwargs)
|
||||
|
||||
diffusers_format_checkpoint = convert_stable_cascade_unet_single_file_to_diffusers(checkpoint)
|
||||
if is_accelerate_available():
|
||||
unexpected_keys = load_model_dict_into_meta(model, diffusers_format_checkpoint, dtype=torch_dtype)
|
||||
if len(unexpected_keys) > 0:
|
||||
logger.warn(
|
||||
f"Some weights of the model checkpoint were not used when initializing {cls.__name__}: \n {[', '.join(unexpected_keys)]}"
|
||||
)
|
||||
|
||||
else:
|
||||
model.load_state_dict(diffusers_format_checkpoint)
|
||||
|
||||
if torch_dtype is not None:
|
||||
model.to(torch_dtype)
|
||||
|
||||
return model
|
||||
|
||||
@@ -17,8 +17,7 @@ import torch
|
||||
import torch.nn.functional as F
|
||||
from torch import nn
|
||||
|
||||
from ..utils import USE_PEFT_BACKEND
|
||||
from .lora import LoRACompatibleLinear
|
||||
from ..utils import deprecate
|
||||
|
||||
|
||||
ACTIVATION_FUNCTIONS = {
|
||||
@@ -87,9 +86,7 @@ class GEGLU(nn.Module):
|
||||
|
||||
def __init__(self, dim_in: int, dim_out: int, bias: bool = True):
|
||||
super().__init__()
|
||||
linear_cls = LoRACompatibleLinear if not USE_PEFT_BACKEND else nn.Linear
|
||||
|
||||
self.proj = linear_cls(dim_in, dim_out * 2, bias=bias)
|
||||
self.proj = nn.Linear(dim_in, dim_out * 2, bias=bias)
|
||||
|
||||
def gelu(self, gate: torch.Tensor) -> torch.Tensor:
|
||||
if gate.device.type != "mps":
|
||||
@@ -97,9 +94,12 @@ class GEGLU(nn.Module):
|
||||
# mps: gelu is not implemented for float16
|
||||
return F.gelu(gate.to(dtype=torch.float32)).to(dtype=gate.dtype)
|
||||
|
||||
def forward(self, hidden_states, scale: float = 1.0):
|
||||
args = () if USE_PEFT_BACKEND else (scale,)
|
||||
hidden_states, gate = self.proj(hidden_states, *args).chunk(2, dim=-1)
|
||||
def forward(self, hidden_states, *args, **kwargs):
|
||||
if len(args) > 0 or kwargs.get("scale", None) is not None:
|
||||
deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`."
|
||||
deprecate("scale", "1.0.0", deprecation_message)
|
||||
|
||||
hidden_states, gate = self.proj(hidden_states).chunk(2, dim=-1)
|
||||
return hidden_states * self.gelu(gate)
|
||||
|
||||
|
||||
|
||||
@@ -17,18 +17,18 @@ import torch
|
||||
import torch.nn.functional as F
|
||||
from torch import nn
|
||||
|
||||
from ..utils import USE_PEFT_BACKEND
|
||||
from ..utils import deprecate, logging
|
||||
from ..utils.torch_utils import maybe_allow_in_graph
|
||||
from .activations import GEGLU, GELU, ApproximateGELU
|
||||
from .attention_processor import Attention
|
||||
from .embeddings import SinusoidalPositionalEmbedding
|
||||
from .lora import LoRACompatibleLinear
|
||||
from .normalization import AdaLayerNorm, AdaLayerNormContinuous, AdaLayerNormZero, RMSNorm
|
||||
|
||||
|
||||
def _chunked_feed_forward(
|
||||
ff: nn.Module, hidden_states: torch.Tensor, chunk_dim: int, chunk_size: int, lora_scale: Optional[float] = None
|
||||
):
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
def _chunked_feed_forward(ff: nn.Module, hidden_states: torch.Tensor, chunk_dim: int, chunk_size: int):
|
||||
# "feed_forward_chunk_size" can be used to save memory
|
||||
if hidden_states.shape[chunk_dim] % chunk_size != 0:
|
||||
raise ValueError(
|
||||
@@ -36,18 +36,10 @@ def _chunked_feed_forward(
|
||||
)
|
||||
|
||||
num_chunks = hidden_states.shape[chunk_dim] // chunk_size
|
||||
if lora_scale is None:
|
||||
ff_output = torch.cat(
|
||||
[ff(hid_slice) for hid_slice in hidden_states.chunk(num_chunks, dim=chunk_dim)],
|
||||
dim=chunk_dim,
|
||||
)
|
||||
else:
|
||||
# TOOD(Patrick): LoRA scale can be removed once PEFT refactor is complete
|
||||
ff_output = torch.cat(
|
||||
[ff(hid_slice, scale=lora_scale) for hid_slice in hidden_states.chunk(num_chunks, dim=chunk_dim)],
|
||||
dim=chunk_dim,
|
||||
)
|
||||
|
||||
ff_output = torch.cat(
|
||||
[ff(hid_slice) for hid_slice in hidden_states.chunk(num_chunks, dim=chunk_dim)],
|
||||
dim=chunk_dim,
|
||||
)
|
||||
return ff_output
|
||||
|
||||
|
||||
@@ -143,7 +135,7 @@ class BasicTransformerBlock(nn.Module):
|
||||
double_self_attention: bool = False,
|
||||
upcast_attention: bool = False,
|
||||
norm_elementwise_affine: bool = True,
|
||||
norm_type: str = "layer_norm", # 'layer_norm', 'ada_norm', 'ada_norm_zero', 'ada_norm_single', 'layer_norm_i2vgen'
|
||||
norm_type: str = "layer_norm", # 'layer_norm', 'ada_norm', 'ada_norm_zero', 'ada_norm_single', 'ada_norm_continuous', 'layer_norm_i2vgen'
|
||||
norm_eps: float = 1e-5,
|
||||
final_dropout: bool = False,
|
||||
attention_type: str = "default",
|
||||
@@ -299,6 +291,10 @@ class BasicTransformerBlock(nn.Module):
|
||||
class_labels: Optional[torch.LongTensor] = None,
|
||||
added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None,
|
||||
) -> torch.FloatTensor:
|
||||
if cross_attention_kwargs is not None:
|
||||
if cross_attention_kwargs.get("scale", None) is not None:
|
||||
logger.warning("Passing `scale` to `cross_attention_kwargs` is deprecated. `scale` will be ignored.")
|
||||
|
||||
# Notice that normalization is always applied before the real computation in the following blocks.
|
||||
# 0. Self-Attention
|
||||
batch_size = hidden_states.shape[0]
|
||||
@@ -326,10 +322,7 @@ class BasicTransformerBlock(nn.Module):
|
||||
if self.pos_embed is not None:
|
||||
norm_hidden_states = self.pos_embed(norm_hidden_states)
|
||||
|
||||
# 1. Retrieve lora scale.
|
||||
lora_scale = cross_attention_kwargs.get("scale", 1.0) if cross_attention_kwargs is not None else 1.0
|
||||
|
||||
# 2. Prepare GLIGEN inputs
|
||||
# 1. Prepare GLIGEN inputs
|
||||
cross_attention_kwargs = cross_attention_kwargs.copy() if cross_attention_kwargs is not None else {}
|
||||
gligen_kwargs = cross_attention_kwargs.pop("gligen", None)
|
||||
|
||||
@@ -348,7 +341,7 @@ class BasicTransformerBlock(nn.Module):
|
||||
if hidden_states.ndim == 4:
|
||||
hidden_states = hidden_states.squeeze(1)
|
||||
|
||||
# 2.5 GLIGEN Control
|
||||
# 1.2 GLIGEN Control
|
||||
if gligen_kwargs is not None:
|
||||
hidden_states = self.fuser(hidden_states, gligen_kwargs["objs"])
|
||||
|
||||
@@ -394,11 +387,9 @@ class BasicTransformerBlock(nn.Module):
|
||||
|
||||
if self._chunk_size is not None:
|
||||
# "feed_forward_chunk_size" can be used to save memory
|
||||
ff_output = _chunked_feed_forward(
|
||||
self.ff, norm_hidden_states, self._chunk_dim, self._chunk_size, lora_scale=lora_scale
|
||||
)
|
||||
ff_output = _chunked_feed_forward(self.ff, norm_hidden_states, self._chunk_dim, self._chunk_size)
|
||||
else:
|
||||
ff_output = self.ff(norm_hidden_states, scale=lora_scale)
|
||||
ff_output = self.ff(norm_hidden_states)
|
||||
|
||||
if self.norm_type == "ada_norm_zero":
|
||||
ff_output = gate_mlp.unsqueeze(1) * ff_output
|
||||
@@ -643,7 +634,7 @@ class FeedForward(nn.Module):
|
||||
if inner_dim is None:
|
||||
inner_dim = int(dim * mult)
|
||||
dim_out = dim_out if dim_out is not None else dim
|
||||
linear_cls = LoRACompatibleLinear if not USE_PEFT_BACKEND else nn.Linear
|
||||
linear_cls = nn.Linear
|
||||
|
||||
if activation_fn == "gelu":
|
||||
act_fn = GELU(dim, inner_dim, bias=bias)
|
||||
@@ -665,11 +656,10 @@ class FeedForward(nn.Module):
|
||||
if final_dropout:
|
||||
self.net.append(nn.Dropout(dropout))
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor, scale: float = 1.0) -> torch.Tensor:
|
||||
compatible_cls = (GEGLU,) if USE_PEFT_BACKEND else (GEGLU, LoRACompatibleLinear)
|
||||
def forward(self, hidden_states: torch.Tensor, *args, **kwargs) -> torch.Tensor:
|
||||
if len(args) > 0 or kwargs.get("scale", None) is not None:
|
||||
deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`."
|
||||
deprecate("scale", "1.0.0", deprecation_message)
|
||||
for module in self.net:
|
||||
if isinstance(module, compatible_cls):
|
||||
hidden_states = module(hidden_states, scale)
|
||||
else:
|
||||
hidden_states = module(hidden_states)
|
||||
hidden_states = module(hidden_states)
|
||||
return hidden_states
|
||||
|
||||
@@ -20,10 +20,10 @@ import torch.nn.functional as F
|
||||
from torch import nn
|
||||
|
||||
from ..image_processor import IPAdapterMaskProcessor
|
||||
from ..utils import USE_PEFT_BACKEND, deprecate, logging
|
||||
from ..utils import deprecate, logging
|
||||
from ..utils.import_utils import is_xformers_available
|
||||
from ..utils.torch_utils import maybe_allow_in_graph
|
||||
from .lora import LoRACompatibleLinear, LoRALinearLayer
|
||||
from .lora import LoRALinearLayer
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
@@ -181,10 +181,7 @@ class Attention(nn.Module):
|
||||
f"unknown cross_attention_norm: {cross_attention_norm}. Should be None, 'layer_norm' or 'group_norm'"
|
||||
)
|
||||
|
||||
if USE_PEFT_BACKEND:
|
||||
linear_cls = nn.Linear
|
||||
else:
|
||||
linear_cls = LoRACompatibleLinear
|
||||
linear_cls = nn.Linear
|
||||
|
||||
self.linear_cls = linear_cls
|
||||
self.to_q = linear_cls(query_dim, self.inner_dim, bias=bias)
|
||||
@@ -741,11 +738,14 @@ class AttnProcessor:
|
||||
encoder_hidden_states: Optional[torch.FloatTensor] = None,
|
||||
attention_mask: Optional[torch.FloatTensor] = None,
|
||||
temb: Optional[torch.FloatTensor] = None,
|
||||
scale: float = 1.0,
|
||||
*args,
|
||||
**kwargs,
|
||||
) -> torch.Tensor:
|
||||
residual = hidden_states
|
||||
if len(args) > 0 or kwargs.get("scale", None) is not None:
|
||||
deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`."
|
||||
deprecate("scale", "1.0.0", deprecation_message)
|
||||
|
||||
args = () if USE_PEFT_BACKEND else (scale,)
|
||||
residual = hidden_states
|
||||
|
||||
if attn.spatial_norm is not None:
|
||||
hidden_states = attn.spatial_norm(hidden_states, temb)
|
||||
@@ -764,15 +764,26 @@ class AttnProcessor:
|
||||
if attn.group_norm is not None:
|
||||
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
|
||||
|
||||
query = attn.to_q(hidden_states, *args)
|
||||
query = attn.to_q(hidden_states)
|
||||
|
||||
if encoder_hidden_states is None:
|
||||
encoder_hidden_states = hidden_states
|
||||
# encoder_hidden_states = hidden_states
|
||||
batch, seq, dim = hidden_states.shape
|
||||
height = width = seq**0.5
|
||||
# reshape to (batch, height, width, dim)
|
||||
encoder_hidden_states = hidden_states.view(batch, height, width, dim)
|
||||
# reshape to (batch, dim, height, width)
|
||||
encoder_hidden_states = encoder_hidden_states.permute(0, 3, 1, 2)
|
||||
encoder_hidden_states = torch.nn.functional.avg_pool2d(hidden_states, kernel_size=4)
|
||||
# reshape to (batch, dim, seq)
|
||||
encoder_hidden_states = encoder_hidden_states.view(batch, dim, -1)
|
||||
# reshape to (batch, seq, dim)
|
||||
encoder_hidden_states = encoder_hidden_states.permute(0, 2, 1)
|
||||
elif attn.norm_cross:
|
||||
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
|
||||
|
||||
key = attn.to_k(encoder_hidden_states, *args)
|
||||
value = attn.to_v(encoder_hidden_states, *args)
|
||||
key = attn.to_k(encoder_hidden_states)
|
||||
value = attn.to_v(encoder_hidden_states)
|
||||
|
||||
query = attn.head_to_batch_dim(query)
|
||||
key = attn.head_to_batch_dim(key)
|
||||
@@ -783,7 +794,7 @@ class AttnProcessor:
|
||||
hidden_states = attn.batch_to_head_dim(hidden_states)
|
||||
|
||||
# linear proj
|
||||
hidden_states = attn.to_out[0](hidden_states, *args)
|
||||
hidden_states = attn.to_out[0](hidden_states)
|
||||
# dropout
|
||||
hidden_states = attn.to_out[1](hidden_states)
|
||||
|
||||
@@ -914,11 +925,14 @@ class AttnAddedKVProcessor:
|
||||
hidden_states: torch.FloatTensor,
|
||||
encoder_hidden_states: Optional[torch.FloatTensor] = None,
|
||||
attention_mask: Optional[torch.FloatTensor] = None,
|
||||
scale: float = 1.0,
|
||||
*args,
|
||||
**kwargs,
|
||||
) -> torch.Tensor:
|
||||
residual = hidden_states
|
||||
if len(args) > 0 or kwargs.get("scale", None) is not None:
|
||||
deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`."
|
||||
deprecate("scale", "1.0.0", deprecation_message)
|
||||
|
||||
args = () if USE_PEFT_BACKEND else (scale,)
|
||||
residual = hidden_states
|
||||
|
||||
hidden_states = hidden_states.view(hidden_states.shape[0], hidden_states.shape[1], -1).transpose(1, 2)
|
||||
batch_size, sequence_length, _ = hidden_states.shape
|
||||
@@ -932,17 +946,17 @@ class AttnAddedKVProcessor:
|
||||
|
||||
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
|
||||
|
||||
query = attn.to_q(hidden_states, *args)
|
||||
query = attn.to_q(hidden_states)
|
||||
query = attn.head_to_batch_dim(query)
|
||||
|
||||
encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states, *args)
|
||||
encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states, *args)
|
||||
encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states)
|
||||
encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states)
|
||||
encoder_hidden_states_key_proj = attn.head_to_batch_dim(encoder_hidden_states_key_proj)
|
||||
encoder_hidden_states_value_proj = attn.head_to_batch_dim(encoder_hidden_states_value_proj)
|
||||
|
||||
if not attn.only_cross_attention:
|
||||
key = attn.to_k(hidden_states, *args)
|
||||
value = attn.to_v(hidden_states, *args)
|
||||
key = attn.to_k(hidden_states)
|
||||
value = attn.to_v(hidden_states)
|
||||
key = attn.head_to_batch_dim(key)
|
||||
value = attn.head_to_batch_dim(value)
|
||||
key = torch.cat([encoder_hidden_states_key_proj, key], dim=1)
|
||||
@@ -956,7 +970,7 @@ class AttnAddedKVProcessor:
|
||||
hidden_states = attn.batch_to_head_dim(hidden_states)
|
||||
|
||||
# linear proj
|
||||
hidden_states = attn.to_out[0](hidden_states, *args)
|
||||
hidden_states = attn.to_out[0](hidden_states)
|
||||
# dropout
|
||||
hidden_states = attn.to_out[1](hidden_states)
|
||||
|
||||
@@ -984,11 +998,14 @@ class AttnAddedKVProcessor2_0:
|
||||
hidden_states: torch.FloatTensor,
|
||||
encoder_hidden_states: Optional[torch.FloatTensor] = None,
|
||||
attention_mask: Optional[torch.FloatTensor] = None,
|
||||
scale: float = 1.0,
|
||||
*args,
|
||||
**kwargs,
|
||||
) -> torch.Tensor:
|
||||
residual = hidden_states
|
||||
if len(args) > 0 or kwargs.get("scale", None) is not None:
|
||||
deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`."
|
||||
deprecate("scale", "1.0.0", deprecation_message)
|
||||
|
||||
args = () if USE_PEFT_BACKEND else (scale,)
|
||||
residual = hidden_states
|
||||
|
||||
hidden_states = hidden_states.view(hidden_states.shape[0], hidden_states.shape[1], -1).transpose(1, 2)
|
||||
batch_size, sequence_length, _ = hidden_states.shape
|
||||
@@ -1002,7 +1019,7 @@ class AttnAddedKVProcessor2_0:
|
||||
|
||||
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
|
||||
|
||||
query = attn.to_q(hidden_states, *args)
|
||||
query = attn.to_q(hidden_states)
|
||||
query = attn.head_to_batch_dim(query, out_dim=4)
|
||||
|
||||
encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states)
|
||||
@@ -1011,8 +1028,8 @@ class AttnAddedKVProcessor2_0:
|
||||
encoder_hidden_states_value_proj = attn.head_to_batch_dim(encoder_hidden_states_value_proj, out_dim=4)
|
||||
|
||||
if not attn.only_cross_attention:
|
||||
key = attn.to_k(hidden_states, *args)
|
||||
value = attn.to_v(hidden_states, *args)
|
||||
key = attn.to_k(hidden_states)
|
||||
value = attn.to_v(hidden_states)
|
||||
key = attn.head_to_batch_dim(key, out_dim=4)
|
||||
value = attn.head_to_batch_dim(value, out_dim=4)
|
||||
key = torch.cat([encoder_hidden_states_key_proj, key], dim=2)
|
||||
@@ -1029,7 +1046,7 @@ class AttnAddedKVProcessor2_0:
|
||||
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, residual.shape[1])
|
||||
|
||||
# linear proj
|
||||
hidden_states = attn.to_out[0](hidden_states, *args)
|
||||
hidden_states = attn.to_out[0](hidden_states)
|
||||
# dropout
|
||||
hidden_states = attn.to_out[1](hidden_states)
|
||||
|
||||
@@ -1132,11 +1149,14 @@ class XFormersAttnProcessor:
|
||||
encoder_hidden_states: Optional[torch.FloatTensor] = None,
|
||||
attention_mask: Optional[torch.FloatTensor] = None,
|
||||
temb: Optional[torch.FloatTensor] = None,
|
||||
scale: float = 1.0,
|
||||
*args,
|
||||
**kwargs,
|
||||
) -> torch.FloatTensor:
|
||||
residual = hidden_states
|
||||
if len(args) > 0 or kwargs.get("scale", None) is not None:
|
||||
deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`."
|
||||
deprecate("scale", "1.0.0", deprecation_message)
|
||||
|
||||
args = () if USE_PEFT_BACKEND else (scale,)
|
||||
residual = hidden_states
|
||||
|
||||
if attn.spatial_norm is not None:
|
||||
hidden_states = attn.spatial_norm(hidden_states, temb)
|
||||
@@ -1165,15 +1185,15 @@ class XFormersAttnProcessor:
|
||||
if attn.group_norm is not None:
|
||||
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
|
||||
|
||||
query = attn.to_q(hidden_states, *args)
|
||||
query = attn.to_q(hidden_states)
|
||||
|
||||
if encoder_hidden_states is None:
|
||||
encoder_hidden_states = hidden_states
|
||||
elif attn.norm_cross:
|
||||
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
|
||||
|
||||
key = attn.to_k(encoder_hidden_states, *args)
|
||||
value = attn.to_v(encoder_hidden_states, *args)
|
||||
key = attn.to_k(encoder_hidden_states)
|
||||
value = attn.to_v(encoder_hidden_states)
|
||||
|
||||
query = attn.head_to_batch_dim(query).contiguous()
|
||||
key = attn.head_to_batch_dim(key).contiguous()
|
||||
@@ -1186,7 +1206,7 @@ class XFormersAttnProcessor:
|
||||
hidden_states = attn.batch_to_head_dim(hidden_states)
|
||||
|
||||
# linear proj
|
||||
hidden_states = attn.to_out[0](hidden_states, *args)
|
||||
hidden_states = attn.to_out[0](hidden_states)
|
||||
# dropout
|
||||
hidden_states = attn.to_out[1](hidden_states)
|
||||
|
||||
@@ -1217,8 +1237,13 @@ class AttnProcessor2_0:
|
||||
encoder_hidden_states: Optional[torch.FloatTensor] = None,
|
||||
attention_mask: Optional[torch.FloatTensor] = None,
|
||||
temb: Optional[torch.FloatTensor] = None,
|
||||
scale: float = 1.0,
|
||||
*args,
|
||||
**kwargs,
|
||||
) -> torch.FloatTensor:
|
||||
if len(args) > 0 or kwargs.get("scale", None) is not None:
|
||||
deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`."
|
||||
deprecate("scale", "1.0.0", deprecation_message)
|
||||
|
||||
residual = hidden_states
|
||||
if attn.spatial_norm is not None:
|
||||
hidden_states = attn.spatial_norm(hidden_states, temb)
|
||||
@@ -1242,16 +1267,26 @@ class AttnProcessor2_0:
|
||||
if attn.group_norm is not None:
|
||||
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
|
||||
|
||||
args = () if USE_PEFT_BACKEND else (scale,)
|
||||
query = attn.to_q(hidden_states, *args)
|
||||
query = attn.to_q(hidden_states)
|
||||
|
||||
if encoder_hidden_states is None:
|
||||
encoder_hidden_states = hidden_states
|
||||
# encoder_hidden_states = hidden_states
|
||||
batch, seq, dim = hidden_states.shape
|
||||
height = width = seq**0.5
|
||||
# reshape to (batch, height, width, dim)
|
||||
encoder_hidden_states = hidden_states.view(batch, height, width, dim)
|
||||
# reshape to (batch, dim, height, width)
|
||||
encoder_hidden_states = encoder_hidden_states.permute(0, 3, 1, 2)
|
||||
encoder_hidden_states = torch.nn.functional.avg_pool2d(hidden_states, kernel_size=4)
|
||||
# reshape to (batch, dim, seq)
|
||||
encoder_hidden_states = encoder_hidden_states.view(batch, dim, -1)
|
||||
# reshape to (batch, seq, dim)
|
||||
encoder_hidden_states = encoder_hidden_states.permute(0, 2, 1)
|
||||
elif attn.norm_cross:
|
||||
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
|
||||
|
||||
key = attn.to_k(encoder_hidden_states, *args)
|
||||
value = attn.to_v(encoder_hidden_states, *args)
|
||||
key = attn.to_k(encoder_hidden_states)
|
||||
value = attn.to_v(encoder_hidden_states)
|
||||
|
||||
inner_dim = key.shape[-1]
|
||||
head_dim = inner_dim // attn.heads
|
||||
@@ -1271,7 +1306,7 @@ class AttnProcessor2_0:
|
||||
hidden_states = hidden_states.to(query.dtype)
|
||||
|
||||
# linear proj
|
||||
hidden_states = attn.to_out[0](hidden_states, *args)
|
||||
hidden_states = attn.to_out[0](hidden_states)
|
||||
# dropout
|
||||
hidden_states = attn.to_out[1](hidden_states)
|
||||
|
||||
@@ -1312,8 +1347,13 @@ class FusedAttnProcessor2_0:
|
||||
encoder_hidden_states: Optional[torch.FloatTensor] = None,
|
||||
attention_mask: Optional[torch.FloatTensor] = None,
|
||||
temb: Optional[torch.FloatTensor] = None,
|
||||
scale: float = 1.0,
|
||||
*args,
|
||||
**kwargs,
|
||||
) -> torch.FloatTensor:
|
||||
if len(args) > 0 or kwargs.get("scale", None) is not None:
|
||||
deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`."
|
||||
deprecate("scale", "1.0.0", deprecation_message)
|
||||
|
||||
residual = hidden_states
|
||||
if attn.spatial_norm is not None:
|
||||
hidden_states = attn.spatial_norm(hidden_states, temb)
|
||||
@@ -1337,17 +1377,16 @@ class FusedAttnProcessor2_0:
|
||||
if attn.group_norm is not None:
|
||||
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
|
||||
|
||||
args = () if USE_PEFT_BACKEND else (scale,)
|
||||
if encoder_hidden_states is None:
|
||||
qkv = attn.to_qkv(hidden_states, *args)
|
||||
qkv = attn.to_qkv(hidden_states)
|
||||
split_size = qkv.shape[-1] // 3
|
||||
query, key, value = torch.split(qkv, split_size, dim=-1)
|
||||
else:
|
||||
if attn.norm_cross:
|
||||
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
|
||||
query = attn.to_q(hidden_states, *args)
|
||||
query = attn.to_q(hidden_states)
|
||||
|
||||
kv = attn.to_kv(encoder_hidden_states, *args)
|
||||
kv = attn.to_kv(encoder_hidden_states)
|
||||
split_size = kv.shape[-1] // 2
|
||||
key, value = torch.split(kv, split_size, dim=-1)
|
||||
|
||||
@@ -1368,7 +1407,7 @@ class FusedAttnProcessor2_0:
|
||||
hidden_states = hidden_states.to(query.dtype)
|
||||
|
||||
# linear proj
|
||||
hidden_states = attn.to_out[0](hidden_states, *args)
|
||||
hidden_states = attn.to_out[0](hidden_states)
|
||||
# dropout
|
||||
hidden_states = attn.to_out[1](hidden_states)
|
||||
|
||||
@@ -1859,7 +1898,7 @@ class LoRAAttnProcessor(nn.Module):
|
||||
self.to_v_lora = LoRALinearLayer(cross_attention_dim or v_hidden_size, v_hidden_size, v_rank, network_alpha)
|
||||
self.to_out_lora = LoRALinearLayer(out_hidden_size, out_hidden_size, out_rank, network_alpha)
|
||||
|
||||
def __call__(self, attn: Attention, hidden_states: torch.FloatTensor, *args, **kwargs) -> torch.FloatTensor:
|
||||
def __call__(self, attn: Attention, hidden_states: torch.FloatTensor, **kwargs) -> torch.FloatTensor:
|
||||
self_cls_name = self.__class__.__name__
|
||||
deprecate(
|
||||
self_cls_name,
|
||||
@@ -1877,7 +1916,7 @@ class LoRAAttnProcessor(nn.Module):
|
||||
|
||||
attn._modules.pop("processor")
|
||||
attn.processor = AttnProcessor()
|
||||
return attn.processor(attn, hidden_states, *args, **kwargs)
|
||||
return attn.processor(attn, hidden_states, **kwargs)
|
||||
|
||||
|
||||
class LoRAAttnProcessor2_0(nn.Module):
|
||||
@@ -1920,7 +1959,7 @@ class LoRAAttnProcessor2_0(nn.Module):
|
||||
self.to_v_lora = LoRALinearLayer(cross_attention_dim or v_hidden_size, v_hidden_size, v_rank, network_alpha)
|
||||
self.to_out_lora = LoRALinearLayer(out_hidden_size, out_hidden_size, out_rank, network_alpha)
|
||||
|
||||
def __call__(self, attn: Attention, hidden_states: torch.FloatTensor, *args, **kwargs) -> torch.FloatTensor:
|
||||
def __call__(self, attn: Attention, hidden_states: torch.FloatTensor, **kwargs) -> torch.FloatTensor:
|
||||
self_cls_name = self.__class__.__name__
|
||||
deprecate(
|
||||
self_cls_name,
|
||||
@@ -1938,7 +1977,7 @@ class LoRAAttnProcessor2_0(nn.Module):
|
||||
|
||||
attn._modules.pop("processor")
|
||||
attn.processor = AttnProcessor2_0()
|
||||
return attn.processor(attn, hidden_states, *args, **kwargs)
|
||||
return attn.processor(attn, hidden_states, **kwargs)
|
||||
|
||||
|
||||
class LoRAXFormersAttnProcessor(nn.Module):
|
||||
@@ -1999,7 +2038,7 @@ class LoRAXFormersAttnProcessor(nn.Module):
|
||||
self.to_v_lora = LoRALinearLayer(cross_attention_dim or v_hidden_size, v_hidden_size, v_rank, network_alpha)
|
||||
self.to_out_lora = LoRALinearLayer(out_hidden_size, out_hidden_size, out_rank, network_alpha)
|
||||
|
||||
def __call__(self, attn: Attention, hidden_states: torch.FloatTensor, *args, **kwargs) -> torch.FloatTensor:
|
||||
def __call__(self, attn: Attention, hidden_states: torch.FloatTensor, **kwargs) -> torch.FloatTensor:
|
||||
self_cls_name = self.__class__.__name__
|
||||
deprecate(
|
||||
self_cls_name,
|
||||
@@ -2017,7 +2056,7 @@ class LoRAXFormersAttnProcessor(nn.Module):
|
||||
|
||||
attn._modules.pop("processor")
|
||||
attn.processor = XFormersAttnProcessor()
|
||||
return attn.processor(attn, hidden_states, *args, **kwargs)
|
||||
return attn.processor(attn, hidden_states, **kwargs)
|
||||
|
||||
|
||||
class LoRAAttnAddedKVProcessor(nn.Module):
|
||||
@@ -2058,7 +2097,7 @@ class LoRAAttnAddedKVProcessor(nn.Module):
|
||||
self.to_v_lora = LoRALinearLayer(hidden_size, hidden_size, rank, network_alpha)
|
||||
self.to_out_lora = LoRALinearLayer(hidden_size, hidden_size, rank, network_alpha)
|
||||
|
||||
def __call__(self, attn: Attention, hidden_states: torch.FloatTensor, *args, **kwargs) -> torch.FloatTensor:
|
||||
def __call__(self, attn: Attention, hidden_states: torch.FloatTensor, **kwargs) -> torch.FloatTensor:
|
||||
self_cls_name = self.__class__.__name__
|
||||
deprecate(
|
||||
self_cls_name,
|
||||
@@ -2076,7 +2115,7 @@ class LoRAAttnAddedKVProcessor(nn.Module):
|
||||
|
||||
attn._modules.pop("processor")
|
||||
attn.processor = AttnAddedKVProcessor()
|
||||
return attn.processor(attn, hidden_states, *args, **kwargs)
|
||||
return attn.processor(attn, hidden_states, **kwargs)
|
||||
|
||||
|
||||
class IPAdapterAttnProcessor(nn.Module):
|
||||
|
||||
@@ -18,8 +18,7 @@ import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from ..utils import USE_PEFT_BACKEND
|
||||
from .lora import LoRACompatibleConv
|
||||
from ..utils import deprecate
|
||||
from .normalization import RMSNorm
|
||||
from .upsampling import upfirdn2d_native
|
||||
|
||||
@@ -103,7 +102,7 @@ class Downsample2D(nn.Module):
|
||||
self.padding = padding
|
||||
stride = 2
|
||||
self.name = name
|
||||
conv_cls = nn.Conv2d if USE_PEFT_BACKEND else LoRACompatibleConv
|
||||
conv_cls = nn.Conv2d
|
||||
|
||||
if norm_type == "ln_norm":
|
||||
self.norm = nn.LayerNorm(channels, eps, elementwise_affine)
|
||||
@@ -131,7 +130,10 @@ class Downsample2D(nn.Module):
|
||||
else:
|
||||
self.conv = conv
|
||||
|
||||
def forward(self, hidden_states: torch.FloatTensor, scale: float = 1.0) -> torch.FloatTensor:
|
||||
def forward(self, hidden_states: torch.FloatTensor, *args, **kwargs) -> torch.FloatTensor:
|
||||
if len(args) > 0 or kwargs.get("scale", None) is not None:
|
||||
deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`."
|
||||
deprecate("scale", "1.0.0", deprecation_message)
|
||||
assert hidden_states.shape[1] == self.channels
|
||||
|
||||
if self.norm is not None:
|
||||
@@ -143,13 +145,7 @@ class Downsample2D(nn.Module):
|
||||
|
||||
assert hidden_states.shape[1] == self.channels
|
||||
|
||||
if not USE_PEFT_BACKEND:
|
||||
if isinstance(self.conv, LoRACompatibleConv):
|
||||
hidden_states = self.conv(hidden_states, scale)
|
||||
else:
|
||||
hidden_states = self.conv(hidden_states)
|
||||
else:
|
||||
hidden_states = self.conv(hidden_states)
|
||||
hidden_states = self.conv(hidden_states)
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
@@ -18,10 +18,9 @@ import numpy as np
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
from ..utils import USE_PEFT_BACKEND, deprecate
|
||||
from ..utils import deprecate
|
||||
from .activations import get_activation
|
||||
from .attention_processor import Attention
|
||||
from .lora import LoRACompatibleLinear
|
||||
|
||||
|
||||
def get_timestep_embedding(
|
||||
@@ -200,7 +199,7 @@ class TimestepEmbedding(nn.Module):
|
||||
sample_proj_bias=True,
|
||||
):
|
||||
super().__init__()
|
||||
linear_cls = nn.Linear if USE_PEFT_BACKEND else LoRACompatibleLinear
|
||||
linear_cls = nn.Linear
|
||||
|
||||
self.linear_1 = linear_cls(in_channels, time_embed_dim, sample_proj_bias)
|
||||
|
||||
|
||||
@@ -204,6 +204,9 @@ class LoRALinearLayer(nn.Module):
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
deprecation_message = "Use of `LoRALinearLayer` is deprecated. Please switch to PEFT backend by installing PEFT: `pip install peft`."
|
||||
deprecate("LoRALinearLayer", "1.0.0", deprecation_message)
|
||||
|
||||
self.down = nn.Linear(in_features, rank, bias=False, device=device, dtype=dtype)
|
||||
self.up = nn.Linear(rank, out_features, bias=False, device=device, dtype=dtype)
|
||||
# This value has the same meaning as the `--network_alpha` option in the kohya-ss trainer script.
|
||||
@@ -264,6 +267,9 @@ class LoRAConv2dLayer(nn.Module):
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
deprecation_message = "Use of `LoRAConv2dLayer` is deprecated. Please switch to PEFT backend by installing PEFT: `pip install peft`."
|
||||
deprecate("LoRAConv2dLayer", "1.0.0", deprecation_message)
|
||||
|
||||
self.down = nn.Conv2d(in_features, rank, kernel_size=kernel_size, stride=stride, padding=padding, bias=False)
|
||||
# according to the official kohya_ss trainer kernel_size are always fixed for the up layer
|
||||
# # see: https://github.com/bmaltais/kohya_ss/blob/2accb1305979ba62f5077a23aabac23b4c37e935/networks/lora_diffusers.py#L129
|
||||
|
||||
@@ -124,9 +124,7 @@ def load_state_dict(checkpoint_file: Union[str, os.PathLike], variant: Optional[
|
||||
) from e
|
||||
except (UnicodeDecodeError, ValueError):
|
||||
raise OSError(
|
||||
f"Unable to load weights from checkpoint file for '{checkpoint_file}' "
|
||||
f"at '{checkpoint_file}'. "
|
||||
"If you tried to load a PyTorch model from a TF 2.0 checkpoint, please set from_tf=True."
|
||||
f"Unable to load weights from checkpoint file for '{checkpoint_file}' " f"at '{checkpoint_file}'. "
|
||||
)
|
||||
|
||||
|
||||
@@ -679,7 +677,7 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
|
||||
unexpected_keys = [k for k in unexpected_keys if re.search(pat, k) is None]
|
||||
|
||||
if len(unexpected_keys) > 0:
|
||||
logger.warn(
|
||||
logger.warning(
|
||||
f"Some weights of the model checkpoint were not used when initializing {cls.__name__}: \n {[', '.join(unexpected_keys)]}"
|
||||
)
|
||||
|
||||
@@ -707,7 +705,7 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
|
||||
# the weights so we don't have to do this again.
|
||||
|
||||
if "'Attention' object has no attribute" in str(e):
|
||||
logger.warn(
|
||||
logger.warning(
|
||||
f"Taking `{str(e)}` while using `accelerate.load_checkpoint_and_dispatch` to mean {pretrained_model_name_or_path}"
|
||||
" was saved with deprecated attention block weight names. We will load it with the deprecated attention block"
|
||||
" names and convert them on the fly to the new attention block format. Please re-save the model after this conversion,"
|
||||
|
||||
@@ -20,7 +20,7 @@ import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from ..utils import USE_PEFT_BACKEND
|
||||
from ..utils import deprecate
|
||||
from .activations import get_activation
|
||||
from .attention_processor import SpatialNorm
|
||||
from .downsampling import ( # noqa
|
||||
@@ -30,7 +30,6 @@ from .downsampling import ( # noqa
|
||||
KDownsample2D,
|
||||
downsample_2d,
|
||||
)
|
||||
from .lora import LoRACompatibleConv, LoRACompatibleLinear
|
||||
from .normalization import AdaGroupNorm
|
||||
from .upsampling import ( # noqa
|
||||
FirUpsample2D,
|
||||
@@ -102,7 +101,7 @@ class ResnetBlockCondNorm2D(nn.Module):
|
||||
self.output_scale_factor = output_scale_factor
|
||||
self.time_embedding_norm = time_embedding_norm
|
||||
|
||||
conv_cls = nn.Conv2d if USE_PEFT_BACKEND else LoRACompatibleConv
|
||||
conv_cls = nn.Conv2d
|
||||
|
||||
if groups_out is None:
|
||||
groups_out = groups
|
||||
@@ -149,12 +148,11 @@ class ResnetBlockCondNorm2D(nn.Module):
|
||||
bias=conv_shortcut_bias,
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_tensor: torch.FloatTensor,
|
||||
temb: torch.FloatTensor,
|
||||
scale: float = 1.0,
|
||||
) -> torch.FloatTensor:
|
||||
def forward(self, input_tensor: torch.FloatTensor, temb: torch.FloatTensor, *args, **kwargs) -> torch.FloatTensor:
|
||||
if len(args) > 0 or kwargs.get("scale", None) is not None:
|
||||
deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`."
|
||||
deprecate("scale", "1.0.0", deprecation_message)
|
||||
|
||||
hidden_states = input_tensor
|
||||
|
||||
hidden_states = self.norm1(hidden_states, temb)
|
||||
@@ -166,26 +164,24 @@ class ResnetBlockCondNorm2D(nn.Module):
|
||||
if hidden_states.shape[0] >= 64:
|
||||
input_tensor = input_tensor.contiguous()
|
||||
hidden_states = hidden_states.contiguous()
|
||||
input_tensor = self.upsample(input_tensor, scale=scale)
|
||||
hidden_states = self.upsample(hidden_states, scale=scale)
|
||||
input_tensor = self.upsample(input_tensor)
|
||||
hidden_states = self.upsample(hidden_states)
|
||||
|
||||
elif self.downsample is not None:
|
||||
input_tensor = self.downsample(input_tensor, scale=scale)
|
||||
hidden_states = self.downsample(hidden_states, scale=scale)
|
||||
input_tensor = self.downsample(input_tensor)
|
||||
hidden_states = self.downsample(hidden_states)
|
||||
|
||||
hidden_states = self.conv1(hidden_states, scale) if not USE_PEFT_BACKEND else self.conv1(hidden_states)
|
||||
hidden_states = self.conv1(hidden_states)
|
||||
|
||||
hidden_states = self.norm2(hidden_states, temb)
|
||||
|
||||
hidden_states = self.nonlinearity(hidden_states)
|
||||
|
||||
hidden_states = self.dropout(hidden_states)
|
||||
hidden_states = self.conv2(hidden_states, scale) if not USE_PEFT_BACKEND else self.conv2(hidden_states)
|
||||
hidden_states = self.conv2(hidden_states)
|
||||
|
||||
if self.conv_shortcut is not None:
|
||||
input_tensor = (
|
||||
self.conv_shortcut(input_tensor, scale) if not USE_PEFT_BACKEND else self.conv_shortcut(input_tensor)
|
||||
)
|
||||
input_tensor = self.conv_shortcut(input_tensor)
|
||||
|
||||
output_tensor = (input_tensor + hidden_states) / self.output_scale_factor
|
||||
|
||||
@@ -267,8 +263,8 @@ class ResnetBlock2D(nn.Module):
|
||||
self.time_embedding_norm = time_embedding_norm
|
||||
self.skip_time_act = skip_time_act
|
||||
|
||||
linear_cls = nn.Linear if USE_PEFT_BACKEND else LoRACompatibleLinear
|
||||
conv_cls = nn.Conv2d if USE_PEFT_BACKEND else LoRACompatibleConv
|
||||
linear_cls = nn.Linear
|
||||
conv_cls = nn.Conv2d
|
||||
|
||||
if groups_out is None:
|
||||
groups_out = groups
|
||||
@@ -326,12 +322,11 @@ class ResnetBlock2D(nn.Module):
|
||||
bias=conv_shortcut_bias,
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_tensor: torch.FloatTensor,
|
||||
temb: torch.FloatTensor,
|
||||
scale: float = 1.0,
|
||||
) -> torch.FloatTensor:
|
||||
def forward(self, input_tensor: torch.FloatTensor, temb: torch.FloatTensor, *args, **kwargs) -> torch.FloatTensor:
|
||||
if len(args) > 0 or kwargs.get("scale", None) is not None:
|
||||
deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`."
|
||||
deprecate("scale", "1.0.0", deprecation_message)
|
||||
|
||||
hidden_states = input_tensor
|
||||
|
||||
hidden_states = self.norm1(hidden_states)
|
||||
@@ -342,38 +337,18 @@ class ResnetBlock2D(nn.Module):
|
||||
if hidden_states.shape[0] >= 64:
|
||||
input_tensor = input_tensor.contiguous()
|
||||
hidden_states = hidden_states.contiguous()
|
||||
input_tensor = (
|
||||
self.upsample(input_tensor, scale=scale)
|
||||
if isinstance(self.upsample, Upsample2D)
|
||||
else self.upsample(input_tensor)
|
||||
)
|
||||
hidden_states = (
|
||||
self.upsample(hidden_states, scale=scale)
|
||||
if isinstance(self.upsample, Upsample2D)
|
||||
else self.upsample(hidden_states)
|
||||
)
|
||||
input_tensor = self.upsample(input_tensor)
|
||||
hidden_states = self.upsample(hidden_states)
|
||||
elif self.downsample is not None:
|
||||
input_tensor = (
|
||||
self.downsample(input_tensor, scale=scale)
|
||||
if isinstance(self.downsample, Downsample2D)
|
||||
else self.downsample(input_tensor)
|
||||
)
|
||||
hidden_states = (
|
||||
self.downsample(hidden_states, scale=scale)
|
||||
if isinstance(self.downsample, Downsample2D)
|
||||
else self.downsample(hidden_states)
|
||||
)
|
||||
input_tensor = self.downsample(input_tensor)
|
||||
hidden_states = self.downsample(hidden_states)
|
||||
|
||||
hidden_states = self.conv1(hidden_states, scale) if not USE_PEFT_BACKEND else self.conv1(hidden_states)
|
||||
hidden_states = self.conv1(hidden_states)
|
||||
|
||||
if self.time_emb_proj is not None:
|
||||
if not self.skip_time_act:
|
||||
temb = self.nonlinearity(temb)
|
||||
temb = (
|
||||
self.time_emb_proj(temb, scale)[:, :, None, None]
|
||||
if not USE_PEFT_BACKEND
|
||||
else self.time_emb_proj(temb)[:, :, None, None]
|
||||
)
|
||||
temb = self.time_emb_proj(temb)[:, :, None, None]
|
||||
|
||||
if self.time_embedding_norm == "default":
|
||||
if temb is not None:
|
||||
@@ -393,12 +368,10 @@ class ResnetBlock2D(nn.Module):
|
||||
hidden_states = self.nonlinearity(hidden_states)
|
||||
|
||||
hidden_states = self.dropout(hidden_states)
|
||||
hidden_states = self.conv2(hidden_states, scale) if not USE_PEFT_BACKEND else self.conv2(hidden_states)
|
||||
hidden_states = self.conv2(hidden_states)
|
||||
|
||||
if self.conv_shortcut is not None:
|
||||
input_tensor = (
|
||||
self.conv_shortcut(input_tensor, scale) if not USE_PEFT_BACKEND else self.conv_shortcut(input_tensor)
|
||||
)
|
||||
input_tensor = self.conv_shortcut(input_tensor)
|
||||
|
||||
output_tensor = (input_tensor + hidden_states) / self.output_scale_factor
|
||||
|
||||
|
||||
@@ -19,14 +19,16 @@ import torch.nn.functional as F
|
||||
from torch import nn
|
||||
|
||||
from ...configuration_utils import ConfigMixin, register_to_config
|
||||
from ...utils import USE_PEFT_BACKEND, BaseOutput, deprecate, is_torch_version
|
||||
from ...utils import BaseOutput, deprecate, is_torch_version, logging
|
||||
from ..attention import BasicTransformerBlock
|
||||
from ..embeddings import ImagePositionalEmbeddings, PatchEmbed, PixArtAlphaTextProjection
|
||||
from ..lora import LoRACompatibleConv, LoRACompatibleLinear
|
||||
from ..modeling_utils import ModelMixin
|
||||
from ..normalization import AdaLayerNormSingle
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
|
||||
@dataclass
|
||||
class Transformer2DModelOutput(BaseOutput):
|
||||
"""
|
||||
@@ -92,7 +94,7 @@ class Transformer2DModel(ModelMixin, ConfigMixin):
|
||||
only_cross_attention: bool = False,
|
||||
double_self_attention: bool = False,
|
||||
upcast_attention: bool = False,
|
||||
norm_type: str = "layer_norm",
|
||||
norm_type: str = "layer_norm", # 'layer_norm', 'ada_norm', 'ada_norm_zero', 'ada_norm_single', 'ada_norm_continuous', 'layer_norm_i2vgen'
|
||||
norm_elementwise_affine: bool = True,
|
||||
norm_eps: float = 1e-5,
|
||||
attention_type: str = "default",
|
||||
@@ -100,13 +102,23 @@ class Transformer2DModel(ModelMixin, ConfigMixin):
|
||||
interpolation_scale: float = None,
|
||||
):
|
||||
super().__init__()
|
||||
if patch_size is not None:
|
||||
if norm_type not in ["ada_norm", "ada_norm_zero", "ada_norm_single"]:
|
||||
raise NotImplementedError(
|
||||
f"Forward pass is not implemented when `patch_size` is not None and `norm_type` is '{norm_type}'."
|
||||
)
|
||||
elif norm_type in ["ada_norm", "ada_norm_zero"] and num_embeds_ada_norm is None:
|
||||
raise ValueError(
|
||||
f"When using a `patch_size` and this `norm_type` ({norm_type}), `num_embeds_ada_norm` cannot be None."
|
||||
)
|
||||
|
||||
self.use_linear_projection = use_linear_projection
|
||||
self.num_attention_heads = num_attention_heads
|
||||
self.attention_head_dim = attention_head_dim
|
||||
inner_dim = num_attention_heads * attention_head_dim
|
||||
|
||||
conv_cls = nn.Conv2d if USE_PEFT_BACKEND else LoRACompatibleConv
|
||||
linear_cls = nn.Linear if USE_PEFT_BACKEND else LoRACompatibleLinear
|
||||
conv_cls = nn.Conv2d
|
||||
linear_cls = nn.Linear
|
||||
|
||||
# 1. Transformer2DModel can process both standard continuous images of shape `(batch_size, num_channels, width, height)` as well as quantized image embeddings of shape `(batch_size, num_image_vectors)`
|
||||
# Define whether input is continuous or discrete depending on configuration
|
||||
@@ -294,6 +306,9 @@ class Transformer2DModel(ModelMixin, ConfigMixin):
|
||||
If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a
|
||||
`tuple` where the first element is the sample tensor.
|
||||
"""
|
||||
if cross_attention_kwargs is not None:
|
||||
if cross_attention_kwargs.get("scale", None) is not None:
|
||||
logger.warning("Passing `scale` to `cross_attention_kwargs` is deprecated. `scale` will be ignored.")
|
||||
# ensure attention_mask is a bias, and give it a singleton query_tokens dimension.
|
||||
# we may have done this conversion already, e.g. if we came here via UNet2DConditionModel#forward.
|
||||
# we can tell by counting dims; if ndim == 2: it's a mask rather than a bias.
|
||||
@@ -317,9 +332,6 @@ class Transformer2DModel(ModelMixin, ConfigMixin):
|
||||
encoder_attention_mask = (1 - encoder_attention_mask.to(hidden_states.dtype)) * -10000.0
|
||||
encoder_attention_mask = encoder_attention_mask.unsqueeze(1)
|
||||
|
||||
# Retrieve lora scale.
|
||||
lora_scale = cross_attention_kwargs.get("scale", 1.0) if cross_attention_kwargs is not None else 1.0
|
||||
|
||||
# 1. Input
|
||||
if self.is_input_continuous:
|
||||
batch, _, height, width = hidden_states.shape
|
||||
@@ -327,21 +339,13 @@ class Transformer2DModel(ModelMixin, ConfigMixin):
|
||||
|
||||
hidden_states = self.norm(hidden_states)
|
||||
if not self.use_linear_projection:
|
||||
hidden_states = (
|
||||
self.proj_in(hidden_states, scale=lora_scale)
|
||||
if not USE_PEFT_BACKEND
|
||||
else self.proj_in(hidden_states)
|
||||
)
|
||||
hidden_states = self.proj_in(hidden_states)
|
||||
inner_dim = hidden_states.shape[1]
|
||||
hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * width, inner_dim)
|
||||
else:
|
||||
inner_dim = hidden_states.shape[1]
|
||||
hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * width, inner_dim)
|
||||
hidden_states = (
|
||||
self.proj_in(hidden_states, scale=lora_scale)
|
||||
if not USE_PEFT_BACKEND
|
||||
else self.proj_in(hidden_states)
|
||||
)
|
||||
hidden_states = self.proj_in(hidden_states)
|
||||
|
||||
elif self.is_input_vectorized:
|
||||
hidden_states = self.latent_image_embedding(hidden_states)
|
||||
@@ -404,17 +408,9 @@ class Transformer2DModel(ModelMixin, ConfigMixin):
|
||||
if self.is_input_continuous:
|
||||
if not self.use_linear_projection:
|
||||
hidden_states = hidden_states.reshape(batch, height, width, inner_dim).permute(0, 3, 1, 2).contiguous()
|
||||
hidden_states = (
|
||||
self.proj_out(hidden_states, scale=lora_scale)
|
||||
if not USE_PEFT_BACKEND
|
||||
else self.proj_out(hidden_states)
|
||||
)
|
||||
hidden_states = self.proj_out(hidden_states)
|
||||
else:
|
||||
hidden_states = (
|
||||
self.proj_out(hidden_states, scale=lora_scale)
|
||||
if not USE_PEFT_BACKEND
|
||||
else self.proj_out(hidden_states)
|
||||
)
|
||||
hidden_states = self.proj_out(hidden_states)
|
||||
hidden_states = hidden_states.reshape(batch, height, width, inner_dim).permute(0, 3, 1, 2).contiguous()
|
||||
|
||||
output = hidden_states + residual
|
||||
|
||||
@@ -18,7 +18,7 @@ import torch
|
||||
import torch.nn.functional as F
|
||||
from torch import nn
|
||||
|
||||
from ...utils import is_torch_version, logging
|
||||
from ...utils import deprecate, is_torch_version, logging
|
||||
from ...utils.torch_utils import apply_freeu
|
||||
from ..activations import get_activation
|
||||
from ..attention_processor import Attention, AttnAddedKVProcessor, AttnAddedKVProcessor2_0
|
||||
@@ -69,7 +69,7 @@ def get_down_block(
|
||||
):
|
||||
# If attn head dim is not defined, we default it to the number of heads
|
||||
if attention_head_dim is None:
|
||||
logger.warn(
|
||||
logger.warning(
|
||||
f"It is recommended to provide `attention_head_dim` when calling `get_down_block`. Defaulting `attention_head_dim` to {num_attention_heads}."
|
||||
)
|
||||
attention_head_dim = num_attention_heads
|
||||
@@ -354,7 +354,7 @@ def get_up_block(
|
||||
) -> nn.Module:
|
||||
# If attn head dim is not defined, we default it to the number of heads
|
||||
if attention_head_dim is None:
|
||||
logger.warn(
|
||||
logger.warning(
|
||||
f"It is recommended to provide `attention_head_dim` when calling `get_up_block`. Defaulting `attention_head_dim` to {num_attention_heads}."
|
||||
)
|
||||
attention_head_dim = num_attention_heads
|
||||
@@ -673,7 +673,7 @@ class UNetMidBlock2D(nn.Module):
|
||||
attentions = []
|
||||
|
||||
if attention_head_dim is None:
|
||||
logger.warn(
|
||||
logger.warning(
|
||||
f"It is not recommend to pass `attention_head_dim=None`. Defaulting `attention_head_dim` to `in_channels`: {in_channels}."
|
||||
)
|
||||
attention_head_dim = in_channels
|
||||
@@ -844,8 +844,11 @@ class UNetMidBlock2DCrossAttn(nn.Module):
|
||||
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
|
||||
encoder_attention_mask: Optional[torch.FloatTensor] = None,
|
||||
) -> torch.FloatTensor:
|
||||
lora_scale = cross_attention_kwargs.get("scale", 1.0) if cross_attention_kwargs is not None else 1.0
|
||||
hidden_states = self.resnets[0](hidden_states, temb, scale=lora_scale)
|
||||
if cross_attention_kwargs is not None:
|
||||
if cross_attention_kwargs.get("scale", None) is not None:
|
||||
logger.warning("Passing `scale` to `cross_attention_kwargs` is deprecated. `scale` will be ignored.")
|
||||
|
||||
hidden_states = self.resnets[0](hidden_states, temb)
|
||||
for attn, resnet in zip(self.attentions, self.resnets[1:]):
|
||||
if self.training and self.gradient_checkpointing:
|
||||
|
||||
@@ -882,7 +885,7 @@ class UNetMidBlock2DCrossAttn(nn.Module):
|
||||
encoder_attention_mask=encoder_attention_mask,
|
||||
return_dict=False,
|
||||
)[0]
|
||||
hidden_states = resnet(hidden_states, temb, scale=lora_scale)
|
||||
hidden_states = resnet(hidden_states, temb)
|
||||
|
||||
return hidden_states
|
||||
|
||||
@@ -982,7 +985,8 @@ class UNetMidBlock2DSimpleCrossAttn(nn.Module):
|
||||
encoder_attention_mask: Optional[torch.FloatTensor] = None,
|
||||
) -> torch.FloatTensor:
|
||||
cross_attention_kwargs = cross_attention_kwargs if cross_attention_kwargs is not None else {}
|
||||
lora_scale = cross_attention_kwargs.get("scale", 1.0)
|
||||
if cross_attention_kwargs.get("scale", None) is not None:
|
||||
logger.warning("Passing `scale` to `cross_attention_kwargs` is deprecated. `scale` will be ignored.")
|
||||
|
||||
if attention_mask is None:
|
||||
# if encoder_hidden_states is defined: we are doing cross-attn, so we should use cross-attn mask.
|
||||
@@ -995,7 +999,7 @@ class UNetMidBlock2DSimpleCrossAttn(nn.Module):
|
||||
# mask = attention_mask if encoder_hidden_states is None else encoder_attention_mask
|
||||
mask = attention_mask
|
||||
|
||||
hidden_states = self.resnets[0](hidden_states, temb, scale=lora_scale)
|
||||
hidden_states = self.resnets[0](hidden_states, temb)
|
||||
for attn, resnet in zip(self.attentions, self.resnets[1:]):
|
||||
# attn
|
||||
hidden_states = attn(
|
||||
@@ -1006,7 +1010,7 @@ class UNetMidBlock2DSimpleCrossAttn(nn.Module):
|
||||
)
|
||||
|
||||
# resnet
|
||||
hidden_states = resnet(hidden_states, temb, scale=lora_scale)
|
||||
hidden_states = resnet(hidden_states, temb)
|
||||
|
||||
return hidden_states
|
||||
|
||||
@@ -1035,7 +1039,7 @@ class AttnDownBlock2D(nn.Module):
|
||||
self.downsample_type = downsample_type
|
||||
|
||||
if attention_head_dim is None:
|
||||
logger.warn(
|
||||
logger.warning(
|
||||
f"It is not recommend to pass `attention_head_dim=None`. Defaulting `attention_head_dim` to `in_channels`: {out_channels}."
|
||||
)
|
||||
attention_head_dim = out_channels
|
||||
@@ -1111,23 +1115,22 @@ class AttnDownBlock2D(nn.Module):
|
||||
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
|
||||
) -> Tuple[torch.FloatTensor, Tuple[torch.FloatTensor, ...]]:
|
||||
cross_attention_kwargs = cross_attention_kwargs if cross_attention_kwargs is not None else {}
|
||||
|
||||
lora_scale = cross_attention_kwargs.get("scale", 1.0)
|
||||
if cross_attention_kwargs.get("scale", None) is not None:
|
||||
logger.warning("Passing `scale` to `cross_attention_kwargs` is deprecated. `scale` will be ignored.")
|
||||
|
||||
output_states = ()
|
||||
|
||||
for resnet, attn in zip(self.resnets, self.attentions):
|
||||
cross_attention_kwargs.update({"scale": lora_scale})
|
||||
hidden_states = resnet(hidden_states, temb, scale=lora_scale)
|
||||
hidden_states = resnet(hidden_states, temb)
|
||||
hidden_states = attn(hidden_states, **cross_attention_kwargs)
|
||||
output_states = output_states + (hidden_states,)
|
||||
|
||||
if self.downsamplers is not None:
|
||||
for downsampler in self.downsamplers:
|
||||
if self.downsample_type == "resnet":
|
||||
hidden_states = downsampler(hidden_states, temb=temb, scale=lora_scale)
|
||||
hidden_states = downsampler(hidden_states, temb=temb)
|
||||
else:
|
||||
hidden_states = downsampler(hidden_states, scale=lora_scale)
|
||||
hidden_states = downsampler(hidden_states)
|
||||
|
||||
output_states += (hidden_states,)
|
||||
|
||||
@@ -1236,9 +1239,11 @@ class CrossAttnDownBlock2D(nn.Module):
|
||||
encoder_attention_mask: Optional[torch.FloatTensor] = None,
|
||||
additional_residuals: Optional[torch.FloatTensor] = None,
|
||||
) -> Tuple[torch.FloatTensor, Tuple[torch.FloatTensor, ...]]:
|
||||
output_states = ()
|
||||
if cross_attention_kwargs is not None:
|
||||
if cross_attention_kwargs.get("scale", None) is not None:
|
||||
logger.warning("Passing `scale` to `cross_attention_kwargs` is deprecated. `scale` will be ignored.")
|
||||
|
||||
lora_scale = cross_attention_kwargs.get("scale", 1.0) if cross_attention_kwargs is not None else 1.0
|
||||
output_states = ()
|
||||
|
||||
blocks = list(zip(self.resnets, self.attentions))
|
||||
|
||||
@@ -1270,7 +1275,7 @@ class CrossAttnDownBlock2D(nn.Module):
|
||||
return_dict=False,
|
||||
)[0]
|
||||
else:
|
||||
hidden_states = resnet(hidden_states, temb, scale=lora_scale)
|
||||
hidden_states = resnet(hidden_states, temb)
|
||||
hidden_states = attn(
|
||||
hidden_states,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
@@ -1288,7 +1293,7 @@ class CrossAttnDownBlock2D(nn.Module):
|
||||
|
||||
if self.downsamplers is not None:
|
||||
for downsampler in self.downsamplers:
|
||||
hidden_states = downsampler(hidden_states, scale=lora_scale)
|
||||
hidden_states = downsampler(hidden_states)
|
||||
|
||||
output_states = output_states + (hidden_states,)
|
||||
|
||||
@@ -1348,8 +1353,12 @@ class DownBlock2D(nn.Module):
|
||||
self.gradient_checkpointing = False
|
||||
|
||||
def forward(
|
||||
self, hidden_states: torch.FloatTensor, temb: Optional[torch.FloatTensor] = None, scale: float = 1.0
|
||||
self, hidden_states: torch.FloatTensor, temb: Optional[torch.FloatTensor] = None, *args, **kwargs
|
||||
) -> Tuple[torch.FloatTensor, Tuple[torch.FloatTensor, ...]]:
|
||||
if len(args) > 0 or kwargs.get("scale", None) is not None:
|
||||
deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`."
|
||||
deprecate("scale", "1.0.0", deprecation_message)
|
||||
|
||||
output_states = ()
|
||||
|
||||
for resnet in self.resnets:
|
||||
@@ -1370,13 +1379,13 @@ class DownBlock2D(nn.Module):
|
||||
create_custom_forward(resnet), hidden_states, temb
|
||||
)
|
||||
else:
|
||||
hidden_states = resnet(hidden_states, temb, scale=scale)
|
||||
hidden_states = resnet(hidden_states, temb)
|
||||
|
||||
output_states = output_states + (hidden_states,)
|
||||
|
||||
if self.downsamplers is not None:
|
||||
for downsampler in self.downsamplers:
|
||||
hidden_states = downsampler(hidden_states, scale=scale)
|
||||
hidden_states = downsampler(hidden_states)
|
||||
|
||||
output_states = output_states + (hidden_states,)
|
||||
|
||||
@@ -1447,13 +1456,17 @@ class DownEncoderBlock2D(nn.Module):
|
||||
else:
|
||||
self.downsamplers = None
|
||||
|
||||
def forward(self, hidden_states: torch.FloatTensor, scale: float = 1.0) -> torch.FloatTensor:
|
||||
def forward(self, hidden_states: torch.FloatTensor, *args, **kwargs) -> torch.FloatTensor:
|
||||
if len(args) > 0 or kwargs.get("scale", None) is not None:
|
||||
deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`."
|
||||
deprecate("scale", "1.0.0", deprecation_message)
|
||||
|
||||
for resnet in self.resnets:
|
||||
hidden_states = resnet(hidden_states, temb=None, scale=scale)
|
||||
hidden_states = resnet(hidden_states, temb=None)
|
||||
|
||||
if self.downsamplers is not None:
|
||||
for downsampler in self.downsamplers:
|
||||
hidden_states = downsampler(hidden_states, scale)
|
||||
hidden_states = downsampler(hidden_states)
|
||||
|
||||
return hidden_states
|
||||
|
||||
@@ -1480,7 +1493,7 @@ class AttnDownEncoderBlock2D(nn.Module):
|
||||
attentions = []
|
||||
|
||||
if attention_head_dim is None:
|
||||
logger.warn(
|
||||
logger.warning(
|
||||
f"It is not recommend to pass `attention_head_dim=None`. Defaulting `attention_head_dim` to `in_channels`: {out_channels}."
|
||||
)
|
||||
attention_head_dim = out_channels
|
||||
@@ -1545,15 +1558,18 @@ class AttnDownEncoderBlock2D(nn.Module):
|
||||
else:
|
||||
self.downsamplers = None
|
||||
|
||||
def forward(self, hidden_states: torch.FloatTensor, scale: float = 1.0) -> torch.FloatTensor:
|
||||
def forward(self, hidden_states: torch.FloatTensor, *args, **kwargs) -> torch.FloatTensor:
|
||||
if len(args) > 0 or kwargs.get("scale", None) is not None:
|
||||
deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`."
|
||||
deprecate("scale", "1.0.0", deprecation_message)
|
||||
|
||||
for resnet, attn in zip(self.resnets, self.attentions):
|
||||
hidden_states = resnet(hidden_states, temb=None, scale=scale)
|
||||
cross_attention_kwargs = {"scale": scale}
|
||||
hidden_states = attn(hidden_states, **cross_attention_kwargs)
|
||||
hidden_states = resnet(hidden_states, temb=None)
|
||||
hidden_states = attn(hidden_states)
|
||||
|
||||
if self.downsamplers is not None:
|
||||
for downsampler in self.downsamplers:
|
||||
hidden_states = downsampler(hidden_states, scale)
|
||||
hidden_states = downsampler(hidden_states)
|
||||
|
||||
return hidden_states
|
||||
|
||||
@@ -1579,7 +1595,7 @@ class AttnSkipDownBlock2D(nn.Module):
|
||||
self.resnets = nn.ModuleList([])
|
||||
|
||||
if attention_head_dim is None:
|
||||
logger.warn(
|
||||
logger.warning(
|
||||
f"It is not recommend to pass `attention_head_dim=None`. Defaulting `attention_head_dim` to `in_channels`: {out_channels}."
|
||||
)
|
||||
attention_head_dim = out_channels
|
||||
@@ -1644,18 +1660,22 @@ class AttnSkipDownBlock2D(nn.Module):
|
||||
hidden_states: torch.FloatTensor,
|
||||
temb: Optional[torch.FloatTensor] = None,
|
||||
skip_sample: Optional[torch.FloatTensor] = None,
|
||||
scale: float = 1.0,
|
||||
*args,
|
||||
**kwargs,
|
||||
) -> Tuple[torch.FloatTensor, Tuple[torch.FloatTensor, ...], torch.FloatTensor]:
|
||||
if len(args) > 0 or kwargs.get("scale", None) is not None:
|
||||
deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`."
|
||||
deprecate("scale", "1.0.0", deprecation_message)
|
||||
|
||||
output_states = ()
|
||||
|
||||
for resnet, attn in zip(self.resnets, self.attentions):
|
||||
hidden_states = resnet(hidden_states, temb, scale=scale)
|
||||
cross_attention_kwargs = {"scale": scale}
|
||||
hidden_states = attn(hidden_states, **cross_attention_kwargs)
|
||||
hidden_states = resnet(hidden_states, temb)
|
||||
hidden_states = attn(hidden_states)
|
||||
output_states += (hidden_states,)
|
||||
|
||||
if self.downsamplers is not None:
|
||||
hidden_states = self.resnet_down(hidden_states, temb, scale=scale)
|
||||
hidden_states = self.resnet_down(hidden_states, temb)
|
||||
for downsampler in self.downsamplers:
|
||||
skip_sample = downsampler(skip_sample)
|
||||
|
||||
@@ -1731,16 +1751,21 @@ class SkipDownBlock2D(nn.Module):
|
||||
hidden_states: torch.FloatTensor,
|
||||
temb: Optional[torch.FloatTensor] = None,
|
||||
skip_sample: Optional[torch.FloatTensor] = None,
|
||||
scale: float = 1.0,
|
||||
*args,
|
||||
**kwargs,
|
||||
) -> Tuple[torch.FloatTensor, Tuple[torch.FloatTensor, ...], torch.FloatTensor]:
|
||||
if len(args) > 0 or kwargs.get("scale", None) is not None:
|
||||
deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`."
|
||||
deprecate("scale", "1.0.0", deprecation_message)
|
||||
|
||||
output_states = ()
|
||||
|
||||
for resnet in self.resnets:
|
||||
hidden_states = resnet(hidden_states, temb, scale)
|
||||
hidden_states = resnet(hidden_states, temb)
|
||||
output_states += (hidden_states,)
|
||||
|
||||
if self.downsamplers is not None:
|
||||
hidden_states = self.resnet_down(hidden_states, temb, scale)
|
||||
hidden_states = self.resnet_down(hidden_states, temb)
|
||||
for downsampler in self.downsamplers:
|
||||
skip_sample = downsampler(skip_sample)
|
||||
|
||||
@@ -1816,8 +1841,12 @@ class ResnetDownsampleBlock2D(nn.Module):
|
||||
self.gradient_checkpointing = False
|
||||
|
||||
def forward(
|
||||
self, hidden_states: torch.FloatTensor, temb: Optional[torch.FloatTensor] = None, scale: float = 1.0
|
||||
self, hidden_states: torch.FloatTensor, temb: Optional[torch.FloatTensor] = None, *args, **kwargs
|
||||
) -> Tuple[torch.FloatTensor, Tuple[torch.FloatTensor, ...]]:
|
||||
if len(args) > 0 or kwargs.get("scale", None) is not None:
|
||||
deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`."
|
||||
deprecate("scale", "1.0.0", deprecation_message)
|
||||
|
||||
output_states = ()
|
||||
|
||||
for resnet in self.resnets:
|
||||
@@ -1838,13 +1867,13 @@ class ResnetDownsampleBlock2D(nn.Module):
|
||||
create_custom_forward(resnet), hidden_states, temb
|
||||
)
|
||||
else:
|
||||
hidden_states = resnet(hidden_states, temb, scale)
|
||||
hidden_states = resnet(hidden_states, temb)
|
||||
|
||||
output_states = output_states + (hidden_states,)
|
||||
|
||||
if self.downsamplers is not None:
|
||||
for downsampler in self.downsamplers:
|
||||
hidden_states = downsampler(hidden_states, temb, scale)
|
||||
hidden_states = downsampler(hidden_states, temb)
|
||||
|
||||
output_states = output_states + (hidden_states,)
|
||||
|
||||
@@ -1955,10 +1984,11 @@ class SimpleCrossAttnDownBlock2D(nn.Module):
|
||||
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
|
||||
encoder_attention_mask: Optional[torch.FloatTensor] = None,
|
||||
) -> Tuple[torch.FloatTensor, Tuple[torch.FloatTensor, ...]]:
|
||||
output_states = ()
|
||||
cross_attention_kwargs = cross_attention_kwargs if cross_attention_kwargs is not None else {}
|
||||
if cross_attention_kwargs.get("scale", None) is not None:
|
||||
logger.warning("Passing `scale` to `cross_attention_kwargs` is deprecated. `scale` will be ignored.")
|
||||
|
||||
lora_scale = cross_attention_kwargs.get("scale", 1.0)
|
||||
output_states = ()
|
||||
|
||||
if attention_mask is None:
|
||||
# if encoder_hidden_states is defined: we are doing cross-attn, so we should use cross-attn mask.
|
||||
@@ -1991,7 +2021,7 @@ class SimpleCrossAttnDownBlock2D(nn.Module):
|
||||
**cross_attention_kwargs,
|
||||
)
|
||||
else:
|
||||
hidden_states = resnet(hidden_states, temb, scale=lora_scale)
|
||||
hidden_states = resnet(hidden_states, temb)
|
||||
|
||||
hidden_states = attn(
|
||||
hidden_states,
|
||||
@@ -2004,7 +2034,7 @@ class SimpleCrossAttnDownBlock2D(nn.Module):
|
||||
|
||||
if self.downsamplers is not None:
|
||||
for downsampler in self.downsamplers:
|
||||
hidden_states = downsampler(hidden_states, temb, scale=lora_scale)
|
||||
hidden_states = downsampler(hidden_states, temb)
|
||||
|
||||
output_states = output_states + (hidden_states,)
|
||||
|
||||
@@ -2058,8 +2088,12 @@ class KDownBlock2D(nn.Module):
|
||||
self.gradient_checkpointing = False
|
||||
|
||||
def forward(
|
||||
self, hidden_states: torch.FloatTensor, temb: Optional[torch.FloatTensor] = None, scale: float = 1.0
|
||||
self, hidden_states: torch.FloatTensor, temb: Optional[torch.FloatTensor] = None, *args, **kwargs
|
||||
) -> Tuple[torch.FloatTensor, Tuple[torch.FloatTensor, ...]]:
|
||||
if len(args) > 0 or kwargs.get("scale", None) is not None:
|
||||
deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`."
|
||||
deprecate("scale", "1.0.0", deprecation_message)
|
||||
|
||||
output_states = ()
|
||||
|
||||
for resnet in self.resnets:
|
||||
@@ -2080,7 +2114,7 @@ class KDownBlock2D(nn.Module):
|
||||
create_custom_forward(resnet), hidden_states, temb
|
||||
)
|
||||
else:
|
||||
hidden_states = resnet(hidden_states, temb, scale)
|
||||
hidden_states = resnet(hidden_states, temb)
|
||||
|
||||
output_states += (hidden_states,)
|
||||
|
||||
@@ -2165,8 +2199,11 @@ class KCrossAttnDownBlock2D(nn.Module):
|
||||
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
|
||||
encoder_attention_mask: Optional[torch.FloatTensor] = None,
|
||||
) -> Tuple[torch.FloatTensor, Tuple[torch.FloatTensor, ...]]:
|
||||
cross_attention_kwargs = cross_attention_kwargs if cross_attention_kwargs is not None else {}
|
||||
if cross_attention_kwargs.get("scale", None) is not None:
|
||||
logger.warning("Passing `scale` to `cross_attention_kwargs` is deprecated. `scale` will be ignored.")
|
||||
|
||||
output_states = ()
|
||||
lora_scale = cross_attention_kwargs.get("scale", 1.0) if cross_attention_kwargs is not None else 1.0
|
||||
|
||||
for resnet, attn in zip(self.resnets, self.attentions):
|
||||
if self.training and self.gradient_checkpointing:
|
||||
@@ -2196,7 +2233,7 @@ class KCrossAttnDownBlock2D(nn.Module):
|
||||
encoder_attention_mask=encoder_attention_mask,
|
||||
)
|
||||
else:
|
||||
hidden_states = resnet(hidden_states, temb, scale=lora_scale)
|
||||
hidden_states = resnet(hidden_states, temb)
|
||||
hidden_states = attn(
|
||||
hidden_states,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
@@ -2244,7 +2281,7 @@ class AttnUpBlock2D(nn.Module):
|
||||
self.upsample_type = upsample_type
|
||||
|
||||
if attention_head_dim is None:
|
||||
logger.warn(
|
||||
logger.warning(
|
||||
f"It is not recommend to pass `attention_head_dim=None`. Defaulting `attention_head_dim` to `in_channels`: {out_channels}."
|
||||
)
|
||||
attention_head_dim = out_channels
|
||||
@@ -2316,24 +2353,28 @@ class AttnUpBlock2D(nn.Module):
|
||||
res_hidden_states_tuple: Tuple[torch.FloatTensor, ...],
|
||||
temb: Optional[torch.FloatTensor] = None,
|
||||
upsample_size: Optional[int] = None,
|
||||
scale: float = 1.0,
|
||||
*args,
|
||||
**kwargs,
|
||||
) -> torch.FloatTensor:
|
||||
if len(args) > 0 or kwargs.get("scale", None) is not None:
|
||||
deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`."
|
||||
deprecate("scale", "1.0.0", deprecation_message)
|
||||
|
||||
for resnet, attn in zip(self.resnets, self.attentions):
|
||||
# pop res hidden states
|
||||
res_hidden_states = res_hidden_states_tuple[-1]
|
||||
res_hidden_states_tuple = res_hidden_states_tuple[:-1]
|
||||
hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
|
||||
|
||||
hidden_states = resnet(hidden_states, temb, scale=scale)
|
||||
cross_attention_kwargs = {"scale": scale}
|
||||
hidden_states = attn(hidden_states, **cross_attention_kwargs)
|
||||
hidden_states = resnet(hidden_states, temb)
|
||||
hidden_states = attn(hidden_states)
|
||||
|
||||
if self.upsamplers is not None:
|
||||
for upsampler in self.upsamplers:
|
||||
if self.upsample_type == "resnet":
|
||||
hidden_states = upsampler(hidden_states, temb=temb, scale=scale)
|
||||
hidden_states = upsampler(hidden_states, temb=temb)
|
||||
else:
|
||||
hidden_states = upsampler(hidden_states, scale=scale)
|
||||
hidden_states = upsampler(hidden_states)
|
||||
|
||||
return hidden_states
|
||||
|
||||
@@ -2440,7 +2481,10 @@ class CrossAttnUpBlock2D(nn.Module):
|
||||
attention_mask: Optional[torch.FloatTensor] = None,
|
||||
encoder_attention_mask: Optional[torch.FloatTensor] = None,
|
||||
) -> torch.FloatTensor:
|
||||
lora_scale = cross_attention_kwargs.get("scale", 1.0) if cross_attention_kwargs is not None else 1.0
|
||||
if cross_attention_kwargs is not None:
|
||||
if cross_attention_kwargs.get("scale", None) is not None:
|
||||
logger.warning("Passing `scale` to `cross_attention_kwargs` is deprecated. `scale` will be ignored.")
|
||||
|
||||
is_freeu_enabled = (
|
||||
getattr(self, "s1", None)
|
||||
and getattr(self, "s2", None)
|
||||
@@ -2494,7 +2538,7 @@ class CrossAttnUpBlock2D(nn.Module):
|
||||
return_dict=False,
|
||||
)[0]
|
||||
else:
|
||||
hidden_states = resnet(hidden_states, temb, scale=lora_scale)
|
||||
hidden_states = resnet(hidden_states, temb)
|
||||
hidden_states = attn(
|
||||
hidden_states,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
@@ -2506,7 +2550,7 @@ class CrossAttnUpBlock2D(nn.Module):
|
||||
|
||||
if self.upsamplers is not None:
|
||||
for upsampler in self.upsamplers:
|
||||
hidden_states = upsampler(hidden_states, upsample_size, scale=lora_scale)
|
||||
hidden_states = upsampler(hidden_states, upsample_size)
|
||||
|
||||
return hidden_states
|
||||
|
||||
@@ -2567,8 +2611,13 @@ class UpBlock2D(nn.Module):
|
||||
res_hidden_states_tuple: Tuple[torch.FloatTensor, ...],
|
||||
temb: Optional[torch.FloatTensor] = None,
|
||||
upsample_size: Optional[int] = None,
|
||||
scale: float = 1.0,
|
||||
*args,
|
||||
**kwargs,
|
||||
) -> torch.FloatTensor:
|
||||
if len(args) > 0 or kwargs.get("scale", None) is not None:
|
||||
deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`."
|
||||
deprecate("scale", "1.0.0", deprecation_message)
|
||||
|
||||
is_freeu_enabled = (
|
||||
getattr(self, "s1", None)
|
||||
and getattr(self, "s2", None)
|
||||
@@ -2612,11 +2661,11 @@ class UpBlock2D(nn.Module):
|
||||
create_custom_forward(resnet), hidden_states, temb
|
||||
)
|
||||
else:
|
||||
hidden_states = resnet(hidden_states, temb, scale=scale)
|
||||
hidden_states = resnet(hidden_states, temb)
|
||||
|
||||
if self.upsamplers is not None:
|
||||
for upsampler in self.upsamplers:
|
||||
hidden_states = upsampler(hidden_states, upsample_size, scale=scale)
|
||||
hidden_states = upsampler(hidden_states, upsample_size)
|
||||
|
||||
return hidden_states
|
||||
|
||||
@@ -2683,11 +2732,9 @@ class UpDecoderBlock2D(nn.Module):
|
||||
|
||||
self.resolution_idx = resolution_idx
|
||||
|
||||
def forward(
|
||||
self, hidden_states: torch.FloatTensor, temb: Optional[torch.FloatTensor] = None, scale: float = 1.0
|
||||
) -> torch.FloatTensor:
|
||||
def forward(self, hidden_states: torch.FloatTensor, temb: Optional[torch.FloatTensor] = None) -> torch.FloatTensor:
|
||||
for resnet in self.resnets:
|
||||
hidden_states = resnet(hidden_states, temb=temb, scale=scale)
|
||||
hidden_states = resnet(hidden_states, temb=temb)
|
||||
|
||||
if self.upsamplers is not None:
|
||||
for upsampler in self.upsamplers:
|
||||
@@ -2719,7 +2766,7 @@ class AttnUpDecoderBlock2D(nn.Module):
|
||||
attentions = []
|
||||
|
||||
if attention_head_dim is None:
|
||||
logger.warn(
|
||||
logger.warning(
|
||||
f"It is not recommend to pass `attention_head_dim=None`. Defaulting `attention_head_dim` to `out_channels`: {out_channels}."
|
||||
)
|
||||
attention_head_dim = out_channels
|
||||
@@ -2783,17 +2830,14 @@ class AttnUpDecoderBlock2D(nn.Module):
|
||||
|
||||
self.resolution_idx = resolution_idx
|
||||
|
||||
def forward(
|
||||
self, hidden_states: torch.FloatTensor, temb: Optional[torch.FloatTensor] = None, scale: float = 1.0
|
||||
) -> torch.FloatTensor:
|
||||
def forward(self, hidden_states: torch.FloatTensor, temb: Optional[torch.FloatTensor] = None) -> torch.FloatTensor:
|
||||
for resnet, attn in zip(self.resnets, self.attentions):
|
||||
hidden_states = resnet(hidden_states, temb=temb, scale=scale)
|
||||
cross_attention_kwargs = {"scale": scale}
|
||||
hidden_states = attn(hidden_states, temb=temb, **cross_attention_kwargs)
|
||||
hidden_states = resnet(hidden_states, temb=temb)
|
||||
hidden_states = attn(hidden_states, temb=temb)
|
||||
|
||||
if self.upsamplers is not None:
|
||||
for upsampler in self.upsamplers:
|
||||
hidden_states = upsampler(hidden_states, scale=scale)
|
||||
hidden_states = upsampler(hidden_states)
|
||||
|
||||
return hidden_states
|
||||
|
||||
@@ -2841,7 +2885,7 @@ class AttnSkipUpBlock2D(nn.Module):
|
||||
)
|
||||
|
||||
if attention_head_dim is None:
|
||||
logger.warn(
|
||||
logger.warning(
|
||||
f"It is not recommend to pass `attention_head_dim=None`. Defaulting `attention_head_dim` to `out_channels`: {out_channels}."
|
||||
)
|
||||
attention_head_dim = out_channels
|
||||
@@ -2898,18 +2942,22 @@ class AttnSkipUpBlock2D(nn.Module):
|
||||
res_hidden_states_tuple: Tuple[torch.FloatTensor, ...],
|
||||
temb: Optional[torch.FloatTensor] = None,
|
||||
skip_sample=None,
|
||||
scale: float = 1.0,
|
||||
*args,
|
||||
**kwargs,
|
||||
) -> Tuple[torch.FloatTensor, torch.FloatTensor]:
|
||||
if len(args) > 0 or kwargs.get("scale", None) is not None:
|
||||
deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`."
|
||||
deprecate("scale", "1.0.0", deprecation_message)
|
||||
|
||||
for resnet in self.resnets:
|
||||
# pop res hidden states
|
||||
res_hidden_states = res_hidden_states_tuple[-1]
|
||||
res_hidden_states_tuple = res_hidden_states_tuple[:-1]
|
||||
hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
|
||||
|
||||
hidden_states = resnet(hidden_states, temb, scale=scale)
|
||||
hidden_states = resnet(hidden_states, temb)
|
||||
|
||||
cross_attention_kwargs = {"scale": scale}
|
||||
hidden_states = self.attentions[0](hidden_states, **cross_attention_kwargs)
|
||||
hidden_states = self.attentions[0](hidden_states)
|
||||
|
||||
if skip_sample is not None:
|
||||
skip_sample = self.upsampler(skip_sample)
|
||||
@@ -2923,7 +2971,7 @@ class AttnSkipUpBlock2D(nn.Module):
|
||||
|
||||
skip_sample = skip_sample + skip_sample_states
|
||||
|
||||
hidden_states = self.resnet_up(hidden_states, temb, scale=scale)
|
||||
hidden_states = self.resnet_up(hidden_states, temb)
|
||||
|
||||
return hidden_states, skip_sample
|
||||
|
||||
@@ -3006,15 +3054,20 @@ class SkipUpBlock2D(nn.Module):
|
||||
res_hidden_states_tuple: Tuple[torch.FloatTensor, ...],
|
||||
temb: Optional[torch.FloatTensor] = None,
|
||||
skip_sample=None,
|
||||
scale: float = 1.0,
|
||||
*args,
|
||||
**kwargs,
|
||||
) -> Tuple[torch.FloatTensor, torch.FloatTensor]:
|
||||
if len(args) > 0 or kwargs.get("scale", None) is not None:
|
||||
deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`."
|
||||
deprecate("scale", "1.0.0", deprecation_message)
|
||||
|
||||
for resnet in self.resnets:
|
||||
# pop res hidden states
|
||||
res_hidden_states = res_hidden_states_tuple[-1]
|
||||
res_hidden_states_tuple = res_hidden_states_tuple[:-1]
|
||||
hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
|
||||
|
||||
hidden_states = resnet(hidden_states, temb, scale=scale)
|
||||
hidden_states = resnet(hidden_states, temb)
|
||||
|
||||
if skip_sample is not None:
|
||||
skip_sample = self.upsampler(skip_sample)
|
||||
@@ -3028,7 +3081,7 @@ class SkipUpBlock2D(nn.Module):
|
||||
|
||||
skip_sample = skip_sample + skip_sample_states
|
||||
|
||||
hidden_states = self.resnet_up(hidden_states, temb, scale=scale)
|
||||
hidden_states = self.resnet_up(hidden_states, temb)
|
||||
|
||||
return hidden_states, skip_sample
|
||||
|
||||
@@ -3108,8 +3161,13 @@ class ResnetUpsampleBlock2D(nn.Module):
|
||||
res_hidden_states_tuple: Tuple[torch.FloatTensor, ...],
|
||||
temb: Optional[torch.FloatTensor] = None,
|
||||
upsample_size: Optional[int] = None,
|
||||
scale: float = 1.0,
|
||||
*args,
|
||||
**kwargs,
|
||||
) -> torch.FloatTensor:
|
||||
if len(args) > 0 or kwargs.get("scale", None) is not None:
|
||||
deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`."
|
||||
deprecate("scale", "1.0.0", deprecation_message)
|
||||
|
||||
for resnet in self.resnets:
|
||||
# pop res hidden states
|
||||
res_hidden_states = res_hidden_states_tuple[-1]
|
||||
@@ -3133,11 +3191,11 @@ class ResnetUpsampleBlock2D(nn.Module):
|
||||
create_custom_forward(resnet), hidden_states, temb
|
||||
)
|
||||
else:
|
||||
hidden_states = resnet(hidden_states, temb, scale=scale)
|
||||
hidden_states = resnet(hidden_states, temb)
|
||||
|
||||
if self.upsamplers is not None:
|
||||
for upsampler in self.upsamplers:
|
||||
hidden_states = upsampler(hidden_states, temb, scale=scale)
|
||||
hidden_states = upsampler(hidden_states, temb)
|
||||
|
||||
return hidden_states
|
||||
|
||||
@@ -3253,8 +3311,9 @@ class SimpleCrossAttnUpBlock2D(nn.Module):
|
||||
encoder_attention_mask: Optional[torch.FloatTensor] = None,
|
||||
) -> torch.FloatTensor:
|
||||
cross_attention_kwargs = cross_attention_kwargs if cross_attention_kwargs is not None else {}
|
||||
if cross_attention_kwargs.get("scale", None) is not None:
|
||||
logger.warning("Passing `scale` to `cross_attention_kwargs` is deprecated. `scale` will be ignored.")
|
||||
|
||||
lora_scale = cross_attention_kwargs.get("scale", 1.0)
|
||||
if attention_mask is None:
|
||||
# if encoder_hidden_states is defined: we are doing cross-attn, so we should use cross-attn mask.
|
||||
mask = None if encoder_hidden_states is None else encoder_attention_mask
|
||||
@@ -3292,7 +3351,7 @@ class SimpleCrossAttnUpBlock2D(nn.Module):
|
||||
**cross_attention_kwargs,
|
||||
)
|
||||
else:
|
||||
hidden_states = resnet(hidden_states, temb, scale=lora_scale)
|
||||
hidden_states = resnet(hidden_states, temb)
|
||||
|
||||
hidden_states = attn(
|
||||
hidden_states,
|
||||
@@ -3303,7 +3362,7 @@ class SimpleCrossAttnUpBlock2D(nn.Module):
|
||||
|
||||
if self.upsamplers is not None:
|
||||
for upsampler in self.upsamplers:
|
||||
hidden_states = upsampler(hidden_states, temb, scale=lora_scale)
|
||||
hidden_states = upsampler(hidden_states, temb)
|
||||
|
||||
return hidden_states
|
||||
|
||||
@@ -3364,8 +3423,13 @@ class KUpBlock2D(nn.Module):
|
||||
res_hidden_states_tuple: Tuple[torch.FloatTensor, ...],
|
||||
temb: Optional[torch.FloatTensor] = None,
|
||||
upsample_size: Optional[int] = None,
|
||||
scale: float = 1.0,
|
||||
*args,
|
||||
**kwargs,
|
||||
) -> torch.FloatTensor:
|
||||
if len(args) > 0 or kwargs.get("scale", None) is not None:
|
||||
deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`."
|
||||
deprecate("scale", "1.0.0", deprecation_message)
|
||||
|
||||
res_hidden_states_tuple = res_hidden_states_tuple[-1]
|
||||
if res_hidden_states_tuple is not None:
|
||||
hidden_states = torch.cat([hidden_states, res_hidden_states_tuple], dim=1)
|
||||
@@ -3388,7 +3452,7 @@ class KUpBlock2D(nn.Module):
|
||||
create_custom_forward(resnet), hidden_states, temb
|
||||
)
|
||||
else:
|
||||
hidden_states = resnet(hidden_states, temb, scale=scale)
|
||||
hidden_states = resnet(hidden_states, temb)
|
||||
|
||||
if self.upsamplers is not None:
|
||||
for upsampler in self.upsamplers:
|
||||
@@ -3498,7 +3562,6 @@ class KCrossAttnUpBlock2D(nn.Module):
|
||||
if res_hidden_states_tuple is not None:
|
||||
hidden_states = torch.cat([hidden_states, res_hidden_states_tuple], dim=1)
|
||||
|
||||
lora_scale = cross_attention_kwargs.get("scale", 1.0) if cross_attention_kwargs is not None else 1.0
|
||||
for resnet, attn in zip(self.resnets, self.attentions):
|
||||
if self.training and self.gradient_checkpointing:
|
||||
|
||||
@@ -3527,7 +3590,7 @@ class KCrossAttnUpBlock2D(nn.Module):
|
||||
encoder_attention_mask=encoder_attention_mask,
|
||||
)
|
||||
else:
|
||||
hidden_states = resnet(hidden_states, temb, scale=lora_scale)
|
||||
hidden_states = resnet(hidden_states, temb)
|
||||
hidden_states = attn(
|
||||
hidden_states,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
@@ -3630,6 +3693,8 @@ class KAttentionBlock(nn.Module):
|
||||
encoder_attention_mask: Optional[torch.FloatTensor] = None,
|
||||
) -> torch.FloatTensor:
|
||||
cross_attention_kwargs = cross_attention_kwargs if cross_attention_kwargs is not None else {}
|
||||
if cross_attention_kwargs.get("scale", None) is not None:
|
||||
logger.warning("Passing `scale` to `cross_attention_kwargs` is deprecated. `scale` will be ignored.")
|
||||
|
||||
# 1. Self-Attention
|
||||
if self.add_self_attention:
|
||||
|
||||
@@ -80,7 +80,7 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin,
|
||||
in_channels (`int`, *optional*, defaults to 4): Number of channels in the input sample.
|
||||
out_channels (`int`, *optional*, defaults to 4): Number of channels in the output.
|
||||
center_input_sample (`bool`, *optional*, defaults to `False`): Whether to center the input sample.
|
||||
flip_sin_to_cos (`bool`, *optional*, defaults to `False`):
|
||||
flip_sin_to_cos (`bool`, *optional*, defaults to `True`):
|
||||
Whether to flip the sin to cos in the time embedding.
|
||||
freq_shift (`int`, *optional*, defaults to 0): The frequency shift to apply to the time embedding.
|
||||
down_block_types (`Tuple[str]`, *optional*, defaults to `("CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "DownBlock2D")`):
|
||||
@@ -109,7 +109,7 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin,
|
||||
The number of transformer blocks of type [`~models.attention.BasicTransformerBlock`]. Only relevant for
|
||||
[`~models.unet_2d_blocks.CrossAttnDownBlock2D`], [`~models.unet_2d_blocks.CrossAttnUpBlock2D`],
|
||||
[`~models.unet_2d_blocks.UNetMidBlock2DCrossAttn`].
|
||||
reverse_transformer_layers_per_block : (`Tuple[Tuple]`, *optional*, defaults to None):
|
||||
reverse_transformer_layers_per_block : (`Tuple[Tuple]`, *optional*, defaults to None):
|
||||
The number of transformer blocks of type [`~models.attention.BasicTransformerBlock`], in the upsampling
|
||||
blocks of the U-Net. Only relevant if `transformer_layers_per_block` is of type `Tuple[Tuple]` and for
|
||||
[`~models.unet_2d_blocks.CrossAttnDownBlock2D`], [`~models.unet_2d_blocks.CrossAttnUpBlock2D`],
|
||||
@@ -147,9 +147,9 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin,
|
||||
The second activation function to use in timestep embedding. Choose from `silu`, `mish` and `gelu`.
|
||||
time_cond_proj_dim (`int`, *optional*, defaults to `None`):
|
||||
The dimension of `cond_proj` layer in the timestep embedding.
|
||||
conv_in_kernel (`int`, *optional*, default to `3`): The kernel size of `conv_in` layer. conv_out_kernel (`int`,
|
||||
*optional*, default to `3`): The kernel size of `conv_out` layer. projection_class_embeddings_input_dim (`int`,
|
||||
*optional*): The dimension of the `class_labels` input when
|
||||
conv_in_kernel (`int`, *optional*, default to `3`): The kernel size of `conv_in` layer.
|
||||
conv_out_kernel (`int`, *optional*, default to `3`): The kernel size of `conv_out` layer.
|
||||
projection_class_embeddings_input_dim (`int`, *optional*): The dimension of the `class_labels` input when
|
||||
`class_embed_type="projection"`. Required when `class_embed_type="projection"`.
|
||||
class_embeddings_concat (`bool`, *optional*, defaults to `False`): Whether to concatenate the time
|
||||
embeddings with the class embeddings.
|
||||
@@ -1226,7 +1226,7 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin,
|
||||
**additional_residuals,
|
||||
)
|
||||
else:
|
||||
sample, res_samples = downsample_block(hidden_states=sample, temb=emb, scale=lora_scale)
|
||||
sample, res_samples = downsample_block(hidden_states=sample, temb=emb)
|
||||
if is_adapter and len(down_intrablock_additional_residuals) > 0:
|
||||
sample += down_intrablock_additional_residuals.pop(0)
|
||||
|
||||
@@ -1297,7 +1297,6 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin,
|
||||
temb=emb,
|
||||
res_hidden_states_tuple=res_samples,
|
||||
upsample_size=upsample_size,
|
||||
scale=lora_scale,
|
||||
)
|
||||
|
||||
# 6. post-process
|
||||
|
||||
@@ -75,6 +75,8 @@ class FlaxUNet2DConditionModel(nn.Module, FlaxModelMixin, ConfigMixin):
|
||||
The tuple of downsample blocks to use.
|
||||
up_block_types (`Tuple[str]`, *optional*, defaults to `("FlaxUpBlock2D", "FlaxCrossAttnUpBlock2D", "FlaxCrossAttnUpBlock2D", "FlaxCrossAttnUpBlock2D")`):
|
||||
The tuple of upsample blocks to use.
|
||||
mid_block_type (`str`, *optional*, defaults to `"UNetMidBlock2DCrossAttn"`):
|
||||
Block type for middle of UNet, it can be one of `UNetMidBlock2DCrossAttn`. If `None`, the mid block layer is skipped.
|
||||
block_out_channels (`Tuple[int]`, *optional*, defaults to `(320, 640, 1280, 1280)`):
|
||||
The tuple of output channels for each block.
|
||||
layers_per_block (`int`, *optional*, defaults to 2):
|
||||
@@ -107,6 +109,7 @@ class FlaxUNet2DConditionModel(nn.Module, FlaxModelMixin, ConfigMixin):
|
||||
"DownBlock2D",
|
||||
)
|
||||
up_block_types: Tuple[str, ...] = ("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D")
|
||||
mid_block_type: Optional[str] = "UNetMidBlock2DCrossAttn"
|
||||
only_cross_attention: Union[bool, Tuple[bool]] = False
|
||||
block_out_channels: Tuple[int, ...] = (320, 640, 1280, 1280)
|
||||
layers_per_block: int = 2
|
||||
@@ -252,16 +255,21 @@ class FlaxUNet2DConditionModel(nn.Module, FlaxModelMixin, ConfigMixin):
|
||||
self.down_blocks = down_blocks
|
||||
|
||||
# mid
|
||||
self.mid_block = FlaxUNetMidBlock2DCrossAttn(
|
||||
in_channels=block_out_channels[-1],
|
||||
dropout=self.dropout,
|
||||
num_attention_heads=num_attention_heads[-1],
|
||||
transformer_layers_per_block=transformer_layers_per_block[-1],
|
||||
use_linear_projection=self.use_linear_projection,
|
||||
use_memory_efficient_attention=self.use_memory_efficient_attention,
|
||||
split_head_dim=self.split_head_dim,
|
||||
dtype=self.dtype,
|
||||
)
|
||||
if self.config.mid_block_type == "UNetMidBlock2DCrossAttn":
|
||||
self.mid_block = FlaxUNetMidBlock2DCrossAttn(
|
||||
in_channels=block_out_channels[-1],
|
||||
dropout=self.dropout,
|
||||
num_attention_heads=num_attention_heads[-1],
|
||||
transformer_layers_per_block=transformer_layers_per_block[-1],
|
||||
use_linear_projection=self.use_linear_projection,
|
||||
use_memory_efficient_attention=self.use_memory_efficient_attention,
|
||||
split_head_dim=self.split_head_dim,
|
||||
dtype=self.dtype,
|
||||
)
|
||||
elif self.config.mid_block_type is None:
|
||||
self.mid_block = None
|
||||
else:
|
||||
raise ValueError(f"Unexpected mid_block_type {self.config.mid_block_type}")
|
||||
|
||||
# up
|
||||
up_blocks = []
|
||||
@@ -412,7 +420,8 @@ class FlaxUNet2DConditionModel(nn.Module, FlaxModelMixin, ConfigMixin):
|
||||
down_block_res_samples = new_down_block_res_samples
|
||||
|
||||
# 4. mid
|
||||
sample = self.mid_block(sample, t_emb, encoder_hidden_states, deterministic=not train)
|
||||
if self.mid_block is not None:
|
||||
sample = self.mid_block(sample, t_emb, encoder_hidden_states, deterministic=not train)
|
||||
|
||||
if mid_block_additional_residual is not None:
|
||||
sample += mid_block_additional_residual
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user